/* * @file context_plugin_test.cpp * @brief Context plugin unit tests: token counting (ASCII, CJK, mixed, emoji), * UTF-8 truncation safety, trim edge cases, and system message preservation. * Context 插件单元测试:token 计数(ASCII、CJK、混合、emoji)、UTF-8 截断安全、trim 边界情况、系统消息保留。 * Copyright (c) 2026 dstalk contributors. GPLv3. */ #include "dstalk/dstalk_host.h" #include #include #include #include #include static int g_failures = 0; // Lightweight assertion macro: increments g_failures counter on failure #define CHECK(cond, msg) do { \ if (cond) { \ std::cout << "[OK] " << (msg) << "\n"; \ } else { \ std::cerr << "[FAIL] " << (msg) << "\n"; \ g_failures++; \ } \ } while (0) // Context 插件测试:token 计数边界(null、空、ASCII、CJK、混合)、截断 UTF-8 边界保护 (F-11.1-4)、 // 0xC0/0xC1 超长编码 (F-11.1-6)、多消息 token、trim 的各种场景、系统消息保留、4 字节 emoji、孤立的续字节。 // Context plugin tests: token counting edge cases (null, empty, ASCII, CJK, mixed), // truncated UTF-8 bounds protection (F-11.1-4), 0xC0/0xC1 overlong encoding (F-11.1-6), // multiple-message tokens, trim null/edge/within-limit/exceeds-limit scenarios, // system message preservation, 4-byte emoji, and lone continuation bytes. int main() { const auto dir = std::filesystem::temp_directory_path() / "dstalk-ctx-test"; std::filesystem::create_directories(dir); const auto config_path = dir / "config.toml"; { std::ofstream config(config_path); config << "[api]\n" << "provider = \"openai\"\n" << "base_url = \"https://api.openai.com/v1\"\n" << "api_key = \"test-key\"\n" << "model = \"gpt-4o\"\n"; } if (dstalk_init(config_path.string().c_str()) != 0) { std::cerr << "dstalk_init failed\n"; return 1; } auto* ctx = static_cast( dstalk_service_query("context", 1)); if (!ctx) { std::cerr << "context service not found\n"; dstalk_shutdown(); return 1; } std::cout << "[OK] context service found\n"; // ================================================================ // Test Block 1: count_tokens edge cases (null / empty) // 测试块 1:count_tokens 边界情况(null / 空) // ================================================================ std::cout << "\n--- Block 1: count_tokens edge cases ---\n"; size_t tokens = ctx->count_tokens(nullptr, 0); CHECK(tokens == 0, "T1.1: count_tokens(nullptr, 0) == 0"); tokens = ctx->count_tokens(nullptr, 5); CHECK(tokens == 0, "T1.2: count_tokens(nullptr, 5) == 0"); { dstalk_message_t empty_msg = {nullptr, nullptr, nullptr, nullptr}; tokens = ctx->count_tokens(&empty_msg, 1); CHECK(tokens == 4, "T1.3: null-content message == 4 (overhead only)"); } { dstalk_message_t empty_str_msg = {"user", "", nullptr, nullptr}; tokens = ctx->count_tokens(&empty_str_msg, 1); CHECK(tokens == 4, "T1.4: empty-string content == 4 (overhead only)"); } // ================================================================ // Test Block 2: count_tokens — ASCII // 测试块 2:count_tokens — ASCII // ================================================================ std::cout << "\n--- Block 2: count_tokens ASCII ---\n"; { dstalk_message_t msg = {"user", "Hello World", nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); // 11 ascii chars / 4 = 2 + 4 overhead = 6 CHECK(tokens == 6, "T2.1: 'Hello World' (11 ASCII) == 6 tokens"); std::cout << " tokens = " << tokens << "\n"; } { dstalk_message_t msg = {"user", "abcd", nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); // 4 ascii chars / 4 = 1 + 4 overhead = 5 CHECK(tokens == 5, "T2.2: 'abcd' (4 ASCII) == 5 tokens"); std::cout << " tokens = " << tokens << "\n"; } { dstalk_message_t msg = {"user", "This is a longer ASCII sentence for testing token counts", nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T2.3: long ASCII sentence returns valid count"); std::cout << " tokens = " << tokens << "\n"; } // ================================================================ // Test Block 3: count_tokens — Chinese (CJK U+4E00-U+9FFF) // 测试块 3:count_tokens — 中文 (CJK U+4E00-U+9FFF) // ================================================================ std::cout << "\n--- Block 3: count_tokens Chinese (CJK) ---\n"; { // 中文 = U+4E2D U+6587 = E4 B8 AD E6 96 87 (2 CJK chars) dstalk_message_t msg = {"user", "\xe4\xb8\xad\xe6\x96\x87", nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); // 2 chinese / 2 = 1 + 4 overhead = 5 CHECK(tokens == 5, "T3.1: 2 Chinese chars == 5 tokens"); std::cout << " tokens = " << tokens << "\n"; } { // 你好世界 = 4 CJK chars = 4/2 + 4 = 6 dstalk_message_t msg = {"user", "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c", nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens == 6, "T3.2: 4 Chinese chars == 6 tokens"); std::cout << " tokens = " << tokens << "\n"; } // ================================================================ // Test Block 4: count_tokens — Mixed content // 测试块 4:count_tokens — 混合内容 // ================================================================ std::cout << "\n--- Block 4: count_tokens mixed content ---\n"; { // "Hi 中文" = 3 ASCII + 2 CJK = 3/4 + 2/2 + 4 = 0+1+4 = 5 dstalk_message_t msg = {"user", "Hi \xe4\xb8\xad\xe6\x96\x87", nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens == 5, "T4.1: 'Hi ' + 2 CJK == 5 tokens"); std::cout << " tokens = " << tokens << "\n"; } // ================================================================ // Test Block 5: Truncated UTF-8 bounds protection (F-11.1-4) // 测试块 5:截断 UTF-8 边界保护 (F-11.1-4) // ================================================================ std::cout << "\n--- Block 5: Truncated UTF-8 (F-11.1-4 fix) ---\n"; { // Lone 0xE4 (3-byte sequence lead byte alone) char buf[3] = {static_cast(0xE4), 'A', '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T5.1: lone 0xE4 does not crash"); std::cout << " tokens = " << tokens << "\n"; } { // 0xE4 + 0x80 (3-byte missing last continuation byte) char buf[4] = {static_cast(0xE4), static_cast(0x80), 'B', '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T5.2: 0xE4 0x80 (2/3 bytes) does not crash"); std::cout << " tokens = " << tokens << "\n"; } { // Lone 0xF0 (4-byte sequence lead byte alone) char buf[3] = {static_cast(0xF0), 'X', '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T5.3: lone 0xF0 (4-byte lead) does not crash"); std::cout << " tokens = " << tokens << "\n"; } { // 0xC2 alone (2-byte sequence missing continuation byte) char buf[3] = {static_cast(0xC2), 'Y', '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T5.4: 0xC2 alone (missing cont.) does not crash"); std::cout << " tokens = " << tokens << "\n"; } { // 2-byte lead + invalid continuation (0x00 instead of 0x80-0xBF) char buf[4] = {static_cast(0xC3), '\x00', 'Z', '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T5.5: invalid continuation byte does not crash"); std::cout << " tokens = " << tokens << "\n"; } // ================================================================ // Test Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6) // 测试块 6:0xC0/0xC1 超长编码 (F-11.1-6) // ================================================================ std::cout << "\n--- Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6 fix) ---\n"; { // 0xC0 0x80 = overlong encoding of NUL (U+0000) char buf[4] = {static_cast(0xC0), static_cast(0x80), '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T6.1: 0xC0 0x80 overlong does not crash"); std::cout << " tokens = " << tokens << "\n"; } { // 0xC1 0xBF = overlong encoding of U+007F char buf[4] = {static_cast(0xC1), static_cast(0xBF), '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T6.2: 0xC1 0xBF overlong does not crash"); std::cout << " tokens = " << tokens << "\n"; } { // 0xC0 alone (overlong lead without continuation) char buf[3] = {static_cast(0xC0), 'Q', '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T6.3: lone 0xC0 does not crash"); std::cout << " tokens = " << tokens << "\n"; } { // Verify 0xC0/0xC1 are NOT treated as valid 2-byte sequences // They should each count as 1 other_char, not as 2-byte sequence // 验证 0xC0/0xC1 不被视为合法的 2 字节序列 / 它们每个应计为 1 个 other_char,而非 2 字节序列 // 0xC0 + 0xC1 + 2 ASCII = 2 other + 2 ascii // = (2/3) + (2/4) + 4 overhead = 0 + 0 + 4 = 4 // Actually 2/4 = 0 (integer division) for ascii, 2/3 = 0 for other // So 0 + 0 + 4 = 4 tokens char buf[6] = {static_cast(0xC0), static_cast(0xC1), 'a', 'b', '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens == 4, "T6.4: 0xC0+0xC1+2 ascii token count as expected"); std::cout << " tokens = " << tokens << "\n"; } // ================================================================ // Test Block 7: count_tokens — multiple messages // 测试块 7:count_tokens — 多消息 // ================================================================ std::cout << "\n--- Block 7: multiple messages ---\n"; { dstalk_message_t msgs[3] = { {"system", "You are helpful", nullptr, nullptr}, {"user", "Hello", nullptr, nullptr}, {"assistant", "Hi there", nullptr, nullptr} }; tokens = ctx->count_tokens(msgs, 3); // system: 15 ascii /4 = 3 + 4 = 7 // user: 5 ascii /4 = 1 + 4 = 5 // assistant: 8 ascii /4 = 2 + 4 = 6 // total = 7+5+6 = 18 CHECK(tokens == 18, "T7.1: 3 messages token count == 18"); std::cout << " tokens = " << tokens << "\n"; } { dstalk_message_t msgs[2] = { {"user", "hi", nullptr, nullptr}, {"assistant", "ok", nullptr, nullptr} }; tokens = ctx->count_tokens(msgs, 2); // 2/4 + 4 + 2/4 + 4 = 0+4+0+4 = 8 CHECK(tokens == 8, "T7.2: 2 short messages == 8 tokens"); std::cout << " tokens = " << tokens << "\n"; } // ================================================================ // Test Block 8: trim — null and edge cases // 测试块 8:trim — null 和边界情况 // ================================================================ std::cout << "\n--- Block 8: trim edge cases ---\n"; { dstalk_message_t* out = nullptr; int out_count = 0; int ret = ctx->trim(nullptr, 0, &out, &out_count, 100); CHECK(ret == -1, "T8.1: trim(nullptr, 0) returns -1"); ret = ctx->trim(nullptr, 0, nullptr, nullptr, 100); CHECK(ret == -1, "T8.2: trim with null output pointers returns -1"); } // ================================================================ // Test Block 9: trim — within limit (no trimming needed) // 测试块 9:trim — 预算内(无需裁剪) // ================================================================ std::cout << "\n--- Block 9: trim within limit ---\n"; { dstalk_message_t msgs[2] = { {"user", "hi", nullptr, nullptr}, {"assistant", "hello", nullptr, nullptr} }; dstalk_message_t* out = nullptr; int out_count = 0; int ret = ctx->trim(msgs, 2, &out, &out_count, 4096); CHECK(ret == 0, "T9.1: trim within limit returns 0"); CHECK(out != nullptr, "T9.2: trim allocates output"); CHECK(out_count == 2, "T9.3: trim preserves message count"); if (out && out_count >= 2) { CHECK(out[0].role && std::strcmp(out[0].role, "user") == 0, "T9.4: first message role preserved"); CHECK(out[0].content && std::strcmp(out[0].content, "hi") == 0, "T9.5: first message content preserved"); dstalk_free(out); } else if (out) { dstalk_free(out); } } // ================================================================ // Test Block 10: trim — exceeds limit (trimming required) // 测试块 10:trim — 超预算(需要裁剪) // ================================================================ std::cout << "\n--- Block 10: trim exceeds limit ---\n"; { // 4 long messages, each ~70 chars (~18 ASCII tokens + 4 overhead = 22), // total ~88 tokens > 30 limit dstalk_message_t msgs[4] = { {"user", "This is a long message with enough text to consume many tokens", nullptr, nullptr}, {"assistant", "Another long response that also uses up tokens with lots of words", nullptr, nullptr}, {"user", "A third long message pushing us well over the token budget limit", nullptr, nullptr}, {"assistant", "The fourth long message will cause us to exceed the max budget", nullptr, nullptr} }; dstalk_message_t* out = nullptr; int out_count = 0; int ret = ctx->trim(msgs, 4, &out, &out_count, 30); // trim may return -1 if a single message exceeds limit, or 0 with reduced count if (ret == 0 && out) { CHECK(out_count <= 4, "T10.1: trim output count <= input count"); std::cout << " output count = " << out_count << " (in=4, limit=30)\n"; dstalk_free(out); } else { // Single message exceeds limit => returns -1 with empty output std::cout << " trim returned " << ret << " (single msg > limit path)\n"; CHECK(ret == -1, "T10.2: expected ret=-1 (single msg exceeds 30 tokens)"); } } // ================================================================ // Test Block 11: trim — system message preservation // 测试块 11:trim — 系统消息保留 // ================================================================ std::cout << "\n--- Block 11: trim preserves system messages ---\n"; { dstalk_message_t msgs[3] = { {"system", "You are a helpful assistant", nullptr, nullptr}, {"user", "Hello this is a very long user message that will push us over the token budget", nullptr, nullptr}, {"assistant", "I am a very long assistant response designed to consume tokens for testing", nullptr, nullptr} }; dstalk_message_t* out = nullptr; int out_count = 0; int ret = ctx->trim(msgs, 3, &out, &out_count, 25); if (ret >= 0 && out && out_count > 0) { CHECK(out[0].role && std::strcmp(out[0].role, "system") == 0, "T11.1: system message preserved as first in output"); std::cout << " output count = " << out_count << "\n"; dstalk_free(out); } else if (out) { dstalk_free(out); } } // ================================================================ // Test Block 12: count_tokens — 4-byte UTF-8 (emoji / supplementary) // 测试块 12:count_tokens — 4 字节 UTF-8(emoji / 补充平面) // ================================================================ std::cout << "\n--- Block 12: 4-byte UTF-8 ---\n"; { // U+1F600 (��) = F0 9F 98 80 char buf[6] = {static_cast(0xF0), static_cast(0x9F), static_cast(0x98), static_cast(0x80), '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); // 1 other_char / 3 + 4 overhead = 0 + 4 = 4 CHECK(tokens == 4, "T12.1: single 4-byte char (emoji) == 4 tokens"); std::cout << " tokens = " << tokens << "\n"; } // ================================================================ // Test Block 13: count_tokens — continuation bytes as lone chars // 测试块 13:count_tokens — 孤立的续字节 // ================================================================ std::cout << "\n--- Block 13: lone continuation bytes ---\n"; { // 0x80 alone (continuation byte without lead byte) char buf[3] = {static_cast(0x80), 'A', '\0'}; dstalk_message_t msg = {"user", buf, nullptr, nullptr}; tokens = ctx->count_tokens(&msg, 1); CHECK(tokens >= 4, "T13.1: lone continuation byte does not crash"); std::cout << " tokens = " << tokens << "\n"; } dstalk_shutdown(); std::cout << "[OK] dstalk_shutdown succeeded\n"; std::cout << "\n"; if (g_failures == 0) { std::cout << "=== All context plugin tests passed ===\n"; return 0; } else { std::cerr << "=== " << g_failures << " test(s) FAILED ===\n"; return 1; } }