diff --git a/康康/AI/FileDownloader.swift b/康康/AI/FileDownloader.swift new file mode 100644 index 0000000..673a215 --- /dev/null +++ b/康康/AI/FileDownloader.swift @@ -0,0 +1,158 @@ +import Foundation + +enum DownloadError: Error, LocalizedError { + case badStatus(Int) + case sizeMismatch(expected: Int, got: Int) + + var errorDescription: String? { + switch self { + case .badStatus(let code): + return "下载失败(HTTP \(code))" + case .sizeMismatch(let expected, let got): + return "文件大小校验失败(预期 \(expected),实际 \(got))" + } + } +} + +/// 下载单个文件,支持 HTTP Range 断点续传 + 完成后大小校验。 +/// 用 `URLSessionDataDelegate` 把响应体分块写入 `.part`,完成后原子改名为成品。 +/// +/// 注意:文件大小一律用 `FileManager.attributesOfItem` 读取,**不用** +/// `URL.resourceValues(.fileSizeKey)` —— 后者会把结果缓存在 URL 实例上, +/// 续传时先读 offset 再读 finalSize 会拿到下载前的陈旧大小,导致误判校验失败。 +/// +/// 一个实例一次处理一个文件(串行)。共享状态用锁保证可见性。 +final class FileDownloader: NSObject, URLSessionDataDelegate, @unchecked Sendable { + private let configuration: URLSessionConfiguration + + private let lock = NSLock() + private var handle: FileHandle? + private var written: Int = 0 + private var onProgress: ((Int) -> Void)? + private var responseError: Error? + private var continuation: CheckedContinuation? + + init(configuration: URLSessionConfiguration = .default) { + self.configuration = configuration + super.init() + } + + /// 不走 URL 资源值缓存的文件大小读取。 + static func fileSize(at url: URL) -> Int { + guard let attrs = try? FileManager.default.attributesOfItem(atPath: url.path), + let size = attrs[.size] as? Int else { return 0 } + return size + } + + /// 从 `url` 下载到 `destination`。若存在 `destination.part` 则发 Range 请求续传; + /// 完成后校验总大小 == `expectedBytes`,通过则原子改名为 `destination`。 + nonisolated func download( + from url: URL, + to destination: URL, + expectedBytes: Int, + onProgress: (@Sendable (Int) -> Void)? = nil + ) async throws { + let fm = FileManager.default + let part = destination.appendingPathExtension("part") + + // 成品已存在且大小正确 → 跳过 + if Self.fileSize(at: destination) == expectedBytes, + fm.fileExists(atPath: destination.path) { + return + } + + try fm.createDirectory( + at: destination.deletingLastPathComponent(), withIntermediateDirectories: true) + + var offset = 0 + if fm.fileExists(atPath: part.path) { + offset = Self.fileSize(at: part) + } else { + fm.createFile(atPath: part.path, contents: nil) + } + + let fileHandle = try FileHandle(forWritingTo: part) + try fileHandle.seekToEnd() + + lock.lock() + self.handle = fileHandle + self.written = offset + self.onProgress = onProgress + self.responseError = nil + lock.unlock() + + var request = URLRequest(url: url) + if offset > 0 { + request.setValue("bytes=\(offset)-", forHTTPHeaderField: "Range") + } + + let session = URLSession(configuration: configuration, delegate: self, delegateQueue: nil) + defer { session.finishTasksAndInvalidate() } + + // 句柄在 didCompleteWithError 内关闭(同一 delegate 队列,串行于 didReceive)。 + try await withCheckedThrowingContinuation { (cont: CheckedContinuation) in + lock.lock() + self.continuation = cont + lock.unlock() + session.dataTask(with: request).resume() + } + + let finalSize = Self.fileSize(at: part) + guard finalSize == expectedBytes else { + try? fm.removeItem(at: part) + throw DownloadError.sizeMismatch(expected: expectedBytes, got: finalSize) + } + + if fm.fileExists(atPath: destination.path) { + try fm.removeItem(at: destination) + } + try fm.moveItem(at: part, to: destination) + } + + // MARK: - URLSessionDataDelegate (全部在串行 delegate 队列执行) + + nonisolated func urlSession( + _ session: URLSession, dataTask: URLSessionDataTask, + didReceive response: URLResponse, + completionHandler: @escaping (URLSession.ResponseDisposition) -> Void + ) { + if let http = response as? HTTPURLResponse, http.statusCode >= 400 { + lock.lock(); responseError = DownloadError.badStatus(http.statusCode); lock.unlock() + completionHandler(.cancel) + } else { + completionHandler(.allow) + } + } + + nonisolated func urlSession( + _ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data + ) { + lock.lock() + try? handle?.write(contentsOf: data) + written += data.count + let progress = written + let callback = onProgress + lock.unlock() + callback?(progress) + } + + nonisolated func urlSession( + _ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error? + ) { + lock.lock() + try? handle?.close() + handle = nil + let cont = continuation + continuation = nil + let respErr = responseError + lock.unlock() + + if let respErr { + cont?.resume(throwing: respErr) + } else if let error { + cont?.resume(throwing: error) + } else { + cont?.resume() + } + } +} diff --git a/康康/AI/ModelManifest.swift b/康康/AI/ModelManifest.swift new file mode 100644 index 0000000..fc4413c --- /dev/null +++ b/康康/AI/ModelManifest.swift @@ -0,0 +1,59 @@ +import Foundation + +/// 模型文件清单中的一项:相对模型目录的路径 + 预期字节数(用于总进度计算与下载后大小校验)。 +struct ModelFile: Equatable, Sendable { + let path: String + let bytes: Int +} + +/// 硬编码的模型文件清单与下载源。 +/// 只列加载必需的功能文件,排除 README.md / .gitattributes(省下载)。 +/// 字节数与服务器素材逐一核对一致,见 +/// docs/superpowers/specs/2026-05-29-model-download-design.md 附录 A。 +enum ModelManifest { + /// 自建 Caddy 静态服务(用户自建 HTTPS 反代)。 + /// 备选纯 IP(需 App 端 ATS 例外): http://101.132.124.52:5244/ + static let baseURL = URL(string: "https://file.myv0.com/")! + + static func files(for kind: ModelKind) -> [ModelFile] { + switch kind { + case .llm: + return [ + ModelFile(path: "config.json", bytes: 937), + ModelFile(path: "model.safetensors", bytes: 968_080_210), + ModelFile(path: "model.safetensors.index.json", bytes: 49_731), + ModelFile(path: "tokenizer.json", bytes: 11_422_654), + ModelFile(path: "tokenizer_config.json", bytes: 9_706), + ModelFile(path: "vocab.json", bytes: 2_776_833), + ModelFile(path: "merges.txt", bytes: 1_671_853), + ModelFile(path: "special_tokens_map.json", bytes: 613), + ModelFile(path: "added_tokens.json", bytes: 707), + ] + case .vl: + return [ + ModelFile(path: "config.json", bytes: 1_659), + ModelFile(path: "model.safetensors", bytes: 3_073_720_461), + ModelFile(path: "model.safetensors.index.json", bytes: 108_307), + ModelFile(path: "tokenizer.json", bytes: 11_421_896), + ModelFile(path: "tokenizer_config.json", bytes: 7_256), + ModelFile(path: "vocab.json", bytes: 2_776_833), + ModelFile(path: "merges.txt", bytes: 1_671_853), + ModelFile(path: "special_tokens_map.json", bytes: 613), + ModelFile(path: "added_tokens.json", bytes: 605), + ModelFile(path: "chat_template.json", bytes: 1_050), + ModelFile(path: "preprocessor_config.json", bytes: 350), + ] + } + } + + static func totalBytes(for kind: ModelKind) -> Int { + files(for: kind).reduce(0) { $0 + $1.bytes } + } + + /// 单个文件的下载 URL = baseURL / <仓库名> / <相对路径>。 + static func fileURL(for kind: ModelKind, file: ModelFile) -> URL { + baseURL + .appendingPathComponent(kind.rawValue, isDirectory: true) + .appendingPathComponent(file.path) + } +} diff --git a/康康/AI/ModelStore.swift b/康康/AI/ModelStore.swift index 613ca4b..c76a2f4 100644 --- a/康康/AI/ModelStore.swift +++ b/康康/AI/ModelStore.swift @@ -84,4 +84,55 @@ final class ModelStore: @unchecked Sendable { } try FileManager.default.copyItem(at: bundleURL, to: target) } + + // MARK: - 下载 / 导入支撑 + + /// 模型目录下某个相对路径文件的本地 URL。 + nonisolated func fileURL(for kind: ModelKind, relativePath: String) -> URL { + localURL(for: kind).appendingPathComponent(relativePath) + } + + /// 本地该文件当前字节数,不存在返回 0(用于断点续传偏移与跳过判断)。 + nonisolated func localBytes(for kind: ModelKind, relativePath: String) -> Int { + let url = fileURL(for: kind, relativePath: relativePath) + guard let size = try? url.resourceValues(forKeys: [.fileSizeKey]).fileSize else { return 0 } + return size + } + + /// 按清单校验模型是否完整:每个文件都存在且大小等于预期。 + /// `files` 默认取 `ModelManifest`;测试可注入小清单。 + nonisolated func isComplete(for kind: ModelKind, files: [ModelFile]? = nil) -> Bool { + let manifest = files ?? ModelManifest.files(for: kind) + guard !manifest.isEmpty else { return false } + for file in manifest where localBytes(for: kind, relativePath: file.path) != file.bytes { + return false + } + return true + } + + /// 旁路导入:把一个含 config.json 的模型文件夹整体拷入沙盒(现场重装兜底)。 + nonisolated func importModel(_ kind: ModelKind, from sourceFolder: URL) throws { + let configPath = sourceFolder.appendingPathComponent(kind.sentinelFilename).path + guard FileManager.default.fileExists(atPath: configPath) else { + throw ModelStoreError.missingConfig + } + let target = localURL(for: kind) + if FileManager.default.fileExists(atPath: target.path) { + try FileManager.default.removeItem(at: target) + } + try FileManager.default.createDirectory( + at: target.deletingLastPathComponent(), withIntermediateDirectories: true) + try FileManager.default.copyItem(at: sourceFolder, to: target) + } +} + +enum ModelStoreError: Error, LocalizedError { + case missingConfig + + var errorDescription: String? { + switch self { + case .missingConfig: + return "所选文件夹缺少 config.json,不是有效的模型目录" + } + } } diff --git a/康康/Features/Me/MeView.swift b/康康/Features/Me/MeView.swift index a7479b5..bcccea3 100644 --- a/康康/Features/Me/MeView.swift +++ b/康康/Features/Me/MeView.swift @@ -7,6 +7,8 @@ struct MeView: View { @Query private var reminders: [MetricReminder] @Query private var customMetrics: [CustomMonitorMetric] + @State private var downloadService = ModelDownloadService.shared + private var profile: UserProfile? { profiles.first } private var enabledReminderCount: Int { reminders.filter(\.enabled).count } @@ -17,9 +19,7 @@ struct MeView: View { profileCard remindersCard customMetricsCard - settingsCard(title: "模型管理", - detail: "未配置", - icon: "cpu") + modelManagementCard settingsCard(title: "Face ID 启动锁", detail: "关闭", icon: "faceid") @@ -42,6 +42,7 @@ struct MeView: View { if profiles.isEmpty { _ = UserProfileStore.loadOrCreate(in: ctx) } + downloadService.refreshStates() } } } @@ -161,6 +162,23 @@ struct MeView: View { return "\(customMetrics.count) 项" } + private var modelManagementCard: some View { + NavigationLink { + ModelManagementView() + } label: { + settingsCard(title: "模型管理", detail: modelDetail, icon: "cpu") + } + .buttonStyle(.plain) + } + + private var modelDetail: String { + let states = downloadService.states + if ModelKind.allCases.allSatisfy({ states[$0]?.phase == .ready }) { return "已就绪" } + if downloadService.isAnyDownloading { return "下载中…" } + let readyCount = ModelKind.allCases.filter { states[$0]?.phase == .ready }.count + return readyCount == 0 ? "未下载" : "\(readyCount)/\(ModelKind.allCases.count) 就绪" + } + private func settingsCard(title: String, detail: String, icon: String) -> some View { HStack(spacing: 12) { ZStack { diff --git a/康康/Features/Me/ModelManagementView.swift b/康康/Features/Me/ModelManagementView.swift new file mode 100644 index 0000000..6836b41 --- /dev/null +++ b/康康/Features/Me/ModelManagementView.swift @@ -0,0 +1,226 @@ +import SwiftUI +import Network +import UniformTypeIdentifiers + +/// 「我的 · 模型管理」页:分模型卡片显示下载状态/进度,支持下载全部/暂停 + 旁路文件导入。 +/// 只观察 ModelDownloadService 的状态,不直接碰 URLSession(§3.1)。 +struct ModelManagementView: View { + @State private var service = ModelDownloadService.shared + @State private var isCellular = false + @State private var showCellularConfirm = false + @State private var showImporter = false + @State private var importError: String? + + private let monitor = NWPathMonitor() + private let monitorQueue = DispatchQueue(label: "kk.netmonitor") + + private var allReady: Bool { + ModelKind.allCases.allSatisfy { service.states[$0]?.phase == .ready } + } + + var body: some View { + ScrollView { + VStack(spacing: 14) { + ForEach(ModelKind.allCases, id: \.self) { kind in + modelCard(kind) + } + + actionButtons + .padding(.top, 4) + + if let importError { + Text(importError) + .font(.system(size: 12)) + .foregroundStyle(Tj.Palette.brick) + .frame(maxWidth: .infinity, alignment: .leading) + } + + footer + .padding(.top, 8) + } + .padding(.horizontal, 16) + .padding(.vertical, 18) + } + .background(Tj.Palette.sand.ignoresSafeArea()) + .navigationTitle("模型管理") + .navigationBarTitleDisplayMode(.inline) + .onAppear { + service.refreshStates() + monitor.pathUpdateHandler = { path in + let cellular = path.status == .satisfied && path.usesInterfaceType(.cellular) + Task { @MainActor in isCellular = cellular } + } + monitor.start(queue: monitorQueue) + } + .onDisappear { monitor.cancel() } + .fileImporter(isPresented: $showImporter, + allowedContentTypes: [.folder]) { handleImport($0) } + .alert("使用蜂窝网络下载?", isPresented: $showCellularConfirm) { + Button("取消", role: .cancel) {} + Button("继续下载") { service.downloadAll() } + } message: { + Text("模型约 \(formatBytes(totalAllBytes)),建议在 Wi-Fi 下下载。") + } + } + + // MARK: - 模型卡片 + + private func modelCard(_ kind: ModelKind) -> some View { + let state = service.states[kind] + ?? DownloadState(phase: .idle, receivedBytes: 0, + totalBytes: ModelManifest.totalBytes(for: kind), bytesPerSecond: 0) + return VStack(alignment: .leading, spacing: 10) { + HStack(alignment: .top) { + VStack(alignment: .leading, spacing: 3) { + Text(kind.displayName) + .font(.system(size: 15, weight: .semibold)) + .foregroundStyle(Tj.Palette.text) + Text(subtitle(kind)) + .font(.system(size: 12)) + .foregroundStyle(Tj.Palette.text3) + } + Spacer() + statusBadge(state.phase) + } + + if state.phase == .downloading { + ProgressView(value: min(max(state.fraction, 0), 1)) + .tint(Tj.Palette.ink) + HStack { + Text("\(Int(state.fraction * 100))%") + Spacer() + Text(speedText(state)) + } + .font(.system(size: 11, design: .monospaced)) + .foregroundStyle(Tj.Palette.text3) + } else { + HStack { + Text(formatBytes(ModelManifest.totalBytes(for: kind))) + .font(.system(size: 11, design: .monospaced)) + .foregroundStyle(Tj.Palette.text3) + Spacer() + if case .failed(let message) = state.phase { + Text(message) + .font(.system(size: 11)) + .foregroundStyle(Tj.Palette.brick) + .lineLimit(1) + } + } + } + } + .padding(14) + .frame(maxWidth: .infinity, alignment: .leading) + .tjCard() + .contentShape(Rectangle()) + .onTapGesture { + if case .failed = state.phase { service.download(kind) } + } + } + + private func statusBadge(_ phase: DownloadPhase) -> some View { + switch phase { + case .idle: return TjBadge(text: "待下载", style: .neutral) + case .downloading: return TjBadge(text: "下载中", style: .amber) + case .verifying: return TjBadge(text: "校验中", style: .amber) + case .ready: return TjBadge(text: "已就绪", style: .leaf) + case .failed: return TjBadge(text: "失败 · 重试", style: .brick) + } + } + + // MARK: - 动作按钮 + + @ViewBuilder + private var actionButtons: some View { + if service.isAnyDownloading { + Button { + for kind in ModelKind.allCases { service.cancel(kind) } + } label: { + Text("暂停下载").frame(maxWidth: .infinity) + } + .buttonStyle(TjGhostButton()) + } else if allReady { + HStack(spacing: 6) { + Image(systemName: "checkmark.seal.fill") + Text("两个模型都已就绪") + } + .font(.system(size: 13, weight: .semibold)) + .foregroundStyle(Tj.Palette.leaf) + .frame(maxWidth: .infinity) + .padding(.vertical, 6) + } else { + Button { + if isCellular { showCellularConfirm = true } else { service.downloadAll() } + } label: { + Text("下载全部模型 · \(formatBytes(totalAllBytes))") + .frame(maxWidth: .infinity) + } + .buttonStyle(TjPrimaryButton()) + } + + Button { + importError = nil + showImporter = true + } label: { + Text("从文件导入(离线)").frame(maxWidth: .infinity) + } + .buttonStyle(TjGhostButton()) + } + + private var footer: some View { + VStack(spacing: 8) { + TjLockChip() + Text("100% 本地推理 · 模型仅需下载一次") + .font(.system(size: 11)) + .foregroundStyle(Tj.Palette.text3) + } + .frame(maxWidth: .infinity) + } + + // MARK: - 旁路导入 + + private func handleImport(_ result: Result) { + do { + let folder = try result.get() + let scoped = folder.startAccessingSecurityScopedResource() + defer { if scoped { folder.stopAccessingSecurityScopedResource() } } + + let name = folder.lastPathComponent + guard let kind = ModelKind.allCases.first(where: { $0.rawValue == name }) else { + importError = "请选择名为 Qwen3-1.7B-4bit 或 Qwen2.5-VL-3B-Instruct-4bit 的文件夹" + return + } + try service.importModel(kind, from: folder) + importError = nil + } catch { + importError = "导入失败:\(error.localizedDescription)" + } + } + + // MARK: - 辅助 + + private var totalAllBytes: Int { + ModelKind.allCases.reduce(0) { $0 + ModelManifest.totalBytes(for: $1) } + } + + private func subtitle(_ kind: ModelKind) -> String { + switch kind { + case .llm: return "文本解读 · 趋势 / 问答" + case .vl: return "拍照识别报告 → 结构化指标" + } + } + + private func formatBytes(_ bytes: Int) -> String { + ByteCountFormatter.string(fromByteCount: Int64(bytes), countStyle: .file) + } + + private func speedText(_ state: DownloadState) -> String { + guard state.bytesPerSecond > 0 else { return "—" } + return formatBytes(Int(state.bytesPerSecond)) + "/s" + } +} + +#Preview { + NavigationStack { + ModelManagementView() + } +} diff --git a/康康/Services/ModelDownloadService.swift b/康康/Services/ModelDownloadService.swift new file mode 100644 index 0000000..8ef3e26 --- /dev/null +++ b/康康/Services/ModelDownloadService.swift @@ -0,0 +1,147 @@ +import Foundation +import Observation + +/// 模型下载编排:遍历 ModelManifest 逐文件串行下载,聚合进度,支持暂停/重试/旁路导入。 +/// UI 只观察 `states`,不直接碰 URLSession(§3.1 模块边界)。 +/// 核心下载/校验逻辑在 `FileDownloader`,文件路径/就绪判定在 `ModelStore`。 +@MainActor +@Observable +final class ModelDownloadService { + static let shared = ModelDownloadService() + + private(set) var states: [ModelKind: DownloadState] = [:] + + private let store: ModelStore + private var tasks: [ModelKind: Task] = [:] + private var lastSampleTime: [ModelKind: Date] = [:] + private var lastSampleBytes: [ModelKind: Int] = [:] + + init(store: ModelStore = .shared) { + self.store = store + refreshStates() + } + + /// 根据沙盒现状刷新每个模型的状态(已完整→ready,否则 idle)。 + func refreshStates() { + for kind in ModelKind.allCases { + let total = ModelManifest.totalBytes(for: kind) + if store.isComplete(for: kind) { + states[kind] = DownloadState(phase: .ready, receivedBytes: total, + totalBytes: total, bytesPerSecond: 0) + } else if states[kind]?.phase == .downloading { + continue // 不打断进行中的下载 + } else { + states[kind] = DownloadState(phase: .idle, receivedBytes: completedBytes(for: kind), + totalBytes: total, bytesPerSecond: 0) + } + } + } + + var isAnyDownloading: Bool { + states.values.contains { $0.phase == .downloading } + } + + /// 下载某个模型。幂等:已在下载或已就绪则忽略。 + func download(_ kind: ModelKind) { + guard tasks[kind] == nil, states[kind]?.phase != .ready else { return } + let total = ModelManifest.totalBytes(for: kind) + states[kind] = DownloadState(phase: .downloading, receivedBytes: completedBytes(for: kind), + totalBytes: total, bytesPerSecond: 0) + lastSampleTime[kind] = Date() + lastSampleBytes[kind] = completedBytes(for: kind) + + let task = Task { [weak self] in + guard let self else { return } + await self.run(kind) + } + tasks[kind] = task + } + + func downloadAll() { + for kind in ModelKind.allCases { download(kind) } + } + + /// 暂停下载。已下载的 .part 保留,下次从断点续传。 + func cancel(_ kind: ModelKind) { + tasks[kind]?.cancel() + tasks[kind] = nil + let total = ModelManifest.totalBytes(for: kind) + states[kind] = DownloadState(phase: .idle, receivedBytes: completedBytes(for: kind), + totalBytes: total, bytesPerSecond: 0) + } + + /// 旁路导入:从用户选择的文件夹拷入模型(现场重装兜底)。 + func importModel(_ kind: ModelKind, from folder: URL) throws { + try store.importModel(kind, from: folder) + refreshStates() + } + + // MARK: - 内部 + + private func run(_ kind: ModelKind) async { + let files = ModelManifest.files(for: kind) + let downloader = FileDownloader() + var completedBefore = 0 + + do { + for file in files { + if Task.isCancelled { return } + let destination = store.fileURL(for: kind, relativePath: file.path) + let base = completedBefore + try await downloader.download( + from: ModelManifest.fileURL(for: kind, file: file), + to: destination, + expectedBytes: file.bytes, + onProgress: { [weak self] received in + Task { @MainActor in + self?.applyProgress(kind, currentTotal: base + received) + } + } + ) + completedBefore += file.bytes + } + finish(kind, success: true, message: nil) + } catch { + if Task.isCancelled { + // cancel() 已设置 idle 状态 + } else { + finish(kind, success: false, message: error.localizedDescription) + } + } + } + + private func applyProgress(_ kind: ModelKind, currentTotal: Int) { + guard var state = states[kind], state.phase == .downloading else { return } + let now = Date() + if let lastTime = lastSampleTime[kind], let lastBytes = lastSampleBytes[kind] { + let dt = now.timeIntervalSince(lastTime) + if dt >= 0.5 { + state.bytesPerSecond = Double(currentTotal - lastBytes) / dt + lastSampleTime[kind] = now + lastSampleBytes[kind] = currentTotal + } + } + state.receivedBytes = currentTotal + states[kind] = state + } + + private func finish(_ kind: ModelKind, success: Bool, message: String?) { + tasks[kind] = nil + let total = ModelManifest.totalBytes(for: kind) + if success { + states[kind] = DownloadState(phase: .ready, receivedBytes: total, + totalBytes: total, bytesPerSecond: 0) + } else { + states[kind] = DownloadState(phase: .failed(message ?? "下载失败"), + receivedBytes: completedBytes(for: kind), + totalBytes: total, bytesPerSecond: 0) + } + } + + /// 已完整下载的文件字节之和(用于续传时的起始进度)。 + private func completedBytes(for kind: ModelKind) -> Int { + ModelManifest.files(for: kind).reduce(0) { sum, file in + store.localBytes(for: kind, relativePath: file.path) == file.bytes ? sum + file.bytes : sum + } + } +} diff --git a/康康/Services/ModelDownloadTypes.swift b/康康/Services/ModelDownloadTypes.swift new file mode 100644 index 0000000..3ceefff --- /dev/null +++ b/康康/Services/ModelDownloadTypes.swift @@ -0,0 +1,22 @@ +import Foundation + +/// 单个模型的下载阶段。 +enum DownloadPhase: Equatable, Sendable { + case idle // 待下载 + case downloading // 下载中 + case verifying // 校验中 + case ready // 已就绪 + case failed(String) // 失败 · 可重试 +} + +/// 单个模型的下载状态快照,供 UI 观察。 +struct DownloadState: Equatable, Sendable { + var phase: DownloadPhase + var receivedBytes: Int + var totalBytes: Int + var bytesPerSecond: Double + + var fraction: Double { + totalBytes > 0 ? Double(receivedBytes) / Double(totalBytes) : 0 + } +} diff --git a/康康Tests/ModelDownloadCoreTests.swift b/康康Tests/ModelDownloadCoreTests.swift new file mode 100644 index 0000000..d0f1da5 --- /dev/null +++ b/康康Tests/ModelDownloadCoreTests.swift @@ -0,0 +1,136 @@ +import Testing +import Foundation +@testable import 康康 + +// MARK: - Mock 网络层 + +/// 按 URL 注册完整响应体,startLoading 时按请求的 Range header 自动切片返回(206)或全量(200)。 +/// 每个测试用唯一 URL 注册自己的内容 → 测试间不会互相覆盖,无需依赖执行顺序或可见性。 +final class MockURLProtocol: URLProtocol, @unchecked Sendable { + private static let lock = NSLock() + private static var bodies: [String: Data] = [:] + + static func register(_ url: URL, body: Data) { + lock.lock(); defer { lock.unlock() } + bodies[url.path] = body + } + static func reset() { + lock.lock(); defer { lock.unlock() } + bodies.removeAll() + } + private static func body(forPath path: String) -> Data? { + lock.lock(); defer { lock.unlock() } + return bodies[path] + } + + override class func canInit(with request: URLRequest) -> Bool { true } + override class func canonicalRequest(for request: URLRequest) -> URLRequest { request } + + override func startLoading() { + guard let url = request.url, let full = Self.body(forPath: url.path) else { + client?.urlProtocol(self, didFailWithError: URLError(.fileDoesNotExist)) + return + } + var data = full + var status = 200 + var headers: [String: String] = [:] + if let range = request.value(forHTTPHeaderField: "Range"), + let start = Self.parseRangeStart(range), start <= full.count { + data = Data(full.suffix(from: start)) + status = 206 + headers["Content-Range"] = "bytes \(start)-\(full.count - 1)/\(full.count)" + } + let response = HTTPURLResponse( + url: url, statusCode: status, httpVersion: nil, headerFields: headers)! + client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) + client?.urlProtocol(self, didLoad: data) + client?.urlProtocolDidFinishLoading(self) + } + + override func stopLoading() {} + + /// "bytes=2-" → 2 + private static func parseRangeStart(_ s: String) -> Int? { + guard let eq = s.firstIndex(of: "="), let dash = s.firstIndex(of: "-") else { return nil } + return Int(s[s.index(after: eq).. URLSessionConfiguration { + let config = URLSessionConfiguration.ephemeral + config.protocolClasses = [MockURLProtocol.self] + return config +} + +private func tempFile() -> URL { + FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + .appendingPathComponent("a.bin") +} + +private func uniqueURL() -> URL { + URL(string: "https://mock.test/\(UUID().uuidString).bin")! +} + +// MARK: - DownloadState + +struct DownloadStateTests { + @Test func fractionZeroWhenTotalZero() { + let s = DownloadState(phase: .idle, receivedBytes: 0, totalBytes: 0, bytesPerSecond: 0) + #expect(s.fraction == 0) + } + + @Test func fractionComputed() { + let s = DownloadState(phase: .downloading, receivedBytes: 50, totalBytes: 200, bytesPerSecond: 0) + #expect(s.fraction == 0.25) + } +} + +// MARK: - FileDownloader + +/// 串行执行:这些测试共享全局 URLProtocol / URLSession 状态,并行会互相干扰。 +@Suite(.serialized) +struct FileDownloaderTests { + + @Test func downloadsFileContent() async throws { + let url = uniqueURL() + MockURLProtocol.register(url, body: Data("hello".utf8)) + let dst = tempFile() + defer { try? FileManager.default.removeItem(at: dst.deletingLastPathComponent()) } + + let dl = FileDownloader(configuration: mockConfiguration()) + try await dl.download(from: url, to: dst, expectedBytes: 5) + + #expect(try Data(contentsOf: dst) == Data("hello".utf8)) + #expect(!FileManager.default.fileExists(atPath: dst.appendingPathExtension("part").path)) + } + + @Test func resumesFromPartialFile() async throws { + let url = uniqueURL() + MockURLProtocol.register(url, body: Data("hello".utf8)) + let dst = tempFile() + defer { try? FileManager.default.removeItem(at: dst.deletingLastPathComponent()) } + // 预置已下载的一半,download 应从 offset 2 续传 + try FileManager.default.createDirectory( + at: dst.deletingLastPathComponent(), withIntermediateDirectories: true) + try Data("he".utf8).write(to: dst.appendingPathExtension("part")) + + let dl = FileDownloader(configuration: mockConfiguration()) + try await dl.download(from: url, to: dst, expectedBytes: 5) + + #expect(try Data(contentsOf: dst) == Data("hello".utf8)) + } + + @Test func throwsOnSizeMismatch() async throws { + let url = uniqueURL() + MockURLProtocol.register(url, body: Data("hi".utf8)) // 仅 2 字节,期望 5 + let dst = tempFile() + defer { try? FileManager.default.removeItem(at: dst.deletingLastPathComponent()) } + + let dl = FileDownloader(configuration: mockConfiguration()) + await #expect(throws: (any Error).self) { + try await dl.download(from: url, to: dst, expectedBytes: 5) + } + #expect(!FileManager.default.fileExists(atPath: dst.path)) + } +} diff --git a/康康Tests/ModelManifestTests.swift b/康康Tests/ModelManifestTests.swift new file mode 100644 index 0000000..1bf0cb4 --- /dev/null +++ b/康康Tests/ModelManifestTests.swift @@ -0,0 +1,47 @@ +import Testing +import Foundation +@testable import 康康 + +struct ModelManifestTests { + + @Test func llmHasNineFunctionalFiles() { + #expect(ModelManifest.files(for: .llm).count == 9) + } + + @Test func vlHasElevenFunctionalFiles() { + #expect(ModelManifest.files(for: .vl).count == 11) + } + + @Test func llmTotalBytesMatchesManifest() { + #expect(ModelManifest.totalBytes(for: .llm) == 984_013_244) + } + + @Test func vlTotalBytesMatchesManifest() { + #expect(ModelManifest.totalBytes(for: .vl) == 3_089_710_883) + } + + @Test func excludesReadmeAndGitattributes() { + for kind in [ModelKind.llm, .vl] { + let names = ModelManifest.files(for: kind).map(\.path) + #expect(!names.contains("README.md")) + #expect(!names.contains(".gitattributes")) + } + } + + @Test func includesEssentialFiles() { + let llm = ModelManifest.files(for: .llm).map(\.path) + #expect(llm.contains("config.json")) + #expect(llm.contains("model.safetensors")) + #expect(llm.contains("tokenizer.json")) + + let vl = ModelManifest.files(for: .vl).map(\.path) + #expect(vl.contains("preprocessor_config.json")) // VL 拍照识别必需 + #expect(vl.contains("model.safetensors")) + } + + @Test func fileURLIsBaseSlashRepoSlashPath() { + let file = ModelFile(path: "config.json", bytes: 937) + let url = ModelManifest.fileURL(for: .llm, file: file) + #expect(url.absoluteString == "https://file.myv0.com/Qwen3-1.7B-4bit/config.json") + } +} diff --git a/康康Tests/ModelStoreDownloadSupportTests.swift b/康康Tests/ModelStoreDownloadSupportTests.swift new file mode 100644 index 0000000..6aa4622 --- /dev/null +++ b/康康Tests/ModelStoreDownloadSupportTests.swift @@ -0,0 +1,92 @@ +import Testing +import Foundation +@testable import 康康 + +struct ModelStoreDownloadSupportTests { + + private func isolatedStore() throws -> ModelStore { + let temp = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + return try ModelStore(rootURL: temp) + } + + @Test func fileURLPointsIntoModelFolder() throws { + let store = try isolatedStore() + defer { try? FileManager.default.removeItem(at: store.rootURL) } + let url = store.fileURL(for: .llm, relativePath: "config.json") + #expect(url == store.localURL(for: .llm).appendingPathComponent("config.json")) + } + + @Test func localBytesZeroWhenMissing() throws { + let store = try isolatedStore() + defer { try? FileManager.default.removeItem(at: store.rootURL) } + #expect(store.localBytes(for: .llm, relativePath: "config.json") == 0) + } + + @Test func localBytesReturnsFileSize() throws { + let store = try isolatedStore() + defer { try? FileManager.default.removeItem(at: store.rootURL) } + let folder = store.localURL(for: .llm) + try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true) + try Data(repeating: 7, count: 512).write(to: folder.appendingPathComponent("config.json")) + #expect(store.localBytes(for: .llm, relativePath: "config.json") == 512) + } + + @Test func isCompleteFalseWhenFilesMissing() throws { + let store = try isolatedStore() + defer { try? FileManager.default.removeItem(at: store.rootURL) } + let files = [ModelFile(path: "a.bin", bytes: 1024)] + #expect(store.isComplete(for: .llm, files: files) == false) + } + + @Test func isCompleteTrueWhenAllFilesPresentWithExpectedSize() throws { + let store = try isolatedStore() + defer { try? FileManager.default.removeItem(at: store.rootURL) } + let folder = store.localURL(for: .llm) + try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true) + try Data(repeating: 1, count: 1024).write(to: folder.appendingPathComponent("a.bin")) + let files = [ModelFile(path: "a.bin", bytes: 1024)] + #expect(store.isComplete(for: .llm, files: files) == true) + } + + @Test func isCompleteFalseWhenSizeMismatch() throws { + let store = try isolatedStore() + defer { try? FileManager.default.removeItem(at: store.rootURL) } + let folder = store.localURL(for: .llm) + try FileManager.default.createDirectory(at: folder, withIntermediateDirectories: true) + try Data(repeating: 1, count: 999).write(to: folder.appendingPathComponent("a.bin")) + let files = [ModelFile(path: "a.bin", bytes: 1024)] + #expect(store.isComplete(for: .llm, files: files) == false) + } + + @Test func importModelCopiesFolderAndMarksReady() throws { + let store = try isolatedStore() + defer { try? FileManager.default.removeItem(at: store.rootURL) } + + let src = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + try FileManager.default.createDirectory(at: src, withIntermediateDirectories: true) + defer { try? FileManager.default.removeItem(at: src) } + try "{}".write(to: src.appendingPathComponent("config.json"), atomically: true, encoding: .utf8) + try "x".write(to: src.appendingPathComponent("tokenizer.json"), atomically: true, encoding: .utf8) + + try store.importModel(.llm, from: src) + + #expect(store.isReady(.llm) == true) + #expect(FileManager.default.fileExists( + atPath: store.fileURL(for: .llm, relativePath: "tokenizer.json").path)) + } + + @Test func importModelThrowsWhenNoConfig() throws { + let store = try isolatedStore() + defer { try? FileManager.default.removeItem(at: store.rootURL) } + let src = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + try FileManager.default.createDirectory(at: src, withIntermediateDirectories: true) + defer { try? FileManager.default.removeItem(at: src) } + + #expect(throws: (any Error).self) { + try store.importModel(.llm, from: src) + } + } +}