feat(iOS): 更新MNN后端模型配置优化性能 将MNN主模型从Qwen3.5-4B(~2.64GiB)降级为Qwen3.5-2B(~1.1GiB),因为4B版本 实测运行过慢,影响用户体验。iPhone17+/SME2设备使用2B模型,保留MLX 兜底方案用于模拟器和备用场景,确保AI推理性能和存储效率的平衡。 ```
105 lines
4.7 KiB
Swift
105 lines
4.7 KiB
Swift
import Foundation
|
|
import MLX
|
|
import MLXLLM
|
|
import MLXLMCommon
|
|
|
|
/// 封装 MLX 语言模型的流式生成,actor 保证单线程访问。
|
|
/// 基于 mlx-swift-examples 2.29.1(commit 9bff95ca)的 API。
|
|
actor LLMSession {
|
|
let container: ModelContainer
|
|
|
|
init(container: ModelContainer) {
|
|
self.container = container
|
|
}
|
|
|
|
/// 在 simulator 把默认设备强切为 CPU(MLX 的 Metal backend 在部分 Sim 路径会 abort)。
|
|
/// 真机走 body 默认设备(GPU/ANE)。
|
|
/// 用 task-scoped `withDefaultDevice`,TaskLocal 会传递到 child Task / actor 调用。
|
|
private static func withDeviceOverride<R>(
|
|
_ body: () async throws -> R
|
|
) async rethrows -> R {
|
|
#if targetEnvironment(simulator)
|
|
return try await Device.withDefaultDevice(.cpu, body)
|
|
#else
|
|
return try await body()
|
|
#endif
|
|
}
|
|
|
|
/// 从本地目录加载模型(包含 config.json + weights + tokenizer)。
|
|
static func load(folderURL: URL) async throws -> LLMSession {
|
|
let configuration = ModelConfiguration(directory: folderURL)
|
|
let container = try await withDeviceOverride {
|
|
try await LLMModelFactory.shared.loadContainer(
|
|
configuration: configuration
|
|
)
|
|
}
|
|
return LLMSession(container: container)
|
|
}
|
|
|
|
/// 流式生成。返回的 AsyncThrowingStream 被取消时,内部 Task 也会取消。
|
|
/// - Parameters:
|
|
/// - prompt: 原始 prompt 文本(经 processor 转 LMInput)
|
|
/// - maxTokens: 最大 token 数,由 GenerateParameters 控制
|
|
func generate(prompt: String, maxTokens: Int) -> AsyncThrowingStream<TokenChunk, Error> {
|
|
AsyncThrowingStream { continuation in
|
|
let task = Task {
|
|
do {
|
|
try await Self.withDeviceOverride {
|
|
// 低温:本 App 文本任务多为"直答/JSON 抽取",高温随机性会经常吐成非 JSON。
|
|
// 0.3 + topP 0.85 让输出更确定、JSON 更稳(与 MNN set_config 降温对齐)。
|
|
// repetitionPenalty:低温 + 无惩罚时,长文本(如「关键指标」列表)会逐行复读
|
|
// 进入死循环;1.1 的重复惩罚 + 64 token 上下文窗口掐断复读(与 MNN penalty 对齐)。
|
|
let parameters = GenerateParameters(
|
|
maxTokens: maxTokens,
|
|
temperature: Float(0.3),
|
|
topP: Float(0.85),
|
|
repetitionPenalty: Float(1.1),
|
|
repetitionContextSize: 64
|
|
)
|
|
|
|
try await container.perform { (context: ModelContext) in
|
|
let userInput = UserInput(prompt: prompt)
|
|
let lmInput = try await context.processor.prepare(input: userInput)
|
|
|
|
let start = Date()
|
|
var produced = 0
|
|
|
|
for await event in try MLXLMCommon.generate(
|
|
input: lmInput,
|
|
parameters: parameters,
|
|
context: context
|
|
) {
|
|
if Task.isCancelled { break }
|
|
|
|
switch event {
|
|
case .chunk(let text):
|
|
produced += 1
|
|
let elapsed = Date().timeIntervalSince(start)
|
|
let rate = elapsed > 0 ? Double(produced) / elapsed : 0
|
|
continuation.yield(TokenChunk(text: text, decodeRate: rate))
|
|
|
|
case .info:
|
|
// 生成完成统计,是流的最后一个事件
|
|
break
|
|
|
|
case .toolCall:
|
|
// 纯文本生成不会触发,switch 穷举
|
|
break
|
|
}
|
|
}
|
|
// 注:研究笔记里曾建议尾部 MLX.GPU.synchronize() 以确保
|
|
// GPU 操作全部完成。但 AsyncStream 已经 yield 真实解码后的
|
|
// 文字,GPU 是否完全空闲不影响数据正确性。去掉此调用同时省
|
|
// 一份 transitive import MLX 的依赖,简化 SPM 链接。
|
|
}
|
|
}
|
|
continuation.finish()
|
|
} catch {
|
|
continuation.finish(throwing: error)
|
|
}
|
|
}
|
|
continuation.onTermination = { _ in task.cancel() }
|
|
}
|
|
}
|
|
}
|