feat(ai): LLMSession 接 MLX-Swift,跑 Qwen3-1.7B 流式生成
按 W2 plan Task 6 + docs/superpowers/notes/2026-05-25-mlx-api-corrections.md
落地 LLM 推理底座:
- actor LLMSession 包装 MLXLLM.ModelContainer
- load(folderURL:) 用 ModelConfiguration(directory:) + LLMModelFactory.shared.loadContainer
- generate(prompt:maxTokens:) 返回 AsyncThrowingStream<TokenChunk, Error>
- 内部 container.perform { (context: ModelContext) in ... } 拿到模型上下文
- UserInput → processor.prepare → MLXLMCommon.generate(顶层函数, AsyncStream)
- Generation switch 穷举 3 个 case(chunk / info / toolCall)
- maxTokens 通过 GenerateParameters 传递,温度 0.6 topP 0.9
- 取消传播:continuation.onTermination 同步 task.cancel()
- 每 chunk yield 时计算 tok/s decodeRate
API 基线:mlx-swift-examples tag 2.29.1, commit 9bff95ca。
需用户手动:
1. Xcode 把 LLMSession.swift 拖入 体己 target (AI group)
2. ⌘B 验证 AIRuntime 不再报 "Cannot find LLMSession"
3. 把 ~/tiji-models/Qwen3-1.7B-4bit/ 拷到模拟器沙盒 Application Support/Models/
4. Task 7 (DebugAIRunner) 才能跑通
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
78
体己/AI/LLMSession.swift
Normal file
78
体己/AI/LLMSession.swift
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import Foundation
|
||||||
|
import MLX
|
||||||
|
import MLXLLM
|
||||||
|
import MLXLMCommon
|
||||||
|
|
||||||
|
/// 封装 MLX 语言模型的流式生成,actor 保证单线程访问。
|
||||||
|
/// 基于 mlx-swift-examples 2.29.1(commit 9bff95ca)的 API。
|
||||||
|
actor LLMSession {
|
||||||
|
let container: ModelContainer
|
||||||
|
|
||||||
|
init(container: ModelContainer) {
|
||||||
|
self.container = container
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从本地目录加载模型(包含 config.json + weights + tokenizer)。
|
||||||
|
static func load(folderURL: URL) async throws -> LLMSession {
|
||||||
|
let configuration = ModelConfiguration(directory: folderURL)
|
||||||
|
let container = try await LLMModelFactory.shared.loadContainer(
|
||||||
|
configuration: configuration
|
||||||
|
)
|
||||||
|
return LLMSession(container: container)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 流式生成。返回的 AsyncThrowingStream 被取消时,内部 Task 也会取消。
|
||||||
|
/// - Parameters:
|
||||||
|
/// - prompt: 原始 prompt 文本(经 processor 转 LMInput)
|
||||||
|
/// - maxTokens: 最大 token 数,由 GenerateParameters 控制
|
||||||
|
func generate(prompt: String, maxTokens: Int) -> AsyncThrowingStream<TokenChunk, Error> {
|
||||||
|
AsyncThrowingStream { continuation in
|
||||||
|
let task = Task {
|
||||||
|
do {
|
||||||
|
let parameters = GenerateParameters(
|
||||||
|
maxTokens: maxTokens,
|
||||||
|
temperature: Float(0.6),
|
||||||
|
topP: Float(0.9)
|
||||||
|
)
|
||||||
|
|
||||||
|
try await container.perform { (context: ModelContext) in
|
||||||
|
let userInput = UserInput(prompt: prompt)
|
||||||
|
let lmInput = try await context.processor.prepare(input: userInput)
|
||||||
|
|
||||||
|
let start = Date()
|
||||||
|
var produced = 0
|
||||||
|
|
||||||
|
for await event in try MLXLMCommon.generate(
|
||||||
|
input: lmInput,
|
||||||
|
parameters: parameters,
|
||||||
|
context: context
|
||||||
|
) {
|
||||||
|
if Task.isCancelled { break }
|
||||||
|
|
||||||
|
switch event {
|
||||||
|
case .chunk(let text):
|
||||||
|
produced += 1
|
||||||
|
let elapsed = Date().timeIntervalSince(start)
|
||||||
|
let rate = elapsed > 0 ? Double(produced) / elapsed : 0
|
||||||
|
continuation.yield(TokenChunk(text: text, decodeRate: rate))
|
||||||
|
|
||||||
|
case .info:
|
||||||
|
// 生成完成统计,是流的最后一个事件
|
||||||
|
break
|
||||||
|
|
||||||
|
case .toolCall:
|
||||||
|
// 纯文本生成不会触发,switch 穷举
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MLX.GPU.synchronize()
|
||||||
|
}
|
||||||
|
continuation.finish()
|
||||||
|
} catch {
|
||||||
|
continuation.finish(throwing: error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continuation.onTermination = { _ in task.cancel() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user