按 code quality review(P0)反馈,for-await 因 Task.isCancelled 退出时,GPU.synchronize() 不必执行——这是一个阻塞的 GPU 同步操作, 取消场景下属浪费。 W3 引入"用户取消推理"UI 时会更频繁触发此路径。 P1/P2 留待 W3 退散考量: - decodeRate 用窗口平均(目前是累积) - AIRuntime 持具体 LLMSession 类型,W3 抽 protocol 做 mock - prompt 空字符串守门 - Float(0.6) 风格 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
106 lines
3.7 KiB
Swift
106 lines
3.7 KiB
Swift
import Foundation
|
||
|
||
// TODO: 添加 MLX 依赖后取消注释
|
||
// import MLX
|
||
// import MLXLLM
|
||
// import MLXLMCommon
|
||
|
||
// 临时占位符类型,等添加 MLX 依赖后删除
|
||
#if false
|
||
|
||
/// 封装 MLX 语言模型的流式生成,actor 保证单线程访问。
|
||
/// 基于 mlx-swift-examples 2.29.1(commit 9bff95ca)的 API。
|
||
actor LLMSession {
|
||
let container: ModelContainer
|
||
|
||
init(container: ModelContainer) {
|
||
self.container = container
|
||
}
|
||
|
||
/// 从本地目录加载模型(包含 config.json + weights + tokenizer)。
|
||
static func load(folderURL: URL) async throws -> LLMSession {
|
||
let configuration = ModelConfiguration(directory: folderURL)
|
||
let container = try await LLMModelFactory.shared.loadContainer(
|
||
configuration: configuration
|
||
)
|
||
return LLMSession(container: container)
|
||
}
|
||
|
||
/// 流式生成。返回的 AsyncThrowingStream 被取消时,内部 Task 也会取消。
|
||
/// - Parameters:
|
||
/// - prompt: 原始 prompt 文本(经 processor 转 LMInput)
|
||
/// - maxTokens: 最大 token 数,由 GenerateParameters 控制
|
||
func generate(prompt: String, maxTokens: Int) -> AsyncThrowingStream<TokenChunk, Error> {
|
||
AsyncThrowingStream { continuation in
|
||
let task = Task {
|
||
do {
|
||
let parameters = GenerateParameters(
|
||
maxTokens: maxTokens,
|
||
temperature: Float(0.6),
|
||
topP: Float(0.9)
|
||
)
|
||
|
||
try await container.perform { (context: ModelContext) in
|
||
let userInput = UserInput(prompt: prompt)
|
||
let lmInput = try await context.processor.prepare(input: userInput)
|
||
|
||
let start = Date()
|
||
var produced = 0
|
||
|
||
for await event in try MLXLMCommon.generate(
|
||
input: lmInput,
|
||
parameters: parameters,
|
||
context: context
|
||
) {
|
||
if Task.isCancelled { break }
|
||
|
||
switch event {
|
||
case .chunk(let text):
|
||
produced += 1
|
||
let elapsed = Date().timeIntervalSince(start)
|
||
let rate = elapsed > 0 ? Double(produced) / elapsed : 0
|
||
continuation.yield(TokenChunk(text: text, decodeRate: rate))
|
||
|
||
case .info:
|
||
// 生成完成统计,是流的最后一个事件
|
||
break
|
||
|
||
case .toolCall:
|
||
// 纯文本生成不会触发,switch 穷举
|
||
break
|
||
}
|
||
}
|
||
if !Task.isCancelled {
|
||
MLX.GPU.synchronize()
|
||
}
|
||
}
|
||
continuation.finish()
|
||
} catch {
|
||
continuation.finish(throwing: error)
|
||
}
|
||
}
|
||
continuation.onTermination = { _ in task.cancel() }
|
||
}
|
||
}
|
||
}
|
||
|
||
#endif
|
||
|
||
// 临时实现,用于编译通过
|
||
class ModelContainer: @unchecked Sendable {
|
||
init() {}
|
||
}
|
||
|
||
struct ModelConfiguration {
|
||
let directory: URL
|
||
init(directory: URL) { self.directory = directory }
|
||
}
|
||
|
||
class LLMModelFactory: @unchecked Sendable {
|
||
static let shared = LLMModelFactory()
|
||
|
||
func loadContainer(configuration: ModelConfiguration) async throws -> ModelContainer {
|
||
throw NSError(domain: "MLXNotAvailable", code: -1, userInfo: [NSLocalizedDescriptionKey: "MLX framework not available"])
|
||
}
|
||
}
|