- Introduced a new Python script `check_agents_metadata.py` for validating agent metadata, including YAML parsing, rating ranges, and cross-references. - Added usage instructions and exit codes for the script. - Created a new markdown file `模块目录和功能说明.md` to outline the directory structure and functionality of the modules. - Added a text file `说明此文件不可AI修改.txt` to specify that certain files should not be modified by AI, including important information about the `dstalk` framework and its modules.
447 lines
20 KiB
C++
447 lines
20 KiB
C++
/*
|
||
* @file context_plugin.cpp
|
||
* @brief Context plugin: token counting and context window trimming.
|
||
* 上下文插件:token 计数和上下文窗口裁剪。
|
||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||
*/
|
||
|
||
// plugin-context: 上下文管理服务插件 / Context management service plugin
|
||
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
|
||
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
|
||
#include "dstalk/dstalk_host.h"
|
||
#include "dstalk/dstalk_types.h"
|
||
#include "dstalk/dstalk_services.h"
|
||
|
||
#include <algorithm>
|
||
#include <cstddef>
|
||
#include <cstdint>
|
||
#include <cstdio>
|
||
#include <cstring>
|
||
#include <exception>
|
||
#include <string>
|
||
#include <vector>
|
||
|
||
// ============================================================
|
||
// 全局状态 / Global state
|
||
// ============================================================
|
||
|
||
static const dstalk_host_api_t* g_host = nullptr;
|
||
static const dstalk_session_service_t* g_session = nullptr;
|
||
|
||
// ============================================================
|
||
// 内部 C++ 辅助:共享 UTF-8 token 计数 / Internal C++ helper: shared UTF-8 token counting
|
||
// W18.1: 合并 count_tokens_one_message / count_tokens_trim 的重复逻辑 (F-11.1-5)
|
||
// Merge duplicated logic between count_tokens_one_message / count_tokens_trim (F-11.1-5)
|
||
// 添加 UTF-8 越界保护 (F-11.1-4) 和 0xC0/0xC1 过短编码检测 (F-11.1-6)
|
||
// Add UTF-8 out-of-bounds protection (F-11.1-4) and 0xC0/0xC1 overlong encoding detection (F-11.1-6)
|
||
// ============================================================
|
||
|
||
// 统计 UTF-8 字节序列 [text, text+len) 的估算 token 数。
|
||
// overhead: 每条消息的固定开销 token(role + separators = 4)
|
||
// 多字节序列在越界或无效后继字节时回退为单字节 other_chars 计数,不崩溃。
|
||
// Count estimated tokens for UTF-8 byte sequence [text, text+len).
|
||
// overhead: fixed token overhead per message (role + separators = 4).
|
||
// Multi-byte sequences fall back to single-byte other_chars counting when out-of-bounds or invalid continuation bytes.
|
||
static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||
if (!text || len == 0) return overhead;
|
||
|
||
size_t ascii_chars = 0;
|
||
size_t chinese_chars = 0;
|
||
size_t other_chars = 0;
|
||
|
||
size_t i = 0;
|
||
while (i < len && text[i] != '\0') {
|
||
unsigned char c = static_cast<unsigned char>(text[i]);
|
||
|
||
if (c < 0x80) {
|
||
// ASCII / ASCII
|
||
ascii_chars++;
|
||
i += 1;
|
||
} else if (c >= 0xE4 && c <= 0xE9) {
|
||
// CJK 统一表意文字 (U+4E00-U+9FFF): 3 字节 UTF-8 0xE4-0xE9 / CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9
|
||
// W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内 / Check if subsequent 2 bytes are in valid range
|
||
if (i + 2 >= len ||
|
||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
|
||
other_chars++;
|
||
i += 1;
|
||
} else {
|
||
chinese_chars++;
|
||
i += 3;
|
||
}
|
||
} else if (c >= 0xC2 && c < 0xE0) {
|
||
// 2 字节序列 (有效范围 0xC2-0xDF) / 2-byte sequence (valid range 0xC2-0xDF)
|
||
// W18.1 (F-11.1-4): 检查后续 1 字节 / Check subsequent 1 byte
|
||
if (i + 1 >= len ||
|
||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80) {
|
||
other_chars++;
|
||
i += 1;
|
||
} else {
|
||
other_chars++;
|
||
i += 2;
|
||
}
|
||
} else if (c == 0xC0 || c == 0xC1) {
|
||
// W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节 / Overlong encoding, invalid UTF-8 start byte
|
||
// 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars / 0xC0/0xC1 never appear in valid UTF-8; counted as single-byte in other_chars
|
||
other_chars++;
|
||
i += 1;
|
||
} else if (c >= 0xE0 && c < 0xF0) {
|
||
// 非 CJK 3 字节序列 (0xE0-0xE3, 0xEA-0xEF) / Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF)
|
||
// CJK 范围 0xE4-0xE9 已在上方分支处理 / CJK range 0xE4-0xE9 handled in branch above
|
||
if (i + 2 >= len ||
|
||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
|
||
other_chars++;
|
||
i += 1;
|
||
} else {
|
||
other_chars++;
|
||
i += 3;
|
||
}
|
||
} else if (c >= 0xF0 && c < 0xF8) {
|
||
// 4 字节序列 / 4-byte sequence
|
||
if (i + 3 >= len ||
|
||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80 ||
|
||
(static_cast<unsigned char>(text[i + 3]) & 0xC0) != 0x80) {
|
||
other_chars++;
|
||
i += 1;
|
||
} else {
|
||
other_chars++;
|
||
i += 4;
|
||
}
|
||
} else {
|
||
// 续字节 (0x80-0xBF) 和其他无效起始字节 (0xF8-0xFF) / Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF)
|
||
other_chars++;
|
||
i += 1;
|
||
}
|
||
}
|
||
|
||
return (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3) + overhead;
|
||
}
|
||
|
||
// ============================================================
|
||
// 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装) / Message-level token counting (thin wrappers for count_tokens_all and trim_impl)
|
||
// ============================================================
|
||
|
||
// 对单条 C 消息结构体封装 count_tokens_utf8 / Wrap count_tokens_utf8 for a single C message struct.
|
||
static size_t count_tokens_one_message(const dstalk_message_t& msg) {
|
||
const char* text = msg.content;
|
||
if (!text) return 4; // 只有 overhead / overhead only
|
||
return count_tokens_utf8(text, std::strlen(text), 4);
|
||
}
|
||
|
||
// 对 C 消息数组求和估算 token / Sum token estimates across an array of C messages.
|
||
static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
|
||
size_t total = 0;
|
||
for (int i = 0; i < count; ++i) {
|
||
total += count_tokens_one_message(msgs[i]);
|
||
}
|
||
return total;
|
||
}
|
||
|
||
// ============================================================
|
||
// 内部 trim 逻辑 / Internal trim logic
|
||
// ============================================================
|
||
|
||
// 为 trim 操作将 C 消息数组复制到内部 struct / Copy C message array to internal struct for trim operation
|
||
struct TrimMessage {
|
||
std::string role;
|
||
std::string content;
|
||
std::string tool_call_id;
|
||
std::string tool_calls_json;
|
||
};
|
||
|
||
static size_t count_tokens_trim(const TrimMessage& msg) {
|
||
if (msg.content.empty()) return 4;
|
||
return count_tokens_utf8(msg.content.c_str(), msg.content.size(), 4);
|
||
}
|
||
|
||
static size_t count_tokens_trim_vec(const std::vector<TrimMessage>& msgs) {
|
||
size_t total = 0;
|
||
for (const auto& m : msgs) total += count_tokens_trim(m);
|
||
return total;
|
||
}
|
||
|
||
// 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚) / Free all host-allocated string fields in a single dstalk_message_t (OOM rollback helper).
|
||
static void free_msg_strs(dstalk_message_t* msg) {
|
||
if (msg->role) { g_host->free((void*)msg->role); msg->role = nullptr; }
|
||
if (msg->content) { g_host->free((void*)msg->content); msg->content = nullptr; }
|
||
if (msg->tool_call_id) { g_host->free((void*)msg->tool_call_id); msg->tool_call_id = nullptr; }
|
||
if (msg->tool_calls_json) { g_host->free((void*)msg->tool_calls_json); msg->tool_calls_json = nullptr; }
|
||
}
|
||
|
||
// 将 TrimMessage 的字符串字段通过 g_host->strdup 复制到 dstalk_message_t。
|
||
// 成功返回 0;OOM 时释放当前消息已分配字段并返回 -1。
|
||
// Copy TrimMessage string fields into a dstalk_message_t via host->strdup.
|
||
// On OOM, frees already-allocated fields and returns -1.
|
||
static int strdup_message_fields(dstalk_message_t* dst, const TrimMessage& src) {
|
||
memset(dst, 0, sizeof(dstalk_message_t));
|
||
|
||
if (!src.role.empty()) {
|
||
dst->role = g_host->strdup(src.role.c_str());
|
||
if (!dst->role) goto oom;
|
||
}
|
||
if (!src.content.empty()) {
|
||
dst->content = g_host->strdup(src.content.c_str());
|
||
if (!dst->content) goto oom;
|
||
}
|
||
if (!src.tool_call_id.empty()) {
|
||
dst->tool_call_id = g_host->strdup(src.tool_call_id.c_str());
|
||
if (!dst->tool_call_id) goto oom;
|
||
}
|
||
if (!src.tool_calls_json.empty()) {
|
||
dst->tool_calls_json = g_host->strdup(src.tool_calls_json.c_str());
|
||
if (!dst->tool_calls_json) goto oom;
|
||
}
|
||
return 0;
|
||
|
||
oom:
|
||
free_msg_strs(dst);
|
||
return -1;
|
||
}
|
||
|
||
// W12.1 修复:trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3) / W12.1 fix: trim_impl wrapped in try/catch to prevent C++ exceptions crossing ABI boundary (§5.3)
|
||
// 核心裁剪逻辑:通过删除最旧的 user/assistant 对来减少消息列表以适应 max_tokens。
|
||
// 保留 system 消息。try/catch 保护 ABI / Core trim logic: reduce message list to fit within max_tokens by removing
|
||
// oldest user/assistant pairs. Preserves system messages. try/catch guards ABI.
|
||
static int trim_impl(const dstalk_message_t* in, int in_count,
|
||
dstalk_message_t** out, int* out_count,
|
||
size_t max_tokens) {
|
||
try {
|
||
if (!in || in_count <= 0 || !out || !out_count) return -1;
|
||
|
||
// W18.1 (F-11.1-3): g_max_tokens 已移除,调用方必须提供有效 max_tokens;
|
||
// 传 0 时使用硬编码默认值 4096 / g_max_tokens removed, caller must provide valid max_tokens;
|
||
// when 0 is passed, use hardcoded default 4096.
|
||
if (max_tokens == 0) max_tokens = 4096;
|
||
|
||
// 将 C 数组转换为内部 vector / Convert C array to internal vector
|
||
std::vector<TrimMessage> messages;
|
||
messages.reserve(in_count);
|
||
for (int i = 0; i < in_count; ++i) {
|
||
TrimMessage tm;
|
||
if (in[i].role) tm.role = in[i].role;
|
||
if (in[i].content) tm.content = in[i].content;
|
||
if (in[i].tool_call_id) tm.tool_call_id = in[i].tool_call_id;
|
||
if (in[i].tool_calls_json) tm.tool_calls_json = in[i].tool_calls_json;
|
||
messages.push_back(std::move(tm));
|
||
}
|
||
|
||
// 如果已在限制内,直接返回完整副本 / If already within limit, return full copy directly
|
||
size_t current = count_tokens_trim_vec(messages);
|
||
if (current <= max_tokens) {
|
||
*out_count = in_count;
|
||
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * in_count));
|
||
if (!*out) return -1;
|
||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback already allocated on OOM
|
||
for (int i = 0; i < in_count; ++i) {
|
||
if (strdup_message_fields(&(*out)[i], messages[i]) != 0) {
|
||
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
|
||
g_host->free(*out);
|
||
*out = nullptr;
|
||
return -1;
|
||
}
|
||
}
|
||
return 0;
|
||
}
|
||
|
||
// 分离 system 消息和非 system 消息 / Separate system messages from non-system messages
|
||
std::vector<TrimMessage> system_msgs;
|
||
std::vector<TrimMessage> non_system_msgs;
|
||
for (const auto& msg : messages) {
|
||
if (msg.role == "system") {
|
||
system_msgs.push_back(msg);
|
||
} else {
|
||
non_system_msgs.push_back(msg);
|
||
}
|
||
}
|
||
|
||
size_t system_tokens = count_tokens_trim_vec(system_msgs);
|
||
if (system_tokens > max_tokens) {
|
||
std::fprintf(stderr, "[context] WARNING: system messages alone "
|
||
"(%zu tokens) exceed max_context_tokens (%zu)\n",
|
||
system_tokens, max_tokens);
|
||
}
|
||
|
||
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
|
||
for (const auto& msg : non_system_msgs) {
|
||
size_t msg_tokens = count_tokens_trim(msg);
|
||
if (msg_tokens > max_tokens) {
|
||
std::fprintf(stderr, "[context] WARNING: single message "
|
||
"(%s, %zu tokens) exceeds max_context_tokens (%zu). "
|
||
"Returning empty list.\n",
|
||
msg.role.c_str(), msg_tokens, max_tokens);
|
||
*out = nullptr;
|
||
*out_count = 0;
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
|
||
while (!non_system_msgs.empty()) {
|
||
current = system_tokens + count_tokens_trim_vec(non_system_msgs);
|
||
if (current <= max_tokens) break;
|
||
|
||
// 找第一个 "user" 消息 / Find first "user" message
|
||
auto user_it = non_system_msgs.begin();
|
||
while (user_it != non_system_msgs.end() && user_it->role != "user") {
|
||
++user_it;
|
||
}
|
||
if (user_it == non_system_msgs.end()) break;
|
||
|
||
// 找下一个 "assistant" / Find next "assistant"
|
||
auto assistant_it = user_it + 1;
|
||
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
|
||
++assistant_it;
|
||
}
|
||
|
||
if (assistant_it == non_system_msgs.end()) {
|
||
non_system_msgs.erase(user_it);
|
||
} else {
|
||
// 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
|
||
non_system_msgs.erase(assistant_it);
|
||
user_it = non_system_msgs.begin();
|
||
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
|
||
if (user_it != non_system_msgs.end()) non_system_msgs.erase(user_it);
|
||
}
|
||
}
|
||
|
||
// W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token),使用当前 max_tokens / Message count upper bound rough estimate (~100 tokens per message), uses current max_tokens
|
||
{
|
||
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
|
||
if (max_msg_count < 1) max_msg_count = 1;
|
||
while (non_system_msgs.size() > max_msg_count) {
|
||
non_system_msgs.erase(non_system_msgs.begin());
|
||
}
|
||
}
|
||
|
||
// 组装结果 / Assemble result
|
||
std::vector<TrimMessage> result;
|
||
result.reserve(system_msgs.size() + non_system_msgs.size());
|
||
result.insert(result.end(), system_msgs.begin(), system_msgs.end());
|
||
result.insert(result.end(), non_system_msgs.begin(), non_system_msgs.end());
|
||
|
||
int result_count = static_cast<int>(result.size());
|
||
*out_count = result_count;
|
||
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * result_count));
|
||
if (!*out) return -1;
|
||
|
||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback on OOM
|
||
for (int i = 0; i < result_count; ++i) {
|
||
if (strdup_message_fields(&(*out)[i], result[i]) != 0) {
|
||
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
|
||
g_host->free(*out);
|
||
*out = nullptr;
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
return 0;
|
||
} catch (const std::exception& e) {
|
||
// W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate() / Prevent C++ exceptions (std::bad_alloc etc.) from crossing C ABI boundary -> std::terminate()
|
||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl exception: %s", e.what());
|
||
return -1;
|
||
} catch (...) {
|
||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl unknown exception");
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
// ============================================================
|
||
// Context 服务 vtable 实现 / Context service vtable implementation
|
||
// ============================================================
|
||
|
||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate() / Wrapped try/catch prevents exceptions crossing C ABI boundary -> std::terminate()
|
||
// 对 C 消息数组进行 token 计数。输入为 null/空时返回 0 / Count tokens across an array of C messages. Returns 0 on null/empty input.
|
||
static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
|
||
try {
|
||
if (!msgs || count <= 0) return 0;
|
||
return count_tokens_all(msgs, count);
|
||
} catch (...) {
|
||
return 0;
|
||
}
|
||
}
|
||
|
||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
|
||
// 裁剪消息列表以适应 max_tokens,返回新分配的主机内存数组 / Trim a message list to fit within max_tokens, returning a new host-allocated array.
|
||
static int context_trim(const dstalk_message_t* in, int in_count,
|
||
dstalk_message_t** out, int* out_count,
|
||
size_t max_tokens) {
|
||
try {
|
||
return trim_impl(in, in_count, out, out_count, max_tokens);
|
||
} catch (...) {
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
// W18.1 (F-11.1-3): g_max_tokens / context_set_max_tokens 已移除。
|
||
// max_tokens 由调用方通过 trim() 的 max_tokens 参数直接传入;
|
||
// 传 0 时 trim_impl 使用硬编码默认值 4096。
|
||
// g_max_tokens / context_set_max_tokens removed. max_tokens is passed directly
|
||
// by caller via trim()'s max_tokens parameter; trim_impl uses hardcoded default 4096 when 0.
|
||
static dstalk_context_service_t g_context_service = {
|
||
context_count_tokens,
|
||
context_trim
|
||
};
|
||
|
||
// ============================================================
|
||
// 插件生命周期 / Plugin lifecycle
|
||
// ============================================================
|
||
|
||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
|
||
// 插件初始化:保存主机指针,查询 session 依赖,注册 context 服务 / Plugin init: store host pointer, query session dependency, register context service.
|
||
static int on_init(const dstalk_host_api_t* host) {
|
||
try {
|
||
g_host = host;
|
||
|
||
// 查询依赖服务: session / Query dependency service: session
|
||
void* raw = host->query_service("session", 1);
|
||
if (!raw) {
|
||
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
|
||
return -1;
|
||
}
|
||
g_session = static_cast<const dstalk_session_service_t*>(raw);
|
||
|
||
return host->register_service("context", 1, &g_context_service);
|
||
} catch (const std::exception& e) {
|
||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init exception: %s", e.what());
|
||
return -1;
|
||
} catch (...) {
|
||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init unknown exception");
|
||
return -1;
|
||
}
|
||
}
|
||
|
||
// W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log / Wrapped try/catch prevents exceptions crossing C ABI boundary -- void function only logs
|
||
// 插件关闭:清空指针。try/catch 保护 ABI(void 函数) / Plugin shutdown: null out pointers. try/catch guards ABI (void function).
|
||
static void on_shutdown() {
|
||
try {
|
||
g_session = nullptr;
|
||
g_host = nullptr;
|
||
} catch (const std::exception& e) {
|
||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: %s", e.what());
|
||
g_session = nullptr;
|
||
g_host = nullptr;
|
||
} catch (...) {
|
||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: unknown exception");
|
||
g_session = nullptr;
|
||
g_host = nullptr;
|
||
}
|
||
}
|
||
|
||
static dstalk_plugin_info_t g_info = {
|
||
"context",
|
||
"1.0.0",
|
||
"Context management plugin with token counting and trim support / 支持 token 计数和裁剪的上下文管理插件",
|
||
DSTALK_API_VERSION,
|
||
{"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||
on_init,
|
||
on_shutdown,
|
||
nullptr
|
||
};
|
||
|
||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||
return &g_info;
|
||
}
|