/* * @file context_plugin.cpp * @brief Context plugin: token counting and context window trimming. * 上下文插件:token 计数和上下文窗口裁剪。 * Copyright (c) 2026 dstalk contributors. GPLv3. */ // plugin-context: 上下文管理服务插件 / Context management service plugin // 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation // 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting) #include "dstalk/dstalk_host.h" #include "dstalk/dstalk_types.h" #include "dstalk/dstalk_services.h" #include #include #include #include #include #include #include #include // ============================================================ // 全局状态 / Global state // ============================================================ static const dstalk_host_api_t* g_host = nullptr; static const dstalk_session_service_t* g_session = nullptr; // ============================================================ // 内部 C++ 辅助:共享 UTF-8 token 计数 / Internal C++ helper: shared UTF-8 token counting // W18.1: 合并 count_tokens_one_message / count_tokens_trim 的重复逻辑 (F-11.1-5) // Merge duplicated logic between count_tokens_one_message / count_tokens_trim (F-11.1-5) // 添加 UTF-8 越界保护 (F-11.1-4) 和 0xC0/0xC1 过短编码检测 (F-11.1-6) // Add UTF-8 out-of-bounds protection (F-11.1-4) and 0xC0/0xC1 overlong encoding detection (F-11.1-6) // ============================================================ // 统计 UTF-8 字节序列 [text, text+len) 的估算 token 数。 // overhead: 每条消息的固定开销 token(role + separators = 4) // 多字节序列在越界或无效后继字节时回退为单字节 other_chars 计数,不崩溃。 // Count estimated tokens for UTF-8 byte sequence [text, text+len). // overhead: fixed token overhead per message (role + separators = 4). // Multi-byte sequences fall back to single-byte other_chars counting when out-of-bounds or invalid continuation bytes. static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) { if (!text || len == 0) return overhead; size_t ascii_chars = 0; size_t chinese_chars = 0; size_t other_chars = 0; size_t i = 0; while (i < len && text[i] != '\0') { unsigned char c = static_cast(text[i]); if (c < 0x80) { // ASCII / ASCII ascii_chars++; i += 1; } else if (c >= 0xE4 && c <= 0xE9) { // CJK 统一表意文字 (U+4E00-U+9FFF): 3 字节 UTF-8 0xE4-0xE9 / CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9 // W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内 / Check if subsequent 2 bytes are in valid range if (i + 2 >= len || (static_cast(text[i + 1]) & 0xC0) != 0x80 || (static_cast(text[i + 2]) & 0xC0) != 0x80) { other_chars++; i += 1; } else { chinese_chars++; i += 3; } } else if (c >= 0xC2 && c < 0xE0) { // 2 字节序列 (有效范围 0xC2-0xDF) / 2-byte sequence (valid range 0xC2-0xDF) // W18.1 (F-11.1-4): 检查后续 1 字节 / Check subsequent 1 byte if (i + 1 >= len || (static_cast(text[i + 1]) & 0xC0) != 0x80) { other_chars++; i += 1; } else { other_chars++; i += 2; } } else if (c == 0xC0 || c == 0xC1) { // W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节 / Overlong encoding, invalid UTF-8 start byte // 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars / 0xC0/0xC1 never appear in valid UTF-8; counted as single-byte in other_chars other_chars++; i += 1; } else if (c >= 0xE0 && c < 0xF0) { // 非 CJK 3 字节序列 (0xE0-0xE3, 0xEA-0xEF) / Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF) // CJK 范围 0xE4-0xE9 已在上方分支处理 / CJK range 0xE4-0xE9 handled in branch above if (i + 2 >= len || (static_cast(text[i + 1]) & 0xC0) != 0x80 || (static_cast(text[i + 2]) & 0xC0) != 0x80) { other_chars++; i += 1; } else { other_chars++; i += 3; } } else if (c >= 0xF0 && c < 0xF8) { // 4 字节序列 / 4-byte sequence if (i + 3 >= len || (static_cast(text[i + 1]) & 0xC0) != 0x80 || (static_cast(text[i + 2]) & 0xC0) != 0x80 || (static_cast(text[i + 3]) & 0xC0) != 0x80) { other_chars++; i += 1; } else { other_chars++; i += 4; } } else { // 续字节 (0x80-0xBF) 和其他无效起始字节 (0xF8-0xFF) / Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF) other_chars++; i += 1; } } return (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3) + overhead; } // ============================================================ // 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装) / Message-level token counting (thin wrappers for count_tokens_all and trim_impl) // ============================================================ // 对单条 C 消息结构体封装 count_tokens_utf8 / Wrap count_tokens_utf8 for a single C message struct. static size_t count_tokens_one_message(const dstalk_message_t& msg) { const char* text = msg.content; if (!text) return 4; // 只有 overhead / overhead only return count_tokens_utf8(text, std::strlen(text), 4); } // 对 C 消息数组求和估算 token / Sum token estimates across an array of C messages. static size_t count_tokens_all(const dstalk_message_t* msgs, int count) { size_t total = 0; for (int i = 0; i < count; ++i) { total += count_tokens_one_message(msgs[i]); } return total; } // ============================================================ // 内部 trim 逻辑 / Internal trim logic // ============================================================ // 为 trim 操作将 C 消息数组复制到内部 struct / Copy C message array to internal struct for trim operation struct TrimMessage { std::string role; std::string content; std::string tool_call_id; std::string tool_calls_json; }; static size_t count_tokens_trim(const TrimMessage& msg) { if (msg.content.empty()) return 4; return count_tokens_utf8(msg.content.c_str(), msg.content.size(), 4); } static size_t count_tokens_trim_vec(const std::vector& msgs) { size_t total = 0; for (const auto& m : msgs) total += count_tokens_trim(m); return total; } // 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚) / Free all host-allocated string fields in a single dstalk_message_t (OOM rollback helper). static void free_msg_strs(dstalk_message_t* msg) { if (msg->role) { g_host->free((void*)msg->role); msg->role = nullptr; } if (msg->content) { g_host->free((void*)msg->content); msg->content = nullptr; } if (msg->tool_call_id) { g_host->free((void*)msg->tool_call_id); msg->tool_call_id = nullptr; } if (msg->tool_calls_json) { g_host->free((void*)msg->tool_calls_json); msg->tool_calls_json = nullptr; } } // 将 TrimMessage 的字符串字段通过 g_host->strdup 复制到 dstalk_message_t。 // 成功返回 0;OOM 时释放当前消息已分配字段并返回 -1。 // Copy TrimMessage string fields into a dstalk_message_t via host->strdup. // On OOM, frees already-allocated fields and returns -1. static int strdup_message_fields(dstalk_message_t* dst, const TrimMessage& src) { memset(dst, 0, sizeof(dstalk_message_t)); if (!src.role.empty()) { dst->role = g_host->strdup(src.role.c_str()); if (!dst->role) goto oom; } if (!src.content.empty()) { dst->content = g_host->strdup(src.content.c_str()); if (!dst->content) goto oom; } if (!src.tool_call_id.empty()) { dst->tool_call_id = g_host->strdup(src.tool_call_id.c_str()); if (!dst->tool_call_id) goto oom; } if (!src.tool_calls_json.empty()) { dst->tool_calls_json = g_host->strdup(src.tool_calls_json.c_str()); if (!dst->tool_calls_json) goto oom; } return 0; oom: free_msg_strs(dst); return -1; } // W12.1 修复:trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3) / W12.1 fix: trim_impl wrapped in try/catch to prevent C++ exceptions crossing ABI boundary (§5.3) // 核心裁剪逻辑:通过删除最旧的 user/assistant 对来减少消息列表以适应 max_tokens。 // 保留 system 消息。try/catch 保护 ABI / Core trim logic: reduce message list to fit within max_tokens by removing // oldest user/assistant pairs. Preserves system messages. try/catch guards ABI. static int trim_impl(const dstalk_message_t* in, int in_count, dstalk_message_t** out, int* out_count, size_t max_tokens) { try { if (!in || in_count <= 0 || !out || !out_count) return -1; // W18.1 (F-11.1-3): g_max_tokens 已移除,调用方必须提供有效 max_tokens; // 传 0 时使用硬编码默认值 4096 / g_max_tokens removed, caller must provide valid max_tokens; // when 0 is passed, use hardcoded default 4096. if (max_tokens == 0) max_tokens = 4096; // 将 C 数组转换为内部 vector / Convert C array to internal vector std::vector messages; messages.reserve(in_count); for (int i = 0; i < in_count; ++i) { TrimMessage tm; if (in[i].role) tm.role = in[i].role; if (in[i].content) tm.content = in[i].content; if (in[i].tool_call_id) tm.tool_call_id = in[i].tool_call_id; if (in[i].tool_calls_json) tm.tool_calls_json = in[i].tool_calls_json; messages.push_back(std::move(tm)); } // 如果已在限制内,直接返回完整副本 / If already within limit, return full copy directly size_t current = count_tokens_trim_vec(messages); if (current <= max_tokens) { *out_count = in_count; *out = static_cast(g_host->alloc(sizeof(dstalk_message_t) * in_count)); if (!*out) return -1; // W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback already allocated on OOM for (int i = 0; i < in_count; ++i) { if (strdup_message_fields(&(*out)[i], messages[i]) != 0) { for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]); g_host->free(*out); *out = nullptr; return -1; } } return 0; } // 分离 system 消息和非 system 消息 / Separate system messages from non-system messages std::vector system_msgs; std::vector non_system_msgs; for (const auto& msg : messages) { if (msg.role == "system") { system_msgs.push_back(msg); } else { non_system_msgs.push_back(msg); } } size_t system_tokens = count_tokens_trim_vec(system_msgs); if (system_tokens > max_tokens) { std::fprintf(stderr, "[context] WARNING: system messages alone " "(%zu tokens) exceed max_context_tokens (%zu)\n", system_tokens, max_tokens); } // 检查是否有单条消息超过限制 / Check if any single message exceeds the limit for (const auto& msg : non_system_msgs) { size_t msg_tokens = count_tokens_trim(msg); if (msg_tokens > max_tokens) { std::fprintf(stderr, "[context] WARNING: single message " "(%s, %zu tokens) exceeds max_context_tokens (%zu). " "Returning empty list.\n", msg.role.c_str(), msg_tokens, max_tokens); *out = nullptr; *out_count = 0; return -1; } } // 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together while (!non_system_msgs.empty()) { current = system_tokens + count_tokens_trim_vec(non_system_msgs); if (current <= max_tokens) break; // 找第一个 "user" 消息 / Find first "user" message auto user_it = non_system_msgs.begin(); while (user_it != non_system_msgs.end() && user_it->role != "user") { ++user_it; } if (user_it == non_system_msgs.end()) break; // 找下一个 "assistant" / Find next "assistant" auto assistant_it = user_it + 1; while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") { ++assistant_it; } if (assistant_it == non_system_msgs.end()) { non_system_msgs.erase(user_it); } else { // 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation non_system_msgs.erase(assistant_it); user_it = non_system_msgs.begin(); while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it; if (user_it != non_system_msgs.end()) non_system_msgs.erase(user_it); } } // W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token),使用当前 max_tokens / Message count upper bound rough estimate (~100 tokens per message), uses current max_tokens { size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100) if (max_msg_count < 1) max_msg_count = 1; while (non_system_msgs.size() > max_msg_count) { non_system_msgs.erase(non_system_msgs.begin()); } } // 组装结果 / Assemble result std::vector result; result.reserve(system_msgs.size() + non_system_msgs.size()); result.insert(result.end(), system_msgs.begin(), system_msgs.end()); result.insert(result.end(), non_system_msgs.begin(), non_system_msgs.end()); int result_count = static_cast(result.size()); *out_count = result_count; *out = static_cast(g_host->alloc(sizeof(dstalk_message_t) * result_count)); if (!*out) return -1; // W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback on OOM for (int i = 0; i < result_count; ++i) { if (strdup_message_fields(&(*out)[i], result[i]) != 0) { for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]); g_host->free(*out); *out = nullptr; return -1; } } return 0; } catch (const std::exception& e) { // W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate() / Prevent C++ exceptions (std::bad_alloc etc.) from crossing C ABI boundary -> std::terminate() if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl exception: %s", e.what()); return -1; } catch (...) { if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl unknown exception"); return -1; } } // ============================================================ // Context 服务 vtable 实现 / Context service vtable implementation // ============================================================ // W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate() / Wrapped try/catch prevents exceptions crossing C ABI boundary -> std::terminate() // 对 C 消息数组进行 token 计数。输入为 null/空时返回 0 / Count tokens across an array of C messages. Returns 0 on null/empty input. static size_t context_count_tokens(const dstalk_message_t* msgs, int count) { try { if (!msgs || count <= 0) return 0; return count_tokens_all(msgs, count); } catch (...) { return 0; } } // W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary // 裁剪消息列表以适应 max_tokens,返回新分配的主机内存数组 / Trim a message list to fit within max_tokens, returning a new host-allocated array. static int context_trim(const dstalk_message_t* in, int in_count, dstalk_message_t** out, int* out_count, size_t max_tokens) { try { return trim_impl(in, in_count, out, out_count, max_tokens); } catch (...) { return -1; } } // W18.1 (F-11.1-3): g_max_tokens / context_set_max_tokens 已移除。 // max_tokens 由调用方通过 trim() 的 max_tokens 参数直接传入; // 传 0 时 trim_impl 使用硬编码默认值 4096。 // g_max_tokens / context_set_max_tokens removed. max_tokens is passed directly // by caller via trim()'s max_tokens parameter; trim_impl uses hardcoded default 4096 when 0. static dstalk_context_service_t g_context_service = { context_count_tokens, context_trim }; // ============================================================ // 插件生命周期 / Plugin lifecycle // ============================================================ // W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary // 插件初始化:保存主机指针,查询 session 依赖,注册 context 服务 / Plugin init: store host pointer, query session dependency, register context service. static int on_init(const dstalk_host_api_t* host) { try { g_host = host; // 查询依赖服务: session / Query dependency service: session void* raw = host->query_service("session", 1); if (!raw) { host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found"); return -1; } g_session = static_cast(raw); return host->register_service("context", 1, &g_context_service); } catch (const std::exception& e) { if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init exception: %s", e.what()); return -1; } catch (...) { if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init unknown exception"); return -1; } } // W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log / Wrapped try/catch prevents exceptions crossing C ABI boundary -- void function only logs // 插件关闭:清空指针。try/catch 保护 ABI(void 函数) / Plugin shutdown: null out pointers. try/catch guards ABI (void function). static void on_shutdown() { try { g_session = nullptr; g_host = nullptr; } catch (const std::exception& e) { if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: %s", e.what()); g_session = nullptr; g_host = nullptr; } catch (...) { if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: unknown exception"); g_session = nullptr; g_host = nullptr; } } static dstalk_plugin_info_t g_info = { "context", "1.0.0", "Context management plugin with token counting and trim support / 支持 token 计数和裁剪的上下文管理插件", DSTALK_API_VERSION, {"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, on_init, on_shutdown, nullptr }; // 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host. extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { return &g_info; }