feat: add AI endpoint manager plugin with configuration and routing capabilities
Some checks failed
CI / Determine matrix (push) Has been cancelled
CI / Sanitizer (ASan+UBSan) / ubuntu-24.04 (push) Has been cancelled
CI / Coverage (gcovr) / ubuntu-24.04 (push) Has been cancelled
CI / ${{ matrix.os }} / ${{ matrix.build_type }} (push) Has been cancelled

- Introduced `ai_endpoint_mgr` plugin to manage multiple AI provider endpoints.
- Added configuration reference documentation for `config.toml`.
- Implemented endpoint loading, active endpoint switching, and model mutation.
- Included error handling for missing endpoints and configuration failures.
- Developed unit tests covering various scenarios including error paths and concurrency.
This commit is contained in:
2026-06-03 21:07:25 +08:00
parent 28ae90a6cc
commit 4745ce1f1c
18 changed files with 1570 additions and 34 deletions

View File

@@ -0,0 +1,547 @@
/*
* @file endpoint_mgr_plugin_test.cpp
* @brief AI endpoint manager plugin unit tests: endpoint loading, active/model mutation, routing, secret-safe listing,
* plus error-path coverage (null history, missing endpoint, configure failed, empty/bad active, concurrency).
* AI endpoint 管理器插件单元测试endpoint 加载、active/model 修改、路由和脱敏列表、
* 以及错误路径覆盖(空 history、缺失 endpoint、configure 失败、空/错 active、并发
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "../plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp"
#include <cstdarg>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>
static int g_failures = 0;
#define CHECK(cond, msg) do { \
if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \
} else { \
std::cerr << "[FAIL] " << (msg) << "\n"; \
g_failures++; \
} \
} while (0)
struct ConfigureRecord {
std::string provider;
std::string base_url;
std::string api_key;
std::string model;
int max_tokens = 0;
double temperature = 0.0;
int configure_calls = 0;
int chat_calls = 0;
int stream_calls = 0;
};
static std::unordered_map<std::string, std::string> g_config_values;
static const dstalk_ai_endpoint_mgr_t* g_registered_mgr = nullptr;
static ConfigureRecord g_last_configure;
static int g_stream_cb_count = 0;
static void* fake_alloc(size_t size)
{
return std::malloc(size);
}
static void fake_free(void* ptr)
{
std::free(ptr);
}
static char* fake_strdup(const char* s)
{
if (!s) return nullptr;
size_t n = std::strlen(s) + 1;
char* p = static_cast<char*>(std::malloc(n));
if (p) std::memcpy(p, s, n);
return p;
}
static int fake_register_service(const char* name, int version, void* vtable)
{
if (!name || !vtable || version != 1) return -1;
if (std::strcmp(name, "ai_endpoint_mgr") == 0) {
g_registered_mgr = static_cast<const dstalk_ai_endpoint_mgr_t*>(vtable);
return 0;
}
return -2;
}
static void fake_unregister_service(const char*)
{
}
static const char* fake_config_get(const char* key)
{
if (!key) return nullptr;
auto it = g_config_values.find(key);
if (it == g_config_values.end()) return nullptr;
static thread_local std::string tls_value;
tls_value = it->second;
return tls_value.c_str();
}
static int fake_config_set(const char* key, const char* value)
{
if (!key || !value) return -1;
g_config_values[key] = value;
return 0;
}
static void fake_log(int, const char*, ...)
{
}
static int fake_configure(const char* provider, const char* base_url,
const char* api_key, const char* model,
int max_tokens, double temperature)
{
g_last_configure.provider = provider ? provider : "";
g_last_configure.base_url = base_url ? base_url : "";
g_last_configure.api_key = api_key ? api_key : "";
g_last_configure.model = model ? model : "";
g_last_configure.max_tokens = max_tokens;
g_last_configure.temperature = temperature;
g_last_configure.configure_calls++;
return 0;
}
// 模拟 configure 失败的 provider service / Fake provider service whose configure always fails
static int fake_configure_fail(const char*, const char*, const char*, const char*, int, double)
{
g_last_configure.configure_calls++;
return -1;
}
static dstalk_chat_result_t fake_chat(const dstalk_message_t*, int,
const char*, const char*)
{
g_last_configure.chat_calls++;
dstalk_chat_result_t r = {};
r.ok = 1;
r.content = fake_strdup("ok");
return r;
}
static int test_stream_cb(const char*, void* userdata)
{
int* count = static_cast<int*>(userdata);
if (count) (*count)++;
return 0;
}
static dstalk_chat_result_t fake_chat_stream(const dstalk_message_t*, int,
const char*, dstalk_stream_cb cb,
void* userdata)
{
g_last_configure.stream_calls++;
if (cb) cb("tok", userdata);
dstalk_chat_result_t r = {};
r.ok = 1;
r.content = fake_strdup("stream-ok");
return r;
}
static void fake_free_result(dstalk_chat_result_t* result)
{
if (!result) return;
if (result->content) fake_free((void*)result->content);
if (result->error) fake_free((void*)result->error);
if (result->tool_calls_json) fake_free((void*)result->tool_calls_json);
result->content = nullptr;
result->error = nullptr;
result->tool_calls_json = nullptr;
}
static dstalk_ai_service_t g_fake_openai_service = {
&fake_configure,
&fake_chat,
&fake_chat_stream,
&fake_free_result,
};
static dstalk_ai_service_t g_fake_anthropic_service = {
&fake_configure,
&fake_chat,
&fake_chat_stream,
&fake_free_result,
};
// configure 总是失败的 provider 服务 / Provider service whose configure always fails
static dstalk_ai_service_t g_fake_failing_service = {
&fake_configure_fail,
&fake_chat,
&fake_chat_stream,
&fake_free_result,
};
static void* fake_query_service(const char* name, int min_version)
{
if (!name || min_version > 1) return nullptr;
if (std::strcmp(name, "ai_openai") == 0) return &g_fake_openai_service;
if (std::strcmp(name, "ai_anthropic") == 0) return &g_fake_anthropic_service;
if (std::strcmp(name, "ai_failing") == 0) return &g_fake_failing_service;
return nullptr;
}
static dstalk_host_api_t make_fake_host()
{
dstalk_host_api_t host = {};
host.register_service = fake_register_service;
host.query_service = fake_query_service;
host.unregister_service = fake_unregister_service;
host.config_get = fake_config_get;
host.config_set = fake_config_set;
host.log = fake_log;
host.alloc = fake_alloc;
host.free = fake_free;
host.strdup = fake_strdup;
return host;
}
static void setup_endpoint_config()
{
g_config_values.clear();
g_config_values["endpoints.names"] = "openai_main, anthropic_alt, missing_provider, openai_main";
g_config_values["endpoints.active"] = "anthropic_alt";
g_config_values["endpoint.openai_main.provider"] = "ai_openai";
g_config_values["endpoint.openai_main.api_key"] = "sk-openai-test";
g_config_values["endpoint.openai_main.model"] = "gpt-4o";
g_config_values["endpoint.openai_main.max_tokens"] = "1234";
g_config_values["endpoint.openai_main.temperature"] = "0.25";
g_config_values["endpoint.anthropic_alt.provider"] = "ai_anthropic";
g_config_values["endpoint.anthropic_alt.api_key"] = "sk-ant-test";
g_config_values["endpoint.anthropic_alt.model"] = "claude-sonnet-test";
g_config_values["endpoint.missing_provider.provider"] = "ai_missing";
g_config_values["endpoint.missing_provider.model"] = "missing-model";
}
// 设置单 endpoint 配置用于错误路径测试 / Set up single-endpoint config for error-path testing
static void setup_single_endpoint(const char* name, const char* provider, const char* model,
const char* base_url = nullptr)
{
g_config_values.clear();
std::string names_key = "endpoints.names";
g_config_values[names_key] = name;
std::string prefix = std::string("endpoint.") + name + ".";
g_config_values[prefix + "provider"] = provider;
g_config_values[prefix + "api_key"] = "sk-test";
g_config_values[prefix + "model"] = model;
if (base_url && *base_url) {
g_config_values[prefix + "base_url"] = base_url;
}
}
static void reset_test_state()
{
on_shutdown();
g_registered_mgr = nullptr;
g_last_configure = {};
g_config_values.clear();
}
int main()
{
// ================================================================
// 主测试流程 / Main test flow (existing)
// ================================================================
on_shutdown();
setup_endpoint_config();
g_registered_mgr = nullptr;
g_last_configure = {};
dstalk_host_api_t host = make_fake_host();
int init_rc = on_init(&host);
CHECK(init_rc == 0, "on_init registers endpoint manager service");
CHECK(g_registered_mgr != nullptr, "registered service pointer captured");
const dstalk_ai_endpoint_mgr_t* mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 2, "loads valid endpoints, skips missing provider and duplicate");
const char* active = mgr ? mgr->get_active() : nullptr;
CHECK(active && std::strcmp(active, "anthropic_alt") == 0,
"configured active endpoint is selected");
char* list = mgr ? mgr->list_json() : nullptr;
std::string list_json = list ? list : "";
if (list) fake_free(list);
CHECK(list_json.find("openai_main") != std::string::npos,
"list_json includes OpenAI endpoint");
CHECK(list_json.find("anthropic_alt") != std::string::npos,
"list_json includes Anthropic endpoint");
CHECK(list_json.find("https://api.anthropic.com") != std::string::npos,
"Anthropic endpoint uses Anthropic default base_url");
CHECK(list_json.find("sk-openai-test") == std::string::npos &&
list_json.find("sk-ant-test") == std::string::npos,
"list_json does not expose API keys");
CHECK(mgr->set_active("missing") == -2, "set_active rejects unknown endpoint");
CHECK(mgr->set_active("openai_main") == 0, "set_active accepts known endpoint");
CHECK(std::strcmp(mgr->get_active(), "openai_main") == 0,
"get_active reflects set_active change");
CHECK(mgr->set_model(nullptr, "gpt-4.1-mini") == 0,
"set_model(nullptr, model) updates active endpoint");
CHECK(g_config_values["endpoint.openai_main.model"] == "gpt-4.1-mini",
"set_model mirrors model to host config store");
CHECK(mgr->set_model("missing", "model") == -2,
"set_model rejects unknown endpoint");
CHECK(mgr->set_model("openai_main", "") == -1,
"set_model rejects empty model");
dstalk_message_t msg = {"user", "hello", nullptr, nullptr};
dstalk_chat_result_t r = mgr->chat(nullptr, &msg, 1, "hi", "[]");
CHECK(r.ok == 1 && r.content && std::strcmp(r.content, "ok") == 0,
"chat routes to active endpoint service");
CHECK(g_last_configure.provider == "ai_openai" &&
g_last_configure.base_url == "https://api.openai.com/v1" &&
g_last_configure.model == "gpt-4.1-mini" &&
g_last_configure.max_tokens == 1234 &&
g_last_configure.temperature == 0.25,
"chat configures OpenAI endpoint before routing");
mgr->free_result(&r);
CHECK(mgr->set_active("anthropic_alt") == 0, "switch active endpoint to Anthropic");
r = mgr->chat(nullptr, &msg, 1, "hi", nullptr);
CHECK(r.ok == 1 && g_last_configure.provider == "ai_anthropic" &&
g_last_configure.base_url == "https://api.anthropic.com" &&
g_last_configure.model == "claude-sonnet-test",
"chat configures Anthropic endpoint before routing");
mgr->free_result(&r);
g_stream_cb_count = 0;
r = mgr->chat_stream("anthropic_alt", &msg, 1, "hi", test_stream_cb, &g_stream_cb_count);
CHECK(r.ok == 1 && g_stream_cb_count == 1,
"chat_stream routes callback through selected endpoint");
mgr->free_result(&r);
on_shutdown();
CHECK(mgr->count() == 0, "on_shutdown clears endpoint cache");
// ================================================================
// 错误路径测试 / Error-path tests (P3)
// ================================================================
// --- E1: null history with non-zero length / 空指针 history ---
{
reset_test_state();
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
host = make_fake_host();
CHECK(on_init(&host) == 0, "E1: init with single endpoint");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 1, "E1: one endpoint loaded");
// null history + non-zero len -> 应返回错误 / should return error
r = mgr->chat("test_ep", nullptr, 3, "hi", nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "null history with non-zero length") == 0,
"E1: null history with len>0 returns error for chat");
mgr->free_result(&r);
r = mgr->chat_stream("test_ep", nullptr, 2, "hi", nullptr, nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "null history with non-zero length") == 0,
"E1: null history with len>0 returns error for chat_stream");
mgr->free_result(&r);
// null history + zero len -> 应正常 (无历史) / should pass (empty history)
r = mgr->chat("test_ep", nullptr, 0, "hi", nullptr);
CHECK(r.ok == 1, "E1: null history with len=0 passes for chat");
mgr->free_result(&r);
on_shutdown();
}
// --- E2: missing endpoint chat / 缺失 endpoint ---
{
reset_test_state();
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
host = make_fake_host();
CHECK(on_init(&host) == 0, "E2: init");
mgr = g_registered_mgr;
r = mgr->chat("nonexistent_ep", &msg, 1, "hi", nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint not found") == 0,
"E2: chat with nonexistent endpoint returns error");
mgr->free_result(&r);
r = mgr->chat_stream("nonexistent_ep", &msg, 1, "hi", nullptr, nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint not found") == 0,
"E2: chat_stream with nonexistent endpoint returns error");
mgr->free_result(&r);
on_shutdown();
}
// --- E3: configure failed / configure 失败 ---
{
reset_test_state();
// ai_failing provider 的 configure 总是返回 -1 / ai_failing provider's configure always returns -1
setup_single_endpoint("fail_ep", "ai_failing", "fail-model", "https://fail.example.com/api");
host = make_fake_host();
int rc = on_init(&host);
CHECK(rc == 0, "E3: init with failing endpoint");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 1, "E3: failing endpoint loaded (service exists)");
r = mgr->chat("fail_ep", &msg, 1, "hi", nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint configure failed") == 0,
"E3: chat returns error when configure fails");
mgr->free_result(&r);
r = mgr->chat_stream("fail_ep", &msg, 1, "hi", nullptr, nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint configure failed") == 0,
"E3: chat_stream returns error when configure fails");
mgr->free_result(&r);
on_shutdown();
}
// --- E4: empty endpoints (no endpoints.names config key) / 空 endpoint 列表 ---
{
reset_test_state();
// 不设置任何 endpoint 配置 / No endpoint config at all
g_config_values.clear();
host = make_fake_host();
CHECK(on_init(&host) == 0, "E4: init with no endpoint config");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 0, "E4: zero endpoints loaded");
CHECK(mgr->get_active() == nullptr, "E4: get_active returns nullptr when no endpoints");
// 在没有 endpoint 的情况下chat 应报错 / chat should error when no endpoints
r = mgr->chat(nullptr, &msg, 1, "hi", nullptr);
CHECK(r.ok == 0 && r.error != nullptr &&
std::strcmp(r.error, "endpoint not found") == 0,
"E4: chat with no endpoints returns error");
mgr->free_result(&r);
on_shutdown();
}
// --- E5: bad active (active set to nonexistent endpoint) / 无效 active ---
{
reset_test_state();
g_config_values["endpoints.names"] = "ep1";
g_config_values["endpoints.active"] = "does_not_exist"; // 不存在的 active / nonexistent active
g_config_values["endpoint.ep1.provider"] = "ai_openai";
g_config_values["endpoint.ep1.api_key"] = "sk-test";
g_config_values["endpoint.ep1.model"] = "gpt-test";
host = make_fake_host();
CHECK(on_init(&host) == 0, "E5: init with bad active config");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 1, "E5: endpoint loaded despite bad active");
// 当 active 指向不存在的 endpoint 时get_active 应返回第一个加载有效的 endpoint
// When active points to nonexistent endpoint, get_active should return the first valid loaded endpoint
active = mgr->get_active();
CHECK(active != nullptr && std::strcmp(active, "ep1") == 0,
"E5: get_active falls back to first loaded endpoint when configured active is invalid");
on_shutdown();
}
// --- E6: set_active with null/empty name / set_active 空名称 ---
{
reset_test_state();
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
host = make_fake_host();
CHECK(on_init(&host) == 0, "E6: init");
mgr = g_registered_mgr;
CHECK(mgr->set_active(nullptr) == -1, "E6: set_active(nullptr) returns -1");
CHECK(mgr->set_active("") == -1, "E6: set_active(\"\") returns -1");
on_shutdown();
}
// ================================================================
// secure_zero 基础测试 / Basic secure_zero test
// ================================================================
{
// 分配缓冲区并填充可识别模式 / Allocate buffer and fill with recognizable pattern
char buf[64];
std::memset(buf, 0xAB, sizeof(buf));
dstalk_ai::secure_zero(buf, sizeof(buf));
bool all_zero = true;
for (int i = 0; i < (int)sizeof(buf); ++i) {
if (buf[i] != 0) { all_zero = false; break; }
}
CHECK(all_zero, "secure_zero wipes all bytes to zero");
// 空 size / zero size — 不崩溃 / should not crash
dstalk_ai::secure_zero(buf, 0);
CHECK(true, "secure_zero with size=0 does not crash");
// nullptr + zero size — 不崩溃 / should not crash
dstalk_ai::secure_zero(nullptr, 0);
CHECK(true, "secure_zero(nullptr, 0) does not crash");
}
// ================================================================
// 轻量并发读写测试 / Lightweight concurrent read/write test
// ================================================================
{
reset_test_state();
setup_single_endpoint("conc_ep", "ai_openai", "gpt-concurrent");
host = make_fake_host();
CHECK(on_init(&host) == 0, "concurrency setup: init");
mgr = g_registered_mgr;
CHECK(mgr && mgr->count() == 1, "concurrency setup: endpoint loaded");
const int kReaders = 4;
const int kIters = 100;
std::vector<std::thread> threads;
std::atomic<int> errors{0};
// 读者线程: 反复调用 get_active / list_json / count / set_active
// Reader threads: repeatedly call get_active / list_json / count / set_active
for (int t = 0; t < kReaders; ++t) {
threads.emplace_back([mgr, &errors, t]() {
for (int i = 0; i < kIters; ++i) {
// 读操作 / read operations
const char* a = mgr->get_active();
if (a && std::strcmp(a, "conc_ep") != 0) errors++;
int c = mgr->count();
if (c != 1) errors++;
char* l = mgr->list_json();
if (l) fake_free(l);
else errors++;
// 轻量写操作: set_active 到同一个 endpoint / lightweight write
if (i % 10 == 0) {
int rc = mgr->set_active("conc_ep");
if (rc != 0) errors++;
}
}
});
}
for (auto& th : threads) th.join();
CHECK(errors.load() == 0, "concurrent read/write: no errors across threads");
on_shutdown();
}
// ================================================================
// 总结 / Summary
// ================================================================
if (g_failures == 0) {
std::cout << "\nendpoint_mgr_plugin_test: all checks passed\n";
} else {
std::cerr << "\nendpoint_mgr_plugin_test: " << g_failures << " failure(s)\n";
}
return g_failures == 0 ? 0 : 1;
}