- Introduced a new Python script `check_agents_metadata.py` for validating agent metadata, including YAML parsing, rating ranges, and cross-references. - Added usage instructions and exit codes for the script. - Created a new markdown file `模块目录和功能说明.md` to outline the directory structure and functionality of the modules. - Added a text file `说明此文件不可AI修改.txt` to specify that certain files should not be modified by AI, including important information about the `dstalk` framework and its modules.
452 lines
19 KiB
C++
452 lines
19 KiB
C++
/*
|
||
* @file context_plugin_test.cpp
|
||
* @brief Context plugin unit tests: token counting (ASCII, CJK, mixed, emoji),
|
||
* UTF-8 truncation safety, trim edge cases, and system message preservation.
|
||
* Context 插件单元测试:token 计数(ASCII、CJK、混合、emoji)、UTF-8 截断安全、trim 边界情况、系统消息保留。
|
||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||
*/
|
||
|
||
#include "dstalk/dstalk_host.h"
|
||
|
||
#include <cstring>
|
||
#include <filesystem>
|
||
#include <fstream>
|
||
#include <iostream>
|
||
#include <string>
|
||
|
||
static int g_failures = 0;
|
||
// Lightweight assertion macro: increments g_failures counter on failure
|
||
#define CHECK(cond, msg) do { \
|
||
if (cond) { \
|
||
std::cout << "[OK] " << (msg) << "\n"; \
|
||
} else { \
|
||
std::cerr << "[FAIL] " << (msg) << "\n"; \
|
||
g_failures++; \
|
||
} \
|
||
} while (0)
|
||
|
||
// Context 插件测试:token 计数边界(null、空、ASCII、CJK、混合)、截断 UTF-8 边界保护 (F-11.1-4)、
|
||
// 0xC0/0xC1 超长编码 (F-11.1-6)、多消息 token、trim 的各种场景、系统消息保留、4 字节 emoji、孤立的续字节。
|
||
// Context plugin tests: token counting edge cases (null, empty, ASCII, CJK, mixed),
|
||
// truncated UTF-8 bounds protection (F-11.1-4), 0xC0/0xC1 overlong encoding (F-11.1-6),
|
||
// multiple-message tokens, trim null/edge/within-limit/exceeds-limit scenarios,
|
||
// system message preservation, 4-byte emoji, and lone continuation bytes.
|
||
int main()
|
||
{
|
||
const auto dir = std::filesystem::temp_directory_path() / "dstalk-ctx-test";
|
||
std::filesystem::create_directories(dir);
|
||
|
||
const auto config_path = dir / "config.toml";
|
||
{
|
||
std::ofstream config(config_path);
|
||
config << "[api]\n"
|
||
<< "provider = \"deepseek\"\n"
|
||
<< "base_url = \"https://api.deepseek.com/v1\"\n"
|
||
<< "api_key = \"test-key\"\n"
|
||
<< "model = \"deepseek-v4-pro\"\n";
|
||
}
|
||
|
||
if (dstalk_init(config_path.string().c_str()) != 0) {
|
||
std::cerr << "dstalk_init failed\n";
|
||
return 1;
|
||
}
|
||
|
||
auto* ctx = static_cast<const dstalk_context_service_t*>(
|
||
dstalk_service_query("context", 1));
|
||
if (!ctx) {
|
||
std::cerr << "context service not found\n";
|
||
dstalk_shutdown();
|
||
return 1;
|
||
}
|
||
std::cout << "[OK] context service found\n";
|
||
|
||
// ================================================================
|
||
// Test Block 1: count_tokens edge cases (null / empty)
|
||
// 测试块 1:count_tokens 边界情况(null / 空)
|
||
// ================================================================
|
||
std::cout << "\n--- Block 1: count_tokens edge cases ---\n";
|
||
|
||
size_t tokens = ctx->count_tokens(nullptr, 0);
|
||
CHECK(tokens == 0, "T1.1: count_tokens(nullptr, 0) == 0");
|
||
|
||
tokens = ctx->count_tokens(nullptr, 5);
|
||
CHECK(tokens == 0, "T1.2: count_tokens(nullptr, 5) == 0");
|
||
|
||
{
|
||
dstalk_message_t empty_msg = {nullptr, nullptr, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&empty_msg, 1);
|
||
CHECK(tokens == 4, "T1.3: null-content message == 4 (overhead only)");
|
||
}
|
||
|
||
{
|
||
dstalk_message_t empty_str_msg = {"user", "", nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&empty_str_msg, 1);
|
||
CHECK(tokens == 4, "T1.4: empty-string content == 4 (overhead only)");
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 2: count_tokens — ASCII
|
||
// 测试块 2:count_tokens — ASCII
|
||
// ================================================================
|
||
std::cout << "\n--- Block 2: count_tokens ASCII ---\n";
|
||
|
||
{
|
||
dstalk_message_t msg = {"user", "Hello World", nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
// 11 ascii chars / 4 = 2 + 4 overhead = 6
|
||
CHECK(tokens == 6, "T2.1: 'Hello World' (11 ASCII) == 6 tokens");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
dstalk_message_t msg = {"user", "abcd", nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
// 4 ascii chars / 4 = 1 + 4 overhead = 5
|
||
CHECK(tokens == 5, "T2.2: 'abcd' (4 ASCII) == 5 tokens");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
dstalk_message_t msg = {"user",
|
||
"This is a longer ASCII sentence for testing token counts",
|
||
nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T2.3: long ASCII sentence returns valid count");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 3: count_tokens — Chinese (CJK U+4E00-U+9FFF)
|
||
// 测试块 3:count_tokens — 中文 (CJK U+4E00-U+9FFF)
|
||
// ================================================================
|
||
std::cout << "\n--- Block 3: count_tokens Chinese (CJK) ---\n";
|
||
|
||
{
|
||
// 中文 = U+4E2D U+6587 = E4 B8 AD E6 96 87 (2 CJK chars)
|
||
dstalk_message_t msg = {"user",
|
||
"\xe4\xb8\xad\xe6\x96\x87", nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
// 2 chinese / 2 = 1 + 4 overhead = 5
|
||
CHECK(tokens == 5, "T3.1: 2 Chinese chars == 5 tokens");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
// 你好世界 = 4 CJK chars = 4/2 + 4 = 6
|
||
dstalk_message_t msg = {"user",
|
||
"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c",
|
||
nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens == 6, "T3.2: 4 Chinese chars == 6 tokens");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 4: count_tokens — Mixed content
|
||
// 测试块 4:count_tokens — 混合内容
|
||
// ================================================================
|
||
std::cout << "\n--- Block 4: count_tokens mixed content ---\n";
|
||
|
||
{
|
||
// "Hi 中文" = 3 ASCII + 2 CJK = 3/4 + 2/2 + 4 = 0+1+4 = 5
|
||
dstalk_message_t msg = {"user",
|
||
"Hi \xe4\xb8\xad\xe6\x96\x87", nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens == 5, "T4.1: 'Hi ' + 2 CJK == 5 tokens");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 5: Truncated UTF-8 bounds protection (F-11.1-4)
|
||
// 测试块 5:截断 UTF-8 边界保护 (F-11.1-4)
|
||
// ================================================================
|
||
std::cout << "\n--- Block 5: Truncated UTF-8 (F-11.1-4 fix) ---\n";
|
||
|
||
{
|
||
// Lone 0xE4 (3-byte sequence lead byte alone)
|
||
char buf[3] = {static_cast<char>(0xE4), 'A', '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T5.1: lone 0xE4 does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
// 0xE4 + 0x80 (3-byte missing last continuation byte)
|
||
char buf[4] = {static_cast<char>(0xE4), static_cast<char>(0x80),
|
||
'B', '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T5.2: 0xE4 0x80 (2/3 bytes) does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
// Lone 0xF0 (4-byte sequence lead byte alone)
|
||
char buf[3] = {static_cast<char>(0xF0), 'X', '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T5.3: lone 0xF0 (4-byte lead) does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
// 0xC2 alone (2-byte sequence missing continuation byte)
|
||
char buf[3] = {static_cast<char>(0xC2), 'Y', '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T5.4: 0xC2 alone (missing cont.) does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
// 2-byte lead + invalid continuation (0x00 instead of 0x80-0xBF)
|
||
char buf[4] = {static_cast<char>(0xC3), '\x00', 'Z', '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T5.5: invalid continuation byte does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6)
|
||
// 测试块 6:0xC0/0xC1 超长编码 (F-11.1-6)
|
||
// ================================================================
|
||
std::cout << "\n--- Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6 fix) ---\n";
|
||
|
||
{
|
||
// 0xC0 0x80 = overlong encoding of NUL (U+0000)
|
||
char buf[4] = {static_cast<char>(0xC0), static_cast<char>(0x80), '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T6.1: 0xC0 0x80 overlong does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
// 0xC1 0xBF = overlong encoding of U+007F
|
||
char buf[4] = {static_cast<char>(0xC1), static_cast<char>(0xBF), '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T6.2: 0xC1 0xBF overlong does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
// 0xC0 alone (overlong lead without continuation)
|
||
char buf[3] = {static_cast<char>(0xC0), 'Q', '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T6.3: lone 0xC0 does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
// Verify 0xC0/0xC1 are NOT treated as valid 2-byte sequences
|
||
// They should each count as 1 other_char, not as 2-byte sequence
|
||
// 验证 0xC0/0xC1 不被视为合法的 2 字节序列 / 它们每个应计为 1 个 other_char,而非 2 字节序列
|
||
// 0xC0 + 0xC1 + 2 ASCII = 2 other + 2 ascii
|
||
// = (2/3) + (2/4) + 4 overhead = 0 + 0 + 4 = 4
|
||
// Actually 2/4 = 0 (integer division) for ascii, 2/3 = 0 for other
|
||
// So 0 + 0 + 4 = 4 tokens
|
||
char buf[6] = {static_cast<char>(0xC0), static_cast<char>(0xC1),
|
||
'a', 'b', '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens == 4, "T6.4: 0xC0+0xC1+2 ascii token count as expected");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 7: count_tokens — multiple messages
|
||
// 测试块 7:count_tokens — 多消息
|
||
// ================================================================
|
||
std::cout << "\n--- Block 7: multiple messages ---\n";
|
||
|
||
{
|
||
dstalk_message_t msgs[3] = {
|
||
{"system", "You are helpful", nullptr, nullptr},
|
||
{"user", "Hello", nullptr, nullptr},
|
||
{"assistant", "Hi there", nullptr, nullptr}
|
||
};
|
||
tokens = ctx->count_tokens(msgs, 3);
|
||
// system: 15 ascii /4 = 3 + 4 = 7
|
||
// user: 5 ascii /4 = 1 + 4 = 5
|
||
// assistant: 8 ascii /4 = 2 + 4 = 6
|
||
// total = 7+5+6 = 18
|
||
CHECK(tokens == 18, "T7.1: 3 messages token count == 18");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
{
|
||
dstalk_message_t msgs[2] = {
|
||
{"user", "hi", nullptr, nullptr},
|
||
{"assistant", "ok", nullptr, nullptr}
|
||
};
|
||
tokens = ctx->count_tokens(msgs, 2);
|
||
// 2/4 + 4 + 2/4 + 4 = 0+4+0+4 = 8
|
||
CHECK(tokens == 8, "T7.2: 2 short messages == 8 tokens");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 8: trim — null and edge cases
|
||
// 测试块 8:trim — null 和边界情况
|
||
// ================================================================
|
||
std::cout << "\n--- Block 8: trim edge cases ---\n";
|
||
|
||
{
|
||
dstalk_message_t* out = nullptr;
|
||
int out_count = 0;
|
||
|
||
int ret = ctx->trim(nullptr, 0, &out, &out_count, 100);
|
||
CHECK(ret == -1, "T8.1: trim(nullptr, 0) returns -1");
|
||
|
||
ret = ctx->trim(nullptr, 0, nullptr, nullptr, 100);
|
||
CHECK(ret == -1, "T8.2: trim with null output pointers returns -1");
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 9: trim — within limit (no trimming needed)
|
||
// 测试块 9:trim — 预算内(无需裁剪)
|
||
// ================================================================
|
||
std::cout << "\n--- Block 9: trim within limit ---\n";
|
||
|
||
{
|
||
dstalk_message_t msgs[2] = {
|
||
{"user", "hi", nullptr, nullptr},
|
||
{"assistant", "hello", nullptr, nullptr}
|
||
};
|
||
dstalk_message_t* out = nullptr;
|
||
int out_count = 0;
|
||
|
||
int ret = ctx->trim(msgs, 2, &out, &out_count, 4096);
|
||
CHECK(ret == 0, "T9.1: trim within limit returns 0");
|
||
CHECK(out != nullptr, "T9.2: trim allocates output");
|
||
CHECK(out_count == 2, "T9.3: trim preserves message count");
|
||
|
||
if (out && out_count >= 2) {
|
||
CHECK(out[0].role && std::strcmp(out[0].role, "user") == 0,
|
||
"T9.4: first message role preserved");
|
||
CHECK(out[0].content && std::strcmp(out[0].content, "hi") == 0,
|
||
"T9.5: first message content preserved");
|
||
dstalk_free(out);
|
||
} else if (out) {
|
||
dstalk_free(out);
|
||
}
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 10: trim — exceeds limit (trimming required)
|
||
// 测试块 10:trim — 超预算(需要裁剪)
|
||
// ================================================================
|
||
std::cout << "\n--- Block 10: trim exceeds limit ---\n";
|
||
|
||
{
|
||
// 4 long messages, each ~70 chars (~18 ASCII tokens + 4 overhead = 22),
|
||
// total ~88 tokens > 30 limit
|
||
dstalk_message_t msgs[4] = {
|
||
{"user",
|
||
"This is a long message with enough text to consume many tokens",
|
||
nullptr, nullptr},
|
||
{"assistant",
|
||
"Another long response that also uses up tokens with lots of words",
|
||
nullptr, nullptr},
|
||
{"user",
|
||
"A third long message pushing us well over the token budget limit",
|
||
nullptr, nullptr},
|
||
{"assistant",
|
||
"The fourth long message will cause us to exceed the max budget",
|
||
nullptr, nullptr}
|
||
};
|
||
dstalk_message_t* out = nullptr;
|
||
int out_count = 0;
|
||
|
||
int ret = ctx->trim(msgs, 4, &out, &out_count, 30);
|
||
// trim may return -1 if a single message exceeds limit, or 0 with reduced count
|
||
if (ret == 0 && out) {
|
||
CHECK(out_count <= 4, "T10.1: trim output count <= input count");
|
||
std::cout << " output count = " << out_count << " (in=4, limit=30)\n";
|
||
dstalk_free(out);
|
||
} else {
|
||
// Single message exceeds limit => returns -1 with empty output
|
||
std::cout << " trim returned " << ret << " (single msg > limit path)\n";
|
||
CHECK(ret == -1, "T10.2: expected ret=-1 (single msg exceeds 30 tokens)");
|
||
}
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 11: trim — system message preservation
|
||
// 测试块 11:trim — 系统消息保留
|
||
// ================================================================
|
||
std::cout << "\n--- Block 11: trim preserves system messages ---\n";
|
||
|
||
{
|
||
dstalk_message_t msgs[3] = {
|
||
{"system", "You are a helpful assistant", nullptr, nullptr},
|
||
{"user",
|
||
"Hello this is a very long user message that will push us over the token budget",
|
||
nullptr, nullptr},
|
||
{"assistant",
|
||
"I am a very long assistant response designed to consume tokens for testing",
|
||
nullptr, nullptr}
|
||
};
|
||
dstalk_message_t* out = nullptr;
|
||
int out_count = 0;
|
||
|
||
int ret = ctx->trim(msgs, 3, &out, &out_count, 25);
|
||
if (ret >= 0 && out && out_count > 0) {
|
||
CHECK(out[0].role && std::strcmp(out[0].role, "system") == 0,
|
||
"T11.1: system message preserved as first in output");
|
||
std::cout << " output count = " << out_count << "\n";
|
||
dstalk_free(out);
|
||
} else if (out) {
|
||
dstalk_free(out);
|
||
}
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 12: count_tokens — 4-byte UTF-8 (emoji / supplementary)
|
||
// 测试块 12:count_tokens — 4 字节 UTF-8(emoji / 补充平面)
|
||
// ================================================================
|
||
std::cout << "\n--- Block 12: 4-byte UTF-8 ---\n";
|
||
|
||
{
|
||
// U+1F600 (<28><>) = F0 9F 98 80
|
||
char buf[6] = {static_cast<char>(0xF0), static_cast<char>(0x9F),
|
||
static_cast<char>(0x98), static_cast<char>(0x80), '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
// 1 other_char / 3 + 4 overhead = 0 + 4 = 4
|
||
CHECK(tokens == 4, "T12.1: single 4-byte char (emoji) == 4 tokens");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
// ================================================================
|
||
// Test Block 13: count_tokens — continuation bytes as lone chars
|
||
// 测试块 13:count_tokens — 孤立的续字节
|
||
// ================================================================
|
||
std::cout << "\n--- Block 13: lone continuation bytes ---\n";
|
||
|
||
{
|
||
// 0x80 alone (continuation byte without lead byte)
|
||
char buf[3] = {static_cast<char>(0x80), 'A', '\0'};
|
||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||
tokens = ctx->count_tokens(&msg, 1);
|
||
CHECK(tokens >= 4, "T13.1: lone continuation byte does not crash");
|
||
std::cout << " tokens = " << tokens << "\n";
|
||
}
|
||
|
||
dstalk_shutdown();
|
||
std::cout << "[OK] dstalk_shutdown succeeded\n";
|
||
|
||
std::cout << "\n";
|
||
if (g_failures == 0) {
|
||
std::cout << "=== All context plugin tests passed ===\n";
|
||
return 0;
|
||
} else {
|
||
std::cerr << "=== " << g_failures << " test(s) FAILED ===\n";
|
||
return 1;
|
||
}
|
||
}
|