- W18.1 (王测+林深): Remove g_max_tokens dead API, UTF-8 bounds protection, deduplicate token counting, 0xC0/0xC1 handling, add 13 test blocks (36 checks) - W18.2 (赵码+朱晴): Fix /context no-session error message, /status 3-state connection display - W18.3 (曹武+徐磊): plugin_loader security audit — 9 dimensions, rating C, 1 HIGH + 2 MEDIUM findings - W18.4 (马奔+胡桐): CI dual-platform matrix (Ubuntu clang-18 + Windows clang-cl), ccache, build timing baseline Build 0 error, ctest 5/5 pass, metadata check clean. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
430 lines
17 KiB
C++
430 lines
17 KiB
C++
// ============================================================================
|
|
// context_plugin_test.cpp — 上下文插件单元测试
|
|
// ============================================================================
|
|
// W18.1 (qa-wang + architect-lin): 覆盖 token 计数、trim、UTF-8 边界、
|
|
// 0xC0/0xC1 过短编码检测。修复 F-11.1-3/4/5/6 后补充测试。
|
|
// ============================================================================
|
|
|
|
#include "dstalk/dstalk_host.h"
|
|
|
|
#include <cstring>
|
|
#include <filesystem>
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <string>
|
|
|
|
static int g_failures = 0;
|
|
#define CHECK(cond, msg) do { \
|
|
if (cond) { \
|
|
std::cout << "[OK] " << (msg) << "\n"; \
|
|
} else { \
|
|
std::cerr << "[FAIL] " << (msg) << "\n"; \
|
|
g_failures++; \
|
|
} \
|
|
} while (0)
|
|
|
|
int main()
|
|
{
|
|
const auto dir = std::filesystem::temp_directory_path() / "dstalk-ctx-test";
|
|
std::filesystem::create_directories(dir);
|
|
|
|
const auto config_path = dir / "config.toml";
|
|
{
|
|
std::ofstream config(config_path);
|
|
config << "[api]\n"
|
|
<< "provider = \"deepseek\"\n"
|
|
<< "base_url = \"https://api.deepseek.com/v1\"\n"
|
|
<< "api_key = \"test-key\"\n"
|
|
<< "model = \"deepseek-v4-pro\"\n";
|
|
}
|
|
|
|
if (dstalk_init(config_path.string().c_str()) != 0) {
|
|
std::cerr << "dstalk_init failed\n";
|
|
return 1;
|
|
}
|
|
|
|
auto* ctx = static_cast<const dstalk_context_service_t*>(
|
|
dstalk_service_query("context", 1));
|
|
if (!ctx) {
|
|
std::cerr << "context service not found\n";
|
|
dstalk_shutdown();
|
|
return 1;
|
|
}
|
|
std::cout << "[OK] context service found\n";
|
|
|
|
// ================================================================
|
|
// Test Block 1: count_tokens edge cases (null / empty)
|
|
// ================================================================
|
|
std::cout << "\n--- Block 1: count_tokens edge cases ---\n";
|
|
|
|
size_t tokens = ctx->count_tokens(nullptr, 0);
|
|
CHECK(tokens == 0, "T1.1: count_tokens(nullptr, 0) == 0");
|
|
|
|
tokens = ctx->count_tokens(nullptr, 5);
|
|
CHECK(tokens == 0, "T1.2: count_tokens(nullptr, 5) == 0");
|
|
|
|
{
|
|
dstalk_message_t empty_msg = {nullptr, nullptr, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&empty_msg, 1);
|
|
CHECK(tokens == 4, "T1.3: null-content message == 4 (overhead only)");
|
|
}
|
|
|
|
{
|
|
dstalk_message_t empty_str_msg = {"user", "", nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&empty_str_msg, 1);
|
|
CHECK(tokens == 4, "T1.4: empty-string content == 4 (overhead only)");
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 2: count_tokens — ASCII
|
|
// ================================================================
|
|
std::cout << "\n--- Block 2: count_tokens ASCII ---\n";
|
|
|
|
{
|
|
dstalk_message_t msg = {"user", "Hello World", nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
// 11 ascii chars / 4 = 2 + 4 overhead = 6
|
|
CHECK(tokens == 6, "T2.1: 'Hello World' (11 ASCII) == 6 tokens");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
dstalk_message_t msg = {"user", "abcd", nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
// 4 ascii chars / 4 = 1 + 4 overhead = 5
|
|
CHECK(tokens == 5, "T2.2: 'abcd' (4 ASCII) == 5 tokens");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
dstalk_message_t msg = {"user",
|
|
"This is a longer ASCII sentence for testing token counts",
|
|
nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T2.3: long ASCII sentence returns valid count");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 3: count_tokens — Chinese (CJK U+4E00-U+9FFF)
|
|
// ================================================================
|
|
std::cout << "\n--- Block 3: count_tokens Chinese (CJK) ---\n";
|
|
|
|
{
|
|
// 中文 = U+4E2D U+6587 = E4 B8 AD E6 96 87 (2 CJK chars)
|
|
dstalk_message_t msg = {"user",
|
|
"\xe4\xb8\xad\xe6\x96\x87", nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
// 2 chinese / 2 = 1 + 4 overhead = 5
|
|
CHECK(tokens == 5, "T3.1: 2 Chinese chars == 5 tokens");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
// 你好世界 = 4 CJK chars = 4/2 + 4 = 6
|
|
dstalk_message_t msg = {"user",
|
|
"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c",
|
|
nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens == 6, "T3.2: 4 Chinese chars == 6 tokens");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 4: count_tokens — Mixed content
|
|
// ================================================================
|
|
std::cout << "\n--- Block 4: count_tokens mixed content ---\n";
|
|
|
|
{
|
|
// "Hi 中文" = 3 ASCII + 2 CJK = 3/4 + 2/2 + 4 = 0+1+4 = 5
|
|
dstalk_message_t msg = {"user",
|
|
"Hi \xe4\xb8\xad\xe6\x96\x87", nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens == 5, "T4.1: 'Hi ' + 2 CJK == 5 tokens");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 5: Truncated UTF-8 bounds protection (F-11.1-4)
|
|
// ================================================================
|
|
std::cout << "\n--- Block 5: Truncated UTF-8 (F-11.1-4 fix) ---\n";
|
|
|
|
{
|
|
// Lone 0xE4 (3-byte sequence lead byte alone)
|
|
char buf[3] = {static_cast<char>(0xE4), 'A', '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T5.1: lone 0xE4 does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
// 0xE4 + 0x80 (3-byte missing last continuation byte)
|
|
char buf[4] = {static_cast<char>(0xE4), static_cast<char>(0x80),
|
|
'B', '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T5.2: 0xE4 0x80 (2/3 bytes) does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
// Lone 0xF0 (4-byte sequence lead byte alone)
|
|
char buf[3] = {static_cast<char>(0xF0), 'X', '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T5.3: lone 0xF0 (4-byte lead) does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
// 0xC2 alone (2-byte sequence missing continuation byte)
|
|
char buf[3] = {static_cast<char>(0xC2), 'Y', '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T5.4: 0xC2 alone (missing cont.) does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
// 2-byte lead + invalid continuation (0x00 instead of 0x80-0xBF)
|
|
char buf[4] = {static_cast<char>(0xC3), '\x00', 'Z', '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T5.5: invalid continuation byte does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6)
|
|
// ================================================================
|
|
std::cout << "\n--- Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6 fix) ---\n";
|
|
|
|
{
|
|
// 0xC0 0x80 = overlong encoding of NUL (U+0000)
|
|
char buf[4] = {static_cast<char>(0xC0), static_cast<char>(0x80), '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T6.1: 0xC0 0x80 overlong does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
// 0xC1 0xBF = overlong encoding of U+007F
|
|
char buf[4] = {static_cast<char>(0xC1), static_cast<char>(0xBF), '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T6.2: 0xC1 0xBF overlong does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
// 0xC0 alone (overlong lead without continuation)
|
|
char buf[3] = {static_cast<char>(0xC0), 'Q', '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T6.3: lone 0xC0 does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
// Verify 0xC0/0xC1 are NOT treated as valid 2-byte sequences
|
|
// They should each count as 1 other_char, not as 2-byte sequence
|
|
// 0xC0 + 0xC1 + 2 ASCII = 2 other + 2 ascii
|
|
// = (2/3) + (2/4) + 4 overhead = 0 + 0 + 4 = 4
|
|
// Actually 2/4 = 0 (integer division) for ascii, 2/3 = 0 for other
|
|
// So 0 + 0 + 4 = 4 tokens
|
|
char buf[6] = {static_cast<char>(0xC0), static_cast<char>(0xC1),
|
|
'a', 'b', '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens == 4, "T6.4: 0xC0+0xC1+2 ascii token count as expected");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 7: count_tokens — multiple messages
|
|
// ================================================================
|
|
std::cout << "\n--- Block 7: multiple messages ---\n";
|
|
|
|
{
|
|
dstalk_message_t msgs[3] = {
|
|
{"system", "You are helpful", nullptr, nullptr},
|
|
{"user", "Hello", nullptr, nullptr},
|
|
{"assistant", "Hi there", nullptr, nullptr}
|
|
};
|
|
tokens = ctx->count_tokens(msgs, 3);
|
|
// system: 15 ascii /4 = 3 + 4 = 7
|
|
// user: 5 ascii /4 = 1 + 4 = 5
|
|
// assistant: 8 ascii /4 = 2 + 4 = 6
|
|
// total = 7+5+6 = 18
|
|
CHECK(tokens == 18, "T7.1: 3 messages token count == 18");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
{
|
|
dstalk_message_t msgs[2] = {
|
|
{"user", "hi", nullptr, nullptr},
|
|
{"assistant", "ok", nullptr, nullptr}
|
|
};
|
|
tokens = ctx->count_tokens(msgs, 2);
|
|
// 2/4 + 4 + 2/4 + 4 = 0+4+0+4 = 8
|
|
CHECK(tokens == 8, "T7.2: 2 short messages == 8 tokens");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 8: trim — null and edge cases
|
|
// ================================================================
|
|
std::cout << "\n--- Block 8: trim edge cases ---\n";
|
|
|
|
{
|
|
dstalk_message_t* out = nullptr;
|
|
int out_count = 0;
|
|
|
|
int ret = ctx->trim(nullptr, 0, &out, &out_count, 100);
|
|
CHECK(ret == -1, "T8.1: trim(nullptr, 0) returns -1");
|
|
|
|
ret = ctx->trim(nullptr, 0, nullptr, nullptr, 100);
|
|
CHECK(ret == -1, "T8.2: trim with null output pointers returns -1");
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 9: trim — within limit (no trimming needed)
|
|
// ================================================================
|
|
std::cout << "\n--- Block 9: trim within limit ---\n";
|
|
|
|
{
|
|
dstalk_message_t msgs[2] = {
|
|
{"user", "hi", nullptr, nullptr},
|
|
{"assistant", "hello", nullptr, nullptr}
|
|
};
|
|
dstalk_message_t* out = nullptr;
|
|
int out_count = 0;
|
|
|
|
int ret = ctx->trim(msgs, 2, &out, &out_count, 4096);
|
|
CHECK(ret == 0, "T9.1: trim within limit returns 0");
|
|
CHECK(out != nullptr, "T9.2: trim allocates output");
|
|
CHECK(out_count == 2, "T9.3: trim preserves message count");
|
|
|
|
if (out && out_count >= 2) {
|
|
CHECK(out[0].role && std::strcmp(out[0].role, "user") == 0,
|
|
"T9.4: first message role preserved");
|
|
CHECK(out[0].content && std::strcmp(out[0].content, "hi") == 0,
|
|
"T9.5: first message content preserved");
|
|
dstalk_free(out);
|
|
} else if (out) {
|
|
dstalk_free(out);
|
|
}
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 10: trim — exceeds limit (trimming required)
|
|
// ================================================================
|
|
std::cout << "\n--- Block 10: trim exceeds limit ---\n";
|
|
|
|
{
|
|
// 4 long messages, each ~70 chars (~18 ASCII tokens + 4 overhead = 22),
|
|
// total ~88 tokens > 30 limit
|
|
dstalk_message_t msgs[4] = {
|
|
{"user",
|
|
"This is a long message with enough text to consume many tokens",
|
|
nullptr, nullptr},
|
|
{"assistant",
|
|
"Another long response that also uses up tokens with lots of words",
|
|
nullptr, nullptr},
|
|
{"user",
|
|
"A third long message pushing us well over the token budget limit",
|
|
nullptr, nullptr},
|
|
{"assistant",
|
|
"The fourth long message will cause us to exceed the max budget",
|
|
nullptr, nullptr}
|
|
};
|
|
dstalk_message_t* out = nullptr;
|
|
int out_count = 0;
|
|
|
|
int ret = ctx->trim(msgs, 4, &out, &out_count, 30);
|
|
// trim may return -1 if a single message exceeds limit, or 0 with reduced count
|
|
if (ret == 0 && out) {
|
|
CHECK(out_count <= 4, "T10.1: trim output count <= input count");
|
|
std::cout << " output count = " << out_count << " (in=4, limit=30)\n";
|
|
dstalk_free(out);
|
|
} else {
|
|
// Single message exceeds limit => returns -1 with empty output
|
|
std::cout << " trim returned " << ret << " (single msg > limit path)\n";
|
|
CHECK(ret == -1, "T10.2: expected ret=-1 (single msg exceeds 30 tokens)");
|
|
}
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 11: trim — system message preservation
|
|
// ================================================================
|
|
std::cout << "\n--- Block 11: trim preserves system messages ---\n";
|
|
|
|
{
|
|
dstalk_message_t msgs[3] = {
|
|
{"system", "You are a helpful assistant", nullptr, nullptr},
|
|
{"user",
|
|
"Hello this is a very long user message that will push us over the token budget",
|
|
nullptr, nullptr},
|
|
{"assistant",
|
|
"I am a very long assistant response designed to consume tokens for testing",
|
|
nullptr, nullptr}
|
|
};
|
|
dstalk_message_t* out = nullptr;
|
|
int out_count = 0;
|
|
|
|
int ret = ctx->trim(msgs, 3, &out, &out_count, 25);
|
|
if (ret >= 0 && out && out_count > 0) {
|
|
CHECK(out[0].role && std::strcmp(out[0].role, "system") == 0,
|
|
"T11.1: system message preserved as first in output");
|
|
std::cout << " output count = " << out_count << "\n";
|
|
dstalk_free(out);
|
|
} else if (out) {
|
|
dstalk_free(out);
|
|
}
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 12: count_tokens — 4-byte UTF-8 (emoji / supplementary)
|
|
// ================================================================
|
|
std::cout << "\n--- Block 12: 4-byte UTF-8 ---\n";
|
|
|
|
{
|
|
// U+1F600 (😀) = F0 9F 98 80
|
|
char buf[6] = {static_cast<char>(0xF0), static_cast<char>(0x9F),
|
|
static_cast<char>(0x98), static_cast<char>(0x80), '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
// 1 other_char / 3 + 4 overhead = 0 + 4 = 4
|
|
CHECK(tokens == 4, "T12.1: single 4-byte char (emoji) == 4 tokens");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
// ================================================================
|
|
// Test Block 13: count_tokens — continuation bytes as lone chars
|
|
// ================================================================
|
|
std::cout << "\n--- Block 13: lone continuation bytes ---\n";
|
|
|
|
{
|
|
// 0x80 alone (continuation byte without lead byte)
|
|
char buf[3] = {static_cast<char>(0x80), 'A', '\0'};
|
|
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
|
tokens = ctx->count_tokens(&msg, 1);
|
|
CHECK(tokens >= 4, "T13.1: lone continuation byte does not crash");
|
|
std::cout << " tokens = " << tokens << "\n";
|
|
}
|
|
|
|
dstalk_shutdown();
|
|
std::cout << "[OK] dstalk_shutdown succeeded\n";
|
|
|
|
std::cout << "\n";
|
|
if (g_failures == 0) {
|
|
std::cout << "=== All context plugin tests passed ===\n";
|
|
return 0;
|
|
} else {
|
|
std::cerr << "=== " << g_failures << " test(s) FAILED ===\n";
|
|
return 1;
|
|
}
|
|
}
|