Files
dstalk/tests/context_plugin_test.cpp
XiuChengWu f6cb51b40a Add unit tests for OpenAI plugin and establish coding standards
- Introduced comprehensive unit tests for the OpenAI plugin, covering SSE parsing, sentinel matching, delta extraction, request building, and more.
- Created a new markdown file detailing coding and naming conventions for the dstalk project, including guidelines for comments, naming rules, code organization, and memory management practices.
2026-05-31 00:51:59 +08:00

452 lines
19 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* @file context_plugin_test.cpp
* @brief Context plugin unit tests: token counting (ASCII, CJK, mixed, emoji),
* UTF-8 truncation safety, trim edge cases, and system message preservation.
* Context 插件单元测试token 计数ASCII、CJK、混合、emoji、UTF-8 截断安全、trim 边界情况、系统消息保留。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h"
#include <cstring>
#include <filesystem>
#include <fstream>
#include <iostream>
#include <string>
static int g_failures = 0;
// Lightweight assertion macro: increments g_failures counter on failure
#define CHECK(cond, msg) do { \
if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \
} else { \
std::cerr << "[FAIL] " << (msg) << "\n"; \
g_failures++; \
} \
} while (0)
// Context 插件测试token 计数边界null、空、ASCII、CJK、混合、截断 UTF-8 边界保护 (F-11.1-4)、
// 0xC0/0xC1 超长编码 (F-11.1-6)、多消息 token、trim 的各种场景、系统消息保留、4 字节 emoji、孤立的续字节。
// Context plugin tests: token counting edge cases (null, empty, ASCII, CJK, mixed),
// truncated UTF-8 bounds protection (F-11.1-4), 0xC0/0xC1 overlong encoding (F-11.1-6),
// multiple-message tokens, trim null/edge/within-limit/exceeds-limit scenarios,
// system message preservation, 4-byte emoji, and lone continuation bytes.
int main()
{
const auto dir = std::filesystem::temp_directory_path() / "dstalk-ctx-test";
std::filesystem::create_directories(dir);
const auto config_path = dir / "config.toml";
{
std::ofstream config(config_path);
config << "[api]\n"
<< "provider = \"openai\"\n"
<< "base_url = \"https://api.openai.com/v1\"\n"
<< "api_key = \"test-key\"\n"
<< "model = \"gpt-4o\"\n";
}
if (dstalk_init(config_path.string().c_str()) != 0) {
std::cerr << "dstalk_init failed\n";
return 1;
}
auto* ctx = static_cast<const dstalk_context_service_t*>(
dstalk_service_query("context", 1));
if (!ctx) {
std::cerr << "context service not found\n";
dstalk_shutdown();
return 1;
}
std::cout << "[OK] context service found\n";
// ================================================================
// Test Block 1: count_tokens edge cases (null / empty)
// 测试块 1count_tokens 边界情况null / 空)
// ================================================================
std::cout << "\n--- Block 1: count_tokens edge cases ---\n";
size_t tokens = ctx->count_tokens(nullptr, 0);
CHECK(tokens == 0, "T1.1: count_tokens(nullptr, 0) == 0");
tokens = ctx->count_tokens(nullptr, 5);
CHECK(tokens == 0, "T1.2: count_tokens(nullptr, 5) == 0");
{
dstalk_message_t empty_msg = {nullptr, nullptr, nullptr, nullptr};
tokens = ctx->count_tokens(&empty_msg, 1);
CHECK(tokens == 4, "T1.3: null-content message == 4 (overhead only)");
}
{
dstalk_message_t empty_str_msg = {"user", "", nullptr, nullptr};
tokens = ctx->count_tokens(&empty_str_msg, 1);
CHECK(tokens == 4, "T1.4: empty-string content == 4 (overhead only)");
}
// ================================================================
// Test Block 2: count_tokens — ASCII
// 测试块 2count_tokens — ASCII
// ================================================================
std::cout << "\n--- Block 2: count_tokens ASCII ---\n";
{
dstalk_message_t msg = {"user", "Hello World", nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
// 11 ascii chars / 4 = 2 + 4 overhead = 6
CHECK(tokens == 6, "T2.1: 'Hello World' (11 ASCII) == 6 tokens");
std::cout << " tokens = " << tokens << "\n";
}
{
dstalk_message_t msg = {"user", "abcd", nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
// 4 ascii chars / 4 = 1 + 4 overhead = 5
CHECK(tokens == 5, "T2.2: 'abcd' (4 ASCII) == 5 tokens");
std::cout << " tokens = " << tokens << "\n";
}
{
dstalk_message_t msg = {"user",
"This is a longer ASCII sentence for testing token counts",
nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T2.3: long ASCII sentence returns valid count");
std::cout << " tokens = " << tokens << "\n";
}
// ================================================================
// Test Block 3: count_tokens — Chinese (CJK U+4E00-U+9FFF)
// 测试块 3count_tokens — 中文 (CJK U+4E00-U+9FFF)
// ================================================================
std::cout << "\n--- Block 3: count_tokens Chinese (CJK) ---\n";
{
// 中文 = U+4E2D U+6587 = E4 B8 AD E6 96 87 (2 CJK chars)
dstalk_message_t msg = {"user",
"\xe4\xb8\xad\xe6\x96\x87", nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
// 2 chinese / 2 = 1 + 4 overhead = 5
CHECK(tokens == 5, "T3.1: 2 Chinese chars == 5 tokens");
std::cout << " tokens = " << tokens << "\n";
}
{
// 你好世界 = 4 CJK chars = 4/2 + 4 = 6
dstalk_message_t msg = {"user",
"\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c",
nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens == 6, "T3.2: 4 Chinese chars == 6 tokens");
std::cout << " tokens = " << tokens << "\n";
}
// ================================================================
// Test Block 4: count_tokens — Mixed content
// 测试块 4count_tokens — 混合内容
// ================================================================
std::cout << "\n--- Block 4: count_tokens mixed content ---\n";
{
// "Hi 中文" = 3 ASCII + 2 CJK = 3/4 + 2/2 + 4 = 0+1+4 = 5
dstalk_message_t msg = {"user",
"Hi \xe4\xb8\xad\xe6\x96\x87", nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens == 5, "T4.1: 'Hi ' + 2 CJK == 5 tokens");
std::cout << " tokens = " << tokens << "\n";
}
// ================================================================
// Test Block 5: Truncated UTF-8 bounds protection (F-11.1-4)
// 测试块 5截断 UTF-8 边界保护 (F-11.1-4)
// ================================================================
std::cout << "\n--- Block 5: Truncated UTF-8 (F-11.1-4 fix) ---\n";
{
// Lone 0xE4 (3-byte sequence lead byte alone)
char buf[3] = {static_cast<char>(0xE4), 'A', '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T5.1: lone 0xE4 does not crash");
std::cout << " tokens = " << tokens << "\n";
}
{
// 0xE4 + 0x80 (3-byte missing last continuation byte)
char buf[4] = {static_cast<char>(0xE4), static_cast<char>(0x80),
'B', '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T5.2: 0xE4 0x80 (2/3 bytes) does not crash");
std::cout << " tokens = " << tokens << "\n";
}
{
// Lone 0xF0 (4-byte sequence lead byte alone)
char buf[3] = {static_cast<char>(0xF0), 'X', '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T5.3: lone 0xF0 (4-byte lead) does not crash");
std::cout << " tokens = " << tokens << "\n";
}
{
// 0xC2 alone (2-byte sequence missing continuation byte)
char buf[3] = {static_cast<char>(0xC2), 'Y', '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T5.4: 0xC2 alone (missing cont.) does not crash");
std::cout << " tokens = " << tokens << "\n";
}
{
// 2-byte lead + invalid continuation (0x00 instead of 0x80-0xBF)
char buf[4] = {static_cast<char>(0xC3), '\x00', 'Z', '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T5.5: invalid continuation byte does not crash");
std::cout << " tokens = " << tokens << "\n";
}
// ================================================================
// Test Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6)
// 测试块 60xC0/0xC1 超长编码 (F-11.1-6)
// ================================================================
std::cout << "\n--- Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6 fix) ---\n";
{
// 0xC0 0x80 = overlong encoding of NUL (U+0000)
char buf[4] = {static_cast<char>(0xC0), static_cast<char>(0x80), '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T6.1: 0xC0 0x80 overlong does not crash");
std::cout << " tokens = " << tokens << "\n";
}
{
// 0xC1 0xBF = overlong encoding of U+007F
char buf[4] = {static_cast<char>(0xC1), static_cast<char>(0xBF), '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T6.2: 0xC1 0xBF overlong does not crash");
std::cout << " tokens = " << tokens << "\n";
}
{
// 0xC0 alone (overlong lead without continuation)
char buf[3] = {static_cast<char>(0xC0), 'Q', '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T6.3: lone 0xC0 does not crash");
std::cout << " tokens = " << tokens << "\n";
}
{
// Verify 0xC0/0xC1 are NOT treated as valid 2-byte sequences
// They should each count as 1 other_char, not as 2-byte sequence
// 验证 0xC0/0xC1 不被视为合法的 2 字节序列 / 它们每个应计为 1 个 other_char而非 2 字节序列
// 0xC0 + 0xC1 + 2 ASCII = 2 other + 2 ascii
// = (2/3) + (2/4) + 4 overhead = 0 + 0 + 4 = 4
// Actually 2/4 = 0 (integer division) for ascii, 2/3 = 0 for other
// So 0 + 0 + 4 = 4 tokens
char buf[6] = {static_cast<char>(0xC0), static_cast<char>(0xC1),
'a', 'b', '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens == 4, "T6.4: 0xC0+0xC1+2 ascii token count as expected");
std::cout << " tokens = " << tokens << "\n";
}
// ================================================================
// Test Block 7: count_tokens — multiple messages
// 测试块 7count_tokens — 多消息
// ================================================================
std::cout << "\n--- Block 7: multiple messages ---\n";
{
dstalk_message_t msgs[3] = {
{"system", "You are helpful", nullptr, nullptr},
{"user", "Hello", nullptr, nullptr},
{"assistant", "Hi there", nullptr, nullptr}
};
tokens = ctx->count_tokens(msgs, 3);
// system: 15 ascii /4 = 3 + 4 = 7
// user: 5 ascii /4 = 1 + 4 = 5
// assistant: 8 ascii /4 = 2 + 4 = 6
// total = 7+5+6 = 18
CHECK(tokens == 18, "T7.1: 3 messages token count == 18");
std::cout << " tokens = " << tokens << "\n";
}
{
dstalk_message_t msgs[2] = {
{"user", "hi", nullptr, nullptr},
{"assistant", "ok", nullptr, nullptr}
};
tokens = ctx->count_tokens(msgs, 2);
// 2/4 + 4 + 2/4 + 4 = 0+4+0+4 = 8
CHECK(tokens == 8, "T7.2: 2 short messages == 8 tokens");
std::cout << " tokens = " << tokens << "\n";
}
// ================================================================
// Test Block 8: trim — null and edge cases
// 测试块 8trim — null 和边界情况
// ================================================================
std::cout << "\n--- Block 8: trim edge cases ---\n";
{
dstalk_message_t* out = nullptr;
int out_count = 0;
int ret = ctx->trim(nullptr, 0, &out, &out_count, 100);
CHECK(ret == -1, "T8.1: trim(nullptr, 0) returns -1");
ret = ctx->trim(nullptr, 0, nullptr, nullptr, 100);
CHECK(ret == -1, "T8.2: trim with null output pointers returns -1");
}
// ================================================================
// Test Block 9: trim — within limit (no trimming needed)
// 测试块 9trim — 预算内(无需裁剪)
// ================================================================
std::cout << "\n--- Block 9: trim within limit ---\n";
{
dstalk_message_t msgs[2] = {
{"user", "hi", nullptr, nullptr},
{"assistant", "hello", nullptr, nullptr}
};
dstalk_message_t* out = nullptr;
int out_count = 0;
int ret = ctx->trim(msgs, 2, &out, &out_count, 4096);
CHECK(ret == 0, "T9.1: trim within limit returns 0");
CHECK(out != nullptr, "T9.2: trim allocates output");
CHECK(out_count == 2, "T9.3: trim preserves message count");
if (out && out_count >= 2) {
CHECK(out[0].role && std::strcmp(out[0].role, "user") == 0,
"T9.4: first message role preserved");
CHECK(out[0].content && std::strcmp(out[0].content, "hi") == 0,
"T9.5: first message content preserved");
dstalk_free(out);
} else if (out) {
dstalk_free(out);
}
}
// ================================================================
// Test Block 10: trim — exceeds limit (trimming required)
// 测试块 10trim — 超预算(需要裁剪)
// ================================================================
std::cout << "\n--- Block 10: trim exceeds limit ---\n";
{
// 4 long messages, each ~70 chars (~18 ASCII tokens + 4 overhead = 22),
// total ~88 tokens > 30 limit
dstalk_message_t msgs[4] = {
{"user",
"This is a long message with enough text to consume many tokens",
nullptr, nullptr},
{"assistant",
"Another long response that also uses up tokens with lots of words",
nullptr, nullptr},
{"user",
"A third long message pushing us well over the token budget limit",
nullptr, nullptr},
{"assistant",
"The fourth long message will cause us to exceed the max budget",
nullptr, nullptr}
};
dstalk_message_t* out = nullptr;
int out_count = 0;
int ret = ctx->trim(msgs, 4, &out, &out_count, 30);
// trim may return -1 if a single message exceeds limit, or 0 with reduced count
if (ret == 0 && out) {
CHECK(out_count <= 4, "T10.1: trim output count <= input count");
std::cout << " output count = " << out_count << " (in=4, limit=30)\n";
dstalk_free(out);
} else {
// Single message exceeds limit => returns -1 with empty output
std::cout << " trim returned " << ret << " (single msg > limit path)\n";
CHECK(ret == -1, "T10.2: expected ret=-1 (single msg exceeds 30 tokens)");
}
}
// ================================================================
// Test Block 11: trim — system message preservation
// 测试块 11trim — 系统消息保留
// ================================================================
std::cout << "\n--- Block 11: trim preserves system messages ---\n";
{
dstalk_message_t msgs[3] = {
{"system", "You are a helpful assistant", nullptr, nullptr},
{"user",
"Hello this is a very long user message that will push us over the token budget",
nullptr, nullptr},
{"assistant",
"I am a very long assistant response designed to consume tokens for testing",
nullptr, nullptr}
};
dstalk_message_t* out = nullptr;
int out_count = 0;
int ret = ctx->trim(msgs, 3, &out, &out_count, 25);
if (ret >= 0 && out && out_count > 0) {
CHECK(out[0].role && std::strcmp(out[0].role, "system") == 0,
"T11.1: system message preserved as first in output");
std::cout << " output count = " << out_count << "\n";
dstalk_free(out);
} else if (out) {
dstalk_free(out);
}
}
// ================================================================
// Test Block 12: count_tokens — 4-byte UTF-8 (emoji / supplementary)
// 测试块 12count_tokens — 4 字节 UTF-8emoji / 补充平面)
// ================================================================
std::cout << "\n--- Block 12: 4-byte UTF-8 ---\n";
{
// U+1F600 (<28><>) = F0 9F 98 80
char buf[6] = {static_cast<char>(0xF0), static_cast<char>(0x9F),
static_cast<char>(0x98), static_cast<char>(0x80), '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
// 1 other_char / 3 + 4 overhead = 0 + 4 = 4
CHECK(tokens == 4, "T12.1: single 4-byte char (emoji) == 4 tokens");
std::cout << " tokens = " << tokens << "\n";
}
// ================================================================
// Test Block 13: count_tokens — continuation bytes as lone chars
// 测试块 13count_tokens — 孤立的续字节
// ================================================================
std::cout << "\n--- Block 13: lone continuation bytes ---\n";
{
// 0x80 alone (continuation byte without lead byte)
char buf[3] = {static_cast<char>(0x80), 'A', '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
tokens = ctx->count_tokens(&msg, 1);
CHECK(tokens >= 4, "T13.1: lone continuation byte does not crash");
std::cout << " tokens = " << tokens << "\n";
}
dstalk_shutdown();
std::cout << "[OK] dstalk_shutdown succeeded\n";
std::cout << "\n";
if (g_failures == 0) {
std::cout << "=== All context plugin tests passed ===\n";
return 0;
} else {
std::cerr << "=== " << g_failures << " test(s) FAILED ===\n";
return 1;
}
}