From 4745ce1f1c6ec8654be740eeef6f70c674b934b8 Mon Sep 17 00:00:00 2001 From: XiuChengWu <732857315@qq.com> Date: Wed, 3 Jun 2026 21:07:25 +0800 Subject: [PATCH] feat: add AI endpoint manager plugin with configuration and routing capabilities - Introduced `ai_endpoint_mgr` plugin to manage multiple AI provider endpoints. - Added configuration reference documentation for `config.toml`. - Implemented endpoint loading, active endpoint switching, and model mutation. - Included error handling for missing endpoints and configuration failures. - Developed unit tests covering various scenarios including error paths and concurrency. --- config.example.toml | 90 +++ docs/README.md | 4 +- docs/reference/config.md | 156 +++++ docs/tutorial/quick-start.md | 3 + dstalk_cli/src/main.cpp | 121 +++- dstalk_core/include/dstalk/dstalk_services.h | 36 ++ .../include/dstalk_frontend_common.hpp | 7 +- .../src/frontend_common.cpp | 20 +- dstalk_gui/src/main.cpp | 39 +- dstalk_web/src/main.cpp | 42 +- plugins_upper/CMakeLists.txt | 1 + plugins_upper/ai_endpoint_mgr/CMakeLists.txt | 32 + .../src/endpoint_mgr_plugin.cpp | 400 +++++++++++++ .../anthropic/src/anthropic_plugin.cpp | 2 + plugins_upper/openai/src/openai_plugin.cpp | 1 + tests/CMakeLists.txt | 39 ++ tests/endpoint_mgr_plugin_test.cpp | 547 ++++++++++++++++++ tests/smoke_test.cpp | 64 ++ 18 files changed, 1570 insertions(+), 34 deletions(-) create mode 100644 config.example.toml create mode 100644 docs/reference/config.md create mode 100644 plugins_upper/ai_endpoint_mgr/CMakeLists.txt create mode 100644 plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp create mode 100644 tests/endpoint_mgr_plugin_test.cpp diff --git a/config.example.toml b/config.example.toml new file mode 100644 index 0000000..e350378 --- /dev/null +++ b/config.example.toml @@ -0,0 +1,90 @@ +# ============================================================================ +# dstalk config.example.toml — 配置模板 / Configuration template +# ============================================================================ +# 复制为 config.toml 并修改即可使用。 +# Copy to config.toml and edit before use. +# +# 两种配置方式(互斥,选择一种即可): +# Two configuration modes (mutually exclusive, pick one): +# +# 1) 单 Provider 模式(简单,一个 AI 后端) +# Single-provider mode (simple, one AI backend) +# 使用 keys: ai.provider / api.base_url / api.api_key / api.model +# +# 2) 多 Endpoint 模式(高级,同时配置多个 AI 后端并通过 /endpoint 切换) +# Multi-endpoint mode (advanced, multiple AI backends switchable via /endpoint) +# 使用 keys: endpoints.names / endpoints.active / endpoint..* +# +# 如果同时配置了两种方式,ai_endpoint_mgr 的 chat 调用优先走 endpoints.*。 +# If both are configured, ai_endpoint_mgr chat calls prefer endpoints.*. +# ============================================================================ + +# ---------------------------------------------------------------------------- +# 方式 1: 单 Provider 模式 / Mode 1: Single Provider +# ---------------------------------------------------------------------------- +# 取消下方注释即可使用 / Uncomment to use + +# ai.provider = "ai_openai" # ai_openai 或 ai_anthropic / ai_openai or ai_anthropic +# +# api.base_url = "https://api.openai.com/v1" +# api.api_key = "sk-your-key-here" +# api.model = "gpt-4o" + +# 或者用 Anthropic / Or use Anthropic: +# ai.provider = "ai_anthropic" +# api.base_url = "https://api.anthropic.com" +# api.api_key = "sk-ant-your-key-here" +# api.model = "claude-sonnet-4-20250514" + + +# ---------------------------------------------------------------------------- +# 方式 2: 多 Endpoint 模式 / Mode 2: Multi-Endpoint +# ---------------------------------------------------------------------------- +# 取消下方注释即可使用,多个 endpoint 在对话中通过 /endpoint 命令切换。 +# Uncomment to use. Switch endpoints during a session via the /endpoint command. + +# 逗号分隔的 endpoint 名称列表(至少一个) / Comma-separated endpoint names (at least one) +# endpoints.names = "openai_main, anthropic_alt" + +# 默认激活的 endpoint(可选,不设置则取列表第一个) / Default active endpoint (optional, defaults to first in list) +# endpoints.active = "openai_main" + +# --- openai_main --- +# provider 必须 / required +# endpoint.openai_main.provider = "ai_openai" + +# base_url 可选 / optional (默认按 provider 自动填 / defaults by provider: OpenAI→api.openai.com/v1, Anthropic→api.anthropic.com) +# endpoint.openai_main.base_url = "https://api.openai.com/v1" + +# api_key 必须 / required +# endpoint.openai_main.api_key = "sk-your-key-here" + +# model 必须 / required +# endpoint.openai_main.model = "gpt-4o" + +# max_tokens 可选 / optional (默认 4096) / default 4096 +# endpoint.openai_main.max_tokens = 4096 + +# temperature 可选 / optional (默认 0.7, 范围 0.0~2.0) / default 0.7, range 0.0~2.0 +# endpoint.openai_main.temperature = 0.7 + + +# --- anthropic_alt --- +# endpoint.anthropic_alt.provider = "ai_anthropic" +# base_url 未设置时自动使用 https://api.anthropic.com / defaults to https://api.anthropic.com when unset +# endpoint.anthropic_alt.api_key = "sk-ant-your-key-here" +# endpoint.anthropic_alt.model = "claude-sonnet-4-20250514" +# endpoint.anthropic_alt.max_tokens = 8192 + + +# --- deepseek (自定义 base_url 示例 / custom base_url example) --- +# endpoint.deepseek.provider = "ai_openai" +# endpoint.deepseek.base_url = "https://api.deepseek.com/v1" +# endpoint.deepseek.api_key = "sk-deepseek-your-key-here" +# endpoint.deepseek.model = "deepseek-chat" + +# 提示 / Tips: +# - list_json 输出不含 api_key(安全脱敏)。 +# list_json output excludes api_key (security de-identification). +# - 未知 provider 必须显式配置 base_url,否则 endpoint 加载失败。 +# Unknown providers must explicitly set base_url or the endpoint won't load. diff --git a/docs/README.md b/docs/README.md index c5577f6..2740a4d 100644 --- a/docs/README.md +++ b/docs/README.md @@ -24,6 +24,7 @@ | 文档 | 说明 | |------|------| | [CLI 命令速查](reference/commands.md) | 全部 CLI 命令的别名、作用与示例 | +| [配置参考](reference/config.md) | config.toml 单 provider 与多 endpoint 全部字段说明、默认值与安全提示 | | [Plugin ABI 契约](reference/plugin-abi.md) | 跨 DLL 通信的 C ABI 规范:内存所有权、堆纪律、回调线程安全 | --- @@ -53,7 +54,7 @@ ### 参考 - [ ] **API 参考** (`reference/api.md`) — TODO: dstalk_host.h 完整 API 说明与调用示例 -- [ ] **配置参考** (`reference/config.md`) — TODO: config.toml 所有字段的详细说明 +- [x] **配置参考** (`reference/config.md`) — config.toml 所有字段的详细说明、默认值与安全提示 - [ ] **服务接口参考** (`reference/services.md`) — TODO: dstalk_services.h 中所有 vtable 接口定义 --- @@ -63,6 +64,7 @@ - 命令以 `$ ` 前缀表示, 在终端中运行 - 代码块标注语言 (```c, ```toml, ```bash 等) - [ ] 表示计划中未完成的文档 +- 配置文件模板见项目根目录 `config.example.toml` --- diff --git a/docs/reference/config.md b/docs/reference/config.md new file mode 100644 index 0000000..9e78fad --- /dev/null +++ b/docs/reference/config.md @@ -0,0 +1,156 @@ +# 配置参考 / Configuration Reference + +`config.toml` 是 dstalk 的唯一配置文件,放在项目根目录。本文档列出所有支持字段、类型、默认值与使用说明。 + +--- + +## 配置概览 + +配置分为两种模式: + +| 模式 | 适用场景 | 核心 key | +|------|----------|----------| +| **单 Provider** | 只需一个 AI 后端 | `ai.provider`, `api.*` | +| **多 Endpoint** | 同时配置多个 AI 后端,运行时切换 | `endpoints.names`, `endpoint..*` | + +两种模式可以同时配置,但 `ai_endpoint_mgr` 的 `chat` / `chat_stream` 调用优先走 `endpoints.*`。 + +--- + +## 一、单 Provider 模式 (Legacy) + +直接在 `[global]` 下声明。这些 key 无默认值——不配置则无法启动 AI 对话。 + +### ai.provider + +- **类型**: string +- **默认**: 无(必填) +- **值**: `"ai_openai"` 或 `"ai_anthropic"` +- **说明**: 指定要加载的 AI provider 插件。该名称对应插件 DLL 注册的服务名。 + +### api.base_url + +- **类型**: string +- **默认**: 无(必填) +- **说明**: AI API 的基础 URL。例如: + - OpenAI: `https://api.openai.com/v1` + - Anthropic: `https://api.anthropic.com` + - DeepSeek (OpenAI 兼容): `https://api.deepseek.com/v1` + +### api.api_key + +- **类型**: string +- **默认**: 无(必填) +- **说明**: API 密钥。注意保管,不要提交到版本控制。 + +### api.model + +- **类型**: string +- **默认**: 无(必填) +- **说明**: 要使用的模型名称。例如 `gpt-4o`、`claude-sonnet-4-20250514`、`deepseek-chat`。 + +--- + +## 二、多 Endpoint 模式 (推荐) + +由 `ai_endpoint_mgr` 插件管理。每个 endpoint 是一个命名的 AI 后端配置,运行时可通过 `/endpoint` 命令切换。 + +### endpoints.names + +- **类型**: string (逗号分隔) +- **默认**: 无 +- **说明**: 所有要启用的 endpoint 名称列表。例如 `"openai_main, anthropic_alt"`。重复名称会被跳过并输出警告。 + +### endpoints.active + +- **类型**: string +- **默认**: 取 `endpoints.names` 中第一个成功加载的 endpoint +- **说明**: 启动时默认使用的 endpoint 名称。如果指定的名称不在 `endpoints.names` 中或对应配置无效,则回退到第一个成功加载的 endpoint。 + +--- + +### endpoint.\.provider + +- **类型**: string +- **默认**: 无(必填) +- **值**: `"ai_openai"` 或 `"ai_anthropic"` +- **说明**: 该 endpoint 使用的 AI provider 服务名称。 + +### endpoint.\.base_url + +- **类型**: string +- **默认**: 按 provider 自动推导 +- **说明**: + - 未配置时,已知 provider 自动使用默认 base_url: + - `ai_anthropic` → `https://api.anthropic.com` + - `ai_openai` → `https://api.openai.com/v1` + - 如果 provider 不在已知列表中,**必须**显式配置 `base_url`,否则该 endpoint 加载失败。 + - 显式配置的值优先于默认值,可用于自定义端点(如代理、DeepSeek 兼容 API 等)。 + +### endpoint.\.api_key + +- **类型**: string +- **默认**: 无(必填) +- **说明**: 该 endpoint 的 API 密钥。 + +> **安全提示**: `api_key` **不会**出现在 `list_json()` 的输出中,也不进入日志。这是有意为之的安全策略。 + +### endpoint.\.model + +- **类型**: string +- **默认**: 无(必填) +- **说明**: 该 endpoint 默认使用的模型名称。运行时可通过 `set_model()` 修改并同步回 `config.toml`。 + +### endpoint.\.max_tokens + +- **类型**: integer +- **默认**: `4096` +- **范围**: `1` ~ `1000000`(超出范围的值视为无效,回退到默认值) +- **说明**: 每次请求的最大输出 token 数。 + +### endpoint.\.temperature + +- **类型**: double +- **默认**: `0.7` +- **范围**: `0.0` ~ `2.0`(超出范围的值视为无效,回退到默认值) +- **说明**: 生成温度。越高随机性越强,越低越确定性。 + +--- + +## 三、ui 配置 (可选 / optional) + +以下 key 目前为 CLI 前端预留,后续 GUI/Web 前端也会使用。 + +### ui.prompt + +- **类型**: string +- **默认**: `"> "` +- **说明**: 命令行提示符字符串。 + +### ui.multiline + +- **类型**: string +- **默认**: `"/"` (即不支持多行输入) +- **说明**: 多行输入结束符。发送该字符串结束多行输入模式。 + +--- + +## 四、完整配置示例 + +见项目根目录 `config.example.toml`。 + +--- + +## 五、运行时 endpoint 操作 + +`ai_endpoint_mgr` 服务提供以下接口(通过 C ABI)供前端调用: + +| 操作 | 说明 | +|------|------| +| `count()` | 返回已加载的 endpoint 数量 | +| `list_json()` | 返回 JSON 数组,每项含 `name`, `provider`, `base_url`, `model`, `active` (不含 `api_key`) | +| `get_active()` | 返回当前激活的 endpoint 名称 | +| `set_active(name)` | 切换激活 endpoint;返回 0 成功,-2 表示不存在 | +| `set_model(name, model)` | 修改 endpoint 的模型并同步到 config;name 为 null 时作用于当前 active endpoint | +| `chat(ep, history, len, input, tools)` | 路由对话到指定 endpoint(ep 为 null 时用 active) | +| `chat_stream(ep, history, len, input, cb, userdata)` | 流式路由对话 | diff --git a/docs/tutorial/quick-start.md b/docs/tutorial/quick-start.md index 5b89cc7..3f16391 100644 --- a/docs/tutorial/quick-start.md +++ b/docs/tutorial/quick-start.md @@ -96,6 +96,8 @@ api.model = "gpt-4o" > **关键**: 修改 `ai.provider` 字段即可在不同后端间切换, 无需改动代码。 > > API Key 可从 [OpenAI-compatible 开放平台](https://platform.openai.com/) 或 [Anthropic Console](https://console.anthropic.com/) 获取。 +> +> **多 Endpoint 配置** (同时使用多个 AI 后端并在运行时通过 `/endpoint` 切换): 参见 [配置参考](../reference/config.md) 和项目根目录 `config.example.toml`。 --- @@ -172,5 +174,6 @@ dstalk v0.1.0 | dstalk AI | /help 查看帮助 | /quit 退出 ## 下一步 - 查看 [CLI 命令速查表](../reference/commands.md) 了解全部命令 +- 查看 [配置参考](../reference/config.md) 了解 config.toml 所有字段 - 输入 `/help` 在 dstalk 内查看命令列表 - 输入 `/status` 查看当前运行状态 diff --git a/dstalk_cli/src/main.cpp b/dstalk_cli/src/main.cpp index 8d1844a..218952c 100644 --- a/dstalk_cli/src/main.cpp +++ b/dstalk_cli/src/main.cpp @@ -56,6 +56,7 @@ static const dstalk_ai_service_t* g_ai = nullptr; static const dstalk_session_service_t* g_session = nullptr; static const dstalk_file_io_service_t* g_file_io = nullptr; static const dstalk_tools_service_t* g_tools = nullptr; +static const dstalk_ai_endpoint_mgr_t* g_endpoint_mgr = nullptr; // I08: AI endpoint manager(可选)/ optional // ---- 运行时状态 / Runtime state ---- // g_current_model tracks the active model name for display in the prompt. @@ -134,6 +135,61 @@ static void spinner_join() } } +// ---- AI 调用路由(endpoint_mgr 优先,g_ai fallback)/ AI call routing (endpoint_mgr preferred, g_ai fallback) ---- +// 当 endpoint_mgr 可用且至少有一个已配置 endpoint 时,通过 endpoint_mgr 路由调用; +// 否则回退到直接使用 g_ai 服务(保持旧配置兼容)。 +// When endpoint_mgr is available with >=1 configured endpoints, route through it; +// otherwise fall back to direct g_ai service (keeping old config compatible). + +// 是否有可用的 endpoint_mgr / Whether endpoint_mgr is usable +static inline bool has_endpoint_mgr() +{ + return g_endpoint_mgr != nullptr && g_endpoint_mgr->count() > 0; +} + +// 是否有任一 AI 后端 / Whether any AI backend is usable +static inline bool has_ai_backend() +{ + return has_endpoint_mgr() || g_ai != nullptr; +} + +// 阻塞 chat 路由 / Blocking chat routing +static dstalk_chat_result_t do_chat( + const dstalk_message_t* history, int history_len, + const char* user_input, const char* tools_json) +{ + if (has_endpoint_mgr()) + return g_endpoint_mgr->chat(nullptr, history, history_len, user_input, tools_json); + return g_ai->chat(history, history_len, user_input, tools_json); +} + +// 流式 chat 路由 / Streaming chat routing +static dstalk_chat_result_t do_chat_stream( + const dstalk_message_t* history, int history_len, + const char* user_input, dstalk_stream_cb cb, void* userdata) +{ + if (has_endpoint_mgr()) + return g_endpoint_mgr->chat_stream(nullptr, history, history_len, user_input, cb, userdata); + return g_ai->chat_stream(history, history_len, user_input, cb, userdata); +} + +// 释放 chat result(使用对应服务) / Free chat result (use corresponding service) +static void do_free_result(dstalk_chat_result_t* result) +{ + if (has_endpoint_mgr()) + g_endpoint_mgr->free_result(result); + else + g_ai->free_result(result); +} + +// 设置模型(endpoint_mgr 优先) / Set model (endpoint_mgr preferred) +static int do_set_model(const char* model) +{ + if (has_endpoint_mgr()) + return g_endpoint_mgr->set_model(nullptr, model); + return g_ai->configure(nullptr, nullptr, nullptr, model, 0, 0.0); +} + // ---- 错误分类与友好提示 / Error classification and user-friendly messages ---- // 根据 HTTP 状态码和错误消息字符串匹配,将常见错误归类为认证/频率限制/网络/配额问题,并给出中文建议。 // Classifies common errors into auth/rate-limit/network/quota categories based on HTTP status and string matching, with Chinese suggestions. @@ -382,6 +438,21 @@ static void handle_command(const char* line) const dstalk_tools_service_t* tools = static_cast( dstalk_service_query("tools", 1)); std::printf(" Tools 服务: %s\n", tools ? "就绪" : "不可用"); + + // I08/I09: endpoint manager 状态 / endpoint manager status + if (g_endpoint_mgr) { + std::printf(" --- Endpoint Manager ---\n"); + std::printf(" 状态: 就绪 (%d endpoint(s))\n", g_endpoint_mgr->count()); + const char* active = g_endpoint_mgr->get_active(); + std::printf(" Active Endpoint: %s\n", active ? active : "(无)"); + char* list_json = g_endpoint_mgr->list_json(); + if (list_json) { + std::printf(" Endpoints: %s\n", list_json); // JSON 不含 api_key,已脱敏 / no api_key in JSON, already desensitized + dstalk_free(list_json); + } + } else { + std::printf(" Endpoint Manager: 不可用\n"); + } return; } @@ -393,7 +464,15 @@ static void handle_command(const char* line) std::printf(CLR_RED "[ERROR] /model 需要模型名\n" CLR_RESET); return; } - if (g_ai) { + // I08: 优先通过 endpoint_mgr 设置模型,fallback 到 g_ai->configure / prefer endpoint_mgr, fallback to g_ai + if (g_endpoint_mgr && g_endpoint_mgr->count() > 0) { + if (g_endpoint_mgr->set_model(nullptr, model) == 0) { + g_current_model = model; + std::printf(CLR_GREEN "[OK] 模型已切换: %s (via endpoint_mgr)\n" CLR_RESET, model); + } else { + std::printf(CLR_RED "[ERROR] 模型切换失败(endpoint 不存在或未配置)\n" CLR_RESET); + } + } else if (g_ai) { g_ai->configure(nullptr, nullptr, nullptr, model, 0, 0.0); g_current_model = model; std::printf(CLR_GREEN "[OK] 模型已切换: %s\n" CLR_RESET, model); @@ -645,6 +724,9 @@ int main(int argc, char* argv[]) g_session = static_cast(dstalk_service_query("session", 1)); g_file_io = static_cast(dstalk_service_query("file_io", 1)); g_tools = static_cast(dstalk_service_query("tools", 1)); + // I08: 查询 AI endpoint manager(可选服务)/ query AI endpoint manager (optional service) + g_endpoint_mgr = static_cast( + dstalk_service_query("ai_endpoint_mgr", 1)); if (!g_ai) { std::fprintf(stderr, CLR_RED "[dstalk] AI 服务未找到(请检查插件目录)\n" CLR_RESET); @@ -663,6 +745,12 @@ int main(int argc, char* argv[]) g_ai->configure(ai_provider, base_url, api_key ? api_key : "", model, 4096, 0.7); g_current_model = model; // A1: 记录当前模型名 / Record current model name } + // I08: 记录 endpoint_mgr 可用性 / log endpoint_mgr availability + if (g_endpoint_mgr && g_endpoint_mgr->count() > 0) { + const char* active = g_endpoint_mgr->get_active(); + std::fprintf(stderr, "[dstalk] endpoint_mgr: %d endpoint(s), active=%s\n", + g_endpoint_mgr->count(), active ? active : "(none)"); + } if (!batch_mode) { std::printf("\n"); @@ -678,22 +766,23 @@ int main(int argc, char* argv[]) dstalk_shutdown(); return EXIT_FATAL; } - if (!g_ai || !g_session) { + if (!has_ai_backend() || !g_session) { std::fprintf(stderr, CLR_RED "[ERROR] AI or session service unavailable\n" CLR_RESET); dstalk_shutdown(); return EXIT_CONFIG; } int history_count = 0; const dstalk_message_t* history = g_session->history(&history_count); - dstalk_chat_result_t result = g_ai->chat(history, history_count, input.c_str(), nullptr); + // I08: 通过 endpoint_mgr 路由(优先),或 fallback 到 g_ai / route via endpoint_mgr (preferred), or fallback to g_ai + dstalk_chat_result_t result = do_chat(history, history_count, input.c_str(), nullptr); if (result.ok) { std::printf("%s\n", result.content ? result.content : ""); - g_ai->free_result(&result); + do_free_result(&result); dstalk_shutdown(); return EXIT_OK; } else { print_error(result.error, result.http_status); - g_ai->free_result(&result); + do_free_result(&result); dstalk_shutdown(); return EXIT_FATAL; } @@ -718,22 +807,23 @@ int main(int argc, char* argv[]) } prompt_text = prompt_arg; } - if (!g_ai || !g_session) { + if (!has_ai_backend() || !g_session) { std::fprintf(stderr, CLR_RED "[ERROR] AI or session service unavailable\n" CLR_RESET); dstalk_shutdown(); return EXIT_CONFIG; } int history_count = 0; const dstalk_message_t* history = g_session->history(&history_count); - dstalk_chat_result_t result = g_ai->chat(history, history_count, prompt_text.c_str(), nullptr); + // I08: 通过 endpoint_mgr 路由(优先),或 fallback 到 g_ai / route via endpoint_mgr (preferred), or fallback to g_ai + dstalk_chat_result_t result = do_chat(history, history_count, prompt_text.c_str(), nullptr); if (result.ok) { std::printf("%s\n", result.content ? result.content : ""); - g_ai->free_result(&result); + do_free_result(&result); dstalk_shutdown(); return EXIT_OK; } else { print_error(result.error, result.http_status); - g_ai->free_result(&result); + do_free_result(&result); dstalk_shutdown(); return EXIT_FATAL; } @@ -770,7 +860,7 @@ int main(int argc, char* argv[]) } // AI 对话(通过插件服务 vtable) / AI chat (via plugin service vtable) - if (!g_ai || !g_session) { + if (!has_ai_backend() || !g_session) { std::printf(CLR_RED "[ERROR] AI 或 Session 服务不可用\n" CLR_RESET); continue; } @@ -782,7 +872,8 @@ int main(int argc, char* argv[]) // 启动 spinner,等待 AI 响应 / Start spinner while waiting for AI response spinner_start(); bool first = true; - dstalk_chat_result_t result = g_ai->chat_stream( + // I08: 通过 endpoint_mgr 路由(优先),或 fallback 到 g_ai / route via endpoint_mgr (preferred), or fallback to g_ai + dstalk_chat_result_t result = do_chat_stream( history, history_count, line.c_str(), on_stream_token, &first); // 确保 spinner 已停止(处理无流式输出的情况) / Ensure spinner is stopped (handles no-stream-output case) @@ -866,10 +957,12 @@ int main(int argc, char* argv[]) history_count = 0; history = g_session->history(&history_count); - g_ai->free_result(&result); + // I08: 通过 endpoint_mgr 路由 free_result / route free_result via endpoint_mgr + do_free_result(&result); spinner_start(); bool tool_stream_first = true; - result = g_ai->chat_stream(history, history_count, nullptr, on_stream_token, &tool_stream_first); + // I08: 通过 endpoint_mgr 路由 chat_stream / route chat_stream via endpoint_mgr + result = do_chat_stream(history, history_count, nullptr, on_stream_token, &tool_stream_first); spinner_stop(); if (result.ok) { @@ -896,7 +989,7 @@ int main(int argc, char* argv[]) std::printf(CLR_RESET "\n"); print_error(result.error, result.http_status); } - g_ai->free_result(&result); + do_free_result(&result); } // B2: 单一退出点,dstalk_shutdown 只在此调用(交互模式下) / Single exit point, dstalk_shutdown only called here (in interactive mode) diff --git a/dstalk_core/include/dstalk/dstalk_services.h b/dstalk_core/include/dstalk/dstalk_services.h index aa1f771..3770118 100644 --- a/dstalk_core/include/dstalk/dstalk_services.h +++ b/dstalk_core/include/dstalk/dstalk_services.h @@ -35,6 +35,42 @@ typedef struct { void (*free_result)(dstalk_chat_result_t* result); } dstalk_ai_service_t; +/* ---- AI endpoint manager 服务 vtable / AI endpoint manager service vtable ---- */ +/* 以服务名称 "ai_endpoint_mgr" 注册;用于按名称路由到多个 provider/model endpoint。 / Registered as "ai_endpoint_mgr"; routes named endpoints to provider/model configs. */ +typedef struct { + /* 返回已配置 endpoint 数量 / Return configured endpoint count. */ + int (*count)(void); + /* 返回 endpoint 列表 JSON,调用方用 dstalk_free 释放 / Return endpoint list JSON; caller frees with dstalk_free. */ + char* (*list_json)(void); + /* 获取当前 active endpoint 名称 / Get current active endpoint name. + 返回的指针指向 thread_local 存储区,调用方不需要释放。该指针在同线程下一次 + get_active 调用前或 active endpoint 状态变化前有效;跨线程并发调用各自拥有 + 独立的 thread_local 副本,互不干扰。 + The returned pointer points to thread-local storage; caller must not free it. + Valid until the next get_active call on the same thread or until the active + endpoint changes. Concurrent calls from different threads each have their own + independent thread-local copy. */ + const char* (*get_active)(void); + /* 设置当前 active endpoint / Set current active endpoint. */ + int (*set_active)(const char* endpoint_name); + /* 设置 endpoint 模型名;endpoint_name 为空时修改 active endpoint / Set endpoint model; null endpoint_name updates active endpoint. */ + int (*set_model)(const char* endpoint_name, const char* model); + /* 在指定 endpoint 上执行阻塞 chat / Run blocking chat on a named endpoint. */ + dstalk_chat_result_t (*chat)( + const char* endpoint_name, + const dstalk_message_t* history, int history_len, + const char* user_input, + const char* tools_json); + /* 在指定 endpoint 上执行流式 chat / Run streaming chat on a named endpoint. */ + dstalk_chat_result_t (*chat_stream)( + const char* endpoint_name, + const dstalk_message_t* history, int history_len, + const char* user_input, + dstalk_stream_cb cb, void* userdata); + /* 释放 endpoint manager 返回的 chat result / Free chat result returned by endpoint manager. */ + void (*free_result)(dstalk_chat_result_t* result); +} dstalk_ai_endpoint_mgr_t; + /* ---- 会话服务 vtable / Session service vtable ---- */ /* 以服务名称 "session" 注册 / Registered under service name "session" */ typedef struct { diff --git a/dstalk_frontend_common/include/dstalk_frontend_common.hpp b/dstalk_frontend_common/include/dstalk_frontend_common.hpp index 703af21..60603f5 100644 --- a/dstalk_frontend_common/include/dstalk_frontend_common.hpp +++ b/dstalk_frontend_common/include/dstalk_frontend_common.hpp @@ -4,7 +4,7 @@ // 提供所有前端(CLI / GUI / Web)共享的启动逻辑: // - 配置文件发现(argv / 默认路径 / 平台 fopen) // - dstalk_init() 调用 -// - 常用服务查询(ai, session, file_io, tools, context) +// - 常用服务查询(ai / endpoint_mgr / session / file_io / tools / context) // - AI 服务默认配置(从 config 读取,带 fallback) // ============================================================================ @@ -20,6 +20,7 @@ struct FrontendServices { const dstalk_session_service_t* session = nullptr; const dstalk_file_io_service_t* file_io = nullptr; const dstalk_tools_service_t* tools = nullptr; + const dstalk_ai_endpoint_mgr_t* endpoint_mgr = nullptr; // I08: AI endpoint manager(可选)/ optional std::string provider; // "ai.deepseek" / "ai.openai" / "ai.anthropic" std::string model; // e.g. "deepseek-v4-pro" @@ -35,8 +36,8 @@ struct FrontendServices { // 功能: // 1. 发现配置文件:优先 argv[1](跳过已知标志),其次 default_config(如 "config.toml") // 2. 调用 dstalk_init(config_path) -// 3. 查询常用插件服务(ai / session / file_io / tools) -// 4. 用 dstalk_config_get 读取 api.* 键并调用 ai->configure() 设置默认值 +// 3. 查询常用插件服务(ai / endpoint_mgr / session / file_io / tools) +// 4. 用 dstalk_config_get 读取 api.* 键并调用 ai->configure() 设置旧单 provider 默认值 // // 参数: // svc - [out] 填入查询到的服务指针和配置信息 diff --git a/dstalk_frontend_common/src/frontend_common.cpp b/dstalk_frontend_common/src/frontend_common.cpp index 1bb3c09..aed0f85 100644 --- a/dstalk_frontend_common/src/frontend_common.cpp +++ b/dstalk_frontend_common/src/frontend_common.cpp @@ -79,12 +79,19 @@ int dstalk_frontend_init(FrontendServices& svc, svc.tools = static_cast( dstalk_service_query("tools", 1)); + // I08: 查询 AI endpoint manager(可选服务)/ query AI endpoint manager (optional service) + svc.endpoint_mgr = static_cast( + dstalk_service_query("ai_endpoint_mgr", 1)); + const dstalk_context_service_t* ctx_svc = static_cast( dstalk_service_query("context", 1)); (void)ctx_svc; // 不强制使用,保留以备将来使用 - if (!svc.ai) { + // endpoint_mgr 可作为 AI 后端;缺少旧 ai provider 时仍允许多 endpoint 配置工作。 + // endpoint_mgr can serve as the AI backend; allow multi-endpoint config even when legacy ai provider is absent. + const bool endpoint_mgr_ready = svc.endpoint_mgr && svc.endpoint_mgr->count() > 0; + if (!svc.ai && !endpoint_mgr_ready) { std::fprintf(stderr, "[dstalk] AI 服务未找到(请检查插件目录)\n"); return 2; } @@ -93,7 +100,8 @@ int dstalk_frontend_init(FrontendServices& svc, return 3; } - // (3) 配置 AI 服务的默认值 + // (3) 配置 AI 服务的默认值;endpoint_mgr 路径由 endpoint..* 配置驱动。 + // Configure legacy AI defaults; endpoint_mgr is driven by endpoint..* config. const char* base_url = dstalk_config_get("api.base_url"); const char* api_key = dstalk_config_get("api.api_key"); const char* model = dstalk_config_get("api.model"); @@ -105,9 +113,11 @@ int dstalk_frontend_init(FrontendServices& svc, svc.api_key = api_key ? api_key : ""; svc.model = model; - svc.ai->configure(provider, base_url, - api_key ? api_key : "", - model, 4096, 0.7); + if (svc.ai) { + svc.ai->configure(provider, base_url, + api_key ? api_key : "", + model, 4096, 0.7); + } svc.initialized = true; return 0; diff --git a/dstalk_gui/src/main.cpp b/dstalk_gui/src/main.cpp index dd2df1e..ad85044 100644 --- a/dstalk_gui/src/main.cpp +++ b/dstalk_gui/src/main.cpp @@ -27,6 +27,7 @@ // 在启动时从主机查询获取的服务 vtable 全局指针。 static const dstalk_ai_service_t* g_ai_svc = nullptr; static const dstalk_session_service_t* g_session_svc = nullptr; +static const dstalk_ai_endpoint_mgr_t* g_endpoint_mgr = nullptr; // I08: AI endpoint manager(可选)/ optional // ---- 常量 / Constants ---- @@ -287,10 +288,19 @@ static void renderStatusBar(AppContext& ctx) { } // 状态文本:模型名 | 消息条数 | 流式状态 / Status text: model name | message count | streaming state + // I08: 添加 endpoint manager 信息(如果可用)/ add endpoint manager info (if available) char buf[256]; - snprintf(buf, sizeof(buf), "%s | %d messages | %s", - gs.model_name.c_str(), msgCount, - gs.streaming ? "streaming" : "ready"); + const char* active_ep = nullptr; + if (g_endpoint_mgr) active_ep = g_endpoint_mgr->get_active(); + if (active_ep && active_ep[0]) { + snprintf(buf, sizeof(buf), "%s | %d messages | %s | ep:%s", + gs.model_name.c_str(), msgCount, + gs.streaming ? "streaming" : "ready", active_ep); + } else { + snprintf(buf, sizeof(buf), "%s | %d messages | %s", + gs.model_name.c_str(), msgCount, + gs.streaming ? "streaming" : "ready"); + } drawText(r, static_cast(PADDING), barY + (STATUS_H - CHAR_H) / 2.0f, buf, COL_WHITE); } @@ -832,6 +842,9 @@ int main(int argc, char* argv[]) { if (!ai_provider) ai_provider = "ai_openai"; g_ai_svc = static_cast(dstalk_service_query(ai_provider, 1)); g_session_svc = static_cast(dstalk_service_query("session", 1)); + // I08: 查询 AI endpoint manager(可选服务)/ query AI endpoint manager (optional service) + g_endpoint_mgr = static_cast( + dstalk_service_query("ai_endpoint_mgr", 1)); if (!g_ai_svc) dstalk_log(3, "AI service not found (check plugins directory)"); if (!g_session_svc) dstalk_log(3, "Session service not found"); @@ -899,15 +912,25 @@ int main(int argc, char* argv[]) { std::string& userMsg = ctx.state.messages[ctx.state.messages.size() - 2].content; int rc = -1; - if (g_ai_svc) { + // I08: 优先通过 endpoint_mgr 路由,fallback 到 g_ai_svc / prefer endpoint_mgr, fallback to g_ai_svc + const bool use_mgr = (g_endpoint_mgr && g_endpoint_mgr->count() > 0); + if (use_mgr || g_ai_svc) { int hcount = 0; const dstalk_message_t* history = g_session_svc ? g_session_svc->history(&hcount) : nullptr; - dstalk_chat_result_t result = g_ai_svc->chat_stream( - history, hcount, userMsg.c_str(), - streamTokenCallback, &ctx); + dstalk_chat_result_t result; + if (use_mgr) { + result = g_endpoint_mgr->chat_stream( + nullptr, history, hcount, userMsg.c_str(), + streamTokenCallback, &ctx); + } else { + result = g_ai_svc->chat_stream( + history, hcount, userMsg.c_str(), + streamTokenCallback, &ctx); + } rc = result.ok ? 0 : -1; - g_ai_svc->free_result(&result); + if (use_mgr) g_endpoint_mgr->free_result(&result); + else g_ai_svc->free_result(&result); } // 流式传输完成(或被取消) diff --git a/dstalk_web/src/main.cpp b/dstalk_web/src/main.cpp index 66d155e..461829c 100644 --- a/dstalk_web/src/main.cpp +++ b/dstalk_web/src/main.cpp @@ -43,6 +43,7 @@ class SseSession; // 插件服务 vtable 的全局指针,在启动时从主机查询获取。 static const dstalk_ai_service_t* g_ai = nullptr; static const dstalk_session_service_t* g_session = nullptr; +static const dstalk_ai_endpoint_mgr_t* g_endpoint_mgr = nullptr; // I08: AI endpoint manager(可选)/ optional // ---- 运行时状态 / Runtime state ---- // g_quit signals the main loop to exit (set by Ctrl+C). @@ -208,8 +209,19 @@ static void run_chat_worker( }; // 调用流式 AI 聊天 / Call streaming AI chat - dstalk_chat_result_t result = g_ai->chat_stream( - history, history_count, nullptr, token_cb, &cb_data); + // I08: 优先通过 endpoint_mgr 路由,fallback 到 g_ai / prefer endpoint_mgr, fallback to g_ai + dstalk_chat_result_t result = {}; + const bool use_mgr = (g_endpoint_mgr && g_endpoint_mgr->count() > 0); + if (use_mgr) { + result = g_endpoint_mgr->chat_stream( + nullptr, history, history_count, nullptr, token_cb, &cb_data); + } else if (g_ai) { + result = g_ai->chat_stream( + history, history_count, nullptr, token_cb, &cb_data); + } else { + result.ok = 0; + result.error = dstalk_strdup("AI service unavailable"); + } // 将 AI 回复加入会话 / Add AI reply to session if (result.ok) { @@ -221,7 +233,11 @@ static void run_chat_worker( bool ok = result.ok; std::string content_copy = result.content ? result.content : ""; std::string error_copy = result.error ? result.error : ""; - g_ai->free_result(&result); + + // I08: 根据路由来源释放 result / free result based on routing source + if (use_mgr) g_endpoint_mgr->free_result(&result); + else if (g_ai) g_ai->free_result(&result); + else if (result.error) dstalk_free((void*)result.error); asio::post(ioc, [weak_sse, ok, content_copy, error_copy]() { if (auto sse = weak_sse.lock()) { @@ -373,6 +389,23 @@ private: if (model) st["model"] = std::string(model); st["status"] = "running"; + // I08/I09: endpoint manager 状态 / endpoint manager status + if (g_endpoint_mgr) { + st["endpoint_mgr_available"] = true; + st["endpoint_count"] = g_endpoint_mgr->count(); + const char* active = g_endpoint_mgr->get_active(); + if (active) st["active_endpoint"] = std::string(active); + char* list_json = g_endpoint_mgr->list_json(); + if (list_json) { + boost::system::error_code ec; + auto jv = boost::json::parse(list_json, ec); + if (!ec) st["endpoints"] = std::move(jv); + dstalk_free(list_json); + } + } else { + st["endpoint_mgr_available"] = false; + } + auto self = shared_from_this(); http::response res{http::status::ok, request_.version()}; res.set("Access-Control-Allow-Origin", "*"); @@ -510,6 +543,9 @@ int main(int argc, char* argv[]) if (!ai_provider) ai_provider = "ai_openai"; g_ai = static_cast(dstalk_service_query(ai_provider, 1)); g_session = static_cast(dstalk_service_query("session", 1)); + // I08: 查询 AI endpoint manager(可选服务)/ query AI endpoint manager (optional service) + g_endpoint_mgr = static_cast( + dstalk_service_query("ai_endpoint_mgr", 1)); if (!g_ai) { std::fprintf(stderr, "[dstalk_web] AI service not found (check plugins directory)\n"); diff --git a/plugins_upper/CMakeLists.txt b/plugins_upper/CMakeLists.txt index f274066..9bb077d 100644 --- a/plugins_upper/CMakeLists.txt +++ b/plugins_upper/CMakeLists.txt @@ -6,3 +6,4 @@ add_subdirectory(ai_common) # 共享 AI 工具库(静态库)/ shared AI u add_subdirectory(context) # 依赖 session / depends on session add_subdirectory(openai) # 依赖 http, config, ai_common / depends on http, config, ai_common add_subdirectory(anthropic) # 依赖 http, config, ai_common / depends on http, config, ai_common +add_subdirectory(ai_endpoint_mgr) # 路由多个 AI endpoint / routes multiple AI endpoints diff --git a/plugins_upper/ai_endpoint_mgr/CMakeLists.txt b/plugins_upper/ai_endpoint_mgr/CMakeLists.txt new file mode 100644 index 0000000..d1f22b3 --- /dev/null +++ b/plugins_upper/ai_endpoint_mgr/CMakeLists.txt @@ -0,0 +1,32 @@ +# ============================================================ +# AI endpoint manager plugin / AI endpoint manager 插件 +# ============================================================ + +find_package(Boost REQUIRED CONFIG) + +add_library(plugin_ai_endpoint_mgr SHARED + src/endpoint_mgr_plugin.cpp +) + +target_include_directories(plugin_ai_endpoint_mgr + PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk_core/include + ${CMAKE_SOURCE_DIR}/plugins_upper/ai_common/include +) + +target_link_libraries(plugin_ai_endpoint_mgr + PRIVATE + dstalk + ai_common + dstalk_boost_config + boost::boost +) + +# cxx_std_20 已由 dstalk 和 ai_common (PUBLIC) 传播,无需重复声明 +# cxx_std_20 is already propagated by dstalk and ai_common (PUBLIC); no need to redeclare + +set_target_properties(plugin_ai_endpoint_mgr PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp b/plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp new file mode 100644 index 0000000..bac80a9 --- /dev/null +++ b/plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp @@ -0,0 +1,400 @@ +/* + * @file endpoint_mgr_plugin.cpp + * @brief AI endpoint manager: routes named endpoint configs to AI provider services. + * AI endpoint 管理器:把命名 endpoint 配置路由到具体 AI provider 服务。 + * Copyright (c) 2026 dstalk contributors. GPLv3. + */ + +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" +#include "ai_common.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace json = boost::json; + +struct EndpointConfig { + std::string name; + std::string provider; + std::string base_url; + std::string api_key; + std::string model; + int max_tokens = 4096; + double temperature = 0.7; +}; + +static std::atomic g_host{nullptr}; +static std::unordered_map g_endpoints; +static std::string g_active_endpoint; +static std::shared_mutex g_endpoints_mutex; + +// 按 provider 名称动态分配互斥锁;避免未知 provider 错误共享 OpenAI/Anthropic 专用锁 +// Dynamically allocate mutex per provider name; prevents unknown providers from incorrectly sharing the OpenAI/Anthropic-specific locks +static std::shared_mutex g_provider_mutexes_mutex; +static std::unordered_map> g_provider_mutexes; + +static std::mutex& provider_mutex(const std::string& provider) +{ + // 快速路径:读锁查找已有 mutex / Fast path: read lock to find existing mutex + { + std::shared_lock lock(g_provider_mutexes_mutex); + auto it = g_provider_mutexes.find(provider); + if (it != g_provider_mutexes.end()) return *it->second; + } + // 慢速路径:写锁创建新 mutex (双重检查) / Slow path: write lock to create new mutex (double-check) + std::unique_lock lock(g_provider_mutexes_mutex); + auto it = g_provider_mutexes.find(provider); + if (it != g_provider_mutexes.end()) return *it->second; + auto [new_it, _] = g_provider_mutexes.emplace(provider, std::make_unique()); + return *new_it->second; +} + +static std::string trim_copy(std::string s) +{ + auto is_space = [](unsigned char c) { return c == ' ' || c == '\t' || c == '\r' || c == '\n'; }; + while (!s.empty() && is_space(static_cast(s.front()))) s.erase(s.begin()); + while (!s.empty() && is_space(static_cast(s.back()))) s.pop_back(); + return s; +} + +static std::vector split_csv(const char* raw) +{ + std::vector out; + if (!raw || !*raw) return out; + std::stringstream ss(raw); + std::string item; + while (std::getline(ss, item, ',')) { + item = trim_copy(item); + if (!item.empty()) out.push_back(item); + } + return out; +} + +static int parse_int_or_default(const char* raw, int fallback) +{ + if (!raw || !*raw) return fallback; + char* end = nullptr; + long v = std::strtol(raw, &end, 10); + if (!end || *end != '\0' || v <= 0 || v > 1000000) return fallback; + return static_cast(v); +} + +static double parse_double_or_default(const char* raw, double fallback) +{ + if (!raw || !*raw) return fallback; + char* end = nullptr; + double v = std::strtod(raw, &end); + if (!end || *end != '\0' || v < 0.0 || v > 2.0) return fallback; + return v; +} + +static const char* cfg_get(const dstalk_host_api_t* host, const std::string& key) +{ + if (!host || !host->config_get) return nullptr; + return host->config_get(key.c_str()); +} + +static std::string cfg_get_copy(const dstalk_host_api_t* host, const std::string& key) +{ + const char* value = cfg_get(host, key); + return value ? std::string(value) : std::string(); +} + +static const char* default_base_url_for_provider(const std::string& provider) +{ + if (provider == "ai_anthropic") return "https://api.anthropic.com"; + if (provider == "ai_openai") return "https://api.openai.com/v1"; + return nullptr; // 未知 provider 必须显式配置 base_url / unknown provider must explicitly configure base_url +} + +static void clear_endpoints_locked() +{ + for (auto& kv : g_endpoints) { + dstalk_ai::secure_zero(kv.second.api_key.data(), kv.second.api_key.size()); + kv.second.api_key.clear(); + } + g_endpoints.clear(); + g_active_endpoint.clear(); +} + +static const dstalk_ai_service_t* lookup_ai_service(const EndpointConfig& ep) +{ + const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); + if (!host || !host->query_service || ep.provider.empty()) return nullptr; + return static_cast(host->query_service(ep.provider.c_str(), 1)); +} + +static dstalk_chat_result_t make_error(const char* msg, int status = 0) +{ + dstalk_chat_result_t r = {}; + r.ok = 0; + r.http_status = status; + const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); + r.error = (host && host->strdup) ? host->strdup(msg ? msg : "endpoint manager error") : nullptr; + return r; +} + +static bool load_endpoint(const dstalk_host_api_t* host, const std::string& name, EndpointConfig& out) +{ + std::string prefix = "endpoint." + name + "."; + std::string provider = cfg_get_copy(host, prefix + "provider"); + std::string base_url = cfg_get_copy(host, prefix + "base_url"); + std::string api_key = cfg_get_copy(host, prefix + "api_key"); + std::string model = cfg_get_copy(host, prefix + "model"); + if (provider.empty() || model.empty()) return false; + + out.name = name; + out.provider = provider; + + // 设定 base_url: 显式配置优先,其次用已知 provider 的默认值;未知 provider 必须显式配置 + // Determine base_url: explicit config first, then known provider default; unknown providers must configure explicitly + if (!base_url.empty()) { + out.base_url = base_url; + } else { + const char* default_url = default_base_url_for_provider(out.provider); + if (default_url) { + out.base_url = default_url; + } else { + return false; // 未知 provider 且未配置 base_url / unknown provider without explicit base_url + } + } + + out.api_key = api_key; + out.model = model; + out.max_tokens = parse_int_or_default(cfg_get(host, prefix + "max_tokens"), 4096); + out.temperature = parse_double_or_default(cfg_get(host, prefix + "temperature"), 0.7); + return host && host->query_service && host->query_service(out.provider.c_str(), 1) != nullptr; +} + +static int reload_endpoints_locked(const dstalk_host_api_t* host) +{ + clear_endpoints_locked(); + if (!host) return 0; + + std::vector names = split_csv(cfg_get(host, "endpoints.names")); + if (names.empty()) return 0; + + for (const std::string& name : names) { + if (g_endpoints.find(name) != g_endpoints.end()) { + if (host->log) host->log(DSTALK_LOG_WARN, "[ai_endpoint_mgr] skipping duplicate endpoint '%s'", name.c_str()); + continue; + } + EndpointConfig ep; + if (load_endpoint(host, name, ep)) { + if (g_active_endpoint.empty()) g_active_endpoint = name; + g_endpoints.emplace(name, std::move(ep)); + } else if (host->log) { + host->log(DSTALK_LOG_WARN, "[ai_endpoint_mgr] skipping invalid endpoint '%s'", name.c_str()); + } + } + + const char* active = cfg_get(host, "endpoints.active"); + if (active && g_endpoints.count(active)) { + g_active_endpoint = active; + } + return static_cast(g_endpoints.size()); +} + +static int mgr_count() +{ + std::shared_lock lock(g_endpoints_mutex); + return static_cast(g_endpoints.size()); +} + +static char* mgr_list_json() +{ + const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); + if (!host || !host->strdup) return nullptr; + json::array arr; + { + std::shared_lock lock(g_endpoints_mutex); + std::vector names; + names.reserve(g_endpoints.size()); + for (const auto& kv : g_endpoints) names.push_back(kv.first); + std::sort(names.begin(), names.end()); + for (const auto& name : names) { + const auto& ep = g_endpoints.at(name); + json::object o; + o["name"] = ep.name; + o["provider"] = ep.provider; + o["base_url"] = ep.base_url; + o["model"] = ep.model; + o["active"] = (ep.name == g_active_endpoint); + arr.emplace_back(std::move(o)); + } + } + return host->strdup(json::serialize(arr).c_str()); +} + +static const char* mgr_get_active() +{ + static thread_local std::string active; + std::shared_lock lock(g_endpoints_mutex); + active = g_active_endpoint; + return active.empty() ? nullptr : active.c_str(); +} + +static int mgr_set_active(const char* endpoint_name) +{ + if (!endpoint_name || !*endpoint_name) return -1; + std::unique_lock lock(g_endpoints_mutex); + auto it = g_endpoints.find(endpoint_name); + if (it == g_endpoints.end()) return -2; + g_active_endpoint = endpoint_name; + return 0; +} + +static EndpointConfig lookup_endpoint(const char* endpoint_name) +{ + std::shared_lock lock(g_endpoints_mutex); + std::string name; + if (endpoint_name && *endpoint_name) name = endpoint_name; + else name = g_active_endpoint; + auto it = g_endpoints.find(name); + if (it == g_endpoints.end()) return {}; + return it->second; +} + +static int mgr_set_model(const char* endpoint_name, const char* model) +{ + if (!model || !*model) return -1; + + const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); + std::string selected; + { + std::unique_lock lock(g_endpoints_mutex); + selected = (endpoint_name && *endpoint_name) ? endpoint_name : g_active_endpoint; + auto it = g_endpoints.find(selected); + if (it == g_endpoints.end()) return -2; + it->second.model = model; + } + + if (host && host->config_set) { + std::string key = "endpoint." + selected + ".model"; + host->config_set(key.c_str(), model); + } + return 0; +} + +static dstalk_chat_result_t mgr_chat(const char* endpoint_name, + const dstalk_message_t* history, + int history_len, + const char* user_input, + const char* tools_json) +{ + // 防御: history_len > 0 时 history 不得为 nullptr / Guard: history must not be null when history_len > 0 + if (history_len > 0 && history == nullptr) return make_error("null history with non-zero length"); + EndpointConfig ep = lookup_endpoint(endpoint_name); + const dstalk_ai_service_t* service = lookup_ai_service(ep); + if (!service) return make_error("endpoint not found"); + std::lock_guard guard(provider_mutex(ep.provider)); + int rc = service->configure(ep.provider.c_str(), ep.base_url.c_str(), ep.api_key.c_str(), + ep.model.c_str(), ep.max_tokens, ep.temperature); + if (rc != 0) return make_error("endpoint configure failed"); + return service->chat(history, history_len, user_input, tools_json); +} + +static dstalk_chat_result_t mgr_chat_stream(const char* endpoint_name, + const dstalk_message_t* history, + int history_len, + const char* user_input, + dstalk_stream_cb cb, + void* userdata) +{ + // 防御: history_len > 0 时 history 不得为 nullptr / Guard: history must not be null when history_len > 0 + if (history_len > 0 && history == nullptr) return make_error("null history with non-zero length"); + EndpointConfig ep = lookup_endpoint(endpoint_name); + const dstalk_ai_service_t* service = lookup_ai_service(ep); + if (!service) return make_error("endpoint not found"); + std::lock_guard guard(provider_mutex(ep.provider)); + int rc = service->configure(ep.provider.c_str(), ep.base_url.c_str(), ep.api_key.c_str(), + ep.model.c_str(), ep.max_tokens, ep.temperature); + if (rc != 0) return make_error("endpoint configure failed"); + return service->chat_stream(history, history_len, user_input, cb, userdata); +} + +static void mgr_free_result(dstalk_chat_result_t* result) +{ + const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); + dstalk_ai::free_chat_result(host, result); +} + +static dstalk_ai_endpoint_mgr_t g_service = { + &mgr_count, + &mgr_list_json, + &mgr_get_active, + &mgr_set_active, + &mgr_set_model, + &mgr_chat, + &mgr_chat_stream, + &mgr_free_result, +}; + +static int on_init(const dstalk_host_api_t* host) +{ + try { + if (!host) return -1; + g_host.store(host, std::memory_order_release); + { + std::unique_lock lock(g_endpoints_mutex); + reload_endpoints_locked(host); + } + if (host->log) host->log(DSTALK_LOG_INFO, "[ai_endpoint_mgr] loaded %d endpoint(s)", mgr_count()); + return host->register_service("ai_endpoint_mgr", 1, &g_service); + } catch (const std::exception& e) { + if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_init exception: %s", e.what()); + return -1; + } catch (...) { + if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_init unknown exception"); + return -1; + } +} + +static void on_shutdown() +{ + try { + const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); + if (host && host->log) host->log(DSTALK_LOG_INFO, "[ai_endpoint_mgr] shutdown"); + std::unique_lock lock(g_endpoints_mutex); + clear_endpoints_locked(); + g_host.store(nullptr, std::memory_order_release); + } catch (const std::exception& e) { + const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); + if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_shutdown exception: %s", e.what()); + g_host.store(nullptr, std::memory_order_release); + } catch (...) { + const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); + if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_shutdown unknown exception"); + g_host.store(nullptr, std::memory_order_release); + } +} + +static dstalk_plugin_info_t g_info = { + /* .name = */ "ai_endpoint_mgr", + /* .version = */ "1.0.0", + /* .description = */ "AI endpoint manager for multiple named provider/model endpoints / 多命名 AI endpoint 管理器", + /* .api_version = */ DSTALK_API_VERSION, + /* .dependencies = */ { "openai_compat", "anthropic_ai", NULL }, + /* .on_init = */ on_init, + /* .on_shutdown = */ on_shutdown, + /* .on_event = */ nullptr, +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) +{ + return &g_info; +} diff --git a/plugins_upper/anthropic/src/anthropic_plugin.cpp b/plugins_upper/anthropic/src/anthropic_plugin.cpp index afbff5b..5c6edf5 100644 --- a/plugins_upper/anthropic/src/anthropic_plugin.cpp +++ b/plugins_upper/anthropic/src/anthropic_plugin.cpp @@ -65,6 +65,7 @@ static std::string build_request_json( std::string system_prompt; json::array msgs; + if (history) { // 防御: history 为空时跳过历史遍历 / Defensive: skip history iteration when null for (int i = 0; i < history_len; ++i) { const auto& m = history[i]; if (m.role && std::strcmp(m.role, "system") == 0) { @@ -77,6 +78,7 @@ static std::string build_request_json( obj["content"] = m.content ? m.content : ""; msgs.push_back(obj); } + } // if (history) // 追加当前用户输入 / Append current user input { diff --git a/plugins_upper/openai/src/openai_plugin.cpp b/plugins_upper/openai/src/openai_plugin.cpp index ce2e5ff..16556fe 100644 --- a/plugins_upper/openai/src/openai_plugin.cpp +++ b/plugins_upper/openai/src/openai_plugin.cpp @@ -49,6 +49,7 @@ static std::string build_headers_json(const std::string& auth_header_value) static void append_history(json::array& msgs, const dstalk_message_t* history, int history_len) { + if (!history) return; // 防御: history 为空时直接返回 / Defensive: return early when history is null for (int i = 0; i < history_len; ++i) { const auto& m = history[i]; json::object obj; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1a2388c..b3c9b26 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -145,6 +145,8 @@ target_compile_definitions(dstalk_anthropic_plugin_test BOOST_ALL_NO_LIB ) +find_package(Boost REQUIRED CONFIG) + target_link_libraries(dstalk_anthropic_plugin_test PRIVATE dstalk @@ -174,6 +176,8 @@ target_compile_definitions(dstalk_openai_plugin_test BOOST_ALL_NO_LIB ) +find_package(Boost REQUIRED CONFIG) + target_link_libraries(dstalk_openai_plugin_test PRIVATE dstalk @@ -183,6 +187,37 @@ target_link_libraries(dstalk_openai_plugin_test add_test(NAME dstalk_openai_plugin_test COMMAND dstalk_openai_plugin_test) +# ============================================================ +# dstalk_endpoint_mgr_plugin_test — AI endpoint manager 插件单元测试 +# 覆盖: endpoint 加载/列表脱敏/active/model 修改/路由 +# ============================================================ + +add_executable(dstalk_endpoint_mgr_plugin_test + endpoint_mgr_plugin_test.cpp +) + +target_include_directories(dstalk_endpoint_mgr_plugin_test + PRIVATE ${CMAKE_SOURCE_DIR}/dstalk_core/include + PRIVATE ${CMAKE_SOURCE_DIR}/plugins_upper/ai_common/include +) + +target_compile_definitions(dstalk_endpoint_mgr_plugin_test + PRIVATE + BOOST_JSON_HEADER_ONLY + BOOST_ALL_NO_LIB +) + +find_package(Boost REQUIRED CONFIG) + +target_link_libraries(dstalk_endpoint_mgr_plugin_test + PRIVATE + dstalk + ai_common + boost::boost +) + +add_test(NAME dstalk_endpoint_mgr_plugin_test COMMAND dstalk_endpoint_mgr_plugin_test) + # ============================================================ # dstalk_network_plugin_test — Network 插件单元测试 # W22.2 (qa-xu): 通过 #include source 访问 static 函数 @@ -203,6 +238,8 @@ target_compile_definitions(dstalk_network_plugin_test BOOST_ALL_NO_LIB ) +find_package(Boost REQUIRED CONFIG) + target_link_libraries(dstalk_network_plugin_test PRIVATE dstalk @@ -238,6 +275,8 @@ target_compile_features(dstalk_lsp_plugin_test PRIVATE cxx_std_17 ) +find_package(Boost REQUIRED CONFIG) + target_link_libraries(dstalk_lsp_plugin_test PRIVATE dstalk diff --git a/tests/endpoint_mgr_plugin_test.cpp b/tests/endpoint_mgr_plugin_test.cpp new file mode 100644 index 0000000..5b98a23 --- /dev/null +++ b/tests/endpoint_mgr_plugin_test.cpp @@ -0,0 +1,547 @@ +/* + * @file endpoint_mgr_plugin_test.cpp + * @brief AI endpoint manager plugin unit tests: endpoint loading, active/model mutation, routing, secret-safe listing, + * plus error-path coverage (null history, missing endpoint, configure failed, empty/bad active, concurrency). + * AI endpoint 管理器插件单元测试:endpoint 加载、active/model 修改、路由和脱敏列表、 + * 以及错误路径覆盖(空 history、缺失 endpoint、configure 失败、空/错 active、并发)。 + * Copyright (c) 2026 dstalk contributors. GPLv3. + */ +#include "../plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +static int g_failures = 0; +#define CHECK(cond, msg) do { \ + if (cond) { \ + std::cout << "[OK] " << (msg) << "\n"; \ + } else { \ + std::cerr << "[FAIL] " << (msg) << "\n"; \ + g_failures++; \ + } \ +} while (0) + +struct ConfigureRecord { + std::string provider; + std::string base_url; + std::string api_key; + std::string model; + int max_tokens = 0; + double temperature = 0.0; + int configure_calls = 0; + int chat_calls = 0; + int stream_calls = 0; +}; + +static std::unordered_map g_config_values; +static const dstalk_ai_endpoint_mgr_t* g_registered_mgr = nullptr; +static ConfigureRecord g_last_configure; +static int g_stream_cb_count = 0; + +static void* fake_alloc(size_t size) +{ + return std::malloc(size); +} + +static void fake_free(void* ptr) +{ + std::free(ptr); +} + +static char* fake_strdup(const char* s) +{ + if (!s) return nullptr; + size_t n = std::strlen(s) + 1; + char* p = static_cast(std::malloc(n)); + if (p) std::memcpy(p, s, n); + return p; +} + +static int fake_register_service(const char* name, int version, void* vtable) +{ + if (!name || !vtable || version != 1) return -1; + if (std::strcmp(name, "ai_endpoint_mgr") == 0) { + g_registered_mgr = static_cast(vtable); + return 0; + } + return -2; +} + +static void fake_unregister_service(const char*) +{ +} + +static const char* fake_config_get(const char* key) +{ + if (!key) return nullptr; + auto it = g_config_values.find(key); + if (it == g_config_values.end()) return nullptr; + static thread_local std::string tls_value; + tls_value = it->second; + return tls_value.c_str(); +} + +static int fake_config_set(const char* key, const char* value) +{ + if (!key || !value) return -1; + g_config_values[key] = value; + return 0; +} + +static void fake_log(int, const char*, ...) +{ +} + +static int fake_configure(const char* provider, const char* base_url, + const char* api_key, const char* model, + int max_tokens, double temperature) +{ + g_last_configure.provider = provider ? provider : ""; + g_last_configure.base_url = base_url ? base_url : ""; + g_last_configure.api_key = api_key ? api_key : ""; + g_last_configure.model = model ? model : ""; + g_last_configure.max_tokens = max_tokens; + g_last_configure.temperature = temperature; + g_last_configure.configure_calls++; + return 0; +} + +// 模拟 configure 失败的 provider service / Fake provider service whose configure always fails +static int fake_configure_fail(const char*, const char*, const char*, const char*, int, double) +{ + g_last_configure.configure_calls++; + return -1; +} + +static dstalk_chat_result_t fake_chat(const dstalk_message_t*, int, + const char*, const char*) +{ + g_last_configure.chat_calls++; + dstalk_chat_result_t r = {}; + r.ok = 1; + r.content = fake_strdup("ok"); + return r; +} + +static int test_stream_cb(const char*, void* userdata) +{ + int* count = static_cast(userdata); + if (count) (*count)++; + return 0; +} + +static dstalk_chat_result_t fake_chat_stream(const dstalk_message_t*, int, + const char*, dstalk_stream_cb cb, + void* userdata) +{ + g_last_configure.stream_calls++; + if (cb) cb("tok", userdata); + dstalk_chat_result_t r = {}; + r.ok = 1; + r.content = fake_strdup("stream-ok"); + return r; +} + +static void fake_free_result(dstalk_chat_result_t* result) +{ + if (!result) return; + if (result->content) fake_free((void*)result->content); + if (result->error) fake_free((void*)result->error); + if (result->tool_calls_json) fake_free((void*)result->tool_calls_json); + result->content = nullptr; + result->error = nullptr; + result->tool_calls_json = nullptr; +} + +static dstalk_ai_service_t g_fake_openai_service = { + &fake_configure, + &fake_chat, + &fake_chat_stream, + &fake_free_result, +}; + +static dstalk_ai_service_t g_fake_anthropic_service = { + &fake_configure, + &fake_chat, + &fake_chat_stream, + &fake_free_result, +}; + +// configure 总是失败的 provider 服务 / Provider service whose configure always fails +static dstalk_ai_service_t g_fake_failing_service = { + &fake_configure_fail, + &fake_chat, + &fake_chat_stream, + &fake_free_result, +}; + +static void* fake_query_service(const char* name, int min_version) +{ + if (!name || min_version > 1) return nullptr; + if (std::strcmp(name, "ai_openai") == 0) return &g_fake_openai_service; + if (std::strcmp(name, "ai_anthropic") == 0) return &g_fake_anthropic_service; + if (std::strcmp(name, "ai_failing") == 0) return &g_fake_failing_service; + return nullptr; +} + +static dstalk_host_api_t make_fake_host() +{ + dstalk_host_api_t host = {}; + host.register_service = fake_register_service; + host.query_service = fake_query_service; + host.unregister_service = fake_unregister_service; + host.config_get = fake_config_get; + host.config_set = fake_config_set; + host.log = fake_log; + host.alloc = fake_alloc; + host.free = fake_free; + host.strdup = fake_strdup; + return host; +} + +static void setup_endpoint_config() +{ + g_config_values.clear(); + g_config_values["endpoints.names"] = "openai_main, anthropic_alt, missing_provider, openai_main"; + g_config_values["endpoints.active"] = "anthropic_alt"; + + g_config_values["endpoint.openai_main.provider"] = "ai_openai"; + g_config_values["endpoint.openai_main.api_key"] = "sk-openai-test"; + g_config_values["endpoint.openai_main.model"] = "gpt-4o"; + g_config_values["endpoint.openai_main.max_tokens"] = "1234"; + g_config_values["endpoint.openai_main.temperature"] = "0.25"; + + g_config_values["endpoint.anthropic_alt.provider"] = "ai_anthropic"; + g_config_values["endpoint.anthropic_alt.api_key"] = "sk-ant-test"; + g_config_values["endpoint.anthropic_alt.model"] = "claude-sonnet-test"; + + g_config_values["endpoint.missing_provider.provider"] = "ai_missing"; + g_config_values["endpoint.missing_provider.model"] = "missing-model"; +} + +// 设置单 endpoint 配置用于错误路径测试 / Set up single-endpoint config for error-path testing +static void setup_single_endpoint(const char* name, const char* provider, const char* model, + const char* base_url = nullptr) +{ + g_config_values.clear(); + std::string names_key = "endpoints.names"; + g_config_values[names_key] = name; + std::string prefix = std::string("endpoint.") + name + "."; + g_config_values[prefix + "provider"] = provider; + g_config_values[prefix + "api_key"] = "sk-test"; + g_config_values[prefix + "model"] = model; + if (base_url && *base_url) { + g_config_values[prefix + "base_url"] = base_url; + } +} + +static void reset_test_state() +{ + on_shutdown(); + g_registered_mgr = nullptr; + g_last_configure = {}; + g_config_values.clear(); +} + +int main() +{ + // ================================================================ + // 主测试流程 / Main test flow (existing) + // ================================================================ + on_shutdown(); + setup_endpoint_config(); + g_registered_mgr = nullptr; + g_last_configure = {}; + + dstalk_host_api_t host = make_fake_host(); + int init_rc = on_init(&host); + CHECK(init_rc == 0, "on_init registers endpoint manager service"); + CHECK(g_registered_mgr != nullptr, "registered service pointer captured"); + + const dstalk_ai_endpoint_mgr_t* mgr = g_registered_mgr; + CHECK(mgr && mgr->count() == 2, "loads valid endpoints, skips missing provider and duplicate"); + + const char* active = mgr ? mgr->get_active() : nullptr; + CHECK(active && std::strcmp(active, "anthropic_alt") == 0, + "configured active endpoint is selected"); + + char* list = mgr ? mgr->list_json() : nullptr; + std::string list_json = list ? list : ""; + if (list) fake_free(list); + CHECK(list_json.find("openai_main") != std::string::npos, + "list_json includes OpenAI endpoint"); + CHECK(list_json.find("anthropic_alt") != std::string::npos, + "list_json includes Anthropic endpoint"); + CHECK(list_json.find("https://api.anthropic.com") != std::string::npos, + "Anthropic endpoint uses Anthropic default base_url"); + CHECK(list_json.find("sk-openai-test") == std::string::npos && + list_json.find("sk-ant-test") == std::string::npos, + "list_json does not expose API keys"); + + CHECK(mgr->set_active("missing") == -2, "set_active rejects unknown endpoint"); + CHECK(mgr->set_active("openai_main") == 0, "set_active accepts known endpoint"); + CHECK(std::strcmp(mgr->get_active(), "openai_main") == 0, + "get_active reflects set_active change"); + + CHECK(mgr->set_model(nullptr, "gpt-4.1-mini") == 0, + "set_model(nullptr, model) updates active endpoint"); + CHECK(g_config_values["endpoint.openai_main.model"] == "gpt-4.1-mini", + "set_model mirrors model to host config store"); + CHECK(mgr->set_model("missing", "model") == -2, + "set_model rejects unknown endpoint"); + CHECK(mgr->set_model("openai_main", "") == -1, + "set_model rejects empty model"); + + dstalk_message_t msg = {"user", "hello", nullptr, nullptr}; + dstalk_chat_result_t r = mgr->chat(nullptr, &msg, 1, "hi", "[]"); + CHECK(r.ok == 1 && r.content && std::strcmp(r.content, "ok") == 0, + "chat routes to active endpoint service"); + CHECK(g_last_configure.provider == "ai_openai" && + g_last_configure.base_url == "https://api.openai.com/v1" && + g_last_configure.model == "gpt-4.1-mini" && + g_last_configure.max_tokens == 1234 && + g_last_configure.temperature == 0.25, + "chat configures OpenAI endpoint before routing"); + mgr->free_result(&r); + + CHECK(mgr->set_active("anthropic_alt") == 0, "switch active endpoint to Anthropic"); + r = mgr->chat(nullptr, &msg, 1, "hi", nullptr); + CHECK(r.ok == 1 && g_last_configure.provider == "ai_anthropic" && + g_last_configure.base_url == "https://api.anthropic.com" && + g_last_configure.model == "claude-sonnet-test", + "chat configures Anthropic endpoint before routing"); + mgr->free_result(&r); + + g_stream_cb_count = 0; + r = mgr->chat_stream("anthropic_alt", &msg, 1, "hi", test_stream_cb, &g_stream_cb_count); + CHECK(r.ok == 1 && g_stream_cb_count == 1, + "chat_stream routes callback through selected endpoint"); + mgr->free_result(&r); + + on_shutdown(); + CHECK(mgr->count() == 0, "on_shutdown clears endpoint cache"); + + // ================================================================ + // 错误路径测试 / Error-path tests (P3) + // ================================================================ + + // --- E1: null history with non-zero length / 空指针 history --- + { + reset_test_state(); + setup_single_endpoint("test_ep", "ai_openai", "gpt-test"); + host = make_fake_host(); + CHECK(on_init(&host) == 0, "E1: init with single endpoint"); + mgr = g_registered_mgr; + CHECK(mgr && mgr->count() == 1, "E1: one endpoint loaded"); + + // null history + non-zero len -> 应返回错误 / should return error + r = mgr->chat("test_ep", nullptr, 3, "hi", nullptr); + CHECK(r.ok == 0 && r.error != nullptr && + std::strcmp(r.error, "null history with non-zero length") == 0, + "E1: null history with len>0 returns error for chat"); + mgr->free_result(&r); + + r = mgr->chat_stream("test_ep", nullptr, 2, "hi", nullptr, nullptr); + CHECK(r.ok == 0 && r.error != nullptr && + std::strcmp(r.error, "null history with non-zero length") == 0, + "E1: null history with len>0 returns error for chat_stream"); + mgr->free_result(&r); + + // null history + zero len -> 应正常 (无历史) / should pass (empty history) + r = mgr->chat("test_ep", nullptr, 0, "hi", nullptr); + CHECK(r.ok == 1, "E1: null history with len=0 passes for chat"); + mgr->free_result(&r); + + on_shutdown(); + } + + // --- E2: missing endpoint chat / 缺失 endpoint --- + { + reset_test_state(); + setup_single_endpoint("test_ep", "ai_openai", "gpt-test"); + host = make_fake_host(); + CHECK(on_init(&host) == 0, "E2: init"); + mgr = g_registered_mgr; + + r = mgr->chat("nonexistent_ep", &msg, 1, "hi", nullptr); + CHECK(r.ok == 0 && r.error != nullptr && + std::strcmp(r.error, "endpoint not found") == 0, + "E2: chat with nonexistent endpoint returns error"); + mgr->free_result(&r); + + r = mgr->chat_stream("nonexistent_ep", &msg, 1, "hi", nullptr, nullptr); + CHECK(r.ok == 0 && r.error != nullptr && + std::strcmp(r.error, "endpoint not found") == 0, + "E2: chat_stream with nonexistent endpoint returns error"); + mgr->free_result(&r); + + on_shutdown(); + } + + // --- E3: configure failed / configure 失败 --- + { + reset_test_state(); + // ai_failing provider 的 configure 总是返回 -1 / ai_failing provider's configure always returns -1 + setup_single_endpoint("fail_ep", "ai_failing", "fail-model", "https://fail.example.com/api"); + host = make_fake_host(); + int rc = on_init(&host); + CHECK(rc == 0, "E3: init with failing endpoint"); + mgr = g_registered_mgr; + CHECK(mgr && mgr->count() == 1, "E3: failing endpoint loaded (service exists)"); + + r = mgr->chat("fail_ep", &msg, 1, "hi", nullptr); + CHECK(r.ok == 0 && r.error != nullptr && + std::strcmp(r.error, "endpoint configure failed") == 0, + "E3: chat returns error when configure fails"); + mgr->free_result(&r); + + r = mgr->chat_stream("fail_ep", &msg, 1, "hi", nullptr, nullptr); + CHECK(r.ok == 0 && r.error != nullptr && + std::strcmp(r.error, "endpoint configure failed") == 0, + "E3: chat_stream returns error when configure fails"); + mgr->free_result(&r); + + on_shutdown(); + } + + // --- E4: empty endpoints (no endpoints.names config key) / 空 endpoint 列表 --- + { + reset_test_state(); + // 不设置任何 endpoint 配置 / No endpoint config at all + g_config_values.clear(); + host = make_fake_host(); + CHECK(on_init(&host) == 0, "E4: init with no endpoint config"); + mgr = g_registered_mgr; + CHECK(mgr && mgr->count() == 0, "E4: zero endpoints loaded"); + CHECK(mgr->get_active() == nullptr, "E4: get_active returns nullptr when no endpoints"); + + // 在没有 endpoint 的情况下,chat 应报错 / chat should error when no endpoints + r = mgr->chat(nullptr, &msg, 1, "hi", nullptr); + CHECK(r.ok == 0 && r.error != nullptr && + std::strcmp(r.error, "endpoint not found") == 0, + "E4: chat with no endpoints returns error"); + mgr->free_result(&r); + + on_shutdown(); + } + + // --- E5: bad active (active set to nonexistent endpoint) / 无效 active --- + { + reset_test_state(); + g_config_values["endpoints.names"] = "ep1"; + g_config_values["endpoints.active"] = "does_not_exist"; // 不存在的 active / nonexistent active + g_config_values["endpoint.ep1.provider"] = "ai_openai"; + g_config_values["endpoint.ep1.api_key"] = "sk-test"; + g_config_values["endpoint.ep1.model"] = "gpt-test"; + host = make_fake_host(); + CHECK(on_init(&host) == 0, "E5: init with bad active config"); + mgr = g_registered_mgr; + CHECK(mgr && mgr->count() == 1, "E5: endpoint loaded despite bad active"); + + // 当 active 指向不存在的 endpoint 时,get_active 应返回第一个加载有效的 endpoint + // When active points to nonexistent endpoint, get_active should return the first valid loaded endpoint + active = mgr->get_active(); + CHECK(active != nullptr && std::strcmp(active, "ep1") == 0, + "E5: get_active falls back to first loaded endpoint when configured active is invalid"); + + on_shutdown(); + } + + // --- E6: set_active with null/empty name / set_active 空名称 --- + { + reset_test_state(); + setup_single_endpoint("test_ep", "ai_openai", "gpt-test"); + host = make_fake_host(); + CHECK(on_init(&host) == 0, "E6: init"); + mgr = g_registered_mgr; + + CHECK(mgr->set_active(nullptr) == -1, "E6: set_active(nullptr) returns -1"); + CHECK(mgr->set_active("") == -1, "E6: set_active(\"\") returns -1"); + + on_shutdown(); + } + + // ================================================================ + // secure_zero 基础测试 / Basic secure_zero test + // ================================================================ + { + // 分配缓冲区并填充可识别模式 / Allocate buffer and fill with recognizable pattern + char buf[64]; + std::memset(buf, 0xAB, sizeof(buf)); + dstalk_ai::secure_zero(buf, sizeof(buf)); + bool all_zero = true; + for (int i = 0; i < (int)sizeof(buf); ++i) { + if (buf[i] != 0) { all_zero = false; break; } + } + CHECK(all_zero, "secure_zero wipes all bytes to zero"); + + // 空 size / zero size — 不崩溃 / should not crash + dstalk_ai::secure_zero(buf, 0); + CHECK(true, "secure_zero with size=0 does not crash"); + + // nullptr + zero size — 不崩溃 / should not crash + dstalk_ai::secure_zero(nullptr, 0); + CHECK(true, "secure_zero(nullptr, 0) does not crash"); + } + + // ================================================================ + // 轻量并发读写测试 / Lightweight concurrent read/write test + // ================================================================ + { + reset_test_state(); + setup_single_endpoint("conc_ep", "ai_openai", "gpt-concurrent"); + host = make_fake_host(); + CHECK(on_init(&host) == 0, "concurrency setup: init"); + mgr = g_registered_mgr; + CHECK(mgr && mgr->count() == 1, "concurrency setup: endpoint loaded"); + + const int kReaders = 4; + const int kIters = 100; + std::vector threads; + std::atomic errors{0}; + + // 读者线程: 反复调用 get_active / list_json / count / set_active + // Reader threads: repeatedly call get_active / list_json / count / set_active + for (int t = 0; t < kReaders; ++t) { + threads.emplace_back([mgr, &errors, t]() { + for (int i = 0; i < kIters; ++i) { + // 读操作 / read operations + const char* a = mgr->get_active(); + if (a && std::strcmp(a, "conc_ep") != 0) errors++; + int c = mgr->count(); + if (c != 1) errors++; + char* l = mgr->list_json(); + if (l) fake_free(l); + else errors++; + + // 轻量写操作: set_active 到同一个 endpoint / lightweight write + if (i % 10 == 0) { + int rc = mgr->set_active("conc_ep"); + if (rc != 0) errors++; + } + } + }); + } + + for (auto& th : threads) th.join(); + CHECK(errors.load() == 0, "concurrent read/write: no errors across threads"); + + on_shutdown(); + } + + // ================================================================ + // 总结 / Summary + // ================================================================ + if (g_failures == 0) { + std::cout << "\nendpoint_mgr_plugin_test: all checks passed\n"; + } else { + std::cerr << "\nendpoint_mgr_plugin_test: " << g_failures << " failure(s)\n"; + } + return g_failures == 0 ? 0 : 1; +} diff --git a/tests/smoke_test.cpp b/tests/smoke_test.cpp index dd7b129..c14553e 100644 --- a/tests/smoke_test.cpp +++ b/tests/smoke_test.cpp @@ -55,6 +55,15 @@ int main() << "provider = \"openai\"\n" << "base_url = \"https://api.openai.com/v1\"\n" << "api_key = \"test-key\"\n" + << "model = \"gpt-4o\"\n" + << "\n" + << "# minimal endpoint config for ai_endpoint_mgr / 最小 endpoint 配置供 ai_endpoint_mgr 加载\n" + << "[endpoints]\n" + << "names = \"gpt4o\"\n" + << "\n" + << "[endpoint.gpt4o]\n" + << "provider = \"ai_openai\"\n" + << "api_key = \"test-key\"\n" << "model = \"gpt-4o\"\n"; } @@ -217,6 +226,52 @@ int main() // 测试 dstalk_log / Test dstalk_log dstalk_log(DSTALK_LOG_INFO, "Smoke test completed successfully"); + // ======================================================================== + // 测试服务查询: ai_endpoint_mgr / Test service query: ai_endpoint_mgr + // 验证 endpoint 加载 / 列表脱敏 / count / get_active / Verify endpoint load / list sanitization / count / get_active + // 不调用真实 chat 网络 / No real chat network calls + // ======================================================================== + { + auto* ep_mgr = static_cast( + dstalk_service_query("ai_endpoint_mgr", 1)); + if (ep_mgr) { + std::cout << "[OK] ai_endpoint_mgr service found\n"; + + // 验证 endpoint 数量 >= 1 / Verify endpoint count >= 1 + int n = ep_mgr->count(); + if (n >= 1) { + std::cout << "[OK] ai_endpoint_mgr count = " << n << "\n"; + } else { + std::cerr << "[FAIL] ai_endpoint_mgr count = " << n << " (expected >= 1)\n"; + } + + // 验证 list_json 非空且不泄露 api_key / Verify list_json non-null and no api_key leak + char* list = ep_mgr->list_json(); + if (list) { + std::string list_str(list); + if (list_str.find("test-key") == std::string::npos) { + std::cout << "[OK] ai_endpoint_mgr list_json sanitized (no api_key leak): " + << list << "\n"; + } else { + std::cerr << "[FAIL] ai_endpoint_mgr list_json leaks api_key: " << list << "\n"; + } + dstalk_free(list); + } else { + std::cerr << "[FAIL] ai_endpoint_mgr list_json returned null\n"; + } + + // 验证 get_active 非空 / Verify get_active non-null + const char* active = ep_mgr->get_active(); + if (active) { + std::cout << "[OK] ai_endpoint_mgr get_active = " << active << "\n"; + } else { + std::cerr << "[FAIL] ai_endpoint_mgr get_active returned null\n"; + } + } else { + std::cerr << "[WARN] ai_endpoint_mgr service not found\n"; + } + } + // ======================================================================== // 扩展测试块 C2: null-safety / 转义边界 / tools 调用链 / session 健壮性 // Extended test block C2: null-safety / escape boundaries / tools chain / session robustness @@ -751,6 +806,15 @@ int main() << "provider = \"openai\"\n" << "base_url = \"https://api.openai.com/v1\"\n" << "api_key = \"test-key\"\n" + << "model = \"gpt-4o\"\n" + << "\n" + << "# minimal endpoint config for ai_endpoint_mgr / 最小 endpoint 配置供 ai_endpoint_mgr 加载\n" + << "[endpoints]\n" + << "names = \"gpt4o\"\n" + << "\n" + << "[endpoint.gpt4o]\n" + << "provider = \"ai_openai\"\n" + << "api_key = \"test-key\"\n" << "model = \"gpt-4o\"\n"; }