Add metadata validation script and module documentation

- Introduced a new Python script `check_agents_metadata.py` for validating agent metadata, including YAML parsing, rating ranges, and cross-references.
- Added usage instructions and exit codes for the script.
- Created a new markdown file `模块目录和功能说明.md` to outline the directory structure and functionality of the modules.
- Added a text file `说明此文件不可AI修改.txt` to specify that certain files should not be modified by AI, including important information about the `dstalk` framework and its modules.
This commit is contained in:
2026-05-31 00:00:58 +08:00
parent 3cc9ee95e4
commit f2da0f2ed4
43 changed files with 2467 additions and 800 deletions

View File

@@ -8,6 +8,7 @@ set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED ON) set(CMAKE_C_STANDARD_REQUIRED ON)
option(DSTALK_BUILD_GUI "Build the SDL3 GUI frontend" OFF) option(DSTALK_BUILD_GUI "Build the SDL3 GUI frontend" OFF)
option(DSTALK_BUILD_WEB "Build the web UI frontend" OFF)
option(DSTALK_BUILD_TESTS "Build dstalk tests" ON) option(DSTALK_BUILD_TESTS "Build dstalk tests" ON)
add_subdirectory(dstalk-core) add_subdirectory(dstalk-core)
@@ -18,6 +19,10 @@ if(DSTALK_BUILD_GUI)
add_subdirectory(dstalk-gui) add_subdirectory(dstalk-gui)
endif() endif()
if(DSTALK_BUILD_WEB)
add_subdirectory(dstalk-web)
endif()
if(DSTALK_BUILD_TESTS) if(DSTALK_BUILD_TESTS)
enable_testing() enable_testing()
add_subdirectory(tests) add_subdirectory(tests)

View File

@@ -1,8 +1,9 @@
// ============================================================================ /*
// dstalk-cli — 命令行前端 (使用插件化架构) * @file main.cpp
// ============================================================================ * @brief CLI frontend for dstalk: ANSI terminal UI, command parsing, streaming chat, tool calling loop, batch/pipe mode.
// 通过 dstalk_host.h API 初始化核心,然后查询插件服务 vtable 调用功能 * dstalk 命令行前端ANSI 终端界面、命令解析、流式对话、工具调用循环、批处理/管道模式
// ============================================================================ * Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
@@ -28,7 +29,7 @@
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
// ---- ANSI 简写 ---- // ---- ANSI 简写 / ANSI shorthand macros ----
#define CLR_RESET "\033[0m" #define CLR_RESET "\033[0m"
#define CLR_CYAN "\033[36m" #define CLR_CYAN "\033[36m"
#define CLR_YELLOW "\033[33m" #define CLR_YELLOW "\033[33m"
@@ -37,25 +38,36 @@
#define CLR_DIM "\033[2m" #define CLR_DIM "\033[2m"
#define CLR_BOLD "\033[1m" #define CLR_BOLD "\033[1m"
// ---- 退出码 ---- // ---- 退出码 / Exit codes ----
// 0=正常退出 1=用户中断(SIGINT/Ctrl+C) 2=致命错误 3=配置错误 // 0=正常退出 1=用户中断(SIGINT/Ctrl+C) 2=致命错误 3=配置错误
// 0=normal 1=user interrupt (SIGINT/Ctrl+C) 2=fatal error 3=config error
#define EXIT_OK 0 #define EXIT_OK 0
#define EXIT_INTERRUPT 1 #define EXIT_INTERRUPT 1
#define EXIT_FATAL 2 #define EXIT_FATAL 2
#define EXIT_CONFIG 3 #define EXIT_CONFIG 3
// ---- 服务 vtable 指针 ---- // ---- 服务 vtable 指针 / Service vtable pointers ----
// Global pointers to plugin service vtables, queried from the host on startup.
// 插件服务 vtable 的全局指针,在启动时从主机查询获取。
static const dstalk_ai_service_t* g_ai = nullptr; static const dstalk_ai_service_t* g_ai = nullptr;
static const dstalk_session_service_t* g_session = nullptr; static const dstalk_session_service_t* g_session = nullptr;
static const dstalk_file_io_service_t* g_file_io = nullptr; static const dstalk_file_io_service_t* g_file_io = nullptr;
static const dstalk_tools_service_t* g_tools = nullptr; static const dstalk_tools_service_t* g_tools = nullptr;
// ---- 运行时状态 ---- // ---- 运行时状态 / Runtime state ----
// g_current_model tracks the active model name for display in the prompt.
// g_quit_requested signals the main loop to exit (set by /quit or Ctrl+C).
// g_quit_via_signal distinguishes SIGINT-triggered exit from normal /quit.
// g_current_model 记录当前模型名称,用于提示符显示。
// g_quit_requested 通知主循环退出(由 /quit 或 Ctrl+C 设置)。
// g_quit_via_signal 区分 SIGINT 触发的退出和正常的 /quit 退出。
static std::string g_current_model; static std::string g_current_model;
static std::atomic<bool> g_quit_requested{false}; static std::atomic<bool> g_quit_requested{false};
static std::atomic<bool> g_quit_via_signal{false}; static std::atomic<bool> g_quit_via_signal{false};
// ---- Ctrl+C 信号处理 ---- // ---- Ctrl+C 信号处理 / Ctrl+C signal handlers ----
// Windows console event handler (CTRL_C_EVENT / CTRL_BREAK_EVENT).
// Windows 控制台事件处理CTRL_C_EVENT / CTRL_BREAK_EVENT
#ifdef _WIN32 #ifdef _WIN32
static BOOL WINAPI on_console_event(DWORD event) static BOOL WINAPI on_console_event(DWORD event)
{ {
@@ -66,6 +78,8 @@ static BOOL WINAPI on_console_event(DWORD event)
} }
return FALSE; return FALSE;
} }
// Unix signal handler (SIGINT).
// Unix 信号处理SIGINT
#else #else
static void on_signal(int /*sig*/) static void on_signal(int /*sig*/)
{ {
@@ -74,7 +88,9 @@ static void on_signal(int /*sig*/)
} }
#endif #endif
// ---- 工具函数 ---- // ---- 工具函数 / Utility functions ----
// 打印启动横幅 / Print the dstalk CLI banner with version, AI indicator, and quick command hints.
static void print_banner() static void print_banner()
{ {
std::printf("%sdstalk v0.1.0%s | %sdstalk AI%s | " std::printf("%sdstalk v0.1.0%s | %sdstalk AI%s | "
@@ -85,6 +101,7 @@ static void print_banner()
CLR_DIM, CLR_RESET); CLR_DIM, CLR_RESET);
} }
// 打印帮助文本 / Print the full help text listing all available slash commands.
static void print_help() static void print_help()
{ {
std::printf("\n%s命令列表:%s\n", CLR_BOLD, CLR_RESET); std::printf("\n%s命令列表:%s\n", CLR_BOLD, CLR_RESET);
@@ -104,6 +121,7 @@ static void print_help()
std::printf("\n直接输入问题即可与 AI 对话。\n\n"); std::printf("\n直接输入问题即可与 AI 对话。\n\n");
} }
// 通过 file_io 服务读取并显示文件内容 / Read and display the contents of the file at the given path via the file_io service.
static void print_file(const char* path) static void print_file(const char* path)
{ {
while (*path == ' ') path++; while (*path == ' ') path++;
@@ -122,6 +140,7 @@ static void print_file(const char* path)
} }
} }
// 列出目录内容,按文件名排序,子目录以青色高亮 / List directory entries sorted by filename, highlighting subdirectories in cyan.
static void list_files(const char* path) static void list_files(const char* path)
{ {
while (*path == ' ') path++; while (*path == ' ') path++;
@@ -155,11 +174,12 @@ static void list_files(const char* path)
} }
} }
// 分发斜杠命令 / Dispatch a slash-command string: /quit, /help, /clear, /context, /status, /model, /file, /history, /save, /load.
static void handle_command(const char* line) static void handle_command(const char* line)
{ {
if (!line || line[0] != '/') return; if (!line || line[0] != '/') return;
// /quit —— 设置退出标志,让控制流自然回到 main 末尾 // /quit —— 设置退出标志,让控制流自然回到 main 末尾 / Set quit flag to let control flow naturally return to end of main
if (std::strcmp(line, "/quit") == 0 || std::strcmp(line, "/q") == 0) { if (std::strcmp(line, "/quit") == 0 || std::strcmp(line, "/q") == 0) {
g_quit_requested = true; g_quit_requested = true;
return; return;
@@ -197,7 +217,7 @@ static void handle_command(const char* line)
return; return;
} }
// /status —— 脱敏显示当前运行状态 // /status —— 脱敏显示当前运行状态 / Display current runtime status (desensitized)
if (std::strcmp(line, "/status") == 0) { if (std::strcmp(line, "/status") == 0) {
const char* provider = dstalk_config_get("ai.provider"); const char* provider = dstalk_config_get("ai.provider");
if (!provider) provider = "ai.deepseek"; if (!provider) provider = "ai.deepseek";
@@ -246,7 +266,7 @@ static void handle_command(const char* line)
return; return;
} }
// /file <subcommand> [args...] —— 统一入口,避免 strncmp 空格匹配遗漏 // /file <subcommand> [args...] —— 统一入口,避免 strncmp 空格匹配遗漏 / Unified entry to avoid strncmp space matching issues
if (std::strncmp(line, "/file", 5) == 0) { if (std::strncmp(line, "/file", 5) == 0) {
const char* rest = line + 5; const char* rest = line + 5;
while (*rest == ' ') rest++; while (*rest == ' ') rest++;
@@ -370,7 +390,8 @@ static void handle_command(const char* line)
std::printf(CLR_RED "未知命令: %s (输入 /help 查看帮助)\n" CLR_RESET, line); std::printf(CLR_RED "未知命令: %s (输入 /help 查看帮助)\n" CLR_RESET, line);
} }
// ---- 流式回调 ---- // ---- 流式回调 / Streaming callback ----
// 流式输出回调:每收到一个 token 打印到 stdout 并刷新 / Callback invoked for each token during streaming chat; prints the token to stdout and flushes.
static int on_stream_token(const char* token, void* userdata) static int on_stream_token(const char* token, void* userdata)
{ {
bool* first = static_cast<bool*>(userdata); bool* first = static_cast<bool*>(userdata);
@@ -383,10 +404,12 @@ static int on_stream_token(const char* token, void* userdata)
return 0; return 0;
} }
// ---- 主程序 ---- // ---- 主程序 / Main entry point ----
// 入口:初始化 dstalk host查询插件服务处理 batch/pipe/交互模式。
// Entry point: initializes dstalk host, queries plugin services, handles batch/pipe/interactive modes.
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
// Windows: 启用 ANSI 转义码支持 // Windows: 启用 ANSI 转义码支持 / Windows: enable ANSI escape code support
#ifdef _WIN32 #ifdef _WIN32
HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE);
DWORD mode = 0; DWORD mode = 0;
@@ -394,7 +417,7 @@ int main(int argc, char* argv[])
SetConsoleMode(hOut, mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); SetConsoleMode(hOut, mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING);
#endif #endif
// ---- C1: batch/pipe 模式检测 ---- // ---- C1: batch/pipe 模式检测 / batch/pipe mode detection ----
#ifdef _WIN32 #ifdef _WIN32
bool pipe_mode = (_isatty(_fileno(stdin)) == 0); bool pipe_mode = (_isatty(_fileno(stdin)) == 0);
#else #else
@@ -421,17 +444,17 @@ int main(int argc, char* argv[])
} }
if (pipe_mode) batch_mode = true; if (pipe_mode) batch_mode = true;
// ---- B1: 安装 Ctrl+C 处理 ---- // ---- B1: 安装 Ctrl+C 处理 / Install Ctrl+C handlers ----
#ifdef _WIN32 #ifdef _WIN32
SetConsoleCtrlHandler(on_console_event, TRUE); SetConsoleCtrlHandler(on_console_event, TRUE);
#else #else
signal(SIGINT, on_signal); signal(SIGINT, on_signal);
#endif #endif
// 查找配置文件 // 查找配置文件 / Locate config file
const char* config_path = nullptr; const char* config_path = nullptr;
if (argc >= 2) { if (argc >= 2) {
// 跳过 --batch / --prompt 标志 // 跳过 --batch / --prompt 标志 / Skip --batch / --prompt flags
for (int i = 1; i < argc; ++i) { for (int i = 1; i < argc; ++i) {
if (std::strcmp(argv[i], "--batch") != 0 && std::strcmp(argv[i], "--prompt") != 0) { if (std::strcmp(argv[i], "--batch") != 0 && std::strcmp(argv[i], "--prompt") != 0) {
config_path = argv[i]; config_path = argv[i];
@@ -457,13 +480,13 @@ int main(int argc, char* argv[])
} }
} }
// 初始化主机(加载配置 + 自动扫描 plugins/ 目录加载插件) // 初始化主机(加载配置 + 自动扫描 plugins/ 目录加载插件) / Init host: load config + auto-scan plugins/ directory
if (dstalk_init(config_path) != 0) { if (dstalk_init(config_path) != 0) {
std::fprintf(stderr, CLR_RED "[dstalk] 初始化失败\n" CLR_RESET); std::fprintf(stderr, CLR_RED "[dstalk] 初始化失败\n" CLR_RESET);
return EXIT_CONFIG; return EXIT_CONFIG;
} }
// 查询插件服务 // 查询插件服务 / Query plugin services
const char* ai_provider = dstalk_config_get("ai.provider"); const char* ai_provider = dstalk_config_get("ai.provider");
if (!ai_provider) ai_provider = "ai.deepseek"; if (!ai_provider) ai_provider = "ai.deepseek";
g_ai = static_cast<const dstalk_ai_service_t*>(dstalk_service_query(ai_provider, 1)); g_ai = static_cast<const dstalk_ai_service_t*>(dstalk_service_query(ai_provider, 1));
@@ -478,7 +501,7 @@ int main(int argc, char* argv[])
std::fprintf(stderr, CLR_RED "[dstalk] Session 服务未找到\n" CLR_RESET); std::fprintf(stderr, CLR_RED "[dstalk] Session 服务未找到\n" CLR_RESET);
} }
// 自动从配置加载 AI 设置 // 自动从配置加载 AI 设置 / Auto-load AI settings from config
if (g_ai) { if (g_ai) {
const char* base_url = dstalk_config_get("api.base_url"); const char* base_url = dstalk_config_get("api.base_url");
const char* api_key = dstalk_config_get("api.api_key"); const char* api_key = dstalk_config_get("api.api_key");
@@ -486,7 +509,7 @@ int main(int argc, char* argv[])
if (!base_url) base_url = "https://api.deepseek.com/v1"; if (!base_url) base_url = "https://api.deepseek.com/v1";
if (!model) model = "deepseek-v4-pro"; if (!model) model = "deepseek-v4-pro";
g_ai->configure(ai_provider, base_url, api_key ? api_key : "", model, 4096, 0.7); g_ai->configure(ai_provider, base_url, api_key ? api_key : "", model, 4096, 0.7);
g_current_model = model; // A1: 记录当前模型名 g_current_model = model; // A1: 记录当前模型名 / Record current model name
} }
if (!batch_mode) { if (!batch_mode) {
@@ -495,7 +518,7 @@ int main(int argc, char* argv[])
std::printf("\n"); std::printf("\n");
} }
// ---- B3: 管道输入模式 (非交互) ---- // ---- B3: 管道输入模式 (非交互) / Pipe input mode (non-interactive) ----
if (pipe_mode) { if (pipe_mode) {
std::string input; std::string input;
char buf[4096]; char buf[4096];
@@ -529,11 +552,11 @@ int main(int argc, char* argv[])
} }
} }
// ---- --prompt 批处理模式 (非交互) ---- // ---- --prompt 批处理模式 (非交互) / --prompt batch mode (non-interactive) ----
if (prompt_arg) { if (prompt_arg) {
std::string prompt_text; std::string prompt_text;
if (std::strcmp(prompt_arg, "-") == 0) { if (std::strcmp(prompt_arg, "-") == 0) {
// --prompt - or --prompt (no arg): read prompt from stdin // --prompt - or --prompt (no arg): read prompt from stdin / --prompt - 或 --prompt无参数从 stdin 读取提示
char buf[4096]; char buf[4096];
while (std::fgets(buf, sizeof(buf), stdin)) { while (std::fgets(buf, sizeof(buf), stdin)) {
prompt_text += buf; prompt_text += buf;
@@ -575,13 +598,13 @@ int main(int argc, char* argv[])
char buffer[8192]; char buffer[8192];
while (true) { while (true) {
// B1: 检查退出标志 // B1: 检查退出标志 / Check quit flag
if (g_quit_requested) { if (g_quit_requested) {
std::printf("再见!\n"); std::printf("再见!\n");
break; break;
} }
// A1: 提示符带模型名batch 模式不打印) // A1: 提示符带模型名batch 模式不打印) / Prompt shows model name (not printed in batch mode)
if (!batch_mode) { if (!batch_mode) {
std::printf(CLR_CYAN "[%s] " CLR_RESET CLR_YELLOW "> " CLR_RESET, std::printf(CLR_CYAN "[%s] " CLR_RESET CLR_YELLOW "> " CLR_RESET,
g_current_model.empty() ? "?" : g_current_model.c_str()); g_current_model.empty() ? "?" : g_current_model.c_str());
@@ -590,14 +613,14 @@ int main(int argc, char* argv[])
if (!std::fgets(buffer, sizeof(buffer), stdin)) break; if (!std::fgets(buffer, sizeof(buffer), stdin)) break;
// C3: fgets 截断检测 // C3: fgets 截断检测 / fgets truncation detection
if (!std::strchr(buffer, '\n') && !feof(stdin)) { if (!std::strchr(buffer, '\n') && !feof(stdin)) {
std::fprintf(stderr, CLR_RED "[ERROR] 输入超过 8KB已截断。建议用文件方式dstalk --batch < file.txt\n" CLR_RESET); std::fprintf(stderr, CLR_RED "[ERROR] 输入超过 8KB已截断。建议用文件方式dstalk --batch < file.txt\n" CLR_RESET);
int c; int c;
while ((c = std::fgetc(stdin)) != '\n' && c != EOF) {} while ((c = std::fgetc(stdin)) != '\n' && c != EOF) {}
} }
// 去除末尾换行 // 去除末尾换行 / Strip trailing newline
size_t len = std::strlen(buffer); size_t len = std::strlen(buffer);
while (len > 0 && (buffer[len-1] == '\n' || buffer[len-1] == '\r')) { while (len > 0 && (buffer[len-1] == '\n' || buffer[len-1] == '\r')) {
buffer[--len] = '\0'; buffer[--len] = '\0';
@@ -605,19 +628,19 @@ int main(int argc, char* argv[])
if (len == 0) continue; if (len == 0) continue;
// 命令处理 // 命令处理 / Command dispatch
if (buffer[0] == '/') { if (buffer[0] == '/') {
handle_command(buffer); handle_command(buffer);
continue; continue;
} }
// AI 对话(通过插件服务 vtable // AI 对话(通过插件服务 vtable / AI chat (via plugin service vtable)
if (!g_ai || !g_session) { if (!g_ai || !g_session) {
std::printf(CLR_RED "[ERROR] AI 或 Session 服务不可用\n" CLR_RESET); std::printf(CLR_RED "[ERROR] AI 或 Session 服务不可用\n" CLR_RESET);
continue; continue;
} }
// 获取会话历史 // 获取会话历史 / Get session history
int history_count = 0; int history_count = 0;
const dstalk_message_t* history = g_session->history(&history_count); const dstalk_message_t* history = g_session->history(&history_count);
@@ -627,14 +650,14 @@ int main(int argc, char* argv[])
if (result.ok) { if (result.ok) {
std::printf(CLR_RESET "\n\n"); std::printf(CLR_RESET "\n\n");
// 将用户消息和 AI 回复添加到会话 // 将用户消息和 AI 回复添加到会话 / Add user message and AI reply to session
dstalk_message_t user_msg = {"user", buffer, nullptr, nullptr}; dstalk_message_t user_msg = {"user", buffer, nullptr, nullptr};
g_session->add(&user_msg); g_session->add(&user_msg);
dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json}; dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json};
g_session->add(&ai_msg); g_session->add(&ai_msg);
// W20.1: Tool Calling 闭环 // W20.1: Tool Calling 闭环 / Tool calling closed loop
// 若 AI 返回了 tool_calls自动执行工具并将结果追加到 history再调 AI // 若 AI 返回了 tool_calls自动执行工具并将结果追加到 history再调 AI / If AI returns tool_calls, auto-execute tools, append results to history, then call AI again
bool has_tool_calls = (result.tool_calls_json && result.tool_calls_json[0] != '\0'); bool has_tool_calls = (result.tool_calls_json && result.tool_calls_json[0] != '\0');
const int MAX_TOOL_ROUNDS = 5; const int MAX_TOOL_ROUNDS = 5;
int tool_round = 0; int tool_round = 0;
@@ -643,15 +666,15 @@ int main(int argc, char* argv[])
tool_round++; tool_round++;
has_tool_calls = false; has_tool_calls = false;
// 保存 tool_calls_jsonfree_result 前必须拷贝) // 保存 tool_calls_jsonfree_result 前必须拷贝) / Save tool_calls_json (must copy before free_result)
std::string tc_json(result.tool_calls_json); std::string tc_json(result.tool_calls_json);
// 解析 [{"id":"...", "function":{"name":"...", "arguments":"..."}}] // 解析 [{"id":"...", "function":{"name":"...", "arguments":"..."}}] / Parse tool calls JSON array
boost::system::error_code ec; boost::system::error_code ec;
auto tc_val = boost::json::parse(tc_json, ec); auto tc_val = boost::json::parse(tc_json, ec);
if (ec.failed() || !tc_val.is_array()) break; if (ec.failed() || !tc_val.is_array()) break;
const auto& tc_array = tc_val.as_array(); const auto& tc_array = tc_val.as_array();
if (tc_array.empty()) break; // 空数组 → 终止 if (tc_array.empty()) break; // 空数组 → 终止 / empty array → stop
bool any_executed = false; bool any_executed = false;
for (const auto& tc : tc_array) { for (const auto& tc : tc_array) {
@@ -675,7 +698,7 @@ int main(int argc, char* argv[])
std::string call_id = (id_j && id_j->is_string()) std::string call_id = (id_j && id_j->is_string())
? boost::json::value_to<std::string>(*id_j) : ""; ? boost::json::value_to<std::string>(*id_j) : "";
// 执行工具 // 执行工具 / Execute tool
std::printf(CLR_DIM "[工具调用] %s...\n" CLR_RESET, tool_name.c_str()); std::printf(CLR_DIM "[工具调用] %s...\n" CLR_RESET, tool_name.c_str());
char* exec_result = g_tools->execute(tool_name.c_str(), tool_args.c_str()); char* exec_result = g_tools->execute(tool_name.c_str(), tool_args.c_str());
if (exec_result) { if (exec_result) {
@@ -691,7 +714,7 @@ int main(int argc, char* argv[])
any_executed = true; any_executed = true;
} else { } else {
std::printf(CLR_DIM "[工具结果] fail\n" CLR_RESET); std::printf(CLR_DIM "[工具结果] fail\n" CLR_RESET);
// 单工具失败log + skip // 单工具失败log + skip / Single tool failure: log + skip
std::fprintf(stderr, CLR_YELLOW "[WARN] tool '%s' returned null, skipping\n" CLR_RESET, std::fprintf(stderr, CLR_YELLOW "[WARN] tool '%s' returned null, skipping\n" CLR_RESET,
tool_name.c_str()); tool_name.c_str());
} }
@@ -699,7 +722,7 @@ int main(int argc, char* argv[])
if (!any_executed) break; if (!any_executed) break;
// 重新调用 AIchat_stream 流式,此时 history 已包含工具结果) // 重新调用 AIchat_stream 流式,此时 history 已包含工具结果) / Re-invoke AI (chat_stream streaming, history now includes tool results)
history_count = 0; history_count = 0;
history = g_session->history(&history_count); history = g_session->history(&history_count);
@@ -728,14 +751,14 @@ int main(int argc, char* argv[])
std::fprintf(stderr, CLR_YELLOW "[WARN] 已达最大工具调用轮次(%d),停止\n" CLR_RESET, MAX_TOOL_ROUNDS); std::fprintf(stderr, CLR_YELLOW "[WARN] 已达最大工具调用轮次(%d),停止\n" CLR_RESET, MAX_TOOL_ROUNDS);
} }
} else { } else {
// A3: error 路径下需 NULL 保护;当前只取 result.errorcontent 未涉及 // A3: error 路径下需 NULL 保护;当前只取 result.errorcontent 未涉及 / Error path needs NULL guard; currently only reads result.error, content not involved
std::printf(CLR_RESET "\n" CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET, std::printf(CLR_RESET "\n" CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET,
result.error ? result.error : "unknown error"); result.error ? result.error : "unknown error");
} }
g_ai->free_result(&result); g_ai->free_result(&result);
} }
// B2: 单一退出点dstalk_shutdown 只在此调用(交互模式下) // B2: 单一退出点dstalk_shutdown 只在此调用(交互模式下) / Single exit point, dstalk_shutdown only called here (in interactive mode)
dstalk_shutdown(); dstalk_shutdown();
return g_quit_via_signal ? EXIT_INTERRUPT : EXIT_OK; return g_quit_via_signal ? EXIT_INTERRUPT : EXIT_OK;
} }

View File

@@ -1,3 +1,10 @@
/**
* @file dstalk_host.h
* @brief Host API declarations: plugin lifecycle, service registry, event bus, config, logging, memory.
* 主机 API 声明:插件生命周期、服务注册表、事件总线、配置、日志、内存管理。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#ifndef DSTALK_HOST_H #ifndef DSTALK_HOST_H
#define DSTALK_HOST_H #define DSTALK_HOST_H
@@ -8,7 +15,7 @@
extern "C" { extern "C" {
#endif #endif
// === 平台导出宏 === /* ---- 平台导出宏 / Platform export macros ---- */
#ifndef DSTALK_API #ifndef DSTALK_API
#if defined(_WIN32) #if defined(_WIN32)
#ifdef DSTALK_BUILD_DLL #ifdef DSTALK_BUILD_DLL
@@ -21,21 +28,23 @@ extern "C" {
#endif #endif
#endif #endif
// === 插件导出宏 === /* ---- 插件导出宏 / Plugin export macro ---- */
#if defined(_WIN32) #if defined(_WIN32)
#define DSTALK_PLUGIN_EXPORT __declspec(dllexport) #define DSTALK_PLUGIN_EXPORT __declspec(dllexport)
#else #else
#define DSTALK_PLUGIN_EXPORT __attribute__((visibility("default"))) #define DSTALK_PLUGIN_EXPORT __attribute__((visibility("default")))
#endif #endif
// === API 版本 === /* ---- API 版本常量 / API version constants ---- */
#define DSTALK_API_VERSION 1 #define DSTALK_API_VERSION 1 // 当前主机 API 版本,插件必须匹配 / current host API version plugins must match
#define DSTALK_MAX_DEPS 8 #define DSTALK_MAX_DEPS 8 // 插件可声明的最大依赖项数量 / maximum dependency entries a plugin can declare
// === 诊断 === /* ---- 诊断回调 / Diagnostics callback ---- */
/* 主机调用此回调用于断言失败和内部诊断 / Called by the host for assertion failures and internal diagnostics */
typedef void (*dstalk_diag_cb)(int severity, const char* file, typedef void (*dstalk_diag_cb)(int severity, const char* file,
int line, const char* func, const char* message); int line, const char* func, const char* message);
/* 断言宏: 当 expr 为假时记录错误并返回 retval / Assertion macro: logs error and returns retval if expr is false */
#define DSTALK_ERROR_RETURN(expr, retval) do { \ #define DSTALK_ERROR_RETURN(expr, retval) do { \
if (!(expr)) { \ if (!(expr)) { \
dstalk_log(DSTALK_LOG_ERROR, "[%s:%d] %s: assertion '%s' failed", \ dstalk_log(DSTALK_LOG_ERROR, "[%s:%d] %s: assertion '%s' failed", \
@@ -44,85 +53,107 @@ typedef void (*dstalk_diag_cb)(int severity, const char* file,
} \ } \
} while(0) } while(0)
/* 注册诊断回调用于内部错误报告 / Register a diagnostic callback for internal error reporting */
DSTALK_API void dstalk_set_diag_callback(dstalk_diag_cb cb); DSTALK_API void dstalk_set_diag_callback(dstalk_diag_cb cb);
// === 事件处理器 === /* ---- 事件处理器类型 / Event handler type ---- */
/* 当已订阅的事件被触发时由主机调用 / Called by the host when a subscribed event is emitted */
typedef void (*dstalk_event_handler_fn)(int event_type, const void* data, void* userdata); typedef void (*dstalk_event_handler_fn)(int event_type, const void* data, void* userdata);
// === Host 提供给插件的 API 表 === /* ---- 主机 API vtable (传递给插件的 on_init) / Host API vtable (passed to plugin's on_init) ---- */
typedef struct { typedef struct {
// 服务注册/查询 /* --- 服务注册表 / service registry --- */
int (*register_service)(const char* name, int version, void* vtable); int (*register_service)(const char* name, int version, void* vtable);
void*(*query_service)(const char* name, int min_version); void*(*query_service)(const char* name, int min_version);
// 事件 /* --- 事件总线 / event bus --- */
int (*event_subscribe)(int event_type, dstalk_event_handler_fn handler, void* userdata); int (*event_subscribe)(int event_type, dstalk_event_handler_fn handler, void* userdata);
int (*event_emit)(int event_type, const void* data); int (*event_emit)(int event_type, const void* data);
void (*event_unsubscribe)(int sub_id); void (*event_unsubscribe)(int sub_id);
// 配置 /* --- 配置管理 / configuration --- */
const char* (*config_get)(const char* key); const char* (*config_get)(const char* key);
int (*config_set)(const char* key, const char* value); int (*config_set)(const char* key, const char* value);
// 日志 /* --- 日志记录 / logging --- */
void (*log)(int level, const char* fmt, ...); void (*log)(int level, const char* fmt, ...);
// 内存 /* --- 内存管理 / memory management --- */
void* (*alloc)(size_t size); void* (*alloc)(size_t size);
void (*free)(void* ptr); void (*free)(void* ptr);
char* (*strdup)(const char* s); char* (*strdup)(const char* s);
} dstalk_host_api_t; } dstalk_host_api_t;
// === 插件信息结构 === /* ---- 插件描述符 / Plugin descriptor ---- */
/* 每个插件通过 dstalk_plugin_init() 导出此结构体 / Every plugin exports this via dstalk_plugin_init() */
typedef struct { typedef struct {
const char* name; // 插件名称(唯一标识) const char* name; // 唯一插件标识符 / unique plugin identifier
const char* version; // 语义版本号,如 "1.0.0" const char* version; // 语义版本号,如 "1.0.0" / semantic version, e.g. "1.0.0"
const char* description; // 描述 const char* description; // 人类可读的描述信息 / human-readable description
int api_version; // 必须 == DSTALK_API_VERSION int api_version; // 必须等于 DSTALK_API_VERSION / must equal DSTALK_API_VERSION
// 依赖声明(以 NULL 结尾) /* null-terminated 依赖插件名称列表 / null-terminated list of dependency plugin names */
const char* dependencies[DSTALK_MAX_DEPS]; const char* dependencies[DSTALK_MAX_DEPS];
// 生命周期回调 /* 生命周期回调 / lifecycle callbacks */
int (*on_init)(const dstalk_host_api_t* host); int (*on_init)(const dstalk_host_api_t* host);
void (*on_shutdown)(void); void (*on_shutdown)(void);
// 事件处理(可选) /* 可选: 事件总线上每个事件通过时调用 / optional: called for every event passing through the bus */
void (*on_event)(int event_type, const void* data); void (*on_event)(int event_type, const void* data);
} dstalk_plugin_info_t; } dstalk_plugin_info_t;
// === 插件入口函数 === /* ---- 插件入口点 / Plugin entry point ---- */
/* 每个共享库插件必须导出一个与此签名匹配的函数 / Every shared library plugin must export a function with this signature */
typedef dstalk_plugin_info_t* (*dstalk_plugin_init_fn)(void); typedef dstalk_plugin_info_t* (*dstalk_plugin_init_fn)(void);
// === Host 公共 API === /* ========================================================================
* 主机公共 API / Host public API
* ======================================================================== */
// 初始化/销毁 /* 使用给定的配置文件路径初始化 dstalk 主机 / Initialize the dstalk host with the given config file path */
DSTALK_API int dstalk_init(const char* config_path); DSTALK_API int dstalk_init(const char* config_path);
/* 关闭主机: 卸载插件, 释放资源 / Shut down the host: unload plugins, free resources */
DSTALK_API void dstalk_shutdown(void); DSTALK_API void dstalk_shutdown(void);
// 插件管理 /* 从共享库路径加载插件; 返回 plugin_id, 出错返回 -1 / Load a plugin from a shared library path; returns plugin_id or -1 on error */
DSTALK_API int dstalk_plugin_load(const char* path); DSTALK_API int dstalk_plugin_load(const char* path);
/* 按 id 卸载之前加载的插件 / Unload a previously loaded plugin by its id */
DSTALK_API int dstalk_plugin_unload(int plugin_id); DSTALK_API int dstalk_plugin_unload(int plugin_id);
/* 将已加载插件信息的 JSON 数组写入 *output_json (调用方释放) / Write a JSON array of loaded plugin info to *output_json (caller frees) */
DSTALK_API int dstalk_plugin_list(char** output_json); DSTALK_API int dstalk_plugin_list(char** output_json);
// 服务查询 /* 按名称和最低版本号查找已注册的服务 vtable / Look up a registered service vtable by name and minimum version */
DSTALK_API void* dstalk_service_query(const char* service_name, int min_version); DSTALK_API void* dstalk_service_query(const char* service_name, int min_version);
// 事件系统 /* 为特定事件类型订阅处理器; 返回 subscription_id / Subscribe handler to a specific event type; returns subscription_id */
DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata); DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata);
/* 向所有已订阅该类型事件的订阅者发送事件 / Emit an event to all subscribers of the given type */
DSTALK_API int dstalk_event_emit(int event_type, const void* data); DSTALK_API int dstalk_event_emit(int event_type, const void* data);
/* 按 id 移除订阅 / Remove a subscription by its id */
DSTALK_API void dstalk_event_unsubscribe(int subscription_id); DSTALK_API void dstalk_event_unsubscribe(int subscription_id);
// 配置 /* 通过键名获取配置值 (未找到返回 NULL) / Retrieve a config value by key (returns NULL if not found) */
DSTALK_API const char* dstalk_config_get(const char* key); DSTALK_API const char* dstalk_config_get(const char* key);
/* 设置配置键值对; 成功返回 0 / Set a config key/value pair; returns 0 on success */
DSTALK_API int dstalk_config_set(const char* key, const char* value); DSTALK_API int dstalk_config_set(const char* key, const char* value);
// 日志 /* 以给定严重等级记录日志消息 / Log a message at the given severity level */
DSTALK_API void dstalk_log(int level, const char* fmt, ...); DSTALK_API void dstalk_log(int level, const char* fmt, ...);
// 内存 /* 使用主机的内存分配器分配内存 / Allocate memory using the host's allocator */
DSTALK_API void* dstalk_alloc(size_t size); DSTALK_API void* dstalk_alloc(size_t size);
/* 释放之前由主机分配的内存 / Free memory previously allocated by the host */
DSTALK_API void dstalk_free(void* ptr); DSTALK_API void dstalk_free(void* ptr);
/* 使用主机的内存分配器复制 C 字符串 / Duplicate a C-string using the host's allocator */
DSTALK_API char* dstalk_strdup(const char* s); DSTALK_API char* dstalk_strdup(const char* s);
#ifdef __cplusplus #ifdef __cplusplus

View File

@@ -1,3 +1,10 @@
/**
* @file dstalk_lsp.h
* @brief Convenience C API for Language Server Protocol operations (delegates to "lsp" plugin).
* LSP语言服务器协议操作的便捷 C API委托给 "lsp" 插件)。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#ifndef DSTALK_LSP_H #ifndef DSTALK_LSP_H
#define DSTALK_LSP_H #define DSTALK_LSP_H
@@ -7,51 +14,51 @@
extern "C" { extern "C" {
#endif #endif
/* ---- LSP 服务器生命周期 ---- */ /* ---- LSP 服务器生命周期 / LSP Server Lifecycle ---- */
/* /*
* 启动语言服务器进程 * 启动语言服务器进程 / Start the language server process
* server_cmd: 命令字符串,例如 "clangd" 或 "pyright --stdio" 或完整路径 * server_cmd: 命令字符串,例如 "clangd" 或 "pyright --stdio" 或完整路径 / command string, e.g. "clangd" or "pyright --stdio" or full path
* language: 语言标识,例如 "c", "cpp", "python", "javascript", "rust" * language: 语言标识,例如 "c", "cpp", "python", "javascript", "rust" / language identifier, e.g. "c", "cpp", "python", "javascript", "rust"
* returns: 0 成功, -1 失败 * returns: 0 成功, -1 失败 / 0 success, -1 failure
*/ */
DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language); DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language);
/* /*
* 停止语言服务器 * 停止语言服务器 / Stop the language server
* 发送 shutdown 请求,然后发送 exit 通知 * 发送 shutdown 请求,然后发送 exit 通知 / sends shutdown request, then exit notification
* 关闭管道,终止子进程 * 关闭管道,终止子进程 / closes pipes, terminates child process
*/ */
DSTALK_API void dstalk_lsp_stop(void); DSTALK_API void dstalk_lsp_stop(void);
/* ---- 文档管理 ---- */ /* ---- 文档管理 / Document Management ---- */
/* /*
* 在语言服务器中打开一个文档 * 在语言服务器中打开一个文档 / Open a document in the language server
* uri: 文件 URI例如 "file:///path/to/file.c" * uri: 文件 URI例如 "file:///path/to/file.c" / file URI, e.g. "file:///path/to/file.c"
* content: 文件内容文本 * content: 文件内容文本 / file content text
* language_id: 语言 ID例如 "c", "cpp", "python", "javascript" * language_id: 语言 ID例如 "c", "cpp", "python", "javascript" / language ID, e.g. "c", "cpp", "python", "javascript"
* returns: 0 成功, -1 失败 * returns: 0 成功, -1 失败 / 0 success, -1 failure
*/ */
DSTALK_API int dstalk_lsp_open(const char* uri, const char* content, DSTALK_API int dstalk_lsp_open(const char* uri, const char* content,
const char* language_id); const char* language_id);
/* /*
* 关闭语言服务器中的文档 * 关闭语言服务器中的文档 / Close a document in the language server
* uri: 文件 URI * uri: 文件 URI / file URI
* returns: 0 成功, -1 失败 * returns: 0 成功, -1 失败 / 0 success, -1 failure
*/ */
DSTALK_API int dstalk_lsp_close(const char* uri); DSTALK_API int dstalk_lsp_close(const char* uri);
/* ---- 查询操作 ---- */ /* ---- 查询操作 / Query Operations ---- */
/* /*
* 获取诊断信息 (编译错误、警告等) * 获取诊断信息 (编译错误、警告等) / Get diagnostics (build errors, warnings, etc.)
* uri: 文件 URI * uri: 文件 URI / file URI
* output: 输出参数JSON 格式的诊断列表 (调用方通过 dstalk_free 释放) * output: 输出参数JSON 格式的诊断列表 (调用方通过 dstalk_free 释放) / output param, JSON list of diagnostics (caller frees via dstalk_free)
* returns: 0 成功, -1 失败 * returns: 0 成功, -1 失败 / 0 success, -1 failure
* *
* JSON 输出格式示例: * JSON 输出格式示例 / JSON output format example:
* [ * [
* { * {
* "range": { "start": {"line":0,"character":0}, "end":{"line":0,"character":5} }, * "range": { "start": {"line":0,"character":0}, "end":{"line":0,"character":5} },
@@ -63,23 +70,23 @@ DSTALK_API int dstalk_lsp_close(const char* uri);
DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output); DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output);
/* /*
* 获取悬停信息 (类型、文档等) * 获取悬停信息 (类型、文档等) / Get hover info (type, documentation, etc.)
* uri: 文件 URI * uri: 文件 URI / file URI
* line: 行号 (0-based) * line: 行号 (0-based) / line number (0-based)
* character: 列号 (0-based, UTF-16 code units) * character: 列号 (0-based, UTF-16 code units) / column number (0-based, UTF-16 code units)
* output: 输出参数JSON 格式的悬停信息 (调用方通过 dstalk_free 释放) * output: 输出参数JSON 格式的悬停信息 (调用方通过 dstalk_free 释放) / output param, JSON hover info (caller frees via dstalk_free)
* returns: 0 成功, -1 失败 * returns: 0 成功, -1 失败 / 0 success, -1 failure
*/ */
DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character, DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character,
char** output); char** output);
/* /*
* 获取代码补全建议 * 获取代码补全建议 / Get code completion suggestions
* uri: 文件 URI * uri: 文件 URI / file URI
* line: 行号 (0-based) * line: 行号 (0-based) / line number (0-based)
* character: 列号 (0-based, UTF-16 code units) * character: 列号 (0-based, UTF-16 code units) / column number (0-based, UTF-16 code units)
* output: 输出参数JSON 格式的补全列表 (调用方通过 dstalk_free 释放) * output: 输出参数JSON 格式的补全列表 (调用方通过 dstalk_free 释放) / output param, JSON completion list (caller frees via dstalk_free)
* returns: 0 成功, -1 失败 * returns: 0 成功, -1 失败 / 0 success, -1 failure
*/ */
DSTALK_API int dstalk_lsp_completion(const char* uri, int line, int character, DSTALK_API int dstalk_lsp_completion(const char* uri, int line, int character,
char** output); char** output);

View File

@@ -1,3 +1,10 @@
/**
* @file dstalk_services.h
* @brief Service vtable definitions for all plugin-provided services (AI, Session, HTTP, etc.).
* 所有插件提供的服务 vtable 定义AI、会话、HTTP 等)。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#ifndef DSTALK_SERVICES_H #ifndef DSTALK_SERVICES_H
#define DSTALK_SERVICES_H #define DSTALK_SERVICES_H
@@ -7,46 +14,64 @@
extern "C" { extern "C" {
#endif #endif
// === AI 服务 vtable (实际服务名由插件注册: "ai.deepseek" / "ai.anthropic") === /* ---- AI 服务 vtable / AI service vtable ---- */
/* 以名称如 "ai.deepseek" 或 "ai.anthropic" 注册 / Registered under names such as "ai.deepseek" or "ai.anthropic" */
typedef struct { typedef struct {
/* 配置服务商连接 (base_url, api_key, model 等) / Configure provider connection (base_url, api_key, model, etc.) */
int (*configure)(const char* provider, const char* base_url, int (*configure)(const char* provider, const char* base_url,
const char* api_key, const char* model, const char* api_key, const char* model,
int max_tokens, double temperature); int max_tokens, double temperature);
/* 发送单轮聊天补全请求 (阻塞) / Send a single-turn chat completion (blocking) */
dstalk_chat_result_t (*chat)( dstalk_chat_result_t (*chat)(
const dstalk_message_t* history, int history_len, const dstalk_message_t* history, int history_len,
const char* user_input, const char* user_input,
const char* tools_json); const char* tools_json);
/* 通过回调实现流式令牌传输的聊天补全 / Send a chat completion with streaming tokens via callback */
dstalk_chat_result_t (*chat_stream)( dstalk_chat_result_t (*chat_stream)(
const dstalk_message_t* history, int history_len, const dstalk_message_t* history, int history_len,
const char* user_input, const char* user_input,
dstalk_stream_cb cb, void* userdata); dstalk_stream_cb cb, void* userdata);
/* 释放 dstalk_chat_result_t 持有的资源 / Free resources held by a dstalk_chat_result_t */
void (*free_result)(dstalk_chat_result_t* result); void (*free_result)(dstalk_chat_result_t* result);
} dstalk_ai_service_t; } dstalk_ai_service_t;
// === Session 服务 (service name: "session") === /* ---- 会话服务 vtable / Session service vtable ---- */
/* 以服务名称 "session" 注册 / Registered under service name "session" */
typedef struct { typedef struct {
/* 将消息追加到会话历史 / Append a message to the session history */
void (*add)(const dstalk_message_t* msg); void (*add)(const dstalk_message_t* msg);
/* 清除会话历史中的所有消息 / Clear all messages from the session history */
void (*clear)(void); void (*clear)(void);
/* 将会话历史保存到文件 (JSON); 成功返回 0 / Save session history to a file (JSON); returns 0 on success */
int (*save)(const char* path); int (*save)(const char* path);
/* 从文件 (JSON) 加载会话历史; 成功返回 0 / Load session history from a file (JSON); returns 0 on success */
int (*load)(const char* path); int (*load)(const char* path);
/* 获取完整消息历史; out_count 接收数组长度 / Get the full message history; out_count receives the array length */
const dstalk_message_t* (*history)(int* out_count); const dstalk_message_t* (*history)(int* out_count);
/* 返回当前会话历史的近似令牌数 / Return the approximate token count of the current session history */
int (*token_count)(void); int (*token_count)(void);
} dstalk_session_service_t; } dstalk_session_service_t;
// === Context 服务 (service name: "context") === /* ---- 上下文服务 vtable / Context service vtable ---- */
/* 以服务名称 "context" 注册 / Registered under service name "context" */
typedef struct { typedef struct {
/* 计算消息数组中近似的令牌数 / Count approximate tokens in an array of messages */
size_t (*count_tokens)(const dstalk_message_t* msgs, int count); size_t (*count_tokens)(const dstalk_message_t* msgs, int count);
/* 裁剪消息历史以适应 max_tokens; out/out_count 为新分配 / Trim message history to fit within max_tokens; out/out_count are newly allocated */
int (*trim)(const dstalk_message_t* in, int in_count, int (*trim)(const dstalk_message_t* in, int in_count,
dstalk_message_t** out, int* out_count, dstalk_message_t** out, int* out_count,
size_t max_tokens); size_t max_tokens);
} dstalk_context_service_t; } dstalk_context_service_t;
// === HTTP 服务 (service name: "http") === /* ---- HTTP 服务 vtable / HTTP service vtable ---- */
/* 以服务名称 "http" 注册 / Registered under service name "http" */
typedef struct { typedef struct {
/* POST JSON 体到主机; 返回响应体和 HTTP 状态码 / POST JSON body to a host; returns response body and HTTP status code */
int (*post_json)(const char* host, const char* port, int (*post_json)(const char* host, const char* port,
const char* target, const char* body, const char* target, const char* body,
const char* headers_json, const char* headers_json,
char** response_body, int* status_code); char** response_body, int* status_code);
/* POST 带流式响应; 令牌通过回调传递 / POST with streaming response; tokens are delivered via callback */
int (*post_stream)(const char* host, const char* port, int (*post_stream)(const char* host, const char* port,
const char* target, const char* body, const char* target, const char* body,
const char* headers_json, const char* headers_json,
@@ -54,38 +79,61 @@ typedef struct {
char** response_body, int* status_code); char** response_body, int* status_code);
} dstalk_http_service_t; } dstalk_http_service_t;
// === File IO 服务 (service name: "file_io") === /* ---- 文件 I/O 服务 vtable / File I/O service vtable ---- */
/* 以服务名称 "file_io" 注册 / Registered under service name "file_io" */
typedef struct { typedef struct {
/* 读取整个文件内容到 *content; 成功返回 0 / Read entire file content into *content; returns 0 on success */
int (*read)(const char* path, char** content); int (*read)(const char* path, char** content);
/* 将内容写入文件 (覆盖已有文件); 成功返回 0 / Write content to a file (overwrites if exists); returns 0 on success */
int (*write)(const char* path, const char* content); int (*write)(const char* path, const char* content);
} dstalk_file_io_service_t; } dstalk_file_io_service_t;
// === Config 服务 (service name: "config") === /* ---- 配置服务 vtable / Config service vtable ---- */
/* 以服务名称 "config" 注册 / Registered under service name "config" */
typedef struct { typedef struct {
/* 通过键名获取配置值; 未找到返回 NULL / Get a config value by key; returns NULL if not found */
const char* (*get)(const char* key); const char* (*get)(const char* key);
/* 设置配置键值对; 成功返回 0 / Set a config key/value pair; returns 0 on success */
int (*set)(const char* key, const char* value); int (*set)(const char* key, const char* value);
/* 从 JSON 配置文件加载并合并键值对 / Load and merge key/value pairs from a JSON config file */
int (*load_file)(const char* path); int (*load_file)(const char* path);
} dstalk_config_service_t; } dstalk_config_service_t;
// === Tools 服务 (service name: "tools") === /* ---- 工具服务 vtable / Tools service vtable ---- */
/* 以服务名称 "tools" 注册 / Registered under service name "tools" */
/* 已注册工具被调用时触发的处理器; 接收 JSON 参数, 返回 JSON 结果 / Handler invoked when a registered tool is called; receives JSON args, returns JSON result */
typedef char* (*dstalk_tool_handler_fn)(const char* args_json); typedef char* (*dstalk_tool_handler_fn)(const char* args_json);
typedef struct { typedef struct {
/* 注册工具,包含名称、描述和 JSON Schema 参数 / Register a tool with name, description, and JSON Schema parameters */
int (*register_tool)(const char* name, const char* desc, int (*register_tool)(const char* name, const char* desc,
const char* params_schema, const char* params_schema,
dstalk_tool_handler_fn handler); dstalk_tool_handler_fn handler);
/* 取消注册之前注册的工具 / Unregister a previously registered tool */
void (*unregister_tool)(const char* name); void (*unregister_tool)(const char* name);
/* 获取所有已注册工具为 JSON 数组 (OpenAI 工具格式) / Get all registered tools as a JSON array (OpenAI tool format) */
char* (*get_tools_json)(void); char* (*get_tools_json)(void);
/* 按名称执行已注册工具,传入 JSON 参数 / Execute a registered tool by name with the given JSON arguments */
char* (*execute)(const char* name, const char* args_json); char* (*execute)(const char* name, const char* args_json);
} dstalk_tools_service_t; } dstalk_tools_service_t;
// === LSP 服务 (service name: "lsp") === /* ---- LSP 服务 vtable / LSP service vtable ---- */
/* 以服务名称 "lsp" 注册 / Registered under service name "lsp" */
typedef struct { typedef struct {
/* 启动指定语言的 LSP 服务器进程 / Start an LSP server process for the given language */
int (*start)(const char* server_cmd, const char* language); int (*start)(const char* server_cmd, const char* language);
/* 停止 LSP 服务器并清理资源 / Stop the LSP server and clean up resources */
void (*stop)(void); void (*stop)(void);
/* 在 LSP 服务器中打开文档 / Open a document in the LSP server */
int (*open_document)(const char* uri, const char* content, const char* lang_id); int (*open_document)(const char* uri, const char* content, const char* lang_id);
/* 在 LSP 服务器中关闭文档 / Close a document in the LSP server */
int (*close_document)(const char* uri); int (*close_document)(const char* uri);
/* 获取文档的诊断信息 (错误、警告) 以 JSON 格式返回 / Retrieve diagnostics (errors, warnings) for a document as JSON */
int (*get_diagnostics)(const char* uri, char** json_out); int (*get_diagnostics)(const char* uri, char** json_out);
/* 获取指定位置的悬停信息以 JSON 格式返回 / Retrieve hover information at a given position as JSON */
int (*get_hover)(const char* uri, int line, int col, char** json_out); int (*get_hover)(const char* uri, int line, int col, char** json_out);
/* 获取指定位置的代码补全建议以 JSON 格式返回 / Retrieve code completion suggestions at a given position as JSON */
int (*get_completion)(const char* uri, int line, int col, char** json_out); int (*get_completion)(const char* uri, int line, int col, char** json_out);
} dstalk_lsp_service_t; } dstalk_lsp_service_t;

View File

@@ -1,3 +1,10 @@
/**
* @file dstalk_types.h
* @brief Shared data types used across the dstalk host and all plugins.
* 跨主机和所有插件共享的数据类型定义。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#ifndef DSTALK_TYPES_H #ifndef DSTALK_TYPES_H
#define DSTALK_TYPES_H #define DSTALK_TYPES_H
@@ -7,42 +14,42 @@
extern "C" { extern "C" {
#endif #endif
// 消息结构(跨插件共享) /* 所有插件共享的消息结构体 / Shared message structure used across plugins */
typedef struct { typedef struct {
const char* role; // "user", "assistant", "system", "tool" const char* role; // 角色标识 / Role identifier ("user", "assistant", "system", "tool")
const char* content; // 消息内容 const char* content; // 消息正文文本 / Message body text
const char* tool_call_id; // tool 响应时必填 const char* tool_call_id; // 工具调用响应消息所需 / Required for tool response messages
const char* tool_calls_json;// assistant 返回的工具调用JSON 数组 const char* tool_calls_json;// 助手工具调用JSON 数组 / JSON array of tool calls from assistant
} dstalk_message_t; } dstalk_message_t;
// 聊天结果 /* 聊天/补全调用返回的结果 / Result returned from a chat / completion call */
typedef struct { typedef struct {
int ok; int ok; // 0=失败, 非零=成功 / 0 = failure, non-zero = success
const char* content; // dstalk_strdup 分配调用方 dstalk_free const char* content; // dstalk_strdup 分配; 调用方用 dstalk_free 释放 / allocated by dstalk_strdup; caller frees with dstalk_free
const char* error; // dstalk_strdup 分配 const char* error; // dstalk_strdup 分配; 成功时为 NULL / allocated by dstalk_strdup; NULL on success
int http_status; int http_status; // 服务商返回的 HTTP 状态码 / HTTP status code from the provider
const char* tool_calls_json;// dstalk_strdup 分配 const char* tool_calls_json;// dstalk_strdup 分配; 工具调用的 JSON 数组 / allocated by dstalk_strdup; JSON array of tool calls
} dstalk_chat_result_t; } dstalk_chat_result_t;
// 流式回调 /* 流式令牌回调: 返回非零值提前中止流传输 / Streaming token callback: return non-zero to abort the stream early */
typedef int (*dstalk_stream_cb)(const char* token, void* userdata); typedef int (*dstalk_stream_cb)(const char* token, void* userdata);
// 事件类型 /* 事件类型代码 (匿名枚举) / Event type codes (anonymous enum) */
enum { enum {
DSTALK_EVENT_MESSAGE = 1, // data = dstalk_message_t* DSTALK_EVENT_MESSAGE = 1, // 数据为 dstalk_message_t* / data = dstalk_message_t*
DSTALK_EVENT_SESSION_CLEAR, DSTALK_EVENT_SESSION_CLEAR, // 会话历史已清除 / session history cleared
DSTALK_EVENT_CONFIG_CHANGED, DSTALK_EVENT_CONFIG_CHANGED, // 配置键/值已更新 / configuration key/value updated
DSTALK_EVENT_PLUGIN_LOADED, // data = plugin info JSON string DSTALK_EVENT_PLUGIN_LOADED, // 数据为插件信息 JSON 字符串 / data = plugin info JSON string
DSTALK_EVENT_PLUGIN_UNLOADED, DSTALK_EVENT_PLUGIN_UNLOADED, // 插件已卸载 / plugin unloaded
DSTALK_EVENT_CUSTOM = 1000, // 插件自定义事件起始值 DSTALK_EVENT_CUSTOM = 1000, // 插件自定义事件的基础值 / base value for plugin-defined custom events
}; };
// 日志级别 /* 日志严重等级 (匿名枚举) / Log severity levels (anonymous enum) */
enum { enum {
DSTALK_LOG_DEBUG = 0, DSTALK_LOG_DEBUG = 0, // 详细调试消息 / verbose debug messages
DSTALK_LOG_INFO = 1, DSTALK_LOG_INFO = 1, // 信息性消息 / informational messages
DSTALK_LOG_WARN = 2, DSTALK_LOG_WARN = 2, // 警告条件 / warning conditions
DSTALK_LOG_ERROR = 3, DSTALK_LOG_ERROR = 3, // 错误条件 / error conditions
}; };
#ifdef __cplusplus #ifdef __cplusplus

View File

@@ -1 +1,7 @@
/* @file boost_json.cpp
* @brief Boost.JSON header-only library compilation unit (single TU inclusion).
* Boost.JSON 仅头文件库的编译单元(单翻译单元包含)。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include <boost/json/src.hpp> #include <boost/json/src.hpp>

View File

@@ -1,3 +1,9 @@
/* @file config_store.cpp
* @brief ConfigStore implementation: TOML parsing, thread-safe get/set with thread-local safety.
* ConfigStore 实现TOML 解析、线程安全的 get/set基于 thread-local 安全机制)。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "config_store.hpp" #include "config_store.hpp"
#include "../../plugins/config/include/toml_parse.h" #include "../../plugins/config/include/toml_parse.h"
@@ -8,6 +14,7 @@
namespace dstalk { namespace dstalk {
// 在互斥锁下加载并解析 TOML 文件到键值存储 / Load and parse a TOML file into the key-value store under mutex.
int ConfigStore::load_file(const char* path) int ConfigStore::load_file(const char* path)
{ {
if (!path) return -1; if (!path) return -1;
@@ -19,7 +26,7 @@ int ConfigStore::load_file(const char* path)
ss << file.rdbuf(); ss << file.rdbuf();
std::string data = ss.str(); std::string data = ss.str();
// W12.2: Use shared TOML parser (de-duplicated from config_plugin.cpp) // W12.2: 使用共享 TOML 解析器(从 config_plugin.cpp 去重) / Use shared TOML parser (de-duplicated from config_plugin.cpp)
toml::parse(data, [this](const std::string& key, const std::string& value) { toml::parse(data, [this](const std::string& key, const std::string& value) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
data_[key] = value; data_[key] = value;
@@ -28,6 +35,7 @@ int ConfigStore::load_file(const char* path)
return 0; return 0;
} }
// 检索配置值,返回线程本地副本以避免 c_str() 悬空 / Retrieve config value, returning a thread-local copy to avoid dangling c_str().
const char* ConfigStore::get(const char* key) const const char* ConfigStore::get(const char* key) const
{ {
if (!key) return nullptr; if (!key) return nullptr;
@@ -35,7 +43,9 @@ const char* ConfigStore::get(const char* key) const
auto it = data_.find(key); auto it = data_.find(key);
if (it == data_.end()) return nullptr; if (it == data_.end()) return nullptr;
// W12.2: Copy to thread-local buffer before releasing lock. // W12.2: 在释放锁之前复制到线程本地缓冲区 /
// Copy to thread-local buffer before releasing lock.
// 防止当并发 set() 触发 std::string 重新分配时 c_str() 悬空 /
// Prevents c_str() dangling when concurrent set() on the same key // Prevents c_str() dangling when concurrent set() on the same key
// triggers std::string reallocation (W11.2 audit Finding 3). // triggers std::string reallocation (W11.2 audit Finding 3).
thread_local std::string tls_cached; thread_local std::string tls_cached;
@@ -43,15 +53,17 @@ const char* ConfigStore::get(const char* key) const
return tls_cached.c_str(); return tls_cached.c_str();
} }
// 以 std::string 值类型检索配置(安全的值副本)/ Retrieve config value as an owned std::string (safe by-value copy).
std::string ConfigStore::get_copy(const char* key) const std::string ConfigStore::get_copy(const char* key) const
{ {
if (!key) return {}; if (!key) return {};
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto it = data_.find(key); auto it = data_.find(key);
if (it == data_.end()) return {}; if (it == data_.end()) return {};
return it->second; // copy-constructed under lock, always safe return it->second; // 在锁下复制构造,始终安全 / copy-constructed under lock, always safe
} }
// 在锁下设置配置键值对 / Set a config key-value pair under lock.
int ConfigStore::set(const char* key, const char* value) int ConfigStore::set(const char* key, const char* value)
{ {
if (!key || !value) return -1; if (!key || !value) return -1;

View File

@@ -1,3 +1,9 @@
/* @file config_store.hpp
* @brief Thread-safe key-value configuration store with TOML file loading.
* 线程安全键值配置存储,支持 TOML 文件加载。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#pragma once #pragma once
#include <mutex> #include <mutex>
@@ -6,32 +12,36 @@
namespace dstalk { namespace dstalk {
// 线程安全的键值存储,支持 TOML 配置文件 / Thread-safe key-value store backed by TOML config files.
// 通过 mutex_ 支持并发读取get() 返回线程本地缓冲区 / Supports concurrent reads via mutex_ and returns thread-local buffers from get().
class ConfigStore { class ConfigStore {
public: public:
ConfigStore() = default; ConfigStore() = default;
~ConfigStore() = default; ~ConfigStore() = default;
// Load key-value pairs from a TOML file. // 从 TOML 文件加载键值对 / Load key-value pairs from a TOML file.
// Returns 0 on success, -1 if file not found or path is null. // 成功返回 0文件未找到或路径为空返回 -1 / Returns 0 on success, -1 if file not found or path is null.
int load_file(const char* path); int load_file(const char* path);
// Get config value (returns internal pointer, thread-safe). // 获取配置值(返回内部指针,线程安全)/ Get config value (returns internal pointer, thread-safe).
// W12.2: Returned pointer is now backed by a thread-local copy; // W12.2: 返回的指针现在由线程局部副本支持,对其他线程对同一键的并发 set() 安全 /
// Returned pointer is now backed by a thread-local copy;
// safe against concurrent set() on the same key from other threads. // safe against concurrent set() on the same key from other threads.
// 调用者仍应立即使用 — 同一线程上的下一次 get() 将覆盖缓冲区 /
// Caller should still consume immediately — next get() on same // Caller should still consume immediately — next get() on same
// thread will overwrite the buffer. // thread will overwrite the buffer.
const char* get(const char* key) const; const char* get(const char* key) const;
// Get a safe by-value copy of a config entry (no dangling risk). // 获取配置项的安全值副本(无悬空风险)/ Get a safe by-value copy of a config entry (no dangling risk).
// Returns empty string if key not found. // 如果键未找到,返回空字符串 / Returns empty string if key not found.
std::string get_copy(const char* key) const; std::string get_copy(const char* key) const;
// Set config value. Returns 0 on success, -1 on null arguments. // 设置配置值 / Set config value. 成功返回 0参数为空返回 -1 / Returns 0 on success, -1 on null arguments.
int set(const char* key, const char* value); int set(const char* key, const char* value);
private: private:
mutable std::mutex mutex_; mutable std::mutex mutex_; // 保护所有 data_ 访问 / Protects all data_ access
std::unordered_map<std::string, std::string> data_; std::unordered_map<std::string, std::string> data_; // 配置键值存储 / Config key-value store
}; };
} // namespace dstalk } // namespace dstalk

View File

@@ -1,9 +1,16 @@
/* @file event_bus.cpp
* @brief EventBus implementation: subscribe, unsubscribe, emit with reader-writer locking.
* EventBus 实现:基于读写锁的 subscribe、unsubscribe、emit。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "event_bus.hpp" #include "event_bus.hpp"
#include <algorithm> #include <algorithm>
namespace dstalk { namespace dstalk {
// 为给定事件类型注册处理器,返回订阅 ID / Register a handler for the given event type, returning a subscription id.
int EventBus::subscribe(int event_type, EventHandler handler) int EventBus::subscribe(int event_type, EventHandler handler)
{ {
std::unique_lock<std::shared_mutex> lock(mutex_); std::unique_lock<std::shared_mutex> lock(mutex_);
@@ -12,6 +19,7 @@ int EventBus::subscribe(int event_type, EventHandler handler)
return id; return id;
} }
// 通过 ID 移除订阅(如果 ID 未找到则无操作)/ Remove a subscription by id (no-op if id not found).
void EventBus::unsubscribe(int subscription_id) void EventBus::unsubscribe(int subscription_id)
{ {
std::unique_lock<std::shared_mutex> lock(mutex_); std::unique_lock<std::shared_mutex> lock(mutex_);
@@ -23,6 +31,8 @@ void EventBus::unsubscribe(int subscription_id)
subscriptions_.end()); subscriptions_.end());
} }
// 在共享锁下将事件分发给所有匹配的订阅者 / Dispatch an event to all matching subscribers under a shared lock.
// 返回被调用的处理器数量 / Returns the count of handlers invoked.
int EventBus::emit(int event_type, const void* data) int EventBus::emit(int event_type, const void* data)
{ {
std::shared_lock<std::shared_mutex> lock(mutex_); std::shared_lock<std::shared_mutex> lock(mutex_);

View File

@@ -1,3 +1,9 @@
/* @file event_bus.hpp
* @brief Publish-subscribe event bus with shared_mutex for concurrent read access.
* 发布-订阅事件总线,使用 shared_mutex 支持并发读访问。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#pragma once #pragma once
#include <functional> #include <functional>
@@ -10,18 +16,23 @@ namespace dstalk {
using EventHandler = std::function<void(int event_type, const void* data)>; using EventHandler = std::function<void(int event_type, const void* data)>;
// 轻量级发布-订阅事件总线 / Lightweight pub-sub event bus.
// 读取者使用 shared_lockemit因此多个处理器可以并发分发
// 写入者使用 unique_locksubscribe / unsubscribe
// Readers use shared_lock (emit) so multiple handlers can be dispatched
// concurrently; writers use unique_lock (subscribe / unsubscribe).
class EventBus { class EventBus {
public: public:
EventBus() = default; EventBus() = default;
~EventBus() = default; ~EventBus() = default;
// 订阅事件返回订阅ID // 订阅事件返回订阅ID / Subscribe to an event, returning a subscription id
int subscribe(int event_type, EventHandler handler); int subscribe(int event_type, EventHandler handler);
// 取消订阅 // 取消订阅 / Unsubscribe by subscription id
void unsubscribe(int subscription_id); void unsubscribe(int subscription_id);
// 发布事件 // 发布事件 / Emit an event to all matching subscribers
int emit(int event_type, const void* data); int emit(int event_type, const void* data);
private: private:
@@ -31,9 +42,9 @@ private:
EventHandler handler; EventHandler handler;
}; };
mutable std::shared_mutex mutex_; mutable std::shared_mutex mutex_; // 读写锁emit 用 sharedsubscribe/unsubscribe 用 unique / RW lock: shared for emit, unique for subscribe/unsubscribe
std::vector<Subscription> subscriptions_; std::vector<Subscription> subscriptions_; // emit 时线性扫描;对少量订阅者足够 / Linear scan on emit; ok for small subscriber counts
int next_id_ = 1; int next_id_ = 1; // 单调递增订阅 ID 计数器 / Monotonic subscription id counter
}; };
} // namespace dstalk } // namespace dstalk

View File

@@ -1,3 +1,10 @@
/*
* @file host.cpp
* @brief Core host orchestrator: global singletons, dstalk_host_api_t instantiation, public C API, LSP delegation.
* 核心主机协调器全局单例、dstalk_host_api_t 实例化、公共 C API、LSP 委托。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "config_store.hpp" #include "config_store.hpp"
#include "event_bus.hpp" #include "event_bus.hpp"
@@ -15,7 +22,7 @@
namespace fs = std::filesystem; namespace fs = std::filesystem;
// ============================================================ // ============================================================
// 全局主机上下文 // 全局主机上下文 / Global host context
// ============================================================ // ============================================================
namespace { namespace {
std::mutex g_init_mutex; std::mutex g_init_mutex;
@@ -27,8 +34,10 @@ namespace {
dstalk::PluginLoader* g_plugin_loader = nullptr; dstalk::PluginLoader* g_plugin_loader = nullptr;
static std::atomic<dstalk_diag_cb> g_diag_callback{nullptr}; static std::atomic<dstalk_diag_cb> g_diag_callback{nullptr};
// ---- 内部辅助 ---- // ---- 内部辅助 / Internal helpers ----
// 复制 C 字符串(用 malloc 分配,调用者必须用 api_free/free 释放)
// Duplicate a C string allocated with malloc (caller must free via api_free/free).
char* host_strdup(const char* s) { char* host_strdup(const char* s) {
if (!s) return nullptr; if (!s) return nullptr;
size_t len = strlen(s); size_t len = strlen(s);
@@ -37,6 +46,8 @@ namespace {
return copy; return copy;
} }
// 核心日志实现:格式化消息,写入 stderr并转发到诊断回调如果已设置
// Core logging implementation: formats message, writes to stderr, and forwards to diagnostic callback if set.
void host_log_impl(int level, const char* fmt, va_list args) { void host_log_impl(int level, const char* fmt, va_list args) {
const char* prefix = ""; const char* prefix = "";
switch (level) { switch (level) {
@@ -50,7 +61,7 @@ namespace {
va_copy(args_copy, args); va_copy(args_copy, args);
vfprintf(stderr, fmt, args); vfprintf(stderr, fmt, args);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
// 转发到诊断回调 // 转发到诊断回调 / Forward to diagnostic callback
auto cb = g_diag_callback.load(std::memory_order_acquire); auto cb = g_diag_callback.load(std::memory_order_acquire);
if (cb) { if (cb) {
char buf[1024]; char buf[1024];
@@ -60,6 +71,8 @@ namespace {
va_end(args_copy); va_end(args_copy);
} }
// host_log_impl 的 printf 风格便捷包装。
// Convenience wrapper around host_log_impl for printf-style calls.
void host_log(int level, const char* fmt, ...) { void host_log(int level, const char* fmt, ...) {
va_list args; va_list args;
va_start(args, fmt); va_start(args, fmt);
@@ -67,16 +80,22 @@ namespace {
va_end(args); va_end(args);
} }
// ---- Host API 表回调 ---- // ---- Host API 表回调 / Host API table callbacks ----
// 将服务 vtable 按名称和版本注册到全局注册表。
// Register a service vtable with the given name and version into the global registry.
int api_register_service(const char* name, int version, void* vtable) { int api_register_service(const char* name, int version, void* vtable) {
return g_service_registry ? g_service_registry->register_service(name, version, vtable) : -1; return g_service_registry ? g_service_registry->register_service(name, version, vtable) : -1;
} }
// 按名称和最低版本从全局注册表查询服务 vtable。
// Query a service vtable by name and minimum version from the global registry.
void* api_query_service(const char* name, int min_version) { void* api_query_service(const char* name, int min_version) {
return g_service_registry ? g_service_registry->query_service(name, min_version) : nullptr; return g_service_registry ? g_service_registry->query_service(name, min_version) : nullptr;
} }
// 通过全局事件总线订阅指定事件类型的处理函数。
// Subscribe a handler to a given event type via the global event bus.
int api_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata) { int api_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata) {
if (!g_event_bus || !handler) return -1; if (!g_event_bus || !handler) return -1;
return g_event_bus->subscribe(event_type, return g_event_bus->subscribe(event_type,
@@ -85,22 +104,32 @@ namespace {
}); });
} }
// 通过全局事件总路线程安全地向所有已注册处理函数发送事件。
// Emit an event to all registered handlers via the global event bus.
int api_event_emit(int event_type, const void* data) { int api_event_emit(int event_type, const void* data) {
return g_event_bus ? g_event_bus->emit(event_type, data) : -1; return g_event_bus ? g_event_bus->emit(event_type, data) : -1;
} }
// 通过订阅 ID 取消注册之前的事件处理函数。
// Unsubscribe a previously registered event handler by subscription ID.
void api_event_unsubscribe(int sub_id) { void api_event_unsubscribe(int sub_id) {
if (g_event_bus) g_event_bus->unsubscribe(sub_id); if (g_event_bus) g_event_bus->unsubscribe(sub_id);
} }
// 从全局配置存储中按键名读取配置值。
// Read a config value by key from the global config store.
const char* api_config_get(const char* key) { const char* api_config_get(const char* key) {
return g_config ? g_config->get(key) : nullptr; return g_config ? g_config->get(key) : nullptr;
} }
// 在全局配置存储中设置配置键值对。
// Set a config key/value pair in the global config store.
int api_config_set(const char* key, const char* value) { int api_config_set(const char* key, const char* value) {
return g_config ? g_config->set(key, value) : -1; return g_config ? g_config->set(key, value) : -1;
} }
// 主机端日志函数host_log_impl 的 varargs 包装)。
// Host-facing log function (varargs wrapper around host_log_impl).
void api_log(int level, const char* fmt, ...) { void api_log(int level, const char* fmt, ...) {
va_list args; va_list args;
va_start(args, fmt); va_start(args, fmt);
@@ -108,11 +137,16 @@ namespace {
va_end(args); va_end(args);
} }
// 内存分配包装 / Memory allocation wrapper (malloc).
void* api_alloc(size_t size) { return malloc(size); } void* api_alloc(size_t size) { return malloc(size); }
// 内存释放包装 / Memory free wrapper (free).
void api_free(void* ptr) { free(ptr); } void api_free(void* ptr) { free(ptr); }
// 字符串复制包装 / String duplication wrapper (host_strdup).
char* api_strdup(const char* s) { return host_strdup(s); } char* api_strdup(const char* s) { return host_strdup(s); }
// 传递给每个插件 on_init 的完整主机 API vtable。
// The complete host API vtable passed to every plugin's on_init.
dstalk_host_api_t g_host_api = { dstalk_host_api_t g_host_api = {
api_register_service, api_register_service,
api_query_service, api_query_service,
@@ -127,8 +161,12 @@ namespace {
api_strdup api_strdup
}; };
// ---- 插件目录扫描 ---- // ---- 插件目录扫描 / Plugin directory scanning ----
// 扫描目录中的插件 DLL 并通过 PluginLoader 加载。
// 返回加载的插件数量,出错返回 -1。
// Scan a directory for plugin DLLs and load them via PluginLoader.
// Returns the number of plugins loaded, or -1 on error.
int load_plugins_from_directory(const char* plugin_dir) { int load_plugins_from_directory(const char* plugin_dir) {
if (!plugin_dir) return -1; if (!plugin_dir) return -1;
@@ -163,9 +201,11 @@ namespace {
} }
// ============================================================ // ============================================================
// 公共 API // 公共 API / Public API
// ============================================================ // ============================================================
// 初始化 dstalk 主机:创建单例、加载配置、扫描插件、初始化所有插件。
// Initialize the dstalk host: create singletons, load config, scan plugins, initialize all plugins.
DSTALK_API int dstalk_init(const char* config_path) DSTALK_API int dstalk_init(const char* config_path)
{ {
std::lock_guard<std::mutex> lock(g_init_mutex); std::lock_guard<std::mutex> lock(g_init_mutex);
@@ -178,14 +218,14 @@ DSTALK_API int dstalk_init(const char* config_path)
g_service_registry = new dstalk::ServiceRegistry(); g_service_registry = new dstalk::ServiceRegistry();
g_plugin_loader = new dstalk::PluginLoader(); g_plugin_loader = new dstalk::PluginLoader();
// 加载配置 // 加载配置 / Load config
if (config_path && config_path[0]) { if (config_path && config_path[0]) {
if (g_config->load_file(config_path) != 0) { if (g_config->load_file(config_path) != 0) {
host_log(DSTALK_LOG_WARN, "Failed to load config: %s", config_path); host_log(DSTALK_LOG_WARN, "Failed to load config: %s", config_path);
} }
} }
// 扫描插件目录 // 扫描插件目录 / Scan plugin directory
const char* plugin_dir = g_config->get("plugin_dir"); const char* plugin_dir = g_config->get("plugin_dir");
if (!plugin_dir) plugin_dir = "plugins"; if (!plugin_dir) plugin_dir = "plugins";
int loaded = load_plugins_from_directory(plugin_dir); int loaded = load_plugins_from_directory(plugin_dir);
@@ -195,7 +235,7 @@ DSTALK_API int dstalk_init(const char* config_path)
loaded = load_plugins_from_directory("../plugins"); loaded = load_plugins_from_directory("../plugins");
} }
// 初始化所有插件 // 初始化所有插件 / Initialize all plugins
if (g_plugin_loader->initialize_all(&g_host_api) != 0) { if (g_plugin_loader->initialize_all(&g_host_api) != 0) {
host_log(DSTALK_LOG_WARN, "Some plugins failed to initialize"); host_log(DSTALK_LOG_WARN, "Some plugins failed to initialize");
} }
@@ -214,6 +254,8 @@ DSTALK_API int dstalk_init(const char* config_path)
} }
} }
// 关闭 dstalk 主机:关闭插件、销毁单例、释放资源。
// Shutdown the dstalk host: shutdown plugins, destroy singletons, release resources.
DSTALK_API void dstalk_shutdown(void) DSTALK_API void dstalk_shutdown(void)
{ {
std::lock_guard<std::mutex> lock(g_init_mutex); std::lock_guard<std::mutex> lock(g_init_mutex);
@@ -234,6 +276,8 @@ DSTALK_API void dstalk_shutdown(void)
g_initialized = false; g_initialized = false;
} }
// 从给定路径加载单个插件 DLL 并初始化。返回插件 ID失败返回 -1。
// Load a single plugin DLL from the given path and initialize it. Returns plugin ID or -1.
DSTALK_API int dstalk_plugin_load(const char* path) DSTALK_API int dstalk_plugin_load(const char* path)
{ {
if (!g_initialized || !g_plugin_loader) return -1; if (!g_initialized || !g_plugin_loader) return -1;
@@ -244,12 +288,16 @@ DSTALK_API int dstalk_plugin_load(const char* path)
return id; return id;
} }
// 按 ID 卸载插件:调用 on_shutdown卸载 DLL从注册表中移除。成功返回 0。
// Unload a plugin by ID: call on_shutdown, unload DLL, remove from registry. Returns 0 on success.
DSTALK_API int dstalk_plugin_unload(int plugin_id) DSTALK_API int dstalk_plugin_unload(int plugin_id)
{ {
if (!g_initialized || !g_plugin_loader) return -1; if (!g_initialized || !g_plugin_loader) return -1;
return g_plugin_loader->unload_plugin(plugin_id); return g_plugin_loader->unload_plugin(plugin_id);
} }
// 以 JSON 字符串列出所有已加载插件。调用者必须用 dstalk_free 释放 *output_json。
// List all loaded plugins as a JSON string. Caller must free *output_json with dstalk_free.
DSTALK_API int dstalk_plugin_list(char** output_json) DSTALK_API int dstalk_plugin_list(char** output_json)
{ {
if (!g_initialized || !g_plugin_loader || !output_json) return -1; if (!g_initialized || !g_plugin_loader || !output_json) return -1;
@@ -257,12 +305,16 @@ DSTALK_API int dstalk_plugin_list(char** output_json)
return *output_json ? 0 : -1; return *output_json ? 0 : -1;
} }
// 按名称和最低版本从全局服务注册表查询服务 vtable。
// Query a service vtable by name and minimum version from the global service registry.
DSTALK_API void* dstalk_service_query(const char* service_name, int min_version) DSTALK_API void* dstalk_service_query(const char* service_name, int min_version)
{ {
if (!g_initialized || !g_service_registry) return nullptr; if (!g_initialized || !g_service_registry) return nullptr;
return g_service_registry->query_service(service_name, min_version); return g_service_registry->query_service(service_name, min_version);
} }
// 订阅回调到事件类型。返回订阅 ID失败返回 -1。
// Subscribe a callback to an event type. Returns a subscription ID or -1.
DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata) DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata)
{ {
if (!g_initialized || !g_event_bus || !handler) return -1; if (!g_initialized || !g_event_bus || !handler) return -1;
@@ -270,30 +322,40 @@ DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn ha
[handler, userdata](int type, const void* data) { handler(type, data, userdata); }); [handler, userdata](int type, const void* data) { handler(type, data, userdata); });
} }
// 向订阅了该事件类型的所有处理函数发送事件。
// Emit an event to all handlers subscribed to the given event type.
DSTALK_API int dstalk_event_emit(int event_type, const void* data) DSTALK_API int dstalk_event_emit(int event_type, const void* data)
{ {
if (!g_initialized || !g_event_bus) return -1; if (!g_initialized || !g_event_bus) return -1;
return g_event_bus->emit(event_type, data); return g_event_bus->emit(event_type, data);
} }
// 按订阅 ID 取消注册之前的事件处理函数。
// Unsubscribe a previously registered event handler by subscription ID.
DSTALK_API void dstalk_event_unsubscribe(int subscription_id) DSTALK_API void dstalk_event_unsubscribe(int subscription_id)
{ {
if (!g_initialized || !g_event_bus) return; if (!g_initialized || !g_event_bus) return;
g_event_bus->unsubscribe(subscription_id); g_event_bus->unsubscribe(subscription_id);
} }
// 按键读取配置值。返回指向内部存储的指针(请勿释放)。
// Read a configuration value by key. Returns pointer to internal storage (do not free).
DSTALK_API const char* dstalk_config_get(const char* key) DSTALK_API const char* dstalk_config_get(const char* key)
{ {
if (!g_initialized || !g_config) return nullptr; if (!g_initialized || !g_config) return nullptr;
return g_config->get(key); return g_config->get(key);
} }
// 设置配置键值对。成功返回 0。
// Set a configuration key/value pair. Returns 0 on success.
DSTALK_API int dstalk_config_set(const char* key, const char* value) DSTALK_API int dstalk_config_set(const char* key, const char* value)
{ {
if (!g_initialized || !g_config) return -1; if (!g_initialized || !g_config) return -1;
return g_config->set(key, value); return g_config->set(key, value);
} }
// 在给定级别记录消息printf 风格)。写入 stderr 并转发到诊断回调。
// Log a message at the given level (printf-style). Writes to stderr and forwards to diag callback.
DSTALK_API void dstalk_log(int level, const char* fmt, ...) DSTALK_API void dstalk_log(int level, const char* fmt, ...)
{ {
va_list args; va_list args;
@@ -302,24 +364,33 @@ DSTALK_API void dstalk_log(int level, const char* fmt, ...)
va_end(args); va_end(args);
} }
// 通过 malloc 分配内存(为插件 ABI 一致性提供) / Allocate memory via malloc (provided for plugin ABI consistency).
DSTALK_API void* dstalk_alloc(size_t size) { return malloc(size); } DSTALK_API void* dstalk_alloc(size_t size) { return malloc(size); }
// 释放通过 dstalk_alloc 分配的内存(为插件 ABI 一致性提供) / Free memory allocated via dstalk_alloc (provided for plugin ABI consistency).
DSTALK_API void dstalk_free(void* ptr) { free(ptr); } DSTALK_API void dstalk_free(void* ptr) { free(ptr); }
// 使用 dstalk_alloc 复制 C 字符串(调用者必须 dstalk_free / Duplicate a C string using dstalk_alloc (caller must dstalk_free).
DSTALK_API char* dstalk_strdup(const char* s) { return host_strdup(s); } DSTALK_API char* dstalk_strdup(const char* s) { return host_strdup(s); }
// 注册接收所有日志消息的诊断回调(传入 null 可取消设置)。
// Register a diagnostic callback that receives all log messages (may be null to unset).
DSTALK_API void dstalk_set_diag_callback(dstalk_diag_cb cb) { DSTALK_API void dstalk_set_diag_callback(dstalk_diag_cb cb) {
g_diag_callback.store(cb, std::memory_order_release); g_diag_callback.store(cb, std::memory_order_release);
} }
// ============================================================ // ============================================================
// LSP 便捷函数 (委托给 "lsp" 服务插件) // LSP 便捷函数 (委托给 "lsp" 服务插件) / LSP convenience functions (delegated to "lsp" service plugin)
// ============================================================ // ============================================================
// 从全局服务注册表获取 "lsp" 服务 vtable不可用则返回 null。
// Retrieve the "lsp" service vtable from the global service registry, or null if unavailable.
static const dstalk_lsp_service_t* get_lsp_service() { static const dstalk_lsp_service_t* get_lsp_service() {
if (!g_initialized || !g_service_registry) return nullptr; if (!g_initialized || !g_service_registry) return nullptr;
return static_cast<const dstalk_lsp_service_t*>( return static_cast<const dstalk_lsp_service_t*>(
g_service_registry->query_service("lsp", 1)); g_service_registry->query_service("lsp", 1));
} }
// 为给定的命令和语言启动语言服务器进程。
// Start a language server process for the given command and language.
DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language) DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language)
{ {
auto* svc = get_lsp_service(); auto* svc = get_lsp_service();
@@ -327,12 +398,16 @@ DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language)
return svc->start(server_cmd, language); return svc->start(server_cmd, language);
} }
// 停止当前正在运行的语言服务器进程。
// Stop the currently running language server process.
DSTALK_API void dstalk_lsp_stop(void) DSTALK_API void dstalk_lsp_stop(void)
{ {
auto* svc = get_lsp_service(); auto* svc = get_lsp_service();
if (svc && svc->stop) svc->stop(); if (svc && svc->stop) svc->stop();
} }
// 在 LSP 服务器中打开文档以供分析didOpen 通知)。
// Open a document in the LSP server for analysis (didOpen notification).
DSTALK_API int dstalk_lsp_open(const char* uri, const char* content, const char* language_id) DSTALK_API int dstalk_lsp_open(const char* uri, const char* content, const char* language_id)
{ {
auto* svc = get_lsp_service(); auto* svc = get_lsp_service();
@@ -340,6 +415,8 @@ DSTALK_API int dstalk_lsp_open(const char* uri, const char* content, const char*
return svc->open_document(uri, content, language_id); return svc->open_document(uri, content, language_id);
} }
// 在 LSP 服务器中关闭文档didClose 通知)。
// Close a document in the LSP server (didClose notification).
DSTALK_API int dstalk_lsp_close(const char* uri) DSTALK_API int dstalk_lsp_close(const char* uri)
{ {
auto* svc = get_lsp_service(); auto* svc = get_lsp_service();
@@ -347,6 +424,8 @@ DSTALK_API int dstalk_lsp_close(const char* uri)
return svc->close_document(uri); return svc->close_document(uri);
} }
// 检索文档的当前诊断信息。调用者必须用 dstalk_free 释放 *output。
// Retrieve current diagnostics for a document. Caller must free *output with dstalk_free.
DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output) DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output)
{ {
auto* svc = get_lsp_service(); auto* svc = get_lsp_service();
@@ -354,6 +433,8 @@ DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output)
return svc->get_diagnostics(uri, output); return svc->get_diagnostics(uri, output);
} }
// 请求文档位置处的悬停信息。调用者必须用 dstalk_free 释放 *output。
// Request hover information at a document position. Caller must free *output with dstalk_free.
DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character, char** output) DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character, char** output)
{ {
auto* svc = get_lsp_service(); auto* svc = get_lsp_service();
@@ -361,6 +442,8 @@ DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character, char**
return svc->get_hover(uri, line, character, output); return svc->get_hover(uri, line, character, output);
} }
// 请求文档位置处的补全项。调用者必须用 dstalk_free 释放 *output。
// Request completion items at a document position. Caller must free *output with dstalk_free.
DSTALK_API int dstalk_lsp_completion(const char* uri, int line, int character, char** output) DSTALK_API int dstalk_lsp_completion(const char* uri, int line, int character, char** output)
{ {
auto* svc = get_lsp_service(); auto* svc = get_lsp_service();

View File

@@ -1,3 +1,10 @@
/*
* @file plugin_loader.cpp
* @brief PluginLoader implementation: DLL load/unload, path validation, Kahn topological sort, lifecycle management.
* PluginLoader 实现DLL 加载/卸载、路径验证、Kahn 拓扑排序、生命周期管理。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "plugin_loader.hpp" #include "plugin_loader.hpp"
#include <boost/json.hpp> #include <boost/json.hpp>
@@ -21,20 +28,26 @@ namespace dstalk {
namespace json = boost::json; namespace json = boost::json;
namespace fs = std::filesystem; namespace fs = std::filesystem;
// 析构函数:调用 shutdown_all 释放所有插件并释放 DLL 句柄。
// Destructor: calls shutdown_all to release all plugins and free DLL handles.
PluginLoader::~PluginLoader() PluginLoader::~PluginLoader()
{ {
shutdown_all(); shutdown_all();
} }
// 加载插件 DLL验证路径扩展名、目录遍历、目录加载库
// 解析 dstalk_plugin_init验证 API 版本,解析依赖,分配 ID。
// Load a plugin DLL: validate path (extension, traversal, directory), load library,
// resolve dstalk_plugin_init, verify API version, parse dependencies, assign ID.
int PluginLoader::load_plugin(const char* path) int PluginLoader::load_plugin(const char* path)
{ {
if (!path) return -1; if (!path) return -1;
// === Path validation (F-18.3-3) === // === 路径验证 (F-18.3-3) / Path validation (F-18.3-3) ===
{ {
fs::path p = fs::absolute(fs::path(path)).lexically_normal(); fs::path p = fs::absolute(fs::path(path)).lexically_normal();
// Extension check (case-insensitive) // 扩展名检查(大小写不敏感) / Extension check (case-insensitive)
std::string ext = p.extension().string(); std::string ext = p.extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), std::transform(ext.begin(), ext.end(), ext.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); }); [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
@@ -57,7 +70,7 @@ int PluginLoader::load_plugin(const char* path)
return -1; return -1;
} }
// Directory traversal check // 目录遍历检查 / Directory traversal check
bool has_dotdot = false; bool has_dotdot = false;
bool in_plugins_dir = false; bool in_plugins_dir = false;
for (const auto& comp : p) { for (const auto& comp : p) {
@@ -78,6 +91,7 @@ int PluginLoader::load_plugin(const char* path)
return -1; return -1;
} }
// 目录约束:必须位于 'plugins' 目录下或为纯文件名
// Directory constraint: must be under a 'plugins' directory or be a plain filename // Directory constraint: must be under a 'plugins' directory or be a plain filename
if (!in_plugins_dir && p.has_parent_path()) { if (!in_plugins_dir && p.has_parent_path()) {
if (host_api_) { if (host_api_) {
@@ -88,7 +102,7 @@ int PluginLoader::load_plugin(const char* path)
} }
} }
// 加载DLL // 加载DLL / Load DLL
#ifdef _WIN32 #ifdef _WIN32
void* handle = LoadLibraryA(path); void* handle = LoadLibraryA(path);
#else #else
@@ -109,7 +123,7 @@ int PluginLoader::load_plugin(const char* path)
return -1; return -1;
} }
// 获取入口函数 // 获取入口函数 / Resolve entry function
#ifdef _WIN32 #ifdef _WIN32
auto init_fn = (dstalk_plugin_init_fn)GetProcAddress( auto init_fn = (dstalk_plugin_init_fn)GetProcAddress(
(HMODULE)handle, "dstalk_plugin_init"); (HMODULE)handle, "dstalk_plugin_init");
@@ -138,7 +152,7 @@ int PluginLoader::load_plugin(const char* path)
return -1; return -1;
} }
// 调用入口函数获取插件信息 // 调用入口函数获取插件信息 / Call entry function to get plugin info
dstalk_plugin_info_t* info = nullptr; dstalk_plugin_info_t* info = nullptr;
try { try {
info = init_fn(); info = init_fn();
@@ -160,7 +174,7 @@ int PluginLoader::load_plugin(const char* path)
return -1; return -1;
} }
// 检查API版本兼容性 // 检查API版本兼容性 / Check API version compatibility
if (info->api_version != DSTALK_API_VERSION) { if (info->api_version != DSTALK_API_VERSION) {
if (host_api_) { if (host_api_) {
host_api_->log(DSTALK_LOG_ERROR, host_api_->log(DSTALK_LOG_ERROR,
@@ -175,7 +189,7 @@ int PluginLoader::load_plugin(const char* path)
return -1; return -1;
} }
// 创建插件信息 // 创建插件信息 / Create plugin info
int id = next_id_++; int id = next_id_++;
PluginInfo plugin; PluginInfo plugin;
plugin.id = id; plugin.id = id;
@@ -187,7 +201,7 @@ int PluginLoader::load_plugin(const char* path)
plugin.info = info; plugin.info = info;
plugin.initialized = false; plugin.initialized = false;
// 解析依赖 // 解析依赖 / Parse dependencies
for (int i = 0; i < DSTALK_MAX_DEPS && info->dependencies[i]; i++) { for (int i = 0; i < DSTALK_MAX_DEPS && info->dependencies[i]; i++) {
plugin.dependencies.push_back(info->dependencies[i]); plugin.dependencies.push_back(info->dependencies[i]);
} }
@@ -196,6 +210,8 @@ int PluginLoader::load_plugin(const char* path)
return id; return id;
} }
// 按 ID 卸载插件:若已初始化则调用 on_shutdown释放 DLL 句柄,从 map 中移除。
// Unload a plugin by ID: call on_shutdown if initialized, free the DLL handle, erase from map.
int PluginLoader::unload_plugin(int plugin_id) int PluginLoader::unload_plugin(int plugin_id)
{ {
auto it = plugins_.find(plugin_id); auto it = plugins_.find(plugin_id);
@@ -203,7 +219,7 @@ int PluginLoader::unload_plugin(int plugin_id)
PluginInfo& plugin = it->second; PluginInfo& plugin = it->second;
// 调用关闭回调 // 调用关闭回调 / Call shutdown callback
if (plugin.initialized && plugin.info->on_shutdown) { if (plugin.initialized && plugin.info->on_shutdown) {
try { try {
plugin.info->on_shutdown(); plugin.info->on_shutdown();
@@ -216,7 +232,7 @@ int PluginLoader::unload_plugin(int plugin_id)
} }
} }
// 卸载DLL // 卸载DLL / Unload DLL
#ifdef _WIN32 #ifdef _WIN32
FreeLibrary((HMODULE)plugin.handle); FreeLibrary((HMODULE)plugin.handle);
#else #else
@@ -227,6 +243,8 @@ int PluginLoader::unload_plugin(int plugin_id)
return 0; return 0;
} }
// 将所有已加载插件序列化为 JSON 数组字符串。
// Serialize all loaded plugins into a JSON array string.
std::string PluginLoader::list_plugins() const std::string PluginLoader::list_plugins() const
{ {
json::array arr; json::array arr;
@@ -250,15 +268,19 @@ std::string PluginLoader::list_plugins() const
return json::serialize(arr); return json::serialize(arr);
} }
// 使用 Kahn 算法计算依赖顺序的插件 ID 列表。
// 若检测到循环依赖则抛出 std::runtime_error。
// Compute dependency-ordered plugin IDs using Kahn's algorithm.
// Throws std::runtime_error if a circular dependency is detected.
std::vector<int> PluginLoader::topological_sort() const std::vector<int> PluginLoader::topological_sort() const
{ {
// 构建名称到ID的映射 // 构建名称到ID的映射 / Build name-to-ID map
std::unordered_map<std::string, int> name_to_id; std::unordered_map<std::string, int> name_to_id;
for (const auto& [id, plugin] : plugins_) { for (const auto& [id, plugin] : plugins_) {
name_to_id[plugin.name] = id; name_to_id[plugin.name] = id;
} }
// 计算入度 // 计算入度 / Calculate in-degrees
std::unordered_map<int, int> in_degree; std::unordered_map<int, int> in_degree;
std::unordered_map<int, std::vector<int>> dependents; std::unordered_map<int, std::vector<int>> dependents;
@@ -277,7 +299,7 @@ std::vector<int> PluginLoader::topological_sort() const
} }
} }
// 拓扑排序Kahn算法 // 拓扑排序Kahn算法 / Topological sort (Kahn's algorithm)
std::queue<int> queue; std::queue<int> queue;
for (const auto& [id, degree] : in_degree) { for (const auto& [id, degree] : in_degree) {
if (degree == 0) { if (degree == 0) {
@@ -298,7 +320,7 @@ std::vector<int> PluginLoader::topological_sort() const
} }
} }
// 检查循环依赖 // 检查循环依赖 / Check for circular dependency
if (sorted.size() != plugins_.size()) { if (sorted.size() != plugins_.size()) {
throw std::runtime_error("Circular dependency detected"); throw std::runtime_error("Circular dependency detected");
} }
@@ -306,17 +328,21 @@ std::vector<int> PluginLoader::topological_sort() const
return sorted; return sorted;
} }
// 验证依赖:检查缺失依赖和循环依赖。
// 成功返回 0发现错误返回 -1错误通过 host_api_ 记录)。
// Validate dependencies: checks for missing dependencies and circular dependencies.
// Returns 0 on success, -1 if any errors found (errors are logged via host_api_).
int PluginLoader::validate_dependencies() const int PluginLoader::validate_dependencies() const
{ {
int error_count = 0; int error_count = 0;
// 构建名称到ID的映射 // 构建名称到ID的映射 / Build name-to-ID map
std::unordered_map<std::string, int> name_to_id; std::unordered_map<std::string, int> name_to_id;
for (const auto& [id, plugin] : plugins_) { for (const auto& [id, plugin] : plugins_) {
name_to_id[plugin.name] = id; name_to_id[plugin.name] = id;
} }
// 检查1缺失依赖deps 引用的插件未加载) // 检查1缺失依赖deps 引用的插件未加载) / Check 1: missing dependencies (deps reference plugins not loaded)
for (const auto& [id, plugin] : plugins_) { for (const auto& [id, plugin] : plugins_) {
for (const auto& dep_name : plugin.dependencies) { for (const auto& dep_name : plugin.dependencies) {
if (name_to_id.find(dep_name) == name_to_id.end()) { if (name_to_id.find(dep_name) == name_to_id.end()) {
@@ -330,7 +356,7 @@ int PluginLoader::validate_dependencies() const
} }
} }
// 检查2循环依赖拓扑排序失败 // 检查2循环依赖拓扑排序失败 / Check 2: circular dependency (topological sort fails)
try { try {
topological_sort(); topological_sort();
} catch (const std::runtime_error&) { } catch (const std::runtime_error&) {
@@ -344,12 +370,19 @@ int PluginLoader::validate_dependencies() const
return error_count > 0 ? -1 : 0; return error_count > 0 ? -1 : 0;
} }
// 按依赖顺序初始化所有未初始化的插件。
// 无效依赖或失败初始化会标记插件名,避免级联失败。
// 返回初始化失败的插件数量,严重错误返回 -1。
// Initialize all uninitialized plugins in dependency order.
// Invalid dependencies or failed inits mark the plugin name, avoiding cascading failures.
// Returns the number of plugins that failed to initialize, or -1 on critical error.
int PluginLoader::initialize_all(const dstalk_host_api_t* host_api) int PluginLoader::initialize_all(const dstalk_host_api_t* host_api)
{ {
if (!host_api) return -1; if (!host_api) return -1;
host_api_ = host_api; host_api_ = host_api;
// 依赖合法性校验log 错误但不 crash继续初始化流程 // 依赖合法性校验log 错误但不 crash继续初始化流程
// Validate dependencies (log errors but don't crash, continue initialization)
if (validate_dependencies() != 0) { if (validate_dependencies() != 0) {
host_api->log(DSTALK_LOG_WARN, host_api->log(DSTALK_LOG_WARN,
"[plugin_loader] Dependency validation failed; initialization may be incomplete"); "[plugin_loader] Dependency validation failed; initialization may be incomplete");
@@ -368,7 +401,7 @@ int PluginLoader::initialize_all(const dstalk_host_api_t* host_api)
PluginInfo& plugin = it->second; PluginInfo& plugin = it->second;
if (plugin.initialized) continue; if (plugin.initialized) continue;
// 检查依赖是否已失败 // 检查依赖是否已失败 / Check if dependency has already failed
bool dep_unavailable = false; bool dep_unavailable = false;
for (const auto& dep_name : plugin.dependencies) { for (const auto& dep_name : plugin.dependencies) {
if (failed_names.count(dep_name)) { if (failed_names.count(dep_name)) {
@@ -415,13 +448,17 @@ int PluginLoader::initialize_all(const dstalk_host_api_t* host_api)
return failed_count; return failed_count;
} catch (const std::runtime_error&) { } catch (const std::runtime_error&) {
// 循环依赖 // 循环依赖 / Circular dependency
return -1; return -1;
} catch (const std::exception&) { } catch (const std::exception&) {
return -1; return -1;
} }
} }
// 仅初始化尚未初始化的插件(用于增量/按需加载)。
// 返回新初始化的插件数量,失败返回 -1。
// Initialize only plugins that haven't been initialized yet (used for incremental/on-demand loading).
// Returns the number of newly initialized plugins, or -1 on failure.
int PluginLoader::initialize_pending(const dstalk_host_api_t* host_api) int PluginLoader::initialize_pending(const dstalk_host_api_t* host_api)
{ {
host_api_ = host_api; host_api_ = host_api;
@@ -463,15 +500,17 @@ int PluginLoader::initialize_pending(const dstalk_host_api_t* host_api)
} }
} }
// 按逆依赖顺序关闭所有插件,然后释放所有 DLL 句柄并清空 map。
// Shutdown all plugins in reverse dependency order, then free all DLL handles and clear the map.
void PluginLoader::shutdown_all() void PluginLoader::shutdown_all()
{ {
// 按逆序关闭 // 按逆序关闭 / Shutdown in reverse order
std::vector<int> order; std::vector<int> order;
try { try {
order = topological_sort(); order = topological_sort();
std::reverse(order.begin(), order.end()); std::reverse(order.begin(), order.end());
} catch (...) { } catch (...) {
// 如果排序失败,按任意顺序关闭 // 如果排序失败,按任意顺序关闭 / If sorting fails, shutdown in arbitrary order
for (const auto& [id, _] : plugins_) { for (const auto& [id, _] : plugins_) {
order.push_back(id); order.push_back(id);
} }
@@ -496,7 +535,7 @@ void PluginLoader::shutdown_all()
plugin.initialized = false; plugin.initialized = false;
} }
// 释放所有 DLL 句柄 // 释放所有 DLL 句柄 / Free all DLL handles
for (auto& [id, plugin] : plugins_) { for (auto& [id, plugin] : plugins_) {
if (plugin.handle) { if (plugin.handle) {
#ifdef _WIN32 #ifdef _WIN32
@@ -510,6 +549,8 @@ void PluginLoader::shutdown_all()
plugins_.clear(); plugins_.clear();
} }
// 按 ID 查找插件。返回 PluginInfo 指针,未找到则返回 nullptr。
// Look up a plugin by ID. Returns pointer to PluginInfo, or nullptr if not found.
const PluginInfo* PluginLoader::get_plugin(int plugin_id) const const PluginInfo* PluginLoader::get_plugin(int plugin_id) const
{ {
auto it = plugins_.find(plugin_id); auto it = plugins_.find(plugin_id);

View File

@@ -1,3 +1,10 @@
/*
* @file plugin_loader.hpp
* @brief DLL plugin loader with topological sort for dependency-ordered initialization.
* DLL 插件加载器,使用拓扑排序实现按依赖顺序初始化。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#pragma once #pragma once
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
@@ -8,6 +15,8 @@
namespace dstalk { namespace dstalk {
// 描述单个已加载插件标识、DLL 句柄、信息 vtable 和初始化状态。
// Describes a single loaded plugin: identity, DLL handle, info vtable, and init state.
struct PluginInfo { struct PluginInfo {
int id; int id;
std::string name; std::string name;
@@ -16,42 +25,47 @@ struct PluginInfo {
int api_version; int api_version;
std::vector<std::string> dependencies; std::vector<std::string> dependencies;
void* handle; // DLL handle void* handle; // DLL 句柄 / DLL handle
dstalk_plugin_info_t* info; dstalk_plugin_info_t* info;
bool initialized; bool initialized;
}; };
// 管理基于 DLL 的插件生命周期:加载、卸载、验证依赖、
// 拓扑排序初始化、关闭和 JSON 列表。
// Manages the lifecycle of DLL-based plugins: load, unload, validate dependencies,
// topological-sort initialization, shutdown, and JSON listing.
class PluginLoader { class PluginLoader {
public: public:
PluginLoader() = default; PluginLoader() = default;
~PluginLoader(); ~PluginLoader();
// 加载插件返回插件ID失败返回-1 // 加载插件返回插件ID失败返回-1 / Load plugin (returns plugin ID, -1 on failure)
int load_plugin(const char* path); int load_plugin(const char* path);
// 卸载插件 // 卸载插件 / Unload plugin
int unload_plugin(int plugin_id); int unload_plugin(int plugin_id);
// 获取插件列表JSON格式 // 获取插件列表JSON格式 / Get plugin list (JSON format)
std::string list_plugins() const; std::string list_plugins() const;
// 按依赖顺序初始化所有插件 // 按依赖顺序初始化所有插件 / Initialize all plugins in dependency order
int initialize_all(const dstalk_host_api_t* host_api); int initialize_all(const dstalk_host_api_t* host_api);
// 仅初始化尚未初始化的插件(增量加载场景) // 仅初始化尚未初始化的插件(增量加载场景) / Initialize only uninitialized plugins (incremental loading scenario)
int initialize_pending(const dstalk_host_api_t* host_api); int initialize_pending(const dstalk_host_api_t* host_api);
// 关闭所有插件 // 关闭所有插件 / Shutdown all plugins
void shutdown_all(); void shutdown_all();
// 获取插件信息 // 获取插件信息 / Get plugin info
const PluginInfo* get_plugin(int plugin_id) const; const PluginInfo* get_plugin(int plugin_id) const;
private: private:
// 拓扑排序(按依赖顺序) // 拓扑排序(按依赖顺序) / Topological sort (by dependency order)
std::vector<int> topological_sort() const; std::vector<int> topological_sort() const;
// 依赖合法性校验(缺失依赖 + 循环依赖),返回 0 成功 / -1 失败 // 依赖合法性校验(缺失依赖 + 循环依赖),返回 0 成功 / -1 失败
// Validate dependencies (missing + circular), returns 0 success / -1 failure
int validate_dependencies() const; int validate_dependencies() const;
std::unordered_map<int, PluginInfo> plugins_; std::unordered_map<int, PluginInfo> plugins_;

View File

@@ -1,22 +1,30 @@
/* @file service_registry.cpp
* @brief ServiceRegistry implementation: register, query, unregister with reader-writer locking.
* ServiceRegistry 实现:基于读写锁的 register、query、unregister。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "service_registry.hpp" #include "service_registry.hpp"
namespace dstalk { namespace dstalk {
// 注册指定版本的命名服务 / Register a named service at a given version. 参数为空返回 -1已注册返回 -2 / Returns -1 on null args, -2 if already registered.
int ServiceRegistry::register_service(const char* name, int version, void* vtable) int ServiceRegistry::register_service(const char* name, int version, void* vtable)
{ {
if (!name || !vtable) return -1; if (!name || !vtable) return -1;
std::unique_lock<std::shared_mutex> lock(mutex_); std::unique_lock<std::shared_mutex> lock(mutex_);
// 检查是否已注册 // 检查是否已注册 / Check if already registered
if (services_.find(name) != services_.end()) { if (services_.find(name) != services_.end()) {
return -2; // 已存在 return -2; // 已存在 / already registered
} }
services_[name] = {name, version, vtable}; services_[name] = {name, version, vtable};
return 0; return 0;
} }
// 按名称和最低版本查询服务 / Query a service by name and minimum version. 返回 vtable 指针或 nullptr / Returns vtable pointer or nullptr if not found.
void* ServiceRegistry::query_service(const char* name, int min_version) const void* ServiceRegistry::query_service(const char* name, int min_version) const
{ {
if (!name) return nullptr; if (!name) return nullptr;
@@ -31,6 +39,7 @@ void* ServiceRegistry::query_service(const char* name, int min_version) const
return it->second.vtable; return it->second.vtable;
} }
// 注销指定名称的服务name 为空或未找到时无操作)/ Unregister a named service (no-op if name is null or not found).
void ServiceRegistry::unregister_service(const char* name) void ServiceRegistry::unregister_service(const char* name)
{ {
if (!name) return; if (!name) return;

View File

@@ -1,3 +1,9 @@
/* @file service_registry.hpp
* @brief Name-versioned service registry for decoupled plugin communication.
* 基于名称+版本的服务注册表,用于插件间解耦通信。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#pragma once #pragma once
#include <mutex> #include <mutex>
@@ -7,18 +13,23 @@
namespace dstalk { namespace dstalk {
// 名称 + 最低版本服务目录 / Name + minimum-version service directory.
// 插件注册 vtable消费者按名称和版本约束查询 /
// Plugins register vtables; consumers query by name and version constraint.
// 读取query使用 shared_lock写入register/unregister使用 unique_lock /
// Reads (query) use shared_lock; writes (register/unregister) use unique_lock.
class ServiceRegistry { class ServiceRegistry {
public: public:
ServiceRegistry() = default; ServiceRegistry() = default;
~ServiceRegistry() = default; ~ServiceRegistry() = default;
// 注册服务 // 注册服务 / Register a named service at a given version
int register_service(const char* name, int version, void* vtable); int register_service(const char* name, int version, void* vtable);
// 查询服务(返回 vtable 指针,或 nullptr // 查询服务(返回 vtable 指针,或 nullptr/ Query a service by name and minimum version
void* query_service(const char* name, int min_version) const; void* query_service(const char* name, int min_version) const;
// 注销服务 // 注销服务 / Unregister a named service
void unregister_service(const char* name); void unregister_service(const char* name);
private: private:
@@ -28,7 +39,7 @@ private:
void* vtable; void* vtable;
}; };
mutable std::shared_mutex mutex_; mutable std::shared_mutex mutex_; // 读写锁query 用 sharedregister/unregister 用 unique / RW lock: shared for query, unique for register/unregister
std::unordered_map<std::string, ServiceEntry> services_; std::unordered_map<std::string, ServiceEntry> services_;
}; };

View File

@@ -1,11 +1,14 @@
// ============================================================================ /*
// dstalk-gui — SDL3 聊天客户端 * @file main.cpp
// ============================================================================ * @brief SDL3-based GUI frontend for dstalk (stub/minimal implementation).
// 使用 SDL3 内置的 SDL_RenderDebugText() 渲染文本8x8 像素), * dstalk 的 SDL3 图形界面前端(最小化实现)。
// 通过 SDL_SetRenderScale 2 倍缩放至有效的 16x16 像素。 * Copyright (c) 2026 dstalk contributors. GPLv3.
// *
// 该文件是独立的——不需要额外的源文件。 * Uses SDL3's built-in SDL_RenderDebugText() for 8x8 pixel text, scaled 2x to
// ============================================================================ * effective 16x16 pixels via SDL_SetRenderScale. Self-contained single-file GUI.
* 使用 SDL3 内置的 SDL_RenderDebugText() 渲染 8x8 像素文本,通过 SDL_SetRenderScale
* 缩放 2 倍达到 16x16 像素效果。自包含的单文件 GUI。
*/
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
@@ -19,46 +22,48 @@
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
// ---- 服务 vtable 指针 ---- // ---- 服务 vtable 指针 / Service vtable pointers ----
// Global pointers to service vtables queried from the host on startup.
// 在启动时从主机查询获取的服务 vtable 全局指针。
static const dstalk_ai_service_t* g_ai_svc = nullptr; static const dstalk_ai_service_t* g_ai_svc = nullptr;
static const dstalk_session_service_t* g_session_svc = nullptr; static const dstalk_session_service_t* g_session_svc = nullptr;
// ---- 常量 ---- // ---- 常量 / Constants ----
static constexpr int WINDOW_W = 1024; static constexpr int WINDOW_W = 1024;
static constexpr int WINDOW_H = 768; static constexpr int WINDOW_H = 768;
static constexpr float RENDER_SCALE = 2.0f; static constexpr float RENDER_SCALE = 2.0f;
// 逻辑坐标尺寸(物理像素 / RENDER_SCALE // 逻辑坐标尺寸(物理像素 / RENDER_SCALE / Logical coordinate dimensions (physical pixels / RENDER_SCALE)
static constexpr int LOGICAL_W = WINDOW_W / 2; // 512 static constexpr int LOGICAL_W = WINDOW_W / 2; // 512
static constexpr int LOGICAL_H = WINDOW_H / 2; // 384 static constexpr int LOGICAL_H = WINDOW_H / 2; // 384
static constexpr int CHAR_W = 8; // SDL_RenderDebugText 原生字符宽度(逻辑像素) static constexpr int CHAR_W = 8; // SDL_RenderDebugText 原生字符宽度(逻辑像素) / native char width (logical pixels)
static constexpr int CHAR_H = 8; // 原生字符高度(逻辑像素) static constexpr int CHAR_H = 8; // 原生字符高度(逻辑像素) / native char height (logical pixels)
static constexpr int TITLE_H = 16; // 标题栏高度(逻辑像素) static constexpr int TITLE_H = 16; // 标题栏高度(逻辑像素) / title bar height (logical pixels)
static constexpr int PADDING = 4; // 内边距(逻辑像素) static constexpr int PADDING = 4; // 内边距(逻辑像素) / padding (logical pixels)
// 侧边栏 // 侧边栏 / Sidebar
static constexpr int SIDEBAR_W = 80; // 侧边栏宽度(逻辑像素,渲染为 160 物理像素) static constexpr int SIDEBAR_W = 80; // 侧边栏宽度(逻辑像素,渲染为 160 物理像素) / sidebar width (logical, renders as 160 physical px)
// 状态栏 // 状态栏 / Status bar
static constexpr int STATUS_H = 20; // 状态栏高度(逻辑像素,渲染为 40 物理像素) static constexpr int STATUS_H = 20; // 状态栏高度(逻辑像素,渲染为 40 物理像素) / status bar height (logical, renders as 40 physical px)
// 输入区域动态高度 // 输入区域动态高度 / Input area dynamic height
static constexpr int INPUT_H_MIN = 40; // 最小高度(逻辑像素) static constexpr int INPUT_H_MIN = 40; // 最小高度(逻辑像素) / min height (logical pixels)
static constexpr int INPUT_H_MAX = 120; // 最大高度(逻辑像素) static constexpr int INPUT_H_MAX = 120; // 最大高度(逻辑像素) / max height (logical pixels)
// 消息区域Y 起点不变,宽度和高度动态计算) // 消息区域Y 起点不变,宽度和高度动态计算) / Message area (Y origin fixed, width and height calculated dynamically)
static constexpr int MSG_Y = TITLE_H; static constexpr int MSG_Y = TITLE_H;
// 颜色ARGB 格式,用于 SDL_SetRenderDrawColor // 颜色ARGB 格式,用于 SDL_SetRenderDrawColor / Colors (ARGB format, for SDL_SetRenderDrawColor)
static constexpr SDL_Color COL_BG = {0x1E, 0x1E, 0x2E, 0xFF}; static constexpr SDL_Color COL_BG = {0x1E, 0x1E, 0x2E, 0xFF};
static constexpr SDL_Color COL_TITLE_BG = {0x2D, 0x2D, 0x44, 0xFF}; static constexpr SDL_Color COL_TITLE_BG = {0x2D, 0x2D, 0x44, 0xFF};
static constexpr SDL_Color COL_INPUT_BG = {0x2A, 0x2A, 0x3E, 0xFF}; static constexpr SDL_Color COL_INPUT_BG = {0x2A, 0x2A, 0x3E, 0xFF};
static constexpr SDL_Color COL_USER = {0x00, 0xFF, 0xFF, 0xFF}; // 青色 static constexpr SDL_Color COL_USER = {0x00, 0xFF, 0xFF, 0xFF}; // 青色 / cyan
static constexpr SDL_Color COL_AI = {0x00, 0xFF, 0x80, 0xFF}; // 绿色 static constexpr SDL_Color COL_AI = {0x00, 0xFF, 0x80, 0xFF}; // 绿色 / green
static constexpr SDL_Color COL_SYS = {0xFF, 0xFF, 0x00, 0xFF}; // 黄色 static constexpr SDL_Color COL_SYS = {0xFF, 0xFF, 0x00, 0xFF}; // 黄色 / yellow
static constexpr SDL_Color COL_BTN = {0x50, 0x50, 0x80, 0xFF}; // 按钮 static constexpr SDL_Color COL_BTN = {0x50, 0x50, 0x80, 0xFF}; // 按钮 / button
static constexpr SDL_Color COL_WHITE = {0xFF, 0xFF, 0xFF, 0xFF}; static constexpr SDL_Color COL_WHITE = {0xFF, 0xFF, 0xFF, 0xFF};
static constexpr SDL_Color COL_CURSOR = {0xFF, 0xFF, 0xFF, 0xFF}; static constexpr SDL_Color COL_CURSOR = {0xFF, 0xFF, 0xFF, 0xFF};
static constexpr SDL_Color COL_SEP = {0x50, 0x50, 0x70, 0xFF}; static constexpr SDL_Color COL_SEP = {0x50, 0x50, 0x70, 0xFF};
@@ -68,8 +73,9 @@ static constexpr SDL_Color COL_SIDEBAR_BTN = {0x40, 0x40, 0x68, 0xFF};
static constexpr SDL_Color COL_STATUSBAR_BG= {0x2D, 0x2D, 0x44, 0xFF}; static constexpr SDL_Color COL_STATUSBAR_BG= {0x2D, 0x2D, 0x44, 0xFF};
static constexpr SDL_Color COL_DIM = {0x80, 0x80, 0x80, 0xFF}; static constexpr SDL_Color COL_DIM = {0x80, 0x80, 0x80, 0xFF};
// ---- 数据结构 ---- // ---- 数据结构 / Data structures ----
// 单条聊天消息 / Represents a single chat message with role and text content.
struct ChatMessage { struct ChatMessage {
enum Role { USER, ASSISTANT, SYSTEM } role; enum Role { USER, ASSISTANT, SYSTEM } role;
std::string content; std::string content;
@@ -77,62 +83,66 @@ struct ChatMessage {
ChatMessage(Role r, std::string c) : role(r), content(std::move(c)) {} ChatMessage(Role r, std::string c) : role(r), content(std::move(c)) {}
}; };
// 保存所有可变 UI 状态 / Holds all mutable UI state: message list, input buffer, scroll, streaming flag, etc.
struct GuiState { struct GuiState {
std::vector<ChatMessage> messages; std::vector<ChatMessage> messages;
std::string inputBuffer; std::string inputBuffer;
int scrollOffset = 0; // 从底部滚动的逻辑像素 int scrollOffset = 0; // 从底部滚动的逻辑像素 / logical pixels scrolled from bottom
bool streaming = false; bool streaming = false;
bool running = true; bool running = true;
int cursorPos = 0; // 输入缓冲区中的光标位置 int cursorPos = 0; // 输入缓冲区中的光标位置 / cursor position in input buffer
bool cursorVisible = true; bool cursorVisible = true;
Uint64 lastCursorBlink = 0; Uint64 lastCursorBlink = 0;
float maxScroll = 0; // 可用的最大滚动距离(逻辑像素) float maxScroll = 0; // 可用的最大滚动距离(逻辑像素) / max available scroll distance (logical pixels)
// P0 新增字段 // P0 新增字段 / P0 new fields
std::vector<std::string> input_history; // 输入历史(最多 20 条) std::vector<std::string> input_history; // 输入历史(最多 20 条) / input history (max 20 entries)
int history_index = -1; // 当前历史位置(-1 = 新输入) int history_index = -1; // 当前历史位置(-1 = 新输入) / current history position (-1 = new input)
std::string saved_input; // 浏览历史时暂存当前输入 std::string saved_input; // 浏览历史时暂存当前输入 / saved current input while browsing history
bool sidebar_visible = true; // 侧边栏可见性 bool sidebar_visible = true; // 侧边栏可见性 / sidebar visibility
std::string model_name = "deepseek-chat";// 当前模型名 std::string model_name = "deepseek-chat";// 当前模型名 / current model name
}; };
// 持有上下文指针,用于将回调传递给流式 API // 将 GuiState 与 SDL 窗口/渲染器句柄及逐帧标志打包。
// 作为 userdata 传递给流式回调,使其可以更新缓冲区并重新渲染。
// Bundles GuiState with SDL window/renderer handles and per-frame flags.
// Passed as userdata to the streaming callback so it can update the buffer and re-render.
struct AppContext { struct AppContext {
GuiState state; GuiState state;
SDL_Window* window = nullptr; SDL_Window* window = nullptr;
SDL_Renderer* renderer = nullptr; SDL_Renderer* renderer = nullptr;
bool sendPending = false; // 按下 Enter 后设置为 true bool sendPending = false; // 按下 Enter 后设置为 true / set to true after pressing Enter
std::string streamBuffer; // 存储当前流式消息 std::string streamBuffer; // 存储当前流式消息 / stores current streaming message
}; };
// ---- 辅助函数 ---- // ---- 辅助函数 / Helper functions ----
// 获取一个逻辑坐标的 SDL 矩形 // 在逻辑坐标系中创建 SDL_FRect / Create an SDL_FRect in logical coordinates.
static SDL_FRect mkRect(float x, float y, float w, float h) { static SDL_FRect mkRect(float x, float y, float w, float h) {
SDL_FRect r; SDL_FRect r;
r.x = x; r.y = y; r.w = w; r.h = h; r.x = x; r.y = y; r.w = w; r.h = h;
return r; return r;
} }
// 使用给定的颜色设置绘制颜色 // 使用 SDL_Color 设置渲染器的绘制颜色 / Set the renderer's draw color from an SDL_Color.
static void setColor(SDL_Renderer* r, SDL_Color c) { static void setColor(SDL_Renderer* r, SDL_Color c) {
SDL_SetRenderDrawColor(r, c.r, c.g, c.b, c.a); SDL_SetRenderDrawColor(r, c.r, c.g, c.b, c.a);
} }
// 使用颜色绘制填充矩形 // 以纯色填充矩形(逻辑坐标) / Fill a rectangle with a solid color (logical coordinates).
static void fillRect(SDL_Renderer* r, SDL_FRect rect, SDL_Color c) { static void fillRect(SDL_Renderer* r, SDL_FRect rect, SDL_Color c) {
setColor(r, c); setColor(r, c);
SDL_RenderFillRect(r, &rect); SDL_RenderFillRect(r, &rect);
} }
// 在给定位置(逻辑坐标)绘制一个调试文本字符串,并设定颜色 // 在指定逻辑位置以指定颜色绘制调试文本 / Draw a debug-text string at a given logical position with the specified color.
static void drawText(SDL_Renderer* r, float x, float y, static void drawText(SDL_Renderer* r, float x, float y,
const char* text, SDL_Color color) { const char* text, SDL_Color color) {
setColor(r, color); setColor(r, color);
SDL_RenderDebugText(r, x, y, text); SDL_RenderDebugText(r, x, y, text);
} }
// 绘制一个可见的调试文本字符,避免为空字符串调用 SDL_RenderDebugText // 仅当字符串非空时绘制调试文本(避免 SDL_RenderDebugText 问题) / Draw debug text only if the string is non-empty (avoids SDL_RenderDebugText issues).
static void drawTextSafe(SDL_Renderer* r, float x, float y, static void drawTextSafe(SDL_Renderer* r, float x, float y,
const char* text) { const char* text) {
if (text && text[0] != '\0') { if (text && text[0] != '\0') {
@@ -140,7 +150,7 @@ static void drawTextSafe(SDL_Renderer* r, float x, float y,
} }
} }
// 计算输入区域的动态高度(根据输入内容中的换行数) // 根据换行符数量计算输入区域的动态高度 / Compute the dynamic height of the input area based on the number of newlines.
static int calcInputHeight(const std::string& input) { static int calcInputHeight(const std::string& input) {
int lines = 1; int lines = 1;
for (char ch : input) { for (char ch : input) {
@@ -150,14 +160,13 @@ static int calcInputHeight(const std::string& input) {
std::max(INPUT_H_MIN, lines * CHAR_H + PADDING * 2)); std::max(INPUT_H_MIN, lines * CHAR_H + PADDING * 2));
} }
// ---- 文本换行 ---- // ---- 文本换行 / Text wrapping ----
// 将一段文本按字符数换行。保留嵌入的 '\n',并在单词边界处尽可能按字符数换行。 // 按 maxChars 换行文本,保留嵌入的换行符 / Word-wrap text to fit within maxChars per line, respecting embedded newlines.
// 返回逻辑文本行列表。
static std::vector<std::string> wrapText(const std::string& text, int maxChars) { static std::vector<std::string> wrapText(const std::string& text, int maxChars) {
std::vector<std::string> lines; std::vector<std::string> lines;
// 首先按嵌入的换行符分割 // 首先按嵌入的换行符分割 / First split by embedded newlines
std::string remaining = text; std::string remaining = text;
while (!remaining.empty()) { while (!remaining.empty()) {
std::string segment; std::string segment;
@@ -170,13 +179,13 @@ static std::vector<std::string> wrapText(const std::string& text, int maxChars)
remaining.clear(); remaining.clear();
} }
// 将片段按单词换行以适应 maxChars // 将片段按单词换行以适应 maxChars / Wrap segment by word to fit maxChars
while (!segment.empty()) { while (!segment.empty()) {
if (static_cast<int>(segment.size()) <= maxChars) { if (static_cast<int>(segment.size()) <= maxChars) {
lines.push_back(segment); lines.push_back(segment);
break; break;
} }
// 在 maxChars 位置寻找空格/单词边界 // 在 maxChars 位置寻找空格/单词边界 / Find space/word boundary at maxChars position
int splitAt = maxChars; int splitAt = maxChars;
for (int i = maxChars; i > 0; --i) { for (int i = maxChars; i > 0; --i) {
char ch = segment[i]; char ch = segment[i];
@@ -187,7 +196,7 @@ static std::vector<std::string> wrapText(const std::string& text, int maxChars)
break; break;
} }
if ((ch & 0x80) != 0) { if ((ch & 0x80) != 0) {
// UTF-8 多字节字符——不在中间分割 // UTF-8 多字节字符——不在中间分割 / UTF-8 multi-byte char — don't split in the middle
} }
} }
if (splitAt <= 0 || splitAt > maxChars) { if (splitAt <= 0 || splitAt > maxChars) {
@@ -195,7 +204,7 @@ static std::vector<std::string> wrapText(const std::string& text, int maxChars)
} }
lines.push_back(segment.substr(0, splitAt)); lines.push_back(segment.substr(0, splitAt));
// 去除下一行的前导空格 // 去除下一行的前导空格 / Trim leading spaces for the next line
size_t start = splitAt; size_t start = splitAt;
while (start < segment.size() && while (start < segment.size() &&
(segment[start] == ' ' || segment[start] == '\t')) { (segment[start] == ' ' || segment[start] == '\t')) {
@@ -207,8 +216,7 @@ static std::vector<std::string> wrapText(const std::string& text, int maxChars)
return lines; return lines;
} }
// 计算所有消息的总渲染高度(逻辑像素) // 计算所有消息在换行后的总渲染高度(逻辑像素) / Calculate the total rendered height (in logical pixels) of all messages after wrapping.
// 注意:这使用当前的侧边栏状态来决定宽度;调用者应在侧边栏宽度正确时调用。
static int calcTotalMsgHeight(GuiState& state, int charsPerLine) { static int calcTotalMsgHeight(GuiState& state, int charsPerLine) {
int totalH = 0; int totalH = 0;
for (auto& msg : state.messages) { for (auto& msg : state.messages) {
@@ -219,8 +227,10 @@ static int calcTotalMsgHeight(GuiState& state, int charsPerLine) {
return totalH; return totalH;
} }
// ---- 侧边栏渲染 ---- // ---- 侧边栏渲染 / Sidebar rendering ----
// 渲染左侧边栏:背景、会话列表和"+ New Chat"按钮。
// Render the left sidebar: background, session list, and "+ New Chat" button.
static void renderSidebar(AppContext& ctx) { static void renderSidebar(AppContext& ctx) {
GuiState& gs = ctx.state; GuiState& gs = ctx.state;
SDL_Renderer* r = ctx.renderer; SDL_Renderer* r = ctx.renderer;
@@ -228,32 +238,34 @@ static void renderSidebar(AppContext& ctx) {
float sbY = static_cast<float>(TITLE_H); float sbY = static_cast<float>(TITLE_H);
float sbH = static_cast<float>(LOGICAL_H) - TITLE_H - STATUS_H; float sbH = static_cast<float>(LOGICAL_H) - TITLE_H - STATUS_H;
// 背景 // 背景 / Background
fillRect(r, mkRect(0, sbY, sbW, sbH), COL_SIDEBAR_BG); fillRect(r, mkRect(0, sbY, sbW, sbH), COL_SIDEBAR_BG);
// 右侧分隔线 // 右侧分隔线 / Right separator line
setColor(r, COL_SEP); setColor(r, COL_SEP);
SDL_RenderLine(r, sbW, sbY, sbW, sbY + sbH); SDL_RenderLine(r, sbW, sbY, sbW, sbY + sbH);
// "Chats" 标题 // "Chats" 标题 / "Chats" title
drawText(r, static_cast<float>(PADDING), sbY + PADDING, "Chats", COL_WHITE); drawText(r, static_cast<float>(PADDING), sbY + PADDING, "Chats", COL_WHITE);
// 会话列表(当前只有 "default" // 会话列表(当前只有 "default" / Session list (currently only "default")
float listY = sbY + TITLE_H; float listY = sbY + TITLE_H;
// "default" 条目(活动状态高亮) // "default" 条目(活动状态高亮) / "default" entry (active state highlighted)
float itemH = static_cast<float>(CHAR_H + PADDING); float itemH = static_cast<float>(CHAR_H + PADDING);
fillRect(r, mkRect(PADDING, listY, sbW - PADDING * 2, itemH), COL_SIDEBAR_ACT); fillRect(r, mkRect(PADDING, listY, sbW - PADDING * 2, itemH), COL_SIDEBAR_ACT);
drawText(r, PADDING * 2.0f, listY + PADDING / 2.0f, "default", COL_AI); drawText(r, PADDING * 2.0f, listY + PADDING / 2.0f, "default", COL_AI);
// "+ New Chat" 按钮(侧边栏底部) // "+ New Chat" 按钮(侧边栏底部) / "+ New Chat" button (sidebar bottom)
float btnY = sbY + sbH - CHAR_H - PADDING * 2; float btnY = sbY + sbH - CHAR_H - PADDING * 2;
float btnH = static_cast<float>(CHAR_H + PADDING); float btnH = static_cast<float>(CHAR_H + PADDING);
fillRect(r, mkRect(PADDING, btnY, sbW - PADDING * 2, btnH), COL_SIDEBAR_BTN); fillRect(r, mkRect(PADDING, btnY, sbW - PADDING * 2, btnH), COL_SIDEBAR_BTN);
drawText(r, PADDING * 2.0f, btnY + PADDING / 2.0f, "+ New Chat", COL_WHITE); drawText(r, PADDING * 2.0f, btnY + PADDING / 2.0f, "+ New Chat", COL_WHITE);
} }
// ---- 状态栏渲染 ---- // ---- 状态栏渲染 / Status bar rendering ----
// 渲染底部状态栏:模型名、消息数和流式状态。
// Render the bottom status bar: model name, message count, and streaming state.
static void renderStatusBar(AppContext& ctx) { static void renderStatusBar(AppContext& ctx) {
GuiState& gs = ctx.state; GuiState& gs = ctx.state;
SDL_Renderer* r = ctx.renderer; SDL_Renderer* r = ctx.renderer;
@@ -261,20 +273,20 @@ static void renderStatusBar(AppContext& ctx) {
float lh = static_cast<float>(LOGICAL_H); float lh = static_cast<float>(LOGICAL_H);
float barY = lh - STATUS_H; float barY = lh - STATUS_H;
// 背景 // 背景 / Background
fillRect(r, mkRect(0, barY, lw, static_cast<float>(STATUS_H)), COL_STATUSBAR_BG); fillRect(r, mkRect(0, barY, lw, static_cast<float>(STATUS_H)), COL_STATUSBAR_BG);
// 顶部分隔线 // 顶部分隔线 / Top separator line
setColor(r, COL_SEP); setColor(r, COL_SEP);
SDL_RenderLine(r, 0, barY, lw, barY); SDL_RenderLine(r, 0, barY, lw, barY);
// 统计消息数(排除系统消息) // 统计消息数(排除系统消息) / Count messages (excluding system messages)
int msgCount = 0; int msgCount = 0;
for (auto& msg : gs.messages) { for (auto& msg : gs.messages) {
if (msg.role != ChatMessage::SYSTEM) msgCount++; if (msg.role != ChatMessage::SYSTEM) msgCount++;
} }
// 状态文本:模型名 | 消息条数 | 流式状态 // 状态文本:模型名 | 消息条数 | 流式状态 / Status text: model name | message count | streaming state
char buf[256]; char buf[256];
snprintf(buf, sizeof(buf), "%s | %d messages | %s", snprintf(buf, sizeof(buf), "%s | %d messages | %s",
gs.model_name.c_str(), msgCount, gs.model_name.c_str(), msgCount,
@@ -283,8 +295,10 @@ static void renderStatusBar(AppContext& ctx) {
barY + (STATUS_H - CHAR_H) / 2.0f, buf, COL_WHITE); barY + (STATUS_H - CHAR_H) / 2.0f, buf, COL_WHITE);
} }
// ---- 主渲染 ---- // ---- 主渲染 / Main rendering ----
// 渲染一帧:标题栏、侧边栏、消息区(滚动)、输入区、光标、发送按钮、状态栏。
// Render one full frame: title bar, sidebar, message area (with scrolling), input area, cursor, send button, status bar.
static void renderFrame(AppContext& ctx) { static void renderFrame(AppContext& ctx) {
GuiState& gs = ctx.state; GuiState& gs = ctx.state;
SDL_Renderer* r = ctx.renderer; SDL_Renderer* r = ctx.renderer;
@@ -301,33 +315,33 @@ static void renderFrame(AppContext& ctx) {
int charsPerLine = std::max(20, int charsPerLine = std::max(20,
static_cast<int>(msgAreaW - PADDING * 2) / CHAR_W); static_cast<int>(msgAreaW - PADDING * 2) / CHAR_W);
// 1. 设置渲染缩放以获得 2 倍文本大小 // 1. 设置渲染缩放以获得 2 倍文本大小 / Set render scale for 2x text size
SDL_SetRenderScale(r, RENDER_SCALE, RENDER_SCALE); SDL_SetRenderScale(r, RENDER_SCALE, RENDER_SCALE);
// 2. 清除背景 // 2. 清除背景 / Clear background
setColor(r, COL_BG); setColor(r, COL_BG);
SDL_RenderClear(r); SDL_RenderClear(r);
// 3. 标题栏(全宽) // 3. 标题栏(全宽)/ Title bar (full width)
fillRect(r, mkRect(0, 0, lw, static_cast<float>(TITLE_H)), COL_TITLE_BG); fillRect(r, mkRect(0, 0, lw, static_cast<float>(TITLE_H)), COL_TITLE_BG);
drawText(r, static_cast<float>(PADDING), static_cast<float>(PADDING), drawText(r, static_cast<float>(PADDING), static_cast<float>(PADDING),
"dstalk - AI Chat", COL_WHITE); "dstalk - AI Chat", COL_WHITE);
// 右侧的状态指示器 // 右侧的状态指示器 / Status indicator on the right
const char* status = gs.streaming ? "[streaming...]" : "[ready]"; const char* status = gs.streaming ? "[streaming...]" : "[ready]";
float statusW = static_cast<float>(strlen(status)) * CHAR_W + PADDING; float statusW = static_cast<float>(strlen(status)) * CHAR_W + PADDING;
drawText(r, lw - statusW, static_cast<float>(PADDING), status, COL_WHITE); drawText(r, lw - statusW, static_cast<float>(PADDING), status, COL_WHITE);
// 4. 标题栏分隔线 // 4. 标题栏分隔线 / Title bar separator line
setColor(r, COL_SEP); setColor(r, COL_SEP);
SDL_RenderLine(r, 0, static_cast<float>(TITLE_H), SDL_RenderLine(r, 0, static_cast<float>(TITLE_H),
lw, static_cast<float>(TITLE_H)); lw, static_cast<float>(TITLE_H));
// 5. 侧边栏(可折叠) // 5. 侧边栏(可折叠)/ Sidebar (collapsible)
if (gs.sidebar_visible) { if (gs.sidebar_visible) {
renderSidebar(ctx); renderSidebar(ctx);
} }
// 6. 消息区域(带滚动) // 6. 消息区域(带滚动)/ Message area (with scrolling)
SDL_Rect msgClip; SDL_Rect msgClip;
msgClip.x = static_cast<int>(msgAreaX * RENDER_SCALE); msgClip.x = static_cast<int>(msgAreaX * RENDER_SCALE);
msgClip.y = static_cast<int>(msgAreaY * RENDER_SCALE); msgClip.y = static_cast<int>(msgAreaY * RENDER_SCALE);
@@ -335,13 +349,13 @@ static void renderFrame(AppContext& ctx) {
msgClip.h = static_cast<int>(msgAreaH * RENDER_SCALE); msgClip.h = static_cast<int>(msgAreaH * RENDER_SCALE);
SDL_SetRenderClipRect(r, &msgClip); SDL_SetRenderClipRect(r, &msgClip);
// 计算总消息高度和滚动限制 // 计算总消息高度和滚动限制 / Calculate total message height and scroll limits
int totalMsgH = calcTotalMsgHeight(gs, charsPerLine); int totalMsgH = calcTotalMsgHeight(gs, charsPerLine);
gs.maxScroll = std::max(0.0f, static_cast<float>(totalMsgH) - msgAreaH); gs.maxScroll = std::max(0.0f, static_cast<float>(totalMsgH) - msgAreaH);
if (gs.scrollOffset < 0) gs.scrollOffset = 0; if (gs.scrollOffset < 0) gs.scrollOffset = 0;
if (gs.scrollOffset > gs.maxScroll) gs.scrollOffset = static_cast<int>(gs.maxScroll); if (gs.scrollOffset > gs.maxScroll) gs.scrollOffset = static_cast<int>(gs.maxScroll);
// 绘制消息:起始 Y 从消息区域顶部减去 scrollOffset // 绘制消息:起始 Y 从消息区域顶部减去 scrollOffset / Draw messages: start Y from message area top minus scrollOffset
float drawY = msgAreaY - gs.scrollOffset; float drawY = msgAreaY - gs.scrollOffset;
float unusedSpace = msgAreaH - static_cast<float>(totalMsgH); float unusedSpace = msgAreaH - static_cast<float>(totalMsgH);
float bottomOffset = std::max(0.0f, unusedSpace); float bottomOffset = std::max(0.0f, unusedSpace);
@@ -359,7 +373,7 @@ static void renderFrame(AppContext& ctx) {
default: col = COL_SYS; prefix = "Sys> "; break; default: col = COL_SYS; prefix = "Sys> "; break;
} }
// 如果该消息可见,则绘制 // 如果该消息可见,则绘制 / Draw if this message is visible
float msgBottom = drawY + msgH; float msgBottom = drawY + msgH;
if (msgBottom > msgAreaY && drawY < msgAreaY + msgAreaH) { if (msgBottom > msgAreaY && drawY < msgAreaY + msgAreaH) {
float lineY = drawY + 2; float lineY = drawY + 2;
@@ -383,14 +397,14 @@ static void renderFrame(AppContext& ctx) {
SDL_SetRenderClipRect(r, nullptr); SDL_SetRenderClipRect(r, nullptr);
// 7. 输入区域分隔线 // 7. 输入区域分隔线 / Input area separator line
setColor(r, COL_SEP); setColor(r, COL_SEP);
SDL_RenderLine(r, msgAreaX, inputY, lw, inputY); SDL_RenderLine(r, msgAreaX, inputY, lw, inputY);
// 8. 输入区域背景 // 8. 输入区域背景 / Input area background
fillRect(r, mkRect(msgAreaX, inputY, msgAreaW, static_cast<float>(inputH)), COL_INPUT_BG); fillRect(r, mkRect(msgAreaX, inputY, msgAreaW, static_cast<float>(inputH)), COL_INPUT_BG);
// 9. 输入文本(支持多行显示) // 9. 输入文本(支持多行显示)/ Input text (multi-line support)
if (!gs.inputBuffer.empty()) { if (!gs.inputBuffer.empty()) {
std::string remaining = gs.inputBuffer; std::string remaining = gs.inputBuffer;
int lineIdx = 0; int lineIdx = 0;
@@ -416,7 +430,7 @@ static void renderFrame(AppContext& ctx) {
textY, "Type here..."); textY, "Type here...");
} }
// 10. 光标(多行感知) // 10. 光标(多行感知)/ Cursor (multi-line aware)
if (!gs.streaming) { if (!gs.streaming) {
Uint64 now = SDL_GetTicks(); Uint64 now = SDL_GetTicks();
if (now - gs.lastCursorBlink > 530) { if (now - gs.lastCursorBlink > 530) {
@@ -424,7 +438,7 @@ static void renderFrame(AppContext& ctx) {
gs.lastCursorBlink = now; gs.lastCursorBlink = now;
} }
if (gs.cursorVisible && gs.cursorPos <= static_cast<int>(gs.inputBuffer.size())) { if (gs.cursorVisible && gs.cursorPos <= static_cast<int>(gs.inputBuffer.size())) {
// 计算光标所在行和列 // 计算光标所在行和列 / Calculate cursor line and column
int curLine = 0; int curLine = 0;
int charsBeforeLine = 0; int charsBeforeLine = 0;
for (int i = 0; i < gs.cursorPos; i++) { for (int i = 0; i < gs.cursorPos; i++) {
@@ -444,7 +458,7 @@ static void renderFrame(AppContext& ctx) {
} }
} }
// 11. 发送/停止按钮 // 11. 发送/停止按钮 / Send/Stop button
float btnW = 5 * CHAR_W + PADDING; float btnW = 5 * CHAR_W + PADDING;
float btnH = CHAR_H + PADDING; float btnH = CHAR_H + PADDING;
float btnX = lw - btnW - PADDING; float btnX = lw - btnW - PADDING;
@@ -458,26 +472,27 @@ static void renderFrame(AppContext& ctx) {
drawText(r, btnTextX, btnTextY, "[Send]", COL_WHITE); drawText(r, btnTextX, btnTextY, "[Send]", COL_WHITE);
} }
// 12. 状态栏 // 12. 状态栏 / Status bar
renderStatusBar(ctx); renderStatusBar(ctx);
// 13. Present // 13. Present / Present
SDL_RenderPresent(r); SDL_RenderPresent(r);
} }
// ---- 事件处理 ---- // ---- 事件处理 / Event handling ----
// 尝试发送当前输入缓冲区的内容;返回 true 表示消息已排队 // 验证当前输入缓冲区并将其作为用户消息加入队列;成功发送则返回 true。
// Validate and queue the current input buffer as a user message; returns true if sent.
static bool trySendMessage(GuiState& gs) { static bool trySendMessage(GuiState& gs) {
std::string text = gs.inputBuffer; std::string text = gs.inputBuffer;
// 去除前导/尾随空白,但保留内容空白 // 去除前导/尾随空白,但保留内容空白 / Trim leading/trailing whitespace but preserve content whitespace
size_t start = text.find_first_not_of(" \t\r\n"); size_t start = text.find_first_not_of(" \t\r\n");
size_t end = text.find_last_not_of(" \t\r\n"); size_t end = text.find_last_not_of(" \t\r\n");
if (start == std::string::npos) return false; // 空输入 if (start == std::string::npos) return false; // 空输入 / empty input
text = text.substr(start, end - start + 1); text = text.substr(start, end - start + 1);
if (text.empty()) return false; if (text.empty()) return false;
// 保存原始输入到历史(最多保留 20 条) // 保存原始输入到历史(最多保留 20 条) / Save original input to history (max 20 entries)
gs.input_history.push_back(gs.inputBuffer); gs.input_history.push_back(gs.inputBuffer);
if (gs.input_history.size() > 20) if (gs.input_history.size() > 20)
gs.input_history.erase(gs.input_history.begin()); gs.input_history.erase(gs.input_history.begin());
@@ -489,7 +504,8 @@ static bool trySendMessage(GuiState& gs) {
return true; return true;
} }
// 如果输入区域中的 Send/Stop 按钮被点击,返回 true // 检查物理像素坐标是否落在发送/停止按钮区域内。
// Return true if the given physical-pixel coordinates fall within the Send/Stop button.
static bool isSendButtonHit(AppContext& ctx, float physX, float physY) { static bool isSendButtonHit(AppContext& ctx, float physX, float physY) {
float lx = physX / RENDER_SCALE; float lx = physX / RENDER_SCALE;
float ly = physY / RENDER_SCALE; float ly = physY / RENDER_SCALE;
@@ -506,8 +522,10 @@ static bool isSendButtonHit(AppContext& ctx, float physX, float physY) {
ly >= btnY && ly <= btnY + btnH; ly >= btnY && ly <= btnY + btnH;
} }
// ---- 流式回调 ---- // ---- 流式回调 / Streaming callback ----
// 流式 token 回调:将 token 追加到 streamBuffer更新最后一条助手消息然后重新渲染。
// Streaming token callback: appends token to streamBuffer, updates last assistant message, then re-renders.
static int streamTokenCallback(const char* token, void* userdata) { static int streamTokenCallback(const char* token, void* userdata) {
AppContext* ctx = static_cast<AppContext*>(userdata); AppContext* ctx = static_cast<AppContext*>(userdata);
GuiState& gs = ctx->state; GuiState& gs = ctx->state;
@@ -520,7 +538,7 @@ static int streamTokenCallback(const char* token, void* userdata) {
} }
} }
// 泵送事件以保持窗口响应 // 泵送事件以保持窗口响应 / Pump events to keep the window responsive
SDL_PumpEvents(); SDL_PumpEvents();
SDL_Event ev; SDL_Event ev;
@@ -547,15 +565,17 @@ static int streamTokenCallback(const char* token, void* userdata) {
} }
} }
// 重新渲染以显示进度的令牌 // 重新渲染以显示进度的令牌 / Re-render to show the token progress
gs.scrollOffset = 0; gs.scrollOffset = 0;
renderFrame(*ctx); renderFrame(*ctx);
return 0; return 0;
} }
// ---- 主事件处理函数 ---- // ---- 主事件处理函数 / Main event processing function ----
// 分发单个 SDL 事件以更新 GuiState键盘输入、鼠标点击、滚动、文本输入
// Dispatch a single SDL event to update GuiState (keyboard input, mouse clicks, scroll, text input).
static void processEvent(AppContext& ctx, SDL_Event& ev) { static void processEvent(AppContext& ctx, SDL_Event& ev) {
GuiState& gs = ctx.state; GuiState& gs = ctx.state;
@@ -571,23 +591,23 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
bool shift = (mod & SDL_KMOD_SHIFT) != 0; bool shift = (mod & SDL_KMOD_SHIFT) != 0;
if (gs.streaming) { if (gs.streaming) {
// 流式传输期间,按 Escape 键取消 // 流式传输期间,按 Escape 键取消 / While streaming, press Escape to cancel
if (key == SDLK_ESCAPE) { if (key == SDLK_ESCAPE) {
gs.streaming = false; gs.streaming = false;
} }
break; break;
} }
// Tab 切换侧边栏显示/隐藏 // Tab 切换侧边栏显示/隐藏 / Tab toggles sidebar visibility
if (key == SDLK_TAB) { if (key == SDLK_TAB) {
gs.sidebar_visible = !gs.sidebar_visible; gs.sidebar_visible = !gs.sidebar_visible;
break; break;
} }
// 输入历史浏览(↑/↓) // 输入历史浏览(↑/↓) / Input history browsing (Up/Down)
if (key == SDLK_UP && !gs.input_history.empty()) { if (key == SDLK_UP && !gs.input_history.empty()) {
if (gs.history_index == -1) { if (gs.history_index == -1) {
// 首次进入历史浏览,保存当前输入 // 首次进入历史浏览,保存当前输入 / First time browsing history, save current input
gs.saved_input = gs.inputBuffer; gs.saved_input = gs.inputBuffer;
gs.history_index = static_cast<int>(gs.input_history.size()) - 1; gs.history_index = static_cast<int>(gs.input_history.size()) - 1;
} else if (gs.history_index > 0) { } else if (gs.history_index > 0) {
@@ -606,7 +626,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
if (gs.history_index >= 0) { if (gs.history_index >= 0) {
gs.inputBuffer = gs.input_history[gs.history_index]; gs.inputBuffer = gs.input_history[gs.history_index];
} else { } else {
// 回到新输入,恢复暂存的输入 // 回到新输入,恢复暂存的输入 / Back to new input, restore saved input
gs.inputBuffer = gs.saved_input; gs.inputBuffer = gs.saved_input;
} }
gs.cursorPos = static_cast<int>(gs.inputBuffer.size()); gs.cursorPos = static_cast<int>(gs.inputBuffer.size());
@@ -620,7 +640,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
case SDLK_RETURN: case SDLK_RETURN:
case SDLK_KP_ENTER: case SDLK_KP_ENTER:
if (shift) { if (shift) {
// Shift+Enter插入换行符不发送 // Shift+Enter插入换行符不发送 / Shift+Enter: insert newline (don't send)
gs.inputBuffer.insert(gs.cursorPos, "\n"); gs.inputBuffer.insert(gs.cursorPos, "\n");
gs.cursorPos++; gs.cursorPos++;
gs.cursorVisible = true; gs.cursorVisible = true;
@@ -670,7 +690,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
break; break;
case SDLK_V: case SDLK_V:
if (ctrl) { if (ctrl) {
// Ctrl+V从剪贴板粘贴 // Ctrl+V从剪贴板粘贴 / Ctrl+V: paste from clipboard
if (SDL_HasClipboardText()) { if (SDL_HasClipboardText()) {
char* clip = SDL_GetClipboardText(); char* clip = SDL_GetClipboardText();
if (clip) { if (clip) {
@@ -685,7 +705,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
break; break;
case SDLK_C: case SDLK_C:
if (ctrl) { if (ctrl) {
// Ctrl+C复制到剪贴板复制最后一条助手消息 // Ctrl+C复制到剪贴板复制最后一条助手消息 / Ctrl+C: copy to clipboard (copy last assistant message)
if (!gs.messages.empty()) { if (!gs.messages.empty()) {
for (int i = static_cast<int>(gs.messages.size()) - 1; i >= 0; --i) { for (int i = static_cast<int>(gs.messages.size()) - 1; i >= 0; --i) {
if (gs.messages[i].role != ChatMessage::USER) { if (gs.messages[i].role != ChatMessage::USER) {
@@ -701,7 +721,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
break; break;
case SDLK_L: case SDLK_L:
if (ctrl) { if (ctrl) {
// Ctrl+L清除聊天 // Ctrl+L清除聊天 / Ctrl+L: clear chat
if (g_session_svc) g_session_svc->clear(); if (g_session_svc) g_session_svc->clear();
gs.messages.clear(); gs.messages.clear();
gs.messages.push_back(ChatMessage( gs.messages.push_back(ChatMessage(
@@ -711,7 +731,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
break; break;
case SDLK_S: case SDLK_S:
if (ctrl) { if (ctrl) {
// Ctrl+S保存会话 // Ctrl+S保存会话 / Ctrl+S: save session
if (g_session_svc && g_session_svc->save("session.json") == 0) { if (g_session_svc && g_session_svc->save("session.json") == 0) {
gs.messages.push_back(ChatMessage( gs.messages.push_back(ChatMessage(
ChatMessage::SYSTEM, "Session saved to session.json")); ChatMessage::SYSTEM, "Session saved to session.json"));
@@ -724,7 +744,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
break; break;
case SDLK_O: case SDLK_O:
if (ctrl) { if (ctrl) {
// Ctrl+O加载会话 // Ctrl+O加载会话 / Ctrl+O: load session
if (g_session_svc && g_session_svc->load("session.json") == 0) { if (g_session_svc && g_session_svc->load("session.json") == 0) {
gs.messages.push_back(ChatMessage( gs.messages.push_back(ChatMessage(
ChatMessage::SYSTEM, "Session loaded from session.json")); ChatMessage::SYSTEM, "Session loaded from session.json"));
@@ -743,7 +763,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
case SDL_EVENT_TEXT_INPUT: case SDL_EVENT_TEXT_INPUT:
if (!gs.streaming) { if (!gs.streaming) {
// 将文本插入光标位置 // 将文本插入光标位置 / Insert text at cursor position
gs.inputBuffer.insert(gs.cursorPos, ev.text.text); gs.inputBuffer.insert(gs.cursorPos, ev.text.text);
gs.cursorPos += static_cast<int>(strlen(ev.text.text)); gs.cursorPos += static_cast<int>(strlen(ev.text.text));
gs.cursorVisible = true; gs.cursorVisible = true;
@@ -772,6 +792,8 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
case SDL_EVENT_WINDOW_RESIZED: { case SDL_EVENT_WINDOW_RESIZED: {
// 当窗口大小改变时,不更新我们的常量——保持 1024x768 的逻辑尺寸。 // 当窗口大小改变时,不更新我们的常量——保持 1024x768 的逻辑尺寸。
// SDL 将自动缩放输出。 // SDL 将自动缩放输出。
// When window resizes, don't update our constants — keep 1024x768 logical size.
// SDL will auto-scale the output.
break; break;
} }
@@ -780,10 +802,12 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
} }
} }
// ---- 入口 ---- // ---- 入口 / Entry point ----
// 入口:初始化 dstalk host 和 SDL3运行主事件/渲染循环,然后清理。
// Entry point: initializes dstalk host and SDL3, runs the main event/render loop, then cleans up.
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
// ----- 初始化 dstalk ----- // ----- 初始化 dstalk / Initialize dstalk -----
if (dstalk_init(nullptr) != 0) { if (dstalk_init(nullptr) != 0) {
std::fprintf(stderr, "[dstalk] Init failed\n"); std::fprintf(stderr, "[dstalk] Init failed\n");
return 1; return 1;
@@ -796,7 +820,7 @@ int main(int argc, char* argv[]) {
if (!g_ai_svc) dstalk_log(3, "AI service not found (check plugins directory)"); if (!g_ai_svc) dstalk_log(3, "AI service not found (check plugins directory)");
if (!g_session_svc) dstalk_log(3, "Session service not found"); if (!g_session_svc) dstalk_log(3, "Session service not found");
// ----- 初始化 SDL ----- // ----- 初始化 SDL / Initialize SDL -----
if (!SDL_Init(SDL_INIT_VIDEO)) { if (!SDL_Init(SDL_INIT_VIDEO)) {
std::fprintf(stderr, "[dstalk] SDL init failed: %s\n", SDL_GetError()); std::fprintf(stderr, "[dstalk] SDL init failed: %s\n", SDL_GetError());
dstalk_shutdown(); dstalk_shutdown();
@@ -822,10 +846,10 @@ int main(int argc, char* argv[]) {
return 1; return 1;
} }
// 启用文本输入事件 // 启用文本输入事件 / Enable text input events
SDL_StartTextInput(window); SDL_StartTextInput(window);
// ----- 应用程序状态 ----- // ----- 应用程序状态 / Application state -----
AppContext ctx; AppContext ctx;
ctx.window = window; ctx.window = window;
ctx.renderer = renderer; ctx.renderer = renderer;
@@ -834,29 +858,29 @@ int main(int argc, char* argv[]) {
"Ctrl+L clear, Ctrl+S save, Ctrl+O load. " "Ctrl+L clear, Ctrl+S save, Ctrl+O load. "
"Shift+Enter for newline, Up/Down for history, Tab toggle sidebar.")); "Shift+Enter for newline, Up/Down for history, Tab toggle sidebar."));
// ----- 主循环 ----- // ----- 主循环 / Main loop -----
SDL_Event event; SDL_Event event;
while (ctx.state.running) { while (ctx.state.running) {
// 处理所有待处理事件 // 处理所有待处理事件 / Process all pending events
while (SDL_PollEvent(&event)) { while (SDL_PollEvent(&event)) {
processEvent(ctx, event); processEvent(ctx, event);
if (!ctx.state.running) break; if (!ctx.state.running) break;
} }
if (!ctx.state.running) break; if (!ctx.state.running) break;
// 检查待发送的消息 // 检查待发送的消息 / Check for pending message to send
if (ctx.sendPending && !ctx.state.streaming) { if (ctx.sendPending && !ctx.state.streaming) {
ctx.sendPending = false; ctx.sendPending = false;
if (trySendMessage(ctx.state)) { if (trySendMessage(ctx.state)) {
// 开始流式传输 // 开始流式传输 / Start streaming
ctx.state.streaming = true; ctx.state.streaming = true;
ctx.streamBuffer.clear(); ctx.streamBuffer.clear();
// 为流式响应添加占位消息 // 为流式响应添加占位消息 / Add placeholder message for streaming response
ctx.state.messages.push_back( ctx.state.messages.push_back(
ChatMessage(ChatMessage::ASSISTANT, "")); ChatMessage(ChatMessage::ASSISTANT, ""));
ctx.state.scrollOffset = 0; ctx.state.scrollOffset = 0;
// 对最后一条消息调用流式 API通过插件服务 vtable // 对最后一条消息调用流式 API通过插件服务 vtable / Call streaming API for the last message (via plugin service vtable)
std::string& userMsg = std::string& userMsg =
ctx.state.messages[ctx.state.messages.size() - 2].content; ctx.state.messages[ctx.state.messages.size() - 2].content;
int rc = -1; int rc = -1;
@@ -871,7 +895,7 @@ int main(int argc, char* argv[]) {
g_ai_svc->free_result(&result); g_ai_svc->free_result(&result);
} }
// 流式传输完成(或被取消) // 流式传输完成(或被取消) / Streaming completed (or cancelled)
if (rc != 0) { if (rc != 0) {
if (!ctx.state.messages.empty() && if (!ctx.state.messages.empty() &&
ctx.state.messages.back().role == ChatMessage::ASSISTANT) { ctx.state.messages.back().role == ChatMessage::ASSISTANT) {
@@ -884,14 +908,14 @@ int main(int argc, char* argv[]) {
} }
} }
// 渲染当前帧 // 渲染当前帧 / Render current frame
renderFrame(ctx); renderFrame(ctx);
// 短暂休眠以降低 CPU 使用率 // 短暂休眠以降低 CPU 使用率 / Brief sleep to reduce CPU usage
SDL_Delay(16); // ~60 FPS SDL_Delay(16); // ~60 FPS
} }
// ----- 清理 ----- // ----- 清理 / Cleanup -----
SDL_StopTextInput(window); SDL_StopTextInput(window);
SDL_DestroyRenderer(renderer); SDL_DestroyRenderer(renderer);
SDL_DestroyWindow(window); SDL_DestroyWindow(window);

27
dstalk-web/CMakeLists.txt Normal file
View File

@@ -0,0 +1,27 @@
# ============================================================
# dstalk-web — Web 前端 / Web frontend (Boost.Beast HTTP + SSE)
# ============================================================
find_package(Boost REQUIRED CONFIG)
add_executable(dstalk-web
src/main.cpp
)
set_target_properties(dstalk-web PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
)
target_compile_features(dstalk-web PRIVATE cxx_std_20)
target_link_libraries(dstalk-web
PRIVATE
dstalk
boost::boost
dstalk_boost_config
)
# Windows: Boost.Asio 需要 Winsock / Boost.Asio requires Winsock
if(WIN32)
target_link_libraries(dstalk-web PRIVATE ws2_32)
endif()

561
dstalk-web/src/main.cpp Normal file
View File

@@ -0,0 +1,561 @@
/*
* @file main.cpp
* @brief Boost.Beast HTTP server frontend for dstalk-web: SSE streaming chat, embedded web UI, CORS support.
* dstalk-web 的 Boost.Beast HTTP 服务端SSE 流式对话、嵌入式网页界面、CORS 支持。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h"
#include "web_ui.hpp"
#include <boost/beast/core.hpp>
#include <boost/beast/http.hpp>
#include <boost/asio.hpp>
#include <boost/json.hpp>
#include <boost/json/src.hpp>
#include <atomic>
#include <cstdio>
#include <cstring>
#include <deque>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#ifdef _WIN32
#include <windows.h>
#else
#include <signal.h>
#endif
// ---- 命名空间别名 / Namespace aliases ----
namespace beast = boost::beast;
namespace http = beast::http;
namespace asio = boost::asio;
using tcp = asio::ip::tcp;
// ---- 前置声明 / Forward declarations ----
class SseSession;
// ---- 服务 vtable 指针 / Service vtable pointers ----
// Global pointers to plugin service vtables, queried from the host on startup.
// 插件服务 vtable 的全局指针,在启动时从主机查询获取。
static const dstalk_ai_service_t* g_ai = nullptr;
static const dstalk_session_service_t* g_session = nullptr;
// ---- 运行时状态 / Runtime state ----
// g_quit signals the main loop to exit (set by Ctrl+C).
// g_ioc is the io_context pointer for use by signal handlers to stop the event loop.
// g_quit 通知主循环退出(由 Ctrl+C 设置)。
// g_ioc 供信号处理函数调用 stop() 的 io_context 指针。
static std::atomic<bool> g_quit{false};
static asio::io_context* g_ioc = nullptr;
// ---- Ctrl+C 信号处理 / Ctrl+C signal handlers ----
// Windows console event handler (CTRL_C_EVENT / CTRL_BREAK_EVENT).
// Windows 控制台事件处理CTRL_C_EVENT / CTRL_BREAK_EVENT
#ifdef _WIN32
static BOOL WINAPI on_console_event(DWORD event)
{
if (event == CTRL_C_EVENT || event == CTRL_BREAK_EVENT) {
g_quit = true;
if (g_ioc) g_ioc->stop();
return TRUE;
}
return FALSE;
}
// Unix signal handler (SIGINT).
// Unix 信号处理SIGINT
#else
static void on_signal(int /*sig*/)
{
g_quit = true;
if (g_ioc) g_ioc->stop();
}
#endif
// ========================================================================
// SseSession — 管理一个 SSE 流式响应连接 / Manages one SSE streaming response
// ========================================================================
// 持有从 HttpSession 转移过来的 tcp::socket以 SSE 格式流式发送 AI 回复。
// 所有公开方法均在 io_context 线程上被调用,因此无需互斥锁。
// Owns the tcp::socket transferred from HttpSession; streams AI response as SSE.
// All public methods are called on the io_context thread, so no mutex is needed.
class SseSession : public std::enable_shared_from_this<SseSession> {
public:
// 构造函数:接管已接受的 socket / Constructor: take ownership of the accepted socket.
explicit SseSession(tcp::socket&& s) : socket_(std::move(s)) {}
// 发送 SSE HTTP 响应头并准备接收数据帧 / Send SSE HTTP response headers and prepare for data frames.
void start() {
writing_ = true; // 阻止数据写入,等待头部发送完成 / Block data writes until headers are sent
std::string header =
"HTTP/1.1 200 OK\r\n"
"Content-Type: text/event-stream\r\n"
"Cache-Control: no-cache\r\n"
"Connection: keep-alive\r\n"
"Access-Control-Allow-Origin: *\r\n"
"\r\n";
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(header),
[self](beast::error_code ec, size_t) {
self->writing_ = false;
if (!ec && !self->pending_.empty()) {
self->do_write();
}
// 写入失败则让 socket 随 SseSession 析构自然关闭 / On error, let socket close via destructor
});
}
// 将 token 加入待发送队列,若未在写入则启动写链 / Push token to pending queue; start write chain if idle.
void send_token(const std::string& token) {
if (closed_) return;
// 换行符会破坏 SSE 帧结构,替换为空格 / Newlines break SSE frame structure; replace with spaces
std::string t = token;
for (auto& c : t) if (c == '\n' || c == '\r') c = ' ';
if (t.empty()) return;
pending_.push_back("data: " + t + "\n\n");
if (!writing_) do_write();
}
// 发送完成事件后关闭连接 / Send done event then close the connection.
void send_done(bool ok, const std::string& content) {
if (closed_) return;
closed_ = true;
(void)ok;
(void)content;
// JS 客户端忽略 [DONE] token流自然结束触发最终渲染 / JS client ignores [DONE] token; stream end triggers final render
pending_.push_back("event: done\ndata: [DONE]\n\n");
if (!writing_) do_write();
}
// 发送错误事件后关闭连接 / Send error event then close the connection.
void send_error(const std::string& msg) {
if (closed_) return;
closed_ = true;
std::string m = msg;
for (auto& c : m) if (c == '\n' || c == '\r') c = ' ';
pending_.push_back("event: error\ndata: " + m + "\n\n");
if (!writing_) do_write();
}
private:
// 从待发送队列头部取出并异步写入 / Pop front of pending queue and async-write it.
void do_write() {
if (writing_ || pending_.empty()) return;
writing_ = true;
auto self = shared_from_this();
asio::async_write(socket_, asio::buffer(pending_.front()),
[self](beast::error_code ec, size_t) {
self->writing_ = false;
self->pending_.pop_front();
if (!ec) {
if (!self->pending_.empty()) {
self->do_write();
} else if (self->closed_) {
// 队列已空且会话已关闭 → 关闭 socket / Queue drained and session closed → close socket
beast::error_code ignored;
self->socket_.shutdown(tcp::socket::shutdown_both, ignored);
self->socket_.close(ignored);
}
}
// 写入错误时 socket 随 SseSession 析构 / On write error, socket closes with SseSession
});
}
tcp::socket socket_;
std::deque<std::string> pending_;
bool writing_ = false;
bool closed_ = false;
};
// ========================================================================
// run_chat_worker — 在独立线程中执行流式 AI 聊天 / Execute streaming AI chat in a dedicated thread
// ========================================================================
// 将用户消息加入会话,调用 g_ai->chat_stream(),通过 asio::post 将 token 投递到 io_context。
// Add user message to session, call g_ai->chat_stream(), post tokens to io_context via asio::post.
static void run_chat_worker(
std::string user_input,
std::weak_ptr<SseSession> weak_sse,
asio::io_context& ioc)
{
// 将用户消息加入会话 / Add user message to session
dstalk_message_t user_msg = {"user", user_input.c_str(), nullptr, nullptr};
g_session->add(&user_msg);
// 获取会话历史 / Get session history
int history_count = 0;
const dstalk_message_t* history = g_session->history(&history_count);
// 流式回调上下文 / Streaming callback context
struct CallbackData {
std::weak_ptr<SseSession> sse;
asio::io_context* ioc;
};
CallbackData cb_data{weak_sse, &ioc};
// 流式 token 回调:将 token 投递到 io_context 线程 / Streaming token callback: post token to io_context thread
auto token_cb = [](const char* token, void* userdata) -> int {
auto* data = static_cast<CallbackData*>(userdata);
if (auto sse = data->sse.lock()) {
std::string t(token);
asio::post(*data->ioc, [sse, t = std::move(t)]() {
sse->send_token(t);
});
}
return 0;
};
// 调用流式 AI 聊天 / Call streaming AI chat
dstalk_chat_result_t result = g_ai->chat_stream(
history, history_count, nullptr, token_cb, &cb_data);
// 将 AI 回复加入会话 / Add AI reply to session
if (result.ok) {
dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json};
g_session->add(&ai_msg);
}
// 将完成/错误事件投递到 io_context 线程 / Post completion/error event to io_context thread
bool ok = result.ok;
std::string content_copy = result.content ? result.content : "";
std::string error_copy = result.error ? result.error : "";
g_ai->free_result(&result);
asio::post(ioc, [weak_sse, ok, content_copy, error_copy]() {
if (auto sse = weak_sse.lock()) {
if (ok) {
sse->send_done(true, content_copy);
} else {
sse->send_error(error_copy.empty() ? "unknown error" : error_copy);
}
}
});
}
// ========================================================================
// HttpSession — 处理单个 HTTP 请求 / Handles one HTTP request
// ========================================================================
// 使用 Beast 解析请求,按 method + target 路由到相应处理器。
// Uses Beast to parse the request, routing by method + target to the appropriate handler.
class HttpSession : public std::enable_shared_from_this<HttpSession> {
public:
// 构造函数:接管已接受的 socket / Constructor: take ownership of the accepted socket.
explicit HttpSession(tcp::socket&& s) : socket_(std::move(s)) {}
// 开始读取请求 / Start reading the request.
void start() { do_read(); }
private:
// 异步读取 HTTP 请求 / Asynchronously read the HTTP request.
void do_read() {
auto self = shared_from_this();
http::async_read(socket_, buffer_, request_,
[self](beast::error_code ec, size_t) {
if (ec) return; // 客户端断开或读取错误 / Client disconnected or read error
self->handle_request();
});
}
// 路由 HTTP 请求到相应的处理器 / Route the HTTP request to the appropriate handler.
void handle_request() {
auto const method = request_.method();
auto const target = std::string(request_.target());
// GET / — 返回嵌入式网页界面 / Return embedded web UI
if (method == http::verb::get && target == "/") {
serve_web_ui();
return;
}
// POST /chat — SSE 流式聊天 / SSE streaming chat
if (method == http::verb::post && target == "/chat") {
handle_chat();
return;
}
// OPTIONS /chat — CORS 预检请求 / CORS preflight request
if (method == http::verb::options && target == "/chat") {
serve_cors_preflight();
return;
}
// POST /clear — 清除会话 / Clear session
if (method == http::verb::post && target == "/clear") {
handle_clear();
return;
}
// POST /status — 返回运行状态 / Return runtime status
if (method == http::verb::post && target == "/status") {
handle_status();
return;
}
// 未知路由 — 404 / Unknown route — 404
serve_404();
}
// 返回 HTML 网页界面 / Serve the HTML web UI.
void serve_web_ui() {
auto self = shared_from_this();
http::response<http::string_body> res{http::status::ok, request_.version()};
res.set(http::field::content_type, "text/html; charset=utf-8");
res.set(http::field::cache_control, "no-cache");
res.set("Access-Control-Allow-Origin", "*");
res.body() = kWebUiHtml;
res.prepare_payload();
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
}
// 解析 JSON body、创建 SseSession、启动工作线程 / Parse JSON body, create SseSession, spawn worker thread.
void handle_chat() {
// 解析 JSON body / Parse JSON body
boost::system::error_code ec;
auto jv = boost::json::parse(request_.body(), ec);
if (ec || !jv.is_object()) {
serve_bad_request("Invalid JSON body");
return;
}
auto const& obj = jv.as_object();
auto it = obj.find("message");
if (it == obj.end() || !it->value().is_string()) {
serve_bad_request("Missing or invalid 'message' field");
return;
}
std::string user_input = boost::json::value_to<std::string>(it->value());
// 创建 SseSession 并转移 socket 所有权 / Create SseSession and transfer socket ownership
auto sse = std::make_shared<SseSession>(std::move(socket_));
sse->start();
// 在独立线程中执行聊天chat_stream 是阻塞调用) / Execute chat in dedicated thread (chat_stream is blocking)
std::thread worker([user_input = std::move(user_input), sse]() {
run_chat_worker(user_input, sse, *g_ioc);
});
worker.detach();
}
// 返回 CORS 预检响应头 / Return CORS preflight response headers.
void serve_cors_preflight() {
auto self = shared_from_this();
http::response<http::empty_body> res{http::status::ok, request_.version()};
res.set("Access-Control-Allow-Origin", "*");
res.set("Access-Control-Allow-Methods", "POST, OPTIONS");
res.set("Access-Control-Allow-Headers", "Content-Type");
res.set("Access-Control-Max-Age", "86400");
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
}
// 清除当前会话 / Clear the current session.
void handle_clear() {
if (g_session) g_session->clear();
auto self = shared_from_this();
http::response<http::string_body> res{http::status::ok, request_.version()};
res.set("Access-Control-Allow-Origin", "*");
res.set(http::field::content_type, "application/json");
res.body() = "{\"ok\":true}";
res.prepare_payload();
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
}
// 返回运行状态 JSON / Return runtime status as JSON.
void handle_status() {
boost::json::object st;
if (g_session) {
int count = 0;
g_session->history(&count);
st["messages"] = count;
st["tokens"] = g_session->token_count();
}
const char* model = dstalk_config_get("api.model");
if (model) st["model"] = std::string(model);
st["status"] = "running";
auto self = shared_from_this();
http::response<http::string_body> res{http::status::ok, request_.version()};
res.set("Access-Control-Allow-Origin", "*");
res.set(http::field::content_type, "application/json");
res.body() = boost::json::serialize(st);
res.prepare_payload();
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
}
// 返回 400 Bad Request / Return 400 Bad Request.
void serve_bad_request(const std::string& msg) {
auto self = shared_from_this();
http::response<http::string_body> res{http::status::bad_request, request_.version()};
res.set(http::field::content_type, "text/plain");
res.set("Access-Control-Allow-Origin", "*");
res.body() = msg;
res.prepare_payload();
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
}
// 返回 404 Not Found / Return 404 Not Found.
void serve_404() {
auto self = shared_from_this();
http::response<http::string_body> res{http::status::not_found, request_.version()};
res.set(http::field::content_type, "text/plain");
res.set("Access-Control-Allow-Origin", "*");
res.body() = "404 Not Found";
res.prepare_payload();
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
}
tcp::socket socket_;
beast::flat_buffer buffer_;
http::request<http::string_body> request_;
};
// ========================================================================
// Listener — 接受 TCP 连接并创建 HttpSession / Accepts TCP connections and creates HttpSessions
// ========================================================================
// 异步接受循环:每个进入的连接包装为 HttpSession 并由 io_context 驱动其生命周期。
// Async accept loop: each inbound connection is wrapped in an HttpSession driven by the io_context.
class Listener {
public:
// 构造函数:打开 acceptor、绑定地址、开始监听 / Constructor: open acceptor, bind, start listening.
Listener(asio::io_context& ioc, const tcp::endpoint& ep)
: acceptor_(ioc)
{
beast::error_code ec;
acceptor_.open(ep.protocol(), ec);
if (ec) {
std::fprintf(stderr, "[dstalk-web] acceptor.open: %s\n", ec.message().c_str());
return;
}
acceptor_.set_option(asio::socket_base::reuse_address(true), ec);
acceptor_.bind(ep, ec);
if (ec) {
std::fprintf(stderr, "[dstalk-web] acceptor.bind: %s\n", ec.message().c_str());
return;
}
acceptor_.listen(asio::socket_base::max_listen_connections, ec);
if (ec) {
std::fprintf(stderr, "[dstalk-web] acceptor.listen: %s\n", ec.message().c_str());
return;
}
}
// 启动接受循环 / Start the accept loop.
void run() { do_accept(); }
private:
// 异步接受一个连接,创建 HttpSession 并继续监听 / Async-accept one connection, create HttpSession, keep listening.
void do_accept() {
acceptor_.async_accept(
[this](beast::error_code ec, tcp::socket socket) {
if (!ec) {
// 为每个入站连接创建新的 HttpSession / Create a new HttpSession for each inbound connection
std::make_shared<HttpSession>(std::move(socket))->start();
}
// 继续接受下一个连接(除非已发出退出信号) / Keep accepting (unless quit has been signaled)
if (!g_quit) do_accept();
});
}
tcp::acceptor acceptor_;
};
// ========================================================================
// main — 入口点 / Entry point
// ========================================================================
// 初始化 dstalk host查询 AI/Session 服务,配置 HTTP 监听,运行 io_context 事件循环。
// Initialize dstalk host, query AI/Session services, configure HTTP listener, run io_context event loop.
int main(int argc, char* argv[])
{
// Windows: 启用 ANSI 转义码 + 安装 Ctrl+C 处理器 / Windows: enable ANSI escape codes + install Ctrl+C handler
#ifdef _WIN32
HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE);
DWORD mode = 0;
GetConsoleMode(hOut, &mode);
SetConsoleMode(hOut, mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING);
SetConsoleCtrlHandler(on_console_event, TRUE);
#else
signal(SIGINT, on_signal);
#endif
// 查找配置文件路径 / Locate config file path
const char* config_path = nullptr;
if (argc >= 2) {
config_path = argv[1];
}
if (!config_path) {
const char* default_configs[] = {"config.toml", nullptr};
for (int i = 0; default_configs[i]; i++) {
FILE* f = nullptr;
#ifdef _WIN32
fopen_s(&f, default_configs[i], "r");
#else
f = fopen(default_configs[i], "r");
#endif
if (f) {
fclose(f);
config_path = default_configs[i];
break;
}
}
}
// 初始化 dstalk 主机(加载配置 + 自动扫描 plugins/ 目录) / Init dstalk host (load config + auto-scan plugins/)
if (dstalk_init(config_path) != 0) {
std::fprintf(stderr, "[dstalk-web] dstalk_init failed\n");
return 3;
}
// 查询插件服务 / Query plugin services
const char* ai_provider = dstalk_config_get("ai.provider");
if (!ai_provider) ai_provider = "ai.deepseek";
g_ai = static_cast<const dstalk_ai_service_t*>(dstalk_service_query(ai_provider, 1));
g_session = static_cast<const dstalk_session_service_t*>(dstalk_service_query("session", 1));
if (!g_ai) {
std::fprintf(stderr, "[dstalk-web] AI service not found (check plugins directory)\n");
}
if (!g_session) {
std::fprintf(stderr, "[dstalk-web] Session service not found\n");
}
// 从配置自动加载 AI 设置 / Auto-load AI settings from config
if (g_ai) {
const char* base_url = dstalk_config_get("api.base_url");
const char* api_key = dstalk_config_get("api.api_key");
const char* model = dstalk_config_get("api.model");
if (!base_url) base_url = "https://api.deepseek.com/v1";
if (!model) model = "deepseek-v4-pro";
g_ai->configure(ai_provider, base_url, api_key ? api_key : "", model, 4096, 0.7);
}
// 读取 web 服务配置 / Read web server config
const char* web_host = dstalk_config_get("web.host");
if (!web_host || !web_host[0]) web_host = "127.0.0.1";
const char* web_port_str = dstalk_config_get("web.port");
unsigned short web_port = 8080;
if (web_port_str && web_port_str[0]) {
web_port = static_cast<unsigned short>(std::strtoul(web_port_str, nullptr, 10));
}
// 创建 io_context 并启动监听 / Create io_context and start listener
asio::io_context ioc;
g_ioc = &ioc;
tcp::endpoint ep(asio::ip::make_address(web_host), web_port);
Listener listener(ioc, ep);
listener.run();
// 打印启动信息 / Print startup message
std::printf("[dstalk-web] running at http://%s:%u\n", web_host, web_port);
std::printf("[dstalk-web] Press Ctrl+C to stop\n");
// 运行事件循环(阻塞直到 g_ioc->stop() 被信号处理函数调用) / Run event loop (blocks until g_ioc->stop() called by signal handler)
ioc.run();
// 清理 / Cleanup
g_ioc = nullptr;
dstalk_shutdown();
std::printf("[dstalk-web] stopped\n");
return 0;
}

226
dstalk-web/src/web_ui.hpp Normal file
View File

@@ -0,0 +1,226 @@
/*
* @file web_ui.hpp
* @brief Embedded HTML/JS chat UI served by dstalk-web.
* 嵌入的 HTML/JS 聊天界面,由 dstalk-web 提供。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#ifndef DSTALK_WEB_UI_HPP
#define DSTALK_WEB_UI_HPP
// 深色主题单页聊天界面 — 通过 fetch ReadableStream 实现 SSE 流式传输
// Dark-themed single-page chat UI — SSE streaming via fetch ReadableStream
static const char kWebUiHtml[] = R"html(<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width,initial-scale=1.0">
<title>dstalk Web</title>
<style>
*,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
body{font-family:system-ui,-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,Oxygen,Ubuntu,Cantarell,sans-serif;background:#1a1a2e;color:#eaeaea;height:100dvh;display:flex;flex-direction:column;overflow:hidden}
#header{display:flex;align-items:center;justify-content:space-between;padding:8px 18px;background:#16162a;border-bottom:1px solid #2a2a4a;flex-shrink:0}
#header h1{font-size:1rem;font-weight:600;display:flex;align-items:center;gap:8px;color:#a6b8e0}
#header .status-row{display:flex;align-items:center;gap:14px;font-size:.75rem;color:#7a7a9a}
#dot{width:9px;height:9px;border-radius:50%;background:#555;flex-shrink:0;transition:background .3s}
#dot.connected{background:#4caf50}
#dot.streaming{background:#f06292;animation:pulse .8s infinite}
@keyframes pulse{0%,100%{opacity:1}50%{opacity:.35}}
#clearBtn{font-size:.72rem;padding:3px 10px;border:1px solid #f06292;border-radius:4px;color:#f06292;background:transparent;cursor:pointer;transition:background .2s}
#clearBtn:hover{background:#f0629218}
#messages{flex:1;overflow-y:auto;padding:16px 18px;display:flex;flex-direction:column;gap:10px;scroll-behavior:smooth}
.bubble{max-width:82%;padding:10px 14px;border-radius:10px;font-size:.9rem;line-height:1.55;word-wrap:break-word;animation:fadeIn .2s}
.bubble.user{align-self:flex-end;background:#2a3f6e;border-bottom-right-radius:3px}
.bubble.assistant{align-self:flex-start;background:#1e1e38;border:1px solid #2a2a4a;border-bottom-left-radius:3px}
.bubble pre{background:#0d1117;padding:9px 12px;border-radius:6px;overflow-x:auto;margin:6px 0;font-size:.8rem;white-space:pre-wrap}
.bubble code{font-family:"Fira Code","Cascadia Code",Consolas,monospace;font-size:.8rem}
.bubble strong{color:#f06292}
.bubble .lang{display:block;color:#8b949e;font-size:.68rem;margin-bottom:3px;text-transform:uppercase;letter-spacing:.4px}
#typing{display:none;align-self:flex-start;padding:10px 14px;background:#1e1e38;border:1px solid #2a2a4a;border-radius:10px;border-bottom-left-radius:3px}
#typing.active{display:block}
#typing span{display:inline-block;width:6px;height:6px;border-radius:50%;background:#f06292;margin:0 2px;animation:bounce 1.2s infinite}
#typing span:nth-child(2){animation-delay:.15s}
#typing span:nth-child(3){animation-delay:.3s}
@keyframes bounce{0%,60%,100%{transform:translateY(0);opacity:.3}30%{transform:translateY(-6px);opacity:1}}
#inputBar{display:flex;gap:8px;padding:10px 16px;background:#16162a;border-top:1px solid #2a2a4a;flex-shrink:0}
#inputBar textarea{flex:1;resize:none;background:#1a1a2e;color:#eaeaea;border:1px solid #2a2a4a;border-radius:8px;padding:9px 12px;font-size:.88rem;font-family:inherit;outline:none;transition:border .2s,box-shadow .2s;min-height:40px;max-height:120px;rows:1}
#inputBar textarea:focus{border-color:#f06292;box-shadow:0 0 8px #f0629230}
#sendBtn,#stopBtn{padding:9px 16px;border:none;border-radius:8px;font-size:.88rem;font-weight:600;cursor:pointer;transition:background .2s,opacity .2s}
#sendBtn{background:#f06292;color:#fff}
#sendBtn:hover{background:#d4517a}
#sendBtn:disabled{opacity:.45;cursor:not-allowed}
#stopBtn{display:none;background:#4a4a6a;color:#ccc}
#stopBtn:hover{background:#5a5a7a}
#stopBtn.visible{display:inline-block}
.emptyState{text-align:center;color:#4a4a6a;margin-top:20vh;font-size:.92rem;line-height:1.7}
.emptyState .logo{font-size:2rem;margin-bottom:8px}
@keyframes fadeIn{from{opacity:0;transform:translateY(5px)}to{opacity:1;transform:translateY(0)}}
@media(max-width:600px){.bubble{max-width:92%}#inputBar{padding:8px 10px;gap:6px}#sendBtn,#stopBtn{padding:8px 14px;font-size:.82rem}}
</style>
</head>
<body>
<div id="header">
<h1>&#9670; dstalk Web <span id="dot"></span></h1>
<div class="status-row">
<span id="lblModel">-</span>
<button id="clearBtn" title="Clear conversation / 清空对话">Clear</button>
</div>
</div>
<div id="messages"><div class="emptyState"><div class="logo">&#9670;</div>dstalk Web<br>Send a message to begin.<br>发送消息开始对话。</div></div>
<div id="typing"><span></span><span></span><span></span></div>
<div id="inputBar">
<textarea id="msgInput" placeholder="输入消息... (Enter 发送 / Shift+Enter 换行)" rows="1"></textarea>
<button id="sendBtn">Send</button>
<button id="stopBtn">Stop</button>
</div>
<script>
const msgs=document.getElementById('messages'),input=document.getElementById('msgInput'),
sendBtn=document.getElementById('sendBtn'),stopBtn=document.getElementById('stopBtn'),
dot=document.getElementById('dot'),typing=document.getElementById('typing'),
lblModel=document.getElementById('lblModel');
let abortCtrl=null,streaming=false,lastAiBubble=null,tokenBuf='';
function scrollDown(){msgs.scrollTop=msgs.scrollHeight}
function clearEmptyState(){const e=msgs.querySelector('.emptyState');if(e)e.remove()}
function addBubble(role,text){
clearEmptyState();
const d=document.createElement('div');
d.className='bubble '+role;
d.innerHTML=renderMD(text);
msgs.appendChild(d);
scrollDown();
return d;
}
function renderMD(t){
t=t.replace(/&/g,'&amp;').replace(/</g,'&lt;').replace(/>/g,'&gt;');
// 代码块 / code fences
t=t.replace(/```(\w*)\n?([\s\S]*?)```/g,(_,lang,code)=>{
const label=lang?'<span class="lang">'+lang+'</span>':'';
return '<pre><code>'+label+code.trim()+'</code></pre>';
});
// 行内代码 / inline code
t=t.replace(/`([^`]+)`/g,'<code>$1</code>');
// 粗体 / bold
t=t.replace(/\*\*(.+?)\*\*/g,'<strong>$1</strong>');
// 换行 / newlines
t=t.replace(/\n/g,'<br>');
return t;
}
function setStreaming(s){
streaming=s;
dot.classList.toggle('streaming',s);
dot.classList.toggle('connected',!s);
typing.classList.toggle('active',s&&!lastAiBubble);
sendBtn.disabled=s;
stopBtn.classList.toggle('visible',s);
if(s){
if(!lastAiBubble){lastAiBubble=addBubble('assistant','');typing.classList.remove('active')}
tokenBuf='';
}else{
if(lastAiBubble&&tokenBuf)lastAiBubble.innerHTML=renderMD(tokenBuf);
lastAiBubble=null;tokenBuf='';abortCtrl=null;
}
}
// 解析 SSE 帧 / Parse SSE frames (separated by double-newline)
function parseSSE(text,callback){
let idx;
while((idx=text.indexOf('\n\n'))!==-1){
const frame=text.slice(0,idx);
text=text.slice(idx+2);
const lines=frame.split('\n');
let event='message',data='';
for(const line of lines){
if(line.startsWith('event: '))event=line.slice(7).trim();
else if(line.startsWith('data: '))data+=line.slice(6);
}
callback(event,data);
}
return text;
}
async function send(){
const text=input.value.trim();
if(!text||streaming)return;
input.value='';input.style.height='auto';
addBubble('user',text);
setStreaming(true);
abortCtrl=new AbortController();
try{
const res=await fetch('/chat',{
method:'POST',headers:{'Content-Type':'application/json'},
body:JSON.stringify({message:text}),signal:abortCtrl.signal
});
if(!res.ok)throw new Error('HTTP '+res.status);
const reader=res.body.getReader(),decoder=new TextDecoder();
let leftover='';
while(true){
const{value,done}=await reader.read();
if(done)break;
leftover+=decoder.decode(value,{stream:true});
leftover=parseSSE(leftover,(event,data)=>{
if(event==='error'){
if(lastAiBubble)lastAiBubble.innerHTML=renderMD(tokenBuf||'');
const errEl=addBubble('assistant','[Error] '+data);
errEl.style.borderColor='#f06292';
}else if(event==='done'){
/* stream complete */
}else{
tokenBuf+=data;
if(lastAiBubble){lastAiBubble.innerHTML=renderMD(tokenBuf);scrollDown()}
}
});
}
// 处理残留在 leftover 中的非帧数据 / handle leftover non-frame data
if(leftover.trim()){
const trimmed=leftover.trim();
if(trimmed.startsWith('data: '))tokenBuf+=trimmed.slice(6);
}
}catch(e){
if(e.name!=='AbortError'){
if(lastAiBubble&&tokenBuf)lastAiBubble.innerHTML=renderMD(tokenBuf);
}
}
setStreaming(false);
loadStatus();
}
function stop(){if(abortCtrl){abortCtrl.abort();setStreaming(false)}}
input.addEventListener('keydown',e=>{
if(e.key==='Enter'&&!e.shiftKey){e.preventDefault();send()}
setTimeout(()=>{input.style.height='auto';input.style.height=Math.min(input.scrollHeight,120)+'px'},0);
});
sendBtn.addEventListener('click',send);
stopBtn.addEventListener('click',stop);
async function loadStatus(){
try{
const r=await fetch('/status',{method:'POST'});
if(!r.ok)return;
const s=await r.json();
lblModel.textContent=(s.model||'-')+' | '+(s.provider||'-');
dot.classList.toggle('connected',!!s.ai);
}catch(e){}
}
document.getElementById('clearBtn').addEventListener('click',async()=>{
try{await fetch('/clear',{method:'POST'})}catch(e){}
msgs.innerHTML='<div class="emptyState"><div class="logo">&#9670;</div>dstalk Web<br>Send a message to begin.<br>发送消息开始对话。</div>';
lastAiBubble=null;tokenBuf='';
if(streaming)setStreaming(false);
loadStatus();
});
// 页面加载时检查后端状态 / Check backend status on load
loadStatus();
</script>
</body>
</html>)html";
#endif // DSTALK_WEB_UI_HPP

View File

@@ -1,7 +1,10 @@
/* /*
* example_plugin.cpp - Minimal dstalk plugin demonstrating the API contract. * @file example_plugin.cpp
* @brief Example plugin demonstrating the dstalk plugin API contract.
* 示例插件:演示 dstalk 插件 API 契约。
* Copyright (c) 2026 dstalk contributors. GPLv3.
* *
* Build instructions (conceptual): * Build instructions (conceptual) / 构建说明(概念性):
* *
* Linux / macOS: * Linux / macOS:
* g++ -std=c++20 -shared -fPIC -fvisibility=hidden \ * g++ -std=c++20 -shared -fPIC -fvisibility=hidden \
@@ -14,6 +17,7 @@
* /Fe:example_plugin.dll example_plugin.cpp * /Fe:example_plugin.dll example_plugin.cpp
* *
* The resulting `.so` / `.dylib` / `.dll` can be loaded with: * The resulting `.so` / `.dylib` / `.dll` can be loaded with:
* 生成的 .so / .dylib / .dll 可通过以下方式加载:
* *
* int id = dstalk_plugin_load("./example_plugin.so"); * int id = dstalk_plugin_load("./example_plugin.so");
*/ */
@@ -25,11 +29,12 @@
#include <cstring> /* strlen, strcmp */ #include <cstring> /* strlen, strcmp */
/* ------------------------------------------------------------------ /* ------------------------------------------------------------------
* Private state (one instance per plugin load) * 私有状态(每个插件加载实例一份) / Private state (one instance per plugin load)
* ------------------------------------------------------------------ * ------------------------------------------------------------------
* *
* In a more complex plugin this struct would hold open database * In a more complex plugin this struct would hold open database
* connections, configuration, etc. * connections, configuration, etc.
* 在更复杂的插件中,此结构体可包含打开的数据库连接、配置等。
*/ */
struct ExampleState { struct ExampleState {
@@ -37,31 +42,33 @@ struct ExampleState {
}; };
/* ------------------------------------------------------------------ /* ------------------------------------------------------------------
* Stored host API table so callbacks can use host services. * 保存主机 API 表,以便回调函数使用主机服务 / Stored host API table so callbacks can use host services.
* ------------------------------------------------------------------ */ * ------------------------------------------------------------------ */
static const dstalk_host_api_t* g_host = nullptr; static const dstalk_host_api_t* g_host = nullptr;
static ExampleState g_state; /* not heap-allocated: stays valid static ExampleState g_state; /* 非堆分配:在库映射期间持续有效 / not heap-allocated: stays valid
while the library is mapped */ while the library is mapped */
/* ------------------------------------------------------------------ /* ------------------------------------------------------------------
* on_init (was on_load) * on_init原 on_load / on_init (was on_load)
* ------------------------------------------------------------------ */ * ------------------------------------------------------------------ */
// 插件初始化:保存主机指针,重置调用计数,记录加载消息 / Plugin init: store host pointer, reset call count, log loaded message.
static int my_on_init(const dstalk_host_api_t* host) static int my_on_init(const dstalk_host_api_t* host)
{ {
g_host = host; g_host = host;
g_state.call_count = 0; g_state.call_count = 0;
/* TODO: real plugins would initialise resources here: /* TODO: 真实插件应在此处初始化资源 / real plugins would initialise resources here:
* - parse a plugin-specific config file via host->config_get * - 通过 host->config_get 解析插件专属配置文件 / parse a plugin-specific config file via host->config_get
* - open a log file * - 打开日志文件 / open a log file
* - connect to a local service * - 连接到本地服务 / connect to a local service
* - register services via host->register_service * - 通过 host->register_service 注册服务 / register services via host->register_service
* *
* Return non-zero to signal a fatal initialisation error to the * Return non-zero to signal a fatal initialisation error to the
* host, which will then unload the plugin immediately. * host, which will then unload the plugin immediately.
* 返回非零值以向主机报告致命初始化错误,主机将立即卸载该插件。
*/ */
if (host) { if (host) {
@@ -73,12 +80,13 @@ static int my_on_init(const dstalk_host_api_t* host)
} }
/* ------------------------------------------------------------------ /* ------------------------------------------------------------------
* on_shutdown (was on_unload) * on_shutdown原 on_unload / on_shutdown (was on_unload)
* ------------------------------------------------------------------ */ * ------------------------------------------------------------------ */
// 插件关闭:记录调用次数,释放资源 / Plugin shutdown: log call count, release any resources.
static void my_on_shutdown(void) static void my_on_shutdown(void)
{ {
/* TODO: release any resources allocated in on_init. After this /* TODO: 释放 on_init 中分配的所有资源。此函数返回后主机将卸载共享库。 / release any resources allocated in on_init. After this
* function returns the host will unmap the shared library. */ * function returns the host will unmap the shared library. */
if (g_host) { if (g_host) {
@@ -92,20 +100,21 @@ static void my_on_shutdown(void)
} }
/* ------------------------------------------------------------------ /* ------------------------------------------------------------------
* on_event (was on_message) * on_event原 on_message / on_event (was on_message)
* ------------------------------------------------------------------ */ * ------------------------------------------------------------------ */
// 插件事件处理:记录消息事件,忽略其他事件类型 / Plugin event handler: log message events, ignore other event types.
static void my_on_event(int event_type, const void* data) static void my_on_event(int event_type, const void* data)
{ {
if (event_type == DSTALK_EVENT_MESSAGE && data) { if (event_type == DSTALK_EVENT_MESSAGE && data) {
const auto* msg = static_cast<const dstalk_message_t*>(data); const auto* msg = static_cast<const dstalk_message_t*>(data);
g_state.call_count++; g_state.call_count++;
/* A real plugin might: /* 真实插件可能: / A real plugin might:
* - log the conversation to a file * - 将对话记录到文件 / log the conversation to a file
* - apply content moderation * - 实施内容审核 / apply content moderation
* - translate messages on the fly * - 实时翻译消息 / translate messages on the fly
* - enrich messages with external data * - 用外部数据丰富消息 / enrich messages with external data
*/ */
if (g_host) { if (g_host) {
@@ -117,19 +126,19 @@ static void my_on_event(int event_type, const void* data)
msg->role, std::strlen(msg->content)); msg->role, std::strlen(msg->content));
} }
} }
/* Other event types (DSTALK_EVENT_SESSION_CLEAR, DSTALK_EVENT_CONFIG_CHANGED, /* 其他事件类型 / Other event types (DSTALK_EVENT_SESSION_CLEAR, DSTALK_EVENT_CONFIG_CHANGED,
DSTALK_EVENT_PLUGIN_LOADED, DSTALK_EVENT_PLUGIN_UNLOADED, DSTALK_EVENT_CUSTOM+) DSTALK_EVENT_PLUGIN_LOADED, DSTALK_EVENT_PLUGIN_UNLOADED, DSTALK_EVENT_CUSTOM+)
are silently ignored by this minimal plugin. */ 此最小化插件静默忽略 / are silently ignored by this minimal plugin. */
} }
/* ------------------------------------------------------------------ /* ------------------------------------------------------------------
* Plugin descriptor (static -- lives for the lifetime of the .so) * 插件描述符(静态 —— 在 .so 的生命周期内有效) / Plugin descriptor (static -- lives for the lifetime of the .so)
* ------------------------------------------------------------------ */ * ------------------------------------------------------------------ */
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
/* .name = */ "example-plugin", /* .name = */ "example-plugin",
/* .version = */ "1.0.0", /* .version = */ "1.0.0",
/* .description = */ "An example plugin for dstalk", /* .description = */ "An example plugin for dstalk / dstalk 示例插件",
/* .api_version = */ DSTALK_API_VERSION, /* .api_version = */ DSTALK_API_VERSION,
/* .dependencies = */ {nullptr}, /* .dependencies = */ {nullptr},
/* .on_init = */ my_on_init, /* .on_init = */ my_on_init,
@@ -138,13 +147,16 @@ static dstalk_plugin_info_t g_info = {
}; };
/* ------------------------------------------------------------------ /* ------------------------------------------------------------------
* Mandatory entry point * 必须入口点 / Mandatory entry point
* ------------------------------------------------------------------ * ------------------------------------------------------------------
* *
* The host looks for this symbol via dlsym / GetProcAddress. * The host looks for this symbol via dlsym / GetProcAddress.
* 主机通过 dlsym / GetProcAddress 查找此符号。
* It MUST be declared extern "C" so the name is not mangled. * It MUST be declared extern "C" so the name is not mangled.
* 必须声明为 extern "C" 以避免名称修饰。
*/ */
// 返回插件描述符给主机加载器 / Returns the plugin descriptor to the host loader.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
{ {
return &g_info; return &g_info;

View File

@@ -1,3 +1,10 @@
/*
* @file anthropic_plugin.cpp
* @brief Anthropic Claude Messages API provider plugin with streaming support.
* Anthropic Claude Messages API 提供者插件,支持流式输出。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
@@ -11,14 +18,14 @@
namespace json = boost::json; namespace json = boost::json;
// ============================================================================ // ============================================================================
// 全局指针 — W17.4: std::atomic 保护 on_shutdown 与 service 函数并发读写 // 全局指针 — W17.4: std::atomic 保护 on_shutdown 与 service 函数并发读写 / Global pointers — W17.4: std::atomic protects concurrent read/write between on_shutdown and service functions
// ============================================================================ // ============================================================================
static std::atomic<const dstalk_host_api_t*> g_host{nullptr}; static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::atomic<dstalk_http_service_t*> g_http{nullptr}; static std::atomic<dstalk_http_service_t*> g_http{nullptr};
static dstalk_config_service_t* g_config = nullptr; static dstalk_config_service_t* g_config = nullptr;
// ============================================================================ // ============================================================================
// 配置数据 // 配置数据 / Config data
// ============================================================================ // ============================================================================
struct PluginConfig { struct PluginConfig {
std::string provider; std::string provider;
@@ -29,19 +36,21 @@ struct PluginConfig {
double temperature = 0.7; double temperature = 0.7;
}; };
static PluginConfig g_cfg; static PluginConfig g_cfg;
static std::string g_tools_json; // W21.2: cached by configure(), consumed by chat/chat_stream static std::string g_tools_json; // W21.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
// ============================================================================ // ============================================================================
// 安全擦除:用 volatile 写零循环防止编译器优化 // 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
// ============================================================================ // ============================================================================
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
static void secure_zero(void* p, size_t n) { static void secure_zero(void* p, size_t n) {
volatile char* vp = (volatile char*)p; volatile char* vp = (volatile char*)p;
while (n--) *vp++ = 0; while (n--) *vp++ = 0;
} }
// ============================================================================ // ============================================================================
// 辅助:提取 host / target // 辅助:提取 host / target / Helper: extract host / target
// ============================================================================ // ============================================================================
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
static bool extract_host_port(const std::string& url, static bool extract_host_port(const std::string& url,
std::string& scheme_out, std::string& host_out, std::string& scheme_out, std::string& host_out,
std::string& port_out, std::string& target_out) std::string& port_out, std::string& target_out)
@@ -65,8 +74,9 @@ static bool extract_host_port(const std::string& url,
} }
// ============================================================================ // ============================================================================
// 构建 Anthropic headers JSON // 构建 Anthropic headers JSON / Build Anthropic headers JSON
// ============================================================================ // ============================================================================
// 构建包含 x-api-key 和 anthropic-version 的 JSON headers 对象 / Build the JSON headers object containing x-api-key and anthropic-version.
static std::string build_headers_json() static std::string build_headers_json()
{ {
json::object h; json::object h;
@@ -76,8 +86,11 @@ static std::string build_headers_json()
} }
// ============================================================================ // ============================================================================
// 构建 Anthropic JSON 请求体 // 构建 Anthropic JSON 请求体 / Build Anthropic JSON request body
// ============================================================================ // ============================================================================
// 构建 Anthropic Messages API 的完整 JSON 请求体。
// 按 Anthropic 规范将 system 消息提取为顶层 system 字段 / Build the full JSON request body for the Anthropic Messages API.
// Extracts system messages as a top-level "system" field per Anthropic spec.
static std::string build_request_json( static std::string build_request_json(
const dstalk_message_t* history, int history_len, const dstalk_message_t* history, int history_len,
const std::string& user_input, const std::string& user_input,
@@ -89,7 +102,7 @@ static std::string build_request_json(
root["max_tokens"] = g_cfg.max_tokens; root["max_tokens"] = g_cfg.max_tokens;
root["stream"] = stream; root["stream"] = stream;
// 提取 system 消息作为顶层字段 // 提取 system 消息作为顶层字段 / Extract system messages as top-level field
std::string system_prompt; std::string system_prompt;
json::array msgs; json::array msgs;
@@ -106,7 +119,7 @@ static std::string build_request_json(
msgs.push_back(obj); msgs.push_back(obj);
} }
// 追加当前用户输入 // 追加当前用户输入 / Append current user input
{ {
json::object obj; json::object obj;
obj["role"] = "user"; obj["role"] = "user";
@@ -124,7 +137,7 @@ static std::string build_request_json(
root["temperature"] = g_cfg.temperature; root["temperature"] = g_cfg.temperature;
} }
// W21.2: tools 定义传递给 API // W21.2: tools 定义传递给 API / Pass tools definition to API
if (!tools_json.empty()) { if (!tools_json.empty()) {
root["tools"] = json::parse(tools_json); root["tools"] = json::parse(tools_json);
} }
@@ -133,8 +146,11 @@ static std::string build_request_json(
} }
// ============================================================================ // ============================================================================
// 解析非流式响应 // 解析非流式响应 / Parse non-streaming response
// ============================================================================ // ============================================================================
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t。
// 处理 text 和 tool_use content block将 tool_use 转换为 OpenAI 格式 / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
// Handles text and tool_use content blocks, converting tool_use to OpenAI format.
static void parse_response(const char* body, int http_status, static void parse_response(const char* body, int http_status,
dstalk_chat_result_t& r) dstalk_chat_result_t& r)
{ {
@@ -169,7 +185,7 @@ static void parse_response(const char* body, int http_status,
auto obj = jv.as_object(); auto obj = jv.as_object();
auto content = obj["content"].as_array(); auto content = obj["content"].as_array();
if (!content.empty()) { if (!content.empty()) {
// W21.2: 提取 text 和 tool_use content blocks // W21.2: 提取 text 和 tool_use content blocks / Extract text and tool_use content blocks
std::string text_content; std::string text_content;
json::array tool_use_blocks; json::array tool_use_blocks;
@@ -181,7 +197,7 @@ static void parse_response(const char* body, int http_status,
if (btype == "text") { if (btype == "text") {
text_content = json::value_to<std::string>(bobj["text"]); text_content = json::value_to<std::string>(bobj["text"]);
} else if (btype == "tool_use") { } else if (btype == "tool_use") {
// 转换为 OpenAI 兼容格式: {id, type:"function", function:{name, arguments}} // 转换为 OpenAI 兼容格式: {id, type:"function", function:{name, arguments}} / Convert to OpenAI-compatible format: {id, type:"function", function:{name, arguments}}
json::object tc; json::object tc;
tc["id"] = bobj["id"]; tc["id"] = bobj["id"];
tc["type"] = "function"; tc["type"] = "function";
@@ -206,7 +222,7 @@ static void parse_response(const char* body, int http_status,
r.error = nullptr; r.error = nullptr;
return; return;
} else if (!tool_use_blocks.empty()) { } else if (!tool_use_blocks.empty()) {
// tool-only 响应 // tool-only 响应 / tool-only response
r.content = nullptr; r.content = nullptr;
r.ok = 1; r.ok = 1;
r.error = nullptr; r.error = nullptr;
@@ -235,15 +251,15 @@ static void parse_response(const char* body, int http_status,
} }
// ============================================================================ // ============================================================================
// SSE 事件解析Anthropic 格式: event/content_block_delta) // SSE 事件解析Anthropic 格式: event/content_block_delta) / SSE event parsing (Anthropic format: event/content_block_delta)
// ============================================================================ // ============================================================================
// W21.2: 按 content_block index 累积 Anthropic tool_use 增量 // W21.2: 按 content_block index 累积 Anthropic tool_use 增量 / Accumulate Anthropic tool_use increments by content_block index
struct ToolCallAccum { struct ToolCallAccum {
int index = -1; int index = -1;
std::string id; std::string id;
std::string name; std::string name;
std::string arguments; // 从 input_json_delta.partial_json 累积 std::string arguments; // 从 input_json_delta.partial_json 累积 / accumulated from input_json_delta.partial_json
}; };
struct StreamContext { struct StreamContext {
@@ -252,10 +268,15 @@ struct StreamContext {
void* userdata; void* userdata;
std::string accumulated; std::string accumulated;
bool saw_data_line = false; bool saw_data_line = false;
std::vector<ToolCallAccum> tool_calls; // W21.2: 按 index 累积 tool_use content blocks std::vector<ToolCallAccum> tool_calls; // W21.2: 按 index 累积 tool_use content blocks / accumulate tool_use content blocks by index
}; };
// W21.2: 解析 Anthropic SSE 事件,含 tool_use content_block 增量解析 // W21.2: 解析 Anthropic SSE 事件,含 tool_use content_block 增量解析 / Parse Anthropic SSE events with tool_use content_block incremental parsing
// 解析单个 Anthropic SSE "data:" JSON 事件。处理 content_block_start、
// content_block_delta (text_delta/input_json_delta) 和 message_stop。
// 如果产生了 content token 则返回 true否则返回 false / Parse a single Anthropic SSE "data:" JSON event. Handles content_block_start,
// content_block_delta (text_delta/input_json_delta), and message_stop.
// Returns true if a content token was produced, false otherwise.
static bool parse_sse_data(const std::string& data, std::string& token_out, static bool parse_sse_data(const std::string& data, std::string& token_out,
StreamContext* ctx) StreamContext* ctx)
{ {
@@ -268,7 +289,7 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
std::string type = json::value_to<std::string>(*type_ptr); std::string type = json::value_to<std::string>(*type_ptr);
if (type == "content_block_start") { if (type == "content_block_start") {
// content_block_start 可能为 tool_use // content_block_start 可能为 tool_use / content_block_start may be tool_use
auto* cb = obj.if_contains("content_block"); auto* cb = obj.if_contains("content_block");
if (!cb || !cb->is_object()) return false; if (!cb || !cb->is_object()) return false;
auto& cb_obj = cb->as_object(); auto& cb_obj = cb->as_object();
@@ -311,7 +332,7 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
return true; return true;
} }
} else if (delta_type == "input_json_delta" && ctx) { } else if (delta_type == "input_json_delta" && ctx) {
// W21.2: 累积 tool_use arguments 分片 // W21.2: 累积 tool_use arguments 分片 / Accumulate tool_use arguments fragments
auto* pj = dobj.if_contains("partial_json"); auto* pj = dobj.if_contains("partial_json");
if (pj && pj->is_string()) { if (pj && pj->is_string()) {
auto* idx_ptr = obj.if_contains("index"); auto* idx_ptr = obj.if_contains("index");
@@ -326,18 +347,19 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
} }
} else if (type == "message_stop") { } else if (type == "message_stop") {
token_out.clear(); token_out.clear();
return true; // 流结束 return true; // 流结束 / stream end
} }
// 忽略: message_start, content_block_stop, ping, message_delta // 忽略: message_start, content_block_stop, ping, message_delta / Ignore: message_start, content_block_stop, ping, message_delta
} catch (...) { } catch (...) {
// 解析失败忽略 // 解析失败忽略 / Ignore parse failures
} }
return false; return false;
} }
// ============================================================================ // ============================================================================
// configure // configure / configure
// ============================================================================ // ============================================================================
// 配置插件provider、endpoint、auth、model 和生成参数 / Configure the plugin with provider, endpoint, auth, model, and generation parameters.
static int my_configure(const char* provider, const char* base_url, static int my_configure(const char* provider, const char* base_url,
const char* api_key, const char* model, const char* api_key, const char* model,
int max_tokens, double temperature) int max_tokens, double temperature)
@@ -352,7 +374,7 @@ static int my_configure(const char* provider, const char* base_url,
const auto* h = g_host.load(std::memory_order_acquire); const auto* h = g_host.load(std::memory_order_acquire);
if (h) { if (h) {
// W21.2: 从 tools service 缓存 tools_json供 chat/chat_stream 复用 // W21.2: 从 tools service 缓存 tools_json供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>( auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
h->query_service("tools", 1)); h->query_service("tools", 1));
if (tools_svc && tools_svc->get_tools_json) { if (tools_svc && tools_svc->get_tools_json) {
@@ -381,8 +403,9 @@ static int my_configure(const char* provider, const char* base_url,
} }
// ============================================================================ // ============================================================================
// chat // chat / chat
// ============================================================================ // ============================================================================
// 非流式 chat completion发送 history + user input返回完整响应 / Non-streaming chat completion: send history + user input, return full response.
static dstalk_chat_result_t my_chat( static dstalk_chat_result_t my_chat(
const dstalk_message_t* history, int history_len, const dstalk_message_t* history, int history_len,
const char* user_input, const char* user_input,
@@ -447,26 +470,27 @@ static dstalk_chat_result_t my_chat(
} }
// ============================================================================ // ============================================================================
// chat_stream // chat_stream / chat_stream
// ============================================================================ // ============================================================================
// 行回调 // 行回调 / SSE line callback
// SSE 行回调:解析每个 Anthropic SSE 行并将文本 token 转发给用户 / SSE line callback: parses each Anthropic SSE line and forwards text tokens to user.
static int sse_line_callback(const char* line, void* userdata) static int sse_line_callback(const char* line, void* userdata)
{ {
try { try {
auto* ctx = static_cast<StreamContext*>(userdata); auto* ctx = static_cast<StreamContext*>(userdata);
if (!line || !line[0]) return 1; // 空行,继续 if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
std::string line_str(line); std::string line_str(line);
// SSE 格式: "data: <json>" // SSE 格式: "data: <json>" / SSE format: "data: <json>"
if (line_str.rfind("data: ", 0) == 0) { if (line_str.rfind("data: ", 0) == 0) {
std::string data = line_str.substr(6); std::string data = line_str.substr(6);
std::string token; std::string token;
if (parse_sse_data(data, token, ctx)) { if (parse_sse_data(data, token, ctx)) {
ctx->saw_data_line = true; ctx->saw_data_line = true;
if (token.empty()) { if (token.empty()) {
// message_stop // message_stop / message_stop
return 0; return 0;
} }
ctx->accumulated += token; ctx->accumulated += token;
@@ -475,7 +499,7 @@ static int sse_line_callback(const char* line, void* userdata)
} }
} }
} }
// "event: ..." 行和其他 -> 忽略 // "event: ..." 行和其他 -> 忽略 / "event: ..." lines and others -> ignored
return 1; return 1;
} catch (const std::exception& e) { } catch (const std::exception& e) {
const auto* h = g_host.load(std::memory_order_acquire); const auto* h = g_host.load(std::memory_order_acquire);
@@ -488,6 +512,9 @@ static int sse_line_callback(const char* line, void* userdata)
} }
} }
// 流式 chat completion以 stream=true 发送 history + user input通过回调传递 token。
// 累积 tool_use blocks 并在结束时序列化 / Streaming chat completion: send history + user input with stream=true, deliver tokens
// via callback. Accumulates tool_use blocks and serializes them at end.
static dstalk_chat_result_t my_chat_stream( static dstalk_chat_result_t my_chat_stream(
const dstalk_message_t* history, int history_len, const dstalk_message_t* history, int history_len,
const char* user_input, const char* user_input,
@@ -531,7 +558,7 @@ static dstalk_chat_result_t my_chat_stream(
r.http_status = status_code; r.http_status = status_code;
// 检查错误状态 // 检查错误状态 / Check error status
if (status_code < 200 || status_code >= 300) { if (status_code < 200 || status_code >= 300) {
r.ok = 0; r.ok = 0;
if (response_body && response_body[0]) { if (response_body && response_body[0]) {
@@ -560,7 +587,7 @@ static dstalk_chat_result_t my_chat_stream(
if (response_body) host->free(response_body); if (response_body) host->free(response_body);
// W21.2: 成功条件 = 有内容 OR 有 tool_callstool-only 响应如 function calling // W21.2: 成功条件 = 有内容 OR 有 tool_callstool-only 响应如 function calling / Success = has content OR has tool_calls (tool-only responses like function calling)
bool has_content = !ctx.accumulated.empty(); bool has_content = !ctx.accumulated.empty();
bool has_tool_calls = !ctx.tool_calls.empty(); bool has_tool_calls = !ctx.tool_calls.empty();
@@ -575,7 +602,7 @@ static dstalk_chat_result_t my_chat_stream(
r.content = has_content r.content = has_content
? host->strdup(ctx.accumulated.c_str()) : nullptr; ? host->strdup(ctx.accumulated.c_str()) : nullptr;
// W21.2: 序列化累积的 tool_calls 为 JSON兼容 OpenAI tool_calls 格式) // W21.2: 序列化累积的 tool_calls 为 JSON兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible format)
if (has_tool_calls) { if (has_tool_calls) {
json::array tc_array; json::array tc_array;
for (auto& tc : ctx.tool_calls) { for (auto& tc : ctx.tool_calls) {
@@ -614,8 +641,9 @@ static dstalk_chat_result_t my_chat_stream(
} }
// ============================================================================ // ============================================================================
// free_result // free_result / free_result
// ============================================================================ // ============================================================================
// 释放 chat result 结构体中所有主机分配的字符串字段 / Free all host-allocated string fields in a chat result struct.
static void my_free_result(dstalk_chat_result_t* result) static void my_free_result(dstalk_chat_result_t* result)
{ {
const auto* h = g_host.load(std::memory_order_acquire); const auto* h = g_host.load(std::memory_order_acquire);
@@ -626,7 +654,7 @@ static void my_free_result(dstalk_chat_result_t* result)
} }
// ============================================================================ // ============================================================================
// 服务 vtable // 服务 vtable / Service vtable
// ============================================================================ // ============================================================================
static dstalk_ai_service_t g_service = { static dstalk_ai_service_t g_service = {
&my_configure, &my_configure,
@@ -636,8 +664,9 @@ static dstalk_ai_service_t g_service = {
}; };
// ============================================================================ // ============================================================================
// 生命周期 // 生命周期 / Lifecycle
// ============================================================================ // ============================================================================
// 插件初始化:查询 http 和 config 服务,注册 ai.anthropic 服务 / Plugin init: query http and config services, register ai.anthropic service.
static int on_init(const dstalk_host_api_t* host) static int on_init(const dstalk_host_api_t* host)
{ {
try { try {
@@ -666,6 +695,7 @@ static int on_init(const dstalk_host_api_t* host)
} }
} }
// 插件关闭:从内存安全擦除 API key清空服务指针 / Plugin shutdown: securely erase API key from memory, null out service pointers.
static void on_shutdown() static void on_shutdown()
{ {
try { try {
@@ -686,12 +716,12 @@ static void on_shutdown()
} }
// ============================================================================ // ============================================================================
// 插件描述符 // 插件描述符 / Plugin descriptor
// ============================================================================ // ============================================================================
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
/* .name = */ "anthropic-ai", /* .name = */ "anthropic-ai",
/* .version = */ "1.0.0", /* .version = */ "1.0.0",
/* .description = */ "Anthropic Claude AI provider (Messages API)", /* .description = */ "Anthropic Claude AI provider (Messages API) / Anthropic Claude AI 提供者 (Messages API)",
/* .api_version = */ DSTALK_API_VERSION, /* .api_version = */ DSTALK_API_VERSION,
/* .dependencies = */ { "http", "config", NULL }, /* .dependencies = */ { "http", "config", NULL },
/* .on_init = */ on_init, /* .on_init = */ on_init,
@@ -699,6 +729,7 @@ static dstalk_plugin_info_t g_info = {
/* .on_event = */ nullptr, /* .on_event = */ nullptr,
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
{ {
return &g_info; return &g_info;

View File

@@ -1,16 +1,24 @@
/*
* @file toml_parse.h
* @brief Lightweight single-header TOML parser (subset: flat key-value pairs).
* 轻量级单头文件 TOML 解析器(子集:扁平键值对)。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#pragma once #pragma once
// Shared TOML parser — used by both ConfigStore (core) and config plugin. // 共享 TOML 解析器 —— 由 ConfigStore核心和 config 插件共同使用 / Shared TOML parser — used by both ConfigStore (core) and config plugin.
// W12.2: Extracted from config_store.cpp:23-61 and config_plugin.cpp:28-66 // W12.2: Extracted from config_store.cpp:23-61 and config_plugin.cpp:28-66
// to eliminate the 74-line code duplication (W11.2 audit Finding 1). // to eliminate the 74-line code duplication (W11.2 audit Finding 1).
// Does NOT support: inline tables, arrays, multi-line strings, escape sequences. // Does NOT support: inline tables, arrays, multi-line strings, escape sequences.
// 不支持:内联表、数组、多行字符串、转义序列。
#include <string> #include <string>
namespace dstalk { namespace dstalk {
namespace toml { namespace toml {
/// Parse a TOML string, calling on_kv(full_key, value) for each key-value pair. /// 解析 TOML 字符串,对每个键值对调用 on_kv(full_key, value) / Parse a TOML string, calling on_kv(full_key, value) for each key-value pair.
/// Supports [section] headers, key = "value" pairs, # comments, blank lines. /// 支持 [section] 标题、key = "value" 键值对、# 注释、空行 / Supports [section] headers, key = "value" pairs, # comments, blank lines.
template<typename F> template<typename F>
inline void parse(const std::string& content, F&& on_kv) inline void parse(const std::string& content, F&& on_kv)
{ {
@@ -18,31 +26,31 @@ inline void parse(const std::string& content, F&& on_kv)
size_t pos = 0; size_t pos = 0;
while (pos < content.size()) { while (pos < content.size()) {
// Trim left whitespace // 去除左侧空白 / Trim left whitespace
while (pos < content.size() && (content[pos] == ' ' || content[pos] == '\t')) while (pos < content.size() && (content[pos] == ' ' || content[pos] == '\t'))
pos++; pos++;
if (pos >= content.size()) break; if (pos >= content.size()) break;
// Extract next line // 提取下一行 / Extract next line
size_t nl = content.find('\n', pos); size_t nl = content.find('\n', pos);
std::string line = (nl != std::string::npos) std::string line = (nl != std::string::npos)
? content.substr(pos, nl - pos) : content.substr(pos); ? content.substr(pos, nl - pos) : content.substr(pos);
pos = (nl != std::string::npos) ? nl + 1 : content.size(); pos = (nl != std::string::npos) ? nl + 1 : content.size();
// Trim right whitespace (including \r) // 去除右侧空白(包括 \r / Trim right whitespace (including \r)
while (!line.empty() && (line.back() == '\r' || line.back() == ' ')) while (!line.empty() && (line.back() == '\r' || line.back() == ' '))
line.pop_back(); line.pop_back();
// Skip empty lines and comments // 跳过空行和注释 / Skip empty lines and comments
if (line.empty() || line[0] == '#') continue; if (line.empty() || line[0] == '#') continue;
// Section header: [section_name] // 节标题: [section_name] / Section header: [section_name]
if (line[0] == '[' && line.back() == ']') { if (line[0] == '[' && line.back() == ']') {
current_section = line.substr(1, line.size() - 2); current_section = line.substr(1, line.size() - 2);
continue; continue;
} }
// Key = value // 键 = 值 / Key = value
size_t eq = line.find('='); size_t eq = line.find('=');
if (eq == std::string::npos) continue; if (eq == std::string::npos) continue;

View File

@@ -1,3 +1,10 @@
/*
* @file config_plugin.cpp
* @brief Config plugin: TOML file parsing and key-value configuration service.
* 配置插件TOML 文件解析和键值配置服务。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
#include "../include/toml_parse.h" #include "../include/toml_parse.h"
@@ -7,12 +14,12 @@
#include <sstream> #include <sstream>
// ============================================================ // ============================================================
// Global state // 全局状态 / Global state
// ============================================================ // ============================================================
static const dstalk_host_api_t* g_host = nullptr; static const dstalk_host_api_t* g_host = nullptr;
// ============================================================ // ============================================================
// Service implementations // 服务实现 / Service implementations
// //
// W12.2: Eliminated private ConfigStore (was 90 lines duplicating core). // W12.2: Eliminated private ConfigStore (was 90 lines duplicating core).
// All get/set/load_file now delegate to the host store via g_host->config_get // All get/set/load_file now delegate to the host store via g_host->config_get
@@ -20,16 +27,19 @@ static const dstalk_host_api_t* g_host = nullptr;
// TOML parsing uses the shared dstalk::toml::parse() from toml_parse.h. // TOML parsing uses the shared dstalk::toml::parse() from toml_parse.h.
// ============================================================ // ============================================================
// 从主机存储中按 key 获取配置值 / Retrieve a configuration value by key from the host store.
static const char* config_get(const char* key) { static const char* config_get(const char* key) {
if (!g_host) return nullptr; if (!g_host) return nullptr;
return g_host->config_get(key); return g_host->config_get(key);
} }
// 将键值对存入主机存储 / Store a configuration key-value pair into the host store.
static int config_set(const char* key, const char* value) { static int config_set(const char* key, const char* value) {
if (!g_host) return -1; if (!g_host) return -1;
return g_host->config_set(key, value); return g_host->config_set(key, value);
} }
// 解析指定路径的 TOML 文件,将所有键值对加载到主机存储中 / Parse a TOML file at `path` and load all key-value pairs into the host store.
static int config_load_file(const char* path) { static int config_load_file(const char* path) {
if (!g_host || !path) return -1; if (!g_host || !path) return -1;
@@ -58,12 +68,13 @@ static dstalk_config_service_t g_service = {
}; };
// ============================================================ // ============================================================
// Plugin lifecycle // 插件生命周期 / Plugin lifecycle
// ============================================================ // ============================================================
// 插件初始化:保存主机指针并注册 config 服务 vtable / Plugin init: store host pointer and register the config service vtable.
static int on_init(const dstalk_host_api_t* host) { static int on_init(const dstalk_host_api_t* host) {
g_host = host; g_host = host;
// W12.2: This service is now a thin wrapper around host->config_get/set. // W12.2: 该服务现为 host->config_get/set 的薄封装,建议直接调用主机 API / This service is now a thin wrapper around host->config_get/set.
// Direct host API calls are preferred. // Direct host API calls are preferred.
host->log(DSTALK_LOG_INFO, host->log(DSTALK_LOG_INFO,
"plugin config service is deprecated, prefer host->config_get/set"); "plugin config service is deprecated, prefer host->config_get/set");
@@ -76,8 +87,10 @@ static int on_init(const dstalk_host_api_t* host) {
return (rc >= 0) ? 0 : -1; return (rc >= 0) ? 0 : -1;
} }
// 插件关闭:无需清理本地存储(所有数据在主机存储中) / Plugin shutdown: no local store to clean up (all data lives in host store).
static void on_shutdown() { static void on_shutdown() {
// W12.2: No local store to clean up — all data lives in host store. // W12.2: No local store to clean up — all data lives in host store.
// 无需清理本地存储——所有数据位于主机存储中。
} }
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
@@ -91,6 +104,7 @@ static dstalk_plugin_info_t g_info = {
nullptr // on_event nullptr // on_event
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info; return &g_info;
} }

View File

@@ -1,6 +1,13 @@
// plugin-context: 上下文管理服务插件 /*
// 提供 dstalk_context_service_t vtable 实现 * @file context_plugin.cpp
// 依赖: session (获取历史消息做 token 计数) * @brief Context plugin: token counting and context window trimming.
* 上下文插件token 计数和上下文窗口裁剪。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
// plugin-context: 上下文管理服务插件 / Context management service plugin
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_types.h" #include "dstalk/dstalk_types.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
@@ -15,21 +22,26 @@
#include <vector> #include <vector>
// ============================================================ // ============================================================
// 全局状态 // 全局状态 / Global state
// ============================================================ // ============================================================
static const dstalk_host_api_t* g_host = nullptr; static const dstalk_host_api_t* g_host = nullptr;
static const dstalk_session_service_t* g_session = nullptr; static const dstalk_session_service_t* g_session = nullptr;
// ============================================================ // ============================================================
// 内部 C++ 辅助:共享 UTF-8 token 计数 // 内部 C++ 辅助:共享 UTF-8 token 计数 / Internal C++ helper: shared UTF-8 token counting
// W18.1: 合并 count_tokens_one_message / count_tokens_trim 的重复逻辑 (F-11.1-5) // W18.1: 合并 count_tokens_one_message / count_tokens_trim 的重复逻辑 (F-11.1-5)
// Merge duplicated logic between count_tokens_one_message / count_tokens_trim (F-11.1-5)
// 添加 UTF-8 越界保护 (F-11.1-4) 和 0xC0/0xC1 过短编码检测 (F-11.1-6) // 添加 UTF-8 越界保护 (F-11.1-4) 和 0xC0/0xC1 过短编码检测 (F-11.1-6)
// Add UTF-8 out-of-bounds protection (F-11.1-4) and 0xC0/0xC1 overlong encoding detection (F-11.1-6)
// ============================================================ // ============================================================
// 统计 UTF-8 字节序列 [text, text+len) 的估算 token 数。 // 统计 UTF-8 字节序列 [text, text+len) 的估算 token 数。
// overhead: 每条消息的固定开销 tokenrole + separators = 4 // overhead: 每条消息的固定开销 tokenrole + separators = 4
// 多字节序列在越界或无效后继字节时回退为单字节 other_chars 计数,不崩溃。 // 多字节序列在越界或无效后继字节时回退为单字节 other_chars 计数,不崩溃。
// Count estimated tokens for UTF-8 byte sequence [text, text+len).
// overhead: fixed token overhead per message (role + separators = 4).
// Multi-byte sequences fall back to single-byte other_chars counting when out-of-bounds or invalid continuation bytes.
static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) { static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
if (!text || len == 0) return overhead; if (!text || len == 0) return overhead;
@@ -42,12 +54,12 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
unsigned char c = static_cast<unsigned char>(text[i]); unsigned char c = static_cast<unsigned char>(text[i]);
if (c < 0x80) { if (c < 0x80) {
// ASCII // ASCII / ASCII
ascii_chars++; ascii_chars++;
i += 1; i += 1;
} else if (c >= 0xE4 && c <= 0xE9) { } else if (c >= 0xE4 && c <= 0xE9) {
// CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9 // CJK 统一表意文字 (U+4E00-U+9FFF): 3 字节 UTF-8 0xE4-0xE9 / CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9
// W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内 // W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内 / Check if subsequent 2 bytes are in valid range
if (i + 2 >= len || if (i + 2 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 || (static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) { (static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
@@ -58,8 +70,8 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
i += 3; i += 3;
} }
} else if (c >= 0xC2 && c < 0xE0) { } else if (c >= 0xC2 && c < 0xE0) {
// 2-byte sequence (valid range 0xC2-0xDF) // 2 字节序列 (有效范围 0xC2-0xDF) / 2-byte sequence (valid range 0xC2-0xDF)
// W18.1 (F-11.1-4): 检查后续 1 字节 // W18.1 (F-11.1-4): 检查后续 1 字节 / Check subsequent 1 byte
if (i + 1 >= len || if (i + 1 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80) { (static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80) {
other_chars++; other_chars++;
@@ -69,13 +81,13 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
i += 2; i += 2;
} }
} else if (c == 0xC0 || c == 0xC1) { } else if (c == 0xC0 || c == 0xC1) {
// W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节 // W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节 / Overlong encoding, invalid UTF-8 start byte
// 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars // 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars / 0xC0/0xC1 never appear in valid UTF-8; counted as single-byte in other_chars
other_chars++; other_chars++;
i += 1; i += 1;
} else if (c >= 0xE0 && c < 0xF0) { } else if (c >= 0xE0 && c < 0xF0) {
// Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF) // 非 CJK 3 字节序列 (0xE0-0xE3, 0xEA-0xEF) / Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF)
// CJK 范围 0xE4-0xE9 已在上方分支处理 // CJK 范围 0xE4-0xE9 已在上方分支处理 / CJK range 0xE4-0xE9 handled in branch above
if (i + 2 >= len || if (i + 2 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 || (static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) { (static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
@@ -86,7 +98,7 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
i += 3; i += 3;
} }
} else if (c >= 0xF0 && c < 0xF8) { } else if (c >= 0xF0 && c < 0xF8) {
// 4-byte sequence // 4 字节序列 / 4-byte sequence
if (i + 3 >= len || if (i + 3 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 || (static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80 || (static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80 ||
@@ -98,7 +110,7 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
i += 4; i += 4;
} }
} else { } else {
// Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF) // 续字节 (0x80-0xBF) 和其他无效起始字节 (0xF8-0xFF) / Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF)
other_chars++; other_chars++;
i += 1; i += 1;
} }
@@ -108,15 +120,17 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
} }
// ============================================================ // ============================================================
// 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装) // 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装) / Message-level token counting (thin wrappers for count_tokens_all and trim_impl)
// ============================================================ // ============================================================
// 对单条 C 消息结构体封装 count_tokens_utf8 / Wrap count_tokens_utf8 for a single C message struct.
static size_t count_tokens_one_message(const dstalk_message_t& msg) { static size_t count_tokens_one_message(const dstalk_message_t& msg) {
const char* text = msg.content; const char* text = msg.content;
if (!text) return 4; // 只有 overhead if (!text) return 4; // 只有 overhead / overhead only
return count_tokens_utf8(text, std::strlen(text), 4); return count_tokens_utf8(text, std::strlen(text), 4);
} }
// 对 C 消息数组求和估算 token / Sum token estimates across an array of C messages.
static size_t count_tokens_all(const dstalk_message_t* msgs, int count) { static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
size_t total = 0; size_t total = 0;
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
@@ -126,10 +140,10 @@ static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
} }
// ============================================================ // ============================================================
// 内部 trim 逻辑 // 内部 trim 逻辑 / Internal trim logic
// ============================================================ // ============================================================
// 为 trim 操作将 C 消息数组复制到内部 struct // 为 trim 操作将 C 消息数组复制到内部 struct / Copy C message array to internal struct for trim operation
struct TrimMessage { struct TrimMessage {
std::string role; std::string role;
std::string content; std::string content;
@@ -148,7 +162,7 @@ static size_t count_tokens_trim_vec(const std::vector<TrimMessage>& msgs) {
return total; return total;
} }
// 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚) // 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚) / Free all host-allocated string fields in a single dstalk_message_t (OOM rollback helper).
static void free_msg_strs(dstalk_message_t* msg) { static void free_msg_strs(dstalk_message_t* msg) {
if (msg->role) { g_host->free((void*)msg->role); msg->role = nullptr; } if (msg->role) { g_host->free((void*)msg->role); msg->role = nullptr; }
if (msg->content) { g_host->free((void*)msg->content); msg->content = nullptr; } if (msg->content) { g_host->free((void*)msg->content); msg->content = nullptr; }
@@ -158,6 +172,8 @@ static void free_msg_strs(dstalk_message_t* msg) {
// 将 TrimMessage 的字符串字段通过 g_host->strdup 复制到 dstalk_message_t。 // 将 TrimMessage 的字符串字段通过 g_host->strdup 复制到 dstalk_message_t。
// 成功返回 0OOM 时释放当前消息已分配字段并返回 -1。 // 成功返回 0OOM 时释放当前消息已分配字段并返回 -1。
// Copy TrimMessage string fields into a dstalk_message_t via host->strdup.
// On OOM, frees already-allocated fields and returns -1.
static int strdup_message_fields(dstalk_message_t* dst, const TrimMessage& src) { static int strdup_message_fields(dstalk_message_t* dst, const TrimMessage& src) {
memset(dst, 0, sizeof(dstalk_message_t)); memset(dst, 0, sizeof(dstalk_message_t));
@@ -184,7 +200,10 @@ oom:
return -1; return -1;
} }
// W12.1 修复trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3) // W12.1 修复trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3) / W12.1 fix: trim_impl wrapped in try/catch to prevent C++ exceptions crossing ABI boundary (§5.3)
// 核心裁剪逻辑:通过删除最旧的 user/assistant 对来减少消息列表以适应 max_tokens。
// 保留 system 消息。try/catch 保护 ABI / Core trim logic: reduce message list to fit within max_tokens by removing
// oldest user/assistant pairs. Preserves system messages. try/catch guards ABI.
static int trim_impl(const dstalk_message_t* in, int in_count, static int trim_impl(const dstalk_message_t* in, int in_count,
dstalk_message_t** out, int* out_count, dstalk_message_t** out, int* out_count,
size_t max_tokens) { size_t max_tokens) {
@@ -192,10 +211,11 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
if (!in || in_count <= 0 || !out || !out_count) return -1; if (!in || in_count <= 0 || !out || !out_count) return -1;
// W18.1 (F-11.1-3): g_max_tokens 已移除,调用方必须提供有效 max_tokens // W18.1 (F-11.1-3): g_max_tokens 已移除,调用方必须提供有效 max_tokens
// 传 0 时使用硬编码默认值 4096 // 传 0 时使用硬编码默认值 4096 / g_max_tokens removed, caller must provide valid max_tokens;
// when 0 is passed, use hardcoded default 4096.
if (max_tokens == 0) max_tokens = 4096; if (max_tokens == 0) max_tokens = 4096;
// 将 C 数组转换为内部 vector // 将 C 数组转换为内部 vector / Convert C array to internal vector
std::vector<TrimMessage> messages; std::vector<TrimMessage> messages;
messages.reserve(in_count); messages.reserve(in_count);
for (int i = 0; i < in_count; ++i) { for (int i = 0; i < in_count; ++i) {
@@ -207,13 +227,13 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
messages.push_back(std::move(tm)); messages.push_back(std::move(tm));
} }
// 如果已在限制内,直接返回完整副本 // 如果已在限制内,直接返回完整副本 / If already within limit, return full copy directly
size_t current = count_tokens_trim_vec(messages); size_t current = count_tokens_trim_vec(messages);
if (current <= max_tokens) { if (current <= max_tokens) {
*out_count = in_count; *out_count = in_count;
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * in_count)); *out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * in_count));
if (!*out) return -1; if (!*out) return -1;
// W12.1: strdup 返回值逐一检查OOM 时回滚已分配消息 // W12.1: strdup 返回值逐一检查OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback already allocated on OOM
for (int i = 0; i < in_count; ++i) { for (int i = 0; i < in_count; ++i) {
if (strdup_message_fields(&(*out)[i], messages[i]) != 0) { if (strdup_message_fields(&(*out)[i], messages[i]) != 0) {
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]); for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
@@ -225,7 +245,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
return 0; return 0;
} }
// 分离 system 消息和非 system 消息 // 分离 system 消息和非 system 消息 / Separate system messages from non-system messages
std::vector<TrimMessage> system_msgs; std::vector<TrimMessage> system_msgs;
std::vector<TrimMessage> non_system_msgs; std::vector<TrimMessage> non_system_msgs;
for (const auto& msg : messages) { for (const auto& msg : messages) {
@@ -243,7 +263,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
system_tokens, max_tokens); system_tokens, max_tokens);
} }
// 检查是否有单条消息超过限制 // 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
for (const auto& msg : non_system_msgs) { for (const auto& msg : non_system_msgs) {
size_t msg_tokens = count_tokens_trim(msg); size_t msg_tokens = count_tokens_trim(msg);
if (msg_tokens > max_tokens) { if (msg_tokens > max_tokens) {
@@ -257,19 +277,19 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
} }
} }
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 // 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
while (!non_system_msgs.empty()) { while (!non_system_msgs.empty()) {
current = system_tokens + count_tokens_trim_vec(non_system_msgs); current = system_tokens + count_tokens_trim_vec(non_system_msgs);
if (current <= max_tokens) break; if (current <= max_tokens) break;
// 找第一个 "user" 消息 // 找第一个 "user" 消息 / Find first "user" message
auto user_it = non_system_msgs.begin(); auto user_it = non_system_msgs.begin();
while (user_it != non_system_msgs.end() && user_it->role != "user") { while (user_it != non_system_msgs.end() && user_it->role != "user") {
++user_it; ++user_it;
} }
if (user_it == non_system_msgs.end()) break; if (user_it == non_system_msgs.end()) break;
// 找下一个 "assistant" // 找下一个 "assistant" / Find next "assistant"
auto assistant_it = user_it + 1; auto assistant_it = user_it + 1;
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") { while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
++assistant_it; ++assistant_it;
@@ -278,7 +298,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
if (assistant_it == non_system_msgs.end()) { if (assistant_it == non_system_msgs.end()) {
non_system_msgs.erase(user_it); non_system_msgs.erase(user_it);
} else { } else {
// 先删 assistant 再删 user 避免迭代器失效 // 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
non_system_msgs.erase(assistant_it); non_system_msgs.erase(assistant_it);
user_it = non_system_msgs.begin(); user_it = non_system_msgs.begin();
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it; while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
@@ -286,7 +306,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
} }
} }
// W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token使用当前 max_tokens // W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token使用当前 max_tokens / Message count upper bound rough estimate (~100 tokens per message), uses current max_tokens
{ {
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100) size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
if (max_msg_count < 1) max_msg_count = 1; if (max_msg_count < 1) max_msg_count = 1;
@@ -295,7 +315,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
} }
} }
// 组装结果 // 组装结果 / Assemble result
std::vector<TrimMessage> result; std::vector<TrimMessage> result;
result.reserve(system_msgs.size() + non_system_msgs.size()); result.reserve(system_msgs.size() + non_system_msgs.size());
result.insert(result.end(), system_msgs.begin(), system_msgs.end()); result.insert(result.end(), system_msgs.begin(), system_msgs.end());
@@ -306,7 +326,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * result_count)); *out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * result_count));
if (!*out) return -1; if (!*out) return -1;
// W12.1: strdup 返回值逐一检查OOM 时回滚已分配消息 // W12.1: strdup 返回值逐一检查OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback on OOM
for (int i = 0; i < result_count; ++i) { for (int i = 0; i < result_count; ++i) {
if (strdup_message_fields(&(*out)[i], result[i]) != 0) { if (strdup_message_fields(&(*out)[i], result[i]) != 0) {
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]); for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
@@ -318,7 +338,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
return 0; return 0;
} catch (const std::exception& e) { } catch (const std::exception& e) {
// W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate() // W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate() / Prevent C++ exceptions (std::bad_alloc etc.) from crossing C ABI boundary -> std::terminate()
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl exception: %s", e.what()); if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl exception: %s", e.what());
return -1; return -1;
} catch (...) { } catch (...) {
@@ -328,10 +348,11 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
} }
// ============================================================ // ============================================================
// Context 服务 vtable 实现 // Context 服务 vtable 实现 / Context service vtable implementation
// ============================================================ // ============================================================
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate() // W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate() / Wrapped try/catch prevents exceptions crossing C ABI boundary -> std::terminate()
// 对 C 消息数组进行 token 计数。输入为 null/空时返回 0 / Count tokens across an array of C messages. Returns 0 on null/empty input.
static size_t context_count_tokens(const dstalk_message_t* msgs, int count) { static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
try { try {
if (!msgs || count <= 0) return 0; if (!msgs || count <= 0) return 0;
@@ -341,7 +362,8 @@ static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
} }
} }
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 // W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
// 裁剪消息列表以适应 max_tokens返回新分配的主机内存数组 / Trim a message list to fit within max_tokens, returning a new host-allocated array.
static int context_trim(const dstalk_message_t* in, int in_count, static int context_trim(const dstalk_message_t* in, int in_count,
dstalk_message_t** out, int* out_count, dstalk_message_t** out, int* out_count,
size_t max_tokens) { size_t max_tokens) {
@@ -355,21 +377,24 @@ static int context_trim(const dstalk_message_t* in, int in_count,
// W18.1 (F-11.1-3): g_max_tokens / context_set_max_tokens 已移除。 // W18.1 (F-11.1-3): g_max_tokens / context_set_max_tokens 已移除。
// max_tokens 由调用方通过 trim() 的 max_tokens 参数直接传入; // max_tokens 由调用方通过 trim() 的 max_tokens 参数直接传入;
// 传 0 时 trim_impl 使用硬编码默认值 4096。 // 传 0 时 trim_impl 使用硬编码默认值 4096。
// g_max_tokens / context_set_max_tokens removed. max_tokens is passed directly
// by caller via trim()'s max_tokens parameter; trim_impl uses hardcoded default 4096 when 0.
static dstalk_context_service_t g_context_service = { static dstalk_context_service_t g_context_service = {
context_count_tokens, context_count_tokens,
context_trim context_trim
}; };
// ============================================================ // ============================================================
// 插件生命周期 // 插件生命周期 / Plugin lifecycle
// ============================================================ // ============================================================
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 // W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
// 插件初始化:保存主机指针,查询 session 依赖,注册 context 服务 / Plugin init: store host pointer, query session dependency, register context service.
static int on_init(const dstalk_host_api_t* host) { static int on_init(const dstalk_host_api_t* host) {
try { try {
g_host = host; g_host = host;
// 查询依赖服务: session // 查询依赖服务: session / Query dependency service: session
void* raw = host->query_service("session", 1); void* raw = host->query_service("session", 1);
if (!raw) { if (!raw) {
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found"); host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
@@ -387,7 +412,8 @@ static int on_init(const dstalk_host_api_t* host) {
} }
} }
// W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log // W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log / Wrapped try/catch prevents exceptions crossing C ABI boundary -- void function only logs
// 插件关闭清空指针。try/catch 保护 ABIvoid 函数) / Plugin shutdown: null out pointers. try/catch guards ABI (void function).
static void on_shutdown() { static void on_shutdown() {
try { try {
g_session = nullptr; g_session = nullptr;
@@ -406,7 +432,7 @@ static void on_shutdown() {
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
"context", "context",
"1.0.0", "1.0.0",
"Context management plugin with token counting and trim support", "Context management plugin with token counting and trim support / 支持 token 计数和裁剪的上下文管理插件",
DSTALK_API_VERSION, DSTALK_API_VERSION,
{"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, {"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
on_init, on_init,
@@ -414,6 +440,7 @@ static dstalk_plugin_info_t g_info = {
nullptr nullptr
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info; return &g_info;
} }

View File

@@ -1,3 +1,10 @@
/*
* @file deepseek_plugin.cpp
* @brief DeepSeek/OpenAI-compatible AI provider plugin with SSE streaming and tool calls.
* DeepSeek/OpenAI 兼容 AI 提供者插件,支持 SSE 流式输出和工具调用。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
@@ -11,14 +18,14 @@
namespace json = boost::json; namespace json = boost::json;
// ============================================================================ // ============================================================================
// 全局指针:从 on_init 获取W14.3: atomic acquire/release 保护读写竞态) // 全局指针:从 on_init 获取W14.3: atomic acquire/release 保护读写竞态) / Global pointers: obtained from on_init (W14.3: atomic acquire/release protects read/write races)
// ============================================================================ // ============================================================================
static std::atomic<const dstalk_host_api_t*> g_host{nullptr}; static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::atomic<dstalk_http_service_t*> g_http{nullptr}; static std::atomic<dstalk_http_service_t*> g_http{nullptr};
static std::atomic<dstalk_config_service_t*> g_config{nullptr}; static std::atomic<dstalk_config_service_t*> g_config{nullptr};
// ============================================================================ // ============================================================================
// 配置数据(由 configure() 设置) // 配置数据(由 configure() 设置) / Config data (set by configure())
// ============================================================================ // ============================================================================
struct PluginConfig { struct PluginConfig {
std::string provider; std::string provider;
@@ -29,19 +36,21 @@ struct PluginConfig {
double temperature = 0.7; double temperature = 0.7;
}; };
static PluginConfig g_cfg; static PluginConfig g_cfg;
static std::string g_tools_json; // W20.2: cached by configure(), consumed by chat/chat_stream static std::string g_tools_json; // W20.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
// ============================================================================ // ============================================================================
// 安全擦除:用 volatile 写零循环防止编译器优化 // 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
// ============================================================================ // ============================================================================
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
static void secure_zero(void* p, size_t n) { static void secure_zero(void* p, size_t n) {
volatile char* vp = (volatile char*)p; volatile char* vp = (volatile char*)p;
while (n--) *vp++ = 0; while (n--) *vp++ = 0;
} }
// ============================================================================ // ============================================================================
// 辅助:从 base_url 提取 host 和 target // 辅助:从 base_url 提取 host 和 target / Helper: extract host and target from base_url
// ============================================================================ // ============================================================================
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
static bool extract_host_port(const std::string& url, static bool extract_host_port(const std::string& url,
std::string& scheme_out, std::string& host_out, std::string& scheme_out, std::string& host_out,
std::string& port_out, std::string& target_out) std::string& port_out, std::string& target_out)
@@ -65,8 +74,9 @@ static bool extract_host_port(const std::string& url,
} }
// ============================================================================ // ============================================================================
// 辅助:构建 headers JSON 字符串 // 辅助:构建 headers JSON 字符串 / Helper: build headers JSON string
// ============================================================================ // ============================================================================
// 构建包含 Bearer 授权令牌的 JSON headers 对象 / Build the JSON headers object containing the Bearer authorization token.
static std::string build_headers_json(const std::string& auth_header_value) static std::string build_headers_json(const std::string& auth_header_value)
{ {
json::object h; json::object h;
@@ -75,8 +85,9 @@ static std::string build_headers_json(const std::string& auth_header_value)
} }
// ============================================================================ // ============================================================================
// 辅助dstalk_message_t[] -> boost::json::array // 辅助dstalk_message_t[] -> boost::json::array / Helper: dstalk_message_t[] -> boost::json::array
// ============================================================================ // ============================================================================
// 将 dstalk_message_t 数组转换为 Boost.JSON 数组,用于 API 请求体 / Convert dstalk_message_t array into a Boost.JSON array for the API request body.
static void append_history(json::array& msgs, static void append_history(json::array& msgs,
const dstalk_message_t* history, int history_len) const dstalk_message_t* history, int history_len)
{ {
@@ -100,8 +111,9 @@ static void append_history(json::array& msgs,
} }
// ============================================================================ // ============================================================================
// 构建 DeepSeek JSON 请求体 // 构建 DeepSeek JSON 请求体 / Build DeepSeek JSON request body
// ============================================================================ // ============================================================================
// 构建 DeepSeek/OpenAI chat completions API 的完整 JSON 请求体 / Build the full JSON request body for the DeepSeek/OpenAI chat completions API.
static std::string build_request_json( static std::string build_request_json(
const dstalk_message_t* history, int history_len, const dstalk_message_t* history, int history_len,
const std::string& user_input, const std::string& user_input,
@@ -117,7 +129,7 @@ static std::string build_request_json(
json::array msgs; json::array msgs;
append_history(msgs, history, history_len); append_history(msgs, history, history_len);
// 追加当前用户输入 // 追加当前用户输入 / Append current user input
if (!user_input.empty()) { if (!user_input.empty()) {
json::object obj; json::object obj;
obj["role"] = "user"; obj["role"] = "user";
@@ -127,7 +139,7 @@ static std::string build_request_json(
root["messages"] = msgs; root["messages"] = msgs;
// tools 定义 // tools 定义 / tools definition
if (!tools_json.empty()) { if (!tools_json.empty()) {
root["tools"] = json::parse(tools_json); root["tools"] = json::parse(tools_json);
} }
@@ -136,8 +148,9 @@ static std::string build_request_json(
} }
// ============================================================================ // ============================================================================
// 解析非流式 JSON 响应 // 解析非流式 JSON 响应 / Parse non-streaming JSON response
// ============================================================================ // ============================================================================
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
static void parse_response(const dstalk_host_api_t* host, static void parse_response(const dstalk_host_api_t* host,
const char* body, int http_status, const char* body, int http_status,
dstalk_chat_result_t& r) dstalk_chat_result_t& r)
@@ -207,13 +220,13 @@ static void parse_response(const dstalk_host_api_t* host,
} }
// ============================================================================ // ============================================================================
// 流式上下文:在 SSE 回调间累积内容和 tool_calls // 流式上下文:在 SSE 回调间累积内容和 tool_calls / Stream context: accumulate content and tool_calls across SSE callbacks
// ============================================================================ // ============================================================================
struct ToolCallAccum { struct ToolCallAccum {
int index = -1; int index = -1;
std::string id; std::string id;
std::string name; std::string name;
std::string arguments; // 增量拼接的 JSON arguments 字符串 std::string arguments; // 增量拼接的 JSON arguments 字符串 / incrementally concatenated JSON arguments string
}; };
struct StreamContext { struct StreamContext {
@@ -222,12 +235,18 @@ struct StreamContext {
void* userdata; void* userdata;
std::string accumulated; std::string accumulated;
bool streaming_ok = true; bool streaming_ok = true;
std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls / accumulate delta tool_calls by index
}; };
// ============================================================================ // ============================================================================
// SSE 行解析OpenAI 兼容格式) // SSE 行解析OpenAI 兼容格式) / SSE line parsing (OpenAI-compatible format)
// ============================================================================ // ============================================================================
// 解析单行 SSE "data:" 行。如果包含 content delta将 token 写入 token_out。
// 如果包含 tool_calls delta累积到 ctx->tool_calls。
// 如果产生了 content token 则返回 true否则返回 falsetool_calls 或未知)。
// Parse a single SSE "data:" line. If it contains a content delta, writes the token
// to token_out. If it contains tool_calls delta, accumulates into ctx->tool_calls.
// Returns true if a content token was produced, false otherwise (tool_calls or unknown).
static bool parse_sse_line(const std::string& line, std::string& token_out, static bool parse_sse_line(const std::string& line, std::string& token_out,
StreamContext* ctx) StreamContext* ctx)
{ {
@@ -235,7 +254,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
std::string data = line.substr(6); std::string data = line.substr(6);
// F-13.2-3: Trim leading/trailing whitespace before comparing [DONE] sentinel. // F-13.2-3: 比较 [DONE] 哨兵前去除首尾空白 / Trim leading/trailing whitespace before comparing [DONE] sentinel.
const char* ws = " \t\r\n"; const char* ws = " \t\r\n";
size_t start = data.find_first_not_of(ws); size_t start = data.find_first_not_of(ws);
if (start != std::string::npos) { if (start != std::string::npos) {
@@ -244,7 +263,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
} }
if (data == "[DONE]") { if (data == "[DONE]") {
token_out.clear(); token_out.clear();
return true; // 流结束信号 return true; // 流结束信号 / stream end signal
} }
try { try {
@@ -254,12 +273,12 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
if (!choices.empty()) { if (!choices.empty()) {
auto delta = choices[0].as_object()["delta"].as_object(); auto delta = choices[0].as_object()["delta"].as_object();
// W20.2: 处理 delta["tool_calls"] 增量 chunk // W20.2: 处理 delta["tool_calls"] 增量 chunk / Handle delta["tool_calls"] incremental chunks
// DeepSeek/OpenAI 流式模式 tool_calls 跨多个 SSE 事件分片传输 // DeepSeek/OpenAI 流式模式 tool_calls 跨多个 SSE 事件分片传输 / DeepSeek/OpenAI streaming mode: tool_calls transmitted across multiple SSE event chunks:
// 事件 1: {"index":0, "id":"call_xxx", "function":{"name":"foo"}} // 事件 1 / Event 1: {"index":0, "id":"call_xxx", "function":{"name":"foo"}}
// 事件 2: {"index":0, "function":{"arguments":"{\"bar\":"}} // 事件 2 / Event 2: {"index":0, "function":{"arguments":"{\"bar\":"}}
// 事件 3: {"index":0, "function":{"arguments":"1}"}} // 事件 3 / Event 3: {"index":0, "function":{"arguments":"1}"}}
// 需要按 index 累积 id/name/arguments // 需要按 index 累积 id/name/arguments / Need to accumulate id/name/arguments by index.
if (delta.contains("tool_calls") && ctx) { if (delta.contains("tool_calls") && ctx) {
auto tc_array = delta["tool_calls"].as_array(); auto tc_array = delta["tool_calls"].as_array();
for (auto& tc_val : tc_array) { for (auto& tc_val : tc_array) {
@@ -288,7 +307,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
} }
} }
} }
return false; // tool_calls 已处理,无内容 token 给用户回调 return false; // tool_calls 已处理,无内容 token 给用户回调 / tool_calls processed, no content token for user callback
} }
if (delta.contains("content")) { if (delta.contains("content")) {
@@ -297,14 +316,15 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
} }
} }
} catch (...) { } catch (...) {
// 忽略解析失败 // 忽略解析失败 / Ignore parse failures
} }
return false; return false;
} }
// ============================================================================ // ============================================================================
// configure 实现 // configure 实现 / configure implementation
// ============================================================================ // ============================================================================
// 配置插件provider、endpoint、auth、model 和生成参数 / Configure the plugin with provider, endpoint, auth, model, and generation parameters.
static int my_configure(const char* provider, const char* base_url, static int my_configure(const char* provider, const char* base_url,
const char* api_key, const char* model, const char* api_key, const char* model,
int max_tokens, double temperature) int max_tokens, double temperature)
@@ -319,7 +339,7 @@ static int my_configure(const char* provider, const char* base_url,
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) { if (host) {
// W20.2: 从 tools service 缓存 tools_json供 chat/chat_stream 复用 // W20.2: 从 tools service 缓存 tools_json供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>( auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
host->query_service("tools", 1)); host->query_service("tools", 1));
if (tools_svc && tools_svc->get_tools_json) { if (tools_svc && tools_svc->get_tools_json) {
@@ -348,8 +368,9 @@ static int my_configure(const char* provider, const char* base_url,
} }
// ============================================================================ // ============================================================================
// chat 实现 // chat 实现 / chat implementation
// ============================================================================ // ============================================================================
// 非流式 chat completion发送 history + user input返回完整响应 / Non-streaming chat completion: send history + user input, return full response.
static dstalk_chat_result_t my_chat( static dstalk_chat_result_t my_chat(
const dstalk_message_t* history, int history_len, const dstalk_message_t* history, int history_len,
const char* user_input, const char* user_input,
@@ -412,29 +433,29 @@ static dstalk_chat_result_t my_chat(
} }
// ============================================================================ // ============================================================================
// chat_stream 实现 // chat_stream 实现 / chat_stream implementation
// ============================================================================ // ============================================================================
// 行回调:解析 SSE line将 token 传递给用户回调 // 行回调:解析 SSE line将 token 传递给用户回调 / SSE line callback: parses each line and forwards content tokens to the user callback.
static int sse_line_callback(const char* line, void* userdata) static int sse_line_callback(const char* line, void* userdata)
{ {
try { try {
auto* ctx = static_cast<StreamContext*>(userdata); auto* ctx = static_cast<StreamContext*>(userdata);
if (!line || !line[0]) return 1; // 空行,继续 if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
std::string line_str(line); std::string line_str(line);
std::string token; std::string token;
if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续 if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续 / not a data/tool_calls line, continue
if (token.empty()) return 0; // [DONE],停止 if (token.empty()) return 0; // [DONE],停止 / [DONE], stop
ctx->accumulated += token; ctx->accumulated += token;
if (ctx->user_cb) { if (ctx->user_cb) {
return ctx->user_cb(token.c_str(), ctx->userdata); return ctx->user_cb(token.c_str(), ctx->userdata);
} }
return 1; // 继续 return 1; // 继续 / continue
} catch (const std::exception& e) { } catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[deepseek] sse_line_callback exception: %s", e.what()); if (host && host->log) host->log(DSTALK_LOG_ERROR, "[deepseek] sse_line_callback exception: %s", e.what());
@@ -446,6 +467,9 @@ static int sse_line_callback(const char* line, void* userdata)
} }
} }
// 流式 chat completion以 stream=true 发送 history + user input通过回调传递 token。
// 在 SSE 分片中累积 tool_calls 并在结束时序列化 / Streaming chat completion: send history + user input with stream=true, deliver tokens
// via callback. Accumulates tool_calls across SSE chunks and serializes them at end.
static dstalk_chat_result_t my_chat_stream( static dstalk_chat_result_t my_chat_stream(
const dstalk_message_t* history, int history_len, const dstalk_message_t* history, int history_len,
const char* user_input, const char* user_input,
@@ -488,10 +512,10 @@ static dstalk_chat_result_t my_chat_stream(
r.http_status = status_code; r.http_status = status_code;
// 检查传输层错误或非 2xx 状态 // 检查传输层错误或非 2xx 状态 / Check transport errors or non-2xx status
if (status_code < 200 || status_code >= 300) { if (status_code < 200 || status_code >= 300) {
r.ok = 0; r.ok = 0;
// 尝试从响应体提取错误信息 // 尝试从响应体提取错误信息 / Try to extract error info from response body
if (response_body && response_body[0]) { if (response_body && response_body[0]) {
try { try {
auto jv = json::parse(response_body); auto jv = json::parse(response_body);
@@ -518,7 +542,7 @@ static dstalk_chat_result_t my_chat_stream(
if (response_body && host) host->free(response_body); if (response_body && host) host->free(response_body);
// W20.2: 成功条件 = 有内容 OR 有 tool_callstool-only 响应如 function calling // W20.2: 成功条件 = 有内容 OR 有 tool_callstool-only 响应如 function calling / Success = has content OR has tool_calls (tool-only responses like function calling)
bool has_content = !ctx.accumulated.empty(); bool has_content = !ctx.accumulated.empty();
bool has_tool_calls = !ctx.tool_calls.empty(); bool has_tool_calls = !ctx.tool_calls.empty();
@@ -533,7 +557,7 @@ static dstalk_chat_result_t my_chat_stream(
r.content = has_content r.content = has_content
? host->strdup(ctx.accumulated.c_str()) : nullptr; ? host->strdup(ctx.accumulated.c_str()) : nullptr;
// 序列化累积的 tool_calls 为 JSON兼容 OpenAI tool_calls 格式) // 序列化累积的 tool_calls 为 JSON兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible tool_calls format)
if (has_tool_calls) { if (has_tool_calls) {
json::array tc_array; json::array tc_array;
for (auto& tc : ctx.tool_calls) { for (auto& tc : ctx.tool_calls) {
@@ -572,8 +596,9 @@ static dstalk_chat_result_t my_chat_stream(
} }
// ============================================================================ // ============================================================================
// free_result 实现 // free_result 实现 / free_result implementation
// ============================================================================ // ============================================================================
// 释放 chat result 结构体中所有主机分配的字符串字段 / Free all host-allocated string fields in a chat result struct.
static void my_free_result(dstalk_chat_result_t* result) static void my_free_result(dstalk_chat_result_t* result)
{ {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
@@ -584,7 +609,7 @@ static void my_free_result(dstalk_chat_result_t* result)
} }
// ============================================================================ // ============================================================================
// 服务 vtable // 服务 vtable / Service vtable
// ============================================================================ // ============================================================================
static dstalk_ai_service_t g_service = { static dstalk_ai_service_t g_service = {
&my_configure, &my_configure,
@@ -594,8 +619,9 @@ static dstalk_ai_service_t g_service = {
}; };
// ============================================================================ // ============================================================================
// 生命周期 // 生命周期 / Lifecycle
// ============================================================================ // ============================================================================
// 插件初始化:查询 http 和 config 服务,注册 ai.deepseek 服务 / Plugin init: query http and config services, register ai.deepseek service.
static int on_init(const dstalk_host_api_t* host) static int on_init(const dstalk_host_api_t* host)
{ {
try { try {
@@ -624,6 +650,7 @@ static int on_init(const dstalk_host_api_t* host)
} }
} }
// 插件关闭:从内存安全擦除 API key清空服务指针 / Plugin shutdown: securely erase API key from memory, null out service pointers.
static void on_shutdown() static void on_shutdown()
{ {
try { try {
@@ -644,12 +671,12 @@ static void on_shutdown()
} }
// ============================================================================ // ============================================================================
// 插件描述符 // 插件描述符 / Plugin descriptor
// ============================================================================ // ============================================================================
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
/* .name = */ "deepseek-ai", /* .name = */ "deepseek-ai",
/* .version = */ "1.0.0", /* .version = */ "1.0.0",
/* .description = */ "DeepSeek AI provider (OpenAI-compatible API)", /* .description = */ "DeepSeek AI provider (OpenAI-compatible API) / DeepSeek AI 提供者 (OpenAI 兼容 API)",
/* .api_version = */ DSTALK_API_VERSION, /* .api_version = */ DSTALK_API_VERSION,
/* .dependencies = */ { "http", "config", NULL }, /* .dependencies = */ { "http", "config", NULL },
/* .on_init = */ on_init, /* .on_init = */ on_init,
@@ -657,6 +684,7 @@ static dstalk_plugin_info_t g_info = {
/* .on_event = */ nullptr, /* .on_event = */ nullptr,
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
{ {
return &g_info; return &g_info;

View File

@@ -1,3 +1,10 @@
/*
* @file file_io_plugin.cpp
* @brief File I/O plugin: basic file read/write service.
* 文件 I/O 插件:基本文件读写服务。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
@@ -6,20 +13,21 @@
#include <cstring> #include <cstring>
// ============================================================ // ============================================================
// Global state // 全局状态 / Global state
// ============================================================ // ============================================================
static const dstalk_host_api_t* g_host = nullptr; static const dstalk_host_api_t* g_host = nullptr;
// ============================================================ // ============================================================
// Service implementations // 服务实现 / Service implementations
// ============================================================ // ============================================================
// 读取文件全部内容到主机分配的缓冲区,调用方须通过 host->free 释放 / Read the entire contents of a file into a host-allocated buffer. Caller must free via host->free.
static int file_read(const char* path, char** content) { static int file_read(const char* path, char** content) {
if (!path || !content) return -1; if (!path || !content) return -1;
FILE* fp = fopen(path, "rb"); FILE* fp = fopen(path, "rb");
if (!fp) return -1; if (!fp) return -1;
// Get file size // 获取文件大小 / Get file size
fseek(fp, 0, SEEK_END); fseek(fp, 0, SEEK_END);
long fsize = ftell(fp); long fsize = ftell(fp);
fseek(fp, 0, SEEK_SET); fseek(fp, 0, SEEK_SET);
@@ -29,7 +37,7 @@ static int file_read(const char* path, char** content) {
return -1; return -1;
} }
// Allocate buffer via host allocator (+1 for null terminator) // 通过主机分配器分配缓冲区(+1 用于空终止符) / Allocate buffer via host allocator (+1 for null terminator)
char* buf = (char*)g_host->alloc((size_t)fsize + 1); char* buf = (char*)g_host->alloc((size_t)fsize + 1);
if (!buf) { if (!buf) {
fclose(fp); fclose(fp);
@@ -49,6 +57,7 @@ static int file_read(const char* path, char** content) {
return 0; return 0;
} }
// 将字符串写入文件,覆盖已有内容 / Write a string to a file, overwriting any existing content.
static int file_write(const char* path, const char* content) { static int file_write(const char* path, const char* content) {
if (!path || !content) return -1; if (!path || !content) return -1;
@@ -68,28 +77,31 @@ static dstalk_file_io_service_t g_service = {
}; };
// ============================================================ // ============================================================
// Plugin lifecycle // 插件生命周期 / Plugin lifecycle
// ============================================================ // ============================================================
// 插件初始化:保存主机指针并注册 file_io 服务 / Plugin init: store host pointer and register the file_io service.
static int on_init(const dstalk_host_api_t* host) { static int on_init(const dstalk_host_api_t* host) {
g_host = host; g_host = host;
return host->register_service("file_io", 1, &g_service); return host->register_service("file_io", 1, &g_service);
} }
// 插件关闭:无需清理 / Plugin shutdown: nothing to clean up.
static void on_shutdown() { static void on_shutdown() {
// nothing to clean up // 无需清理 / nothing to clean up
} }
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
"file-io", // name "file-io", // name 名称
"1.0.0", // version "1.0.0", // version 版本
"Basic file I/O service", // description "Basic file I/O service", // description 描述
DSTALK_API_VERSION, // api_version DSTALK_API_VERSION, // api_version
{nullptr}, // dependencies (none) {nullptr}, // dependencies 依赖 (none)
on_init, // on_init on_init, // on_init
on_shutdown, // on_shutdown on_shutdown, // on_shutdown
nullptr // on_event nullptr // on_event
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info; return &g_info;
} }

View File

@@ -1,10 +1,15 @@
/* /*
* plugin-lsp — LSP (Language Server Protocol) 服务 * @file lsp_plugin.cpp
* * @brief LSP plugin: Language Server Protocol JSON-RPC client for diagnostics, hover, completion.
* 自行管理语言服务器子进程,使用 JSON-RPC 2.0 over stdio 通信 * LSP 插件Language Server Protocol JSON-RPC 客户端,用于诊断、悬停、补全
* 无外部服务依赖(不依赖 http/config 等其他插件)。 * Copyright (c) 2026 dstalk contributors. GPLv3.
*/ */
// plugin-lsp — LSP (Language Server Protocol) 服务 / LSP (Language Server Protocol) service
//
// 自行管理语言服务器子进程,使用 JSON-RPC 2.0 over stdio 通信 / Self-manages language server subprocess, communicates via JSON-RPC 2.0 over stdio.
// 无外部服务依赖(不依赖 http/config 等其他插件) / No external service dependencies (does not depend on http/config or other plugins).
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
@@ -22,7 +27,7 @@
#include <unordered_map> #include <unordered_map>
// ============================================================================ // ============================================================================
// 平台相关 — 子进程管理 (内嵌 subprocess::Process) // 平台相关 — 子进程管理 (内嵌 subprocess::Process) / Platform specific — subprocess management (embedded subprocess::Process)
// ============================================================================ // ============================================================================
#ifdef _WIN32 #ifdef _WIN32
@@ -45,12 +50,12 @@
namespace json = boost::json; namespace json = boost::json;
// ============================================================================ // ============================================================================
// 全局指针 // 全局指针 / Global pointers
// ============================================================================ // ============================================================================
static const dstalk_host_api_t* g_host = nullptr; static const dstalk_host_api_t* g_host = nullptr;
// ============================================================================ // ============================================================================
// 子进程封装 (内嵌 subprocess.hpp) // 子进程封装 (内嵌 subprocess.hpp) / Subprocess wrapper (embedded subprocess.hpp)
// ============================================================================ // ============================================================================
struct Process { struct Process {
#ifdef _WIN32 #ifdef _WIN32
@@ -64,6 +69,7 @@ struct Process {
int stdout_fd = -1; int stdout_fd = -1;
#endif #endif
// 从给定命令行启动子进程。为 stdin/stdout 设置管道 / Start a child process from the given command line. Sets up pipes for stdin/stdout.
bool start(const char* cmd) { bool start(const char* cmd) {
if (!cmd || !cmd[0]) return false; if (!cmd || !cmd[0]) return false;
stop(); stop();
@@ -169,6 +175,7 @@ struct Process {
#endif #endif
} }
// 优雅终止子进程,回退到 SIGKILL/TerminateProcess / Gracefully terminate the child process, with fallback to SIGKILL/TerminateProcess.
void stop() { void stop() {
#ifdef _WIN32 #ifdef _WIN32
if (hProcess != INVALID_HANDLE_VALUE) { if (hProcess != INVALID_HANDLE_VALUE) {
@@ -198,6 +205,7 @@ struct Process {
#endif #endif
} }
// 将数据字符串写入子进程 stdin 管道 / Write a data string to the child's stdin pipe.
bool write(const std::string& data) { bool write(const std::string& data) {
if (data.empty()) return true; if (data.empty()) return true;
#ifdef _WIN32 #ifdef _WIN32
@@ -219,6 +227,7 @@ struct Process {
#endif #endif
} }
// 从子进程 stdout 管道读取一行(到并包括 '\n' / Read one line (up to and including '\n') from the child's stdout pipe.
bool read_line(std::string& line) { bool read_line(std::string& line) {
line.clear(); line.clear();
#ifdef _WIN32 #ifdef _WIN32
@@ -242,6 +251,7 @@ struct Process {
#endif #endif
} }
// 从子进程 stdout 管道读取恰好 count 字节到 buf / Read exactly `count` bytes from the child's stdout pipe into `buf`.
bool read_bytes(std::string& buf, int count) { bool read_bytes(std::string& buf, int count) {
if (count <= 0) { buf.clear(); return true; } if (count <= 0) { buf.clear(); return true; }
#ifdef _WIN32 #ifdef _WIN32
@@ -274,7 +284,7 @@ struct Process {
}; };
// ============================================================================ // ============================================================================
// LSP 状态(静态单例) // LSP 状态(静态单例) / LSP state (static singleton)
// ============================================================================ // ============================================================================
struct LspState { struct LspState {
Process proc; Process proc;
@@ -283,23 +293,24 @@ struct LspState {
std::atomic<int> next_id{1}; std::atomic<int> next_id{1};
// 响应用于同步等待 // 响应用于同步等待 / Responses for synchronous waiting
std::mutex mutex; std::mutex mutex;
std::condition_variable cv; std::condition_variable cv;
std::unordered_map<int, std::string> pending_responses; std::unordered_map<int, std::string> pending_responses;
// 诊断缓存: URI -> JSON 字符串 // 诊断缓存: URI -> JSON 字符串 / Diagnostics cache: URI -> JSON string
std::unordered_map<std::string, std::string> diagnostics; std::unordered_map<std::string, std::string> diagnostics;
// 读取线程 // 读取线程 / Reader thread
std::thread reader_thread; std::thread reader_thread;
}; };
static LspState g_lsp; static LspState g_lsp;
// ============================================================================ // ============================================================================
// 辅助函数 // 辅助函数 / Helper functions
// ============================================================================ // ============================================================================
// 去除 string_view 首尾空白 / Trim leading and trailing whitespace from a string_view.
static std::string_view trim(std::string_view sv) { static std::string_view trim(std::string_view sv) {
while (!sv.empty() && (sv.front() == ' ' || sv.front() == '\t' || while (!sv.empty() && (sv.front() == ' ' || sv.front() == '\t' ||
sv.front() == '\r' || sv.front() == '\n')) sv.front() == '\r' || sv.front() == '\n'))
@@ -310,6 +321,7 @@ static std::string_view trim(std::string_view sv) {
return sv; return sv;
} }
// 将 JSON-RPC 消息体包装在 LSP 头中 (Content-Length: ...\r\n\r\n) / Wrap a JSON-RPC message body in an LSP header (Content-Length: ...\r\n\r\n).
static std::string frame_message(const std::string& body) { static std::string frame_message(const std::string& body) {
std::string frame; std::string frame;
frame.reserve(64 + body.size()); frame.reserve(64 + body.size());
@@ -320,6 +332,7 @@ static std::string frame_message(const std::string& body) {
return frame; return frame;
} }
// 从 LSP 头行中解析 Content-Length 值。解析失败返回 -1 / Parse the Content-Length value from an LSP header line. Returns -1 on parse failure.
static int parse_content_length(const std::string& line) { static int parse_content_length(const std::string& line) {
auto sv = trim(std::string_view(line)); auto sv = trim(std::string_view(line));
const char prefix[] = "Content-Length:"; const char prefix[] = "Content-Length:";
@@ -341,9 +354,10 @@ static int parse_content_length(const std::string& line) {
} }
// ============================================================================ // ============================================================================
// JSON-RPC 消息发送 // JSON-RPC 消息发送 / JSON-RPC message sending
// ============================================================================ // ============================================================================
// 向 LSP 服务器发送 JSON-RPC 请求并返回分配的请求 id / Send a JSON-RPC request to the LSP server and return the assigned request id.
static int send_request(const std::string& method, const json::object& params) { static int send_request(const std::string& method, const json::object& params) {
int id = g_lsp.next_id.fetch_add(1); int id = g_lsp.next_id.fetch_add(1);
@@ -358,6 +372,7 @@ static int send_request(const std::string& method, const json::object& params) {
return id; return id;
} }
// 向 LSP 服务器发送 JSON-RPC 通知(无 id 字段,不期待响应) / Send a JSON-RPC notification to the LSP server (no id field, no response expected).
static void send_notification(const std::string& method, const json::object& params) { static void send_notification(const std::string& method, const json::object& params) {
json::object msg; json::object msg;
msg["jsonrpc"] = "2.0"; msg["jsonrpc"] = "2.0";
@@ -369,9 +384,12 @@ static void send_notification(const std::string& method, const json::object& par
} }
// ============================================================================ // ============================================================================
// 消息处理 // 消息处理 / Message handling
// ============================================================================ // ============================================================================
// 分发接收到的 JSON-RPC 消息:将响应路由到待处理队列,
// 处理 textDocument/publishDiagnostics 通知并存入诊断缓存 / Dispatch a received JSON-RPC message: route responses to pending queue,
// handle textDocument/publishDiagnostics notifications into diagnostics cache.
static void handle_message(const std::string& body) { static void handle_message(const std::string& body) {
try { try {
json::value val; json::value val;
@@ -383,14 +401,14 @@ static void handle_message(const std::string& body) {
catch (...) { return; } catch (...) { return; }
if (msg.contains("id") && !msg.contains("method")) { if (msg.contains("id") && !msg.contains("method")) {
// 响应 (有 id, 无 method) // 响应 (有 id, 无 method) / Response (has id, no method)
int id = static_cast<int>(msg["id"].as_int64()); int id = static_cast<int>(msg["id"].as_int64());
std::lock_guard<std::mutex> lock(g_lsp.mutex); std::lock_guard<std::mutex> lock(g_lsp.mutex);
g_lsp.pending_responses[id] = body; g_lsp.pending_responses[id] = body;
g_lsp.cv.notify_all(); g_lsp.cv.notify_all();
} else if (msg.contains("method") && !msg.contains("id")) { } else if (msg.contains("method") && !msg.contains("id")) {
// 通知 (有 method, 无 id) // 通知 (有 method, 无 id) / Notification (has method, no id)
std::string method; std::string method;
try { method = json::value_to<std::string>(msg["method"]); } try { method = json::value_to<std::string>(msg["method"]); }
catch (...) { return; } catch (...) { return; }
@@ -419,17 +437,18 @@ static void handle_message(const std::string& body) {
} }
// ============================================================================ // ============================================================================
// 读取线程主循环 // 读取线程主循环 / Reader thread main loop
// ============================================================================ // ============================================================================
// 读取线程主循环:解析 LSP header+body 帧并分发消息 / Main loop for the reader thread: parse LSP header+body frames and dispatch messages.
static void reader_loop() { static void reader_loop() {
try { try {
while (g_lsp.running) { while (g_lsp.running) {
int content_length = -1; int content_length = -1;
bool pipe_ok = true; bool pipe_ok = true;
// 状态机式读取 header 块:循环 read_line 直到读到空行 // 状态机式读取 header 块:循环 read_line 直到读到空行 / State-machine header block read: loop read_line until empty line
// LSP 3.17: header 块以空行(\r\n)结束,允许 Content-Type 等其他 header // LSP 3.17: header 块以空行(\r\n)结束,允许 Content-Type 等其他 header / LSP 3.17: header block ends with empty line (\r\n), allows other headers like Content-Type
while (pipe_ok) { while (pipe_ok) {
std::string line; std::string line;
if (!g_lsp.proc.read_line(line)) { if (!g_lsp.proc.read_line(line)) {
@@ -437,18 +456,18 @@ static void reader_loop() {
break; break;
} }
// header 块以空行结束 // header 块以空行结束 / header block ends with empty line
auto sv = trim(std::string_view(line)); auto sv = trim(std::string_view(line));
if (sv.empty()) break; if (sv.empty()) break;
// 累积 Content-Length遇到其他 header 不丢弃,继续读取下一行 // 累积 Content-Length遇到其他 header 不丢弃,继续读取下一行 / Accumulate Content-Length; don't discard other headers, continue reading next line
int len = parse_content_length(line); int len = parse_content_length(line);
if (len >= 0) content_length = len; if (len >= 0) content_length = len;
} }
if (!pipe_ok) break; if (!pipe_ok) break;
// 空行前都没读到 Content-Length协议错误——记日志并跳过这一帧 // 空行前都没读到 Content-Length协议错误——记日志并跳过这一帧 / Content-Length not read before empty line, protocol error — log and skip this frame
if (content_length < 0) { if (content_length < 0) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] Invalid LSP frame: missing Content-Length header"); if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] Invalid LSP frame: missing Content-Length header");
continue; continue;
@@ -471,38 +490,39 @@ static void reader_loop() {
} }
// ============================================================================ // ============================================================================
// LSP 服务 vtable 实现 (定义在 vtable 变量之前) // LSP 服务 vtable 实现 (定义在 vtable 变量之前) / LSP service vtable implementation (defined before vtable variable)
// ============================================================================ // ============================================================================
static void g_lsp_impl_stop(); static void g_lsp_impl_stop();
static void g_lsp_impl_stop_nolock(); static void g_lsp_impl_stop_nolock();
static void g_lsp_impl_stop_locked(std::unique_lock<std::mutex>& lock); static void g_lsp_impl_stop_locked(std::unique_lock<std::mutex>& lock);
// 启动 LSP 服务器进程,发送 initialize/initialized 握手,启动读取线程 / Start the LSP server process, send initialize/initialized handshake, start reader thread.
static int g_lsp_impl_start(const char* server_cmd, const char* language) { static int g_lsp_impl_start(const char* server_cmd, const char* language) {
if (!server_cmd || !server_cmd[0]) return -1; if (!server_cmd || !server_cmd[0]) return -1;
try { try {
// 如果已在运行, 先停止 // 如果已在运行, 先停止 / If already running, stop first
if (g_lsp.running) { if (g_lsp.running) {
g_lsp_impl_stop(); g_lsp_impl_stop();
} }
g_lsp.language = language ? language : ""; g_lsp.language = language ? language : "";
// 启动进程 // 启动进程 / Start process
if (!g_lsp.proc.start(server_cmd)) { if (!g_lsp.proc.start(server_cmd)) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] failed to start: %s", server_cmd); if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] failed to start: %s", server_cmd);
return -1; return -1;
} }
// 重置 ID 计数器 // 重置 ID 计数器 / Reset ID counter
g_lsp.next_id = 1; g_lsp.next_id = 1;
// 启动读取线程 // 启动读取线程 / Start reader thread
g_lsp.running = true; g_lsp.running = true;
g_lsp.reader_thread = std::thread(reader_loop); g_lsp.reader_thread = std::thread(reader_loop);
// 构建 initialize 参数 // 构建 initialize 参数 / Build initialize params
json::object text_doc_caps; json::object text_doc_caps;
{ {
json::object hover; json::object hover;
@@ -526,10 +546,10 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
init_params["rootUri"] = nullptr; init_params["rootUri"] = nullptr;
init_params["capabilities"] = capabilities; init_params["capabilities"] = capabilities;
// 发送 initialize 请求 // 发送 initialize 请求 / Send initialize request
int init_id = send_request("initialize", init_params); int init_id = send_request("initialize", init_params);
// 等待 initialize 响应 (最多 10 秒) // 等待 initialize 响应 (最多 10 秒) / Wait for initialize response (max 10 seconds)
{ {
std::unique_lock<std::mutex> lock(g_lsp.mutex); std::unique_lock<std::mutex> lock(g_lsp.mutex);
bool got = g_lsp.cv.wait_for(lock, std::chrono::seconds(10), [init_id]() { bool got = g_lsp.cv.wait_for(lock, std::chrono::seconds(10), [init_id]() {
@@ -544,7 +564,7 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
g_lsp.pending_responses.erase(init_id); g_lsp.pending_responses.erase(init_id);
} }
// 发送 initialized 通知 // 发送 initialized 通知 / Send initialized notification
send_notification("initialized", json::object{}); send_notification("initialized", json::object{});
if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] server started: %s", server_cmd); if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] server started: %s", server_cmd);
@@ -558,14 +578,15 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
} }
} }
// 停止 LSP 服务器:发送 shutdown 请求,发送 exit 通知,停止进程和线程 / Stop the LSP server: send shutdown request, send exit notification, stop process & thread.
static void g_lsp_impl_stop_nolock() { static void g_lsp_impl_stop_nolock() {
try { try {
if (!g_lsp.running) return; if (!g_lsp.running) return;
// 发送 shutdown 请求 // 发送 shutdown 请求 / Send shutdown request
int shutdown_id = send_request("shutdown", json::object{}); int shutdown_id = send_request("shutdown", json::object{});
// 等待 shutdown 响应 (最多 2 秒) // 等待 shutdown 响应 (最多 2 秒) / Wait for shutdown response (max 2 seconds)
{ {
std::unique_lock<std::mutex> lock(g_lsp.mutex); std::unique_lock<std::mutex> lock(g_lsp.mutex);
g_lsp.cv.wait_for(lock, std::chrono::seconds(2), [shutdown_id]() { g_lsp.cv.wait_for(lock, std::chrono::seconds(2), [shutdown_id]() {
@@ -574,10 +595,10 @@ static void g_lsp_impl_stop_nolock() {
g_lsp.pending_responses.clear(); g_lsp.pending_responses.clear();
} }
// 发送 exit 通知 // 发送 exit 通知 / Send exit notification
send_notification("exit", json::object{}); send_notification("exit", json::object{});
// 停止读取线程 // 停止读取线程 / Stop reader thread
g_lsp.running = false; g_lsp.running = false;
g_lsp.proc.stop(); g_lsp.proc.stop();
@@ -593,15 +614,18 @@ static void g_lsp_impl_stop_nolock() {
} }
} }
// 公开 stop无锁获取委托给 g_lsp_impl_stop_nolock / Public stop: acquires no lock (delegates to g_lsp_impl_stop_nolock).
static void g_lsp_impl_stop() { static void g_lsp_impl_stop() {
g_lsp_impl_stop_nolock(); g_lsp_impl_stop_nolock();
} }
// Stop 辅助函数:在调用 g_lsp_impl_stop_nolock 前解锁给定的 unique_lock / Stop helper: unlocks the given unique_lock before calling g_lsp_impl_stop_nolock.
static void g_lsp_impl_stop_locked(std::unique_lock<std::mutex>& lock) { static void g_lsp_impl_stop_locked(std::unique_lock<std::mutex>& lock) {
lock.unlock(); lock.unlock();
g_lsp_impl_stop_nolock(); g_lsp_impl_stop_nolock();
} }
// 向 LSP 服务器发送 textDocument/didOpen 通知 / Send a textDocument/didOpen notification to the LSP server.
static int g_lsp_impl_open_document(const char* uri, const char* content, static int g_lsp_impl_open_document(const char* uri, const char* content,
const char* lang_id) { const char* lang_id) {
if (!g_lsp.running) return -1; if (!g_lsp.running) return -1;
@@ -628,6 +652,7 @@ static int g_lsp_impl_open_document(const char* uri, const char* content,
} }
} }
// 向 LSP 服务器发送 textDocument/didClose 通知 / Send a textDocument/didClose notification to the LSP server.
static int g_lsp_impl_close_document(const char* uri) { static int g_lsp_impl_close_document(const char* uri) {
if (!g_lsp.running) return -1; if (!g_lsp.running) return -1;
if (!uri) return -1; if (!uri) return -1;
@@ -650,6 +675,7 @@ static int g_lsp_impl_close_document(const char* uri) {
} }
} }
// 返回给定文档 URI 的缓存诊断 JSON / Return the cached diagnostics JSON for the given document URI.
static int g_lsp_impl_get_diagnostics(const char* uri, char** json_out) { static int g_lsp_impl_get_diagnostics(const char* uri, char** json_out) {
if (!g_lsp.running) return -1; if (!g_lsp.running) return -1;
if (!uri || !json_out) return -1; if (!uri || !json_out) return -1;
@@ -674,6 +700,7 @@ static int g_lsp_impl_get_diagnostics(const char* uri, char** json_out) {
} }
} }
// 发送 textDocument/hover 请求并以 JSON 返回悬停结果 / Send a textDocument/hover request and return the hover result as JSON.
static int g_lsp_impl_get_hover(const char* uri, int line, int col, char** json_out) { static int g_lsp_impl_get_hover(const char* uri, int line, int col, char** json_out) {
if (!g_lsp.running) return -1; if (!g_lsp.running) return -1;
if (!uri || !json_out) return -1; if (!uri || !json_out) return -1;
@@ -727,6 +754,7 @@ static int g_lsp_impl_get_hover(const char* uri, int line, int col, char** json_
} }
} }
// 发送 textDocument/completion 请求并以 JSON 返回补全列表 / Send a textDocument/completion request and return the completion list as JSON.
static int g_lsp_impl_get_completion(const char* uri, int line, int col, char** json_out) { static int g_lsp_impl_get_completion(const char* uri, int line, int col, char** json_out) {
if (!g_lsp.running) return -1; if (!g_lsp.running) return -1;
if (!uri || !json_out) return -1; if (!uri || !json_out) return -1;
@@ -781,7 +809,7 @@ static int g_lsp_impl_get_completion(const char* uri, int line, int col, char**
} }
// ============================================================================ // ============================================================================
// 服务 vtable // 服务 vtable / Service vtable
// ============================================================================ // ============================================================================
static dstalk_lsp_service_t g_service_vtable = { static dstalk_lsp_service_t g_service_vtable = {
@@ -795,15 +823,17 @@ static dstalk_lsp_service_t g_service_vtable = {
}; };
// ============================================================================ // ============================================================================
// 生命周期回调 // 生命周期回调 / Lifecycle callbacks
// ============================================================================ // ============================================================================
// 插件初始化:保存主机指针并注册 lsp 服务 / Plugin init: store host pointer and register the lsp service.
static int on_init(const dstalk_host_api_t* host) { static int on_init(const dstalk_host_api_t* host) {
g_host = host; g_host = host;
if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] initializing LSP service plugin"); if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] initializing LSP service plugin");
return host->register_service("lsp", 1, &g_service_vtable); return host->register_service("lsp", 1, &g_service_vtable);
} }
// 插件关闭:如果运行中则停止 LSP 服务器,清空主机指针 / Plugin shutdown: stop LSP server if running, null out host pointer.
static void on_shutdown() { static void on_shutdown() {
try { try {
if (g_lsp.running) { if (g_lsp.running) {
@@ -821,20 +851,21 @@ static void on_shutdown() {
} }
// ============================================================================ // ============================================================================
// 插件描述符 // 插件描述符 / Plugin descriptor
// ============================================================================ // ============================================================================
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
/* .name = */ "lsp", /* .name = */ "lsp",
/* .version = */ "1.0.0", /* .version = */ "1.0.0",
/* .description = */ "Language Server Protocol client (subprocess manager)", /* .description = */ "Language Server Protocol client (subprocess manager) / Language Server Protocol 客户端(子进程管理器)",
/* .api_version = */ DSTALK_API_VERSION, /* .api_version = */ DSTALK_API_VERSION,
/* .dependencies = */ { NULL }, // 无依赖,自行管理子进程 /* .dependencies = */ { NULL }, // 无依赖,自行管理子进程 / No dependencies, self-manages subprocess
/* .on_init = */ on_init, /* .on_init = */ on_init,
/* .on_shutdown = */ on_shutdown, /* .on_shutdown = */ on_shutdown,
/* .on_event = */ nullptr, /* .on_event = */ nullptr,
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info; return &g_info;
} }

View File

@@ -1,4 +1,11 @@
// MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20) /*
* @file network_plugin.cpp
* @brief Network plugin: HTTP/HTTPS POST and streaming via Boost.Beast + OpenSSL.
* 网络插件:基于 Boost.Beast + OpenSSL 的 HTTP/HTTPS POST 和流式传输。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
// MSVC 14.16 (VS 2017) 不提供 std::to_address (C++20) / MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20)
#define BOOST_ASIO_DISABLE_STD_TO_ADDRESS #define BOOST_ASIO_DISABLE_STD_TO_ADDRESS
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
@@ -29,21 +36,22 @@ namespace ssl = boost::asio::ssl;
using tcp = asio::ip::tcp; using tcp = asio::ip::tcp;
// ============================================================ // ============================================================
// Global state // 全局状态 / Global state
// ============================================================ // ============================================================
static const dstalk_host_api_t* g_host = nullptr; static const dstalk_host_api_t* g_host = nullptr;
static dstalk_config_service_t* g_config_svc = nullptr; static dstalk_config_service_t* g_config_svc = nullptr;
// ============================================================ // ============================================================
// Minimal JSON header parser // 极简 JSON 头解析器 / Minimal JSON header parser
// Parses {"key1":"value1","key2":"value2"} into unordered_map // 将 {"key1":"value1","key2":"value2"} 解析到 unordered_map / Parses {"key1":"value1","key2":"value2"} into unordered_map
// ============================================================ // ============================================================
// 将扁平 JSON 对象中的字符串键值对解析到 unordered_map / Parse a flat JSON object of string key-value pairs into an unordered_map.
static std::unordered_map<std::string, std::string> parse_headers_json(const char* json) { static std::unordered_map<std::string, std::string> parse_headers_json(const char* json) {
std::unordered_map<std::string, std::string> headers; std::unordered_map<std::string, std::string> headers;
if (!json || !*json) return headers; if (!json || !*json) return headers;
std::string s(json); std::string s(json);
// Very simple state-machine parser for flat string-key/value objects // 极简状态机解析器,处理扁平的字符串键值对象 / Very simple state-machine parser for flat string-key/value objects
enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE; enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE;
std::string current_key; std::string current_key;
std::string current_value; std::string current_value;
@@ -64,7 +72,7 @@ static std::unordered_map<std::string, std::string> parse_headers_json(const cha
break; break;
case IN_VALUE: case IN_VALUE:
if (c == '"') { if (c == '"') {
// Read until closing quote // 读取到闭合引号 / Read until closing quote
++i; ++i;
while (i < s.size() && s[i] != '"') { while (i < s.size() && s[i] != '"') {
if (s[i] == '\\' && i + 1 < s.size()) { current_value += s[++i]; } if (s[i] == '\\' && i + 1 < s.size()) { current_value += s[++i]; }
@@ -81,7 +89,7 @@ static std::unordered_map<std::string, std::string> parse_headers_json(const cha
} }
// ============================================================ // ============================================================
// HTTP Client implementation (adapted from dstalk-core HttpClient) // HTTP 客户端实现(改编自 dstalk-core HttpClient / HTTP Client implementation (adapted from dstalk-core HttpClient)
// ============================================================ // ============================================================
struct HttpClientCtx { struct HttpClientCtx {
asio::io_context ioc; asio::io_context ioc;
@@ -91,15 +99,22 @@ struct HttpClientCtx {
HttpClientCtx() { HttpClientCtx() {
ssl_ctx.set_default_verify_paths(); ssl_ctx.set_default_verify_paths();
// Enable peer certificate verification (CVSS 7.4 fix). // 启用对等证书验证 (CVSS 7.4 修复) / Enable peer certificate verification (CVSS 7.4 fix).
// set_default_verify_paths() loads system CA bundle; without verify_peer // set_default_verify_paths() 加载系统 CA 包;没有 verify_peer
// CA 存储不会被查询——任何证书(自签名/过期)都将被接受 / set_default_verify_paths() loads system CA bundle; without verify_peer
// the CA store is never consulted — any cert (self-signed/expired) is accepted. // the CA store is never consulted — any cert (self-signed/expired) is accepted.
// TODO: Windows: set_default_verify_paths() may not locate system CAs; // TODO: Windows: set_default_verify_paths() 可能无法定位系统 CA
// 如果验证失败,设置 SSL_CERT_FILE 环境变量或捆绑 cacert.pem / Windows: set_default_verify_paths() may not locate system CAs;
// if verification fails, set SSL_CERT_FILE env or bundle a cacert.pem. // if verification fails, set SSL_CERT_FILE env or bundle a cacert.pem.
ssl_ctx.set_verify_mode(ssl::verify_peer); ssl_ctx.set_verify_mode(ssl::verify_peer);
} }
}; };
// 核心 HTTP/HTTPS POST支持可选 SSE 流式传输。执行 DNS 解析、
// TLS 握手(含 SNI 和主机名验证),然后发送请求。
// 如果 cb 非空,响应体将逐行解析用于流式传输 / Core HTTP/HTTPS POST with optional SSE streaming. Performs DNS resolve,
// TLS handshake with SNI and hostname verification, then sends the request.
// If `cb` is non-null, response body is parsed line-by-line for streaming.
static int do_post_stream( static int do_post_stream(
const char* host, const char* host,
const char* port, const char* port,
@@ -117,11 +132,11 @@ static int do_post_stream(
return -1; return -1;
} }
// Initialize output // 初始化输出 / Initialize output
*response_body = nullptr; *response_body = nullptr;
*status_code = -1; *status_code = -1;
// Build C++ lambda from C callback // 从 C 回调构建 C++ lambda / Build C++ lambda from C callback
std::function<bool(const std::string&)> on_line; std::function<bool(const std::string&)> on_line;
if (cb) { if (cb) {
on_line = [cb, userdata](const std::string& line) -> bool { on_line = [cb, userdata](const std::string& line) -> bool {
@@ -131,7 +146,7 @@ static int do_post_stream(
HttpClientCtx ctx; HttpClientCtx ctx;
// Read timeouts from config if available // 从配置读取超时设置 / Read timeouts from config if available
if (g_config_svc) { if (g_config_svc) {
const char* ct = g_config_svc->get("http.connect_timeout"); const char* ct = g_config_svc->get("http.connect_timeout");
const char* rt = g_config_svc->get("http.request_timeout"); const char* rt = g_config_svc->get("http.request_timeout");
@@ -147,7 +162,9 @@ static int do_post_stream(
try { try {
tcp::resolver resolver(ctx.ioc); tcp::resolver resolver(ctx.ioc);
// DNS resolve with 10-second timeout. Boost.Asio's synchronous // DNS 解析10 秒超时。Boost.Asio 的同步 resolve()
// 在内部运行 io_context因此定时器的 async_wait 回调在 resolve() 期间执行,
// 并在超时触发时调用 resolver.cancel() / DNS resolve with 10-second timeout. Boost.Asio's synchronous
// resolve() runs the io_context internally, so the timer's async_wait // resolve() runs the io_context internally, so the timer's async_wait
// callback executes during resolve() and calls resolver.cancel() when // callback executes during resolve() and calls resolver.cancel() when
// the deadline fires. // the deadline fires.
@@ -172,7 +189,7 @@ static int do_post_stream(
beast::ssl_stream<beast::tcp_stream> stream(ctx.ioc, ctx.ssl_ctx); beast::ssl_stream<beast::tcp_stream> stream(ctx.ioc, ctx.ssl_ctx);
beast::flat_buffer buffer; beast::flat_buffer buffer;
// SNI hostname // SNI 主机名 / SNI hostname
if (!SSL_set_tlsext_host_name(stream.native_handle(), host)) { if (!SSL_set_tlsext_host_name(stream.native_handle(), host)) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, if (g_host) g_host->log(DSTALK_LOG_ERROR,
"do_post_stream: SNI hostname set failed for %s", host); "do_post_stream: SNI hostname set failed for %s", host);
@@ -180,7 +197,9 @@ static int do_post_stream(
goto done; goto done;
} }
// Hostname verification: require server certificate CN/SAN to match // 主机名验证:要求服务器证书 CN/SAN 匹配 'host'。
// 与 ssl::verify_peer 协同工作——没有它的话,
// 使用不同主机名的有效 CA 签名证书进行 MITM 攻击仍可通过 / Hostname verification: require server certificate CN/SAN to match
// 'host'. This works in conjunction with ssl::verify_peer on the // 'host'. This works in conjunction with ssl::verify_peer on the
// context — without it MITM with a valid CA-signed cert for a // context — without it MITM with a valid CA-signed cert for a
// different hostname would still pass. // different hostname would still pass.
@@ -191,19 +210,19 @@ static int do_post_stream(
goto done; goto done;
} }
// Connect // 连接 / Connect
beast::get_lowest_layer(stream).expires_after( beast::get_lowest_layer(stream).expires_after(
std::chrono::seconds(ctx.connect_timeout)); std::chrono::seconds(ctx.connect_timeout));
beast::get_lowest_layer(stream).connect(endpoints); beast::get_lowest_layer(stream).connect(endpoints);
beast::get_lowest_layer(stream).expires_never(); beast::get_lowest_layer(stream).expires_never();
// SSL handshake // SSL 握手 / SSL handshake
beast::get_lowest_layer(stream).expires_after( beast::get_lowest_layer(stream).expires_after(
std::chrono::seconds(ctx.connect_timeout)); std::chrono::seconds(ctx.connect_timeout));
stream.handshake(ssl::stream_base::client); stream.handshake(ssl::stream_base::client);
beast::get_lowest_layer(stream).expires_never(); beast::get_lowest_layer(stream).expires_never();
// Build HTTP POST request // 构建 HTTP POST 请求 / Build HTTP POST request
http::request<http::string_body> req{http::verb::post, target, 11}; http::request<http::string_body> req{http::verb::post, target, 11};
req.set(http::field::host, host); req.set(http::field::host, host);
req.set(http::field::user_agent, "dstalk/0.1"); req.set(http::field::user_agent, "dstalk/0.1");
@@ -211,19 +230,19 @@ static int do_post_stream(
req.body() = body; req.body() = body;
req.prepare_payload(); req.prepare_payload();
// Add extra headers from JSON // 从 JSON 添加额外的头 / Add extra headers from JSON
auto extra_headers = parse_headers_json(headers_json); auto extra_headers = parse_headers_json(headers_json);
for (const auto& h : extra_headers) { for (const auto& h : extra_headers) {
req.set(h.first, h.second); req.set(h.first, h.second);
} }
// Send // 发送 / Send
beast::get_lowest_layer(stream).expires_after( beast::get_lowest_layer(stream).expires_after(
std::chrono::seconds(ctx.request_timeout)); std::chrono::seconds(ctx.request_timeout));
http::write(stream, req); http::write(stream, req);
beast::get_lowest_layer(stream).expires_never(); beast::get_lowest_layer(stream).expires_never();
// Read response // 读取响应 / Read response
http::response_parser<http::string_body> parser; http::response_parser<http::string_body> parser;
parser.body_limit(16 * 1024 * 1024); parser.body_limit(16 * 1024 * 1024);
beast::get_lowest_layer(stream).expires_after( beast::get_lowest_layer(stream).expires_after(
@@ -310,8 +329,9 @@ done:
} }
// ============================================================ // ============================================================
// Service implementations // 服务实现 / Service implementations
// ============================================================ // ============================================================
// 同步 HTTP POST返回完整响应体 / Synchronous HTTP POST returning the complete response body.
static int http_post_json( static int http_post_json(
const char* host, const char* port, const char* host, const char* port,
const char* target, const char* body, const char* target, const char* body,
@@ -322,6 +342,7 @@ static int http_post_json(
nullptr, nullptr, response_body, status_code); nullptr, nullptr, response_body, status_code);
} }
// HTTP POST 带 SSE 流式传输:响应行到达时通过 cb 回调传递 / HTTP POST with SSE streaming: response lines are delivered to `cb` as they arrive.
static int http_post_stream( static int http_post_stream(
const char* host, const char* port, const char* host, const char* port,
const char* target, const char* body, const char* target, const char* body,
@@ -339,32 +360,35 @@ static dstalk_http_service_t g_service = {
}; };
// ============================================================ // ============================================================
// Plugin lifecycle // 插件生命周期 / Plugin lifecycle
// ============================================================ // ============================================================
// 插件初始化:保存主机指针,查询 config 服务,注册 http 服务 / Plugin init: store host pointer, query config service, register http service.
static int on_init(const dstalk_host_api_t* host) { static int on_init(const dstalk_host_api_t* host) {
g_host = host; g_host = host;
// Query config service (declared dependency) // 查询 config 服务(声明的依赖) / Query config service (declared dependency)
g_config_svc = (dstalk_config_service_t*)host->query_service("config", 1); g_config_svc = (dstalk_config_service_t*)host->query_service("config", 1);
return host->register_service("http", 1, &g_service); return host->register_service("http", 1, &g_service);
} }
// 插件关闭:无需清理 / Plugin shutdown: nothing to clean up.
static void on_shutdown() { static void on_shutdown() {
// nothing to clean up // 无需清理 / nothing to clean up
} }
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
"http", // name "http", // name 名称
"1.0.0", // version "1.0.0", // version 版本
"HTTP/HTTPS client service using Boost.Beast + OpenSSL", // description "HTTP/HTTPS client service using Boost.Beast + OpenSSL", // description 描述
DSTALK_API_VERSION, // api_version DSTALK_API_VERSION, // api_version
{"config", nullptr}, // dependencies {"config", nullptr}, // dependencies 依赖
on_init, // on_init on_init, // on_init
on_shutdown, // on_shutdown on_shutdown, // on_shutdown
nullptr // on_event nullptr // on_event
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info; return &g_info;
} }

View File

@@ -1,6 +1,13 @@
// plugin-session: 会话管理服务插件 /*
// 提供 dstalk_session_service_t vtable 实现 * @file session_plugin.cpp
// 依赖: file_io (save/load 需要文件操作) * @brief Session plugin: conversation message history management with save/load.
* 会话插件:对话消息历史管理,支持保存/加载。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
// plugin-session: 会话管理服务插件 / Session management service plugin
// 提供 dstalk_session_service_t vtable 实现 / Provides dstalk_session_service_t vtable implementation
// 依赖: file_io (save/load 需要文件操作) / Depends on: file_io (save/load needs file operations)
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_types.h" #include "dstalk/dstalk_types.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
@@ -24,14 +31,14 @@
namespace json = boost::json; namespace json = boost::json;
// ============================================================ // ============================================================
// 内部 C++ 数据结构 // 内部 C++ 数据结构 / Internal C++ data structures
// ============================================================ // ============================================================
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release读取无锁 // W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
static std::atomic<const dstalk_host_api_t*> g_host{nullptr}; static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr}; static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
// 内部消息结构C++ 易用,外部暴露 C struct // 内部消息结构C++ 易用,外部暴露 C struct / Internal message struct (C++ friendly, externally exposed as C struct)
struct InternalMessage { struct InternalMessage {
std::string role; std::string role;
std::string content; std::string content;
@@ -39,21 +46,24 @@ struct InternalMessage {
std::string tool_calls_json; std::string tool_calls_json;
}; };
// 会话历史 + 缓存 —— W14.3: mutex 保护读写 // 会话历史 + 缓存 —— W14.3: mutex 保护读写 / Session history + cache — W14.3: mutex protects read/write
static std::vector<InternalMessage> g_history; static std::vector<InternalMessage> g_history;
static std::vector<dstalk_message_t> g_cached_history; static std::vector<dstalk_message_t> g_cached_history;
static std::mutex g_session_mutex; static std::mutex g_session_mutex;
// ============================================================ // ============================================================
// Token 计数工具(内联,避免硬依赖 context 头文件) // Token 计数工具(内联,避免硬依赖 context 头文件) / Token counting utilities (inline, avoids hard dep on context headers)
// ============================================================ // ============================================================
// 如果字节是 ASCII (0x000x7F) 则返回 true / Returns true if the byte is ASCII (0x000x7F).
static bool is_ascii(unsigned char c) { return c < 0x80; } static bool is_ascii(unsigned char c) { return c < 0x80; }
// 启发式判断:如果字节起始一个 UTF-8 CJK 统一表意文字 (0xE40xE9) 则返回 true / Heuristic: returns true if the byte starts a CJK Unified Ideograph in UTF-8 (0xE40xE9).
static bool starts_cjk(unsigned char c) { static bool starts_cjk(unsigned char c) {
return c >= 0xE4 && c <= 0xE9; return c >= 0xE4 && c <= 0xE9;
} }
// 使用启发式 UTF-8 字节计数估算单条消息的 token 数 / Estimate token count for a single message using heuristic UTF-8 byte counting.
static size_t count_tokens_one(const std::string& text) { static size_t count_tokens_one(const std::string& text) {
size_t ascii_chars = 0; size_t ascii_chars = 0;
size_t chinese_chars = 0; size_t chinese_chars = 0;
@@ -85,9 +95,10 @@ static size_t count_tokens_one(const std::string& text) {
} }
size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3); size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3);
return content_tokens + 4; // +4 per message overhead return content_tokens + 4; // +4 每条消息开销 / +4 per message overhead
} }
// 估算所有消息的总 token 数 / Estimate total token count across all messages.
static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) { static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) {
size_t total = 0; size_t total = 0;
for (const auto& m : msgs) { for (const auto& m : msgs) {
@@ -97,13 +108,15 @@ static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) {
} }
// ============================================================ // ============================================================
// 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex // 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex / Helper: rebuild C cached array (caller must hold g_session_mutex)
// ============================================================ // ============================================================
// 从内部消息 vector 重建 C 兼容的缓存历史数组。调用方必须持有 g_session_mutex / Rebuild the C-compatible cached history array from the internal message vector.
// Caller must hold g_session_mutex.
static void rebuild_cached_history_locked() { static void rebuild_cached_history_locked() {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
// 释放旧的字符串 // 释放旧的字符串 / Free old strings
for (auto& m : g_cached_history) { for (auto& m : g_cached_history) {
if (m.role) { host->free(const_cast<char*>(m.role)); } if (m.role) { host->free(const_cast<char*>(m.role)); }
if (m.content) { host->free(const_cast<char*>(m.content)); } if (m.content) { host->free(const_cast<char*>(m.content)); }
@@ -112,7 +125,7 @@ static void rebuild_cached_history_locked() {
} }
g_cached_history.clear(); g_cached_history.clear();
// 重建 // 重建 / Rebuild
g_cached_history.reserve(g_history.size()); g_cached_history.reserve(g_history.size());
for (const auto& im : g_history) { for (const auto& im : g_history) {
dstalk_message_t cm; dstalk_message_t cm;
@@ -125,9 +138,10 @@ static void rebuild_cached_history_locked() {
} }
// ============================================================ // ============================================================
// Session 服务 vtable 实现 (W14.3: try/catch + mutex) // Session 服务 vtable 实现 (W14.3: try/catch + mutex) / Session service vtable implementation (W14.3: try/catch + mutex)
// ============================================================ // ============================================================
// 向对话历史追加一条消息 / Append a message to the conversation history.
static void session_add(const dstalk_message_t* msg) { static void session_add(const dstalk_message_t* msg) {
try { try {
if (!msg) return; if (!msg) return;
@@ -148,11 +162,13 @@ static void session_add(const dstalk_message_t* msg) {
} }
} }
// 清空对话历史中的所有消息 / Clear all messages from the conversation history.
static void session_clear() { static void session_clear() {
std::lock_guard<std::mutex> lock(g_session_mutex); std::lock_guard<std::mutex> lock(g_session_mutex);
g_history.clear(); g_history.clear();
} }
// 将当前对话历史序列化为 JSON 行文件并保存到 path / Serialize the current conversation history to a JSON lines file at `path`.
static int session_save(const char* path) { static int session_save(const char* path) {
try { try {
if (!path) return -1; if (!path) return -1;
@@ -187,6 +203,7 @@ static int session_save(const char* path) {
} }
} }
// 从 JSON 行文件中加载对话历史,替换当前历史 / Load conversation history from a JSON lines file at `path`, replacing current history.
static int session_load(const char* path) { static int session_load(const char* path) {
try { try {
if (!path) return -1; if (!path) return -1;
@@ -246,6 +263,7 @@ static int session_load(const char* path) {
} }
} }
// 返回指向缓存 C 消息数组的指针,并将 *out_count 设置为数组大小 / Return a pointer to the cached C-array of messages and set *out_count to its size.
static const dstalk_message_t* session_history(int* out_count) { static const dstalk_message_t* session_history(int* out_count) {
try { try {
std::lock_guard<std::mutex> lock(g_session_mutex); std::lock_guard<std::mutex> lock(g_session_mutex);
@@ -265,6 +283,7 @@ static const dstalk_message_t* session_history(int* out_count) {
} }
} }
// 返回当前对话历史的估算 token 数 / Return the estimated token count for the current conversation history.
static int session_token_count() { static int session_token_count() {
try { try {
std::lock_guard<std::mutex> lock(g_session_mutex); std::lock_guard<std::mutex> lock(g_session_mutex);
@@ -290,11 +309,12 @@ static dstalk_session_service_t g_session_service = {
}; };
// ============================================================ // ============================================================
// W20.6: 默认会话保存路径(平台标准目录) // W20.6: 默认会话保存路径(平台标准目录) / Default session save path (platform standard directory)
// ============================================================ // ============================================================
// 返回平台特定的默认会话保存路径,根据需要创建目录 / Return the platform-specific default session save path, creating directories as needed.
static std::string get_default_session_path() { static std::string get_default_session_path() {
// W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录 // W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录 / static cache + mkdir guarantee + fallback to current dir on failure
static std::string cached_path = []() -> std::string { static std::string cached_path = []() -> std::string {
#ifdef _WIN32 #ifdef _WIN32
char* buf = nullptr; char* buf = nullptr;
@@ -323,14 +343,17 @@ static std::string get_default_session_path() {
} }
// ============================================================ // ============================================================
// 插件生命周期 // 插件生命周期 / Plugin lifecycle
// ============================================================ // ============================================================
// 插件初始化:保存主机指针,查询 file_io 依赖,注册 session 服务,
// 并从默认路径自动加载已有会话 / Plugin init: store host pointer, query file_io dependency, register session service,
// and auto-load any existing session from the default path.
static int on_init(const dstalk_host_api_t* host) { static int on_init(const dstalk_host_api_t* host) {
try { try {
g_host.store(host, std::memory_order_release); g_host.store(host, std::memory_order_release);
// 查询依赖服务: file_io // 查询依赖服务: file_io / Query dependency service: file_io
void* raw = host->query_service("file_io", 1); void* raw = host->query_service("file_io", 1);
if (!raw) { if (!raw) {
host->log(DSTALK_LOG_ERROR, "[plugin-session] required service 'file_io' not found"); host->log(DSTALK_LOG_ERROR, "[plugin-session] required service 'file_io' not found");
@@ -338,11 +361,11 @@ static int on_init(const dstalk_host_api_t* host) {
} }
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release); g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
// 注册自身服务 // 注册自身服务 / Register own service
int ret = host->register_service("session", 1, &g_session_service); int ret = host->register_service("session", 1, &g_session_service);
if (ret != 0) return ret; if (ret != 0) return ret;
// W20.6: 从默认路径恢复会话(文件不存在则静默失败) // W20.6: 从默认路径恢复会话(文件不存在则静默失败) / Restore session from default path (silent fail if file missing)
session_load(get_default_session_path().c_str()); session_load(get_default_session_path().c_str());
return 0; return 0;
@@ -357,10 +380,13 @@ static int on_init(const dstalk_host_api_t* host) {
} }
} }
// 插件关闭:自动保存会话到默认路径,失败时回退到当前目录,
// 然后释放缓存历史和清空状态 / Plugin shutdown: auto-save session to default path, fallback to current dir on failure,
// then release cached history and clear state.
static void on_shutdown() { static void on_shutdown() {
try { try {
// W20.6: 清空前自动保存到默认路径 // W20.6: 清空前自动保存到默认路径 / Auto-save to default path before clearing
// W21.4: 失败告警 + 当前目录 fallback // W21.4: 失败告警 + 当前目录 fallback / Failure warning + current dir fallback
int ret = session_save(get_default_session_path().c_str()); int ret = session_save(get_default_session_path().c_str());
if (ret != 0) { if (ret != 0) {
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
@@ -389,7 +415,7 @@ static void on_shutdown() {
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
"session", "session",
"1.0.0", "1.0.0",
"Session management plugin with save/load support", "Session management plugin with save/load support / 支持保存/加载的会话管理插件",
DSTALK_API_VERSION, DSTALK_API_VERSION,
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, {"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
on_init, on_init,
@@ -397,6 +423,7 @@ static dstalk_plugin_info_t g_info = {
nullptr nullptr
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info; return &g_info;
} }

View File

@@ -1,6 +1,13 @@
// plugin-tools: 工具注册服务插件 /*
// 提供 dstalk_tools_service_t vtable 实现 * @file tools_plugin.cpp
// 依赖: file_io (内置 file_read / file_write 工具) * @brief Tools plugin: tool registration, schema management, and execution registry.
* 工具插件工具注册、schema 管理和执行注册表。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
// plugin-tools: 工具注册服务插件 / Tool registration service plugin
// 提供 dstalk_tools_service_t vtable 实现 / Provides dstalk_tools_service_t vtable implementation
// 依赖: file_io (内置 file_read / file_write 工具) / Depends on: file_io (built-in file_read / file_write tools)
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_types.h" #include "dstalk/dstalk_types.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
@@ -20,21 +27,22 @@
namespace json = boost::json; namespace json = boost::json;
// ============================================================ // ============================================================
// 路径安全校验 (W14.3: 防止路径遍历攻击) // 路径安全校验 (W14.3: 防止路径遍历攻击) / Path safety validation (W14.3: prevent path traversal attacks)
// ============================================================ // ============================================================
// 验证文件路径是否安全(无绝对路径、无 ".." 遍历、非空) / Validate that a file path is safe (no absolute paths, no ".." traversal, no empty).
static bool is_safe_path(const std::string& path) { static bool is_safe_path(const std::string& path) {
// 拒绝空路径 // 拒绝空路径 / Reject empty path
if (path.empty()) return false; if (path.empty()) return false;
// 拒绝绝对路径: Unix '/' 开头 或 Windows 盘符 (第二字符 ':') // 拒绝绝对路径: Unix '/' 开头 或 Windows 盘符 (第二字符 ':') / Reject absolute paths: Unix '/' prefix or Windows drive letter (second char ':')
if (path[0] == '/' || path[0] == '\\') return false; if (path[0] == '/' || path[0] == '\\') return false;
if (path.size() >= 2 && path[1] == ':') return false; if (path.size() >= 2 && path[1] == ':') return false;
// 拒绝含 ".." 段的目录遍历 // 拒绝含 ".." 段的目录遍历 / Reject directory traversal with ".." segments
if (path.find("..") != std::string::npos) return false; if (path.find("..") != std::string::npos) return false;
// lexical_normal 消解相对组件后再次校验 // lexical_normal 消解相对组件后再次校验 / Re-validate after resolving relative components with lexical_normal
std::string norm = std::filesystem::path(path).lexically_normal().string(); std::string norm = std::filesystem::path(path).lexically_normal().string();
if (norm.empty()) return false; if (norm.empty()) return false;
if (norm[0] == '/' || norm[0] == '\\') return false; if (norm[0] == '/' || norm[0] == '\\') return false;
@@ -45,10 +53,10 @@ static bool is_safe_path(const std::string& path) {
} }
// ============================================================ // ============================================================
// 内部数据结构 // 内部数据结构 / Internal data structures
// ============================================================ // ============================================================
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release读取无锁 // W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
static std::atomic<const dstalk_host_api_t*> g_host{nullptr}; static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr}; static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
@@ -59,14 +67,15 @@ struct ToolDef {
dstalk_tool_handler_fn handler; dstalk_tool_handler_fn handler;
}; };
// W14.3: g_tools 使用 mutex 保护读写 // W14.3: g_tools 使用 mutex 保护读写 / g_tools uses mutex to protect read/write
static std::vector<ToolDef> g_tools; static std::vector<ToolDef> g_tools;
static std::mutex g_tools_mutex; static std::mutex g_tools_mutex;
// ============================================================ // ============================================================
// 内置工具: file_read, file_write // 内置工具: file_read, file_write / Built-in tools: file_read, file_write
// ============================================================ // ============================================================
// 内置工具处理器:读取文件并以 JSON 字符串返回内容 / Built-in tool handler: read a file and return its contents as a JSON string.
static char* builtin_file_read(const char* args_json) { static char* builtin_file_read(const char* args_json) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire); const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
@@ -83,7 +92,7 @@ static char* builtin_file_read(const char* args_json) {
} }
std::string path = json::value_to<std::string>(*path_j); std::string path = json::value_to<std::string>(*path_j);
// W14.3: 路径遍历防护 // W14.3: 路径遍历防护 / Path traversal protection
if (!is_safe_path(path)) { if (!is_safe_path(path)) {
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_read: unsafe path rejected"); if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_read: unsafe path rejected");
return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr; return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr;
@@ -110,6 +119,7 @@ static char* builtin_file_read(const char* args_json) {
} }
} }
// 内置工具处理器:将内容写入文件,返回成功/错误 JSON 对象 / Built-in tool handler: write content to a file, returning a success/error JSON object.
static char* builtin_file_write(const char* args_json) { static char* builtin_file_write(const char* args_json) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire); const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
@@ -132,7 +142,7 @@ static char* builtin_file_write(const char* args_json) {
std::string path = json::value_to<std::string>(*path_j); std::string path = json::value_to<std::string>(*path_j);
std::string content = json::value_to<std::string>(*content_j); std::string content = json::value_to<std::string>(*content_j);
// W14.3: 路径遍历防护 // W14.3: 路径遍历防护 / Path traversal protection
if (!is_safe_path(path)) { if (!is_safe_path(path)) {
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_write: unsafe path rejected"); if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_write: unsafe path rejected");
return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr; return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr;
@@ -155,18 +165,19 @@ static char* builtin_file_write(const char* args_json) {
} }
// ============================================================ // ============================================================
// Tools 服务 vtable 实现 (W14.3: try/catch + mutex) // Tools 服务 vtable 实现 (W14.3: try/catch + mutex) / Tools service vtable implementation (W14.3: try/catch + mutex)
// ============================================================ // ============================================================
static void tools_unregister_tool(const char* name); static void tools_unregister_tool(const char* name);
// 注册命名工具及其描述、JSON Schema 参数和处理函数 / Register a named tool with its description, JSON Schema parameters, and handler function.
static int tools_register_tool(const char* name, const char* desc, static int tools_register_tool(const char* name, const char* desc,
const char* params_schema, const char* params_schema,
dstalk_tool_handler_fn handler) { dstalk_tool_handler_fn handler) {
try { try {
if (!name || !handler) return -1; if (!name || !handler) return -1;
// 如果已存在同名工具,先注销 // 如果已存在同名工具,先注销 / If a tool with the same name exists, unregister first
tools_unregister_tool(name); tools_unregister_tool(name);
ToolDef td; ToolDef td;
@@ -189,6 +200,7 @@ static int tools_register_tool(const char* name, const char* desc,
} }
} }
// 按名称注销之前注册的工具 / Unregister a previously registered tool by name.
static void tools_unregister_tool(const char* name) { static void tools_unregister_tool(const char* name) {
try { try {
if (!name) return; if (!name) return;
@@ -207,6 +219,7 @@ static void tools_unregister_tool(const char* name) {
} }
} }
// 将所有已注册工具序列化为 OpenAI function-calling 格式的 JSON 数组 / Serialize all registered tools into a JSON array in OpenAI function-calling format.
static char* tools_get_tools_json() { static char* tools_get_tools_json() {
try { try {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
@@ -249,6 +262,7 @@ static char* tools_get_tools_json() {
} }
} }
// 按名称查找工具并分派执行到注册的处理器 / Look up a tool by name and dispatch execution to its registered handler.
static char* tools_execute(const char* name, const char* args_json) { static char* tools_execute(const char* name, const char* args_json) {
try { try {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
@@ -298,14 +312,15 @@ static dstalk_tools_service_t g_tools_service = {
}; };
// ============================================================ // ============================================================
// 插件生命周期 // 插件生命周期 / Plugin lifecycle
// ============================================================ // ============================================================
// 插件初始化:查询 file_io 依赖,注册内置文件工具,注册 tools 服务 / Plugin init: query file_io dependency, register built-in file tools, register tools service.
static int on_init(const dstalk_host_api_t* host) { static int on_init(const dstalk_host_api_t* host) {
try { try {
g_host.store(host, std::memory_order_release); g_host.store(host, std::memory_order_release);
// 查询依赖服务: file_io // 查询依赖服务: file_io / Query dependency service: file_io
void* raw = host->query_service("file_io", 1); void* raw = host->query_service("file_io", 1);
if (!raw) { if (!raw) {
host->log(DSTALK_LOG_ERROR, "[plugin-tools] required service 'file_io' not found"); host->log(DSTALK_LOG_ERROR, "[plugin-tools] required service 'file_io' not found");
@@ -313,7 +328,7 @@ static int on_init(const dstalk_host_api_t* host) {
} }
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release); g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
// 向自身注册内置工具 // 向自身注册内置工具 / Register built-in tools with self
tools_register_tool( tools_register_tool(
"file_read", "file_read",
"Read the contents of a file at the given path", "Read the contents of a file at the given path",
@@ -340,6 +355,7 @@ static int on_init(const dstalk_host_api_t* host) {
} }
} }
// 插件关闭:清空所有已注册工具并清空服务指针 / Plugin shutdown: clear all registered tools and null out service pointers.
static void on_shutdown() { static void on_shutdown() {
try { try {
std::lock_guard<std::mutex> lock(g_tools_mutex); std::lock_guard<std::mutex> lock(g_tools_mutex);
@@ -358,7 +374,7 @@ static void on_shutdown() {
static dstalk_plugin_info_t g_info = { static dstalk_plugin_info_t g_info = {
"tools", "tools",
"1.0.0", "1.0.0",
"Tool registration and execution plugin with built-in file tools", "Tool registration and execution plugin with built-in file tools / 内置文件工具的工具注册和执行插件",
DSTALK_API_VERSION, DSTALK_API_VERSION,
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, {"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
on_init, on_init,
@@ -366,6 +382,7 @@ static dstalk_plugin_info_t g_info = {
nullptr nullptr
}; };
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info; return &g_info;
} }

View File

@@ -1,8 +1,12 @@
// ============================================================================ /*
// anthropic_plugin_test.cpp — Anthropic AI 插件单元测试 * @file anthropic_plugin_test.cpp
// W21.6 (qa-wang): 覆盖 SSE 解析 / JSON 请求构建 / URL 解析 / 安全擦除 * @brief Anthropic AI plugin unit tests: SSE parsing (parse_sse_data edge cases),
// 通过 #include plugin source 访问 file-scope static 函数 * request building (build_request_json), header construction, URL parsing
// ============================================================================ * (extract_host_port), secure_zero, and null-safety for free_result/configure.
* Anthropic AI 插件单元测试SSE 解析parse_sse_data 边界情况、请求构建build_request_json
* 头部构造、URL 解析extract_host_port、secure_zero、free_result/configure 空指针安全。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#define BOOST_JSON_HEADER_ONLY #define BOOST_JSON_HEADER_ONLY
#define BOOST_ALL_NO_LIB #define BOOST_ALL_NO_LIB
#include "../plugins/anthropic/src/anthropic_plugin.cpp" #include "../plugins/anthropic/src/anthropic_plugin.cpp"
@@ -12,6 +16,7 @@
#include <string> #include <string>
static int g_failures = 0; static int g_failures = 0;
// Lightweight assertion macro: increments g_failures counter on failure
#define CHECK(cond, msg) do { \ #define CHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \ std::cout << "[OK] " << (msg) << "\n"; \
@@ -22,6 +27,8 @@ static int g_failures = 0;
} while (0) } while (0)
// Test helper: populate g_cfg for build functions // Test helper: populate g_cfg for build functions
// Test helper: populate g_cfg with valid anthropic defaults before build_* tests
// 测试辅助函数:为 build_* 测试填充 g_cfg 的有效 anthropic 默认值
static void setup_config() { static void setup_config() {
g_cfg.provider = "anthropic"; g_cfg.provider = "anthropic";
g_cfg.base_url = "https://api.anthropic.com"; g_cfg.base_url = "https://api.anthropic.com";
@@ -31,10 +38,18 @@ static void setup_config() {
g_cfg.temperature = 0.7; g_cfg.temperature = 0.7;
} }
// Anthropic 插件测试 (W21.6)parse_sse_data 畸形/无效 JSON、content_block_delta 文本提取、
// message_stop/忽略类型、深层/边界结构、build_request_json 基础+边界、build_headers_json、
// extract_host_port、secure_zero、my_free_result 空指针安全、my_configure 空指针安全。
// Anthropic plugin tests (W21.6): parse_sse_data for malformed/invalid JSON,
// content_block_delta text extraction, message_stop/ignored types, deep/edge structures,
// build_request_json basics+edges, build_headers_json, extract_host_port,
// secure_zero, my_free_result null-safety, and my_configure null-safety.
int main() int main()
{ {
// ================================================================ // ================================================================
// Test Block 1: parse_sse_data — invalid/malformed inputs // Test Block 1: parse_sse_data — invalid/malformed inputs
// 测试块 1parse_sse_data — 无效/畸形输入
// ================================================================ // ================================================================
std::cout << "\n--- Block 1: parse_sse_data invalid/malformed ---\n"; std::cout << "\n--- Block 1: parse_sse_data invalid/malformed ---\n";
@@ -69,14 +84,14 @@ int main()
} }
{ {
// Malformed JSON: unclosed brace // Malformed JSON: unclosed brace / 畸形 JSON未闭合的花括号
std::string token; std::string token;
bool ret = parse_sse_data("{\"type\":\"ping\"", token, nullptr); bool ret = parse_sse_data("{\"type\":\"ping\"", token, nullptr);
CHECK(!ret, "T1.6: malformed JSON (unclosed brace) returns false (no crash)"); CHECK(!ret, "T1.6: malformed JSON (unclosed brace) returns false (no crash)");
} }
{ {
// Random garbage bytes // Random garbage bytes / 随机垃圾字节
std::string token; std::string token;
bool ret = parse_sse_data("\x00\x01\xFF\xFE", token, nullptr); bool ret = parse_sse_data("\x00\x01\xFF\xFE", token, nullptr);
CHECK(!ret, "T1.7: binary garbage returns false (no crash)"); CHECK(!ret, "T1.7: binary garbage returns false (no crash)");
@@ -84,6 +99,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 2: parse_sse_data — content_block_delta // Test Block 2: parse_sse_data — content_block_delta
// 测试块 2parse_sse_data — content_block_delta
// ================================================================ // ================================================================
std::cout << "\n--- Block 2: parse_sse_data content_block_delta ---\n"; std::cout << "\n--- Block 2: parse_sse_data content_block_delta ---\n";
@@ -146,6 +162,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 3: parse_sse_data — message_stop / ignored types // Test Block 3: parse_sse_data — message_stop / ignored types
// 测试块 3parse_sse_data — message_stop / 忽略的类型
// ================================================================ // ================================================================
std::cout << "\n--- Block 3: parse_sse_data message_stop / ignored types ---\n"; std::cout << "\n--- Block 3: parse_sse_data message_stop / ignored types ---\n";
@@ -194,11 +211,12 @@ int main()
// ================================================================ // ================================================================
// Test Block 4: parse_sse_data — deeply nested / edge structures // Test Block 4: parse_sse_data — deeply nested / edge structures
// 测试块 4parse_sse_data — 深层嵌套 / 边界结构
// ================================================================ // ================================================================
std::cout << "\n--- Block 4: parse_sse_data deep/edge structures ---\n"; std::cout << "\n--- Block 4: parse_sse_data deep/edge structures ---\n";
{ {
// Unrecognized event type should just be ignored // Unrecognized event type should just be ignored / 未识别的事件类型应被忽略
std::string token; std::string token;
const char* json = "{\"type\":\"some_unknown_future_type\"}"; const char* json = "{\"type\":\"some_unknown_future_type\"}";
bool ret = parse_sse_data(json, token, nullptr); bool ret = parse_sse_data(json, token, nullptr);
@@ -206,7 +224,7 @@ int main()
} }
{ {
// text_delta with unicode content (Japanese) // text_delta with unicode content (Japanese) / 含 unicode 内容的 text_delta日语
std::string token; std::string token;
const char* json = const char* json =
"{\"type\":\"content_block_delta\"," "{\"type\":\"content_block_delta\","
@@ -221,6 +239,7 @@ int main()
{ {
// Realistic Anthropic SSE chunk (content_block_delta + text_delta) // Realistic Anthropic SSE chunk (content_block_delta + text_delta)
// 真实的 Anthropic SSE 数据块content_block_delta + text_delta
std::string token; std::string token;
const char* json = const char* json =
"{\"type\":\"content_block_delta\"," "{\"type\":\"content_block_delta\","
@@ -233,12 +252,13 @@ int main()
// ================================================================ // ================================================================
// Test Block 5: build_request_json — basic cases // Test Block 5: build_request_json — basic cases
// 测试块 5build_request_json — 基础用例
// ================================================================ // ================================================================
setup_config(); setup_config();
std::cout << "\n--- Block 5: build_request_json basic ---\n"; std::cout << "\n--- Block 5: build_request_json basic ---\n";
{ {
// Single user input, no history, stream=false // Single user input, no history, stream=false / 单一用户输入无历史stream=false
std::string json = build_request_json(nullptr, 0, "Hello", "", false); std::string json = build_request_json(nullptr, 0, "Hello", "", false);
CHECK(!json.empty(), "T5.1: non-empty JSON produced"); CHECK(!json.empty(), "T5.1: non-empty JSON produced");
CHECK(json.find("\"messages\"") != std::string::npos, CHECK(json.find("\"messages\"") != std::string::npos,
@@ -257,7 +277,7 @@ int main()
} }
{ {
// With system message in history // With system message in history / 历史中包含系统消息
dstalk_message_t msgs[1] = { dstalk_message_t msgs[1] = {
{"system", "You are a helpful assistant", nullptr, nullptr} {"system", "You are a helpful assistant", nullptr, nullptr}
}; };
@@ -268,6 +288,7 @@ int main()
"T5.9: system prompt content present"); "T5.9: system prompt content present");
// messages should NOT contain the system role // messages should NOT contain the system role
// (since system messages are stripped from messages[] and put in system field) // (since system messages are stripped from messages[] and put in system field)
// messages 不应包含 system 角色(系统消息从 messages[] 中提取出来,放入 system 字段)
// Actually, the code puts non-system into msgs. Let me check if system is in messages... // Actually, the code puts non-system into msgs. Let me check if system is in messages...
// The loop skips system: `if (m.role && strcmp(m.role, "system")==0) { ... continue; }` // The loop skips system: `if (m.role && strcmp(m.role, "system")==0) { ... continue; }`
// So system should NOT be in the messages array. // So system should NOT be in the messages array.
@@ -276,7 +297,7 @@ int main()
} }
{ {
// With user+assistant history // With user+assistant history / 包含 user+assistant 历史
dstalk_message_t msgs[2] = { dstalk_message_t msgs[2] = {
{"user", "What is 2+2?", nullptr, nullptr}, {"user", "What is 2+2?", nullptr, nullptr},
{"assistant", "It is 4.", nullptr, nullptr} {"assistant", "It is 4.", nullptr, nullptr}
@@ -292,11 +313,12 @@ int main()
// ================================================================ // ================================================================
// Test Block 6: build_request_json — edge cases // Test Block 6: build_request_json — edge cases
// 测试块 6build_request_json — 边界情况
// ================================================================ // ================================================================
std::cout << "\n--- Block 6: build_request_json edge cases ---\n"; std::cout << "\n--- Block 6: build_request_json edge cases ---\n";
{ {
// Empty user input // Empty user input / 空用户输入
std::string json = build_request_json(nullptr, 0, "", "", false); std::string json = build_request_json(nullptr, 0, "", "", false);
CHECK(!json.empty(), "T6.1: empty user input produces valid JSON"); CHECK(!json.empty(), "T6.1: empty user input produces valid JSON");
CHECK(json.find("\"role\":\"user\"") != std::string::npos, CHECK(json.find("\"role\":\"user\"") != std::string::npos,
@@ -314,7 +336,7 @@ int main()
} }
{ {
// Temperature in valid range -> should be included // Temperature in valid range -> should be included / 有效范围内的 temperature -> 应包含
g_cfg.temperature = 1.0; g_cfg.temperature = 1.0;
std::string json = build_request_json(nullptr, 0, "Hi", "", false); std::string json = build_request_json(nullptr, 0, "Hi", "", false);
CHECK(json.find("\"temperature\"") != std::string::npos, CHECK(json.find("\"temperature\"") != std::string::npos,
@@ -323,7 +345,7 @@ int main()
} }
{ {
// Temperature out of range -> should NOT be included // Temperature out of range -> should NOT be included / 超出范围的 temperature -> 不应包含
g_cfg.temperature = 1.5; g_cfg.temperature = 1.5;
std::string json = build_request_json(nullptr, 0, "Hi", "", false); std::string json = build_request_json(nullptr, 0, "Hi", "", false);
CHECK(json.find("\"temperature\"") == std::string::npos, CHECK(json.find("\"temperature\"") == std::string::npos,
@@ -336,7 +358,7 @@ int main()
} }
{ {
// History with null role (should default to "") // History with null role (should default to "") / null 角色的历史(应默认为 ""
dstalk_message_t msgs[1] = { dstalk_message_t msgs[1] = {
{nullptr, "some content", nullptr, nullptr} {nullptr, "some content", nullptr, nullptr}
}; };
@@ -345,7 +367,7 @@ int main()
} }
{ {
// History with null content // History with null content / null 内容的历史
dstalk_message_t msgs[1] = { dstalk_message_t msgs[1] = {
{"user", nullptr, nullptr, nullptr} {"user", nullptr, nullptr, nullptr}
}; };
@@ -355,6 +377,7 @@ int main()
{ {
// Very long message (>2000 chars) — validate no truncation / crash // Very long message (>2000 chars) — validate no truncation / crash
// 超长消息 (>2000 字符) — 验证无截断/崩溃
std::string long_input(5000, 'A'); std::string long_input(5000, 'A');
std::string json = build_request_json(nullptr, 0, long_input, "", false); std::string json = build_request_json(nullptr, 0, long_input, "", false);
CHECK(!json.empty(), "T6.10: 5000-char input produces valid JSON"); CHECK(!json.empty(), "T6.10: 5000-char input produces valid JSON");
@@ -362,7 +385,7 @@ int main()
} }
{ {
// Multiple system messages concatenated // Multiple system messages concatenated / 多条系统消息拼接
dstalk_message_t msgs[2] = { dstalk_message_t msgs[2] = {
{"system", "Rule 1: be polite", nullptr, nullptr}, {"system", "Rule 1: be polite", nullptr, nullptr},
{"system", "Rule 2: be concise", nullptr, nullptr} {"system", "Rule 2: be concise", nullptr, nullptr}
@@ -376,6 +399,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 7: build_headers_json // Test Block 7: build_headers_json
// 测试块 7build_headers_json
// ================================================================ // ================================================================
std::cout << "\n--- Block 7: build_headers_json ---\n"; std::cout << "\n--- Block 7: build_headers_json ---\n";
@@ -392,7 +416,7 @@ int main()
} }
{ {
// With empty API key // With empty API key / 空 API key
std::string saved = g_cfg.api_key; std::string saved = g_cfg.api_key;
g_cfg.api_key = ""; g_cfg.api_key = "";
std::string headers = build_headers_json(); std::string headers = build_headers_json();
@@ -405,6 +429,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 8: extract_host_port // Test Block 8: extract_host_port
// 测试块 8extract_host_port
// ================================================================ // ================================================================
std::cout << "\n--- Block 8: extract_host_port ---\n"; std::cout << "\n--- Block 8: extract_host_port ---\n";
@@ -473,6 +498,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 9: secure_zero // Test Block 9: secure_zero
// 测试块 9secure_zero
// ================================================================ // ================================================================
std::cout << "\n--- Block 9: secure_zero ---\n"; std::cout << "\n--- Block 9: secure_zero ---\n";
@@ -488,7 +514,7 @@ int main()
} }
{ {
// Zero-length should not crash // Zero-length should not crash / 零长度不应崩溃
char buf[4] = {1,2,3,4}; char buf[4] = {1,2,3,4};
secure_zero(buf, 0); secure_zero(buf, 0);
CHECK(buf[0] == 1 && buf[3] == 4, CHECK(buf[0] == 1 && buf[3] == 4,
@@ -496,18 +522,20 @@ int main()
} }
{ {
// Null pointer + zero length = no-op // Null pointer + zero length = no-op / 空指针 + 零长度 = 空操作
secure_zero(nullptr, 0); secure_zero(nullptr, 0);
CHECK(true, "T9.3: secure_zero(nullptr, 0) does not crash"); CHECK(true, "T9.3: secure_zero(nullptr, 0) does not crash");
} }
// ================================================================ // ================================================================
// Test Block 10: my_free_result — null safety // Test Block 10: my_free_result — null safety
// 测试块 10my_free_result — 空指针安全
// ================================================================ // ================================================================
std::cout << "\n--- Block 10: my_free_result null safety ---\n"; std::cout << "\n--- Block 10: my_free_result null safety ---\n";
{ {
// g_host is nullptr, so free_result should early-return // g_host is nullptr, so free_result should early-return
// g_host 为 nullptrfree_result 应提前返回
my_free_result(nullptr); my_free_result(nullptr);
CHECK(true, "T10.1: free_result(nullptr) does not crash (null host)"); CHECK(true, "T10.1: free_result(nullptr) does not crash (null host)");
} }
@@ -524,11 +552,13 @@ int main()
// ================================================================ // ================================================================
// Test Block 11: my_configure — null host safety // Test Block 11: my_configure — null host safety
// 测试块 11my_configure — null host 安全
// ================================================================ // ================================================================
std::cout << "\n--- Block 11: my_configure null host safety ---\n"; std::cout << "\n--- Block 11: my_configure null host safety ---\n";
{ {
// g_host is nullptr, configure should still return 0 (log skipped) // g_host is nullptr, configure should still return 0 (log skipped)
// g_host 为 nullptrconfigure 仍应返回 0跳过日志
int ret = my_configure( int ret = my_configure(
"anthropic", "https://api.anthropic.com", "anthropic", "https://api.anthropic.com",
"sk-key", "claude-sonnet", 2048, 0.5); "sk-key", "claude-sonnet", 2048, 0.5);
@@ -539,13 +569,13 @@ int main()
} }
{ {
// Null string params — should not crash // Null string params — should not crash / null 字符串参数 — 不应崩溃
int ret = my_configure(nullptr, nullptr, nullptr, nullptr, 4096, 0.7); int ret = my_configure(nullptr, nullptr, nullptr, nullptr, 4096, 0.7);
CHECK(ret == 0, "T11.5: my_configure with all-null strings returns 0"); CHECK(ret == 0, "T11.5: my_configure with all-null strings returns 0");
} }
// ================================================================ // ================================================================
// Summary // Summary / 总结
// ================================================================ // ================================================================
std::cout << "\n"; std::cout << "\n";
if (g_failures == 0) { if (g_failures == 0) {

View File

@@ -1,9 +1,10 @@
// ============================================================================ /*
// context_plugin_test.cpp — 上下文插件单元测试 * @file context_plugin_test.cpp
// ============================================================================ * @brief Context plugin unit tests: token counting (ASCII, CJK, mixed, emoji),
// W18.1 (qa-wang + architect-lin): 覆盖 token 计数、trim、UTF-8 边界、 * UTF-8 truncation safety, trim edge cases, and system message preservation.
// 0xC0/0xC1 过短编码检测。修复 F-11.1-3/4/5/6 后补充测试 * Context 插件单元测试token 计数ASCII、CJK、混合、emoji、UTF-8 截断安全、trim 边界情况、系统消息保留
// ============================================================================ * Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
@@ -14,6 +15,7 @@
#include <string> #include <string>
static int g_failures = 0; static int g_failures = 0;
// Lightweight assertion macro: increments g_failures counter on failure
#define CHECK(cond, msg) do { \ #define CHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \ std::cout << "[OK] " << (msg) << "\n"; \
@@ -23,6 +25,12 @@ static int g_failures = 0;
} \ } \
} while (0) } while (0)
// Context 插件测试token 计数边界null、空、ASCII、CJK、混合、截断 UTF-8 边界保护 (F-11.1-4)、
// 0xC0/0xC1 超长编码 (F-11.1-6)、多消息 token、trim 的各种场景、系统消息保留、4 字节 emoji、孤立的续字节。
// Context plugin tests: token counting edge cases (null, empty, ASCII, CJK, mixed),
// truncated UTF-8 bounds protection (F-11.1-4), 0xC0/0xC1 overlong encoding (F-11.1-6),
// multiple-message tokens, trim null/edge/within-limit/exceeds-limit scenarios,
// system message preservation, 4-byte emoji, and lone continuation bytes.
int main() int main()
{ {
const auto dir = std::filesystem::temp_directory_path() / "dstalk-ctx-test"; const auto dir = std::filesystem::temp_directory_path() / "dstalk-ctx-test";
@@ -54,6 +62,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 1: count_tokens edge cases (null / empty) // Test Block 1: count_tokens edge cases (null / empty)
// 测试块 1count_tokens 边界情况null / 空)
// ================================================================ // ================================================================
std::cout << "\n--- Block 1: count_tokens edge cases ---\n"; std::cout << "\n--- Block 1: count_tokens edge cases ---\n";
@@ -77,6 +86,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 2: count_tokens — ASCII // Test Block 2: count_tokens — ASCII
// 测试块 2count_tokens — ASCII
// ================================================================ // ================================================================
std::cout << "\n--- Block 2: count_tokens ASCII ---\n"; std::cout << "\n--- Block 2: count_tokens ASCII ---\n";
@@ -107,6 +117,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 3: count_tokens — Chinese (CJK U+4E00-U+9FFF) // Test Block 3: count_tokens — Chinese (CJK U+4E00-U+9FFF)
// 测试块 3count_tokens — 中文 (CJK U+4E00-U+9FFF)
// ================================================================ // ================================================================
std::cout << "\n--- Block 3: count_tokens Chinese (CJK) ---\n"; std::cout << "\n--- Block 3: count_tokens Chinese (CJK) ---\n";
@@ -132,6 +143,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 4: count_tokens — Mixed content // Test Block 4: count_tokens — Mixed content
// 测试块 4count_tokens — 混合内容
// ================================================================ // ================================================================
std::cout << "\n--- Block 4: count_tokens mixed content ---\n"; std::cout << "\n--- Block 4: count_tokens mixed content ---\n";
@@ -146,6 +158,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 5: Truncated UTF-8 bounds protection (F-11.1-4) // Test Block 5: Truncated UTF-8 bounds protection (F-11.1-4)
// 测试块 5截断 UTF-8 边界保护 (F-11.1-4)
// ================================================================ // ================================================================
std::cout << "\n--- Block 5: Truncated UTF-8 (F-11.1-4 fix) ---\n"; std::cout << "\n--- Block 5: Truncated UTF-8 (F-11.1-4 fix) ---\n";
@@ -197,6 +210,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6) // Test Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6)
// 测试块 60xC0/0xC1 超长编码 (F-11.1-6)
// ================================================================ // ================================================================
std::cout << "\n--- Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6 fix) ---\n"; std::cout << "\n--- Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6 fix) ---\n";
@@ -230,6 +244,7 @@ int main()
{ {
// Verify 0xC0/0xC1 are NOT treated as valid 2-byte sequences // Verify 0xC0/0xC1 are NOT treated as valid 2-byte sequences
// They should each count as 1 other_char, not as 2-byte sequence // They should each count as 1 other_char, not as 2-byte sequence
// 验证 0xC0/0xC1 不被视为合法的 2 字节序列 / 它们每个应计为 1 个 other_char而非 2 字节序列
// 0xC0 + 0xC1 + 2 ASCII = 2 other + 2 ascii // 0xC0 + 0xC1 + 2 ASCII = 2 other + 2 ascii
// = (2/3) + (2/4) + 4 overhead = 0 + 0 + 4 = 4 // = (2/3) + (2/4) + 4 overhead = 0 + 0 + 4 = 4
// Actually 2/4 = 0 (integer division) for ascii, 2/3 = 0 for other // Actually 2/4 = 0 (integer division) for ascii, 2/3 = 0 for other
@@ -244,6 +259,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 7: count_tokens — multiple messages // Test Block 7: count_tokens — multiple messages
// 测试块 7count_tokens — 多消息
// ================================================================ // ================================================================
std::cout << "\n--- Block 7: multiple messages ---\n"; std::cout << "\n--- Block 7: multiple messages ---\n";
@@ -275,6 +291,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 8: trim — null and edge cases // Test Block 8: trim — null and edge cases
// 测试块 8trim — null 和边界情况
// ================================================================ // ================================================================
std::cout << "\n--- Block 8: trim edge cases ---\n"; std::cout << "\n--- Block 8: trim edge cases ---\n";
@@ -291,6 +308,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 9: trim — within limit (no trimming needed) // Test Block 9: trim — within limit (no trimming needed)
// 测试块 9trim — 预算内(无需裁剪)
// ================================================================ // ================================================================
std::cout << "\n--- Block 9: trim within limit ---\n"; std::cout << "\n--- Block 9: trim within limit ---\n";
@@ -320,6 +338,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 10: trim — exceeds limit (trimming required) // Test Block 10: trim — exceeds limit (trimming required)
// 测试块 10trim — 超预算(需要裁剪)
// ================================================================ // ================================================================
std::cout << "\n--- Block 10: trim exceeds limit ---\n"; std::cout << "\n--- Block 10: trim exceeds limit ---\n";
@@ -358,6 +377,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 11: trim — system message preservation // Test Block 11: trim — system message preservation
// 测试块 11trim — 系统消息保留
// ================================================================ // ================================================================
std::cout << "\n--- Block 11: trim preserves system messages ---\n"; std::cout << "\n--- Block 11: trim preserves system messages ---\n";
@@ -387,11 +407,12 @@ int main()
// ================================================================ // ================================================================
// Test Block 12: count_tokens — 4-byte UTF-8 (emoji / supplementary) // Test Block 12: count_tokens — 4-byte UTF-8 (emoji / supplementary)
// 测试块 12count_tokens — 4 字节 UTF-8emoji / 补充平面)
// ================================================================ // ================================================================
std::cout << "\n--- Block 12: 4-byte UTF-8 ---\n"; std::cout << "\n--- Block 12: 4-byte UTF-8 ---\n";
{ {
// U+1F600 (😀) = F0 9F 98 80 // U+1F600 (<EFBFBD><EFBFBD>) = F0 9F 98 80
char buf[6] = {static_cast<char>(0xF0), static_cast<char>(0x9F), char buf[6] = {static_cast<char>(0xF0), static_cast<char>(0x9F),
static_cast<char>(0x98), static_cast<char>(0x80), '\0'}; static_cast<char>(0x98), static_cast<char>(0x80), '\0'};
dstalk_message_t msg = {"user", buf, nullptr, nullptr}; dstalk_message_t msg = {"user", buf, nullptr, nullptr};
@@ -403,6 +424,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 13: count_tokens — continuation bytes as lone chars // Test Block 13: count_tokens — continuation bytes as lone chars
// 测试块 13count_tokens — 孤立的续字节
// ================================================================ // ================================================================
std::cout << "\n--- Block 13: lone continuation bytes ---\n"; std::cout << "\n--- Block 13: lone continuation bytes ---\n";

View File

@@ -1,8 +1,12 @@
// ============================================================================ /*
// deepseek_plugin_test.cpp — DeepSeek AI 插件单元测试 * @file deepseek_plugin_test.cpp
// W21.6 (qa-wang): 覆盖 SSE 解析 / [DONE] 匹配 / JSON 请求构建 / tool_calls * @brief DeepSeek AI plugin unit tests: SSE parsing (parse_sse_line edge cases),
// 通过 #include plugin source 访问 file-scope static 函数 * [DONE] sentinel matching, tool_calls delta extraction, request building,
// ============================================================================ * append_history, extract_host_port, secure_zero, and null-safety.
* DeepSeek AI 插件单元测试SSE 解析parse_sse_line 边界情况)、[DONE] 标记匹配、
* tool_calls delta 提取、请求构建、append_history、extract_host_port、secure_zero、空指针安全。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#define BOOST_JSON_HEADER_ONLY #define BOOST_JSON_HEADER_ONLY
#define BOOST_ALL_NO_LIB #define BOOST_ALL_NO_LIB
#include "../plugins/deepseek/src/deepseek_plugin.cpp" #include "../plugins/deepseek/src/deepseek_plugin.cpp"
@@ -12,6 +16,7 @@
#include <string> #include <string>
static int g_failures = 0; static int g_failures = 0;
// Lightweight assertion macro: increments g_failures counter on failure
#define CHECK(cond, msg) do { \ #define CHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \ std::cout << "[OK] " << (msg) << "\n"; \
@@ -22,6 +27,8 @@ static int g_failures = 0;
} while (0) } while (0)
// Test helper: populate g_cfg for build functions // Test helper: populate g_cfg for build functions
// Test helper: populate g_cfg with valid deepseek defaults before build_* tests
// 测试辅助函数:为 build_* 测试填充 g_cfg 的有效 deepseek 默认值
static void setup_config() { static void setup_config() {
g_cfg.provider = "deepseek"; g_cfg.provider = "deepseek";
g_cfg.base_url = "https://api.deepseek.com/v1"; g_cfg.base_url = "https://api.deepseek.com/v1";
@@ -31,10 +38,19 @@ static void setup_config() {
g_cfg.temperature = 0.7; g_cfg.temperature = 0.7;
} }
// DeepSeek 插件测试 (W21.6)parse_sse_line 无效/畸形输入、[DONE] 标记及空白变体、
// content delta 提取、tool_calls delta 累积、build_request_json基础、tools、边界
// build_headers_json、extract_host_port、secure_zero、append_history所有消息类型
// my_free_result、my_configure。
// DeepSeek plugin tests (W21.6): parse_sse_line invalid/malformed inputs, [DONE] sentinel
// with whitespace variants, content delta extraction, tool_calls delta accumulation,
// build_request_json (basic, tools, edge cases), build_headers_json, extract_host_port,
// secure_zero, append_history (all message types), my_free_result, and my_configure.
int main() int main()
{ {
// ================================================================ // ================================================================
// Test Block 1: parse_sse_line — invalid/malformed inputs // Test Block 1: parse_sse_line — invalid/malformed inputs
// 测试块 1parse_sse_line — 无效/畸形输入
// ================================================================ // ================================================================
std::cout << "\n--- Block 1: parse_sse_line invalid/malformed ---\n"; std::cout << "\n--- Block 1: parse_sse_line invalid/malformed ---\n";
@@ -58,27 +74,28 @@ int main()
{ {
// "data:" without space — rfind("data: ", 0) should fail // "data:" without space — rfind("data: ", 0) should fail
// "data:" 无空格 — rfind("data: ", 0) 应失败
std::string token; std::string token;
bool ret = parse_sse_line("data:{\"x\":1}", token, nullptr); bool ret = parse_sse_line("data:{\"x\":1}", token, nullptr);
CHECK(!ret, "T1.4: 'data:' without trailing space returns false (rfind mismatch)"); CHECK(!ret, "T1.4: 'data:' without trailing space returns false (rfind mismatch)");
} }
{ {
// "data: " followed by invalid JSON // "data: " followed by invalid JSON / "data: " 后跟无效 JSON
std::string token; std::string token;
bool ret = parse_sse_line("data: not valid json!!!", token, nullptr); bool ret = parse_sse_line("data: not valid json!!!", token, nullptr);
CHECK(!ret, "T1.5: 'data: ' + invalid JSON returns false (no crash)"); CHECK(!ret, "T1.5: 'data: ' + invalid JSON returns false (no crash)");
} }
{ {
// "data: " followed by binary garbage // "data: " followed by binary garbage / "data: " 后跟二进制垃圾
std::string token; std::string token;
bool ret = parse_sse_line("data: \x00\x01\xFF\xFE", token, nullptr); bool ret = parse_sse_line("data: \x00\x01\xFF\xFE", token, nullptr);
CHECK(!ret, "T1.6: 'data: ' + binary garbage returns false (no crash)"); CHECK(!ret, "T1.6: 'data: ' + binary garbage returns false (no crash)");
} }
{ {
// Empty data after "data: " // Empty data after "data: " / "data: " 后数据为空
std::string token; std::string token;
bool ret = parse_sse_line("data: ", token, nullptr); bool ret = parse_sse_line("data: ", token, nullptr);
CHECK(!ret, "T1.7: 'data: ' with empty payload returns false"); CHECK(!ret, "T1.7: 'data: ' with empty payload returns false");
@@ -86,6 +103,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 2: parse_sse_line — [DONE] sentinel // Test Block 2: parse_sse_line — [DONE] sentinel
// 测试块 2parse_sse_line — [DONE] 标记
// ================================================================ // ================================================================
std::cout << "\n--- Block 2: parse_sse_line [DONE] sentinel ---\n"; std::cout << "\n--- Block 2: parse_sse_line [DONE] sentinel ---\n";
@@ -97,7 +115,7 @@ int main()
} }
{ {
// [DONE] with leading whitespace // [DONE] with leading whitespace / [DONE] 前导空白
std::string token; std::string token;
bool ret = parse_sse_line("data: [DONE]", token, nullptr); bool ret = parse_sse_line("data: [DONE]", token, nullptr);
CHECK(ret, "T2.3: 'data: [DONE]' (leading spaces) returns true"); CHECK(ret, "T2.3: 'data: [DONE]' (leading spaces) returns true");
@@ -105,7 +123,7 @@ int main()
} }
{ {
// [DONE] with trailing whitespace // [DONE] with trailing whitespace / [DONE] 尾部空白
std::string token; std::string token;
bool ret = parse_sse_line("data: [DONE] ", token, nullptr); bool ret = parse_sse_line("data: [DONE] ", token, nullptr);
CHECK(ret, "T2.5: 'data: [DONE] ' (trailing spaces) returns true"); CHECK(ret, "T2.5: 'data: [DONE] ' (trailing spaces) returns true");
@@ -113,7 +131,7 @@ int main()
} }
{ {
// [DONE] with tabs and newlines around it // [DONE] with tabs and newlines around it / [DONE] 周围有制表符和换行符
std::string token; std::string token;
bool ret = parse_sse_line("data: \t [DONE] \t\r\n", token, nullptr); bool ret = parse_sse_line("data: \t [DONE] \t\r\n", token, nullptr);
CHECK(ret, "T2.7: '[DONE]' with mixed whitespace returns true"); CHECK(ret, "T2.7: '[DONE]' with mixed whitespace returns true");
@@ -121,7 +139,7 @@ int main()
} }
{ {
// [DONE] without spaces — exact match // [DONE] without spaces — exact match / [DONE] 精确匹配(无空格)
std::string token; std::string token;
bool ret = parse_sse_line("data: [DONE]", token, nullptr); bool ret = parse_sse_line("data: [DONE]", token, nullptr);
CHECK(ret, "T2.9: '[DONE]' exact match returns true"); CHECK(ret, "T2.9: '[DONE]' exact match returns true");
@@ -129,13 +147,14 @@ int main()
{ {
// "[done]" lowercase — should NOT match (case-sensitive) // "[done]" lowercase — should NOT match (case-sensitive)
// "[done]" 小写 — 不应匹配(大小写敏感)
std::string token; std::string token;
bool ret = parse_sse_line("data: [done]", token, nullptr); bool ret = parse_sse_line("data: [done]", token, nullptr);
CHECK(!ret, "T2.10: '[done]' lowercase NOT treated as DONE (case-sensitive)"); CHECK(!ret, "T2.10: '[done]' lowercase NOT treated as DONE (case-sensitive)");
} }
{ {
// "[DONE" without closing bracket // "[DONE" without closing bracket / "[DONE" 缺少闭括号
std::string token; std::string token;
bool ret = parse_sse_line("data: [DONE", token, nullptr); bool ret = parse_sse_line("data: [DONE", token, nullptr);
CHECK(!ret, "T2.11: '[DONE' (no closing bracket) not treated as DONE"); CHECK(!ret, "T2.11: '[DONE' (no closing bracket) not treated as DONE");
@@ -143,6 +162,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 3: parse_sse_line — content delta // Test Block 3: parse_sse_line — content delta
// 测试块 3parse_sse_line — content delta
// ================================================================ // ================================================================
std::cout << "\n--- Block 3: parse_sse_line content delta ---\n"; std::cout << "\n--- Block 3: parse_sse_line content delta ---\n";
@@ -166,7 +186,7 @@ int main()
} }
{ {
// Delta with no content field // Delta with no content field / delta 不含 content 字段
std::string token; std::string token;
const char* json = const char* json =
"data: {\"choices\":[{\"delta\":{},\"index\":0}]}"; "data: {\"choices\":[{\"delta\":{},\"index\":0}]}";
@@ -175,7 +195,7 @@ int main()
} }
{ {
// Empty choices array // Empty choices array / 空 choices 数组
std::string token; std::string token;
const char* json = const char* json =
"data: {\"choices\":[]}"; "data: {\"choices\":[]}";
@@ -184,7 +204,7 @@ int main()
} }
{ {
// Single character token (typical streaming) // Single character token (typical streaming) / 单字符 token典型流式
std::string token; std::string token;
const char* json = const char* json =
"data: {\"choices\":[{\"delta\":{\"content\":\"H\"},\"index\":0}]}"; "data: {\"choices\":[{\"delta\":{\"content\":\"H\"},\"index\":0}]}";
@@ -194,7 +214,7 @@ int main()
} }
{ {
// Multi-byte UTF-8 content (emoji) in delta // Multi-byte UTF-8 content (emoji) in delta / delta 中的多字节 UTF-8 内容emoji
std::string token; std::string token;
const char* json = const char* json =
"data: {\"choices\":[{\"delta\":{\"content\":\"\\uD83D\\uDE00\"}," "data: {\"choices\":[{\"delta\":{\"content\":\"\\uD83D\\uDE00\"},"
@@ -207,7 +227,7 @@ int main()
} }
{ {
// Malformed JSON structure — no "delta" key // Malformed JSON structure — no "delta" key / 畸形 JSON 结构 — 无 "delta" key
std::string token; std::string token;
const char* json = const char* json =
"data: {\"choices\":[{\"no_delta\":{},\"index\":0}]}"; "data: {\"choices\":[{\"no_delta\":{},\"index\":0}]}";
@@ -217,6 +237,7 @@ int main()
{ {
// Realistic DeepSeek streaming chunk (with finish_reason) // Realistic DeepSeek streaming chunk (with finish_reason)
// 真实的 DeepSeek 流式数据块(含 finish_reason
std::string token; std::string token;
const char* json = const char* json =
"data: {\"id\":\"chatcmpl-xxx\"," "data: {\"id\":\"chatcmpl-xxx\","
@@ -233,11 +254,13 @@ int main()
// ================================================================ // ================================================================
// Test Block 4: parse_sse_line — tool_calls delta // Test Block 4: parse_sse_line — tool_calls delta
// 测试块 4parse_sse_line — tool_calls delta
// ================================================================ // ================================================================
std::cout << "\n--- Block 4: parse_sse_line tool_calls delta ---\n"; std::cout << "\n--- Block 4: parse_sse_line tool_calls delta ---\n";
{ {
// tool_calls chunk with id + function name (first chunk) // tool_calls chunk with id + function name (first chunk)
// tool_calls 数据块含 id + function name首个数据块
StreamContext ctx = {}; StreamContext ctx = {};
std::string token; std::string token;
const char* json = const char* json =
@@ -258,8 +281,9 @@ int main()
{ {
// tool_calls arguments chunk (second chunk, same index) // tool_calls arguments chunk (second chunk, same index)
// tool_calls arguments 数据块(第二个数据块,相同 index
StreamContext ctx; StreamContext ctx;
// First, set up the initial state // First, set up the initial state / 先设置初始状态
ctx.tool_calls.push_back({0, "call_abc123", "get_weather", ""}); ctx.tool_calls.push_back({0, "call_abc123", "get_weather", ""});
std::string token; std::string token;
@@ -276,7 +300,7 @@ int main()
} }
{ {
// tool_calls final arguments chunk // tool_calls final arguments chunk / tool_calls 最终 arguments 数据块
StreamContext ctx; StreamContext ctx;
ctx.tool_calls.push_back({0, "call_abc123", "get_weather", "{\"city\":\""}); ctx.tool_calls.push_back({0, "call_abc123", "get_weather", "{\"city\":\""});
@@ -295,6 +319,7 @@ int main()
{ {
// tool_calls with null ctx — should skip tool_calls processing // tool_calls with null ctx — should skip tool_calls processing
// tool_calls 配合 null ctx — 应跳过 tool_calls 处理
std::string token; std::string token;
const char* json = const char* json =
"data: {\"choices\":[{\"index\":0," "data: {\"choices\":[{\"index\":0,"
@@ -306,6 +331,7 @@ int main()
{ {
// Multiple tool_calls in single chunk (unusual but valid) // Multiple tool_calls in single chunk (unusual but valid)
// 单个数据块中有多个 tool_calls不常见但合法
StreamContext ctx; StreamContext ctx;
std::string token; std::string token;
const char* json = const char* json =
@@ -325,6 +351,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 5: build_request_json — basic cases // Test Block 5: build_request_json — basic cases
// 测试块 5build_request_json — 基础用例
// ================================================================ // ================================================================
setup_config(); setup_config();
std::cout << "\n--- Block 5: build_request_json basic ---\n"; std::cout << "\n--- Block 5: build_request_json basic ---\n";
@@ -351,7 +378,7 @@ int main()
} }
{ {
// With user+assistant history // With user+assistant history / 包含 user+assistant 历史
dstalk_message_t msgs[2] = { dstalk_message_t msgs[2] = {
{"user", "What is 2+2?", nullptr, nullptr}, {"user", "What is 2+2?", nullptr, nullptr},
{"assistant", "It is 4.", nullptr, nullptr} {"assistant", "It is 4.", nullptr, nullptr}
@@ -376,22 +403,26 @@ int main()
{ {
// Empty user input — no user message appended // Empty user input — no user message appended
// 空用户输入 — 不追加 user 消息
std::string json = build_request_json( std::string json = build_request_json(
nullptr, 0, "", "", false); nullptr, 0, "", "", false);
CHECK(!json.empty(), "T5.13: empty user input produces valid JSON"); CHECK(!json.empty(), "T5.13: empty user input produces valid JSON");
// DeepSeek's build_request_json checks `if (!user_input.empty())` before adding // DeepSeek's build_request_json checks `if (!user_input.empty())` before adding
// So there should be no user message for empty input // So there should be no user message for empty input
// DeepSeek 的 build_request_json 在添加前检查 `if (!user_input.empty())`
// 因此空输入时不应有 user 消息
CHECK(json.find("\"role\":\"user\"") == std::string::npos, CHECK(json.find("\"role\":\"user\"") == std::string::npos,
"T5.14: empty user input NOT added to messages (DeepSeek guard)"); "T5.14: empty user input NOT added to messages (DeepSeek guard)");
} }
// ================================================================ // ================================================================
// Test Block 6: build_request_json — tools / edge cases // Test Block 6: build_request_json — tools / edge cases
// 测试块 6build_request_json — tools / 边界情况
// ================================================================ // ================================================================
std::cout << "\n--- Block 6: build_request_json tools / edges ---\n"; std::cout << "\n--- Block 6: build_request_json tools / edges ---\n";
{ {
// With tools_json // With tools_json / 含 tools_json
std::string tools = "[{\"type\":\"function\"," std::string tools = "[{\"type\":\"function\","
"\"function\":{\"name\":\"get_weather\"," "\"function\":{\"name\":\"get_weather\","
"\"description\":\"Get current weather\"," "\"description\":\"Get current weather\","
@@ -407,7 +438,7 @@ int main()
} }
{ {
// Empty tools_json — no tools field // Empty tools_json — no tools field / 空 tools_json — 无 tools 字段
std::string json = build_request_json( std::string json = build_request_json(
nullptr, 0, "Hello", "", false); nullptr, 0, "Hello", "", false);
CHECK(json.find("\"tools\"") == std::string::npos, CHECK(json.find("\"tools\"") == std::string::npos,
@@ -418,6 +449,8 @@ int main()
// Malformed tools_json — build_request_json calls json::parse() // Malformed tools_json — build_request_json calls json::parse()
// without try/catch, so it will throw std::exception. // without try/catch, so it will throw std::exception.
// This test verifies that the exception is thrown (rather than crashing). // This test verifies that the exception is thrown (rather than crashing).
// 畸形 tools_json — build_request_json 调用 json::parse() 不含 try/catch
// 因此会抛出 std::exception。本测试验证异常被抛出而非崩溃
bool threw = false; bool threw = false;
try { try {
build_request_json(nullptr, 0, "Hello", "NOT JSON", false); build_request_json(nullptr, 0, "Hello", "NOT JSON", false);
@@ -430,7 +463,7 @@ int main()
} }
{ {
// History with null role // History with null role / null 角色的历史
dstalk_message_t msgs[1] = { dstalk_message_t msgs[1] = {
{nullptr, "some content", nullptr, nullptr} {nullptr, "some content", nullptr, nullptr}
}; };
@@ -439,7 +472,7 @@ int main()
} }
{ {
// History with null content // History with null content / null 内容的历史
dstalk_message_t msgs[1] = { dstalk_message_t msgs[1] = {
{"user", nullptr, nullptr, nullptr} {"user", nullptr, nullptr, nullptr}
}; };
@@ -448,7 +481,7 @@ int main()
} }
{ {
// Very long message // Very long message / 超长消息
std::string long_input(5000, 'A'); std::string long_input(5000, 'A');
std::string json = build_request_json( std::string json = build_request_json(
nullptr, 0, long_input, "", false); nullptr, 0, long_input, "", false);
@@ -458,6 +491,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 7: build_headers_json // Test Block 7: build_headers_json
// 测试块 7build_headers_json
// ================================================================ // ================================================================
std::cout << "\n--- Block 7: build_headers_json ---\n"; std::cout << "\n--- Block 7: build_headers_json ---\n";
@@ -470,7 +504,7 @@ int main()
} }
{ {
// Empty API key // Empty API key / 空 API key
std::string headers = build_headers_json(""); std::string headers = build_headers_json("");
CHECK(headers.find("Authorization") != std::string::npos, CHECK(headers.find("Authorization") != std::string::npos,
"T7.3: Authorization header present with empty key"); "T7.3: Authorization header present with empty key");
@@ -480,6 +514,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 8: extract_host_port (same logic as anthropic) // Test Block 8: extract_host_port (same logic as anthropic)
// 测试块 8extract_host_port逻辑同 anthropic
// ================================================================ // ================================================================
std::cout << "\n--- Block 8: extract_host_port ---\n"; std::cout << "\n--- Block 8: extract_host_port ---\n";
@@ -525,6 +560,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 9: secure_zero // Test Block 9: secure_zero
// 测试块 9secure_zero
// ================================================================ // ================================================================
std::cout << "\n--- Block 9: secure_zero ---\n"; std::cout << "\n--- Block 9: secure_zero ---\n";
@@ -546,6 +582,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 10: append_history // Test Block 10: append_history
// 测试块 10append_history
// ================================================================ // ================================================================
std::cout << "\n--- Block 10: append_history ---\n"; std::cout << "\n--- Block 10: append_history ---\n";
@@ -561,7 +598,7 @@ int main()
} }
{ {
// Tool message (should include tool_call_id) // Tool message (should include tool_call_id) / Tool 消息(应包含 tool_call_id
json::array msgs; json::array msgs;
dstalk_message_t m = {"tool", "result data", "call_xyz", nullptr}; dstalk_message_t m = {"tool", "result data", "call_xyz", nullptr};
append_history(msgs, &m, 1); append_history(msgs, &m, 1);
@@ -575,7 +612,7 @@ int main()
} }
{ {
// Assistant with tool_calls_json // Assistant with tool_calls_json / Assistant 含 tool_calls_json
json::array msgs; json::array msgs;
const char* tc_json = "[{\"id\":\"call_1\",\"type\":\"function\"," const char* tc_json = "[{\"id\":\"call_1\",\"type\":\"function\","
"\"function\":{\"name\":\"get_weather\",\"arguments\":\"{}\"}}]"; "\"function\":{\"name\":\"get_weather\",\"arguments\":\"{}\"}}]";
@@ -589,14 +626,14 @@ int main()
} }
{ {
// Empty history (0 messages) // Empty history (0 messages) / 空历史0 条消息)
json::array msgs; json::array msgs;
append_history(msgs, nullptr, 0); append_history(msgs, nullptr, 0);
CHECK(msgs.size() == 0, "T10.12: empty history produces empty array"); CHECK(msgs.size() == 0, "T10.12: empty history produces empty array");
} }
{ {
// Multiple messages // Multiple messages / 多条消息
json::array msgs; json::array msgs;
dstalk_message_t ms[2] = { dstalk_message_t ms[2] = {
{"user", "Q1", nullptr, nullptr}, {"user", "Q1", nullptr, nullptr},
@@ -608,6 +645,7 @@ int main()
{ {
// Null role and null content — default to empty strings // Null role and null content — default to empty strings
// null 角色与 null 内容 — 默认为空字符串
json::array msgs; json::array msgs;
dstalk_message_t m = {nullptr, nullptr, nullptr, nullptr}; dstalk_message_t m = {nullptr, nullptr, nullptr, nullptr};
append_history(msgs, &m, 1); append_history(msgs, &m, 1);
@@ -619,11 +657,13 @@ int main()
// ================================================================ // ================================================================
// Test Block 11: my_free_result — null safety // Test Block 11: my_free_result — null safety
// 测试块 11my_free_result — 空指针安全
// ================================================================ // ================================================================
std::cout << "\n--- Block 11: my_free_result null safety ---\n"; std::cout << "\n--- Block 11: my_free_result null safety ---\n";
{ {
// g_host is nullptr, so free_result should early-return // g_host is nullptr, so free_result should early-return
// g_host 为 nullptrfree_result 应提前返回
my_free_result(nullptr); my_free_result(nullptr);
CHECK(true, "T11.1: free_result(nullptr) does not crash (null host)"); CHECK(true, "T11.1: free_result(nullptr) does not crash (null host)");
} }
@@ -637,6 +677,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 12: my_configure — null host safety // Test Block 12: my_configure — null host safety
// 测试块 12my_configure — null host 安全
// ================================================================ // ================================================================
std::cout << "\n--- Block 12: my_configure null host safety ---\n"; std::cout << "\n--- Block 12: my_configure null host safety ---\n";
@@ -656,7 +697,7 @@ int main()
} }
// ================================================================ // ================================================================
// Summary // Summary / 总结
// ================================================================ // ================================================================
std::cout << "\n"; std::cout << "\n";
if (g_failures == 0) { if (g_failures == 0) {

View File

@@ -1,8 +1,10 @@
// ============================================================================ /*
// event_bus_test.cpp — EventBus 单元测试 * @file event_bus_test.cpp
// ============================================================================ * @brief EventBus unit tests: subscribe, emit, unsubscribe, multi-handler
// 测试: subscribe / unsubscribe / emit / 多订阅者 / 空总线 * dispatch order, independent event types.
// ============================================================================ * EventBus 单元测试:订阅、发布、取消订阅、多处理器分发顺序、独立事件类型。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
@@ -13,6 +15,7 @@
// ---- 轻量断言 ---- // ---- 轻量断言 ----
static int g_failures = 0; static int g_failures = 0;
// Lightweight assertion helper: increments g_failures counter on failure
#define TCHECK(cond, msg) do { \ #define TCHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \ std::cout << "[OK] " << (msg) << "\n"; \
@@ -22,13 +25,16 @@ static int g_failures = 0;
} \ } \
} while (0) } while (0)
// ============================================================ // EventBus 单元测试:订阅+发布、取消订阅、多处理器分发顺序、空总线、独立事件类型路由、取消不存在的订阅。
// EventBus unit tests: subscribe+emit, unsubscribe, multi-handler dispatch order,
// empty bus, independent event type routing, and non-existent unsubscribe safety.
int main() int main()
{ {
std::cout << "=== dstalk event_bus unit tests ===\n\n"; std::cout << "=== dstalk event_bus unit tests ===\n\n";
// ==================================================================== // ====================================================================
// Test 1: subscribe + emit — 基本发布订阅流程 // Test 1: subscribe + emit — 基本发布订阅流程
// Test 1: subscribe + emit — basic pub/sub flow
// ==================================================================== // ====================================================================
{ {
dstalk::EventBus bus; dstalk::EventBus bus;
@@ -49,6 +55,7 @@ int main()
// ==================================================================== // ====================================================================
// Test 2: unsubscribe — 取消订阅后 handler 不再被调用 // Test 2: unsubscribe — 取消订阅后 handler 不再被调用
// Test 2: unsubscribe — handler NOT called after unsubscription
// ==================================================================== // ====================================================================
{ {
dstalk::EventBus bus; dstalk::EventBus bus;
@@ -64,6 +71,7 @@ int main()
// ==================================================================== // ====================================================================
// Test 3: 多订阅者 — 同一事件多个 handler 按订阅顺序全部调用 // Test 3: 多订阅者 — 同一事件多个 handler 按订阅顺序全部调用
// Test 3: multi-subscriber — all handlers for same event invoked in subscription order
// ==================================================================== // ====================================================================
{ {
dstalk::EventBus bus; dstalk::EventBus bus;
@@ -77,13 +85,14 @@ int main()
TCHECK(emitted == 3, "emit returns 3 handlers called"); TCHECK(emitted == 3, "emit returns 3 handlers called");
TCHECK(order.size() == 3, "all 3 handlers invoked"); TCHECK(order.size() == 3, "all 3 handlers invoked");
// 验证订阅顺序 (FIFO: 按 subscribe 顺序触发) // 验证订阅顺序 (FIFO: 按 subscribe 顺序触发) / Verify subscription order (FIFO: in subscribe order)
bool ordered = (order[0] == 1 && order[1] == 2 && order[2] == 3); bool ordered = (order[0] == 1 && order[1] == 2 && order[2] == 3);
TCHECK(ordered, "handlers invoked in subscription order (1,2,3)"); TCHECK(ordered, "handlers invoked in subscription order (1,2,3)");
} }
// ==================================================================== // ====================================================================
// Test 4: 空总线 emit 不崩溃,返回 0 // Test 4: 空总线 emit 不崩溃,返回 0
// Test 4: emit on empty bus no crash, returns 0
// ==================================================================== // ====================================================================
{ {
dstalk::EventBus bus; dstalk::EventBus bus;
@@ -93,6 +102,7 @@ int main()
// ==================================================================== // ====================================================================
// Test 5: 不同 event_type 独立分发 — 只触发匹配的 handler // Test 5: 不同 event_type 独立分发 — 只触发匹配的 handler
// Test 5: independent event_type dispatch — only matching handler triggered
// ==================================================================== // ====================================================================
{ {
dstalk::EventBus bus; dstalk::EventBus bus;
@@ -112,15 +122,16 @@ int main()
// ==================================================================== // ====================================================================
// Test 6: 退订不存在的 ID 不崩溃 // Test 6: 退订不存在的 ID 不崩溃
// Test 6: unsubscribe non-existent ID does not crash
// ==================================================================== // ====================================================================
{ {
dstalk::EventBus bus; dstalk::EventBus bus;
bus.unsubscribe(99999); // 不存在的 ID bus.unsubscribe(99999); // 不存在的 ID / non-existent ID
std::cout << "[OK] unsubscribe non-existent ID (99999) did not crash\n"; std::cout << "[OK] unsubscribe non-existent ID (99999) did not crash\n";
} }
// ==================================================================== // ====================================================================
// 结果 // 结果 / Result
// ==================================================================== // ====================================================================
std::cout << "\n"; std::cout << "\n";
if (g_failures == 0) { if (g_failures == 0) {

View File

@@ -1,8 +1,10 @@
// ============================================================================ /*
// host_api_test.cpp — host API 单元测试 (独立于 smoke_test) * @file host_api_test.cpp
// ============================================================================ * @brief Host API unit tests: service registration, event bus, config store,
// 测试: register_service / query_service / alloc / free / log / init / shutdown * alloc/free, logging, init/shutdown lifecycle.
// ============================================================================ * Host API 单元测试服务注册、事件总线、配置存储、alloc/free、日志、init/shutdown 生命周期。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include <cstdarg> #include <cstdarg>
#include <cstdio> #include <cstdio>
@@ -13,13 +15,14 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
// 引入 ServiceRegistry 实现做纯单元测试 // 引入 ServiceRegistry 实现做纯单元测试 / Include ServiceRegistry impl for pure unit tests
#include "service_registry.hpp" #include "service_registry.hpp"
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
// ---- 轻量断言 ---- // ---- 轻量断言 ----
static int g_failures = 0; static int g_failures = 0;
// Lightweight assertion helper: increments g_failures counter on failure
#define TCHECK(cond, msg) do { \ #define TCHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \ std::cout << "[OK] " << (msg) << "\n"; \
@@ -29,26 +32,32 @@ static int g_failures = 0;
} \ } \
} while (0) } while (0)
// ---- 辅助: 创建临时配置文件 ----
// Helper: creates a temporary config.toml pointing to a non-existent plugin dir,
// so dstalk_init loads no external plugins during tests.
// 辅助函数:创建临时 config.toml 指向不存在的插件目录,使 dstalk_init 在测试时不加载任何外部插件。
static std::string make_temp_config(const std::string& tag) { static std::string make_temp_config(const std::string& tag) {
auto dir = std::filesystem::temp_directory_path() / ("dstalk-host-api-" + tag); auto dir = std::filesystem::temp_directory_path() / ("dstalk-host-api-" + tag);
std::filesystem::create_directories(dir); std::filesystem::create_directories(dir);
auto config_path = dir / "config.toml"; auto config_path = dir / "config.toml";
{ {
std::ofstream c(config_path); std::ofstream c(config_path);
// 指向不存在的插件目录,避免加载任何 .dll // 指向不存在的插件目录,避免加载任何 .dll / Point to nonexistent plugin dir, avoid loading any .dll
c << "plugin_dir = \"__no_such_plugins_dir__\"\n"; c << "plugin_dir = \"__no_such_plugins_dir__\"\n";
} }
return config_path.string(); return config_path.string();
} }
// ============================================================ // Host API 单元测试:覆盖注册/查询重复、版本不匹配、双重 init 防护、alloc/free 边界、日志级别、shutdown 后查询。
// Host API unit tests: covers register/query duplicates, version mismatch,
// double-init guard, alloc/free edge cases, logging levels, and post-shutdown query.
int main() int main()
{ {
std::cout << "=== dstalk host_api unit tests ===\n\n"; std::cout << "=== dstalk host_api unit tests ===\n\n";
// ==================================================================== // ====================================================================
// Test 1: register_service 重复注册 同名+同版本 → 应返回 -2 // Test 1: register_service 重复注册 同名+同版本 → 应返回 -2
// Test 1: register_service duplicate same-name+same-version -> should return -2
// ==================================================================== // ====================================================================
{ {
dstalk::ServiceRegistry reg; dstalk::ServiceRegistry reg;
@@ -64,6 +73,8 @@ int main()
// ==================================================================== // ====================================================================
// Test 2: register_service 同名+不同版本 → 应返回 -2 // Test 2: register_service 同名+不同版本 → 应返回 -2
// 名称已占用,与版本无关 // 名称已占用,与版本无关
// Test 2: register_service same-name+different-version -> should return -2
// Name already taken, regardless of version
// ==================================================================== // ====================================================================
{ {
dstalk::ServiceRegistry reg; dstalk::ServiceRegistry reg;
@@ -78,6 +89,7 @@ int main()
// ==================================================================== // ====================================================================
// Test 3: query_service 不存在的 name → nullptr // Test 3: query_service 不存在的 name → nullptr
// Test 3: query_service nonexistent name -> nullptr
// ==================================================================== // ====================================================================
{ {
dstalk::ServiceRegistry reg; dstalk::ServiceRegistry reg;
@@ -88,6 +100,8 @@ int main()
// ==================================================================== // ====================================================================
// Test 4: query_service 错误版本号 → nullptr // Test 4: query_service 错误版本号 → nullptr
// 注册 v=1, 查询 min_version=2 → 不满足 → nullptr // 注册 v=1, 查询 min_version=2 → 不满足 → nullptr
// Test 4: query_service wrong version -> nullptr
// Registered v=1, query min_version=2 -> unsatisfied -> nullptr
// ==================================================================== // ====================================================================
{ {
dstalk::ServiceRegistry reg; dstalk::ServiceRegistry reg;
@@ -97,13 +111,14 @@ int main()
void* q = reg.query_service("solo", 2); void* q = reg.query_service("solo", 2);
TCHECK(q == nullptr, "query_service(\"solo\",2) with only v1 available returns nullptr"); TCHECK(q == nullptr, "query_service(\"solo\",2) with only v1 available returns nullptr");
// 确证以正确版本查询能拿到 // 确证以正确版本查询能拿到 / Confirm correct version query works
void* q2 = reg.query_service("solo", 1); void* q2 = reg.query_service("solo", 1);
TCHECK(q2 == dummy_vtable, "query_service(\"solo\",1) with v1 available returns vtable"); TCHECK(q2 == dummy_vtable, "query_service(\"solo\",1) with v1 available returns vtable");
} }
// ==================================================================== // ====================================================================
// Test 5: dstalk_init 多次调用 → 第二次应返回 -1 (幂等拒绝) // Test 5: dstalk_init 多次调用 → 第二次应返回 -1 (幂等拒绝)
// Test 5: dstalk_init multiple calls -> second should return -1 (idempotent guard)
// ==================================================================== // ====================================================================
{ {
std::string cfg = make_temp_config("init-twice"); std::string cfg = make_temp_config("init-twice");
@@ -120,6 +135,9 @@ int main()
// Test 6: alloc(0) / free(nullptr) 行为 // Test 6: alloc(0) / free(nullptr) 行为
// malloc(0) 可返回 null 或合法指针; 两者都可 free // malloc(0) 可返回 null 或合法指针; 两者都可 free
// free(nullptr) 是安全空操作 // free(nullptr) 是安全空操作
// Test 6: alloc(0) / free(nullptr) behavior
// malloc(0) may return null or valid pointer; both are free-able
// free(nullptr) is a safe no-op
// ==================================================================== // ====================================================================
{ {
void* p = dstalk_alloc(0); void* p = dstalk_alloc(0);
@@ -134,6 +152,7 @@ int main()
// ==================================================================== // ====================================================================
// Test 7: log 各 level 不崩溃 (DEBUG / INFO / WARN / ERROR) // Test 7: log 各 level 不崩溃 (DEBUG / INFO / WARN / ERROR)
// Test 7: log at each level no crash (DEBUG / INFO / WARN / ERROR)
// ==================================================================== // ====================================================================
{ {
dstalk_log(DSTALK_LOG_DEBUG, "host_api_test: debug level message"); dstalk_log(DSTALK_LOG_DEBUG, "host_api_test: debug level message");
@@ -148,7 +167,7 @@ int main()
dstalk_log(DSTALK_LOG_ERROR, "host_api_test: error level message"); dstalk_log(DSTALK_LOG_ERROR, "host_api_test: error level message");
std::cout << "[OK] dstalk_log(ERROR) no crash\n"; std::cout << "[OK] dstalk_log(ERROR) no crash\n";
// 带格式参数 // 带格式参数 / With format args
dstalk_log(DSTALK_LOG_INFO, "formatted: %s %d", "answer", 42); dstalk_log(DSTALK_LOG_INFO, "formatted: %s %d", "answer", 42);
std::cout << "[OK] dstalk_log with format args no crash\n"; std::cout << "[OK] dstalk_log with format args no crash\n";
} }
@@ -156,6 +175,8 @@ int main()
// ==================================================================== // ====================================================================
// Test 8: dstalk_shutdown 后 query_service → nullptr // Test 8: dstalk_shutdown 后 query_service → nullptr
// g_service_registry 已被 delete 置空 // g_service_registry 已被 delete 置空
// Test 8: query_service after dstalk_shutdown -> nullptr
// g_service_registry has been deleted and nulled
// ==================================================================== // ====================================================================
{ {
std::string cfg = make_temp_config("after-shutdown"); std::string cfg = make_temp_config("after-shutdown");
@@ -167,7 +188,7 @@ int main()
} }
// ==================================================================== // ====================================================================
// 结果 // 结果 / Result
// ==================================================================== // ====================================================================
std::cout << "\n"; std::cout << "\n";
if (g_failures == 0) { if (g_failures == 0) {

View File

@@ -1,8 +1,12 @@
// ============================================================================ /*
// network_plugin_test.cpp — Network 插件单元测试 * @file network_plugin_test.cpp
// W22.2 (qa-xu): 覆盖 parse_headers_json / SSE 行解析 / 参数校验 * @brief Network plugin unit tests (W22.2): parse_headers_json (normal, empty,
// 通过 #include plugin source 访问 file-scope static 函数 * malformed, long values), SSE line splitting boundaries, and
// ============================================================================ * http_post_json/http_post_stream parameter validation.
* Network 插件单元测试 (W22.2)parse_headers_json正常、空、畸形、长值、SSE 行解析边界、
* http_post_json/http_post_stream 参数校验。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_WARNINGS
#define BOOST_ASIO_DISABLE_STD_TO_ADDRESS #define BOOST_ASIO_DISABLE_STD_TO_ADDRESS
#include "../plugins/network/src/network_plugin.cpp" #include "../plugins/network/src/network_plugin.cpp"
@@ -15,6 +19,7 @@
#include <vector> #include <vector>
static int g_failures = 0; static int g_failures = 0;
// Lightweight assertion macro: increments g_failures counter on failure
#define CHECK(cond, msg) do { \ #define CHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \ std::cout << "[OK] " << (msg) << "\n"; \
@@ -24,9 +29,9 @@ static int g_failures = 0;
} \ } \
} while (0) } while (0)
// ================================================================ // SSE line-split helper: mirrors do_post_stream's emit_lines logic for unit-testing
// SSE 行分割 helper (复刻 do_post_stream 的 emit_lines 逻辑) // SSE chunk parsing without a live network connection.
// ================================================================ // SSE 行分割辅助函数:镜像 do_post_stream 的 emit_lines 逻辑,无需实时网络连接即可单元测试 SSE 数据块解析。
static std::vector<std::string> split_sse_lines(std::string fragment) { static std::vector<std::string> split_sse_lines(std::string fragment) {
std::vector<std::string> lines; std::vector<std::string> lines;
size_t pos = 0; size_t pos = 0;
@@ -55,11 +60,17 @@ static std::vector<std::string> split_sse_lines(std::string fragment) {
return lines; return lines;
} }
// ================================================================ // Network 插件测试 (W22.2)parse_headers_json 正常/空/畸形/长值、
// SSE 行分割LF、CRLF、空、null 字节、尾部 CR
// http_post_json/http_post_stream 参数校验(空指针)。
// Network plugin tests (W22.2): parse_headers_json normal/empty/malformed/long,
// SSE line splitting (LF, CRLF, empty, null-bytes, trailing CR),
// and http_post_json/http_post_stream parameter validation (null pointers).
int main() int main()
{ {
// ================================================================ // ================================================================
// Test Block 1: parse_headers_json — 正常 JSON // Test Block 1: parse_headers_json — 正常 JSON
// Test Block 1: parse_headers_json — normal JSON
// ================================================================ // ================================================================
std::cout << "\n--- Block 1: parse_headers_json normal JSON ---\n"; std::cout << "\n--- Block 1: parse_headers_json normal JSON ---\n";
@@ -98,6 +109,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 2: parse_headers_json — 空 / null 输入 // Test Block 2: parse_headers_json — 空 / null 输入
// Test Block 2: parse_headers_json — empty/null input
// ================================================================ // ================================================================
std::cout << "\n--- Block 2: parse_headers_json empty/null input ---\n"; std::cout << "\n--- Block 2: parse_headers_json empty/null input ---\n";
@@ -124,6 +136,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 3: parse_headers_json — 畸形 JSON // Test Block 3: parse_headers_json — 畸形 JSON
// Test Block 3: parse_headers_json — malformed JSON
// ================================================================ // ================================================================
std::cout << "\n--- Block 3: parse_headers_json malformed JSON ---\n"; std::cout << "\n--- Block 3: parse_headers_json malformed JSON ---\n";
@@ -177,6 +190,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 4: parse_headers_json — 超长 header 值 // Test Block 4: parse_headers_json — 超长 header 值
// Test Block 4: parse_headers_json — long values
// ================================================================ // ================================================================
std::cout << "\n--- Block 4: parse_headers_json long values ---\n"; std::cout << "\n--- Block 4: parse_headers_json long values ---\n";
@@ -209,6 +223,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 5: SSE 行解析边界 // Test Block 5: SSE 行解析边界
// Test Block 5: SSE line splitting boundaries
// ================================================================ // ================================================================
std::cout << "\n--- Block 5: SSE line splitting boundaries ---\n"; std::cout << "\n--- Block 5: SSE line splitting boundaries ---\n";
@@ -274,6 +289,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 6: http_post_json — 参数校验 (null ptr, early return) // Test Block 6: http_post_json — 参数校验 (null ptr, early return)
// Test Block 6: http_post_json — parameter validation (null ptr, early return)
// ================================================================ // ================================================================
std::cout << "\n--- Block 6: http_post_json parameter validation ---\n"; std::cout << "\n--- Block 6: http_post_json parameter validation ---\n";
@@ -319,6 +335,7 @@ int main()
// ================================================================ // ================================================================
// Test Block 7: http_post_stream — 参数校验 // Test Block 7: http_post_stream — 参数校验
// Test Block 7: http_post_stream — parameter validation
// ================================================================ // ================================================================
std::cout << "\n--- Block 7: http_post_stream parameter validation ---\n"; std::cout << "\n--- Block 7: http_post_stream parameter validation ---\n";
@@ -338,7 +355,7 @@ int main()
} }
// ================================================================ // ================================================================
// Summary // Summary / 总结
// ================================================================ // ================================================================
std::cout << "\n"; std::cout << "\n";
if (g_failures == 0) { if (g_failures == 0) {

View File

@@ -1,12 +1,10 @@
// ============================================================================ /*
// plugin_loader_test.cpp — PluginLoader 安全回归测试 * @file plugin_loader_test.cpp
// ============================================================================ * @brief PluginLoader safety regression tests (W20.3): path validation,
// W20.3 (qa-xu 徐磊): 覆盖 W19 修复的 5 条发现 (F-18.3-1~5) * ABI checks, next_id_ atomicity, failure-path logging with mock host API.
// - F-18.3-3: 路径验证 (lexically_normal + 扩展名 + 目录约束) * PluginLoader 安全回归测试 (W20.3)路径验证、ABI 检查、next_id_ 原子性、失败路径日志(使用 mock host API
// - F-18.3-4: next_id_ atomic 唯一性 + 单调递增 * Copyright (c) 2026 dstalk contributors. GPLv3.
// - F-18.3-2: host_api_->log 调用 (mock 验证) */
// - F-18.3-1: try/catch 异常安全边界 (间接: 注入 mock 不崩溃)
// ============================================================================
#include "plugin_loader.hpp" #include "plugin_loader.hpp"
@@ -24,6 +22,7 @@ namespace fs = std::filesystem;
// ---- 轻量断言 ---- // ---- 轻量断言 ----
static int g_failures = 0; static int g_failures = 0;
// Lightweight assertion macro: increments g_failures counter on failure
#define CHECK(cond, msg) do { \ #define CHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \ std::cout << "[OK] " << (msg) << "\n"; \
@@ -35,11 +34,14 @@ static int g_failures = 0;
// ============================================================================ // ============================================================================
// Mock host_api — 捕获 log 调用以验证失败路径日志 (F-18.3-2) // Mock host_api — 捕获 log 调用以验证失败路径日志 (F-18.3-2)
// Mock host_api — captures log calls to verify failure-path logging (F-18.3-2)
// ============================================================================ // ============================================================================
static int g_log_call_count = 0; static int g_log_call_count = 0;
static int g_last_severity = 0; static int g_last_severity = 0;
static char g_last_log_msg[1024] = {0}; static char g_last_log_msg[1024] = {0};
// Mock host_api::log implementation: counts calls and captures last severity+message
// Mock host_api::log 实现:计数调用并捕获最后的 severity+message
static void mock_log(int level, const char* fmt, ...) { static void mock_log(int level, const char* fmt, ...) {
g_log_call_count++; g_log_call_count++;
g_last_severity = level; g_last_severity = level;
@@ -49,6 +51,8 @@ static void mock_log(int level, const char* fmt, ...) {
va_end(args); va_end(args);
} }
// Stub host_api functions: return failure/default for all operations except log
// Stub host_api 函数:除 log 外所有操作均返回失败/默认值
static int stub_reg(const char*, int, void*) { return -1; } static int stub_reg(const char*, int, void*) { return -1; }
static void* stub_query(const char*, int) { return nullptr; } static void* stub_query(const char*, int) { return nullptr; }
static int stub_sub(int, dstalk_event_handler_fn, void*) { return -1; } static int stub_sub(int, dstalk_event_handler_fn, void*) { return -1; }
@@ -60,6 +64,8 @@ static void* stub_alloc(size_t) { return nullptr; }
static void stub_free(void*) {} static void stub_free(void*) {}
static char* stub_strdup(const char*) { return nullptr; } static char* stub_strdup(const char*) { return nullptr; }
// Mock host_api vtable: all stubs except mock_log for capturing error-path diagnostics
// Mock host_api 虚表:除 mock_log 外全部 stub用于捕获错误路径诊断
static dstalk_host_api_t g_mock_host_api = { static dstalk_host_api_t g_mock_host_api = {
stub_reg, stub_query, stub_reg, stub_query,
stub_sub, stub_emit, stub_unsub, stub_sub, stub_emit, stub_unsub,
@@ -68,15 +74,16 @@ static dstalk_host_api_t g_mock_host_api = {
stub_alloc, stub_free, stub_strdup stub_alloc, stub_free, stub_strdup
}; };
// Reset log capture state between tests
// 重置日志捕获状态(测试间使用)
static void reset_log_state() { static void reset_log_state() {
g_log_call_count = 0; g_log_call_count = 0;
g_last_severity = 0; g_last_severity = 0;
g_last_log_msg[0] = '\0'; g_last_log_msg[0] = '\0';
} }
// ============================================================================ // Get the absolute path to the build output plugins/ directory
// Helper: 获取构建 plugins/ 目录绝对路径 // 获取构建输出 plugins/ 目录绝对路径
// ============================================================================
static fs::path get_plugins_dir() { static fs::path get_plugins_dir() {
#ifdef DSTALK_TEST_PLUGINS_DIR #ifdef DSTALK_TEST_PLUGINS_DIR
return fs::path(DSTALK_TEST_PLUGINS_DIR); return fs::path(DSTALK_TEST_PLUGINS_DIR);
@@ -85,50 +92,56 @@ static fs::path get_plugins_dir() {
#endif #endif
} }
// ============================================================================ // PluginLoader 回归测试 (W20.3)F-18.3-3 路径验证拒绝、F-18.3-4 next_id_ 唯一性+单调性+并发、
// F-18.3-2 失败路径日志,以及边界情况(空 loader、无效操作
// PluginLoader regression tests (W20.3): F-18.3-3 path validation rejection,
// F-18.3-4 next_id_ uniqueness+monotonic+concurrent, F-18.3-2 failure-path logging,
// and edge cases (empty loader, invalid operations).
int main() int main()
{ {
std::cout << "=== dstalk plugin_loader regression tests (W20.3) ===\n\n"; std::cout << "=== dstalk plugin_loader regression tests (W20.3) ===\n\n";
// ======================================================================== // ========================================================================
// Block 1: 路径验证 — 拒绝非法路径 (F-18.3-3) // Block 1: 路径验证 — 拒绝非法路径 (F-18.3-3)
// Block 1: Path validation — reject illegal paths (F-18.3-3)
// ======================================================================== // ========================================================================
std::cout << "--- Block 1: Path validation — rejection ---\n"; std::cout << "--- Block 1: Path validation — rejection ---\n";
{ {
dstalk::PluginLoader loader; dstalk::PluginLoader loader;
// T1.1: nullptr // T1.1: nullptr / null pointer
CHECK(loader.load_plugin(nullptr) == -1, CHECK(loader.load_plugin(nullptr) == -1,
"T1.1: nullptr path returns -1"); "T1.1: nullptr path returns -1");
// T1.2: 非法扩展名 .txt // T1.2: 非法扩展名 .txt / illegal .txt extension
CHECK(loader.load_plugin("plugins/test.txt") == -1, CHECK(loader.load_plugin("plugins/test.txt") == -1,
"T1.2: .txt extension rejected"); "T1.2: .txt extension rejected");
// T1.3: 路径含 .. 遍历 // T1.3: 路径含 .. 遍历 / path contains .. traversal
CHECK(loader.load_plugin("../plugins/test.dll") == -1, CHECK(loader.load_plugin("../plugins/test.dll") == -1,
"T1.3: ../ traversal rejected"); "T1.3: ../ traversal rejected");
// T1.4: 不在 plugins/ 目录下 // T1.4: 不在 plugins/ 目录下 / not under plugins/ dir
auto tmp = fs::temp_directory_path() / "dstalk_test_no_plugins" / "test.dll"; auto tmp = fs::temp_directory_path() / "dstalk_test_no_plugins" / "test.dll";
CHECK(loader.load_plugin(tmp.string().c_str()) == -1, CHECK(loader.load_plugin(tmp.string().c_str()) == -1,
"T1.4: path not under plugins/ dir rejected"); "T1.4: path not under plugins/ dir rejected");
// T1.5: 路径中间的 .. 段 // T1.5: 路径中间的 .. 段 / .. segment in middle of path
CHECK(loader.load_plugin("plugins/../secret/test.dll") == -1, CHECK(loader.load_plugin("plugins/../secret/test.dll") == -1,
"T1.5: .. in middle of path rejected"); "T1.5: .. in middle of path rejected");
// T1.6: 无扩展名 // T1.6: 无扩展名 / no extension
CHECK(loader.load_plugin("plugins/test") == -1, CHECK(loader.load_plugin("plugins/test") == -1,
"T1.6: no extension rejected"); "T1.6: no extension rejected");
// T1.7: 合法扩展名但不在 plugins/ 下 // T1.7: 合法扩展名但不在 plugins/ 下 / valid extension but not under plugins/
CHECK(loader.load_plugin("/etc/someconfig.so") == -1, CHECK(loader.load_plugin("/etc/someconfig.so") == -1,
"T1.7: .so extension but not under plugins/ rejected"); "T1.7: .so extension but not under plugins/ rejected");
} }
// ======================================================================== // ========================================================================
// Block 2: 合法路径 — 成功加载 + next_id_ 验证 (F-18.3-4) // Block 2: 合法路径 — 成功加载 + next_id_ 验证 (F-18.3-4)
// Block 2: Valid path — successful load + ID uniqueness (F-18.3-4)
// ======================================================================== // ========================================================================
std::cout << "\n--- Block 2: Valid path — successful load + ID uniqueness ---\n"; std::cout << "\n--- Block 2: Valid path — successful load + ID uniqueness ---\n";
{ {
@@ -144,23 +157,23 @@ int main()
std::cout << "[WARN] Plugin DLLs not found at " << plugins_dir.string() std::cout << "[WARN] Plugin DLLs not found at " << plugins_dir.string()
<< " — skipping Block 2\n"; << " — skipping Block 2\n";
} else { } else {
// T2.1: 加载第一个插件 // T2.1: 加载第一个插件 / load first plugin
int id1 = loader.load_plugin(dll_config.string().c_str()); int id1 = loader.load_plugin(dll_config.string().c_str());
CHECK(id1 >= 1, "T2.1: first plugin loaded with positive ID"); CHECK(id1 >= 1, "T2.1: first plugin loaded with positive ID");
std::cout << " id1 = " << id1 << "\n"; std::cout << " id1 = " << id1 << "\n";
// T2.2: 加载第二个不同插件 // T2.2: 加载第二个不同插件 / load second (different) plugin
int id2 = loader.load_plugin(dll_fileio.string().c_str()); int id2 = loader.load_plugin(dll_fileio.string().c_str());
CHECK(id2 >= 1, "T2.2: second plugin loaded with positive ID"); CHECK(id2 >= 1, "T2.2: second plugin loaded with positive ID");
std::cout << " id2 = " << id2 << "\n"; std::cout << " id2 = " << id2 << "\n";
// T2.3: ID 唯一 // T2.3: ID 唯一 / IDs are unique
CHECK(id1 != id2, "T2.3: IDs are unique (next_id_ atomicity)"); CHECK(id1 != id2, "T2.3: IDs are unique (next_id_ atomicity)");
// T2.4: ID 单调递增 // T2.4: ID 单调递增 / IDs monotonically increasing
CHECK(id2 > id1, "T2.4: IDs monotonically increasing"); CHECK(id2 > id1, "T2.4: IDs monotonically increasing");
// T2.5: get_plugin 可查询到已加载插件 // T2.5: get_plugin 可查询到已加载插件 / get_plugin can find loaded plugin
const dstalk::PluginInfo* info1 = loader.get_plugin(id1); const dstalk::PluginInfo* info1 = loader.get_plugin(id1);
CHECK(info1 != nullptr, "T2.5: get_plugin(id1) returns non-null"); CHECK(info1 != nullptr, "T2.5: get_plugin(id1) returns non-null");
if (info1) { if (info1) {
@@ -168,23 +181,24 @@ int main()
std::cout << " plugin1 name: " << info1->name << "\n"; std::cout << " plugin1 name: " << info1->name << "\n";
} }
// T2.7: get_plugin 对无效 ID 返回 nullptr // T2.7: get_plugin 对无效 ID 返回 nullptr / get_plugin returns nullptr for invalid ID
CHECK(loader.get_plugin(99999) == nullptr, CHECK(loader.get_plugin(99999) == nullptr,
"T2.7: get_plugin(invalid_id) returns nullptr"); "T2.7: get_plugin(invalid_id) returns nullptr");
// T2.8: 卸载后 get_plugin 返回 nullptr // T2.8: 卸载后 get_plugin 返回 nullptr / get_plugin returns nullptr after unload
int ret = loader.unload_plugin(id1); int ret = loader.unload_plugin(id1);
CHECK(ret == 0, "T2.8: unload_plugin returns 0"); CHECK(ret == 0, "T2.8: unload_plugin returns 0");
CHECK(loader.get_plugin(id1) == nullptr, CHECK(loader.get_plugin(id1) == nullptr,
"T2.9: get_plugin returns nullptr after unload"); "T2.9: get_plugin returns nullptr after unload");
// 清理 // 清理 / cleanup
loader.unload_plugin(id2); loader.unload_plugin(id2);
} }
} }
// ======================================================================== // ========================================================================
// Block 3: next_id_ 原子性 — 多线程并发加载 (F-18.3-4) // Block 3: next_id_ 原子性 — 多线程并发加载 (F-18.3-4)
// Block 3: next_id_ atomicity — concurrent loads (F-18.3-4)
// ======================================================================== // ========================================================================
std::cout << "\n--- Block 3: next_id_ atomicity — concurrent loads ---\n"; std::cout << "\n--- Block 3: next_id_ atomicity — concurrent loads ---\n";
{ {
@@ -213,7 +227,7 @@ int main()
for (auto& t : threads) t.join(); for (auto& t : threads) t.join();
// 验证: 所有 load 成功, ID 唯一且 > 0 // 验证: 所有 load 成功, ID 唯一且 > 0 / Verify: all loads succeed, IDs unique and > 0
std::vector<int> valid_ids; std::vector<int> valid_ids;
for (size_t i = 0; i < ids.size(); ++i) { for (size_t i = 0; i < ids.size(); ++i) {
CHECK(ids[i] >= 1, "T3." + std::to_string(i) CHECK(ids[i] >= 1, "T3." + std::to_string(i)
@@ -222,7 +236,7 @@ int main()
if (ids[i] >= 1) valid_ids.push_back(ids[i]); if (ids[i] >= 1) valid_ids.push_back(ids[i]);
} }
// 去重后大小应等于成功加载数 // 去重后大小应等于成功加载数 / dedup size should equal successful load count
std::sort(valid_ids.begin(), valid_ids.end()); std::sort(valid_ids.begin(), valid_ids.end());
auto dup = std::unique(valid_ids.begin(), valid_ids.end()); auto dup = std::unique(valid_ids.begin(), valid_ids.end());
size_t unique_count = std::distance(valid_ids.begin(), dup); size_t unique_count = std::distance(valid_ids.begin(), dup);
@@ -231,26 +245,27 @@ int main()
+ std::to_string(unique_count) + "/" + std::to_string(unique_count) + "/"
+ std::to_string(valid_ids.size()) + ")"); + std::to_string(valid_ids.size()) + ")");
// 清理 // 清理 / cleanup
for (int id : valid_ids) loader.unload_plugin(id); for (int id : valid_ids) loader.unload_plugin(id);
} }
} }
// ======================================================================== // ========================================================================
// Block 4: 失败路径日志 — host_api->log 被调用 (F-18.3-2) // Block 4: 失败路径日志 — host_api->log 被调用 (F-18.3-2)
// Block 4: Failure-path logging — host_api->log is called (F-18.3-2)
// ======================================================================== // ========================================================================
std::cout << "\n--- Block 4: Failure-path logging (host_api->log) ---\n"; std::cout << "\n--- Block 4: Failure-path logging (host_api->log) ---\n";
{ {
dstalk::PluginLoader loader; dstalk::PluginLoader loader;
// 4.1: 无 host_api 时 load_plugin 失败不崩溃 // 4.1: 无 host_api 时 load_plugin 失败不崩溃 / load_plugin fails without crash when no host_api
reset_log_state(); reset_log_state();
int id = loader.load_plugin("bad_ext.noext"); int id = loader.load_plugin("bad_ext.noext");
CHECK(id == -1, "T4.1: load_plugin with invalid ext returns -1 (no host_api)"); CHECK(id == -1, "T4.1: load_plugin with invalid ext returns -1 (no host_api)");
CHECK(g_log_call_count == 0, CHECK(g_log_call_count == 0,
"T4.2: log NOT called when host_api_ is null"); "T4.2: log NOT called when host_api_ is null");
// 4.2: 设置 mock host_api 后验证 log 被调用 // 4.2: 设置 mock host_api 后验证 log 被调用 / set mock host_api and verify log is called
int init_ret = loader.initialize_all(&g_mock_host_api); int init_ret = loader.initialize_all(&g_mock_host_api);
CHECK(init_ret == 0, "T4.3: initialize_all with mock host_api returns 0"); CHECK(init_ret == 0, "T4.3: initialize_all with mock host_api returns 0");
@@ -263,7 +278,7 @@ int main()
"T4.6: log severity is DSTALK_LOG_ERROR"); "T4.6: log severity is DSTALK_LOG_ERROR");
std::cout << " log msg: " << g_last_log_msg << "\n"; std::cout << " log msg: " << g_last_log_msg << "\n";
// 4.3: LoadLibrary 失败也触发 log (文件不存在) // 4.3: LoadLibrary 失败也触发 log (文件不存在) / LoadLibrary failure also triggers log (file missing)
reset_log_state(); reset_log_state();
fs::path missing = get_plugins_dir() / "nonexistent_plugin.dll"; fs::path missing = get_plugins_dir() / "nonexistent_plugin.dll";
id = loader.load_plugin(missing.string().c_str()); id = loader.load_plugin(missing.string().c_str());
@@ -275,28 +290,29 @@ int main()
// ======================================================================== // ========================================================================
// Block 5: 边界 — 空 loader / 无效操作 // Block 5: 边界 — 空 loader / 无效操作
// Block 5: Edge cases — empty loader / invalid operations
// ======================================================================== // ========================================================================
std::cout << "\n--- Block 5: Edge cases — empty loader / invalid op ---\n"; std::cout << "\n--- Block 5: Edge cases — empty loader / invalid op ---\n";
{ {
dstalk::PluginLoader loader; dstalk::PluginLoader loader;
// T5.1: unload 不存在的 ID 返回 -1 // T5.1: unload 不存在的 ID 返回 -1 / unload non-existent ID returns -1
CHECK(loader.unload_plugin(42) == -1, CHECK(loader.unload_plugin(42) == -1,
"T5.1: unload_plugin(nonexistent) returns -1"); "T5.1: unload_plugin(nonexistent) returns -1");
// T5.2: 空 PluginLoader 的 list_plugins 返回 "[]" // T5.2: 空 PluginLoader 的 list_plugins 返回 "[]" / empty PluginLoader list_plugins returns "[]"
std::string json = loader.list_plugins(); std::string json = loader.list_plugins();
CHECK(!json.empty(), "T5.2: list_plugins returns non-empty string"); CHECK(!json.empty(), "T5.2: list_plugins returns non-empty string");
CHECK(json == "[]", "T5.3: empty loader produces empty JSON array"); CHECK(json == "[]", "T5.3: empty loader produces empty JSON array");
std::cout << " list_plugins (empty): " << json << "\n"; std::cout << " list_plugins (empty): " << json << "\n";
// T5.3: get_plugin 在空 loader 上返回 nullptr // T5.3: get_plugin 在空 loader 上返回 nullptr / get_plugin on empty loader returns nullptr
CHECK(loader.get_plugin(1) == nullptr, CHECK(loader.get_plugin(1) == nullptr,
"T5.4: get_plugin on empty loader returns nullptr"); "T5.4: get_plugin on empty loader returns nullptr");
} }
// ======================================================================== // ========================================================================
// 结果 // 结果 / Result
// ======================================================================== // ========================================================================
std::cout << "\n"; std::cout << "\n";
if (g_failures == 0) { if (g_failures == 0) {

View File

@@ -1,9 +1,10 @@
// ============================================================================ /*
// service_registry_test.cpp — ServiceRegistry 单元测试(补充覆盖,不与 host_api_test 重叠) * @file service_registry_test.cpp
// ============================================================================ * @brief ServiceRegistry unit tests (supplement to host_api_test): register,
// host_api_test 已覆盖: 重复注册(同名同版/同名异版)、查询不存在服务、版本不满足、 * query, version check, unregister, null-pointer safety, re-registration.
// shutdown 后查询。本测试补充边界与生命周期路径 * ServiceRegistry 单元测试host_api_test 补充):注册、查询、版本检查、取消注册、空指针安全、重新注册
// ============================================================================ * Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
@@ -12,6 +13,7 @@
// ---- 轻量断言 ---- // ---- 轻量断言 ----
static int g_failures = 0; static int g_failures = 0;
// Lightweight assertion helper: increments g_failures counter on failure
#define TCHECK(cond, msg) do { \ #define TCHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
std::cout << "[OK] " << (msg) << "\n"; \ std::cout << "[OK] " << (msg) << "\n"; \
@@ -21,7 +23,11 @@ static int g_failures = 0;
} \ } \
} while (0) } while (0)
// ============================================================ // ServiceRegistry 补充测试:空名称/虚表拒绝、完整生命周期(注册→查询→取消注册→查询为空)、
// 取消注册空指针安全、取消注册后重新注册、空名称查询。
// ServiceRegistry supplement tests: null-name/vtable rejection, full lifecycle
// (register->query->unregister->query nullptr), unregister nullptr safety,
// re-registration after unregister, and query with nullptr name.
int main() int main()
{ {
std::cout << "=== dstalk service_registry unit tests (supplement) ===\n\n"; std::cout << "=== dstalk service_registry unit tests (supplement) ===\n\n";
@@ -47,6 +53,7 @@ int main()
// ==================================================================== // ====================================================================
// Test 3: 完整生命周期 — register → query → unregister → query(nullptr) // Test 3: 完整生命周期 — register → query → unregister → query(nullptr)
// Test 3: full lifecycle — register → query → unregister → query(nullptr)
// ==================================================================== // ====================================================================
{ {
dstalk::ServiceRegistry reg; dstalk::ServiceRegistry reg;
@@ -66,6 +73,7 @@ int main()
// ==================================================================== // ====================================================================
// Test 4: unregister_service(nullptr name) 不崩溃(安全空操作) // Test 4: unregister_service(nullptr name) 不崩溃(安全空操作)
// Test 4: unregister_service(nullptr name) does not crash (safe no-op)
// ==================================================================== // ====================================================================
{ {
dstalk::ServiceRegistry reg; dstalk::ServiceRegistry reg;
@@ -75,6 +83,7 @@ int main()
// ==================================================================== // ====================================================================
// Test 5: 注册后重新注册同名 → 先 unregister 再 register 成功 // Test 5: 注册后重新注册同名 → 先 unregister 再 register 成功
// Test 5: re-register same name after unregister → succeeds
// ==================================================================== // ====================================================================
{ {
dstalk::ServiceRegistry reg; dstalk::ServiceRegistry reg;
@@ -101,7 +110,7 @@ int main()
} }
// ==================================================================== // ====================================================================
// 结果 // 结果 / Result
// ==================================================================== // ====================================================================
std::cout << "\n"; std::cout << "\n";
if (g_failures == 0) { if (g_failures == 0) {

View File

@@ -1,9 +1,12 @@
// ============================================================================ /*
// smoke_test.cpp — 插件化架构烟雾测试 * @file smoke_test.cpp
// ============================================================================ * @brief Basic smoke test: verifies dstalk_init/shutdown cycle, service queries,
// 测试: 核心初始化、插件加载、服务查询、file_iosession 功能 * file_io, session, null-safety, escape boundaries, tool chain, and
// W13.6 (qa-xu 徐磊): 新增 R1-R4 回归保护点,覆盖 W11.7/W12 已修 bug * regression protections R1-R4 (W13.6 qa-xu).
// ============================================================================ * 基础冒烟测试:验证 dstalk_init/shutdown 生命周期、服务查询、file_io、session、
* 空指针安全、转义边界、工具链调用,以及回归保护 R1-R4 (W13.6 qa-xu)。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
@@ -14,6 +17,7 @@
#include <string> #include <string>
// ---- 回归测试断言 (W13.6 qa-xu) ---- // ---- 回归测试断言 (W13.6 qa-xu) ----
// Regression test assertion macro (W13.6 qa-xu): prints [OK]/[FAIL] and tracks failures
static int g_regression_failures = 0; static int g_regression_failures = 0;
#define REGCHECK(cond, msg) do { \ #define REGCHECK(cond, msg) do { \
if (cond) { \ if (cond) { \
@@ -24,19 +28,26 @@ static int g_regression_failures = 0;
} \ } \
} while (0) } while (0)
// ---- W21.5 mock tool handler (qa-xu) ---- // W21.5 mock tool handler (qa-xu): increments call counter and returns mock result JSON
// W21.5 模拟工具处理函数 (qa-xu):递增调用计数器并返回模拟结果 JSON
static int g_mock_tool_called = 0; static int g_mock_tool_called = 0;
static char* mock_tool_handler(const char* /*args_json*/) { static char* mock_tool_handler(const char* /*args_json*/) {
g_mock_tool_called++; g_mock_tool_called++;
return dstalk_strdup("{\"mock_result\":\"ok\"}"); return dstalk_strdup("{\"mock_result\":\"ok\"}");
} }
// 冒烟测试主流程init → 服务查询 → file_io → session → ai → config
// 然后是扩展测试空指针安全、转义边界、工具链、session 健壮性),
// 接着是回归保护 R1-R3、W21.5 工具调用边界和 R4 生命周期循环。
// Smoke test main: init -> service queries -> file_io -> session -> ai -> config,
// then extended tests (null-safety, escape, tool chain, session robustness),
// then regression protections R1-R3, W21.5 tool-call boundaries, and R4 lifecycle cycles.
int main() int main()
{ {
const auto dir = std::filesystem::temp_directory_path() / "dstalk-smoke-test"; const auto dir = std::filesystem::temp_directory_path() / "dstalk-smoke-test";
std::filesystem::create_directories(dir); std::filesystem::create_directories(dir);
// 写一个配置文件用于初始化 // 写一个配置文件用于初始化 / Write a config file for initialization
const auto config_path = dir / "config.toml"; const auto config_path = dir / "config.toml";
{ {
std::ofstream config(config_path); std::ofstream config(config_path);
@@ -47,14 +58,14 @@ int main()
<< "model = \"deepseek-v4-pro\"\n"; << "model = \"deepseek-v4-pro\"\n";
} }
// 初始化主机(会自动扫描 plugins/ 加载插件) // 初始化主机(会自动扫描 plugins/ 加载插件)/ Init host (auto-scans plugins/ to load plugins)
if (dstalk_init(config_path.string().c_str()) != 0) { if (dstalk_init(config_path.string().c_str()) != 0) {
std::cerr << "dstalk_init failed\n"; std::cerr << "dstalk_init failed\n";
return 1; return 1;
} }
std::cout << "[OK] dstalk_init succeeded\n"; std::cout << "[OK] dstalk_init succeeded\n";
// 验证插件列表 // 验证插件列表 / Verify plugin list
{ {
char* list_json = nullptr; char* list_json = nullptr;
int ret = dstalk_plugin_list(&list_json); int ret = dstalk_plugin_list(&list_json);
@@ -66,13 +77,13 @@ int main()
} }
} }
// 测试服务查询: file_io // 测试服务查询: file_io / Test service query: file_io
auto* file_io = static_cast<const dstalk_file_io_service_t*>( auto* file_io = static_cast<const dstalk_file_io_service_t*>(
dstalk_service_query("file_io", 1)); dstalk_service_query("file_io", 1));
if (file_io) { if (file_io) {
std::cout << "[OK] file_io service found\n"; std::cout << "[OK] file_io service found\n";
// 测试写入 // 测试写入 / Test write
const auto file_path = dir / "sample.txt"; const auto file_path = dir / "sample.txt";
constexpr const char* sample_content = "hello dstalk\nquote=\"yes\" tab=\t slash=\\"; constexpr const char* sample_content = "hello dstalk\nquote=\"yes\" tab=\t slash=\\";
if (file_io->write(file_path.string().c_str(), sample_content) == 0) { if (file_io->write(file_path.string().c_str(), sample_content) == 0) {
@@ -83,7 +94,7 @@ int main()
return 1; return 1;
} }
// 测试读取 // 测试读取 / Test read
char* content = nullptr; char* content = nullptr;
if (file_io->read(file_path.string().c_str(), &content) == 0 && content) { if (file_io->read(file_path.string().c_str(), &content) == 0 && content) {
bool ok = std::strcmp(content, sample_content) == 0; bool ok = std::strcmp(content, sample_content) == 0;
@@ -104,13 +115,13 @@ int main()
std::cerr << "[WARN] file_io service not found (plugin may not be in plugins/ dir)\n"; std::cerr << "[WARN] file_io service not found (plugin may not be in plugins/ dir)\n";
} }
// 测试服务查询: session // 测试服务查询: session / Test service query: session
auto* session = static_cast<const dstalk_session_service_t*>( auto* session = static_cast<const dstalk_session_service_t*>(
dstalk_service_query("session", 1)); dstalk_service_query("session", 1));
if (session) { if (session) {
std::cout << "[OK] session service found\n"; std::cout << "[OK] session service found\n";
// 测试 session save/load // 测试 session save/load / Test session save/load
const auto session_path = dir / "session.jsonl"; const auto session_path = dir / "session.jsonl";
const auto saved_path = dir / "session-saved.jsonl"; const auto saved_path = dir / "session-saved.jsonl";
constexpr const char* session_content = constexpr const char* session_content =
@@ -137,7 +148,7 @@ int main()
return 1; return 1;
} }
// 验证保存的内容 // 验证保存的内容 / Verify saved content
if (file_io) { if (file_io) {
char* saved = nullptr; char* saved = nullptr;
if (file_io->read(saved_path.string().c_str(), &saved) == 0 && saved) { if (file_io->read(saved_path.string().c_str(), &saved) == 0 && saved) {
@@ -153,16 +164,16 @@ int main()
} }
} }
// 测试 token 计数 // 测试 token 计数 / Test token count
int tokens = session->token_count(); int tokens = session->token_count();
std::cout << "[OK] session->token_count: " << tokens << "\n"; std::cout << "[OK] session->token_count: " << tokens << "\n";
// 测试 history // 测试 history / Test history
int count = 0; int count = 0;
session->history(&count); session->history(&count);
std::cout << "[OK] session->history count: " << count << "\n"; std::cout << "[OK] session->history count: " << count << "\n";
// 测试 clear // 测试 clear / Test clear
session->clear(); session->clear();
session->history(&count); session->history(&count);
if (count == 0) { if (count == 0) {
@@ -173,6 +184,7 @@ int main()
} }
// 测试服务查询: ai可能因为没有真实 API key 而失败,但服务应存在) // 测试服务查询: ai可能因为没有真实 API key 而失败,但服务应存在)
// Test service query: ai (may fail without real API key, but service should exist)
const char* ai_provider = dstalk_config_get("ai.provider"); const char* ai_provider = dstalk_config_get("ai.provider");
if (!ai_provider) ai_provider = "ai.deepseek"; if (!ai_provider) ai_provider = "ai.deepseek";
auto* ai = static_cast<const dstalk_ai_service_t*>( auto* ai = static_cast<const dstalk_ai_service_t*>(
@@ -183,7 +195,7 @@ int main()
std::cerr << "[WARN] ai service not found\n"; std::cerr << "[WARN] ai service not found\n";
} }
// 测试服务查询: config // 测试服务查询: config / Test service query: config
auto* config_svc = static_cast<const dstalk_config_service_t*>( auto* config_svc = static_cast<const dstalk_config_service_t*>(
dstalk_service_query("config", 1)); dstalk_service_query("config", 1));
if (config_svc) { if (config_svc) {
@@ -196,21 +208,22 @@ int main()
std::cerr << "[WARN] config service not found\n"; std::cerr << "[WARN] config service not found\n";
} }
// 测试 dstalk_config_get主机级配置 API // 测试 dstalk_config_get主机级配置 API/ Test dstalk_config_get (host-level config API)
const char* model = dstalk_config_get("api.model"); const char* model = dstalk_config_get("api.model");
if (model) { if (model) {
std::cout << "[OK] dstalk_config_get(\"api.model\"): " << model << "\n"; std::cout << "[OK] dstalk_config_get(\"api.model\"): " << model << "\n";
} }
// 测试 dstalk_log // 测试 dstalk_log / Test dstalk_log
dstalk_log(DSTALK_LOG_INFO, "Smoke test completed successfully"); dstalk_log(DSTALK_LOG_INFO, "Smoke test completed successfully");
// ======================================================================== // ========================================================================
// 扩展测试块 C2: null-safety / 转义边界 / tools 调用链 / session 健壮性 // 扩展测试块 C2: null-safety / 转义边界 / tools 调用链 / session 健壮性
// Extended test block C2: null-safety / escape boundaries / tools chain / session robustness
// ======================================================================== // ========================================================================
std::cout << "\n--- Extended Smoke Tests (C2) ---\n"; std::cout << "\n--- Extended Smoke Tests (C2) ---\n";
// 提前查询 tools 服务,供后续测试块使用 // 提前查询 tools 服务,供后续测试块使用 / Pre-query tools service for subsequent test blocks
auto* tools = static_cast<const dstalk_tools_service_t*>( auto* tools = static_cast<const dstalk_tools_service_t*>(
dstalk_service_query("tools", 1)); dstalk_service_query("tools", 1));
@@ -234,7 +247,7 @@ int main()
std::cerr << "[FAIL] file_io->write(nullptr, ...) should return error\n"; std::cerr << "[FAIL] file_io->write(nullptr, ...) should return error\n";
} }
// read 的 content 参数也为 null // read 的 content 参数也为 null / read's content param also null
ret = file_io->read("dummy_path", nullptr); ret = file_io->read("dummy_path", nullptr);
if (ret != 0) { if (ret != 0) {
std::cout << "[OK] file_io->read(path, nullptr) returned error (" << ret << ")\n"; std::cout << "[OK] file_io->read(path, nullptr) returned error (" << ret << ")\n";
@@ -242,7 +255,7 @@ int main()
std::cerr << "[FAIL] file_io->read(path, nullptr) should return error\n"; std::cerr << "[FAIL] file_io->read(path, nullptr) should return error\n";
} }
// write 的 content 参数为 null // write 的 content 参数为 null / write's content param is null
ret = file_io->write("dummy_path", nullptr); ret = file_io->write("dummy_path", nullptr);
if (ret != 0) { if (ret != 0) {
std::cout << "[OK] file_io->write(path, nullptr) returned error (" << ret << ")\n"; std::cout << "[OK] file_io->write(path, nullptr) returned error (" << ret << ")\n";
@@ -278,6 +291,7 @@ int main()
char* result = tools->execute(nullptr, nullptr); char* result = tools->execute(nullptr, nullptr);
if (result) { if (result) {
// 实现返回了错误字符串(如 {"error":"tool name is null"}),未崩溃 // 实现返回了错误字符串(如 {"error":"tool name is null"}),未崩溃
// Implementation returned error string (e.g. {"error":"tool name is null"}), no crash
std::cout << "[OK] tools->execute(nullptr, nullptr) did not crash" std::cout << "[OK] tools->execute(nullptr, nullptr) did not crash"
<< " (returned: " << result << ")\n"; << " (returned: " << result << ")\n";
dstalk_free(result); dstalk_free(result);
@@ -303,7 +317,7 @@ int main()
std::cerr << "[FAIL] config->set(nullptr, nullptr) should return error\n"; std::cerr << "[FAIL] config->set(nullptr, nullptr) should return error\n";
} }
// set 的 value 为 null // set 的 value 为 null / set's value is null
ret = config_svc->set("some.key", nullptr); ret = config_svc->set("some.key", nullptr);
if (ret != 0) { if (ret != 0) {
std::cout << "[OK] config->set(key, nullptr) returned error (" << ret << ")\n"; std::cout << "[OK] config->set(key, nullptr) returned error (" << ret << ")\n";
@@ -316,6 +330,8 @@ int main()
// ---- 2. 转义边界测试 ---- // ---- 2. 转义边界测试 ----
// 写入含特殊字符的内容,读回后验证内容一致 // 写入含特殊字符的内容,读回后验证内容一致
// ---- Escape boundary tests ----
// Write content with special chars, verify round-trip integrity
std::cout << "\n[Block] Escape boundary tests\n"; std::cout << "\n[Block] Escape boundary tests\n";
if (file_io) { if (file_io) {
@@ -325,6 +341,12 @@ int main()
// - 实际反斜杠 (0x5C) // - 实际反斜杠 (0x5C)
// - 实际制表符 (0x09) // - 实际制表符 (0x09)
// - 以及字面上的 \n \" \\ \t 转义序列文本 // - 以及字面上的 \n \" \\ \t 转义序列文本
// Build content with various special bytes:
// - literal newline (0x0A)
// - literal double-quote (0x22)
// - literal backslash (0x5C)
// - literal tab (0x09)
// - plus textual \n \" \\ \t escape sequences
constexpr const char* escape_content = constexpr const char* escape_content =
"line1\nline2\n" "line1\nline2\n"
"quote=\"yes\"\n" "quote=\"yes\"\n"
@@ -363,22 +385,25 @@ int main()
// ---- 3. Tools 调用链测试 ---- // ---- 3. Tools 调用链测试 ----
// 通过 tools->execute("file_read", ...) 验证内置工具可正确调用 file_io // 通过 tools->execute("file_read", ...) 验证内置工具可正确调用 file_io
// ---- Tools call chain tests ----
// Verify built-in tools correctly call file_io via tools->execute("file_read", ...)
std::cout << "\n[Block] Tools call chain tests\n"; std::cout << "\n[Block] Tools call chain tests\n";
if (tools && file_io) { if (tools && file_io) {
// 准备测试文件 // 准备测试文件 / Prepare test file
const auto chain_path = dir / "tool_chain_test.txt"; const auto chain_path = dir / "tool_chain_test.txt";
constexpr const char* chain_content = "tools-chain-ok\n"; constexpr const char* chain_content = "tools-chain-ok\n";
file_io->write(chain_path.string().c_str(), chain_content); file_io->write(chain_path.string().c_str(), chain_content);
// 用 generic_string() 获取正斜杠路径,避免 JSON 中反斜杠转义问题 // 用 generic_string() 获取正斜杠路径,避免 JSON 中反斜杠转义问题
// Use generic_string() for forward-slash paths to avoid backslash escaping in JSON
std::string generic_path = chain_path.generic_string(); std::string generic_path = chain_path.generic_string();
std::string args_json = "{\"path\":\"" + generic_path + "\"}"; std::string args_json = "{\"path\":\"" + generic_path + "\"}";
char* result = tools->execute("file_read", args_json.c_str()); char* result = tools->execute("file_read", args_json.c_str());
if (result) { if (result) {
std::cout << "[OK] tools->execute(\"file_read\", ...) returned result\n"; std::cout << "[OK] tools->execute(\"file_read\", ...) returned result\n";
// 验证返回的 JSON 中包含原始文件内容 // 验证返回的 JSON 中包含原始文件内容 / Verify returned JSON contains original file content
if (std::strstr(result, "tools-chain-ok")) { if (std::strstr(result, "tools-chain-ok")) {
std::cout << "[OK] tools->execute chain correctly called file_io\n"; std::cout << "[OK] tools->execute chain correctly called file_io\n";
} else { } else {
@@ -391,7 +416,7 @@ int main()
<< " (tool may not be registered)\n"; << " (tool may not be registered)\n";
} }
// 额外测试:查询 tools 返回的工具列表 // 额外测试:查询 tools 返回的工具列表 / Additional test: query tools list
char* tools_json = tools->get_tools_json(); char* tools_json = tools->get_tools_json();
if (tools_json) { if (tools_json) {
std::cout << "[OK] tools->get_tools_json() returned: " << tools_json << "\n"; std::cout << "[OK] tools->get_tools_json() returned: " << tools_json << "\n";
@@ -406,14 +431,17 @@ int main()
// ---- 4. Session 健壮性测试 ---- // ---- 4. Session 健壮性测试 ----
// session->add(nullptr) 后验证 history 不变 // session->add(nullptr) 后验证 history 不变
// session->clear 后验证 token_count 为 0 // session->clear 后验证 token_count 为 0
// ---- Session robustness tests ----
// Verify history unchanged after session->add(nullptr)
// Verify token_count == 0 after session->clear
std::cout << "\n[Block] Session robustness tests\n"; std::cout << "\n[Block] Session robustness tests\n";
if (session) { if (session) {
// 记录 add(nullptr) 前的 history 计数 // 记录 add(nullptr) 前的 history 计数 / Record history count before add(nullptr)
int count_before = 0; int count_before = 0;
session->history(&count_before); session->history(&count_before);
// 传 null 不应改变 history // 传 null 不应改变 history / Passing null should not change history
session->add(nullptr); session->add(nullptr);
int count_after = 0; int count_after = 0;
@@ -427,7 +455,7 @@ int main()
<< count_before << " -> " << count_after << "\n"; << count_before << " -> " << count_after << "\n";
} }
// clear 后 token_count 应为 0 // clear 后 token_count 应为 0 / token_count should be 0 after clear
session->clear(); session->clear();
int tokens = session->token_count(); int tokens = session->token_count();
if (tokens == 0) { if (tokens == 0) {
@@ -443,6 +471,8 @@ int main()
// ======================================================================== // ========================================================================
// W13.6 回归保护点 R1-R3 (qa-xu 徐磊) // W13.6 回归保护点 R1-R3 (qa-xu 徐磊)
// 覆盖: W11.7 BUG-2/3/4 + W11.1 Discovery 2/3 + W12.2/W12.3 修复 // 覆盖: W11.7 BUG-2/3/4 + W11.1 Discovery 2/3 + W12.2/W12.3 修复
// W13.6 regression protections R1-R3 (qa-xu)
// Covers: W11.7 BUG-2/3/4 + W11.1 Discovery 2/3 + W12.2/W12.3 fixes
// ======================================================================== // ========================================================================
std::cout << "\n--- Regression Tests (R1-R3: W11.7/W12 bug protection) ---\n"; std::cout << "\n--- Regression Tests (R1-R3: W11.7/W12 bug protection) ---\n";
@@ -450,6 +480,10 @@ int main()
// 回归: W11.1 Discovery 3 (g_max_tokens 死变量 — W12.3 已修, W18.1 彻底移除) // 回归: W11.1 Discovery 3 (g_max_tokens 死变量 — W12.3 已修, W18.1 彻底移除)
// W11.7 BUG-3 (/context 静默 — W12.3 已修) // W11.7 BUG-3 (/context 静默 — W12.3 已修)
// 验证: trim 能正确裁剪消息数,调用链完整不崩溃 // 验证: trim 能正确裁剪消息数,调用链完整不崩溃
// ---- R1: context max_tokens takes effect ----
// Regression: W11.1 Discovery 3 (g_max_tokens dead var — fixed W12.3, removed W18.1)
// W11.7 BUG-3 (/context silent — fixed W12.3)
// Verify: trim reduces message count correctly, full call chain without crash
{ {
auto* ctx = static_cast<const dstalk_context_service_t*>( auto* ctx = static_cast<const dstalk_context_service_t*>(
dstalk_service_query("context", 1)); dstalk_service_query("context", 1));
@@ -457,6 +491,7 @@ int main()
std::cout << "[OK] R1: context service found\n"; std::cout << "[OK] R1: context service found\n";
// 构造 5 条消息,每条 ~50 字符 / ~15 token总计 ~75 token > 50 max // 构造 5 条消息,每条 ~50 字符 / ~15 token总计 ~75 token > 50 max
// Build 5 messages, each ~50 chars / ~15 tokens, total ~75 tokens > 50 max
dstalk_message_t msgs[5]; dstalk_message_t msgs[5];
msgs[0] = {"user", "Hello this is message one with enough text to count tokens", nullptr, nullptr}; msgs[0] = {"user", "Hello this is message one with enough text to count tokens", nullptr, nullptr};
msgs[1] = {"assistant", "Message two also has sufficient length for token counting", nullptr, nullptr}; msgs[1] = {"assistant", "Message two also has sufficient length for token counting", nullptr, nullptr};
@@ -476,6 +511,7 @@ int main()
dstalk_free(out); dstalk_free(out);
} else if (ret >= 0) { } else if (ret >= 0) {
// 首条消息即超 max_tokens 时 trim 可能返回空,这也是合法路径 // 首条消息即超 max_tokens 时 trim 可能返回空,这也是合法路径
// When first message exceeds max_tokens, trim may return empty; also valid
std::cout << "[WARN] R1: trim returned null output (single msg exceeds max?)\n"; std::cout << "[WARN] R1: trim returned null output (single msg exceeds max?)\n";
} }
} else { } else {
@@ -487,15 +523,19 @@ int main()
// 回归: W11.2 Discovery 2 (双 ConfigStore 数据孤岛 — W12.2 已修) // 回归: W11.2 Discovery 2 (双 ConfigStore 数据孤岛 — W12.2 已修)
// W11.2 Discovery 3 (c_str() 悬垂 — W12.2 已修) // W11.2 Discovery 3 (c_str() 悬垂 — W12.2 已修)
// 验证: dstalk_config_set 写入后dstalk_config_get 和 config_service->get 返回一致值 // 验证: dstalk_config_set 写入后dstalk_config_get 和 config_service->get 返回一致值
// ---- R2: config dual-store consistency ----
// Regression: W11.2 Discovery 2 (dual ConfigStore islands — fixed W12.2)
// W11.2 Discovery 3 (c_str() dangling — fixed W12.2)
// Verify: after dstalk_config_set write, dstalk_config_get and config_service->get return same value
{ {
constexpr const char* k = "__regr_w13_6_dual"; constexpr const char* k = "__regr_w13_6_dual";
constexpr const char* v = "dual_ok_42"; constexpr const char* v = "dual_ok_42";
// 通过 host API 写入 // 通过 host API 写入 / Write via host API
int set_ret = dstalk_config_set(k, v); int set_ret = dstalk_config_set(k, v);
REGCHECK(set_ret == 0, "R2: dstalk_config_set returned 0"); REGCHECK(set_ret == 0, "R2: dstalk_config_set returned 0");
// 通过 host API 读回 // 通过 host API 读回 / Read back via host API
const char* host_val = dstalk_config_get(k); const char* host_val = dstalk_config_get(k);
REGCHECK(host_val && std::strcmp(host_val, v) == 0, REGCHECK(host_val && std::strcmp(host_val, v) == 0,
"R2: dstalk_config_get matches written value"); "R2: dstalk_config_get matches written value");
@@ -503,6 +543,9 @@ int main()
// 通过 plugin config 服务读回 — 验证双 store 整合后数据可见性一致 // 通过 plugin config 服务读回 — 验证双 store 整合后数据可见性一致
// 注: W12.2 双 store 整合尚未部署,跨 store 可见性当前为已知 gap // 注: W12.2 双 store 整合尚未部署,跨 store 可见性当前为已知 gap
// 本检查用 WARN 记录现状,待 W12.2 fix 落地后改为 REGCHECK // 本检查用 WARN 记录现状,待 W12.2 fix 落地后改为 REGCHECK
// Read back via plugin config service — verify visibility after dual-store merge
// Note: W12.2 dual-store merge not yet deployed; cross-store visibility is a known gap;
// this check uses WARN to record status, upgrade to REGCHECK after W12.2 lands
auto* cfg_svc = static_cast<const dstalk_config_service_t*>( auto* cfg_svc = static_cast<const dstalk_config_service_t*>(
dstalk_service_query("config", 1)); dstalk_service_query("config", 1));
if (cfg_svc) { if (cfg_svc) {
@@ -520,7 +563,7 @@ int main()
std::cerr << "[WARN] R2: config service not found, partial skip\n"; std::cerr << "[WARN] R2: config service not found, partial skip\n";
} }
// 清理测试 key // 清理测试 key / Clean up test key
dstalk_config_set(k, ""); dstalk_config_set(k, "");
} }
@@ -529,6 +572,11 @@ int main()
// W11.7 BUG-4 (/file write 落空) 同类的错误路径静默问题 // W11.7 BUG-4 (/file write 落空) 同类的错误路径静默问题
// 验证: http post_json 到不可达目标返回错误而不崩溃; // 验证: http post_json 到不可达目标返回错误而不崩溃;
// 若 http 服务不可用,回退测 ai 服务错误路径 // 若 http 服务不可用,回退测 ai 服务错误路径
// ---- R3: HTTP / AI service error paths do not crash ----
// Regression: W12.1 removed TLS/http_client code (removed rewritten network layer)
// W11.7 BUG-4 (/file write miss) similar error-path silent issues
// Verify: http post_json to unreachable target returns error without crash;
// fall back to ai service error path if http unavailable
{ {
auto* http = static_cast<const dstalk_http_service_t*>( auto* http = static_cast<const dstalk_http_service_t*>(
dstalk_service_query("http", 1)); dstalk_service_query("http", 1));
@@ -536,6 +584,8 @@ int main()
std::cout << "[OK] R3: http service found\n"; std::cout << "[OK] R3: http service found\n";
// 向 127.0.0.1:1 发请求 — 端口 1 在 Windows 上几乎肯定无服务监听 // 向 127.0.0.1:1 发请求 — 端口 1 在 Windows 上几乎肯定无服务监听
// 连接拒绝应立即返回错误而非崩溃 // 连接拒绝应立即返回错误而非崩溃
// Send request to 127.0.0.1:1 — port 1 on Windows almost certainly has no listener
// Connection refused should return error immediately, not crash
char* body = nullptr; char* body = nullptr;
int status = 0; int status = 0;
int ret = http->post_json("127.0.0.1", "1", "/", int ret = http->post_json("127.0.0.1", "1", "/",
@@ -549,6 +599,7 @@ int main()
} }
} else { } else {
// 回退:测 AI 服务 (ai.deepseek) 错误路径 // 回退:测 AI 服务 (ai.deepseek) 错误路径
// Fallback: test AI service (ai.deepseek) error path
auto* ai_svc = static_cast<const dstalk_ai_service_t*>( auto* ai_svc = static_cast<const dstalk_ai_service_t*>(
dstalk_service_query("ai.deepseek", 1)); dstalk_service_query("ai.deepseek", 1));
if (ai_svc) { if (ai_svc) {
@@ -556,6 +607,7 @@ int main()
dstalk_message_t msg = {"user", "hi", nullptr, nullptr}; dstalk_message_t msg = {"user", "hi", nullptr, nullptr};
dstalk_chat_result_t r = ai_svc->chat(&msg, 1, "", nullptr); dstalk_chat_result_t r = ai_svc->chat(&msg, 1, "", nullptr);
// api_key="test-key" 为无效 key应返回 error result 而非崩溃 // api_key="test-key" 为无效 key应返回 error result 而非崩溃
// api_key="test-key" is invalid, should return error result, not crash
REGCHECK(r.ok == 0 || r.error != nullptr, REGCHECK(r.ok == 0 || r.error != nullptr,
"R3: ai->chat with invalid key returned error result (no crash)"); "R3: ai->chat with invalid key returned error result (no crash)");
if (r.content) dstalk_free((void*)r.content); if (r.content) dstalk_free((void*)r.content);
@@ -570,11 +622,14 @@ int main()
// ======================================================================== // ========================================================================
// W21.5 Tool Calls 边界测试 (qa-xu 徐磊) // W21.5 Tool Calls 边界测试 (qa-xu 徐磊)
// 覆盖: null tool_calls_json / 空数组 "[]" / 有效 tool_calls mock 验证 // 覆盖: null tool_calls_json / 空数组 "[]" / 有效 tool_calls mock 验证
// W21.5 Tool Calls boundary tests (qa-xu)
// Covers: null tool_calls_json / empty array "[]" / valid tool_calls mock verification
// ======================================================================== // ========================================================================
std::cout << "\n--- Tool Calls Boundary Tests (W21.5) ---\n"; std::cout << "\n--- Tool Calls Boundary Tests (W21.5) ---\n";
if (tools && session) { if (tools && session) {
// ---- W21.5-1: null tool_calls_json → 正常处理(不崩溃)---- // ---- W21.5-1: null tool_calls_json → 正常处理(不崩溃)----
// ---- W21.5-1: null tool_calls_json → handle normally (no crash) ----
{ {
int before = 0; int before = 0;
session->history(&before); session->history(&before);
@@ -595,6 +650,7 @@ int main()
} }
// ---- W21.5-2: 空 JSON 数组 "[]" → 正常处理(不崩溃)---- // ---- W21.5-2: 空 JSON 数组 "[]" → 正常处理(不崩溃)----
// ---- W21.5-2: empty JSON array "[]" → handle normally (no crash) ----
{ {
int before = 0; int before = 0;
session->history(&before); session->history(&before);
@@ -616,6 +672,7 @@ int main()
} }
// ---- W21.5-3: 有效 tool_calls JSON → 验证 execute 被调用 (mock) ---- // ---- W21.5-3: 有效 tool_calls JSON → 验证 execute 被调用 (mock) ----
// ---- W21.5-3: valid tool_calls JSON → verify execute is called (mock) ----
{ {
g_mock_tool_called = 0; g_mock_tool_called = 0;
int reg = tools->register_tool( int reg = tools->register_tool(
@@ -638,6 +695,7 @@ int main()
tools->unregister_tool("__w21_5_mock"); tools->unregister_tool("__w21_5_mock");
// 验证已注销的工具返回 error 而非崩溃 // 验证已注销的工具返回 error 而非崩溃
// Verify unregistered tool returns error, not crash
char* err_result = tools->execute("__w21_5_mock", "{}"); char* err_result = tools->execute("__w21_5_mock", "{}");
REGCHECK(err_result && std::strstr(err_result, "error") != nullptr, REGCHECK(err_result && std::strstr(err_result, "error") != nullptr,
"W21.5-3d: unregistered tool returns error (not crash)"); "W21.5-3d: unregistered tool returns error (not crash)");
@@ -645,6 +703,7 @@ int main()
} }
// ---- W21.5-4: save/load 往返保留 tool_calls_json ---- // ---- W21.5-4: save/load 往返保留 tool_calls_json ----
// ---- W21.5-4: save/load round-trip preserves tool_calls_json ----
if (file_io) { if (file_io) {
const auto rtt_path = dir / "w21_5_tc_rtt.jsonl"; const auto rtt_path = dir / "w21_5_tc_rtt.jsonl";
int ret = session->save(rtt_path.string().c_str()); int ret = session->save(rtt_path.string().c_str());
@@ -664,23 +723,28 @@ int main()
std::cerr << "[WARN] W21.5: tools or session service not available\n"; std::cerr << "[WARN] W21.5: tools or session service not available\n";
} }
// 清理 // 清理 / Cleanup
dstalk_shutdown(); dstalk_shutdown();
std::cout << "[OK] dstalk_shutdown succeeded\n"; std::cout << "[OK] dstalk_shutdown succeeded\n";
// ======================================================================== // ========================================================================
// W13.6 回归保护点 R4 (qa-xu 徐磊) // W13.6 回归保护点 R4 (qa-xu 徐磊)
// W13.6 regression protection R4 (qa-xu)
// ======================================================================== // ========================================================================
// ---- R4: 重复 init / shutdown 生命周期 ---- // ---- R4: 重复 init / shutdown 生命周期 ----
// 回归: W9.8 initialize_all 容错 (插件生命周期健壮性) // 回归: W9.8 initialize_all 容错 (插件生命周期健壮性)
// W11.7 BUG-1 [CRITICAL] build/bin/ 损坏副本 (stale state 残留) // W11.7 BUG-1 [CRITICAL] build/bin/ 损坏副本 (stale state 残留)
// 验证: 多次 dstalk_init/dstalk_shutdown 循环不崩溃,每次 reload 正常 // 验证: 多次 dstalk_init/dstalk_shutdown 循环不崩溃,每次 reload 正常
// ---- R4: repeat init/shutdown lifecycle ----
// Regression: W9.8 initialize_all fault tolerance (plugin lifecycle robustness)
// W11.7 BUG-1 [CRITICAL] build/bin/ corrupt copy (stale state residue)
// Verify: multiple dstalk_init/dstalk_shutdown cycles without crash, each reload ok
{ {
std::cout << "\n[Block] R4: Repeat init/shutdown lifecycle\n"; std::cout << "\n[Block] R4: Repeat init/shutdown lifecycle\n";
constexpr int cycles = 3; constexpr int cycles = 3;
for (int i = 0; i < cycles; i++) { for (int i = 0; i < cycles; i++) {
// 每轮重写配置(模拟独立启动) // 每轮重写配置(模拟独立启动)/ Rewrite config each cycle (simulate independent start)
{ {
std::ofstream c(config_path); std::ofstream c(config_path);
c << "[api]\n" c << "[api]\n"
@@ -700,7 +764,7 @@ int main()
break; break;
} }
// 快速验证服务可用 // 快速验证服务可用 / Quick verify service is available
void* q = dstalk_service_query("config", 1); void* q = dstalk_service_query("config", 1);
REGCHECK(q != nullptr, "R4: service query ok after init"); REGCHECK(q != nullptr, "R4: service query ok after init");
@@ -710,7 +774,7 @@ int main()
} }
} }
// ---- 最终结果 ---- // ---- 最终结果 / Final result ----
std::cout << "\n"; std::cout << "\n";
if (g_regression_failures == 0) { if (g_regression_failures == 0) {
std::cout << "=== All smoke tests passed ===\n"; std::cout << "=== All smoke tests passed ===\n";

View File

View File

@@ -0,0 +1,39 @@
此文件不可AI修改
此文件不可AI修改
此文件不可AI修改
说明此文件包含重要信息禁止使用AI进行修改。
dstalk名称中的ds来源于deepseek和display希望能够实现交流的简洁为人类提供更好的服务。
dstalk是基于C/C++开发的基础cli接口,提供高性能接口以方便开发者进行二次开发,同时提供了丰富的接口以满足不同的需求。
dstalk基于多模块的自动加载和自我依赖完全自下而上设计
所有基于dstalk的模块均可以跟随dstalk实现跨平台接入和相互依赖自动加载等。
dstalk更像是一种编程框架和基础系统支持。
dstalk本质上只有dstalk网关是核心所有模块的沟通和相互依赖以及使用都是通过dstalk网关来实现的。
dstalk提供以下默认安装的模块方便大多数模块的开发和使用
1. 基础模块:提供了基础的输入输出、文件操作、网络通信等接口,方便开发者进行二次开发。
2. 模块管理模块:提供了模块化的接口,方便开发者进行模块化开发和管理。
3. AI接入Openai兼容格式模块提供了AI接入openai兼容格式的接口方便开发者进行AI接入和使用。
4. AI接入Anthropic兼容格式模块提供了AI接入Anthropic兼容格式的接口方便开发者进行AI接入和使用。
5. Anthropic和Openai兼容格式转换模块提供了Anthropic和Openai兼容格式转换的接口方便开发者进行格式转换和使用。
依赖于基础模块的模块:
1. AI接入自动识别接口依赖于基础模块提供了AI接入所有格式的接口方便开发者进行AI接入和使用。
dstalk类似openclaw的geteway可以实现相互的沟通和控制一切操作基于dstalk源于模块用于模块服务于模块。
dstalk的设计理念是模块化、自动加载、自我依赖、跨平台、易用性和高性能旨在为开发者提供一个强大而灵活的基础系统支持
以便他们能够专注于开发自己的应用程序,而不必担心底层的细节。
dstalk的目标是成为一个强大而灵活的操作系统帮助开发者更高效地开发应用程序以满足不同的需求。
newapi测试数据
"key":"sk-DWiHMg4T3cIxWUSwRGtjLuPe1c8FuwM0FiGyoyuNFWGpkhjY"
"url":"https://api.ai.pulsareon.com"