Files
kangkang/体己/AI/LLMSession.swift
link2026 1ee512dce1 harden(ai): LLMSession 取消时跳过 MLX.GPU.synchronize
按 code quality review(P0)反馈,for-await 因 Task.isCancelled
退出时,GPU.synchronize() 不必执行——这是一个阻塞的 GPU 同步操作,
取消场景下属浪费。

W3 引入"用户取消推理"UI 时会更频繁触发此路径。

P1/P2 留待 W3 退散考量:
- decodeRate 用窗口平均(目前是累积)
- AIRuntime 持具体 LLMSession 类型,W3 抽 protocol 做 mock
- prompt 空字符串守门
- Float(0.6) 风格

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 16:06:09 +08:00

106 lines
3.7 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import Foundation
// TODO: MLX
// import MLX
// import MLXLLM
// import MLXLMCommon
// MLX
#if false
/// MLX ,actor 线访
/// mlx-swift-examples 2.29.1(commit 9bff95ca) API
actor LLMSession {
let container: ModelContainer
init(container: ModelContainer) {
self.container = container
}
/// ( config.json + weights + tokenizer)
static func load(folderURL: URL) async throws -> LLMSession {
let configuration = ModelConfiguration(directory: folderURL)
let container = try await LLMModelFactory.shared.loadContainer(
configuration: configuration
)
return LLMSession(container: container)
}
/// AsyncThrowingStream , Task
/// - Parameters:
/// - prompt: prompt ( processor LMInput)
/// - maxTokens: token , GenerateParameters
func generate(prompt: String, maxTokens: Int) -> AsyncThrowingStream<TokenChunk, Error> {
AsyncThrowingStream { continuation in
let task = Task {
do {
let parameters = GenerateParameters(
maxTokens: maxTokens,
temperature: Float(0.6),
topP: Float(0.9)
)
try await container.perform { (context: ModelContext) in
let userInput = UserInput(prompt: prompt)
let lmInput = try await context.processor.prepare(input: userInput)
let start = Date()
var produced = 0
for await event in try MLXLMCommon.generate(
input: lmInput,
parameters: parameters,
context: context
) {
if Task.isCancelled { break }
switch event {
case .chunk(let text):
produced += 1
let elapsed = Date().timeIntervalSince(start)
let rate = elapsed > 0 ? Double(produced) / elapsed : 0
continuation.yield(TokenChunk(text: text, decodeRate: rate))
case .info:
// ,
break
case .toolCall:
// ,switch
break
}
}
if !Task.isCancelled {
MLX.GPU.synchronize()
}
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
continuation.onTermination = { _ in task.cancel() }
}
}
}
#endif
//
class ModelContainer: @unchecked Sendable {
init() {}
}
struct ModelConfiguration {
let directory: URL
init(directory: URL) { self.directory = directory }
}
class LLMModelFactory: @unchecked Sendable {
static let shared = LLMModelFactory()
func loadContainer(configuration: ModelConfiguration) async throws -> ModelContainer {
throw NSError(domain: "MLXNotAvailable", code: -1, userInfo: [NSLocalizedDescriptionKey: "MLX framework not available"])
}
}