diff --git a/体己/AI/AIRuntime.swift b/体己/AI/AIRuntime.swift new file mode 100644 index 0000000..4efbcee --- /dev/null +++ b/体己/AI/AIRuntime.swift @@ -0,0 +1,91 @@ +import Foundation + +enum AIRuntimeError: Error, LocalizedError { + case notReady + case modelLoadFailed(String) + case inferenceFailed(String) + + var errorDescription: String? { + switch self { + case .notReady: return "AI 模型尚未准备好" + case .modelLoadFailed(let m): return "模型加载失败:\(m)" + case .inferenceFailed(let m): return "推理失败:\(m)" + } + } +} + +actor AIRuntime { + static let shared = AIRuntime() + + enum Status: Sendable, Equatable { + case notReady + case loading + case ready + case error(String) + } + + private(set) var status: Status = .notReady + private(set) var lastDecodeRate: Double = 0 + + private var llmSession: LLMSession? + + private init() {} + + /// 加载模型。首次调用会真正加载,后续幂等。 + func prepare() async throws { + switch status { + case .ready: return + case .loading: return // 已经在加载 + case .error, .notReady: break + } + + guard ModelStore.shared.isReady(.llm) else { + status = .error("LLM 模型未就绪") + throw AIRuntimeError.notReady + } + + status = .loading + do { + let session = try await LLMSession.load( + folderURL: ModelStore.shared.localURL(for: .llm) + ) + self.llmSession = session + status = .ready + } catch { + status = .error("\(error)") + throw AIRuntimeError.modelLoadFailed("\(error)") + } + } + + /// 流式生成。调用前应先 await prepare()。 + /// 注意:返回流是同步创建的,但跨 actor 调用 LLMSession 需要 await。 + func generate(prompt: String, maxTokens: Int = 256) -> AsyncThrowingStream { + // 在 actor 隔离上下文中捕获快照,Task 内不再访问 self.status / self.llmSession + let snapshotStatus = status + let snapshotSession = llmSession + + return AsyncThrowingStream { continuation in + Task { [weak self] in + guard snapshotStatus == .ready, let session = snapshotSession else { + continuation.finish(throwing: AIRuntimeError.notReady) + return + } + do { + // session.generate 跨 actor 边界,需要 await + let stream = await session.generate(prompt: prompt, maxTokens: maxTokens) + for try await chunk in stream { + await self?.recordRate(chunk.decodeRate) + continuation.yield(chunk) + } + continuation.finish() + } catch { + continuation.finish(throwing: AIRuntimeError.inferenceFailed("\(error)")) + } + } + } + } + + private func recordRate(_ rate: Double) { + if rate > 0 { lastDecodeRate = rate } + } +} diff --git a/体己/AI/TokenChunk.swift b/体己/AI/TokenChunk.swift new file mode 100644 index 0000000..90e21f2 --- /dev/null +++ b/体己/AI/TokenChunk.swift @@ -0,0 +1,6 @@ +import Foundation + +struct TokenChunk: Sendable { + let text: String + let decodeRate: Double // tokens / second, 估算值 +}