import Foundation import MLX import MLXLLM import MLXLMCommon /// 封装 MLX 语言模型的流式生成,actor 保证单线程访问。 /// 基于 mlx-swift-examples 2.29.1(commit 9bff95ca)的 API。 actor LLMSession { let container: ModelContainer init(container: ModelContainer) { self.container = container } /// 从本地目录加载模型(包含 config.json + weights + tokenizer)。 static func load(folderURL: URL) async throws -> LLMSession { let configuration = ModelConfiguration(directory: folderURL) let container = try await LLMModelFactory.shared.loadContainer( configuration: configuration ) return LLMSession(container: container) } /// 流式生成。返回的 AsyncThrowingStream 被取消时,内部 Task 也会取消。 /// - Parameters: /// - prompt: 原始 prompt 文本(经 processor 转 LMInput) /// - maxTokens: 最大 token 数,由 GenerateParameters 控制 func generate(prompt: String, maxTokens: Int) -> AsyncThrowingStream { AsyncThrowingStream { continuation in let task = Task { do { let parameters = GenerateParameters( maxTokens: maxTokens, temperature: Float(0.6), topP: Float(0.9) ) try await container.perform { (context: ModelContext) in let userInput = UserInput(prompt: prompt) let lmInput = try await context.processor.prepare(input: userInput) let start = Date() var produced = 0 for await event in try MLXLMCommon.generate( input: lmInput, parameters: parameters, context: context ) { if Task.isCancelled { break } switch event { case .chunk(let text): produced += 1 let elapsed = Date().timeIntervalSince(start) let rate = elapsed > 0 ? Double(produced) / elapsed : 0 continuation.yield(TokenChunk(text: text, decodeRate: rate)) case .info: // 生成完成统计,是流的最后一个事件 break case .toolCall: // 纯文本生成不会触发,switch 穷举 break } } MLX.GPU.synchronize() } continuation.finish() } catch { continuation.finish(throwing: error) } } continuation.onTermination = { _ in task.cancel() } } } }