import Foundation import MLX import MLXVLM import MLXLMCommon /// 封装 MLX VL 模型(Qwen2.5-VL)的图像 → 文本推理。 /// 与 LLMSession 同款 actor 隔离,串行化由上游 AIRuntime 统一保证。 actor VLSession { let container: ModelContainer init(container: ModelContainer) { self.container = container } private static func withDeviceOverride( _ body: () async throws -> R ) async rethrows -> R { #if targetEnvironment(simulator) return try await Device.withDefaultDevice(.cpu, body) #else return try await body() #endif } /// 从本地目录加载 VL 模型(包含 config.json + weights + tokenizer + processor)。 static func load(folderURL: URL) async throws -> VLSession { let configuration = ModelConfiguration(directory: folderURL) let container = try await withDeviceOverride { try await VLMModelFactory.shared.loadContainer( configuration: configuration ) } return VLSession(container: container) } /// 一次性生成(等收完所有 token 再返回完整字符串)。 /// VL 用于结构化 JSON 抽取,不需要流式 — 也避免半成品 JSON 抖动 UI。 /// - Parameters: /// - imageURLs: 本地 file:// URL,从 FileVault 拿 /// - prompt: 文本指令(VLPrompts.reportExtraction) /// - maxTokens: 默认 512(JSON 体量 ≈ 200-400) func analyze(imageURLs: [URL], prompt: String, maxTokens: Int = 512) async throws -> String { try await Self.withDeviceOverride { try await container.perform { (context: ModelContext) in let images = imageURLs.map { UserInput.Image.url($0) } let userInput = UserInput(prompt: prompt, images: images) let lmInput = try await context.processor.prepare(input: userInput) let parameters = GenerateParameters( maxTokens: maxTokens, temperature: Float(0.2), // JSON 要稳,温度低 topP: Float(0.9) ) var collected = "" for await event in try MLXLMCommon.generate( input: lmInput, parameters: parameters, context: context ) { if Task.isCancelled { break } if case .chunk(let text) = event { collected.append(text) } } return collected } } } }