217 lines
8.6 KiB
Plaintext
217 lines
8.6 KiB
Plaintext
//
|
||
// MNNLLMBridge.mm
|
||
// 康康
|
||
//
|
||
// ObjC++ 实现。device 真机用 <MNN/llm/llm.hpp>;模拟器编为桩(返回不可用,上层回退 MLX)。
|
||
//
|
||
|
||
#import "MNNLLMBridge.h"
|
||
#include <sys/sysctl.h>
|
||
|
||
// MARK: - 性能统计(私有 readwrite 重声明)
|
||
@interface MNNGenerateStats ()
|
||
@property (nonatomic, readwrite) int promptTokens;
|
||
@property (nonatomic, readwrite) int genTokens;
|
||
@property (nonatomic, readwrite) double prefillMs;
|
||
@property (nonatomic, readwrite) double decodeMs;
|
||
@end
|
||
|
||
@implementation MNNGenerateStats
|
||
- (double)decodeTokensPerSecond {
|
||
return self.decodeMs > 0 ? (self.genTokens / (self.decodeMs / 1000.0)) : 0;
|
||
}
|
||
@end
|
||
|
||
// MARK: - SME2 / 可用性探测(device + simulator 都可编)
|
||
|
||
static BOOL kk_sysctlFlag(const char *name) {
|
||
int64_t v = 0; size_t sz = sizeof(v);
|
||
if (sysctlbyname(name, &v, &sz, NULL, 0) != 0) return NO;
|
||
return v != 0;
|
||
}
|
||
|
||
#if TARGET_OS_SIMULATOR
|
||
|
||
// ============ 模拟器桩:无真实 MNN ============
|
||
@implementation MNNLLMBridge
|
||
+ (BOOL)isAvailable { return NO; }
|
||
+ (BOOL)cpuSupportsSME2 { return NO; }
|
||
- (nullable instancetype)initWithConfigPath:(NSString *)configPath { return nil; }
|
||
- (BOOL)isLoaded { return NO; }
|
||
- (MNNGenerateStats *)generateText:(NSString *)prompt maxTokens:(int)maxTokens
|
||
onToken:(void (^)(NSString *))onToken { return [MNNGenerateStats new]; }
|
||
- (nullable MNNGenerateStats *)analyzeImages:(NSArray<NSString *> *)imagePaths prompt:(NSString *)prompt
|
||
maxTokens:(int)maxTokens onToken:(void (^)(NSString *))onToken
|
||
error:(NSError **)error {
|
||
if (error) *error = [NSError errorWithDomain:@"MNN" code:-1
|
||
userInfo:@{NSLocalizedDescriptionKey: @"MNN 在模拟器不可用"}];
|
||
return nil;
|
||
}
|
||
- (void)cancel {}
|
||
@end
|
||
|
||
#else
|
||
|
||
// ============ 真机:真实 MNN-LLM ============
|
||
// MNN 第三方头文件的文档注释不规范,会触发一堆 -Wdocumentation 警告(Executor/
|
||
// Tensor/Interpreter/ImageProcess.hpp)。只在解析 MNN 头时关掉该警告,不影响本项目。
|
||
#pragma clang diagnostic push
|
||
#pragma clang diagnostic ignored "-Wdocumentation"
|
||
#include <MNN/llm/llm.hpp>
|
||
#pragma clang diagnostic pop
|
||
#include <string>
|
||
#include <ostream>
|
||
#include <streambuf>
|
||
#include <atomic>
|
||
|
||
using MNN::Transformer::Llm;
|
||
|
||
namespace {
|
||
/// 把 MNN 写入 ostream 的解码文本转成 NSString 回调;按 UTF-8 完整边界聚合,避免截断多字节。
|
||
class TokenStreamBuf : public std::streambuf {
|
||
public:
|
||
TokenStreamBuf(void (^onToken)(NSString *), std::atomic<bool> *cancel)
|
||
: _onToken(onToken), _cancel(cancel) {}
|
||
void flush() {
|
||
if (_pending.empty()) return;
|
||
emitPending(); // 末尾尽力 emit(即便非完整 UTF-8 也交出去)
|
||
_pending.clear();
|
||
}
|
||
protected:
|
||
std::streamsize xsputn(const char *s, std::streamsize n) override {
|
||
append(s, (size_t)n);
|
||
return n;
|
||
}
|
||
int overflow(int c) override {
|
||
if (c != EOF) { char ch = (char)c; append(&ch, 1); }
|
||
return c;
|
||
}
|
||
private:
|
||
void append(const char *s, size_t n) {
|
||
if (_cancel && _cancel->load()) return; // 已取消,吞掉不回调
|
||
_pending.append(s, n);
|
||
// 仅当整个 pending 是合法 UTF-8 才 emit(token 通常是完整字/词,边界自然对齐)
|
||
NSString *str = [[NSString alloc] initWithBytes:_pending.data()
|
||
length:_pending.size()
|
||
encoding:NSUTF8StringEncoding];
|
||
if (str) { if (_onToken) _onToken(str); _pending.clear(); }
|
||
}
|
||
void emitPending() {
|
||
NSString *str = [[NSString alloc] initWithBytes:_pending.data()
|
||
length:_pending.size()
|
||
encoding:NSUTF8StringEncoding];
|
||
if (str && _onToken) _onToken(str);
|
||
}
|
||
void (^_onToken)(NSString *);
|
||
std::atomic<bool> *_cancel;
|
||
std::string _pending;
|
||
};
|
||
} // namespace
|
||
|
||
@implementation MNNLLMBridge {
|
||
Llm *_llm;
|
||
std::atomic<bool> _cancel;
|
||
BOOL _loaded;
|
||
}
|
||
|
||
+ (BOOL)isAvailable { return YES; }
|
||
|
||
+ (BOOL)cpuSupportsSME2 {
|
||
// Apple 通过 sysctl 暴露 ARM 特性位:FEAT_SME2(A19/iPhone17+)。
|
||
return kk_sysctlFlag("hw.optional.arm.FEAT_SME2");
|
||
}
|
||
|
||
- (nullable instancetype)initWithConfigPath:(NSString *)configPath {
|
||
self = [super init];
|
||
if (!self) return nil;
|
||
_cancel = false;
|
||
_llm = Llm::createLLM(std::string(configPath.UTF8String));
|
||
if (_llm == nullptr) return nil;
|
||
// load 前以 merge-patch 调三件事(只翻这几个叶子,保留 chat_template 等其余配置):
|
||
// ① enable_thinking=false:config.json 默认 true,模板会给每个 assistant 回合硬塞
|
||
// <think>\n 开启思考,吞掉 token 预算并污染 JSON(prompt 里的 /no_think 对此模板无效)。
|
||
// ② 降温:config.json 默认 temperature=1.0 对结构化 JSON 太高,随机性大→经常吐成非 JSON。
|
||
// 本 App 所有任务都是"直答/JSON",压到 0.3 + topP 0.85 让输出更确定、JSON 更稳。
|
||
// ③ 重复惩罚:MNN 默认 mixed_samplers 不含 "penalty"、penalty/ngram_factor=1.0(全关),
|
||
// 叠加低温 → 长文本(如「关键指标」列表)会陷入逐行复读死循环(收缩压 107 mmHg ×N)。
|
||
// 显式把 "penalty" 放进 mixed 链首,开 repetition penalty(1.1)+ n-gram 惩罚(ngram_factor 1.05):
|
||
// n-gram 命中整段重复时惩罚升到 max_penalty,直接掐断逐行复读。
|
||
_llm->set_config("{"
|
||
"\"jinja\":{\"context\":{\"enable_thinking\":false}},"
|
||
"\"sampler_type\":\"mixed\","
|
||
"\"mixed_samplers\":[\"penalty\",\"topK\",\"topP\",\"temperature\"],"
|
||
"\"temperature\":0.3,\"topP\":0.85,\"topK\":40,"
|
||
"\"penalty\":1.1,\"n_gram\":8,\"ngram_factor\":1.05"
|
||
"}");
|
||
_loaded = _llm->load();
|
||
if (!_loaded) { Llm::destroy(_llm); _llm = nullptr; return nil; }
|
||
return self;
|
||
}
|
||
|
||
- (void)dealloc {
|
||
if (_llm) { Llm::destroy(_llm); _llm = nullptr; }
|
||
}
|
||
|
||
- (BOOL)isLoaded { return _loaded; }
|
||
|
||
- (void)cancel { _cancel = true; }
|
||
|
||
// 统一生成:full 已是最终 prompt(文本,或含 <img>路径</img> 标签)。
|
||
// 多模态模型 createLLM 返回 Omni,response 解析 <img> 标签并对路径 CV::imread(OMNI 框架内)。
|
||
- (MNNGenerateStats *)runResponse:(NSString *)full
|
||
maxTokens:(int)maxTokens
|
||
onToken:(void (^)(NSString *))onToken {
|
||
_cancel = false;
|
||
TokenStreamBuf buf(onToken, &_cancel);
|
||
std::ostream os(&buf);
|
||
if (_llm) {
|
||
// 红线:本 App 每次 generate/analyze 都是一次性独立推理(无多轮对话语义)。
|
||
// MNN 的 Llm::response 默认把本轮 prompt+输出累积进 history_tokens / KV cache,
|
||
// 不 reset 的话第二次导出会把上一次的完整上下文叠加进来 → all_seq_len 暴涨、
|
||
// 冲过上下文上限 → 崩溃(用户报「再次导出死机」)。每轮先 reset 清空历史,
|
||
// 与 MLX LLMSession 的「每次 generate 无状态」保持一致。
|
||
_llm->reset();
|
||
_llm->response(std::string(full.UTF8String), &os, nullptr, maxTokens);
|
||
}
|
||
buf.flush();
|
||
return [self statsFromContext];
|
||
}
|
||
|
||
- (MNNGenerateStats *)generateText:(NSString *)prompt
|
||
maxTokens:(int)maxTokens
|
||
onToken:(void (^)(NSString *))onToken {
|
||
return [self runResponse:prompt maxTokens:maxTokens onToken:onToken];
|
||
}
|
||
|
||
- (nullable MNNGenerateStats *)analyzeImages:(NSArray<NSString *> *)imagePaths
|
||
prompt:(NSString *)prompt
|
||
maxTokens:(int)maxTokens
|
||
onToken:(void (^)(NSString *))onToken
|
||
error:(NSError **)error {
|
||
// 在 prompt 前拼 <img>本地路径</img>;Omni 解析标签并对路径 imread(需 OMNI 框架)。
|
||
NSMutableString *full = [NSMutableString string];
|
||
for (NSString *p in imagePaths) {
|
||
[full appendFormat:@"<img>%@</img>", p];
|
||
}
|
||
[full appendString:prompt];
|
||
return [self runResponse:full maxTokens:maxTokens onToken:onToken];
|
||
}
|
||
|
||
- (MNNGenerateStats *)statsFromContext {
|
||
MNNGenerateStats *s = [MNNGenerateStats new];
|
||
if (_llm) {
|
||
const MNN::Transformer::LlmContext *ctx = _llm->getContext();
|
||
if (ctx) {
|
||
s.promptTokens = ctx->prompt_len;
|
||
s.genTokens = ctx->gen_seq_len;
|
||
s.prefillMs = ctx->prefill_us / 1000.0;
|
||
s.decodeMs = ctx->decode_us / 1000.0;
|
||
}
|
||
}
|
||
return s;
|
||
}
|
||
|
||
@end
|
||
|
||
#endif
|