feat: add AI endpoint manager plugin with configuration and routing capabilities
Some checks failed
Some checks failed
- Introduced `ai_endpoint_mgr` plugin to manage multiple AI provider endpoints. - Added configuration reference documentation for `config.toml`. - Implemented endpoint loading, active endpoint switching, and model mutation. - Included error handling for missing endpoints and configuration failures. - Developed unit tests covering various scenarios including error paths and concurrency.
This commit is contained in:
90
config.example.toml
Normal file
90
config.example.toml
Normal file
@@ -0,0 +1,90 @@
|
||||
# ============================================================================
|
||||
# dstalk config.example.toml — 配置模板 / Configuration template
|
||||
# ============================================================================
|
||||
# 复制为 config.toml 并修改即可使用。
|
||||
# Copy to config.toml and edit before use.
|
||||
#
|
||||
# 两种配置方式(互斥,选择一种即可):
|
||||
# Two configuration modes (mutually exclusive, pick one):
|
||||
#
|
||||
# 1) 单 Provider 模式(简单,一个 AI 后端)
|
||||
# Single-provider mode (simple, one AI backend)
|
||||
# 使用 keys: ai.provider / api.base_url / api.api_key / api.model
|
||||
#
|
||||
# 2) 多 Endpoint 模式(高级,同时配置多个 AI 后端并通过 /endpoint 切换)
|
||||
# Multi-endpoint mode (advanced, multiple AI backends switchable via /endpoint)
|
||||
# 使用 keys: endpoints.names / endpoints.active / endpoint.<name>.*
|
||||
#
|
||||
# 如果同时配置了两种方式,ai_endpoint_mgr 的 chat 调用优先走 endpoints.*。
|
||||
# If both are configured, ai_endpoint_mgr chat calls prefer endpoints.*.
|
||||
# ============================================================================
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 方式 1: 单 Provider 模式 / Mode 1: Single Provider
|
||||
# ----------------------------------------------------------------------------
|
||||
# 取消下方注释即可使用 / Uncomment to use
|
||||
|
||||
# ai.provider = "ai_openai" # ai_openai 或 ai_anthropic / ai_openai or ai_anthropic
|
||||
#
|
||||
# api.base_url = "https://api.openai.com/v1"
|
||||
# api.api_key = "sk-your-key-here"
|
||||
# api.model = "gpt-4o"
|
||||
|
||||
# 或者用 Anthropic / Or use Anthropic:
|
||||
# ai.provider = "ai_anthropic"
|
||||
# api.base_url = "https://api.anthropic.com"
|
||||
# api.api_key = "sk-ant-your-key-here"
|
||||
# api.model = "claude-sonnet-4-20250514"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 方式 2: 多 Endpoint 模式 / Mode 2: Multi-Endpoint
|
||||
# ----------------------------------------------------------------------------
|
||||
# 取消下方注释即可使用,多个 endpoint 在对话中通过 /endpoint 命令切换。
|
||||
# Uncomment to use. Switch endpoints during a session via the /endpoint command.
|
||||
|
||||
# 逗号分隔的 endpoint 名称列表(至少一个) / Comma-separated endpoint names (at least one)
|
||||
# endpoints.names = "openai_main, anthropic_alt"
|
||||
|
||||
# 默认激活的 endpoint(可选,不设置则取列表第一个) / Default active endpoint (optional, defaults to first in list)
|
||||
# endpoints.active = "openai_main"
|
||||
|
||||
# --- openai_main ---
|
||||
# provider 必须 / required
|
||||
# endpoint.openai_main.provider = "ai_openai"
|
||||
|
||||
# base_url 可选 / optional (默认按 provider 自动填 / defaults by provider: OpenAI→api.openai.com/v1, Anthropic→api.anthropic.com)
|
||||
# endpoint.openai_main.base_url = "https://api.openai.com/v1"
|
||||
|
||||
# api_key 必须 / required
|
||||
# endpoint.openai_main.api_key = "sk-your-key-here"
|
||||
|
||||
# model 必须 / required
|
||||
# endpoint.openai_main.model = "gpt-4o"
|
||||
|
||||
# max_tokens 可选 / optional (默认 4096) / default 4096
|
||||
# endpoint.openai_main.max_tokens = 4096
|
||||
|
||||
# temperature 可选 / optional (默认 0.7, 范围 0.0~2.0) / default 0.7, range 0.0~2.0
|
||||
# endpoint.openai_main.temperature = 0.7
|
||||
|
||||
|
||||
# --- anthropic_alt ---
|
||||
# endpoint.anthropic_alt.provider = "ai_anthropic"
|
||||
# base_url 未设置时自动使用 https://api.anthropic.com / defaults to https://api.anthropic.com when unset
|
||||
# endpoint.anthropic_alt.api_key = "sk-ant-your-key-here"
|
||||
# endpoint.anthropic_alt.model = "claude-sonnet-4-20250514"
|
||||
# endpoint.anthropic_alt.max_tokens = 8192
|
||||
|
||||
|
||||
# --- deepseek (自定义 base_url 示例 / custom base_url example) ---
|
||||
# endpoint.deepseek.provider = "ai_openai"
|
||||
# endpoint.deepseek.base_url = "https://api.deepseek.com/v1"
|
||||
# endpoint.deepseek.api_key = "sk-deepseek-your-key-here"
|
||||
# endpoint.deepseek.model = "deepseek-chat"
|
||||
|
||||
# 提示 / Tips:
|
||||
# - list_json 输出不含 api_key(安全脱敏)。
|
||||
# list_json output excludes api_key (security de-identification).
|
||||
# - 未知 provider 必须显式配置 base_url,否则 endpoint 加载失败。
|
||||
# Unknown providers must explicitly set base_url or the endpoint won't load.
|
||||
@@ -24,6 +24,7 @@
|
||||
| 文档 | 说明 |
|
||||
|------|------|
|
||||
| [CLI 命令速查](reference/commands.md) | 全部 CLI 命令的别名、作用与示例 |
|
||||
| [配置参考](reference/config.md) | config.toml 单 provider 与多 endpoint 全部字段说明、默认值与安全提示 |
|
||||
| [Plugin ABI 契约](reference/plugin-abi.md) | 跨 DLL 通信的 C ABI 规范:内存所有权、堆纪律、回调线程安全 |
|
||||
|
||||
---
|
||||
@@ -53,7 +54,7 @@
|
||||
### 参考
|
||||
|
||||
- [ ] **API 参考** (`reference/api.md`) — TODO: dstalk_host.h 完整 API 说明与调用示例
|
||||
- [ ] **配置参考** (`reference/config.md`) — TODO: config.toml 所有字段的详细说明
|
||||
- [x] **配置参考** (`reference/config.md`) — config.toml 所有字段的详细说明、默认值与安全提示
|
||||
- [ ] **服务接口参考** (`reference/services.md`) — TODO: dstalk_services.h 中所有 vtable 接口定义
|
||||
|
||||
---
|
||||
@@ -63,6 +64,7 @@
|
||||
- 命令以 `$ ` 前缀表示, 在终端中运行
|
||||
- 代码块标注语言 (```c, ```toml, ```bash 等)
|
||||
- [ ] 表示计划中未完成的文档
|
||||
- 配置文件模板见项目根目录 `config.example.toml`
|
||||
|
||||
---
|
||||
|
||||
|
||||
156
docs/reference/config.md
Normal file
156
docs/reference/config.md
Normal file
@@ -0,0 +1,156 @@
|
||||
# 配置参考 / Configuration Reference
|
||||
|
||||
`config.toml` 是 dstalk 的唯一配置文件,放在项目根目录。本文档列出所有支持字段、类型、默认值与使用说明。
|
||||
|
||||
---
|
||||
|
||||
## 配置概览
|
||||
|
||||
配置分为两种模式:
|
||||
|
||||
| 模式 | 适用场景 | 核心 key |
|
||||
|------|----------|----------|
|
||||
| **单 Provider** | 只需一个 AI 后端 | `ai.provider`, `api.*` |
|
||||
| **多 Endpoint** | 同时配置多个 AI 后端,运行时切换 | `endpoints.names`, `endpoint.<name>.*` |
|
||||
|
||||
两种模式可以同时配置,但 `ai_endpoint_mgr` 的 `chat` / `chat_stream` 调用优先走 `endpoints.*`。
|
||||
|
||||
---
|
||||
|
||||
## 一、单 Provider 模式 (Legacy)
|
||||
|
||||
直接在 `[global]` 下声明。这些 key 无默认值——不配置则无法启动 AI 对话。
|
||||
|
||||
### ai.provider
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 无(必填)
|
||||
- **值**: `"ai_openai"` 或 `"ai_anthropic"`
|
||||
- **说明**: 指定要加载的 AI provider 插件。该名称对应插件 DLL 注册的服务名。
|
||||
|
||||
### api.base_url
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 无(必填)
|
||||
- **说明**: AI API 的基础 URL。例如:
|
||||
- OpenAI: `https://api.openai.com/v1`
|
||||
- Anthropic: `https://api.anthropic.com`
|
||||
- DeepSeek (OpenAI 兼容): `https://api.deepseek.com/v1`
|
||||
|
||||
### api.api_key
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 无(必填)
|
||||
- **说明**: API 密钥。注意保管,不要提交到版本控制。
|
||||
|
||||
### api.model
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 无(必填)
|
||||
- **说明**: 要使用的模型名称。例如 `gpt-4o`、`claude-sonnet-4-20250514`、`deepseek-chat`。
|
||||
|
||||
---
|
||||
|
||||
## 二、多 Endpoint 模式 (推荐)
|
||||
|
||||
由 `ai_endpoint_mgr` 插件管理。每个 endpoint 是一个命名的 AI 后端配置,运行时可通过 `/endpoint` 命令切换。
|
||||
|
||||
### endpoints.names
|
||||
|
||||
- **类型**: string (逗号分隔)
|
||||
- **默认**: 无
|
||||
- **说明**: 所有要启用的 endpoint 名称列表。例如 `"openai_main, anthropic_alt"`。重复名称会被跳过并输出警告。
|
||||
|
||||
### endpoints.active
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 取 `endpoints.names` 中第一个成功加载的 endpoint
|
||||
- **说明**: 启动时默认使用的 endpoint 名称。如果指定的名称不在 `endpoints.names` 中或对应配置无效,则回退到第一个成功加载的 endpoint。
|
||||
|
||||
---
|
||||
|
||||
### endpoint.\<name\>.provider
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 无(必填)
|
||||
- **值**: `"ai_openai"` 或 `"ai_anthropic"`
|
||||
- **说明**: 该 endpoint 使用的 AI provider 服务名称。
|
||||
|
||||
### endpoint.\<name\>.base_url
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 按 provider 自动推导
|
||||
- **说明**:
|
||||
- 未配置时,已知 provider 自动使用默认 base_url:
|
||||
- `ai_anthropic` → `https://api.anthropic.com`
|
||||
- `ai_openai` → `https://api.openai.com/v1`
|
||||
- 如果 provider 不在已知列表中,**必须**显式配置 `base_url`,否则该 endpoint 加载失败。
|
||||
- 显式配置的值优先于默认值,可用于自定义端点(如代理、DeepSeek 兼容 API 等)。
|
||||
|
||||
### endpoint.\<name\>.api_key
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 无(必填)
|
||||
- **说明**: 该 endpoint 的 API 密钥。
|
||||
|
||||
> **安全提示**: `api_key` **不会**出现在 `list_json()` 的输出中,也不进入日志。这是有意为之的安全策略。
|
||||
|
||||
### endpoint.\<name\>.model
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: 无(必填)
|
||||
- **说明**: 该 endpoint 默认使用的模型名称。运行时可通过 `set_model()` 修改并同步回 `config.toml`。
|
||||
|
||||
### endpoint.\<name\>.max_tokens
|
||||
|
||||
- **类型**: integer
|
||||
- **默认**: `4096`
|
||||
- **范围**: `1` ~ `1000000`(超出范围的值视为无效,回退到默认值)
|
||||
- **说明**: 每次请求的最大输出 token 数。
|
||||
|
||||
### endpoint.\<name\>.temperature
|
||||
|
||||
- **类型**: double
|
||||
- **默认**: `0.7`
|
||||
- **范围**: `0.0` ~ `2.0`(超出范围的值视为无效,回退到默认值)
|
||||
- **说明**: 生成温度。越高随机性越强,越低越确定性。
|
||||
|
||||
---
|
||||
|
||||
## 三、ui 配置 (可选 / optional)
|
||||
|
||||
以下 key 目前为 CLI 前端预留,后续 GUI/Web 前端也会使用。
|
||||
|
||||
### ui.prompt
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: `"> "`
|
||||
- **说明**: 命令行提示符字符串。
|
||||
|
||||
### ui.multiline
|
||||
|
||||
- **类型**: string
|
||||
- **默认**: `"/"` (即不支持多行输入)
|
||||
- **说明**: 多行输入结束符。发送该字符串结束多行输入模式。
|
||||
|
||||
---
|
||||
|
||||
## 四、完整配置示例
|
||||
|
||||
见项目根目录 `config.example.toml`。
|
||||
|
||||
---
|
||||
|
||||
## 五、运行时 endpoint 操作
|
||||
|
||||
`ai_endpoint_mgr` 服务提供以下接口(通过 C ABI)供前端调用:
|
||||
|
||||
| 操作 | 说明 |
|
||||
|------|------|
|
||||
| `count()` | 返回已加载的 endpoint 数量 |
|
||||
| `list_json()` | 返回 JSON 数组,每项含 `name`, `provider`, `base_url`, `model`, `active` (不含 `api_key`) |
|
||||
| `get_active()` | 返回当前激活的 endpoint 名称 |
|
||||
| `set_active(name)` | 切换激活 endpoint;返回 0 成功,-2 表示不存在 |
|
||||
| `set_model(name, model)` | 修改 endpoint 的模型并同步到 config;name 为 null 时作用于当前 active endpoint |
|
||||
| `chat(ep, history, len, input, tools)` | 路由对话到指定 endpoint(ep 为 null 时用 active) |
|
||||
| `chat_stream(ep, history, len, input, cb, userdata)` | 流式路由对话 |
|
||||
@@ -96,6 +96,8 @@ api.model = "gpt-4o"
|
||||
> **关键**: 修改 `ai.provider` 字段即可在不同后端间切换, 无需改动代码。
|
||||
>
|
||||
> API Key 可从 [OpenAI-compatible 开放平台](https://platform.openai.com/) 或 [Anthropic Console](https://console.anthropic.com/) 获取。
|
||||
>
|
||||
> **多 Endpoint 配置** (同时使用多个 AI 后端并在运行时通过 `/endpoint` 切换): 参见 [配置参考](../reference/config.md) 和项目根目录 `config.example.toml`。
|
||||
|
||||
---
|
||||
|
||||
@@ -172,5 +174,6 @@ dstalk v0.1.0 | dstalk AI | /help 查看帮助 | /quit 退出
|
||||
## 下一步
|
||||
|
||||
- 查看 [CLI 命令速查表](../reference/commands.md) 了解全部命令
|
||||
- 查看 [配置参考](../reference/config.md) 了解 config.toml 所有字段
|
||||
- 输入 `/help` 在 dstalk 内查看命令列表
|
||||
- 输入 `/status` 查看当前运行状态
|
||||
|
||||
@@ -56,6 +56,7 @@ static const dstalk_ai_service_t* g_ai = nullptr;
|
||||
static const dstalk_session_service_t* g_session = nullptr;
|
||||
static const dstalk_file_io_service_t* g_file_io = nullptr;
|
||||
static const dstalk_tools_service_t* g_tools = nullptr;
|
||||
static const dstalk_ai_endpoint_mgr_t* g_endpoint_mgr = nullptr; // I08: AI endpoint manager(可选)/ optional
|
||||
|
||||
// ---- 运行时状态 / Runtime state ----
|
||||
// g_current_model tracks the active model name for display in the prompt.
|
||||
@@ -134,6 +135,61 @@ static void spinner_join()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- AI 调用路由(endpoint_mgr 优先,g_ai fallback)/ AI call routing (endpoint_mgr preferred, g_ai fallback) ----
|
||||
// 当 endpoint_mgr 可用且至少有一个已配置 endpoint 时,通过 endpoint_mgr 路由调用;
|
||||
// 否则回退到直接使用 g_ai 服务(保持旧配置兼容)。
|
||||
// When endpoint_mgr is available with >=1 configured endpoints, route through it;
|
||||
// otherwise fall back to direct g_ai service (keeping old config compatible).
|
||||
|
||||
// 是否有可用的 endpoint_mgr / Whether endpoint_mgr is usable
|
||||
static inline bool has_endpoint_mgr()
|
||||
{
|
||||
return g_endpoint_mgr != nullptr && g_endpoint_mgr->count() > 0;
|
||||
}
|
||||
|
||||
// 是否有任一 AI 后端 / Whether any AI backend is usable
|
||||
static inline bool has_ai_backend()
|
||||
{
|
||||
return has_endpoint_mgr() || g_ai != nullptr;
|
||||
}
|
||||
|
||||
// 阻塞 chat 路由 / Blocking chat routing
|
||||
static dstalk_chat_result_t do_chat(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input, const char* tools_json)
|
||||
{
|
||||
if (has_endpoint_mgr())
|
||||
return g_endpoint_mgr->chat(nullptr, history, history_len, user_input, tools_json);
|
||||
return g_ai->chat(history, history_len, user_input, tools_json);
|
||||
}
|
||||
|
||||
// 流式 chat 路由 / Streaming chat routing
|
||||
static dstalk_chat_result_t do_chat_stream(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input, dstalk_stream_cb cb, void* userdata)
|
||||
{
|
||||
if (has_endpoint_mgr())
|
||||
return g_endpoint_mgr->chat_stream(nullptr, history, history_len, user_input, cb, userdata);
|
||||
return g_ai->chat_stream(history, history_len, user_input, cb, userdata);
|
||||
}
|
||||
|
||||
// 释放 chat result(使用对应服务) / Free chat result (use corresponding service)
|
||||
static void do_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
if (has_endpoint_mgr())
|
||||
g_endpoint_mgr->free_result(result);
|
||||
else
|
||||
g_ai->free_result(result);
|
||||
}
|
||||
|
||||
// 设置模型(endpoint_mgr 优先) / Set model (endpoint_mgr preferred)
|
||||
static int do_set_model(const char* model)
|
||||
{
|
||||
if (has_endpoint_mgr())
|
||||
return g_endpoint_mgr->set_model(nullptr, model);
|
||||
return g_ai->configure(nullptr, nullptr, nullptr, model, 0, 0.0);
|
||||
}
|
||||
|
||||
// ---- 错误分类与友好提示 / Error classification and user-friendly messages ----
|
||||
// 根据 HTTP 状态码和错误消息字符串匹配,将常见错误归类为认证/频率限制/网络/配额问题,并给出中文建议。
|
||||
// Classifies common errors into auth/rate-limit/network/quota categories based on HTTP status and string matching, with Chinese suggestions.
|
||||
@@ -382,6 +438,21 @@ static void handle_command(const char* line)
|
||||
const dstalk_tools_service_t* tools = static_cast<const dstalk_tools_service_t*>(
|
||||
dstalk_service_query("tools", 1));
|
||||
std::printf(" Tools 服务: %s\n", tools ? "就绪" : "不可用");
|
||||
|
||||
// I08/I09: endpoint manager 状态 / endpoint manager status
|
||||
if (g_endpoint_mgr) {
|
||||
std::printf(" --- Endpoint Manager ---\n");
|
||||
std::printf(" 状态: 就绪 (%d endpoint(s))\n", g_endpoint_mgr->count());
|
||||
const char* active = g_endpoint_mgr->get_active();
|
||||
std::printf(" Active Endpoint: %s\n", active ? active : "(无)");
|
||||
char* list_json = g_endpoint_mgr->list_json();
|
||||
if (list_json) {
|
||||
std::printf(" Endpoints: %s\n", list_json); // JSON 不含 api_key,已脱敏 / no api_key in JSON, already desensitized
|
||||
dstalk_free(list_json);
|
||||
}
|
||||
} else {
|
||||
std::printf(" Endpoint Manager: 不可用\n");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -393,7 +464,15 @@ static void handle_command(const char* line)
|
||||
std::printf(CLR_RED "[ERROR] /model 需要模型名\n" CLR_RESET);
|
||||
return;
|
||||
}
|
||||
if (g_ai) {
|
||||
// I08: 优先通过 endpoint_mgr 设置模型,fallback 到 g_ai->configure / prefer endpoint_mgr, fallback to g_ai
|
||||
if (g_endpoint_mgr && g_endpoint_mgr->count() > 0) {
|
||||
if (g_endpoint_mgr->set_model(nullptr, model) == 0) {
|
||||
g_current_model = model;
|
||||
std::printf(CLR_GREEN "[OK] 模型已切换: %s (via endpoint_mgr)\n" CLR_RESET, model);
|
||||
} else {
|
||||
std::printf(CLR_RED "[ERROR] 模型切换失败(endpoint 不存在或未配置)\n" CLR_RESET);
|
||||
}
|
||||
} else if (g_ai) {
|
||||
g_ai->configure(nullptr, nullptr, nullptr, model, 0, 0.0);
|
||||
g_current_model = model;
|
||||
std::printf(CLR_GREEN "[OK] 模型已切换: %s\n" CLR_RESET, model);
|
||||
@@ -645,6 +724,9 @@ int main(int argc, char* argv[])
|
||||
g_session = static_cast<const dstalk_session_service_t*>(dstalk_service_query("session", 1));
|
||||
g_file_io = static_cast<const dstalk_file_io_service_t*>(dstalk_service_query("file_io", 1));
|
||||
g_tools = static_cast<const dstalk_tools_service_t*>(dstalk_service_query("tools", 1));
|
||||
// I08: 查询 AI endpoint manager(可选服务)/ query AI endpoint manager (optional service)
|
||||
g_endpoint_mgr = static_cast<const dstalk_ai_endpoint_mgr_t*>(
|
||||
dstalk_service_query("ai_endpoint_mgr", 1));
|
||||
|
||||
if (!g_ai) {
|
||||
std::fprintf(stderr, CLR_RED "[dstalk] AI 服务未找到(请检查插件目录)\n" CLR_RESET);
|
||||
@@ -663,6 +745,12 @@ int main(int argc, char* argv[])
|
||||
g_ai->configure(ai_provider, base_url, api_key ? api_key : "", model, 4096, 0.7);
|
||||
g_current_model = model; // A1: 记录当前模型名 / Record current model name
|
||||
}
|
||||
// I08: 记录 endpoint_mgr 可用性 / log endpoint_mgr availability
|
||||
if (g_endpoint_mgr && g_endpoint_mgr->count() > 0) {
|
||||
const char* active = g_endpoint_mgr->get_active();
|
||||
std::fprintf(stderr, "[dstalk] endpoint_mgr: %d endpoint(s), active=%s\n",
|
||||
g_endpoint_mgr->count(), active ? active : "(none)");
|
||||
}
|
||||
|
||||
if (!batch_mode) {
|
||||
std::printf("\n");
|
||||
@@ -678,22 +766,23 @@ int main(int argc, char* argv[])
|
||||
dstalk_shutdown();
|
||||
return EXIT_FATAL;
|
||||
}
|
||||
if (!g_ai || !g_session) {
|
||||
if (!has_ai_backend() || !g_session) {
|
||||
std::fprintf(stderr, CLR_RED "[ERROR] AI or session service unavailable\n" CLR_RESET);
|
||||
dstalk_shutdown();
|
||||
return EXIT_CONFIG;
|
||||
}
|
||||
int history_count = 0;
|
||||
const dstalk_message_t* history = g_session->history(&history_count);
|
||||
dstalk_chat_result_t result = g_ai->chat(history, history_count, input.c_str(), nullptr);
|
||||
// I08: 通过 endpoint_mgr 路由(优先),或 fallback 到 g_ai / route via endpoint_mgr (preferred), or fallback to g_ai
|
||||
dstalk_chat_result_t result = do_chat(history, history_count, input.c_str(), nullptr);
|
||||
if (result.ok) {
|
||||
std::printf("%s\n", result.content ? result.content : "");
|
||||
g_ai->free_result(&result);
|
||||
do_free_result(&result);
|
||||
dstalk_shutdown();
|
||||
return EXIT_OK;
|
||||
} else {
|
||||
print_error(result.error, result.http_status);
|
||||
g_ai->free_result(&result);
|
||||
do_free_result(&result);
|
||||
dstalk_shutdown();
|
||||
return EXIT_FATAL;
|
||||
}
|
||||
@@ -718,22 +807,23 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
prompt_text = prompt_arg;
|
||||
}
|
||||
if (!g_ai || !g_session) {
|
||||
if (!has_ai_backend() || !g_session) {
|
||||
std::fprintf(stderr, CLR_RED "[ERROR] AI or session service unavailable\n" CLR_RESET);
|
||||
dstalk_shutdown();
|
||||
return EXIT_CONFIG;
|
||||
}
|
||||
int history_count = 0;
|
||||
const dstalk_message_t* history = g_session->history(&history_count);
|
||||
dstalk_chat_result_t result = g_ai->chat(history, history_count, prompt_text.c_str(), nullptr);
|
||||
// I08: 通过 endpoint_mgr 路由(优先),或 fallback 到 g_ai / route via endpoint_mgr (preferred), or fallback to g_ai
|
||||
dstalk_chat_result_t result = do_chat(history, history_count, prompt_text.c_str(), nullptr);
|
||||
if (result.ok) {
|
||||
std::printf("%s\n", result.content ? result.content : "");
|
||||
g_ai->free_result(&result);
|
||||
do_free_result(&result);
|
||||
dstalk_shutdown();
|
||||
return EXIT_OK;
|
||||
} else {
|
||||
print_error(result.error, result.http_status);
|
||||
g_ai->free_result(&result);
|
||||
do_free_result(&result);
|
||||
dstalk_shutdown();
|
||||
return EXIT_FATAL;
|
||||
}
|
||||
@@ -770,7 +860,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// AI 对话(通过插件服务 vtable) / AI chat (via plugin service vtable)
|
||||
if (!g_ai || !g_session) {
|
||||
if (!has_ai_backend() || !g_session) {
|
||||
std::printf(CLR_RED "[ERROR] AI 或 Session 服务不可用\n" CLR_RESET);
|
||||
continue;
|
||||
}
|
||||
@@ -782,7 +872,8 @@ int main(int argc, char* argv[])
|
||||
// 启动 spinner,等待 AI 响应 / Start spinner while waiting for AI response
|
||||
spinner_start();
|
||||
bool first = true;
|
||||
dstalk_chat_result_t result = g_ai->chat_stream(
|
||||
// I08: 通过 endpoint_mgr 路由(优先),或 fallback 到 g_ai / route via endpoint_mgr (preferred), or fallback to g_ai
|
||||
dstalk_chat_result_t result = do_chat_stream(
|
||||
history, history_count, line.c_str(), on_stream_token, &first);
|
||||
|
||||
// 确保 spinner 已停止(处理无流式输出的情况) / Ensure spinner is stopped (handles no-stream-output case)
|
||||
@@ -866,10 +957,12 @@ int main(int argc, char* argv[])
|
||||
history_count = 0;
|
||||
history = g_session->history(&history_count);
|
||||
|
||||
g_ai->free_result(&result);
|
||||
// I08: 通过 endpoint_mgr 路由 free_result / route free_result via endpoint_mgr
|
||||
do_free_result(&result);
|
||||
spinner_start();
|
||||
bool tool_stream_first = true;
|
||||
result = g_ai->chat_stream(history, history_count, nullptr, on_stream_token, &tool_stream_first);
|
||||
// I08: 通过 endpoint_mgr 路由 chat_stream / route chat_stream via endpoint_mgr
|
||||
result = do_chat_stream(history, history_count, nullptr, on_stream_token, &tool_stream_first);
|
||||
spinner_stop();
|
||||
|
||||
if (result.ok) {
|
||||
@@ -896,7 +989,7 @@ int main(int argc, char* argv[])
|
||||
std::printf(CLR_RESET "\n");
|
||||
print_error(result.error, result.http_status);
|
||||
}
|
||||
g_ai->free_result(&result);
|
||||
do_free_result(&result);
|
||||
}
|
||||
|
||||
// B2: 单一退出点,dstalk_shutdown 只在此调用(交互模式下) / Single exit point, dstalk_shutdown only called here (in interactive mode)
|
||||
|
||||
@@ -35,6 +35,42 @@ typedef struct {
|
||||
void (*free_result)(dstalk_chat_result_t* result);
|
||||
} dstalk_ai_service_t;
|
||||
|
||||
/* ---- AI endpoint manager 服务 vtable / AI endpoint manager service vtable ---- */
|
||||
/* 以服务名称 "ai_endpoint_mgr" 注册;用于按名称路由到多个 provider/model endpoint。 / Registered as "ai_endpoint_mgr"; routes named endpoints to provider/model configs. */
|
||||
typedef struct {
|
||||
/* 返回已配置 endpoint 数量 / Return configured endpoint count. */
|
||||
int (*count)(void);
|
||||
/* 返回 endpoint 列表 JSON,调用方用 dstalk_free 释放 / Return endpoint list JSON; caller frees with dstalk_free. */
|
||||
char* (*list_json)(void);
|
||||
/* 获取当前 active endpoint 名称 / Get current active endpoint name.
|
||||
返回的指针指向 thread_local 存储区,调用方不需要释放。该指针在同线程下一次
|
||||
get_active 调用前或 active endpoint 状态变化前有效;跨线程并发调用各自拥有
|
||||
独立的 thread_local 副本,互不干扰。
|
||||
The returned pointer points to thread-local storage; caller must not free it.
|
||||
Valid until the next get_active call on the same thread or until the active
|
||||
endpoint changes. Concurrent calls from different threads each have their own
|
||||
independent thread-local copy. */
|
||||
const char* (*get_active)(void);
|
||||
/* 设置当前 active endpoint / Set current active endpoint. */
|
||||
int (*set_active)(const char* endpoint_name);
|
||||
/* 设置 endpoint 模型名;endpoint_name 为空时修改 active endpoint / Set endpoint model; null endpoint_name updates active endpoint. */
|
||||
int (*set_model)(const char* endpoint_name, const char* model);
|
||||
/* 在指定 endpoint 上执行阻塞 chat / Run blocking chat on a named endpoint. */
|
||||
dstalk_chat_result_t (*chat)(
|
||||
const char* endpoint_name,
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
const char* tools_json);
|
||||
/* 在指定 endpoint 上执行流式 chat / Run streaming chat on a named endpoint. */
|
||||
dstalk_chat_result_t (*chat_stream)(
|
||||
const char* endpoint_name,
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
dstalk_stream_cb cb, void* userdata);
|
||||
/* 释放 endpoint manager 返回的 chat result / Free chat result returned by endpoint manager. */
|
||||
void (*free_result)(dstalk_chat_result_t* result);
|
||||
} dstalk_ai_endpoint_mgr_t;
|
||||
|
||||
/* ---- 会话服务 vtable / Session service vtable ---- */
|
||||
/* 以服务名称 "session" 注册 / Registered under service name "session" */
|
||||
typedef struct {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// 提供所有前端(CLI / GUI / Web)共享的启动逻辑:
|
||||
// - 配置文件发现(argv / 默认路径 / 平台 fopen)
|
||||
// - dstalk_init() 调用
|
||||
// - 常用服务查询(ai, session, file_io, tools, context)
|
||||
// - 常用服务查询(ai / endpoint_mgr / session / file_io / tools / context)
|
||||
// - AI 服务默认配置(从 config 读取,带 fallback)
|
||||
// ============================================================================
|
||||
|
||||
@@ -20,6 +20,7 @@ struct FrontendServices {
|
||||
const dstalk_session_service_t* session = nullptr;
|
||||
const dstalk_file_io_service_t* file_io = nullptr;
|
||||
const dstalk_tools_service_t* tools = nullptr;
|
||||
const dstalk_ai_endpoint_mgr_t* endpoint_mgr = nullptr; // I08: AI endpoint manager(可选)/ optional
|
||||
|
||||
std::string provider; // "ai.deepseek" / "ai.openai" / "ai.anthropic"
|
||||
std::string model; // e.g. "deepseek-v4-pro"
|
||||
@@ -35,8 +36,8 @@ struct FrontendServices {
|
||||
// 功能:
|
||||
// 1. 发现配置文件:优先 argv[1](跳过已知标志),其次 default_config(如 "config.toml")
|
||||
// 2. 调用 dstalk_init(config_path)
|
||||
// 3. 查询常用插件服务(ai / session / file_io / tools)
|
||||
// 4. 用 dstalk_config_get 读取 api.* 键并调用 ai->configure() 设置默认值
|
||||
// 3. 查询常用插件服务(ai / endpoint_mgr / session / file_io / tools)
|
||||
// 4. 用 dstalk_config_get 读取 api.* 键并调用 ai->configure() 设置旧单 provider 默认值
|
||||
//
|
||||
// 参数:
|
||||
// svc - [out] 填入查询到的服务指针和配置信息
|
||||
|
||||
@@ -79,12 +79,19 @@ int dstalk_frontend_init(FrontendServices& svc,
|
||||
svc.tools = static_cast<const dstalk_tools_service_t*>(
|
||||
dstalk_service_query("tools", 1));
|
||||
|
||||
// I08: 查询 AI endpoint manager(可选服务)/ query AI endpoint manager (optional service)
|
||||
svc.endpoint_mgr = static_cast<const dstalk_ai_endpoint_mgr_t*>(
|
||||
dstalk_service_query("ai_endpoint_mgr", 1));
|
||||
|
||||
const dstalk_context_service_t* ctx_svc =
|
||||
static_cast<const dstalk_context_service_t*>(
|
||||
dstalk_service_query("context", 1));
|
||||
(void)ctx_svc; // 不强制使用,保留以备将来使用
|
||||
|
||||
if (!svc.ai) {
|
||||
// endpoint_mgr 可作为 AI 后端;缺少旧 ai provider 时仍允许多 endpoint 配置工作。
|
||||
// endpoint_mgr can serve as the AI backend; allow multi-endpoint config even when legacy ai provider is absent.
|
||||
const bool endpoint_mgr_ready = svc.endpoint_mgr && svc.endpoint_mgr->count() > 0;
|
||||
if (!svc.ai && !endpoint_mgr_ready) {
|
||||
std::fprintf(stderr, "[dstalk] AI 服务未找到(请检查插件目录)\n");
|
||||
return 2;
|
||||
}
|
||||
@@ -93,7 +100,8 @@ int dstalk_frontend_init(FrontendServices& svc,
|
||||
return 3;
|
||||
}
|
||||
|
||||
// (3) 配置 AI 服务的默认值
|
||||
// (3) 配置 AI 服务的默认值;endpoint_mgr 路径由 endpoint.<name>.* 配置驱动。
|
||||
// Configure legacy AI defaults; endpoint_mgr is driven by endpoint.<name>.* config.
|
||||
const char* base_url = dstalk_config_get("api.base_url");
|
||||
const char* api_key = dstalk_config_get("api.api_key");
|
||||
const char* model = dstalk_config_get("api.model");
|
||||
@@ -105,9 +113,11 @@ int dstalk_frontend_init(FrontendServices& svc,
|
||||
svc.api_key = api_key ? api_key : "";
|
||||
svc.model = model;
|
||||
|
||||
svc.ai->configure(provider, base_url,
|
||||
api_key ? api_key : "",
|
||||
model, 4096, 0.7);
|
||||
if (svc.ai) {
|
||||
svc.ai->configure(provider, base_url,
|
||||
api_key ? api_key : "",
|
||||
model, 4096, 0.7);
|
||||
}
|
||||
|
||||
svc.initialized = true;
|
||||
return 0;
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
// 在启动时从主机查询获取的服务 vtable 全局指针。
|
||||
static const dstalk_ai_service_t* g_ai_svc = nullptr;
|
||||
static const dstalk_session_service_t* g_session_svc = nullptr;
|
||||
static const dstalk_ai_endpoint_mgr_t* g_endpoint_mgr = nullptr; // I08: AI endpoint manager(可选)/ optional
|
||||
|
||||
// ---- 常量 / Constants ----
|
||||
|
||||
@@ -287,10 +288,19 @@ static void renderStatusBar(AppContext& ctx) {
|
||||
}
|
||||
|
||||
// 状态文本:模型名 | 消息条数 | 流式状态 / Status text: model name | message count | streaming state
|
||||
// I08: 添加 endpoint manager 信息(如果可用)/ add endpoint manager info (if available)
|
||||
char buf[256];
|
||||
snprintf(buf, sizeof(buf), "%s | %d messages | %s",
|
||||
gs.model_name.c_str(), msgCount,
|
||||
gs.streaming ? "streaming" : "ready");
|
||||
const char* active_ep = nullptr;
|
||||
if (g_endpoint_mgr) active_ep = g_endpoint_mgr->get_active();
|
||||
if (active_ep && active_ep[0]) {
|
||||
snprintf(buf, sizeof(buf), "%s | %d messages | %s | ep:%s",
|
||||
gs.model_name.c_str(), msgCount,
|
||||
gs.streaming ? "streaming" : "ready", active_ep);
|
||||
} else {
|
||||
snprintf(buf, sizeof(buf), "%s | %d messages | %s",
|
||||
gs.model_name.c_str(), msgCount,
|
||||
gs.streaming ? "streaming" : "ready");
|
||||
}
|
||||
drawText(r, static_cast<float>(PADDING),
|
||||
barY + (STATUS_H - CHAR_H) / 2.0f, buf, COL_WHITE);
|
||||
}
|
||||
@@ -832,6 +842,9 @@ int main(int argc, char* argv[]) {
|
||||
if (!ai_provider) ai_provider = "ai_openai";
|
||||
g_ai_svc = static_cast<const dstalk_ai_service_t*>(dstalk_service_query(ai_provider, 1));
|
||||
g_session_svc = static_cast<const dstalk_session_service_t*>(dstalk_service_query("session", 1));
|
||||
// I08: 查询 AI endpoint manager(可选服务)/ query AI endpoint manager (optional service)
|
||||
g_endpoint_mgr = static_cast<const dstalk_ai_endpoint_mgr_t*>(
|
||||
dstalk_service_query("ai_endpoint_mgr", 1));
|
||||
if (!g_ai_svc) dstalk_log(3, "AI service not found (check plugins directory)");
|
||||
if (!g_session_svc) dstalk_log(3, "Session service not found");
|
||||
|
||||
@@ -899,15 +912,25 @@ int main(int argc, char* argv[]) {
|
||||
std::string& userMsg =
|
||||
ctx.state.messages[ctx.state.messages.size() - 2].content;
|
||||
int rc = -1;
|
||||
if (g_ai_svc) {
|
||||
// I08: 优先通过 endpoint_mgr 路由,fallback 到 g_ai_svc / prefer endpoint_mgr, fallback to g_ai_svc
|
||||
const bool use_mgr = (g_endpoint_mgr && g_endpoint_mgr->count() > 0);
|
||||
if (use_mgr || g_ai_svc) {
|
||||
int hcount = 0;
|
||||
const dstalk_message_t* history = g_session_svc
|
||||
? g_session_svc->history(&hcount) : nullptr;
|
||||
dstalk_chat_result_t result = g_ai_svc->chat_stream(
|
||||
history, hcount, userMsg.c_str(),
|
||||
streamTokenCallback, &ctx);
|
||||
dstalk_chat_result_t result;
|
||||
if (use_mgr) {
|
||||
result = g_endpoint_mgr->chat_stream(
|
||||
nullptr, history, hcount, userMsg.c_str(),
|
||||
streamTokenCallback, &ctx);
|
||||
} else {
|
||||
result = g_ai_svc->chat_stream(
|
||||
history, hcount, userMsg.c_str(),
|
||||
streamTokenCallback, &ctx);
|
||||
}
|
||||
rc = result.ok ? 0 : -1;
|
||||
g_ai_svc->free_result(&result);
|
||||
if (use_mgr) g_endpoint_mgr->free_result(&result);
|
||||
else g_ai_svc->free_result(&result);
|
||||
}
|
||||
|
||||
// 流式传输完成(或被取消)
|
||||
|
||||
@@ -43,6 +43,7 @@ class SseSession;
|
||||
// 插件服务 vtable 的全局指针,在启动时从主机查询获取。
|
||||
static const dstalk_ai_service_t* g_ai = nullptr;
|
||||
static const dstalk_session_service_t* g_session = nullptr;
|
||||
static const dstalk_ai_endpoint_mgr_t* g_endpoint_mgr = nullptr; // I08: AI endpoint manager(可选)/ optional
|
||||
|
||||
// ---- 运行时状态 / Runtime state ----
|
||||
// g_quit signals the main loop to exit (set by Ctrl+C).
|
||||
@@ -208,8 +209,19 @@ static void run_chat_worker(
|
||||
};
|
||||
|
||||
// 调用流式 AI 聊天 / Call streaming AI chat
|
||||
dstalk_chat_result_t result = g_ai->chat_stream(
|
||||
history, history_count, nullptr, token_cb, &cb_data);
|
||||
// I08: 优先通过 endpoint_mgr 路由,fallback 到 g_ai / prefer endpoint_mgr, fallback to g_ai
|
||||
dstalk_chat_result_t result = {};
|
||||
const bool use_mgr = (g_endpoint_mgr && g_endpoint_mgr->count() > 0);
|
||||
if (use_mgr) {
|
||||
result = g_endpoint_mgr->chat_stream(
|
||||
nullptr, history, history_count, nullptr, token_cb, &cb_data);
|
||||
} else if (g_ai) {
|
||||
result = g_ai->chat_stream(
|
||||
history, history_count, nullptr, token_cb, &cb_data);
|
||||
} else {
|
||||
result.ok = 0;
|
||||
result.error = dstalk_strdup("AI service unavailable");
|
||||
}
|
||||
|
||||
// 将 AI 回复加入会话 / Add AI reply to session
|
||||
if (result.ok) {
|
||||
@@ -221,7 +233,11 @@ static void run_chat_worker(
|
||||
bool ok = result.ok;
|
||||
std::string content_copy = result.content ? result.content : "";
|
||||
std::string error_copy = result.error ? result.error : "";
|
||||
g_ai->free_result(&result);
|
||||
|
||||
// I08: 根据路由来源释放 result / free result based on routing source
|
||||
if (use_mgr) g_endpoint_mgr->free_result(&result);
|
||||
else if (g_ai) g_ai->free_result(&result);
|
||||
else if (result.error) dstalk_free((void*)result.error);
|
||||
|
||||
asio::post(ioc, [weak_sse, ok, content_copy, error_copy]() {
|
||||
if (auto sse = weak_sse.lock()) {
|
||||
@@ -373,6 +389,23 @@ private:
|
||||
if (model) st["model"] = std::string(model);
|
||||
st["status"] = "running";
|
||||
|
||||
// I08/I09: endpoint manager 状态 / endpoint manager status
|
||||
if (g_endpoint_mgr) {
|
||||
st["endpoint_mgr_available"] = true;
|
||||
st["endpoint_count"] = g_endpoint_mgr->count();
|
||||
const char* active = g_endpoint_mgr->get_active();
|
||||
if (active) st["active_endpoint"] = std::string(active);
|
||||
char* list_json = g_endpoint_mgr->list_json();
|
||||
if (list_json) {
|
||||
boost::system::error_code ec;
|
||||
auto jv = boost::json::parse(list_json, ec);
|
||||
if (!ec) st["endpoints"] = std::move(jv);
|
||||
dstalk_free(list_json);
|
||||
}
|
||||
} else {
|
||||
st["endpoint_mgr_available"] = false;
|
||||
}
|
||||
|
||||
auto self = shared_from_this();
|
||||
http::response<http::string_body> res{http::status::ok, request_.version()};
|
||||
res.set("Access-Control-Allow-Origin", "*");
|
||||
@@ -510,6 +543,9 @@ int main(int argc, char* argv[])
|
||||
if (!ai_provider) ai_provider = "ai_openai";
|
||||
g_ai = static_cast<const dstalk_ai_service_t*>(dstalk_service_query(ai_provider, 1));
|
||||
g_session = static_cast<const dstalk_session_service_t*>(dstalk_service_query("session", 1));
|
||||
// I08: 查询 AI endpoint manager(可选服务)/ query AI endpoint manager (optional service)
|
||||
g_endpoint_mgr = static_cast<const dstalk_ai_endpoint_mgr_t*>(
|
||||
dstalk_service_query("ai_endpoint_mgr", 1));
|
||||
|
||||
if (!g_ai) {
|
||||
std::fprintf(stderr, "[dstalk_web] AI service not found (check plugins directory)\n");
|
||||
|
||||
@@ -6,3 +6,4 @@ add_subdirectory(ai_common) # 共享 AI 工具库(静态库)/ shared AI u
|
||||
add_subdirectory(context) # 依赖 session / depends on session
|
||||
add_subdirectory(openai) # 依赖 http, config, ai_common / depends on http, config, ai_common
|
||||
add_subdirectory(anthropic) # 依赖 http, config, ai_common / depends on http, config, ai_common
|
||||
add_subdirectory(ai_endpoint_mgr) # 路由多个 AI endpoint / routes multiple AI endpoints
|
||||
|
||||
32
plugins_upper/ai_endpoint_mgr/CMakeLists.txt
Normal file
32
plugins_upper/ai_endpoint_mgr/CMakeLists.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
# ============================================================
|
||||
# AI endpoint manager plugin / AI endpoint manager 插件
|
||||
# ============================================================
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
add_library(plugin_ai_endpoint_mgr SHARED
|
||||
src/endpoint_mgr_plugin.cpp
|
||||
)
|
||||
|
||||
target_include_directories(plugin_ai_endpoint_mgr
|
||||
PRIVATE
|
||||
${CMAKE_SOURCE_DIR}/dstalk_core/include
|
||||
${CMAKE_SOURCE_DIR}/plugins_upper/ai_common/include
|
||||
)
|
||||
|
||||
target_link_libraries(plugin_ai_endpoint_mgr
|
||||
PRIVATE
|
||||
dstalk
|
||||
ai_common
|
||||
dstalk_boost_config
|
||||
boost::boost
|
||||
)
|
||||
|
||||
# cxx_std_20 已由 dstalk 和 ai_common (PUBLIC) 传播,无需重复声明
|
||||
# cxx_std_20 is already propagated by dstalk and ai_common (PUBLIC); no need to redeclare
|
||||
|
||||
set_target_properties(plugin_ai_endpoint_mgr PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
)
|
||||
400
plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp
Normal file
400
plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp
Normal file
@@ -0,0 +1,400 @@
|
||||
/*
|
||||
* @file endpoint_mgr_plugin.cpp
|
||||
* @brief AI endpoint manager: routes named endpoint configs to AI provider services.
|
||||
* AI endpoint 管理器:把命名 endpoint 配置路由到具体 AI provider 服务。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
#include "ai_common.hpp"
|
||||
|
||||
#include <boost/json.hpp>
|
||||
#include <boost/json/src.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <exception>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace json = boost::json;
|
||||
|
||||
struct EndpointConfig {
|
||||
std::string name;
|
||||
std::string provider;
|
||||
std::string base_url;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
int max_tokens = 4096;
|
||||
double temperature = 0.7;
|
||||
};
|
||||
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::unordered_map<std::string, EndpointConfig> g_endpoints;
|
||||
static std::string g_active_endpoint;
|
||||
static std::shared_mutex g_endpoints_mutex;
|
||||
|
||||
// 按 provider 名称动态分配互斥锁;避免未知 provider 错误共享 OpenAI/Anthropic 专用锁
|
||||
// Dynamically allocate mutex per provider name; prevents unknown providers from incorrectly sharing the OpenAI/Anthropic-specific locks
|
||||
static std::shared_mutex g_provider_mutexes_mutex;
|
||||
static std::unordered_map<std::string, std::unique_ptr<std::mutex>> g_provider_mutexes;
|
||||
|
||||
static std::mutex& provider_mutex(const std::string& provider)
|
||||
{
|
||||
// 快速路径:读锁查找已有 mutex / Fast path: read lock to find existing mutex
|
||||
{
|
||||
std::shared_lock lock(g_provider_mutexes_mutex);
|
||||
auto it = g_provider_mutexes.find(provider);
|
||||
if (it != g_provider_mutexes.end()) return *it->second;
|
||||
}
|
||||
// 慢速路径:写锁创建新 mutex (双重检查) / Slow path: write lock to create new mutex (double-check)
|
||||
std::unique_lock lock(g_provider_mutexes_mutex);
|
||||
auto it = g_provider_mutexes.find(provider);
|
||||
if (it != g_provider_mutexes.end()) return *it->second;
|
||||
auto [new_it, _] = g_provider_mutexes.emplace(provider, std::make_unique<std::mutex>());
|
||||
return *new_it->second;
|
||||
}
|
||||
|
||||
static std::string trim_copy(std::string s)
|
||||
{
|
||||
auto is_space = [](unsigned char c) { return c == ' ' || c == '\t' || c == '\r' || c == '\n'; };
|
||||
while (!s.empty() && is_space(static_cast<unsigned char>(s.front()))) s.erase(s.begin());
|
||||
while (!s.empty() && is_space(static_cast<unsigned char>(s.back()))) s.pop_back();
|
||||
return s;
|
||||
}
|
||||
|
||||
static std::vector<std::string> split_csv(const char* raw)
|
||||
{
|
||||
std::vector<std::string> out;
|
||||
if (!raw || !*raw) return out;
|
||||
std::stringstream ss(raw);
|
||||
std::string item;
|
||||
while (std::getline(ss, item, ',')) {
|
||||
item = trim_copy(item);
|
||||
if (!item.empty()) out.push_back(item);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static int parse_int_or_default(const char* raw, int fallback)
|
||||
{
|
||||
if (!raw || !*raw) return fallback;
|
||||
char* end = nullptr;
|
||||
long v = std::strtol(raw, &end, 10);
|
||||
if (!end || *end != '\0' || v <= 0 || v > 1000000) return fallback;
|
||||
return static_cast<int>(v);
|
||||
}
|
||||
|
||||
static double parse_double_or_default(const char* raw, double fallback)
|
||||
{
|
||||
if (!raw || !*raw) return fallback;
|
||||
char* end = nullptr;
|
||||
double v = std::strtod(raw, &end);
|
||||
if (!end || *end != '\0' || v < 0.0 || v > 2.0) return fallback;
|
||||
return v;
|
||||
}
|
||||
|
||||
static const char* cfg_get(const dstalk_host_api_t* host, const std::string& key)
|
||||
{
|
||||
if (!host || !host->config_get) return nullptr;
|
||||
return host->config_get(key.c_str());
|
||||
}
|
||||
|
||||
static std::string cfg_get_copy(const dstalk_host_api_t* host, const std::string& key)
|
||||
{
|
||||
const char* value = cfg_get(host, key);
|
||||
return value ? std::string(value) : std::string();
|
||||
}
|
||||
|
||||
static const char* default_base_url_for_provider(const std::string& provider)
|
||||
{
|
||||
if (provider == "ai_anthropic") return "https://api.anthropic.com";
|
||||
if (provider == "ai_openai") return "https://api.openai.com/v1";
|
||||
return nullptr; // 未知 provider 必须显式配置 base_url / unknown provider must explicitly configure base_url
|
||||
}
|
||||
|
||||
static void clear_endpoints_locked()
|
||||
{
|
||||
for (auto& kv : g_endpoints) {
|
||||
dstalk_ai::secure_zero(kv.second.api_key.data(), kv.second.api_key.size());
|
||||
kv.second.api_key.clear();
|
||||
}
|
||||
g_endpoints.clear();
|
||||
g_active_endpoint.clear();
|
||||
}
|
||||
|
||||
static const dstalk_ai_service_t* lookup_ai_service(const EndpointConfig& ep)
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (!host || !host->query_service || ep.provider.empty()) return nullptr;
|
||||
return static_cast<const dstalk_ai_service_t*>(host->query_service(ep.provider.c_str(), 1));
|
||||
}
|
||||
|
||||
static dstalk_chat_result_t make_error(const char* msg, int status = 0)
|
||||
{
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 0;
|
||||
r.http_status = status;
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
r.error = (host && host->strdup) ? host->strdup(msg ? msg : "endpoint manager error") : nullptr;
|
||||
return r;
|
||||
}
|
||||
|
||||
static bool load_endpoint(const dstalk_host_api_t* host, const std::string& name, EndpointConfig& out)
|
||||
{
|
||||
std::string prefix = "endpoint." + name + ".";
|
||||
std::string provider = cfg_get_copy(host, prefix + "provider");
|
||||
std::string base_url = cfg_get_copy(host, prefix + "base_url");
|
||||
std::string api_key = cfg_get_copy(host, prefix + "api_key");
|
||||
std::string model = cfg_get_copy(host, prefix + "model");
|
||||
if (provider.empty() || model.empty()) return false;
|
||||
|
||||
out.name = name;
|
||||
out.provider = provider;
|
||||
|
||||
// 设定 base_url: 显式配置优先,其次用已知 provider 的默认值;未知 provider 必须显式配置
|
||||
// Determine base_url: explicit config first, then known provider default; unknown providers must configure explicitly
|
||||
if (!base_url.empty()) {
|
||||
out.base_url = base_url;
|
||||
} else {
|
||||
const char* default_url = default_base_url_for_provider(out.provider);
|
||||
if (default_url) {
|
||||
out.base_url = default_url;
|
||||
} else {
|
||||
return false; // 未知 provider 且未配置 base_url / unknown provider without explicit base_url
|
||||
}
|
||||
}
|
||||
|
||||
out.api_key = api_key;
|
||||
out.model = model;
|
||||
out.max_tokens = parse_int_or_default(cfg_get(host, prefix + "max_tokens"), 4096);
|
||||
out.temperature = parse_double_or_default(cfg_get(host, prefix + "temperature"), 0.7);
|
||||
return host && host->query_service && host->query_service(out.provider.c_str(), 1) != nullptr;
|
||||
}
|
||||
|
||||
static int reload_endpoints_locked(const dstalk_host_api_t* host)
|
||||
{
|
||||
clear_endpoints_locked();
|
||||
if (!host) return 0;
|
||||
|
||||
std::vector<std::string> names = split_csv(cfg_get(host, "endpoints.names"));
|
||||
if (names.empty()) return 0;
|
||||
|
||||
for (const std::string& name : names) {
|
||||
if (g_endpoints.find(name) != g_endpoints.end()) {
|
||||
if (host->log) host->log(DSTALK_LOG_WARN, "[ai_endpoint_mgr] skipping duplicate endpoint '%s'", name.c_str());
|
||||
continue;
|
||||
}
|
||||
EndpointConfig ep;
|
||||
if (load_endpoint(host, name, ep)) {
|
||||
if (g_active_endpoint.empty()) g_active_endpoint = name;
|
||||
g_endpoints.emplace(name, std::move(ep));
|
||||
} else if (host->log) {
|
||||
host->log(DSTALK_LOG_WARN, "[ai_endpoint_mgr] skipping invalid endpoint '%s'", name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
const char* active = cfg_get(host, "endpoints.active");
|
||||
if (active && g_endpoints.count(active)) {
|
||||
g_active_endpoint = active;
|
||||
}
|
||||
return static_cast<int>(g_endpoints.size());
|
||||
}
|
||||
|
||||
static int mgr_count()
|
||||
{
|
||||
std::shared_lock lock(g_endpoints_mutex);
|
||||
return static_cast<int>(g_endpoints.size());
|
||||
}
|
||||
|
||||
static char* mgr_list_json()
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (!host || !host->strdup) return nullptr;
|
||||
json::array arr;
|
||||
{
|
||||
std::shared_lock lock(g_endpoints_mutex);
|
||||
std::vector<std::string> names;
|
||||
names.reserve(g_endpoints.size());
|
||||
for (const auto& kv : g_endpoints) names.push_back(kv.first);
|
||||
std::sort(names.begin(), names.end());
|
||||
for (const auto& name : names) {
|
||||
const auto& ep = g_endpoints.at(name);
|
||||
json::object o;
|
||||
o["name"] = ep.name;
|
||||
o["provider"] = ep.provider;
|
||||
o["base_url"] = ep.base_url;
|
||||
o["model"] = ep.model;
|
||||
o["active"] = (ep.name == g_active_endpoint);
|
||||
arr.emplace_back(std::move(o));
|
||||
}
|
||||
}
|
||||
return host->strdup(json::serialize(arr).c_str());
|
||||
}
|
||||
|
||||
static const char* mgr_get_active()
|
||||
{
|
||||
static thread_local std::string active;
|
||||
std::shared_lock lock(g_endpoints_mutex);
|
||||
active = g_active_endpoint;
|
||||
return active.empty() ? nullptr : active.c_str();
|
||||
}
|
||||
|
||||
static int mgr_set_active(const char* endpoint_name)
|
||||
{
|
||||
if (!endpoint_name || !*endpoint_name) return -1;
|
||||
std::unique_lock lock(g_endpoints_mutex);
|
||||
auto it = g_endpoints.find(endpoint_name);
|
||||
if (it == g_endpoints.end()) return -2;
|
||||
g_active_endpoint = endpoint_name;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static EndpointConfig lookup_endpoint(const char* endpoint_name)
|
||||
{
|
||||
std::shared_lock lock(g_endpoints_mutex);
|
||||
std::string name;
|
||||
if (endpoint_name && *endpoint_name) name = endpoint_name;
|
||||
else name = g_active_endpoint;
|
||||
auto it = g_endpoints.find(name);
|
||||
if (it == g_endpoints.end()) return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static int mgr_set_model(const char* endpoint_name, const char* model)
|
||||
{
|
||||
if (!model || !*model) return -1;
|
||||
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
std::string selected;
|
||||
{
|
||||
std::unique_lock lock(g_endpoints_mutex);
|
||||
selected = (endpoint_name && *endpoint_name) ? endpoint_name : g_active_endpoint;
|
||||
auto it = g_endpoints.find(selected);
|
||||
if (it == g_endpoints.end()) return -2;
|
||||
it->second.model = model;
|
||||
}
|
||||
|
||||
if (host && host->config_set) {
|
||||
std::string key = "endpoint." + selected + ".model";
|
||||
host->config_set(key.c_str(), model);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static dstalk_chat_result_t mgr_chat(const char* endpoint_name,
|
||||
const dstalk_message_t* history,
|
||||
int history_len,
|
||||
const char* user_input,
|
||||
const char* tools_json)
|
||||
{
|
||||
// 防御: history_len > 0 时 history 不得为 nullptr / Guard: history must not be null when history_len > 0
|
||||
if (history_len > 0 && history == nullptr) return make_error("null history with non-zero length");
|
||||
EndpointConfig ep = lookup_endpoint(endpoint_name);
|
||||
const dstalk_ai_service_t* service = lookup_ai_service(ep);
|
||||
if (!service) return make_error("endpoint not found");
|
||||
std::lock_guard<std::mutex> guard(provider_mutex(ep.provider));
|
||||
int rc = service->configure(ep.provider.c_str(), ep.base_url.c_str(), ep.api_key.c_str(),
|
||||
ep.model.c_str(), ep.max_tokens, ep.temperature);
|
||||
if (rc != 0) return make_error("endpoint configure failed");
|
||||
return service->chat(history, history_len, user_input, tools_json);
|
||||
}
|
||||
|
||||
static dstalk_chat_result_t mgr_chat_stream(const char* endpoint_name,
|
||||
const dstalk_message_t* history,
|
||||
int history_len,
|
||||
const char* user_input,
|
||||
dstalk_stream_cb cb,
|
||||
void* userdata)
|
||||
{
|
||||
// 防御: history_len > 0 时 history 不得为 nullptr / Guard: history must not be null when history_len > 0
|
||||
if (history_len > 0 && history == nullptr) return make_error("null history with non-zero length");
|
||||
EndpointConfig ep = lookup_endpoint(endpoint_name);
|
||||
const dstalk_ai_service_t* service = lookup_ai_service(ep);
|
||||
if (!service) return make_error("endpoint not found");
|
||||
std::lock_guard<std::mutex> guard(provider_mutex(ep.provider));
|
||||
int rc = service->configure(ep.provider.c_str(), ep.base_url.c_str(), ep.api_key.c_str(),
|
||||
ep.model.c_str(), ep.max_tokens, ep.temperature);
|
||||
if (rc != 0) return make_error("endpoint configure failed");
|
||||
return service->chat_stream(history, history_len, user_input, cb, userdata);
|
||||
}
|
||||
|
||||
static void mgr_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
dstalk_ai::free_chat_result(host, result);
|
||||
}
|
||||
|
||||
static dstalk_ai_endpoint_mgr_t g_service = {
|
||||
&mgr_count,
|
||||
&mgr_list_json,
|
||||
&mgr_get_active,
|
||||
&mgr_set_active,
|
||||
&mgr_set_model,
|
||||
&mgr_chat,
|
||||
&mgr_chat_stream,
|
||||
&mgr_free_result,
|
||||
};
|
||||
|
||||
static int on_init(const dstalk_host_api_t* host)
|
||||
{
|
||||
try {
|
||||
if (!host) return -1;
|
||||
g_host.store(host, std::memory_order_release);
|
||||
{
|
||||
std::unique_lock lock(g_endpoints_mutex);
|
||||
reload_endpoints_locked(host);
|
||||
}
|
||||
if (host->log) host->log(DSTALK_LOG_INFO, "[ai_endpoint_mgr] loaded %d endpoint(s)", mgr_count());
|
||||
return host->register_service("ai_endpoint_mgr", 1, &g_service);
|
||||
} catch (const std::exception& e) {
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_init exception: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_init unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static void on_shutdown()
|
||||
{
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_INFO, "[ai_endpoint_mgr] shutdown");
|
||||
std::unique_lock lock(g_endpoints_mutex);
|
||||
clear_endpoints_locked();
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_shutdown exception: %s", e.what());
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_shutdown unknown exception");
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
}
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "ai_endpoint_mgr",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "AI endpoint manager for multiple named provider/model endpoints / 多命名 AI endpoint 管理器",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ { "openai_compat", "anthropic_ai", NULL },
|
||||
/* .on_init = */ on_init,
|
||||
/* .on_shutdown = */ on_shutdown,
|
||||
/* .on_event = */ nullptr,
|
||||
};
|
||||
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
|
||||
{
|
||||
return &g_info;
|
||||
}
|
||||
@@ -65,6 +65,7 @@ static std::string build_request_json(
|
||||
std::string system_prompt;
|
||||
json::array msgs;
|
||||
|
||||
if (history) { // 防御: history 为空时跳过历史遍历 / Defensive: skip history iteration when null
|
||||
for (int i = 0; i < history_len; ++i) {
|
||||
const auto& m = history[i];
|
||||
if (m.role && std::strcmp(m.role, "system") == 0) {
|
||||
@@ -77,6 +78,7 @@ static std::string build_request_json(
|
||||
obj["content"] = m.content ? m.content : "";
|
||||
msgs.push_back(obj);
|
||||
}
|
||||
} // if (history)
|
||||
|
||||
// 追加当前用户输入 / Append current user input
|
||||
{
|
||||
|
||||
@@ -49,6 +49,7 @@ static std::string build_headers_json(const std::string& auth_header_value)
|
||||
static void append_history(json::array& msgs,
|
||||
const dstalk_message_t* history, int history_len)
|
||||
{
|
||||
if (!history) return; // 防御: history 为空时直接返回 / Defensive: return early when history is null
|
||||
for (int i = 0; i < history_len; ++i) {
|
||||
const auto& m = history[i];
|
||||
json::object obj;
|
||||
|
||||
@@ -145,6 +145,8 @@ target_compile_definitions(dstalk_anthropic_plugin_test
|
||||
BOOST_ALL_NO_LIB
|
||||
)
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
target_link_libraries(dstalk_anthropic_plugin_test
|
||||
PRIVATE
|
||||
dstalk
|
||||
@@ -174,6 +176,8 @@ target_compile_definitions(dstalk_openai_plugin_test
|
||||
BOOST_ALL_NO_LIB
|
||||
)
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
target_link_libraries(dstalk_openai_plugin_test
|
||||
PRIVATE
|
||||
dstalk
|
||||
@@ -183,6 +187,37 @@ target_link_libraries(dstalk_openai_plugin_test
|
||||
|
||||
add_test(NAME dstalk_openai_plugin_test COMMAND dstalk_openai_plugin_test)
|
||||
|
||||
# ============================================================
|
||||
# dstalk_endpoint_mgr_plugin_test — AI endpoint manager 插件单元测试
|
||||
# 覆盖: endpoint 加载/列表脱敏/active/model 修改/路由
|
||||
# ============================================================
|
||||
|
||||
add_executable(dstalk_endpoint_mgr_plugin_test
|
||||
endpoint_mgr_plugin_test.cpp
|
||||
)
|
||||
|
||||
target_include_directories(dstalk_endpoint_mgr_plugin_test
|
||||
PRIVATE ${CMAKE_SOURCE_DIR}/dstalk_core/include
|
||||
PRIVATE ${CMAKE_SOURCE_DIR}/plugins_upper/ai_common/include
|
||||
)
|
||||
|
||||
target_compile_definitions(dstalk_endpoint_mgr_plugin_test
|
||||
PRIVATE
|
||||
BOOST_JSON_HEADER_ONLY
|
||||
BOOST_ALL_NO_LIB
|
||||
)
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
target_link_libraries(dstalk_endpoint_mgr_plugin_test
|
||||
PRIVATE
|
||||
dstalk
|
||||
ai_common
|
||||
boost::boost
|
||||
)
|
||||
|
||||
add_test(NAME dstalk_endpoint_mgr_plugin_test COMMAND dstalk_endpoint_mgr_plugin_test)
|
||||
|
||||
# ============================================================
|
||||
# dstalk_network_plugin_test — Network 插件单元测试
|
||||
# W22.2 (qa-xu): 通过 #include source 访问 static 函数
|
||||
@@ -203,6 +238,8 @@ target_compile_definitions(dstalk_network_plugin_test
|
||||
BOOST_ALL_NO_LIB
|
||||
)
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
target_link_libraries(dstalk_network_plugin_test
|
||||
PRIVATE
|
||||
dstalk
|
||||
@@ -238,6 +275,8 @@ target_compile_features(dstalk_lsp_plugin_test
|
||||
PRIVATE cxx_std_17
|
||||
)
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
target_link_libraries(dstalk_lsp_plugin_test
|
||||
PRIVATE
|
||||
dstalk
|
||||
|
||||
547
tests/endpoint_mgr_plugin_test.cpp
Normal file
547
tests/endpoint_mgr_plugin_test.cpp
Normal file
@@ -0,0 +1,547 @@
|
||||
/*
|
||||
* @file endpoint_mgr_plugin_test.cpp
|
||||
* @brief AI endpoint manager plugin unit tests: endpoint loading, active/model mutation, routing, secret-safe listing,
|
||||
* plus error-path coverage (null history, missing endpoint, configure failed, empty/bad active, concurrency).
|
||||
* AI endpoint 管理器插件单元测试:endpoint 加载、active/model 修改、路由和脱敏列表、
|
||||
* 以及错误路径覆盖(空 history、缺失 endpoint、configure 失败、空/错 active、并发)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
#include "../plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp"
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
static int g_failures = 0;
|
||||
#define CHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
} else { \
|
||||
std::cerr << "[FAIL] " << (msg) << "\n"; \
|
||||
g_failures++; \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
struct ConfigureRecord {
|
||||
std::string provider;
|
||||
std::string base_url;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
int max_tokens = 0;
|
||||
double temperature = 0.0;
|
||||
int configure_calls = 0;
|
||||
int chat_calls = 0;
|
||||
int stream_calls = 0;
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, std::string> g_config_values;
|
||||
static const dstalk_ai_endpoint_mgr_t* g_registered_mgr = nullptr;
|
||||
static ConfigureRecord g_last_configure;
|
||||
static int g_stream_cb_count = 0;
|
||||
|
||||
static void* fake_alloc(size_t size)
|
||||
{
|
||||
return std::malloc(size);
|
||||
}
|
||||
|
||||
static void fake_free(void* ptr)
|
||||
{
|
||||
std::free(ptr);
|
||||
}
|
||||
|
||||
static char* fake_strdup(const char* s)
|
||||
{
|
||||
if (!s) return nullptr;
|
||||
size_t n = std::strlen(s) + 1;
|
||||
char* p = static_cast<char*>(std::malloc(n));
|
||||
if (p) std::memcpy(p, s, n);
|
||||
return p;
|
||||
}
|
||||
|
||||
static int fake_register_service(const char* name, int version, void* vtable)
|
||||
{
|
||||
if (!name || !vtable || version != 1) return -1;
|
||||
if (std::strcmp(name, "ai_endpoint_mgr") == 0) {
|
||||
g_registered_mgr = static_cast<const dstalk_ai_endpoint_mgr_t*>(vtable);
|
||||
return 0;
|
||||
}
|
||||
return -2;
|
||||
}
|
||||
|
||||
static void fake_unregister_service(const char*)
|
||||
{
|
||||
}
|
||||
|
||||
static const char* fake_config_get(const char* key)
|
||||
{
|
||||
if (!key) return nullptr;
|
||||
auto it = g_config_values.find(key);
|
||||
if (it == g_config_values.end()) return nullptr;
|
||||
static thread_local std::string tls_value;
|
||||
tls_value = it->second;
|
||||
return tls_value.c_str();
|
||||
}
|
||||
|
||||
static int fake_config_set(const char* key, const char* value)
|
||||
{
|
||||
if (!key || !value) return -1;
|
||||
g_config_values[key] = value;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void fake_log(int, const char*, ...)
|
||||
{
|
||||
}
|
||||
|
||||
static int fake_configure(const char* provider, const char* base_url,
|
||||
const char* api_key, const char* model,
|
||||
int max_tokens, double temperature)
|
||||
{
|
||||
g_last_configure.provider = provider ? provider : "";
|
||||
g_last_configure.base_url = base_url ? base_url : "";
|
||||
g_last_configure.api_key = api_key ? api_key : "";
|
||||
g_last_configure.model = model ? model : "";
|
||||
g_last_configure.max_tokens = max_tokens;
|
||||
g_last_configure.temperature = temperature;
|
||||
g_last_configure.configure_calls++;
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 模拟 configure 失败的 provider service / Fake provider service whose configure always fails
|
||||
static int fake_configure_fail(const char*, const char*, const char*, const char*, int, double)
|
||||
{
|
||||
g_last_configure.configure_calls++;
|
||||
return -1;
|
||||
}
|
||||
|
||||
static dstalk_chat_result_t fake_chat(const dstalk_message_t*, int,
|
||||
const char*, const char*)
|
||||
{
|
||||
g_last_configure.chat_calls++;
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 1;
|
||||
r.content = fake_strdup("ok");
|
||||
return r;
|
||||
}
|
||||
|
||||
static int test_stream_cb(const char*, void* userdata)
|
||||
{
|
||||
int* count = static_cast<int*>(userdata);
|
||||
if (count) (*count)++;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static dstalk_chat_result_t fake_chat_stream(const dstalk_message_t*, int,
|
||||
const char*, dstalk_stream_cb cb,
|
||||
void* userdata)
|
||||
{
|
||||
g_last_configure.stream_calls++;
|
||||
if (cb) cb("tok", userdata);
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 1;
|
||||
r.content = fake_strdup("stream-ok");
|
||||
return r;
|
||||
}
|
||||
|
||||
static void fake_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
if (!result) return;
|
||||
if (result->content) fake_free((void*)result->content);
|
||||
if (result->error) fake_free((void*)result->error);
|
||||
if (result->tool_calls_json) fake_free((void*)result->tool_calls_json);
|
||||
result->content = nullptr;
|
||||
result->error = nullptr;
|
||||
result->tool_calls_json = nullptr;
|
||||
}
|
||||
|
||||
static dstalk_ai_service_t g_fake_openai_service = {
|
||||
&fake_configure,
|
||||
&fake_chat,
|
||||
&fake_chat_stream,
|
||||
&fake_free_result,
|
||||
};
|
||||
|
||||
static dstalk_ai_service_t g_fake_anthropic_service = {
|
||||
&fake_configure,
|
||||
&fake_chat,
|
||||
&fake_chat_stream,
|
||||
&fake_free_result,
|
||||
};
|
||||
|
||||
// configure 总是失败的 provider 服务 / Provider service whose configure always fails
|
||||
static dstalk_ai_service_t g_fake_failing_service = {
|
||||
&fake_configure_fail,
|
||||
&fake_chat,
|
||||
&fake_chat_stream,
|
||||
&fake_free_result,
|
||||
};
|
||||
|
||||
static void* fake_query_service(const char* name, int min_version)
|
||||
{
|
||||
if (!name || min_version > 1) return nullptr;
|
||||
if (std::strcmp(name, "ai_openai") == 0) return &g_fake_openai_service;
|
||||
if (std::strcmp(name, "ai_anthropic") == 0) return &g_fake_anthropic_service;
|
||||
if (std::strcmp(name, "ai_failing") == 0) return &g_fake_failing_service;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static dstalk_host_api_t make_fake_host()
|
||||
{
|
||||
dstalk_host_api_t host = {};
|
||||
host.register_service = fake_register_service;
|
||||
host.query_service = fake_query_service;
|
||||
host.unregister_service = fake_unregister_service;
|
||||
host.config_get = fake_config_get;
|
||||
host.config_set = fake_config_set;
|
||||
host.log = fake_log;
|
||||
host.alloc = fake_alloc;
|
||||
host.free = fake_free;
|
||||
host.strdup = fake_strdup;
|
||||
return host;
|
||||
}
|
||||
|
||||
static void setup_endpoint_config()
|
||||
{
|
||||
g_config_values.clear();
|
||||
g_config_values["endpoints.names"] = "openai_main, anthropic_alt, missing_provider, openai_main";
|
||||
g_config_values["endpoints.active"] = "anthropic_alt";
|
||||
|
||||
g_config_values["endpoint.openai_main.provider"] = "ai_openai";
|
||||
g_config_values["endpoint.openai_main.api_key"] = "sk-openai-test";
|
||||
g_config_values["endpoint.openai_main.model"] = "gpt-4o";
|
||||
g_config_values["endpoint.openai_main.max_tokens"] = "1234";
|
||||
g_config_values["endpoint.openai_main.temperature"] = "0.25";
|
||||
|
||||
g_config_values["endpoint.anthropic_alt.provider"] = "ai_anthropic";
|
||||
g_config_values["endpoint.anthropic_alt.api_key"] = "sk-ant-test";
|
||||
g_config_values["endpoint.anthropic_alt.model"] = "claude-sonnet-test";
|
||||
|
||||
g_config_values["endpoint.missing_provider.provider"] = "ai_missing";
|
||||
g_config_values["endpoint.missing_provider.model"] = "missing-model";
|
||||
}
|
||||
|
||||
// 设置单 endpoint 配置用于错误路径测试 / Set up single-endpoint config for error-path testing
|
||||
static void setup_single_endpoint(const char* name, const char* provider, const char* model,
|
||||
const char* base_url = nullptr)
|
||||
{
|
||||
g_config_values.clear();
|
||||
std::string names_key = "endpoints.names";
|
||||
g_config_values[names_key] = name;
|
||||
std::string prefix = std::string("endpoint.") + name + ".";
|
||||
g_config_values[prefix + "provider"] = provider;
|
||||
g_config_values[prefix + "api_key"] = "sk-test";
|
||||
g_config_values[prefix + "model"] = model;
|
||||
if (base_url && *base_url) {
|
||||
g_config_values[prefix + "base_url"] = base_url;
|
||||
}
|
||||
}
|
||||
|
||||
static void reset_test_state()
|
||||
{
|
||||
on_shutdown();
|
||||
g_registered_mgr = nullptr;
|
||||
g_last_configure = {};
|
||||
g_config_values.clear();
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
// ================================================================
|
||||
// 主测试流程 / Main test flow (existing)
|
||||
// ================================================================
|
||||
on_shutdown();
|
||||
setup_endpoint_config();
|
||||
g_registered_mgr = nullptr;
|
||||
g_last_configure = {};
|
||||
|
||||
dstalk_host_api_t host = make_fake_host();
|
||||
int init_rc = on_init(&host);
|
||||
CHECK(init_rc == 0, "on_init registers endpoint manager service");
|
||||
CHECK(g_registered_mgr != nullptr, "registered service pointer captured");
|
||||
|
||||
const dstalk_ai_endpoint_mgr_t* mgr = g_registered_mgr;
|
||||
CHECK(mgr && mgr->count() == 2, "loads valid endpoints, skips missing provider and duplicate");
|
||||
|
||||
const char* active = mgr ? mgr->get_active() : nullptr;
|
||||
CHECK(active && std::strcmp(active, "anthropic_alt") == 0,
|
||||
"configured active endpoint is selected");
|
||||
|
||||
char* list = mgr ? mgr->list_json() : nullptr;
|
||||
std::string list_json = list ? list : "";
|
||||
if (list) fake_free(list);
|
||||
CHECK(list_json.find("openai_main") != std::string::npos,
|
||||
"list_json includes OpenAI endpoint");
|
||||
CHECK(list_json.find("anthropic_alt") != std::string::npos,
|
||||
"list_json includes Anthropic endpoint");
|
||||
CHECK(list_json.find("https://api.anthropic.com") != std::string::npos,
|
||||
"Anthropic endpoint uses Anthropic default base_url");
|
||||
CHECK(list_json.find("sk-openai-test") == std::string::npos &&
|
||||
list_json.find("sk-ant-test") == std::string::npos,
|
||||
"list_json does not expose API keys");
|
||||
|
||||
CHECK(mgr->set_active("missing") == -2, "set_active rejects unknown endpoint");
|
||||
CHECK(mgr->set_active("openai_main") == 0, "set_active accepts known endpoint");
|
||||
CHECK(std::strcmp(mgr->get_active(), "openai_main") == 0,
|
||||
"get_active reflects set_active change");
|
||||
|
||||
CHECK(mgr->set_model(nullptr, "gpt-4.1-mini") == 0,
|
||||
"set_model(nullptr, model) updates active endpoint");
|
||||
CHECK(g_config_values["endpoint.openai_main.model"] == "gpt-4.1-mini",
|
||||
"set_model mirrors model to host config store");
|
||||
CHECK(mgr->set_model("missing", "model") == -2,
|
||||
"set_model rejects unknown endpoint");
|
||||
CHECK(mgr->set_model("openai_main", "") == -1,
|
||||
"set_model rejects empty model");
|
||||
|
||||
dstalk_message_t msg = {"user", "hello", nullptr, nullptr};
|
||||
dstalk_chat_result_t r = mgr->chat(nullptr, &msg, 1, "hi", "[]");
|
||||
CHECK(r.ok == 1 && r.content && std::strcmp(r.content, "ok") == 0,
|
||||
"chat routes to active endpoint service");
|
||||
CHECK(g_last_configure.provider == "ai_openai" &&
|
||||
g_last_configure.base_url == "https://api.openai.com/v1" &&
|
||||
g_last_configure.model == "gpt-4.1-mini" &&
|
||||
g_last_configure.max_tokens == 1234 &&
|
||||
g_last_configure.temperature == 0.25,
|
||||
"chat configures OpenAI endpoint before routing");
|
||||
mgr->free_result(&r);
|
||||
|
||||
CHECK(mgr->set_active("anthropic_alt") == 0, "switch active endpoint to Anthropic");
|
||||
r = mgr->chat(nullptr, &msg, 1, "hi", nullptr);
|
||||
CHECK(r.ok == 1 && g_last_configure.provider == "ai_anthropic" &&
|
||||
g_last_configure.base_url == "https://api.anthropic.com" &&
|
||||
g_last_configure.model == "claude-sonnet-test",
|
||||
"chat configures Anthropic endpoint before routing");
|
||||
mgr->free_result(&r);
|
||||
|
||||
g_stream_cb_count = 0;
|
||||
r = mgr->chat_stream("anthropic_alt", &msg, 1, "hi", test_stream_cb, &g_stream_cb_count);
|
||||
CHECK(r.ok == 1 && g_stream_cb_count == 1,
|
||||
"chat_stream routes callback through selected endpoint");
|
||||
mgr->free_result(&r);
|
||||
|
||||
on_shutdown();
|
||||
CHECK(mgr->count() == 0, "on_shutdown clears endpoint cache");
|
||||
|
||||
// ================================================================
|
||||
// 错误路径测试 / Error-path tests (P3)
|
||||
// ================================================================
|
||||
|
||||
// --- E1: null history with non-zero length / 空指针 history ---
|
||||
{
|
||||
reset_test_state();
|
||||
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
|
||||
host = make_fake_host();
|
||||
CHECK(on_init(&host) == 0, "E1: init with single endpoint");
|
||||
mgr = g_registered_mgr;
|
||||
CHECK(mgr && mgr->count() == 1, "E1: one endpoint loaded");
|
||||
|
||||
// null history + non-zero len -> 应返回错误 / should return error
|
||||
r = mgr->chat("test_ep", nullptr, 3, "hi", nullptr);
|
||||
CHECK(r.ok == 0 && r.error != nullptr &&
|
||||
std::strcmp(r.error, "null history with non-zero length") == 0,
|
||||
"E1: null history with len>0 returns error for chat");
|
||||
mgr->free_result(&r);
|
||||
|
||||
r = mgr->chat_stream("test_ep", nullptr, 2, "hi", nullptr, nullptr);
|
||||
CHECK(r.ok == 0 && r.error != nullptr &&
|
||||
std::strcmp(r.error, "null history with non-zero length") == 0,
|
||||
"E1: null history with len>0 returns error for chat_stream");
|
||||
mgr->free_result(&r);
|
||||
|
||||
// null history + zero len -> 应正常 (无历史) / should pass (empty history)
|
||||
r = mgr->chat("test_ep", nullptr, 0, "hi", nullptr);
|
||||
CHECK(r.ok == 1, "E1: null history with len=0 passes for chat");
|
||||
mgr->free_result(&r);
|
||||
|
||||
on_shutdown();
|
||||
}
|
||||
|
||||
// --- E2: missing endpoint chat / 缺失 endpoint ---
|
||||
{
|
||||
reset_test_state();
|
||||
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
|
||||
host = make_fake_host();
|
||||
CHECK(on_init(&host) == 0, "E2: init");
|
||||
mgr = g_registered_mgr;
|
||||
|
||||
r = mgr->chat("nonexistent_ep", &msg, 1, "hi", nullptr);
|
||||
CHECK(r.ok == 0 && r.error != nullptr &&
|
||||
std::strcmp(r.error, "endpoint not found") == 0,
|
||||
"E2: chat with nonexistent endpoint returns error");
|
||||
mgr->free_result(&r);
|
||||
|
||||
r = mgr->chat_stream("nonexistent_ep", &msg, 1, "hi", nullptr, nullptr);
|
||||
CHECK(r.ok == 0 && r.error != nullptr &&
|
||||
std::strcmp(r.error, "endpoint not found") == 0,
|
||||
"E2: chat_stream with nonexistent endpoint returns error");
|
||||
mgr->free_result(&r);
|
||||
|
||||
on_shutdown();
|
||||
}
|
||||
|
||||
// --- E3: configure failed / configure 失败 ---
|
||||
{
|
||||
reset_test_state();
|
||||
// ai_failing provider 的 configure 总是返回 -1 / ai_failing provider's configure always returns -1
|
||||
setup_single_endpoint("fail_ep", "ai_failing", "fail-model", "https://fail.example.com/api");
|
||||
host = make_fake_host();
|
||||
int rc = on_init(&host);
|
||||
CHECK(rc == 0, "E3: init with failing endpoint");
|
||||
mgr = g_registered_mgr;
|
||||
CHECK(mgr && mgr->count() == 1, "E3: failing endpoint loaded (service exists)");
|
||||
|
||||
r = mgr->chat("fail_ep", &msg, 1, "hi", nullptr);
|
||||
CHECK(r.ok == 0 && r.error != nullptr &&
|
||||
std::strcmp(r.error, "endpoint configure failed") == 0,
|
||||
"E3: chat returns error when configure fails");
|
||||
mgr->free_result(&r);
|
||||
|
||||
r = mgr->chat_stream("fail_ep", &msg, 1, "hi", nullptr, nullptr);
|
||||
CHECK(r.ok == 0 && r.error != nullptr &&
|
||||
std::strcmp(r.error, "endpoint configure failed") == 0,
|
||||
"E3: chat_stream returns error when configure fails");
|
||||
mgr->free_result(&r);
|
||||
|
||||
on_shutdown();
|
||||
}
|
||||
|
||||
// --- E4: empty endpoints (no endpoints.names config key) / 空 endpoint 列表 ---
|
||||
{
|
||||
reset_test_state();
|
||||
// 不设置任何 endpoint 配置 / No endpoint config at all
|
||||
g_config_values.clear();
|
||||
host = make_fake_host();
|
||||
CHECK(on_init(&host) == 0, "E4: init with no endpoint config");
|
||||
mgr = g_registered_mgr;
|
||||
CHECK(mgr && mgr->count() == 0, "E4: zero endpoints loaded");
|
||||
CHECK(mgr->get_active() == nullptr, "E4: get_active returns nullptr when no endpoints");
|
||||
|
||||
// 在没有 endpoint 的情况下,chat 应报错 / chat should error when no endpoints
|
||||
r = mgr->chat(nullptr, &msg, 1, "hi", nullptr);
|
||||
CHECK(r.ok == 0 && r.error != nullptr &&
|
||||
std::strcmp(r.error, "endpoint not found") == 0,
|
||||
"E4: chat with no endpoints returns error");
|
||||
mgr->free_result(&r);
|
||||
|
||||
on_shutdown();
|
||||
}
|
||||
|
||||
// --- E5: bad active (active set to nonexistent endpoint) / 无效 active ---
|
||||
{
|
||||
reset_test_state();
|
||||
g_config_values["endpoints.names"] = "ep1";
|
||||
g_config_values["endpoints.active"] = "does_not_exist"; // 不存在的 active / nonexistent active
|
||||
g_config_values["endpoint.ep1.provider"] = "ai_openai";
|
||||
g_config_values["endpoint.ep1.api_key"] = "sk-test";
|
||||
g_config_values["endpoint.ep1.model"] = "gpt-test";
|
||||
host = make_fake_host();
|
||||
CHECK(on_init(&host) == 0, "E5: init with bad active config");
|
||||
mgr = g_registered_mgr;
|
||||
CHECK(mgr && mgr->count() == 1, "E5: endpoint loaded despite bad active");
|
||||
|
||||
// 当 active 指向不存在的 endpoint 时,get_active 应返回第一个加载有效的 endpoint
|
||||
// When active points to nonexistent endpoint, get_active should return the first valid loaded endpoint
|
||||
active = mgr->get_active();
|
||||
CHECK(active != nullptr && std::strcmp(active, "ep1") == 0,
|
||||
"E5: get_active falls back to first loaded endpoint when configured active is invalid");
|
||||
|
||||
on_shutdown();
|
||||
}
|
||||
|
||||
// --- E6: set_active with null/empty name / set_active 空名称 ---
|
||||
{
|
||||
reset_test_state();
|
||||
setup_single_endpoint("test_ep", "ai_openai", "gpt-test");
|
||||
host = make_fake_host();
|
||||
CHECK(on_init(&host) == 0, "E6: init");
|
||||
mgr = g_registered_mgr;
|
||||
|
||||
CHECK(mgr->set_active(nullptr) == -1, "E6: set_active(nullptr) returns -1");
|
||||
CHECK(mgr->set_active("") == -1, "E6: set_active(\"\") returns -1");
|
||||
|
||||
on_shutdown();
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// secure_zero 基础测试 / Basic secure_zero test
|
||||
// ================================================================
|
||||
{
|
||||
// 分配缓冲区并填充可识别模式 / Allocate buffer and fill with recognizable pattern
|
||||
char buf[64];
|
||||
std::memset(buf, 0xAB, sizeof(buf));
|
||||
dstalk_ai::secure_zero(buf, sizeof(buf));
|
||||
bool all_zero = true;
|
||||
for (int i = 0; i < (int)sizeof(buf); ++i) {
|
||||
if (buf[i] != 0) { all_zero = false; break; }
|
||||
}
|
||||
CHECK(all_zero, "secure_zero wipes all bytes to zero");
|
||||
|
||||
// 空 size / zero size — 不崩溃 / should not crash
|
||||
dstalk_ai::secure_zero(buf, 0);
|
||||
CHECK(true, "secure_zero with size=0 does not crash");
|
||||
|
||||
// nullptr + zero size — 不崩溃 / should not crash
|
||||
dstalk_ai::secure_zero(nullptr, 0);
|
||||
CHECK(true, "secure_zero(nullptr, 0) does not crash");
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// 轻量并发读写测试 / Lightweight concurrent read/write test
|
||||
// ================================================================
|
||||
{
|
||||
reset_test_state();
|
||||
setup_single_endpoint("conc_ep", "ai_openai", "gpt-concurrent");
|
||||
host = make_fake_host();
|
||||
CHECK(on_init(&host) == 0, "concurrency setup: init");
|
||||
mgr = g_registered_mgr;
|
||||
CHECK(mgr && mgr->count() == 1, "concurrency setup: endpoint loaded");
|
||||
|
||||
const int kReaders = 4;
|
||||
const int kIters = 100;
|
||||
std::vector<std::thread> threads;
|
||||
std::atomic<int> errors{0};
|
||||
|
||||
// 读者线程: 反复调用 get_active / list_json / count / set_active
|
||||
// Reader threads: repeatedly call get_active / list_json / count / set_active
|
||||
for (int t = 0; t < kReaders; ++t) {
|
||||
threads.emplace_back([mgr, &errors, t]() {
|
||||
for (int i = 0; i < kIters; ++i) {
|
||||
// 读操作 / read operations
|
||||
const char* a = mgr->get_active();
|
||||
if (a && std::strcmp(a, "conc_ep") != 0) errors++;
|
||||
int c = mgr->count();
|
||||
if (c != 1) errors++;
|
||||
char* l = mgr->list_json();
|
||||
if (l) fake_free(l);
|
||||
else errors++;
|
||||
|
||||
// 轻量写操作: set_active 到同一个 endpoint / lightweight write
|
||||
if (i % 10 == 0) {
|
||||
int rc = mgr->set_active("conc_ep");
|
||||
if (rc != 0) errors++;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& th : threads) th.join();
|
||||
CHECK(errors.load() == 0, "concurrent read/write: no errors across threads");
|
||||
|
||||
on_shutdown();
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// 总结 / Summary
|
||||
// ================================================================
|
||||
if (g_failures == 0) {
|
||||
std::cout << "\nendpoint_mgr_plugin_test: all checks passed\n";
|
||||
} else {
|
||||
std::cerr << "\nendpoint_mgr_plugin_test: " << g_failures << " failure(s)\n";
|
||||
}
|
||||
return g_failures == 0 ? 0 : 1;
|
||||
}
|
||||
@@ -55,6 +55,15 @@ int main()
|
||||
<< "provider = \"openai\"\n"
|
||||
<< "base_url = \"https://api.openai.com/v1\"\n"
|
||||
<< "api_key = \"test-key\"\n"
|
||||
<< "model = \"gpt-4o\"\n"
|
||||
<< "\n"
|
||||
<< "# minimal endpoint config for ai_endpoint_mgr / 最小 endpoint 配置供 ai_endpoint_mgr 加载\n"
|
||||
<< "[endpoints]\n"
|
||||
<< "names = \"gpt4o\"\n"
|
||||
<< "\n"
|
||||
<< "[endpoint.gpt4o]\n"
|
||||
<< "provider = \"ai_openai\"\n"
|
||||
<< "api_key = \"test-key\"\n"
|
||||
<< "model = \"gpt-4o\"\n";
|
||||
}
|
||||
|
||||
@@ -217,6 +226,52 @@ int main()
|
||||
// 测试 dstalk_log / Test dstalk_log
|
||||
dstalk_log(DSTALK_LOG_INFO, "Smoke test completed successfully");
|
||||
|
||||
// ========================================================================
|
||||
// 测试服务查询: ai_endpoint_mgr / Test service query: ai_endpoint_mgr
|
||||
// 验证 endpoint 加载 / 列表脱敏 / count / get_active / Verify endpoint load / list sanitization / count / get_active
|
||||
// 不调用真实 chat 网络 / No real chat network calls
|
||||
// ========================================================================
|
||||
{
|
||||
auto* ep_mgr = static_cast<const dstalk_ai_endpoint_mgr_t*>(
|
||||
dstalk_service_query("ai_endpoint_mgr", 1));
|
||||
if (ep_mgr) {
|
||||
std::cout << "[OK] ai_endpoint_mgr service found\n";
|
||||
|
||||
// 验证 endpoint 数量 >= 1 / Verify endpoint count >= 1
|
||||
int n = ep_mgr->count();
|
||||
if (n >= 1) {
|
||||
std::cout << "[OK] ai_endpoint_mgr count = " << n << "\n";
|
||||
} else {
|
||||
std::cerr << "[FAIL] ai_endpoint_mgr count = " << n << " (expected >= 1)\n";
|
||||
}
|
||||
|
||||
// 验证 list_json 非空且不泄露 api_key / Verify list_json non-null and no api_key leak
|
||||
char* list = ep_mgr->list_json();
|
||||
if (list) {
|
||||
std::string list_str(list);
|
||||
if (list_str.find("test-key") == std::string::npos) {
|
||||
std::cout << "[OK] ai_endpoint_mgr list_json sanitized (no api_key leak): "
|
||||
<< list << "\n";
|
||||
} else {
|
||||
std::cerr << "[FAIL] ai_endpoint_mgr list_json leaks api_key: " << list << "\n";
|
||||
}
|
||||
dstalk_free(list);
|
||||
} else {
|
||||
std::cerr << "[FAIL] ai_endpoint_mgr list_json returned null\n";
|
||||
}
|
||||
|
||||
// 验证 get_active 非空 / Verify get_active non-null
|
||||
const char* active = ep_mgr->get_active();
|
||||
if (active) {
|
||||
std::cout << "[OK] ai_endpoint_mgr get_active = " << active << "\n";
|
||||
} else {
|
||||
std::cerr << "[FAIL] ai_endpoint_mgr get_active returned null\n";
|
||||
}
|
||||
} else {
|
||||
std::cerr << "[WARN] ai_endpoint_mgr service not found\n";
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// 扩展测试块 C2: null-safety / 转义边界 / tools 调用链 / session 健壮性
|
||||
// Extended test block C2: null-safety / escape boundaries / tools chain / session robustness
|
||||
@@ -751,6 +806,15 @@ int main()
|
||||
<< "provider = \"openai\"\n"
|
||||
<< "base_url = \"https://api.openai.com/v1\"\n"
|
||||
<< "api_key = \"test-key\"\n"
|
||||
<< "model = \"gpt-4o\"\n"
|
||||
<< "\n"
|
||||
<< "# minimal endpoint config for ai_endpoint_mgr / 最小 endpoint 配置供 ai_endpoint_mgr 加载\n"
|
||||
<< "[endpoints]\n"
|
||||
<< "names = \"gpt4o\"\n"
|
||||
<< "\n"
|
||||
<< "[endpoint.gpt4o]\n"
|
||||
<< "provider = \"ai_openai\"\n"
|
||||
<< "api_key = \"test-key\"\n"
|
||||
<< "model = \"gpt-4o\"\n";
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user