diff --git a/体己/AI/LLMSession.swift b/体己/AI/LLMSession.swift new file mode 100644 index 0000000..20b6dd5 --- /dev/null +++ b/体己/AI/LLMSession.swift @@ -0,0 +1,78 @@ +import Foundation +import MLX +import MLXLLM +import MLXLMCommon + +/// 封装 MLX 语言模型的流式生成,actor 保证单线程访问。 +/// 基于 mlx-swift-examples 2.29.1(commit 9bff95ca)的 API。 +actor LLMSession { + let container: ModelContainer + + init(container: ModelContainer) { + self.container = container + } + + /// 从本地目录加载模型(包含 config.json + weights + tokenizer)。 + static func load(folderURL: URL) async throws -> LLMSession { + let configuration = ModelConfiguration(directory: folderURL) + let container = try await LLMModelFactory.shared.loadContainer( + configuration: configuration + ) + return LLMSession(container: container) + } + + /// 流式生成。返回的 AsyncThrowingStream 被取消时,内部 Task 也会取消。 + /// - Parameters: + /// - prompt: 原始 prompt 文本(经 processor 转 LMInput) + /// - maxTokens: 最大 token 数,由 GenerateParameters 控制 + func generate(prompt: String, maxTokens: Int) -> AsyncThrowingStream { + AsyncThrowingStream { continuation in + let task = Task { + do { + let parameters = GenerateParameters( + maxTokens: maxTokens, + temperature: Float(0.6), + topP: Float(0.9) + ) + + try await container.perform { (context: ModelContext) in + let userInput = UserInput(prompt: prompt) + let lmInput = try await context.processor.prepare(input: userInput) + + let start = Date() + var produced = 0 + + for await event in try MLXLMCommon.generate( + input: lmInput, + parameters: parameters, + context: context + ) { + if Task.isCancelled { break } + + switch event { + case .chunk(let text): + produced += 1 + let elapsed = Date().timeIntervalSince(start) + let rate = elapsed > 0 ? Double(produced) / elapsed : 0 + continuation.yield(TokenChunk(text: text, decodeRate: rate)) + + case .info: + // 生成完成统计,是流的最后一个事件 + break + + case .toolCall: + // 纯文本生成不会触发,switch 穷举 + break + } + } + MLX.GPU.synchronize() + } + continuation.finish() + } catch { + continuation.finish(throwing: error) + } + } + continuation.onTermination = { _ in task.cancel() } + } + } +}