Files
dstalk/tests/endpoint_mgr_plugin_test.cpp
XiuChengWu 4745ce1f1c
Some checks failed
CI / Determine matrix (push) Has been cancelled
CI / Sanitizer (ASan+UBSan) / ubuntu-24.04 (push) Has been cancelled
CI / Coverage (gcovr) / ubuntu-24.04 (push) Has been cancelled
CI / ${{ matrix.os }} / ${{ matrix.build_type }} (push) Has been cancelled
feat: add AI endpoint manager plugin with configuration and routing capabilities
- Introduced `ai_endpoint_mgr` plugin to manage multiple AI provider endpoints.
- Added configuration reference documentation for `config.toml`.
- Implemented endpoint loading, active endpoint switching, and model mutation.
- Included error handling for missing endpoints and configuration failures.
- Developed unit tests covering various scenarios including error paths and concurrency.
2026-06-03 21:07:25 +08:00

548 lines
21 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* @file endpoint_mgr_plugin_test.cpp
* @brief AI endpoint manager plugin unit tests: endpoint loading, active/model mutation, routing, secret-safe listing,
* plus error-path coverage (null history, missing endpoint, configure failed, empty/bad active, concurrency).
* AI endpoint 管理器插件单元测试endpoint 加载、active/model 修改、路由和脱敏列表、
* 以及错误路径覆盖(空 history、缺失 endpoint、configure 失败、空/错 active、并发
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "../plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp"
#include <cstdarg>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
static int g_failures = 0;
#define CHECK(cond, msg) do { \
if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \
} else { \
std::cerr << "[FAIL] " << (msg) << "\n"; \
g_failures++; \
} \
} while (0)
struct ConfigureRecord {
std::string provider;
std::string base_url;
std::string api_key;
std::string model;
int max_tokens = 0;
double temperature = 0.0;
int configure_calls = 0;
int chat_calls = 0;
int stream_calls = 0;
};
static std::unordered_map<std::string, std::string> g_config_values;
static const dstalk_ai_endpoint_mgr_t* g_registered_mgr = nullptr;
static ConfigureRecord g_last_configure;
static int g_stream_cb_count = 0;
static void* fake_alloc(size_t size)
{
return std::malloc(size);
}
static void fake_free(void* ptr)
{
std::free(ptr);
}
static char* fake_strdup(const char* s)
{
if (!s) return nullptr;
size_t n = std::strlen(s) + 1;
char* p = static_cast<char*>(std::malloc(n));
if (p) std::memcpy(p, s, n);
return p;
}
static int fake_register_service(const char* name, int version, void* vtable)
{
if (!name || !vtable || version != 1) return -1;
if (std::strcmp(name, "ai_endpoint_mgr") == 0) {
g_registered_mgr = static_cast<const dstalk_ai_endpoint_mgr_t*>(vtable);
return 0;
}
return -2;
}
static void fake_unregister_service(const char*)
{
}
static const char* fake_config_get(const char* key)
{
if (!key) return nullptr;
auto it = g_config_values.find(key);
if (it == g_config_values.end()) return nullptr;
static thread_local std::string tls_value;
tls_value = it->second;
return tls_value.c_str();
}
static int fake_config_set(const char* key, const char* value)
{
if (!key || !value) return -1;
g_config_values[key] = value;
return 0;
}
static void fake_log(int, const char*, ...)
{
}
static int fake_configure(const char* provider, const char* base_url,
const char* api_key, const char* model,
int max_tokens, double temperature)
{
g_last_configure.provider = provider ? provider : "";
g_last_configure.base_url = base_url ? base_url : "";
g_last_configure.api_key = api_key ? api_key : "";
g_last_configure.model = model ? model : "";
g_last_configure.max_tokens = max_tokens;
g_last_configure.temperature = temperature;
g_last_configure.configure_calls++;
return 0;
}
// 模拟 configure 失败的 provider service / Fake provider service whose configure always fails
static int fake_configure_fail(const char*, const char*, const char*, const char*, int, double)
{
g_last_configure.configure_calls++;
return -1;
}
static dstalk_chat_result_t fake_chat(const dstalk_message_t*, int,
const char*, const char*)
{
g_last_configure.chat_calls++;
dstalk_chat_result_t r = {};
r.ok = 1;
r.content = fake_strdup("ok");
return r;
}
static int test_stream_cb(const char*, void* userdata)
{
int* count = static_cast<int*>(userdata);
if (count) (*count)++;
return 0;
}
static dstalk_chat_result_t fake_chat_stream(const dstalk_message_t*, int,
const char*, dstalk_stream_cb cb,
void* userdata)
{
g_last_configure.stream_calls++;
if (cb) cb("tok", userdata);
dstalk_chat_result_t r = {};
r.ok = 1;
r.content = fake_strdup("stream-ok");
return r;
}
static void fake_free_result(dstalk_chat_result_t* result)
{
if (!result) return;
if (result->content) fake_free((void*)result->content);
if (result->error) fake_free((void*)result->error);
if (result->tool_calls_json) fake_free((void*)result->tool_calls_json);
result->content = nullptr;
result->error = nullptr;
result->tool_calls_json = nullptr;
}
static dstalk_ai_service_t g_fake_openai_service = {
&fake_configure,
&fake_chat,
&fake_chat_stream,
&fake_free_result,
};
static dstalk_ai_service_t g_fake_anthropic_service = {
&fake_configure,
&fake_chat,
&fake_chat_stream,
&fake_free_result,
};
// configure 总是失败的 provider 服务 / Provider service whose configure always fails
static dstalk_ai_service_t g_fake_failing_service = {
&fake_configure_fail,
&fake_chat,
&fake_chat_stream,
&fake_free_result,
};
static void* fake_query_service(const char* name, int min_version)
{
if (!name || min_version > 1) return nullptr;
if (std::strcmp(name, "ai_openai") == 0) return &g_fake_openai_service;
if (std::strcmp(name, "ai_anthropic") == 0) return &g_fake_anthropic_service;
if (std::strcmp(name, "ai_failing") == 0) return &g_fake_failing_service;
return nullptr;
}
static dstalk_host_api_t make_fake_host()
{
dstalk_host_api_t host = {};
host.register_service = fake_register_service;
host.query_service = fake_query_service;
host.unregister_service = fake_unregister_service;
host.config_get = fake_config_get;
host.config_set = fake_config_set;
host.log = fake_log;
host.alloc = fake_alloc;
host.free = fake_free;
host.strdup = fake_strdup;
return host;
}
static void setup_endpoint_config()
{
g_config_values.clear();
g_config_values["endpoints.names"] = "openai_main, anthropic_alt, missing_provider, openai_main";
g_config_values["endpoints.active"] = "anthropic_alt";
g_config_values["endpoint.openai_main.provider"] = "ai_openai";
g_config_values["endpoint.openai_main.api_key"] = "sk-openai-test";
g_config_values["endpoint.openai_main.model"] = "gpt-4o";
g_config_values["endpoint.openai_main.max_tokens"] = "1234";
g_config_values["endpoint.openai_main.temperature"] = "0.25";
g_config_values["endpoint.anthropic_alt.provider"] = "ai_anthropic";
g_config_values["endpoint.anthropic_alt.api_key"] = "sk-ant-test";
g_config_values["endpoint.anthropic_alt.model"] = "claude-sonnet-test";
g_config_values["endpoint.missing_provider.provider"] = "ai_missing";
g_config_values["endpoint.missing_provider.model"] = "missing-model";
}
// 设置单 endpoint 配置用于错误路径测试 / Set up single-endpoint config for error-path testing
static void setup_single_endpoint(const char* name, const char* provider, const char* model,
const char* base_url = nullptr)
{
g_config_values.clear();
std::string names_key = "endpoints.names";
g_config_values[names_key] = name;
std::string prefix = std::string("endpoint.") + name + ".";
g_config_values[prefix + "provider"] = provider;
g_config_values[prefix + "api_key"] = "sk-test";
g_config_values[prefix + "model"] = model;
if (base_url && *base_url) {
g_config_values[prefix + "base_url"] = base_url;
}
}
static void reset_test_state()
{
on_shutdown();
g_registered_mgr = nullptr;
g_last_configure = {};
g_config_values.clear();
}
int main()
{
// ================================================================
// 主测试流程 / Main test flow (existing)
// ================================================================
on_shutdown();
setup_endpoint_config();
g_registered_mgr = nullptr;
g_last_configure = {};
dstalk_host_api_t host = make_fake_host();
int init_rc = on_init(&host);
CHECK(init_rc == 0, "on_init registers endpoint manager service");
CHECK(g_registered_mgr != nullptr, "registered service pointer captured");
const dstalk_ai_endpoint_mgr_t* mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 2, "loads valid endpoints, skips missing provider and duplicate");
const char* active = mgr ? mgr->get_active() : nullptr;
CHECK(active && std::strcmp(active, "anthropic_alt") == 0,
"configured active endpoint is selected");
char* list = mgr ? mgr->list_json() : nullptr;
std::string list_json = list ? list : "";
if (list) fake_free(list);
CHECK(list_json.find("openai_main") != std::string::npos,
"list_json includes OpenAI endpoint");
CHECK(list_json.find("anthropic_alt") != std::string::npos,
"list_json includes Anthropic endpoint");
CHECK(list_json.find("https://api.anthropic.com") != std::string::npos,
"Anthropic endpoint uses Anthropic default base_url");
CHECK(list_json.find("sk-openai-test") == std::string::npos &&
list_json.find("sk-ant-test") == std::string::npos,
"list_json does not expose API keys");
CHECK(mgr->set_active("missing") == -2, "set_active rejects unknown endpoint");
CHECK(mgr->set_active("openai_main") == 0, "set_active accepts known endpoint");
CHECK(std::strcmp(mgr->get_active(), "openai_main") == 0,
"get_active reflects set_active change");
CHECK(mgr->set_model(nullptr, "gpt-4.1-mini") == 0,
"set_model(nullptr, model) updates active endpoint");
CHECK(g_config_values["endpoint.openai_main.model"] == "gpt-4.1-mini",
"set_model mirrors model to host config store");
CHECK(mgr->set_model("missing", "model") == -2,
"set_model rejects unknown endpoint");
CHECK(mgr->set_model("openai_main", "") == -1,
"set_model rejects empty model");
dstalk_message_t msg = {"user", "hello", nullptr, nullptr};
dstalk_chat_result_t r = mgr->chat(nullptr, &msg, 1, "hi", "[]");
CHECK(r.ok == 1 && r.content && std::strcmp(r.content, "ok") == 0,
"chat routes to active endpoint service");
CHECK(g_last_configure.provider == "ai_openai" &&
g_last_configure.base_url == "https://api.openai.com/v1" &&
g_last_configure.model == "gpt-4.1-mini" &&
g_last_configure.max_tokens == 1234 &&
g_last_configure.temperature == 0.25,
"chat configures OpenAI endpoint before routing");
mgr->free_result(&r);
CHECK(mgr->set_active("anthropic_alt") == 0, "switch active endpoint to Anthropic");
r = mgr->chat(nullptr, &msg, 1, "hi", nullptr);
CHECK(r.ok == 1 && g_last_configure.provider == "ai_anthropic" &&
g_last_configure.base_url == "https://api.anthropic.com" &&
g_last_configure.model == "claude-sonnet-test",
"chat configures Anthropic endpoint before routing");
mgr->free_result(&r);
g_stream_cb_count = 0;
r = mgr->chat_stream("anthropic_alt", &msg, 1, "hi", test_stream_cb, &g_stream_cb_count);
CHECK(r.ok == 1 && g_stream_cb_count == 1,
"chat_stream routes callback through selected endpoint");
mgr->free_result(&r);
on_shutdown();
CHECK(mgr->count() == 0, "on_shutdown clears endpoint cache");
// ================================================================
// 错误路径测试 / Error-path tests (P3)
// ================================================================
// --- E1: null history with non-zero length / 空指针 history ---
{
reset_test_state();
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
host = make_fake_host();
CHECK(on_init(&host) == 0, "E1: init with single endpoint");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 1, "E1: one endpoint loaded");
// null history + non-zero len -> 应返回错误 / should return error
r = mgr->chat("test_ep", nullptr, 3, "hi", nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "null history with non-zero length") == 0,
"E1: null history with len>0 returns error for chat");
mgr->free_result(&r);
r = mgr->chat_stream("test_ep", nullptr, 2, "hi", nullptr, nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "null history with non-zero length") == 0,
"E1: null history with len>0 returns error for chat_stream");
mgr->free_result(&r);
// null history + zero len -> 应正常 (无历史) / should pass (empty history)
r = mgr->chat("test_ep", nullptr, 0, "hi", nullptr);
CHECK(r.ok == 1, "E1: null history with len=0 passes for chat");
mgr->free_result(&r);
on_shutdown();
}
// --- E2: missing endpoint chat / 缺失 endpoint ---
{
reset_test_state();
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
host = make_fake_host();
CHECK(on_init(&host) == 0, "E2: init");
mgr = g_registered_mgr;
r = mgr->chat("nonexistent_ep", &msg, 1, "hi", nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint not found") == 0,
"E2: chat with nonexistent endpoint returns error");
mgr->free_result(&r);
r = mgr->chat_stream("nonexistent_ep", &msg, 1, "hi", nullptr, nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint not found") == 0,
"E2: chat_stream with nonexistent endpoint returns error");
mgr->free_result(&r);
on_shutdown();
}
// --- E3: configure failed / configure 失败 ---
{
reset_test_state();
// ai_failing provider 的 configure 总是返回 -1 / ai_failing provider's configure always returns -1
setup_single_endpoint("fail_ep", "ai_failing", "fail-model", "https://fail.example.com/api");
host = make_fake_host();
int rc = on_init(&host);
CHECK(rc == 0, "E3: init with failing endpoint");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 1, "E3: failing endpoint loaded (service exists)");
r = mgr->chat("fail_ep", &msg, 1, "hi", nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint configure failed") == 0,
"E3: chat returns error when configure fails");
mgr->free_result(&r);
r = mgr->chat_stream("fail_ep", &msg, 1, "hi", nullptr, nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint configure failed") == 0,
"E3: chat_stream returns error when configure fails");
mgr->free_result(&r);
on_shutdown();
}
// --- E4: empty endpoints (no endpoints.names config key) / 空 endpoint 列表 ---
{
reset_test_state();
// 不设置任何 endpoint 配置 / No endpoint config at all
g_config_values.clear();
host = make_fake_host();
CHECK(on_init(&host) == 0, "E4: init with no endpoint config");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 0, "E4: zero endpoints loaded");
CHECK(mgr->get_active() == nullptr, "E4: get_active returns nullptr when no endpoints");
// 在没有 endpoint 的情况下chat 应报错 / chat should error when no endpoints
r = mgr->chat(nullptr, &msg, 1, "hi", nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint not found") == 0,
"E4: chat with no endpoints returns error");
mgr->free_result(&r);
on_shutdown();
}
// --- E5: bad active (active set to nonexistent endpoint) / 无效 active ---
{
reset_test_state();
g_config_values["endpoints.names"] = "ep1";
g_config_values["endpoints.active"] = "does_not_exist"; // 不存在的 active / nonexistent active
g_config_values["endpoint.ep1.provider"] = "ai_openai";
g_config_values["endpoint.ep1.api_key"] = "sk-test";
g_config_values["endpoint.ep1.model"] = "gpt-test";
host = make_fake_host();
CHECK(on_init(&host) == 0, "E5: init with bad active config");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 1, "E5: endpoint loaded despite bad active");
// 当 active 指向不存在的 endpoint 时get_active 应返回第一个加载有效的 endpoint
// When active points to nonexistent endpoint, get_active should return the first valid loaded endpoint
active = mgr->get_active();
CHECK(active != nullptr && std::strcmp(active, "ep1") == 0,
"E5: get_active falls back to first loaded endpoint when configured active is invalid");
on_shutdown();
}
// --- E6: set_active with null/empty name / set_active 空名称 ---
{
reset_test_state();
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
host = make_fake_host();
CHECK(on_init(&host) == 0, "E6: init");
mgr = g_registered_mgr;
CHECK(mgr->set_active(nullptr) == -1, "E6: set_active(nullptr) returns -1");
CHECK(mgr->set_active("") == -1, "E6: set_active(\"\") returns -1");
on_shutdown();
}
// ================================================================
// secure_zero 基础测试 / Basic secure_zero test
// ================================================================
{
// 分配缓冲区并填充可识别模式 / Allocate buffer and fill with recognizable pattern
char buf[64];
std::memset(buf, 0xAB, sizeof(buf));
dstalk_ai::secure_zero(buf, sizeof(buf));
bool all_zero = true;
for (int i = 0; i < (int)sizeof(buf); ++i) {
if (buf[i] != 0) { all_zero = false; break; }
}
CHECK(all_zero, "secure_zero wipes all bytes to zero");
// 空 size / zero size — 不崩溃 / should not crash
dstalk_ai::secure_zero(buf, 0);
CHECK(true, "secure_zero with size=0 does not crash");
// nullptr + zero size — 不崩溃 / should not crash
dstalk_ai::secure_zero(nullptr, 0);
CHECK(true, "secure_zero(nullptr, 0) does not crash");
}
// ================================================================
// 轻量并发读写测试 / Lightweight concurrent read/write test
// ================================================================
{
reset_test_state();
setup_single_endpoint("conc_ep", "ai_openai", "gpt-concurrent");
host = make_fake_host();
CHECK(on_init(&host) == 0, "concurrency setup: init");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 1, "concurrency setup: endpoint loaded");
const int kReaders = 4;
const int kIters = 100;
std::vector<std::thread> threads;
std::atomic<int> errors{0};
// 读者线程: 反复调用 get_active / list_json / count / set_active
// Reader threads: repeatedly call get_active / list_json / count / set_active
for (int t = 0; t < kReaders; ++t) {
threads.emplace_back([mgr, &errors, t]() {
for (int i = 0; i < kIters; ++i) {
// 读操作 / read operations
const char* a = mgr->get_active();
if (a && std::strcmp(a, "conc_ep") != 0) errors++;
int c = mgr->count();
if (c != 1) errors++;
char* l = mgr->list_json();
if (l) fake_free(l);
else errors++;
// 轻量写操作: set_active 到同一个 endpoint / lightweight write
if (i % 10 == 0) {
int rc = mgr->set_active("conc_ep");
if (rc != 0) errors++;
}
}
});
}
for (auto& th : threads) th.join();
CHECK(errors.load() == 0, "concurrent read/write: no errors across threads");
on_shutdown();
}
// ================================================================
// 总结 / Summary
// ================================================================
if (g_failures == 0) {
std::cout << "\nendpoint_mgr_plugin_test: all checks passed\n";
} else {
std::cerr << "\nendpoint_mgr_plugin_test: " << g_failures << " failure(s)\n";
}
return g_failures == 0 ? 0 : 1;
}