import Foundation enum AIRuntimeError: Error, LocalizedError { case notReady case modelLoadFailed(String) case inferenceFailed(String) var errorDescription: String? { switch self { case .notReady: return "AI 模型尚未准备好" case .modelLoadFailed(let m): return "模型加载失败:\(m)" case .inferenceFailed(let m): return "推理失败:\(m)" } } } actor AIRuntime { static let shared = AIRuntime() enum Status: Sendable, Equatable { case notReady case loading case ready case error(String) } private(set) var status: Status = .notReady private(set) var lastDecodeRate: Double = 0 private var llmSession: LLMSession? private init() {} /// 加载模型。首次调用会真正加载,后续幂等。 func prepare() async throws { switch status { case .ready: return case .loading: // 已有其他调用方在加载;本次 prepare 直接返回, // 调用方需稍后 await prepare() 再判 status,或自行轮询 / 显示加载 UI。 // W3 引入 prepare 队列时优化。 return case .error, .notReady: break } guard ModelStore.shared.isReady(.llm) else { status = .error("LLM 模型未就绪") throw AIRuntimeError.notReady } status = .loading do { let session = try await LLMSession.load( folderURL: ModelStore.shared.localURL(for: .llm) ) self.llmSession = session status = .ready } catch { status = .error("\(error)") throw AIRuntimeError.modelLoadFailed("\(error)") } } /// 流式生成。调用前应先 await prepare()。 /// 注意:返回流是同步创建的,但跨 actor 调用 LLMSession 需要 await。 func generate(prompt: String, maxTokens: Int = 256) -> AsyncThrowingStream { // 在 actor 隔离上下文中捕获快照,Task 内不再访问 self.status / self.llmSession let snapshotStatus = status let snapshotSession = llmSession return AsyncThrowingStream { continuation in Task { guard snapshotStatus == .ready, let session = snapshotSession else { continuation.finish(throwing: AIRuntimeError.notReady) return } do { // session.generate 跨 actor 边界,需要 await let stream = await session.generate(prompt: prompt, maxTokens: maxTokens) for try await chunk in stream { // Task 闭包在 generate() 内启动,继承 AIRuntime 的 actor 隔离; // 调用同 actor 的 recordRate 不需要 await self.recordRate(chunk.decodeRate) continuation.yield(chunk) } continuation.finish() } catch { continuation.finish(throwing: AIRuntimeError.inferenceFailed("\(error)")) } } } } private func recordRate(_ rate: Double) { if rate > 0 { lastDecodeRate = rate } } }