Files
dstalk/plugins/context/src/context_plugin.cpp
XiuChengWu f2da0f2ed4 Add metadata validation script and module documentation
- Introduced a new Python script `check_agents_metadata.py` for validating agent metadata, including YAML parsing, rating ranges, and cross-references.
- Added usage instructions and exit codes for the script.
- Created a new markdown file `模块目录和功能说明.md` to outline the directory structure and functionality of the modules.
- Added a text file `说明此文件不可AI修改.txt` to specify that certain files should not be modified by AI, including important information about the `dstalk` framework and its modules.
2026-05-31 00:00:58 +08:00

447 lines
20 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* @file context_plugin.cpp
* @brief Context plugin: token counting and context window trimming.
* 上下文插件token 计数和上下文窗口裁剪。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
// plugin-context: 上下文管理服务插件 / Context management service plugin
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_types.h"
#include "dstalk/dstalk_services.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <exception>
#include <string>
#include <vector>
// ============================================================
// 全局状态 / Global state
// ============================================================
static const dstalk_host_api_t* g_host = nullptr;
static const dstalk_session_service_t* g_session = nullptr;
// ============================================================
// 内部 C++ 辅助:共享 UTF-8 token 计数 / Internal C++ helper: shared UTF-8 token counting
// W18.1: 合并 count_tokens_one_message / count_tokens_trim 的重复逻辑 (F-11.1-5)
// Merge duplicated logic between count_tokens_one_message / count_tokens_trim (F-11.1-5)
// 添加 UTF-8 越界保护 (F-11.1-4) 和 0xC0/0xC1 过短编码检测 (F-11.1-6)
// Add UTF-8 out-of-bounds protection (F-11.1-4) and 0xC0/0xC1 overlong encoding detection (F-11.1-6)
// ============================================================
// 统计 UTF-8 字节序列 [text, text+len) 的估算 token 数。
// overhead: 每条消息的固定开销 tokenrole + separators = 4
// 多字节序列在越界或无效后继字节时回退为单字节 other_chars 计数,不崩溃。
// Count estimated tokens for UTF-8 byte sequence [text, text+len).
// overhead: fixed token overhead per message (role + separators = 4).
// Multi-byte sequences fall back to single-byte other_chars counting when out-of-bounds or invalid continuation bytes.
static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
if (!text || len == 0) return overhead;
size_t ascii_chars = 0;
size_t chinese_chars = 0;
size_t other_chars = 0;
size_t i = 0;
while (i < len && text[i] != '\0') {
unsigned char c = static_cast<unsigned char>(text[i]);
if (c < 0x80) {
// ASCII / ASCII
ascii_chars++;
i += 1;
} else if (c >= 0xE4 && c <= 0xE9) {
// CJK 统一表意文字 (U+4E00-U+9FFF): 3 字节 UTF-8 0xE4-0xE9 / CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9
// W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内 / Check if subsequent 2 bytes are in valid range
if (i + 2 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
other_chars++;
i += 1;
} else {
chinese_chars++;
i += 3;
}
} else if (c >= 0xC2 && c < 0xE0) {
// 2 字节序列 (有效范围 0xC2-0xDF) / 2-byte sequence (valid range 0xC2-0xDF)
// W18.1 (F-11.1-4): 检查后续 1 字节 / Check subsequent 1 byte
if (i + 1 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80) {
other_chars++;
i += 1;
} else {
other_chars++;
i += 2;
}
} else if (c == 0xC0 || c == 0xC1) {
// W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节 / Overlong encoding, invalid UTF-8 start byte
// 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars / 0xC0/0xC1 never appear in valid UTF-8; counted as single-byte in other_chars
other_chars++;
i += 1;
} else if (c >= 0xE0 && c < 0xF0) {
// 非 CJK 3 字节序列 (0xE0-0xE3, 0xEA-0xEF) / Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF)
// CJK 范围 0xE4-0xE9 已在上方分支处理 / CJK range 0xE4-0xE9 handled in branch above
if (i + 2 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
other_chars++;
i += 1;
} else {
other_chars++;
i += 3;
}
} else if (c >= 0xF0 && c < 0xF8) {
// 4 字节序列 / 4-byte sequence
if (i + 3 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 3]) & 0xC0) != 0x80) {
other_chars++;
i += 1;
} else {
other_chars++;
i += 4;
}
} else {
// 续字节 (0x80-0xBF) 和其他无效起始字节 (0xF8-0xFF) / Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF)
other_chars++;
i += 1;
}
}
return (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3) + overhead;
}
// ============================================================
// 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装) / Message-level token counting (thin wrappers for count_tokens_all and trim_impl)
// ============================================================
// 对单条 C 消息结构体封装 count_tokens_utf8 / Wrap count_tokens_utf8 for a single C message struct.
static size_t count_tokens_one_message(const dstalk_message_t& msg) {
const char* text = msg.content;
if (!text) return 4; // 只有 overhead / overhead only
return count_tokens_utf8(text, std::strlen(text), 4);
}
// 对 C 消息数组求和估算 token / Sum token estimates across an array of C messages.
static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
size_t total = 0;
for (int i = 0; i < count; ++i) {
total += count_tokens_one_message(msgs[i]);
}
return total;
}
// ============================================================
// 内部 trim 逻辑 / Internal trim logic
// ============================================================
// 为 trim 操作将 C 消息数组复制到内部 struct / Copy C message array to internal struct for trim operation
struct TrimMessage {
std::string role;
std::string content;
std::string tool_call_id;
std::string tool_calls_json;
};
static size_t count_tokens_trim(const TrimMessage& msg) {
if (msg.content.empty()) return 4;
return count_tokens_utf8(msg.content.c_str(), msg.content.size(), 4);
}
static size_t count_tokens_trim_vec(const std::vector<TrimMessage>& msgs) {
size_t total = 0;
for (const auto& m : msgs) total += count_tokens_trim(m);
return total;
}
// 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚) / Free all host-allocated string fields in a single dstalk_message_t (OOM rollback helper).
static void free_msg_strs(dstalk_message_t* msg) {
if (msg->role) { g_host->free((void*)msg->role); msg->role = nullptr; }
if (msg->content) { g_host->free((void*)msg->content); msg->content = nullptr; }
if (msg->tool_call_id) { g_host->free((void*)msg->tool_call_id); msg->tool_call_id = nullptr; }
if (msg->tool_calls_json) { g_host->free((void*)msg->tool_calls_json); msg->tool_calls_json = nullptr; }
}
// 将 TrimMessage 的字符串字段通过 g_host->strdup 复制到 dstalk_message_t。
// 成功返回 0OOM 时释放当前消息已分配字段并返回 -1。
// Copy TrimMessage string fields into a dstalk_message_t via host->strdup.
// On OOM, frees already-allocated fields and returns -1.
static int strdup_message_fields(dstalk_message_t* dst, const TrimMessage& src) {
memset(dst, 0, sizeof(dstalk_message_t));
if (!src.role.empty()) {
dst->role = g_host->strdup(src.role.c_str());
if (!dst->role) goto oom;
}
if (!src.content.empty()) {
dst->content = g_host->strdup(src.content.c_str());
if (!dst->content) goto oom;
}
if (!src.tool_call_id.empty()) {
dst->tool_call_id = g_host->strdup(src.tool_call_id.c_str());
if (!dst->tool_call_id) goto oom;
}
if (!src.tool_calls_json.empty()) {
dst->tool_calls_json = g_host->strdup(src.tool_calls_json.c_str());
if (!dst->tool_calls_json) goto oom;
}
return 0;
oom:
free_msg_strs(dst);
return -1;
}
// W12.1 修复trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3) / W12.1 fix: trim_impl wrapped in try/catch to prevent C++ exceptions crossing ABI boundary (§5.3)
// 核心裁剪逻辑:通过删除最旧的 user/assistant 对来减少消息列表以适应 max_tokens。
// 保留 system 消息。try/catch 保护 ABI / Core trim logic: reduce message list to fit within max_tokens by removing
// oldest user/assistant pairs. Preserves system messages. try/catch guards ABI.
static int trim_impl(const dstalk_message_t* in, int in_count,
dstalk_message_t** out, int* out_count,
size_t max_tokens) {
try {
if (!in || in_count <= 0 || !out || !out_count) return -1;
// W18.1 (F-11.1-3): g_max_tokens 已移除,调用方必须提供有效 max_tokens
// 传 0 时使用硬编码默认值 4096 / g_max_tokens removed, caller must provide valid max_tokens;
// when 0 is passed, use hardcoded default 4096.
if (max_tokens == 0) max_tokens = 4096;
// 将 C 数组转换为内部 vector / Convert C array to internal vector
std::vector<TrimMessage> messages;
messages.reserve(in_count);
for (int i = 0; i < in_count; ++i) {
TrimMessage tm;
if (in[i].role) tm.role = in[i].role;
if (in[i].content) tm.content = in[i].content;
if (in[i].tool_call_id) tm.tool_call_id = in[i].tool_call_id;
if (in[i].tool_calls_json) tm.tool_calls_json = in[i].tool_calls_json;
messages.push_back(std::move(tm));
}
// 如果已在限制内,直接返回完整副本 / If already within limit, return full copy directly
size_t current = count_tokens_trim_vec(messages);
if (current <= max_tokens) {
*out_count = in_count;
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * in_count));
if (!*out) return -1;
// W12.1: strdup 返回值逐一检查OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback already allocated on OOM
for (int i = 0; i < in_count; ++i) {
if (strdup_message_fields(&(*out)[i], messages[i]) != 0) {
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
g_host->free(*out);
*out = nullptr;
return -1;
}
}
return 0;
}
// 分离 system 消息和非 system 消息 / Separate system messages from non-system messages
std::vector<TrimMessage> system_msgs;
std::vector<TrimMessage> non_system_msgs;
for (const auto& msg : messages) {
if (msg.role == "system") {
system_msgs.push_back(msg);
} else {
non_system_msgs.push_back(msg);
}
}
size_t system_tokens = count_tokens_trim_vec(system_msgs);
if (system_tokens > max_tokens) {
std::fprintf(stderr, "[context] WARNING: system messages alone "
"(%zu tokens) exceed max_context_tokens (%zu)\n",
system_tokens, max_tokens);
}
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
for (const auto& msg : non_system_msgs) {
size_t msg_tokens = count_tokens_trim(msg);
if (msg_tokens > max_tokens) {
std::fprintf(stderr, "[context] WARNING: single message "
"(%s, %zu tokens) exceeds max_context_tokens (%zu). "
"Returning empty list.\n",
msg.role.c_str(), msg_tokens, max_tokens);
*out = nullptr;
*out_count = 0;
return -1;
}
}
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
while (!non_system_msgs.empty()) {
current = system_tokens + count_tokens_trim_vec(non_system_msgs);
if (current <= max_tokens) break;
// 找第一个 "user" 消息 / Find first "user" message
auto user_it = non_system_msgs.begin();
while (user_it != non_system_msgs.end() && user_it->role != "user") {
++user_it;
}
if (user_it == non_system_msgs.end()) break;
// 找下一个 "assistant" / Find next "assistant"
auto assistant_it = user_it + 1;
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
++assistant_it;
}
if (assistant_it == non_system_msgs.end()) {
non_system_msgs.erase(user_it);
} else {
// 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
non_system_msgs.erase(assistant_it);
user_it = non_system_msgs.begin();
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
if (user_it != non_system_msgs.end()) non_system_msgs.erase(user_it);
}
}
// W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token使用当前 max_tokens / Message count upper bound rough estimate (~100 tokens per message), uses current max_tokens
{
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
if (max_msg_count < 1) max_msg_count = 1;
while (non_system_msgs.size() > max_msg_count) {
non_system_msgs.erase(non_system_msgs.begin());
}
}
// 组装结果 / Assemble result
std::vector<TrimMessage> result;
result.reserve(system_msgs.size() + non_system_msgs.size());
result.insert(result.end(), system_msgs.begin(), system_msgs.end());
result.insert(result.end(), non_system_msgs.begin(), non_system_msgs.end());
int result_count = static_cast<int>(result.size());
*out_count = result_count;
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * result_count));
if (!*out) return -1;
// W12.1: strdup 返回值逐一检查OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback on OOM
for (int i = 0; i < result_count; ++i) {
if (strdup_message_fields(&(*out)[i], result[i]) != 0) {
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
g_host->free(*out);
*out = nullptr;
return -1;
}
}
return 0;
} catch (const std::exception& e) {
// W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate() / Prevent C++ exceptions (std::bad_alloc etc.) from crossing C ABI boundary -> std::terminate()
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl exception: %s", e.what());
return -1;
} catch (...) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl unknown exception");
return -1;
}
}
// ============================================================
// Context 服务 vtable 实现 / Context service vtable implementation
// ============================================================
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate() / Wrapped try/catch prevents exceptions crossing C ABI boundary -> std::terminate()
// 对 C 消息数组进行 token 计数。输入为 null/空时返回 0 / Count tokens across an array of C messages. Returns 0 on null/empty input.
static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
try {
if (!msgs || count <= 0) return 0;
return count_tokens_all(msgs, count);
} catch (...) {
return 0;
}
}
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
// 裁剪消息列表以适应 max_tokens返回新分配的主机内存数组 / Trim a message list to fit within max_tokens, returning a new host-allocated array.
static int context_trim(const dstalk_message_t* in, int in_count,
dstalk_message_t** out, int* out_count,
size_t max_tokens) {
try {
return trim_impl(in, in_count, out, out_count, max_tokens);
} catch (...) {
return -1;
}
}
// W18.1 (F-11.1-3): g_max_tokens / context_set_max_tokens 已移除。
// max_tokens 由调用方通过 trim() 的 max_tokens 参数直接传入;
// 传 0 时 trim_impl 使用硬编码默认值 4096。
// g_max_tokens / context_set_max_tokens removed. max_tokens is passed directly
// by caller via trim()'s max_tokens parameter; trim_impl uses hardcoded default 4096 when 0.
static dstalk_context_service_t g_context_service = {
context_count_tokens,
context_trim
};
// ============================================================
// 插件生命周期 / Plugin lifecycle
// ============================================================
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
// 插件初始化:保存主机指针,查询 session 依赖,注册 context 服务 / Plugin init: store host pointer, query session dependency, register context service.
static int on_init(const dstalk_host_api_t* host) {
try {
g_host = host;
// 查询依赖服务: session / Query dependency service: session
void* raw = host->query_service("session", 1);
if (!raw) {
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
return -1;
}
g_session = static_cast<const dstalk_session_service_t*>(raw);
return host->register_service("context", 1, &g_context_service);
} catch (const std::exception& e) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init exception: %s", e.what());
return -1;
} catch (...) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init unknown exception");
return -1;
}
}
// W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log / Wrapped try/catch prevents exceptions crossing C ABI boundary -- void function only logs
// 插件关闭清空指针。try/catch 保护 ABIvoid 函数) / Plugin shutdown: null out pointers. try/catch guards ABI (void function).
static void on_shutdown() {
try {
g_session = nullptr;
g_host = nullptr;
} catch (const std::exception& e) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: %s", e.what());
g_session = nullptr;
g_host = nullptr;
} catch (...) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: unknown exception");
g_session = nullptr;
g_host = nullptr;
}
}
static dstalk_plugin_info_t g_info = {
"context",
"1.0.0",
"Context management plugin with token counting and trim support / 支持 token 计数和裁剪的上下文管理插件",
DSTALK_API_VERSION,
{"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
on_init,
on_shutdown,
nullptr
};
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info;
}