Files
kangkang/康康/AI/LLMSession.swift
2026-06-10 06:42:59 +08:00

115 lines
5.3 KiB
Swift

import Foundation
import MLX
import MLXLLM
import MLXLMCommon
/// MLX ,actor 线访
/// mlx-swift-examples 2.29.1(commit 9bff95ca) API
actor LLMSession {
let container: ModelContainer
/// ( .info ,)
private(set) var lastStats: GenerateStats?
private func record(_ s: GenerateStats) { lastStats = s }
init(container: ModelContainer) {
self.container = container
}
/// simulator CPU(MLX Metal backend Sim abort)
/// body (GPU/ANE)
/// task-scoped `withDefaultDevice`,TaskLocal child Task / actor
private static func withDeviceOverride<R>(
_ body: () async throws -> R
) async rethrows -> R {
#if targetEnvironment(simulator)
return try await Device.withDefaultDevice(.cpu, body)
#else
return try await body()
#endif
}
/// ( config.json + weights + tokenizer)
static func load(folderURL: URL) async throws -> LLMSession {
let configuration = ModelConfiguration(directory: folderURL)
let container = try await withDeviceOverride {
try await LLMModelFactory.shared.loadContainer(
configuration: configuration
)
}
return LLMSession(container: container)
}
/// AsyncThrowingStream , Task
/// - Parameters:
/// - prompt: prompt ( processor LMInput)
/// - maxTokens: token , GenerateParameters
func generate(prompt: String, maxTokens: Int) -> AsyncThrowingStream<TokenChunk, Error> {
AsyncThrowingStream { continuation in
let task = Task {
do {
try await Self.withDeviceOverride {
// : App "/JSON ", JSON
// 0.3 + topP 0.85 JSON ( MNN set_config )
// repetitionPenalty: + ,()
// ;1.1 + 64 token ( MNN penalty )
let parameters = GenerateParameters(
maxTokens: maxTokens,
temperature: Float(0.3),
topP: Float(0.85),
repetitionPenalty: Float(1.1),
repetitionContextSize: 64
)
try await container.perform { (context: ModelContext) in
let userInput = UserInput(prompt: prompt)
let lmInput = try await context.processor.prepare(input: userInput)
let start = Date()
var produced = 0
for await event in try MLXLMCommon.generate(
input: lmInput,
parameters: parameters,
context: context
) {
if Task.isCancelled { break }
switch event {
case .chunk(let text):
produced += 1
let elapsed = Date().timeIntervalSince(start)
let rate = elapsed > 0 ? Double(produced) / elapsed : 0
continuation.yield(TokenChunk(text: text, decodeRate: rate))
case .info(let info):
// ,
await self.record(GenerateStats(
promptTokens: info.promptTokenCount,
genTokens: info.generationTokenCount,
prefillSeconds: info.promptTime,
decodeSeconds: info.generateTime
))
case .toolCall:
// ,switch
break
}
}
// : MLX.GPU.synchronize()
// GPU AsyncStream yield
// ,GPU
// transitive import MLX , SPM
}
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
}
continuation.onTermination = { _ in task.cancel() }
}
}
}