- InferenceEngine:引擎枚举(.mnn 默认 / .mlx 兜底)+ UserDefaults 持久化 + 可用性/SME2 运行时探测(经 MNNLLMBridge) - MNNBackend:actor 封装 MNNLLMBridge 文本流式生成,detached 线程跑同步 response、按 UTF-8 边界 yield TokenChunk,串行化交给 AIRuntime 闸门 - AIRuntime:prepare/generate 按引擎分发;.mnn 且模型就绪→MNN,否则回退 MLX (过渡期 App 始终可用);prepareVL/单模型常驻时互卸 MNN↔MLX 释放内存 公有 API 不变,各 Service 零改动 模拟器 BUILD SUCCEEDED,0 error。引擎切换 UI + SME2 指示留待 Phase 5。 Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
335 lines
14 KiB
Swift
335 lines
14 KiB
Swift
import Foundation
|
|
import MLX
|
|
|
|
enum AIRuntimeError: Error, LocalizedError {
|
|
case notReady
|
|
case modelLoadFailed(String)
|
|
case inferenceFailed(String)
|
|
|
|
var errorDescription: String? {
|
|
switch self {
|
|
case .notReady: return String(appLoc: "AI 模型尚未准备好")
|
|
case .modelLoadFailed(let m): return String(appLoc: "模型加载失败:\(m)")
|
|
case .inferenceFailed(let m): return String(appLoc: "推理失败:\(m)")
|
|
}
|
|
}
|
|
}
|
|
|
|
actor AIRuntime {
|
|
static let shared = AIRuntime()
|
|
|
|
enum Status: Sendable, Equatable {
|
|
case notReady
|
|
case loading
|
|
case ready
|
|
case error(String)
|
|
}
|
|
|
|
private(set) var status: Status = .notReady
|
|
private(set) var vlStatus: Status = .notReady
|
|
private(set) var lastDecodeRate: Double = 0
|
|
|
|
private var llmSession: LLMSession?
|
|
private var vlSession: VLSession?
|
|
|
|
// MARK: - MNN 后端(CPU/SME2,挑战赛考核路径)
|
|
// 文本生成在 .mnn 引擎下走 MNN;VL(图→文)暂仍走 MLX(MNN VL 需 OMNI 构建)。
|
|
private let mnn = MNNBackend()
|
|
private(set) var mnnStatus: Status = .notReady
|
|
/// MNN 模型目录(下载/旁路导入到 Models/Qwen3.5-2B-MNN)。
|
|
nonisolated static var mnnModelFolder: URL {
|
|
ModelStore.shared.rootURL.appendingPathComponent("Qwen3.5-2B-MNN", isDirectory: true)
|
|
}
|
|
|
|
// MARK: - 串行推理闸门(§3.1 OOM 防护的真正落地)
|
|
//
|
|
// actor 只串行化「方法入口」,但 generate() 同步返回流、真正解码在内部 Task;
|
|
// analyzeReport 也在 await 期间让出 actor。若不加闸门,LLM 流正在解码时触发 VL,
|
|
// 两个模型会同时在 GPU 上解码 → 冲过单 App 内存上限被 jetsam 杀
|
|
//(MEMORY 记录的「in-flight 流并发窄口」)。
|
|
//
|
|
// 这里用 actor 内信号量(count = 1):所有「会占显存的重活」(解码 + 模型加载)
|
|
// 进入前先 await acquireGate(),结束后 releaseGate()。actor 串行执行保证
|
|
// gateBusy / gateWaiters 的读写天然无并发。
|
|
private var gateBusy = false
|
|
private var gateWaiters: [CheckedContinuation<Void, Never>] = []
|
|
|
|
private func acquireGate() async {
|
|
if !gateBusy {
|
|
gateBusy = true
|
|
return
|
|
}
|
|
await withCheckedContinuation { (cont: CheckedContinuation<Void, Never>) in
|
|
gateWaiters.append(cont)
|
|
}
|
|
// 被 releaseGate 唤醒时即已持有闸门(gateBusy 保持 true)。
|
|
}
|
|
|
|
private func releaseGate() {
|
|
if gateWaiters.isEmpty {
|
|
gateBusy = false
|
|
} else {
|
|
// 把闸门直接交给队首等待者,gateBusy 维持 true,不留空窗。
|
|
let next = gateWaiters.removeFirst()
|
|
next.resume()
|
|
}
|
|
}
|
|
|
|
private init() {}
|
|
|
|
/// App 启动时调用一次:给 MLX 的 GPU 缓冲池设上限,避免 reuse cache 在大模型常驻之上
|
|
/// 继续膨胀、把峰值推过单 App 内存上限。仅真机生效(模拟器走 CPU,且部分 Metal 路径会 abort)。
|
|
/// 与 increased-memory-limit entitlement + LLM/VL 互斥卸载配合,三管齐下防 jetsam OOM。
|
|
nonisolated static func configureMLXMemory() {
|
|
#if !targetEnvironment(simulator)
|
|
// 256MB cache 上限:够复用、不至于在 3GB 模型之上再囤几百 MB 空闲缓冲。
|
|
MLX.Memory.cacheLimit = 256 * 1024 * 1024
|
|
#endif
|
|
}
|
|
|
|
/// 加载文本模型。首次调用会真正加载,后续幂等。
|
|
/// 按当前引擎路由:.mnn → MNN(CPU/SME2);.mlx → 现有 MLX(GPU)。
|
|
func prepare() async throws {
|
|
// 选了 MNN 且模型已就绪才走 MNN;否则(选 MLX,或 MNN 模型尚未下载)回退 MLX,
|
|
// 保证过渡期 App 始终可用。引擎指示器(Phase 5)展示实际生效后端。
|
|
let mnnReady = FileManager.default.fileExists(
|
|
atPath: Self.mnnModelFolder.appendingPathComponent("config.json").path)
|
|
if InferenceEngine.current == .mnn, mnnReady {
|
|
try await prepareMNN()
|
|
return
|
|
}
|
|
// 走 MLX:先卸 MNN 释放内存(单模型常驻策略)。
|
|
await unloadMNN()
|
|
// 已有其他调用方在加载时,轮询等其结束再判定结果。
|
|
// 不能像旧实现那样裸 return:那会让调用方误以为已 ready,随后 generate 的
|
|
// `guard status == .ready` 失败 → 用户撞上「假错误屏」(模型其实正常加载中)。
|
|
while status == .loading {
|
|
try await Task.sleep(nanoseconds: 80_000_000)
|
|
}
|
|
if status == .ready { return }
|
|
|
|
// 用 isComplete(逐文件字节校验)而非 isReady(只看 config.json):config.json 最小最先下完,
|
|
// 半下载时 isReady 仍 true 会让加载在残缺 safetensors 上崩溃。与 ModelDownloadService 的
|
|
// 完成判据保持一致(它也用 isComplete)。
|
|
guard ModelStore.shared.isComplete(for: .llm) else {
|
|
status = .error("LLM 模型未就绪")
|
|
throw AIRuntimeError.notReady
|
|
}
|
|
|
|
// 进闸门:等所有在跑的推理(可能是 VL 解码)结束,再卸 VL + 载 LLM,
|
|
// 避免「VL 解码 + LLM 加载」内存峰值叠加 OOM。
|
|
await acquireGate()
|
|
defer { releaseGate() }
|
|
// 拿到闸门后复查:排队期间可能已被别的调用方加载好,避免重复 load。
|
|
if status == .ready { return }
|
|
|
|
// OOM 闸门(§3.1):LLM(~1GB)与 VL(~3GB)不可同时常驻,叠加会冲过单 App 内存上限被 jetsam 杀。
|
|
unloadVL()
|
|
|
|
status = .loading
|
|
do {
|
|
let session = try await LLMSession.load(
|
|
folderURL: ModelStore.shared.localURL(for: .llm)
|
|
)
|
|
self.llmSession = session
|
|
status = .ready
|
|
} catch {
|
|
status = .error("\(error)")
|
|
throw AIRuntimeError.modelLoadFailed("\(error)")
|
|
}
|
|
}
|
|
|
|
/// 加载 MNN 文本模型。幂等。单模型常驻:载入前卸掉 MLX 的 LLM/VL 释放内存。
|
|
private func prepareMNN() async throws {
|
|
while mnnStatus == .loading {
|
|
try await Task.sleep(nanoseconds: 80_000_000)
|
|
}
|
|
if mnnStatus == .ready { return }
|
|
|
|
let folder = Self.mnnModelFolder
|
|
let config = folder.appendingPathComponent("config.json").path
|
|
guard FileManager.default.fileExists(atPath: config) else {
|
|
mnnStatus = .error("MNN 模型未就绪")
|
|
throw AIRuntimeError.notReady
|
|
}
|
|
|
|
await acquireGate()
|
|
defer { releaseGate() }
|
|
if mnnStatus == .ready { return }
|
|
|
|
// 单模型常驻:卸 MLX LLM/VL,避免与 MNN 模型叠加占内存。
|
|
unloadLLM()
|
|
unloadVL()
|
|
|
|
mnnStatus = .loading
|
|
do {
|
|
try await mnn.load(folderURL: folder)
|
|
mnnStatus = .ready
|
|
} catch {
|
|
mnnStatus = .error("\(error)")
|
|
throw AIRuntimeError.modelLoadFailed("\(error)")
|
|
}
|
|
}
|
|
|
|
/// 卸载 MNN,释放桥与权重。幂等。
|
|
private func unloadMNN() async {
|
|
guard mnnStatus != .notReady else { return }
|
|
await mnn.unload()
|
|
mnnStatus = .notReady
|
|
MLX.Memory.clearCache()
|
|
}
|
|
|
|
/// 流式生成。调用前应先 await prepare()。
|
|
/// 注意:返回流是同步创建的,但跨 actor 调用 LLMSession 需要 await。
|
|
func generate(prompt: String, maxTokens: Int = 256) -> AsyncThrowingStream<TokenChunk, Error> {
|
|
if InferenceEngine.current == .mnn, mnnStatus == .ready {
|
|
return mnnGenerate(prompt: prompt, maxTokens: maxTokens)
|
|
}
|
|
// 在 actor 隔离上下文中捕获快照,Task 内不再访问 self.status / self.llmSession
|
|
let snapshotStatus = status
|
|
let snapshotSession = llmSession
|
|
|
|
return AsyncThrowingStream { continuation in
|
|
let task = Task {
|
|
guard snapshotStatus == .ready, let session = snapshotSession else {
|
|
continuation.finish(throwing: AIRuntimeError.notReady)
|
|
return
|
|
}
|
|
// 进闸门:保证本次 LLM 解码与任何 VL 解码 / 模型加载串行,绝不并发占显存。
|
|
await self.acquireGate()
|
|
do {
|
|
// session.generate 跨 actor 边界,需要 await
|
|
let stream = await session.generate(prompt: prompt, maxTokens: maxTokens)
|
|
for try await chunk in stream {
|
|
// 消费者(UI)提前关闭/取消时,下面的 checkCancellation 让本 Task 尽快退出,
|
|
// 连带丢弃 session 流并触发其 onTermination,停止底层 MLX 解码,不空耗 GPU。
|
|
try Task.checkCancellation()
|
|
// Task 闭包在 generate() 内启动,继承 AIRuntime 的 actor 隔离;
|
|
// 调用同 actor 的 recordRate 不需要 await
|
|
self.recordRate(chunk.decodeRate)
|
|
continuation.yield(chunk)
|
|
}
|
|
continuation.finish()
|
|
} catch {
|
|
continuation.finish(throwing: AIRuntimeError.inferenceFailed("\(error)"))
|
|
}
|
|
// 正常结束 / 异常 / 取消(checkCancellation 抛出后被上面 catch 吞掉)都会走到这,
|
|
// 闸门一定释放,不会死锁后续推理。
|
|
self.releaseGate()
|
|
}
|
|
// 消费者取消/流终止时取消内部 Task(与 LLMSession / HealthExportService 一致)。
|
|
continuation.onTermination = { _ in task.cancel() }
|
|
}
|
|
}
|
|
|
|
/// MNN(CPU/SME2)文本流式生成。结构与 MLX 分支一致:进闸门、串行解码、记录速率。
|
|
private func mnnGenerate(prompt: String, maxTokens: Int) -> AsyncThrowingStream<TokenChunk, Error> {
|
|
let ready = (mnnStatus == .ready)
|
|
return AsyncThrowingStream { continuation in
|
|
let task = Task {
|
|
guard ready else {
|
|
continuation.finish(throwing: AIRuntimeError.notReady)
|
|
return
|
|
}
|
|
await self.acquireGate()
|
|
do {
|
|
let stream = await self.mnn.generate(prompt: prompt, maxTokens: maxTokens)
|
|
for try await chunk in stream {
|
|
try Task.checkCancellation()
|
|
self.recordRate(chunk.decodeRate)
|
|
continuation.yield(chunk)
|
|
}
|
|
continuation.finish()
|
|
} catch {
|
|
continuation.finish(throwing: AIRuntimeError.inferenceFailed("\(error)"))
|
|
}
|
|
self.releaseGate()
|
|
}
|
|
continuation.onTermination = { _ in task.cancel() }
|
|
}
|
|
}
|
|
|
|
private func recordRate(_ rate: Double) {
|
|
if rate > 0 { lastDecodeRate = rate }
|
|
}
|
|
|
|
// MARK: - VL
|
|
|
|
/// 加载 VL 模型。幂等,首调真正 load。
|
|
func prepareVL() async throws {
|
|
while vlStatus == .loading {
|
|
try await Task.sleep(nanoseconds: 80_000_000)
|
|
}
|
|
if vlStatus == .ready { return }
|
|
|
|
// 同 prepare():用 isComplete 排除半下载(避免在残缺权重上崩溃),与下载服务判据一致。
|
|
guard ModelStore.shared.isComplete(for: .vl) else {
|
|
vlStatus = .error("VL 模型未就绪")
|
|
throw AIRuntimeError.notReady
|
|
}
|
|
|
|
// 进闸门:等所有在跑的推理(可能是 LLM 文本流)结束,再卸 LLM + 载 VL。
|
|
// —— 这正是「异常项快拍识别时 App 自动退出」的主因防护。
|
|
await acquireGate()
|
|
defer { releaseGate() }
|
|
if vlStatus == .ready { return }
|
|
|
|
// OOM 闸门(§3.1):加载 VL(~3GB)前先卸 LLM(~1GB),否则两者常驻叠加冲过内存上限被 jetsam 杀。
|
|
unloadLLM()
|
|
await unloadMNN()
|
|
|
|
vlStatus = .loading
|
|
do {
|
|
let session = try await VLSession.load(
|
|
folderURL: ModelStore.shared.localURL(for: .vl)
|
|
)
|
|
self.vlSession = session
|
|
vlStatus = .ready
|
|
} catch {
|
|
vlStatus = .error("\(error)")
|
|
throw AIRuntimeError.modelLoadFailed("\(error)")
|
|
}
|
|
}
|
|
|
|
// MARK: - 卸载(OOM 闸门)
|
|
|
|
/// 卸载 LLM,释放 ModelContainer 引用并清 MLX 显存缓存。幂等。
|
|
/// 注:只在持有推理闸门时调用(prepareVL 内),此刻不会有 LLM 流在解码,卸载即时生效。
|
|
private func unloadLLM() {
|
|
guard llmSession != nil else { return }
|
|
llmSession = nil
|
|
status = .notReady
|
|
MLX.Memory.clearCache()
|
|
}
|
|
|
|
/// 卸载 VL,释放 ModelContainer 引用并清 MLX 显存缓存。幂等。
|
|
private func unloadVL() {
|
|
guard vlSession != nil else { return }
|
|
vlSession = nil
|
|
vlStatus = .notReady
|
|
MLX.Memory.clearCache()
|
|
}
|
|
|
|
/// 图像 → JSON 字符串(由 VLPrompts.reportExtraction 引导)。
|
|
/// 调用方负责解析 + 失败回退(§3.2)。
|
|
/// 推理闸门保证本调用与 LLM.generate() 的解码串行,不会同时占显存 OOM。
|
|
func analyzeReport(imageURLs: [URL],
|
|
prompt: String,
|
|
maxTokens: Int = 512) async throws -> String {
|
|
guard vlStatus == .ready, let session = vlSession else {
|
|
throw AIRuntimeError.notReady
|
|
}
|
|
await acquireGate()
|
|
defer { releaseGate() }
|
|
do {
|
|
return try await session.analyze(
|
|
imageURLs: imageURLs,
|
|
prompt: prompt,
|
|
maxTokens: maxTokens
|
|
)
|
|
} catch {
|
|
throw AIRuntimeError.inferenceFailed("\(error)")
|
|
}
|
|
}
|
|
}
|