Add metadata validation script and module documentation
- Introduced a new Python script `check_agents_metadata.py` for validating agent metadata, including YAML parsing, rating ranges, and cross-references. - Added usage instructions and exit codes for the script. - Created a new markdown file `模块目录和功能说明.md` to outline the directory structure and functionality of the modules. - Added a text file `说明此文件不可AI修改.txt` to specify that certain files should not be modified by AI, including important information about the `dstalk` framework and its modules.
This commit is contained in:
@@ -8,6 +8,7 @@ set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_C_STANDARD_REQUIRED ON)
|
||||
|
||||
option(DSTALK_BUILD_GUI "Build the SDL3 GUI frontend" OFF)
|
||||
option(DSTALK_BUILD_WEB "Build the web UI frontend" OFF)
|
||||
option(DSTALK_BUILD_TESTS "Build dstalk tests" ON)
|
||||
|
||||
add_subdirectory(dstalk-core)
|
||||
@@ -18,6 +19,10 @@ if(DSTALK_BUILD_GUI)
|
||||
add_subdirectory(dstalk-gui)
|
||||
endif()
|
||||
|
||||
if(DSTALK_BUILD_WEB)
|
||||
add_subdirectory(dstalk-web)
|
||||
endif()
|
||||
|
||||
if(DSTALK_BUILD_TESTS)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
// ============================================================================
|
||||
// dstalk-cli — 命令行前端 (使用插件化架构)
|
||||
// ============================================================================
|
||||
// 通过 dstalk_host.h API 初始化核心,然后查询插件服务 vtable 调用功能。
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file main.cpp
|
||||
* @brief CLI frontend for dstalk: ANSI terminal UI, command parsing, streaming chat, tool calling loop, batch/pipe mode.
|
||||
* dstalk 命令行前端:ANSI 终端界面、命令解析、流式对话、工具调用循环、批处理/管道模式。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
@@ -28,7 +29,7 @@
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
|
||||
// ---- ANSI 简写 ----
|
||||
// ---- ANSI 简写 / ANSI shorthand macros ----
|
||||
#define CLR_RESET "\033[0m"
|
||||
#define CLR_CYAN "\033[36m"
|
||||
#define CLR_YELLOW "\033[33m"
|
||||
@@ -37,25 +38,36 @@
|
||||
#define CLR_DIM "\033[2m"
|
||||
#define CLR_BOLD "\033[1m"
|
||||
|
||||
// ---- 退出码 ----
|
||||
// ---- 退出码 / Exit codes ----
|
||||
// 0=正常退出 1=用户中断(SIGINT/Ctrl+C) 2=致命错误 3=配置错误
|
||||
// 0=normal 1=user interrupt (SIGINT/Ctrl+C) 2=fatal error 3=config error
|
||||
#define EXIT_OK 0
|
||||
#define EXIT_INTERRUPT 1
|
||||
#define EXIT_FATAL 2
|
||||
#define EXIT_CONFIG 3
|
||||
|
||||
// ---- 服务 vtable 指针 ----
|
||||
// ---- 服务 vtable 指针 / Service vtable pointers ----
|
||||
// Global pointers to plugin service vtables, queried from the host on startup.
|
||||
// 插件服务 vtable 的全局指针,在启动时从主机查询获取。
|
||||
static const dstalk_ai_service_t* g_ai = nullptr;
|
||||
static const dstalk_session_service_t* g_session = nullptr;
|
||||
static const dstalk_file_io_service_t* g_file_io = nullptr;
|
||||
static const dstalk_tools_service_t* g_tools = nullptr;
|
||||
|
||||
// ---- 运行时状态 ----
|
||||
// ---- 运行时状态 / Runtime state ----
|
||||
// g_current_model tracks the active model name for display in the prompt.
|
||||
// g_quit_requested signals the main loop to exit (set by /quit or Ctrl+C).
|
||||
// g_quit_via_signal distinguishes SIGINT-triggered exit from normal /quit.
|
||||
// g_current_model 记录当前模型名称,用于提示符显示。
|
||||
// g_quit_requested 通知主循环退出(由 /quit 或 Ctrl+C 设置)。
|
||||
// g_quit_via_signal 区分 SIGINT 触发的退出和正常的 /quit 退出。
|
||||
static std::string g_current_model;
|
||||
static std::atomic<bool> g_quit_requested{false};
|
||||
static std::atomic<bool> g_quit_via_signal{false};
|
||||
|
||||
// ---- Ctrl+C 信号处理 ----
|
||||
// ---- Ctrl+C 信号处理 / Ctrl+C signal handlers ----
|
||||
// Windows console event handler (CTRL_C_EVENT / CTRL_BREAK_EVENT).
|
||||
// Windows 控制台事件处理(CTRL_C_EVENT / CTRL_BREAK_EVENT)。
|
||||
#ifdef _WIN32
|
||||
static BOOL WINAPI on_console_event(DWORD event)
|
||||
{
|
||||
@@ -66,6 +78,8 @@ static BOOL WINAPI on_console_event(DWORD event)
|
||||
}
|
||||
return FALSE;
|
||||
}
|
||||
// Unix signal handler (SIGINT).
|
||||
// Unix 信号处理(SIGINT)。
|
||||
#else
|
||||
static void on_signal(int /*sig*/)
|
||||
{
|
||||
@@ -74,7 +88,9 @@ static void on_signal(int /*sig*/)
|
||||
}
|
||||
#endif
|
||||
|
||||
// ---- 工具函数 ----
|
||||
// ---- 工具函数 / Utility functions ----
|
||||
|
||||
// 打印启动横幅 / Print the dstalk CLI banner with version, AI indicator, and quick command hints.
|
||||
static void print_banner()
|
||||
{
|
||||
std::printf("%sdstalk v0.1.0%s | %sdstalk AI%s | "
|
||||
@@ -85,6 +101,7 @@ static void print_banner()
|
||||
CLR_DIM, CLR_RESET);
|
||||
}
|
||||
|
||||
// 打印帮助文本 / Print the full help text listing all available slash commands.
|
||||
static void print_help()
|
||||
{
|
||||
std::printf("\n%s命令列表:%s\n", CLR_BOLD, CLR_RESET);
|
||||
@@ -104,6 +121,7 @@ static void print_help()
|
||||
std::printf("\n直接输入问题即可与 AI 对话。\n\n");
|
||||
}
|
||||
|
||||
// 通过 file_io 服务读取并显示文件内容 / Read and display the contents of the file at the given path via the file_io service.
|
||||
static void print_file(const char* path)
|
||||
{
|
||||
while (*path == ' ') path++;
|
||||
@@ -122,6 +140,7 @@ static void print_file(const char* path)
|
||||
}
|
||||
}
|
||||
|
||||
// 列出目录内容,按文件名排序,子目录以青色高亮 / List directory entries sorted by filename, highlighting subdirectories in cyan.
|
||||
static void list_files(const char* path)
|
||||
{
|
||||
while (*path == ' ') path++;
|
||||
@@ -155,11 +174,12 @@ static void list_files(const char* path)
|
||||
}
|
||||
}
|
||||
|
||||
// 分发斜杠命令 / Dispatch a slash-command string: /quit, /help, /clear, /context, /status, /model, /file, /history, /save, /load.
|
||||
static void handle_command(const char* line)
|
||||
{
|
||||
if (!line || line[0] != '/') return;
|
||||
|
||||
// /quit —— 设置退出标志,让控制流自然回到 main 末尾
|
||||
// /quit —— 设置退出标志,让控制流自然回到 main 末尾 / Set quit flag to let control flow naturally return to end of main
|
||||
if (std::strcmp(line, "/quit") == 0 || std::strcmp(line, "/q") == 0) {
|
||||
g_quit_requested = true;
|
||||
return;
|
||||
@@ -197,7 +217,7 @@ static void handle_command(const char* line)
|
||||
return;
|
||||
}
|
||||
|
||||
// /status —— 脱敏显示当前运行状态
|
||||
// /status —— 脱敏显示当前运行状态 / Display current runtime status (desensitized)
|
||||
if (std::strcmp(line, "/status") == 0) {
|
||||
const char* provider = dstalk_config_get("ai.provider");
|
||||
if (!provider) provider = "ai.deepseek";
|
||||
@@ -246,7 +266,7 @@ static void handle_command(const char* line)
|
||||
return;
|
||||
}
|
||||
|
||||
// /file <subcommand> [args...] —— 统一入口,避免 strncmp 空格匹配遗漏
|
||||
// /file <subcommand> [args...] —— 统一入口,避免 strncmp 空格匹配遗漏 / Unified entry to avoid strncmp space matching issues
|
||||
if (std::strncmp(line, "/file", 5) == 0) {
|
||||
const char* rest = line + 5;
|
||||
while (*rest == ' ') rest++;
|
||||
@@ -370,7 +390,8 @@ static void handle_command(const char* line)
|
||||
std::printf(CLR_RED "未知命令: %s (输入 /help 查看帮助)\n" CLR_RESET, line);
|
||||
}
|
||||
|
||||
// ---- 流式回调 ----
|
||||
// ---- 流式回调 / Streaming callback ----
|
||||
// 流式输出回调:每收到一个 token 打印到 stdout 并刷新 / Callback invoked for each token during streaming chat; prints the token to stdout and flushes.
|
||||
static int on_stream_token(const char* token, void* userdata)
|
||||
{
|
||||
bool* first = static_cast<bool*>(userdata);
|
||||
@@ -383,10 +404,12 @@ static int on_stream_token(const char* token, void* userdata)
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ---- 主程序 ----
|
||||
// ---- 主程序 / Main entry point ----
|
||||
// 入口:初始化 dstalk host,查询插件服务,处理 batch/pipe/交互模式。
|
||||
// Entry point: initializes dstalk host, queries plugin services, handles batch/pipe/interactive modes.
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// Windows: 启用 ANSI 转义码支持
|
||||
// Windows: 启用 ANSI 转义码支持 / Windows: enable ANSI escape code support
|
||||
#ifdef _WIN32
|
||||
HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE);
|
||||
DWORD mode = 0;
|
||||
@@ -394,7 +417,7 @@ int main(int argc, char* argv[])
|
||||
SetConsoleMode(hOut, mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING);
|
||||
#endif
|
||||
|
||||
// ---- C1: batch/pipe 模式检测 ----
|
||||
// ---- C1: batch/pipe 模式检测 / batch/pipe mode detection ----
|
||||
#ifdef _WIN32
|
||||
bool pipe_mode = (_isatty(_fileno(stdin)) == 0);
|
||||
#else
|
||||
@@ -421,17 +444,17 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
if (pipe_mode) batch_mode = true;
|
||||
|
||||
// ---- B1: 安装 Ctrl+C 处理 ----
|
||||
// ---- B1: 安装 Ctrl+C 处理 / Install Ctrl+C handlers ----
|
||||
#ifdef _WIN32
|
||||
SetConsoleCtrlHandler(on_console_event, TRUE);
|
||||
#else
|
||||
signal(SIGINT, on_signal);
|
||||
#endif
|
||||
|
||||
// 查找配置文件
|
||||
// 查找配置文件 / Locate config file
|
||||
const char* config_path = nullptr;
|
||||
if (argc >= 2) {
|
||||
// 跳过 --batch / --prompt 标志
|
||||
// 跳过 --batch / --prompt 标志 / Skip --batch / --prompt flags
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
if (std::strcmp(argv[i], "--batch") != 0 && std::strcmp(argv[i], "--prompt") != 0) {
|
||||
config_path = argv[i];
|
||||
@@ -457,13 +480,13 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化主机(加载配置 + 自动扫描 plugins/ 目录加载插件)
|
||||
// 初始化主机(加载配置 + 自动扫描 plugins/ 目录加载插件) / Init host: load config + auto-scan plugins/ directory
|
||||
if (dstalk_init(config_path) != 0) {
|
||||
std::fprintf(stderr, CLR_RED "[dstalk] 初始化失败\n" CLR_RESET);
|
||||
return EXIT_CONFIG;
|
||||
}
|
||||
|
||||
// 查询插件服务
|
||||
// 查询插件服务 / Query plugin services
|
||||
const char* ai_provider = dstalk_config_get("ai.provider");
|
||||
if (!ai_provider) ai_provider = "ai.deepseek";
|
||||
g_ai = static_cast<const dstalk_ai_service_t*>(dstalk_service_query(ai_provider, 1));
|
||||
@@ -478,7 +501,7 @@ int main(int argc, char* argv[])
|
||||
std::fprintf(stderr, CLR_RED "[dstalk] Session 服务未找到\n" CLR_RESET);
|
||||
}
|
||||
|
||||
// 自动从配置加载 AI 设置
|
||||
// 自动从配置加载 AI 设置 / Auto-load AI settings from config
|
||||
if (g_ai) {
|
||||
const char* base_url = dstalk_config_get("api.base_url");
|
||||
const char* api_key = dstalk_config_get("api.api_key");
|
||||
@@ -486,7 +509,7 @@ int main(int argc, char* argv[])
|
||||
if (!base_url) base_url = "https://api.deepseek.com/v1";
|
||||
if (!model) model = "deepseek-v4-pro";
|
||||
g_ai->configure(ai_provider, base_url, api_key ? api_key : "", model, 4096, 0.7);
|
||||
g_current_model = model; // A1: 记录当前模型名
|
||||
g_current_model = model; // A1: 记录当前模型名 / Record current model name
|
||||
}
|
||||
|
||||
if (!batch_mode) {
|
||||
@@ -495,7 +518,7 @@ int main(int argc, char* argv[])
|
||||
std::printf("\n");
|
||||
}
|
||||
|
||||
// ---- B3: 管道输入模式 (非交互) ----
|
||||
// ---- B3: 管道输入模式 (非交互) / Pipe input mode (non-interactive) ----
|
||||
if (pipe_mode) {
|
||||
std::string input;
|
||||
char buf[4096];
|
||||
@@ -529,11 +552,11 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
// ---- --prompt 批处理模式 (非交互) ----
|
||||
// ---- --prompt 批处理模式 (非交互) / --prompt batch mode (non-interactive) ----
|
||||
if (prompt_arg) {
|
||||
std::string prompt_text;
|
||||
if (std::strcmp(prompt_arg, "-") == 0) {
|
||||
// --prompt - or --prompt (no arg): read prompt from stdin
|
||||
// --prompt - or --prompt (no arg): read prompt from stdin / --prompt - 或 --prompt(无参数):从 stdin 读取提示
|
||||
char buf[4096];
|
||||
while (std::fgets(buf, sizeof(buf), stdin)) {
|
||||
prompt_text += buf;
|
||||
@@ -575,13 +598,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
char buffer[8192];
|
||||
while (true) {
|
||||
// B1: 检查退出标志
|
||||
// B1: 检查退出标志 / Check quit flag
|
||||
if (g_quit_requested) {
|
||||
std::printf("再见!\n");
|
||||
break;
|
||||
}
|
||||
|
||||
// A1: 提示符带模型名(batch 模式不打印)
|
||||
// A1: 提示符带模型名(batch 模式不打印) / Prompt shows model name (not printed in batch mode)
|
||||
if (!batch_mode) {
|
||||
std::printf(CLR_CYAN "[%s] " CLR_RESET CLR_YELLOW "> " CLR_RESET,
|
||||
g_current_model.empty() ? "?" : g_current_model.c_str());
|
||||
@@ -590,14 +613,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
if (!std::fgets(buffer, sizeof(buffer), stdin)) break;
|
||||
|
||||
// C3: fgets 截断检测
|
||||
// C3: fgets 截断检测 / fgets truncation detection
|
||||
if (!std::strchr(buffer, '\n') && !feof(stdin)) {
|
||||
std::fprintf(stderr, CLR_RED "[ERROR] 输入超过 8KB,已截断。建议用文件方式:dstalk --batch < file.txt\n" CLR_RESET);
|
||||
int c;
|
||||
while ((c = std::fgetc(stdin)) != '\n' && c != EOF) {}
|
||||
}
|
||||
|
||||
// 去除末尾换行
|
||||
// 去除末尾换行 / Strip trailing newline
|
||||
size_t len = std::strlen(buffer);
|
||||
while (len > 0 && (buffer[len-1] == '\n' || buffer[len-1] == '\r')) {
|
||||
buffer[--len] = '\0';
|
||||
@@ -605,19 +628,19 @@ int main(int argc, char* argv[])
|
||||
|
||||
if (len == 0) continue;
|
||||
|
||||
// 命令处理
|
||||
// 命令处理 / Command dispatch
|
||||
if (buffer[0] == '/') {
|
||||
handle_command(buffer);
|
||||
continue;
|
||||
}
|
||||
|
||||
// AI 对话(通过插件服务 vtable)
|
||||
// AI 对话(通过插件服务 vtable) / AI chat (via plugin service vtable)
|
||||
if (!g_ai || !g_session) {
|
||||
std::printf(CLR_RED "[ERROR] AI 或 Session 服务不可用\n" CLR_RESET);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 获取会话历史
|
||||
// 获取会话历史 / Get session history
|
||||
int history_count = 0;
|
||||
const dstalk_message_t* history = g_session->history(&history_count);
|
||||
|
||||
@@ -627,14 +650,14 @@ int main(int argc, char* argv[])
|
||||
|
||||
if (result.ok) {
|
||||
std::printf(CLR_RESET "\n\n");
|
||||
// 将用户消息和 AI 回复添加到会话
|
||||
// 将用户消息和 AI 回复添加到会话 / Add user message and AI reply to session
|
||||
dstalk_message_t user_msg = {"user", buffer, nullptr, nullptr};
|
||||
g_session->add(&user_msg);
|
||||
dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json};
|
||||
g_session->add(&ai_msg);
|
||||
|
||||
// W20.1: Tool Calling 闭环
|
||||
// 若 AI 返回了 tool_calls,自动执行工具并将结果追加到 history,再调 AI
|
||||
// W20.1: Tool Calling 闭环 / Tool calling closed loop
|
||||
// 若 AI 返回了 tool_calls,自动执行工具并将结果追加到 history,再调 AI / If AI returns tool_calls, auto-execute tools, append results to history, then call AI again
|
||||
bool has_tool_calls = (result.tool_calls_json && result.tool_calls_json[0] != '\0');
|
||||
const int MAX_TOOL_ROUNDS = 5;
|
||||
int tool_round = 0;
|
||||
@@ -643,15 +666,15 @@ int main(int argc, char* argv[])
|
||||
tool_round++;
|
||||
has_tool_calls = false;
|
||||
|
||||
// 保存 tool_calls_json(free_result 前必须拷贝)
|
||||
// 保存 tool_calls_json(free_result 前必须拷贝) / Save tool_calls_json (must copy before free_result)
|
||||
std::string tc_json(result.tool_calls_json);
|
||||
|
||||
// 解析 [{"id":"...", "function":{"name":"...", "arguments":"..."}}]
|
||||
// 解析 [{"id":"...", "function":{"name":"...", "arguments":"..."}}] / Parse tool calls JSON array
|
||||
boost::system::error_code ec;
|
||||
auto tc_val = boost::json::parse(tc_json, ec);
|
||||
if (ec.failed() || !tc_val.is_array()) break;
|
||||
const auto& tc_array = tc_val.as_array();
|
||||
if (tc_array.empty()) break; // 空数组 → 终止
|
||||
if (tc_array.empty()) break; // 空数组 → 终止 / empty array → stop
|
||||
|
||||
bool any_executed = false;
|
||||
for (const auto& tc : tc_array) {
|
||||
@@ -675,7 +698,7 @@ int main(int argc, char* argv[])
|
||||
std::string call_id = (id_j && id_j->is_string())
|
||||
? boost::json::value_to<std::string>(*id_j) : "";
|
||||
|
||||
// 执行工具
|
||||
// 执行工具 / Execute tool
|
||||
std::printf(CLR_DIM "[工具调用] %s...\n" CLR_RESET, tool_name.c_str());
|
||||
char* exec_result = g_tools->execute(tool_name.c_str(), tool_args.c_str());
|
||||
if (exec_result) {
|
||||
@@ -691,7 +714,7 @@ int main(int argc, char* argv[])
|
||||
any_executed = true;
|
||||
} else {
|
||||
std::printf(CLR_DIM "[工具结果] fail\n" CLR_RESET);
|
||||
// 单工具失败:log + skip
|
||||
// 单工具失败:log + skip / Single tool failure: log + skip
|
||||
std::fprintf(stderr, CLR_YELLOW "[WARN] tool '%s' returned null, skipping\n" CLR_RESET,
|
||||
tool_name.c_str());
|
||||
}
|
||||
@@ -699,7 +722,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
if (!any_executed) break;
|
||||
|
||||
// 重新调用 AI(chat_stream 流式,此时 history 已包含工具结果)
|
||||
// 重新调用 AI(chat_stream 流式,此时 history 已包含工具结果) / Re-invoke AI (chat_stream streaming, history now includes tool results)
|
||||
history_count = 0;
|
||||
history = g_session->history(&history_count);
|
||||
|
||||
@@ -728,14 +751,14 @@ int main(int argc, char* argv[])
|
||||
std::fprintf(stderr, CLR_YELLOW "[WARN] 已达最大工具调用轮次(%d),停止\n" CLR_RESET, MAX_TOOL_ROUNDS);
|
||||
}
|
||||
} else {
|
||||
// A3: error 路径下需 NULL 保护;当前只取 result.error,content 未涉及
|
||||
// A3: error 路径下需 NULL 保护;当前只取 result.error,content 未涉及 / Error path needs NULL guard; currently only reads result.error, content not involved
|
||||
std::printf(CLR_RESET "\n" CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET,
|
||||
result.error ? result.error : "unknown error");
|
||||
}
|
||||
g_ai->free_result(&result);
|
||||
}
|
||||
|
||||
// B2: 单一退出点,dstalk_shutdown 只在此调用(交互模式下)
|
||||
// B2: 单一退出点,dstalk_shutdown 只在此调用(交互模式下) / Single exit point, dstalk_shutdown only called here (in interactive mode)
|
||||
dstalk_shutdown();
|
||||
return g_quit_via_signal ? EXIT_INTERRUPT : EXIT_OK;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/**
|
||||
* @file dstalk_host.h
|
||||
* @brief Host API declarations: plugin lifecycle, service registry, event bus, config, logging, memory.
|
||||
* 主机 API 声明:插件生命周期、服务注册表、事件总线、配置、日志、内存管理。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#ifndef DSTALK_HOST_H
|
||||
#define DSTALK_HOST_H
|
||||
|
||||
@@ -8,7 +15,7 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// === 平台导出宏 ===
|
||||
/* ---- 平台导出宏 / Platform export macros ---- */
|
||||
#ifndef DSTALK_API
|
||||
#if defined(_WIN32)
|
||||
#ifdef DSTALK_BUILD_DLL
|
||||
@@ -21,21 +28,23 @@ extern "C" {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// === 插件导出宏 ===
|
||||
/* ---- 插件导出宏 / Plugin export macro ---- */
|
||||
#if defined(_WIN32)
|
||||
#define DSTALK_PLUGIN_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
#define DSTALK_PLUGIN_EXPORT __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
// === API 版本 ===
|
||||
#define DSTALK_API_VERSION 1
|
||||
#define DSTALK_MAX_DEPS 8
|
||||
/* ---- API 版本常量 / API version constants ---- */
|
||||
#define DSTALK_API_VERSION 1 // 当前主机 API 版本,插件必须匹配 / current host API version plugins must match
|
||||
#define DSTALK_MAX_DEPS 8 // 插件可声明的最大依赖项数量 / maximum dependency entries a plugin can declare
|
||||
|
||||
// === 诊断 ===
|
||||
/* ---- 诊断回调 / Diagnostics callback ---- */
|
||||
/* 主机调用此回调用于断言失败和内部诊断 / Called by the host for assertion failures and internal diagnostics */
|
||||
typedef void (*dstalk_diag_cb)(int severity, const char* file,
|
||||
int line, const char* func, const char* message);
|
||||
|
||||
/* 断言宏: 当 expr 为假时记录错误并返回 retval / Assertion macro: logs error and returns retval if expr is false */
|
||||
#define DSTALK_ERROR_RETURN(expr, retval) do { \
|
||||
if (!(expr)) { \
|
||||
dstalk_log(DSTALK_LOG_ERROR, "[%s:%d] %s: assertion '%s' failed", \
|
||||
@@ -44,85 +53,107 @@ typedef void (*dstalk_diag_cb)(int severity, const char* file,
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
/* 注册诊断回调用于内部错误报告 / Register a diagnostic callback for internal error reporting */
|
||||
DSTALK_API void dstalk_set_diag_callback(dstalk_diag_cb cb);
|
||||
|
||||
// === 事件处理器 ===
|
||||
/* ---- 事件处理器类型 / Event handler type ---- */
|
||||
/* 当已订阅的事件被触发时由主机调用 / Called by the host when a subscribed event is emitted */
|
||||
typedef void (*dstalk_event_handler_fn)(int event_type, const void* data, void* userdata);
|
||||
|
||||
// === Host 提供给插件的 API 表 ===
|
||||
/* ---- 主机 API vtable (传递给插件的 on_init) / Host API vtable (passed to plugin's on_init) ---- */
|
||||
typedef struct {
|
||||
// 服务注册/查询
|
||||
/* --- 服务注册表 / service registry --- */
|
||||
int (*register_service)(const char* name, int version, void* vtable);
|
||||
void*(*query_service)(const char* name, int min_version);
|
||||
|
||||
// 事件
|
||||
/* --- 事件总线 / event bus --- */
|
||||
int (*event_subscribe)(int event_type, dstalk_event_handler_fn handler, void* userdata);
|
||||
int (*event_emit)(int event_type, const void* data);
|
||||
void (*event_unsubscribe)(int sub_id);
|
||||
|
||||
// 配置
|
||||
/* --- 配置管理 / configuration --- */
|
||||
const char* (*config_get)(const char* key);
|
||||
int (*config_set)(const char* key, const char* value);
|
||||
|
||||
// 日志
|
||||
/* --- 日志记录 / logging --- */
|
||||
void (*log)(int level, const char* fmt, ...);
|
||||
|
||||
// 内存
|
||||
/* --- 内存管理 / memory management --- */
|
||||
void* (*alloc)(size_t size);
|
||||
void (*free)(void* ptr);
|
||||
char* (*strdup)(const char* s);
|
||||
} dstalk_host_api_t;
|
||||
|
||||
// === 插件信息结构 ===
|
||||
/* ---- 插件描述符 / Plugin descriptor ---- */
|
||||
/* 每个插件通过 dstalk_plugin_init() 导出此结构体 / Every plugin exports this via dstalk_plugin_init() */
|
||||
typedef struct {
|
||||
const char* name; // 插件名称(唯一标识)
|
||||
const char* version; // 语义化版本号,如 "1.0.0"
|
||||
const char* description; // 描述
|
||||
int api_version; // 必须 == DSTALK_API_VERSION
|
||||
const char* name; // 唯一插件标识符 / unique plugin identifier
|
||||
const char* version; // 语义版本号,如 "1.0.0" / semantic version, e.g. "1.0.0"
|
||||
const char* description; // 人类可读的描述信息 / human-readable description
|
||||
int api_version; // 必须等于 DSTALK_API_VERSION / must equal DSTALK_API_VERSION
|
||||
|
||||
// 依赖声明(以 NULL 结尾)
|
||||
/* null-terminated 依赖插件名称列表 / null-terminated list of dependency plugin names */
|
||||
const char* dependencies[DSTALK_MAX_DEPS];
|
||||
|
||||
// 生命周期回调
|
||||
/* 生命周期回调 / lifecycle callbacks */
|
||||
int (*on_init)(const dstalk_host_api_t* host);
|
||||
void (*on_shutdown)(void);
|
||||
|
||||
// 事件处理(可选)
|
||||
/* 可选: 事件总线上每个事件通过时调用 / optional: called for every event passing through the bus */
|
||||
void (*on_event)(int event_type, const void* data);
|
||||
} dstalk_plugin_info_t;
|
||||
|
||||
// === 插件入口函数 ===
|
||||
/* ---- 插件入口点 / Plugin entry point ---- */
|
||||
/* 每个共享库插件必须导出一个与此签名匹配的函数 / Every shared library plugin must export a function with this signature */
|
||||
typedef dstalk_plugin_info_t* (*dstalk_plugin_init_fn)(void);
|
||||
|
||||
// === Host 公共 API ===
|
||||
/* ========================================================================
|
||||
* 主机公共 API / Host public API
|
||||
* ======================================================================== */
|
||||
|
||||
// 初始化/销毁
|
||||
/* 使用给定的配置文件路径初始化 dstalk 主机 / Initialize the dstalk host with the given config file path */
|
||||
DSTALK_API int dstalk_init(const char* config_path);
|
||||
|
||||
/* 关闭主机: 卸载插件, 释放资源 / Shut down the host: unload plugins, free resources */
|
||||
DSTALK_API void dstalk_shutdown(void);
|
||||
|
||||
// 插件管理
|
||||
/* 从共享库路径加载插件; 返回 plugin_id, 出错返回 -1 / Load a plugin from a shared library path; returns plugin_id or -1 on error */
|
||||
DSTALK_API int dstalk_plugin_load(const char* path);
|
||||
|
||||
/* 按 id 卸载之前加载的插件 / Unload a previously loaded plugin by its id */
|
||||
DSTALK_API int dstalk_plugin_unload(int plugin_id);
|
||||
|
||||
/* 将已加载插件信息的 JSON 数组写入 *output_json (调用方释放) / Write a JSON array of loaded plugin info to *output_json (caller frees) */
|
||||
DSTALK_API int dstalk_plugin_list(char** output_json);
|
||||
|
||||
// 服务查询
|
||||
/* 按名称和最低版本号查找已注册的服务 vtable / Look up a registered service vtable by name and minimum version */
|
||||
DSTALK_API void* dstalk_service_query(const char* service_name, int min_version);
|
||||
|
||||
// 事件系统
|
||||
/* 为特定事件类型订阅处理器; 返回 subscription_id / Subscribe handler to a specific event type; returns subscription_id */
|
||||
DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata);
|
||||
|
||||
/* 向所有已订阅该类型事件的订阅者发送事件 / Emit an event to all subscribers of the given type */
|
||||
DSTALK_API int dstalk_event_emit(int event_type, const void* data);
|
||||
|
||||
/* 按 id 移除订阅 / Remove a subscription by its id */
|
||||
DSTALK_API void dstalk_event_unsubscribe(int subscription_id);
|
||||
|
||||
// 配置
|
||||
/* 通过键名获取配置值 (未找到返回 NULL) / Retrieve a config value by key (returns NULL if not found) */
|
||||
DSTALK_API const char* dstalk_config_get(const char* key);
|
||||
|
||||
/* 设置配置键值对; 成功返回 0 / Set a config key/value pair; returns 0 on success */
|
||||
DSTALK_API int dstalk_config_set(const char* key, const char* value);
|
||||
|
||||
// 日志
|
||||
/* 以给定严重等级记录日志消息 / Log a message at the given severity level */
|
||||
DSTALK_API void dstalk_log(int level, const char* fmt, ...);
|
||||
|
||||
// 内存
|
||||
/* 使用主机的内存分配器分配内存 / Allocate memory using the host's allocator */
|
||||
DSTALK_API void* dstalk_alloc(size_t size);
|
||||
|
||||
/* 释放之前由主机分配的内存 / Free memory previously allocated by the host */
|
||||
DSTALK_API void dstalk_free(void* ptr);
|
||||
|
||||
/* 使用主机的内存分配器复制 C 字符串 / Duplicate a C-string using the host's allocator */
|
||||
DSTALK_API char* dstalk_strdup(const char* s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/**
|
||||
* @file dstalk_lsp.h
|
||||
* @brief Convenience C API for Language Server Protocol operations (delegates to "lsp" plugin).
|
||||
* LSP(语言服务器协议)操作的便捷 C API(委托给 "lsp" 插件)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#ifndef DSTALK_LSP_H
|
||||
#define DSTALK_LSP_H
|
||||
|
||||
@@ -7,51 +14,51 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* ---- LSP 服务器生命周期 ---- */
|
||||
/* ---- LSP 服务器生命周期 / LSP Server Lifecycle ---- */
|
||||
|
||||
/*
|
||||
* 启动语言服务器进程
|
||||
* server_cmd: 命令字符串,例如 "clangd" 或 "pyright --stdio" 或完整路径
|
||||
* language: 语言标识,例如 "c", "cpp", "python", "javascript", "rust"
|
||||
* returns: 0 成功, -1 失败
|
||||
* 启动语言服务器进程 / Start the language server process
|
||||
* server_cmd: 命令字符串,例如 "clangd" 或 "pyright --stdio" 或完整路径 / command string, e.g. "clangd" or "pyright --stdio" or full path
|
||||
* language: 语言标识,例如 "c", "cpp", "python", "javascript", "rust" / language identifier, e.g. "c", "cpp", "python", "javascript", "rust"
|
||||
* returns: 0 成功, -1 失败 / 0 success, -1 failure
|
||||
*/
|
||||
DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language);
|
||||
|
||||
/*
|
||||
* 停止语言服务器
|
||||
* 发送 shutdown 请求,然后发送 exit 通知
|
||||
* 关闭管道,终止子进程
|
||||
* 停止语言服务器 / Stop the language server
|
||||
* 发送 shutdown 请求,然后发送 exit 通知 / sends shutdown request, then exit notification
|
||||
* 关闭管道,终止子进程 / closes pipes, terminates child process
|
||||
*/
|
||||
DSTALK_API void dstalk_lsp_stop(void);
|
||||
|
||||
/* ---- 文档管理 ---- */
|
||||
/* ---- 文档管理 / Document Management ---- */
|
||||
|
||||
/*
|
||||
* 在语言服务器中打开一个文档
|
||||
* uri: 文件 URI,例如 "file:///path/to/file.c"
|
||||
* content: 文件内容文本
|
||||
* language_id: 语言 ID,例如 "c", "cpp", "python", "javascript"
|
||||
* returns: 0 成功, -1 失败
|
||||
* 在语言服务器中打开一个文档 / Open a document in the language server
|
||||
* uri: 文件 URI,例如 "file:///path/to/file.c" / file URI, e.g. "file:///path/to/file.c"
|
||||
* content: 文件内容文本 / file content text
|
||||
* language_id: 语言 ID,例如 "c", "cpp", "python", "javascript" / language ID, e.g. "c", "cpp", "python", "javascript"
|
||||
* returns: 0 成功, -1 失败 / 0 success, -1 failure
|
||||
*/
|
||||
DSTALK_API int dstalk_lsp_open(const char* uri, const char* content,
|
||||
const char* language_id);
|
||||
|
||||
/*
|
||||
* 关闭语言服务器中的文档
|
||||
* uri: 文件 URI
|
||||
* returns: 0 成功, -1 失败
|
||||
* 关闭语言服务器中的文档 / Close a document in the language server
|
||||
* uri: 文件 URI / file URI
|
||||
* returns: 0 成功, -1 失败 / 0 success, -1 failure
|
||||
*/
|
||||
DSTALK_API int dstalk_lsp_close(const char* uri);
|
||||
|
||||
/* ---- 查询操作 ---- */
|
||||
/* ---- 查询操作 / Query Operations ---- */
|
||||
|
||||
/*
|
||||
* 获取诊断信息 (编译错误、警告等)
|
||||
* uri: 文件 URI
|
||||
* output: 输出参数,JSON 格式的诊断列表 (调用方通过 dstalk_free 释放)
|
||||
* returns: 0 成功, -1 失败
|
||||
* 获取诊断信息 (编译错误、警告等) / Get diagnostics (build errors, warnings, etc.)
|
||||
* uri: 文件 URI / file URI
|
||||
* output: 输出参数,JSON 格式的诊断列表 (调用方通过 dstalk_free 释放) / output param, JSON list of diagnostics (caller frees via dstalk_free)
|
||||
* returns: 0 成功, -1 失败 / 0 success, -1 failure
|
||||
*
|
||||
* JSON 输出格式示例:
|
||||
* JSON 输出格式示例 / JSON output format example:
|
||||
* [
|
||||
* {
|
||||
* "range": { "start": {"line":0,"character":0}, "end":{"line":0,"character":5} },
|
||||
@@ -63,23 +70,23 @@ DSTALK_API int dstalk_lsp_close(const char* uri);
|
||||
DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output);
|
||||
|
||||
/*
|
||||
* 获取悬停信息 (类型、文档等)
|
||||
* uri: 文件 URI
|
||||
* line: 行号 (0-based)
|
||||
* character: 列号 (0-based, UTF-16 code units)
|
||||
* output: 输出参数,JSON 格式的悬停信息 (调用方通过 dstalk_free 释放)
|
||||
* returns: 0 成功, -1 失败
|
||||
* 获取悬停信息 (类型、文档等) / Get hover info (type, documentation, etc.)
|
||||
* uri: 文件 URI / file URI
|
||||
* line: 行号 (0-based) / line number (0-based)
|
||||
* character: 列号 (0-based, UTF-16 code units) / column number (0-based, UTF-16 code units)
|
||||
* output: 输出参数,JSON 格式的悬停信息 (调用方通过 dstalk_free 释放) / output param, JSON hover info (caller frees via dstalk_free)
|
||||
* returns: 0 成功, -1 失败 / 0 success, -1 failure
|
||||
*/
|
||||
DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character,
|
||||
char** output);
|
||||
|
||||
/*
|
||||
* 获取代码补全建议
|
||||
* uri: 文件 URI
|
||||
* line: 行号 (0-based)
|
||||
* character: 列号 (0-based, UTF-16 code units)
|
||||
* output: 输出参数,JSON 格式的补全列表 (调用方通过 dstalk_free 释放)
|
||||
* returns: 0 成功, -1 失败
|
||||
* 获取代码补全建议 / Get code completion suggestions
|
||||
* uri: 文件 URI / file URI
|
||||
* line: 行号 (0-based) / line number (0-based)
|
||||
* character: 列号 (0-based, UTF-16 code units) / column number (0-based, UTF-16 code units)
|
||||
* output: 输出参数,JSON 格式的补全列表 (调用方通过 dstalk_free 释放) / output param, JSON completion list (caller frees via dstalk_free)
|
||||
* returns: 0 成功, -1 失败 / 0 success, -1 failure
|
||||
*/
|
||||
DSTALK_API int dstalk_lsp_completion(const char* uri, int line, int character,
|
||||
char** output);
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/**
|
||||
* @file dstalk_services.h
|
||||
* @brief Service vtable definitions for all plugin-provided services (AI, Session, HTTP, etc.).
|
||||
* 所有插件提供的服务 vtable 定义(AI、会话、HTTP 等)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#ifndef DSTALK_SERVICES_H
|
||||
#define DSTALK_SERVICES_H
|
||||
|
||||
@@ -7,46 +14,64 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// === AI 服务 vtable (实际服务名由插件注册: "ai.deepseek" / "ai.anthropic") ===
|
||||
/* ---- AI 服务 vtable / AI service vtable ---- */
|
||||
/* 以名称如 "ai.deepseek" 或 "ai.anthropic" 注册 / Registered under names such as "ai.deepseek" or "ai.anthropic" */
|
||||
typedef struct {
|
||||
/* 配置服务商连接 (base_url, api_key, model 等) / Configure provider connection (base_url, api_key, model, etc.) */
|
||||
int (*configure)(const char* provider, const char* base_url,
|
||||
const char* api_key, const char* model,
|
||||
int max_tokens, double temperature);
|
||||
/* 发送单轮聊天补全请求 (阻塞) / Send a single-turn chat completion (blocking) */
|
||||
dstalk_chat_result_t (*chat)(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
const char* tools_json);
|
||||
/* 通过回调实现流式令牌传输的聊天补全 / Send a chat completion with streaming tokens via callback */
|
||||
dstalk_chat_result_t (*chat_stream)(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
dstalk_stream_cb cb, void* userdata);
|
||||
/* 释放 dstalk_chat_result_t 持有的资源 / Free resources held by a dstalk_chat_result_t */
|
||||
void (*free_result)(dstalk_chat_result_t* result);
|
||||
} dstalk_ai_service_t;
|
||||
|
||||
// === Session 服务 (service name: "session") ===
|
||||
/* ---- 会话服务 vtable / Session service vtable ---- */
|
||||
/* 以服务名称 "session" 注册 / Registered under service name "session" */
|
||||
typedef struct {
|
||||
/* 将消息追加到会话历史 / Append a message to the session history */
|
||||
void (*add)(const dstalk_message_t* msg);
|
||||
/* 清除会话历史中的所有消息 / Clear all messages from the session history */
|
||||
void (*clear)(void);
|
||||
/* 将会话历史保存到文件 (JSON); 成功返回 0 / Save session history to a file (JSON); returns 0 on success */
|
||||
int (*save)(const char* path);
|
||||
/* 从文件 (JSON) 加载会话历史; 成功返回 0 / Load session history from a file (JSON); returns 0 on success */
|
||||
int (*load)(const char* path);
|
||||
/* 获取完整消息历史; out_count 接收数组长度 / Get the full message history; out_count receives the array length */
|
||||
const dstalk_message_t* (*history)(int* out_count);
|
||||
/* 返回当前会话历史的近似令牌数 / Return the approximate token count of the current session history */
|
||||
int (*token_count)(void);
|
||||
} dstalk_session_service_t;
|
||||
|
||||
// === Context 服务 (service name: "context") ===
|
||||
/* ---- 上下文服务 vtable / Context service vtable ---- */
|
||||
/* 以服务名称 "context" 注册 / Registered under service name "context" */
|
||||
typedef struct {
|
||||
/* 计算消息数组中近似的令牌数 / Count approximate tokens in an array of messages */
|
||||
size_t (*count_tokens)(const dstalk_message_t* msgs, int count);
|
||||
/* 裁剪消息历史以适应 max_tokens; out/out_count 为新分配 / Trim message history to fit within max_tokens; out/out_count are newly allocated */
|
||||
int (*trim)(const dstalk_message_t* in, int in_count,
|
||||
dstalk_message_t** out, int* out_count,
|
||||
size_t max_tokens);
|
||||
} dstalk_context_service_t;
|
||||
|
||||
// === HTTP 服务 (service name: "http") ===
|
||||
/* ---- HTTP 服务 vtable / HTTP service vtable ---- */
|
||||
/* 以服务名称 "http" 注册 / Registered under service name "http" */
|
||||
typedef struct {
|
||||
/* POST JSON 体到主机; 返回响应体和 HTTP 状态码 / POST JSON body to a host; returns response body and HTTP status code */
|
||||
int (*post_json)(const char* host, const char* port,
|
||||
const char* target, const char* body,
|
||||
const char* headers_json,
|
||||
char** response_body, int* status_code);
|
||||
/* POST 带流式响应; 令牌通过回调传递 / POST with streaming response; tokens are delivered via callback */
|
||||
int (*post_stream)(const char* host, const char* port,
|
||||
const char* target, const char* body,
|
||||
const char* headers_json,
|
||||
@@ -54,38 +79,61 @@ typedef struct {
|
||||
char** response_body, int* status_code);
|
||||
} dstalk_http_service_t;
|
||||
|
||||
// === File IO 服务 (service name: "file_io") ===
|
||||
/* ---- 文件 I/O 服务 vtable / File I/O service vtable ---- */
|
||||
/* 以服务名称 "file_io" 注册 / Registered under service name "file_io" */
|
||||
typedef struct {
|
||||
/* 读取整个文件内容到 *content; 成功返回 0 / Read entire file content into *content; returns 0 on success */
|
||||
int (*read)(const char* path, char** content);
|
||||
/* 将内容写入文件 (覆盖已有文件); 成功返回 0 / Write content to a file (overwrites if exists); returns 0 on success */
|
||||
int (*write)(const char* path, const char* content);
|
||||
} dstalk_file_io_service_t;
|
||||
|
||||
// === Config 服务 (service name: "config") ===
|
||||
/* ---- 配置服务 vtable / Config service vtable ---- */
|
||||
/* 以服务名称 "config" 注册 / Registered under service name "config" */
|
||||
typedef struct {
|
||||
/* 通过键名获取配置值; 未找到返回 NULL / Get a config value by key; returns NULL if not found */
|
||||
const char* (*get)(const char* key);
|
||||
/* 设置配置键值对; 成功返回 0 / Set a config key/value pair; returns 0 on success */
|
||||
int (*set)(const char* key, const char* value);
|
||||
/* 从 JSON 配置文件加载并合并键值对 / Load and merge key/value pairs from a JSON config file */
|
||||
int (*load_file)(const char* path);
|
||||
} dstalk_config_service_t;
|
||||
|
||||
// === Tools 服务 (service name: "tools") ===
|
||||
/* ---- 工具服务 vtable / Tools service vtable ---- */
|
||||
/* 以服务名称 "tools" 注册 / Registered under service name "tools" */
|
||||
|
||||
/* 已注册工具被调用时触发的处理器; 接收 JSON 参数, 返回 JSON 结果 / Handler invoked when a registered tool is called; receives JSON args, returns JSON result */
|
||||
typedef char* (*dstalk_tool_handler_fn)(const char* args_json);
|
||||
|
||||
typedef struct {
|
||||
/* 注册工具,包含名称、描述和 JSON Schema 参数 / Register a tool with name, description, and JSON Schema parameters */
|
||||
int (*register_tool)(const char* name, const char* desc,
|
||||
const char* params_schema,
|
||||
dstalk_tool_handler_fn handler);
|
||||
/* 取消注册之前注册的工具 / Unregister a previously registered tool */
|
||||
void (*unregister_tool)(const char* name);
|
||||
/* 获取所有已注册工具为 JSON 数组 (OpenAI 工具格式) / Get all registered tools as a JSON array (OpenAI tool format) */
|
||||
char* (*get_tools_json)(void);
|
||||
/* 按名称执行已注册工具,传入 JSON 参数 / Execute a registered tool by name with the given JSON arguments */
|
||||
char* (*execute)(const char* name, const char* args_json);
|
||||
} dstalk_tools_service_t;
|
||||
|
||||
// === LSP 服务 (service name: "lsp") ===
|
||||
/* ---- LSP 服务 vtable / LSP service vtable ---- */
|
||||
/* 以服务名称 "lsp" 注册 / Registered under service name "lsp" */
|
||||
typedef struct {
|
||||
/* 启动指定语言的 LSP 服务器进程 / Start an LSP server process for the given language */
|
||||
int (*start)(const char* server_cmd, const char* language);
|
||||
/* 停止 LSP 服务器并清理资源 / Stop the LSP server and clean up resources */
|
||||
void (*stop)(void);
|
||||
/* 在 LSP 服务器中打开文档 / Open a document in the LSP server */
|
||||
int (*open_document)(const char* uri, const char* content, const char* lang_id);
|
||||
/* 在 LSP 服务器中关闭文档 / Close a document in the LSP server */
|
||||
int (*close_document)(const char* uri);
|
||||
/* 获取文档的诊断信息 (错误、警告) 以 JSON 格式返回 / Retrieve diagnostics (errors, warnings) for a document as JSON */
|
||||
int (*get_diagnostics)(const char* uri, char** json_out);
|
||||
/* 获取指定位置的悬停信息以 JSON 格式返回 / Retrieve hover information at a given position as JSON */
|
||||
int (*get_hover)(const char* uri, int line, int col, char** json_out);
|
||||
/* 获取指定位置的代码补全建议以 JSON 格式返回 / Retrieve code completion suggestions at a given position as JSON */
|
||||
int (*get_completion)(const char* uri, int line, int col, char** json_out);
|
||||
} dstalk_lsp_service_t;
|
||||
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/**
|
||||
* @file dstalk_types.h
|
||||
* @brief Shared data types used across the dstalk host and all plugins.
|
||||
* 跨主机和所有插件共享的数据类型定义。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#ifndef DSTALK_TYPES_H
|
||||
#define DSTALK_TYPES_H
|
||||
|
||||
@@ -7,42 +14,42 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// 消息结构(跨插件共享)
|
||||
/* 所有插件共享的消息结构体 / Shared message structure used across plugins */
|
||||
typedef struct {
|
||||
const char* role; // "user", "assistant", "system", "tool"
|
||||
const char* content; // 消息内容
|
||||
const char* tool_call_id; // tool 响应时必填
|
||||
const char* tool_calls_json;// assistant 返回的工具调用(JSON 数组)
|
||||
const char* role; // 角色标识 / Role identifier ("user", "assistant", "system", "tool")
|
||||
const char* content; // 消息正文文本 / Message body text
|
||||
const char* tool_call_id; // 工具调用响应消息所需 / Required for tool response messages
|
||||
const char* tool_calls_json;// 助手工具调用的 JSON 数组 / JSON array of tool calls from assistant
|
||||
} dstalk_message_t;
|
||||
|
||||
// 聊天结果
|
||||
/* 聊天/补全调用返回的结果 / Result returned from a chat / completion call */
|
||||
typedef struct {
|
||||
int ok;
|
||||
const char* content; // dstalk_strdup 分配,调用方 dstalk_free
|
||||
const char* error; // dstalk_strdup 分配
|
||||
int http_status;
|
||||
const char* tool_calls_json;// dstalk_strdup 分配
|
||||
int ok; // 0=失败, 非零=成功 / 0 = failure, non-zero = success
|
||||
const char* content; // dstalk_strdup 分配; 调用方用 dstalk_free 释放 / allocated by dstalk_strdup; caller frees with dstalk_free
|
||||
const char* error; // dstalk_strdup 分配; 成功时为 NULL / allocated by dstalk_strdup; NULL on success
|
||||
int http_status; // 服务商返回的 HTTP 状态码 / HTTP status code from the provider
|
||||
const char* tool_calls_json;// dstalk_strdup 分配; 工具调用的 JSON 数组 / allocated by dstalk_strdup; JSON array of tool calls
|
||||
} dstalk_chat_result_t;
|
||||
|
||||
// 流式回调
|
||||
/* 流式令牌回调: 返回非零值提前中止流传输 / Streaming token callback: return non-zero to abort the stream early */
|
||||
typedef int (*dstalk_stream_cb)(const char* token, void* userdata);
|
||||
|
||||
// 事件类型
|
||||
/* 事件类型代码 (匿名枚举) / Event type codes (anonymous enum) */
|
||||
enum {
|
||||
DSTALK_EVENT_MESSAGE = 1, // data = dstalk_message_t*
|
||||
DSTALK_EVENT_SESSION_CLEAR,
|
||||
DSTALK_EVENT_CONFIG_CHANGED,
|
||||
DSTALK_EVENT_PLUGIN_LOADED, // data = plugin info JSON string
|
||||
DSTALK_EVENT_PLUGIN_UNLOADED,
|
||||
DSTALK_EVENT_CUSTOM = 1000, // 插件自定义事件起始值
|
||||
DSTALK_EVENT_MESSAGE = 1, // 数据为 dstalk_message_t* / data = dstalk_message_t*
|
||||
DSTALK_EVENT_SESSION_CLEAR, // 会话历史已清除 / session history cleared
|
||||
DSTALK_EVENT_CONFIG_CHANGED, // 配置键/值已更新 / configuration key/value updated
|
||||
DSTALK_EVENT_PLUGIN_LOADED, // 数据为插件信息 JSON 字符串 / data = plugin info JSON string
|
||||
DSTALK_EVENT_PLUGIN_UNLOADED, // 插件已卸载 / plugin unloaded
|
||||
DSTALK_EVENT_CUSTOM = 1000, // 插件自定义事件的基础值 / base value for plugin-defined custom events
|
||||
};
|
||||
|
||||
// 日志级别
|
||||
/* 日志严重等级 (匿名枚举) / Log severity levels (anonymous enum) */
|
||||
enum {
|
||||
DSTALK_LOG_DEBUG = 0,
|
||||
DSTALK_LOG_INFO = 1,
|
||||
DSTALK_LOG_WARN = 2,
|
||||
DSTALK_LOG_ERROR = 3,
|
||||
DSTALK_LOG_DEBUG = 0, // 详细调试消息 / verbose debug messages
|
||||
DSTALK_LOG_INFO = 1, // 信息性消息 / informational messages
|
||||
DSTALK_LOG_WARN = 2, // 警告条件 / warning conditions
|
||||
DSTALK_LOG_ERROR = 3, // 错误条件 / error conditions
|
||||
};
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
/* @file boost_json.cpp
|
||||
* @brief Boost.JSON header-only library compilation unit (single TU inclusion).
|
||||
* Boost.JSON 仅头文件库的编译单元(单翻译单元包含)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include <boost/json/src.hpp>
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
/* @file config_store.cpp
|
||||
* @brief ConfigStore implementation: TOML parsing, thread-safe get/set with thread-local safety.
|
||||
* ConfigStore 实现:TOML 解析、线程安全的 get/set(基于 thread-local 安全机制)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "config_store.hpp"
|
||||
#include "../../plugins/config/include/toml_parse.h"
|
||||
|
||||
@@ -8,6 +14,7 @@
|
||||
|
||||
namespace dstalk {
|
||||
|
||||
// 在互斥锁下加载并解析 TOML 文件到键值存储 / Load and parse a TOML file into the key-value store under mutex.
|
||||
int ConfigStore::load_file(const char* path)
|
||||
{
|
||||
if (!path) return -1;
|
||||
@@ -19,7 +26,7 @@ int ConfigStore::load_file(const char* path)
|
||||
ss << file.rdbuf();
|
||||
std::string data = ss.str();
|
||||
|
||||
// W12.2: Use shared TOML parser (de-duplicated from config_plugin.cpp)
|
||||
// W12.2: 使用共享 TOML 解析器(从 config_plugin.cpp 去重) / Use shared TOML parser (de-duplicated from config_plugin.cpp)
|
||||
toml::parse(data, [this](const std::string& key, const std::string& value) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
data_[key] = value;
|
||||
@@ -28,6 +35,7 @@ int ConfigStore::load_file(const char* path)
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 检索配置值,返回线程本地副本以避免 c_str() 悬空 / Retrieve config value, returning a thread-local copy to avoid dangling c_str().
|
||||
const char* ConfigStore::get(const char* key) const
|
||||
{
|
||||
if (!key) return nullptr;
|
||||
@@ -35,7 +43,9 @@ const char* ConfigStore::get(const char* key) const
|
||||
auto it = data_.find(key);
|
||||
if (it == data_.end()) return nullptr;
|
||||
|
||||
// W12.2: Copy to thread-local buffer before releasing lock.
|
||||
// W12.2: 在释放锁之前复制到线程本地缓冲区 /
|
||||
// Copy to thread-local buffer before releasing lock.
|
||||
// 防止当并发 set() 触发 std::string 重新分配时 c_str() 悬空 /
|
||||
// Prevents c_str() dangling when concurrent set() on the same key
|
||||
// triggers std::string reallocation (W11.2 audit Finding 3).
|
||||
thread_local std::string tls_cached;
|
||||
@@ -43,15 +53,17 @@ const char* ConfigStore::get(const char* key) const
|
||||
return tls_cached.c_str();
|
||||
}
|
||||
|
||||
// 以 std::string 值类型检索配置(安全的值副本)/ Retrieve config value as an owned std::string (safe by-value copy).
|
||||
std::string ConfigStore::get_copy(const char* key) const
|
||||
{
|
||||
if (!key) return {};
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
auto it = data_.find(key);
|
||||
if (it == data_.end()) return {};
|
||||
return it->second; // copy-constructed under lock, always safe
|
||||
return it->second; // 在锁下复制构造,始终安全 / copy-constructed under lock, always safe
|
||||
}
|
||||
|
||||
// 在锁下设置配置键值对 / Set a config key-value pair under lock.
|
||||
int ConfigStore::set(const char* key, const char* value)
|
||||
{
|
||||
if (!key || !value) return -1;
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
/* @file config_store.hpp
|
||||
* @brief Thread-safe key-value configuration store with TOML file loading.
|
||||
* 线程安全键值配置存储,支持 TOML 文件加载。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
@@ -6,32 +12,36 @@
|
||||
|
||||
namespace dstalk {
|
||||
|
||||
// 线程安全的键值存储,支持 TOML 配置文件 / Thread-safe key-value store backed by TOML config files.
|
||||
// 通过 mutex_ 支持并发读取,get() 返回线程本地缓冲区 / Supports concurrent reads via mutex_ and returns thread-local buffers from get().
|
||||
class ConfigStore {
|
||||
public:
|
||||
ConfigStore() = default;
|
||||
~ConfigStore() = default;
|
||||
|
||||
// Load key-value pairs from a TOML file.
|
||||
// Returns 0 on success, -1 if file not found or path is null.
|
||||
// 从 TOML 文件加载键值对 / Load key-value pairs from a TOML file.
|
||||
// 成功返回 0,文件未找到或路径为空返回 -1 / Returns 0 on success, -1 if file not found or path is null.
|
||||
int load_file(const char* path);
|
||||
|
||||
// Get config value (returns internal pointer, thread-safe).
|
||||
// W12.2: Returned pointer is now backed by a thread-local copy;
|
||||
// 获取配置值(返回内部指针,线程安全)/ Get config value (returns internal pointer, thread-safe).
|
||||
// W12.2: 返回的指针现在由线程局部副本支持,对其他线程对同一键的并发 set() 安全 /
|
||||
// Returned pointer is now backed by a thread-local copy;
|
||||
// safe against concurrent set() on the same key from other threads.
|
||||
// 调用者仍应立即使用 — 同一线程上的下一次 get() 将覆盖缓冲区 /
|
||||
// Caller should still consume immediately — next get() on same
|
||||
// thread will overwrite the buffer.
|
||||
const char* get(const char* key) const;
|
||||
|
||||
// Get a safe by-value copy of a config entry (no dangling risk).
|
||||
// Returns empty string if key not found.
|
||||
// 获取配置项的安全值副本(无悬空风险)/ Get a safe by-value copy of a config entry (no dangling risk).
|
||||
// 如果键未找到,返回空字符串 / Returns empty string if key not found.
|
||||
std::string get_copy(const char* key) const;
|
||||
|
||||
// Set config value. Returns 0 on success, -1 on null arguments.
|
||||
// 设置配置值 / Set config value. 成功返回 0,参数为空返回 -1 / Returns 0 on success, -1 on null arguments.
|
||||
int set(const char* key, const char* value);
|
||||
|
||||
private:
|
||||
mutable std::mutex mutex_;
|
||||
std::unordered_map<std::string, std::string> data_;
|
||||
mutable std::mutex mutex_; // 保护所有 data_ 访问 / Protects all data_ access
|
||||
std::unordered_map<std::string, std::string> data_; // 配置键值存储 / Config key-value store
|
||||
};
|
||||
|
||||
} // namespace dstalk
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
/* @file event_bus.cpp
|
||||
* @brief EventBus implementation: subscribe, unsubscribe, emit with reader-writer locking.
|
||||
* EventBus 实现:基于读写锁的 subscribe、unsubscribe、emit。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "event_bus.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace dstalk {
|
||||
|
||||
// 为给定事件类型注册处理器,返回订阅 ID / Register a handler for the given event type, returning a subscription id.
|
||||
int EventBus::subscribe(int event_type, EventHandler handler)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(mutex_);
|
||||
@@ -12,6 +19,7 @@ int EventBus::subscribe(int event_type, EventHandler handler)
|
||||
return id;
|
||||
}
|
||||
|
||||
// 通过 ID 移除订阅(如果 ID 未找到则无操作)/ Remove a subscription by id (no-op if id not found).
|
||||
void EventBus::unsubscribe(int subscription_id)
|
||||
{
|
||||
std::unique_lock<std::shared_mutex> lock(mutex_);
|
||||
@@ -23,6 +31,8 @@ void EventBus::unsubscribe(int subscription_id)
|
||||
subscriptions_.end());
|
||||
}
|
||||
|
||||
// 在共享锁下将事件分发给所有匹配的订阅者 / Dispatch an event to all matching subscribers under a shared lock.
|
||||
// 返回被调用的处理器数量 / Returns the count of handlers invoked.
|
||||
int EventBus::emit(int event_type, const void* data)
|
||||
{
|
||||
std::shared_lock<std::shared_mutex> lock(mutex_);
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
/* @file event_bus.hpp
|
||||
* @brief Publish-subscribe event bus with shared_mutex for concurrent read access.
|
||||
* 发布-订阅事件总线,使用 shared_mutex 支持并发读访问。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
@@ -10,18 +16,23 @@ namespace dstalk {
|
||||
|
||||
using EventHandler = std::function<void(int event_type, const void* data)>;
|
||||
|
||||
// 轻量级发布-订阅事件总线 / Lightweight pub-sub event bus.
|
||||
// 读取者使用 shared_lock(emit),因此多个处理器可以并发分发;
|
||||
// 写入者使用 unique_lock(subscribe / unsubscribe)。
|
||||
// Readers use shared_lock (emit) so multiple handlers can be dispatched
|
||||
// concurrently; writers use unique_lock (subscribe / unsubscribe).
|
||||
class EventBus {
|
||||
public:
|
||||
EventBus() = default;
|
||||
~EventBus() = default;
|
||||
|
||||
// 订阅事件,返回订阅ID
|
||||
// 订阅事件,返回订阅ID / Subscribe to an event, returning a subscription id
|
||||
int subscribe(int event_type, EventHandler handler);
|
||||
|
||||
// 取消订阅
|
||||
// 取消订阅 / Unsubscribe by subscription id
|
||||
void unsubscribe(int subscription_id);
|
||||
|
||||
// 发布事件
|
||||
// 发布事件 / Emit an event to all matching subscribers
|
||||
int emit(int event_type, const void* data);
|
||||
|
||||
private:
|
||||
@@ -31,9 +42,9 @@ private:
|
||||
EventHandler handler;
|
||||
};
|
||||
|
||||
mutable std::shared_mutex mutex_;
|
||||
std::vector<Subscription> subscriptions_;
|
||||
int next_id_ = 1;
|
||||
mutable std::shared_mutex mutex_; // 读写锁:emit 用 shared,subscribe/unsubscribe 用 unique / RW lock: shared for emit, unique for subscribe/unsubscribe
|
||||
std::vector<Subscription> subscriptions_; // emit 时线性扫描;对少量订阅者足够 / Linear scan on emit; ok for small subscriber counts
|
||||
int next_id_ = 1; // 单调递增订阅 ID 计数器 / Monotonic subscription id counter
|
||||
};
|
||||
|
||||
} // namespace dstalk
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file host.cpp
|
||||
* @brief Core host orchestrator: global singletons, dstalk_host_api_t instantiation, public C API, LSP delegation.
|
||||
* 核心主机协调器:全局单例、dstalk_host_api_t 实例化、公共 C API、LSP 委托。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "config_store.hpp"
|
||||
#include "event_bus.hpp"
|
||||
@@ -15,7 +22,7 @@
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
// ============================================================
|
||||
// 全局主机上下文
|
||||
// 全局主机上下文 / Global host context
|
||||
// ============================================================
|
||||
namespace {
|
||||
std::mutex g_init_mutex;
|
||||
@@ -27,8 +34,10 @@ namespace {
|
||||
dstalk::PluginLoader* g_plugin_loader = nullptr;
|
||||
static std::atomic<dstalk_diag_cb> g_diag_callback{nullptr};
|
||||
|
||||
// ---- 内部辅助 ----
|
||||
// ---- 内部辅助 / Internal helpers ----
|
||||
|
||||
// 复制 C 字符串(用 malloc 分配,调用者必须用 api_free/free 释放)
|
||||
// Duplicate a C string allocated with malloc (caller must free via api_free/free).
|
||||
char* host_strdup(const char* s) {
|
||||
if (!s) return nullptr;
|
||||
size_t len = strlen(s);
|
||||
@@ -37,6 +46,8 @@ namespace {
|
||||
return copy;
|
||||
}
|
||||
|
||||
// 核心日志实现:格式化消息,写入 stderr,并转发到诊断回调(如果已设置)。
|
||||
// Core logging implementation: formats message, writes to stderr, and forwards to diagnostic callback if set.
|
||||
void host_log_impl(int level, const char* fmt, va_list args) {
|
||||
const char* prefix = "";
|
||||
switch (level) {
|
||||
@@ -50,7 +61,7 @@ namespace {
|
||||
va_copy(args_copy, args);
|
||||
vfprintf(stderr, fmt, args);
|
||||
fprintf(stderr, "\n");
|
||||
// 转发到诊断回调
|
||||
// 转发到诊断回调 / Forward to diagnostic callback
|
||||
auto cb = g_diag_callback.load(std::memory_order_acquire);
|
||||
if (cb) {
|
||||
char buf[1024];
|
||||
@@ -60,6 +71,8 @@ namespace {
|
||||
va_end(args_copy);
|
||||
}
|
||||
|
||||
// host_log_impl 的 printf 风格便捷包装。
|
||||
// Convenience wrapper around host_log_impl for printf-style calls.
|
||||
void host_log(int level, const char* fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
@@ -67,16 +80,22 @@ namespace {
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
// ---- Host API 表回调 ----
|
||||
// ---- Host API 表回调 / Host API table callbacks ----
|
||||
|
||||
// 将服务 vtable 按名称和版本注册到全局注册表。
|
||||
// Register a service vtable with the given name and version into the global registry.
|
||||
int api_register_service(const char* name, int version, void* vtable) {
|
||||
return g_service_registry ? g_service_registry->register_service(name, version, vtable) : -1;
|
||||
}
|
||||
|
||||
// 按名称和最低版本从全局注册表查询服务 vtable。
|
||||
// Query a service vtable by name and minimum version from the global registry.
|
||||
void* api_query_service(const char* name, int min_version) {
|
||||
return g_service_registry ? g_service_registry->query_service(name, min_version) : nullptr;
|
||||
}
|
||||
|
||||
// 通过全局事件总线订阅指定事件类型的处理函数。
|
||||
// Subscribe a handler to a given event type via the global event bus.
|
||||
int api_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata) {
|
||||
if (!g_event_bus || !handler) return -1;
|
||||
return g_event_bus->subscribe(event_type,
|
||||
@@ -85,22 +104,32 @@ namespace {
|
||||
});
|
||||
}
|
||||
|
||||
// 通过全局事件总路线程安全地向所有已注册处理函数发送事件。
|
||||
// Emit an event to all registered handlers via the global event bus.
|
||||
int api_event_emit(int event_type, const void* data) {
|
||||
return g_event_bus ? g_event_bus->emit(event_type, data) : -1;
|
||||
}
|
||||
|
||||
// 通过订阅 ID 取消注册之前的事件处理函数。
|
||||
// Unsubscribe a previously registered event handler by subscription ID.
|
||||
void api_event_unsubscribe(int sub_id) {
|
||||
if (g_event_bus) g_event_bus->unsubscribe(sub_id);
|
||||
}
|
||||
|
||||
// 从全局配置存储中按键名读取配置值。
|
||||
// Read a config value by key from the global config store.
|
||||
const char* api_config_get(const char* key) {
|
||||
return g_config ? g_config->get(key) : nullptr;
|
||||
}
|
||||
|
||||
// 在全局配置存储中设置配置键值对。
|
||||
// Set a config key/value pair in the global config store.
|
||||
int api_config_set(const char* key, const char* value) {
|
||||
return g_config ? g_config->set(key, value) : -1;
|
||||
}
|
||||
|
||||
// 主机端日志函数(host_log_impl 的 varargs 包装)。
|
||||
// Host-facing log function (varargs wrapper around host_log_impl).
|
||||
void api_log(int level, const char* fmt, ...) {
|
||||
va_list args;
|
||||
va_start(args, fmt);
|
||||
@@ -108,11 +137,16 @@ namespace {
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
// 内存分配包装 / Memory allocation wrapper (malloc).
|
||||
void* api_alloc(size_t size) { return malloc(size); }
|
||||
// 内存释放包装 / Memory free wrapper (free).
|
||||
void api_free(void* ptr) { free(ptr); }
|
||||
|
||||
// 字符串复制包装 / String duplication wrapper (host_strdup).
|
||||
char* api_strdup(const char* s) { return host_strdup(s); }
|
||||
|
||||
// 传递给每个插件 on_init 的完整主机 API vtable。
|
||||
// The complete host API vtable passed to every plugin's on_init.
|
||||
dstalk_host_api_t g_host_api = {
|
||||
api_register_service,
|
||||
api_query_service,
|
||||
@@ -127,8 +161,12 @@ namespace {
|
||||
api_strdup
|
||||
};
|
||||
|
||||
// ---- 插件目录扫描 ----
|
||||
// ---- 插件目录扫描 / Plugin directory scanning ----
|
||||
|
||||
// 扫描目录中的插件 DLL 并通过 PluginLoader 加载。
|
||||
// 返回加载的插件数量,出错返回 -1。
|
||||
// Scan a directory for plugin DLLs and load them via PluginLoader.
|
||||
// Returns the number of plugins loaded, or -1 on error.
|
||||
int load_plugins_from_directory(const char* plugin_dir) {
|
||||
if (!plugin_dir) return -1;
|
||||
|
||||
@@ -163,9 +201,11 @@ namespace {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 公共 API
|
||||
// 公共 API / Public API
|
||||
// ============================================================
|
||||
|
||||
// 初始化 dstalk 主机:创建单例、加载配置、扫描插件、初始化所有插件。
|
||||
// Initialize the dstalk host: create singletons, load config, scan plugins, initialize all plugins.
|
||||
DSTALK_API int dstalk_init(const char* config_path)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_init_mutex);
|
||||
@@ -178,14 +218,14 @@ DSTALK_API int dstalk_init(const char* config_path)
|
||||
g_service_registry = new dstalk::ServiceRegistry();
|
||||
g_plugin_loader = new dstalk::PluginLoader();
|
||||
|
||||
// 加载配置
|
||||
// 加载配置 / Load config
|
||||
if (config_path && config_path[0]) {
|
||||
if (g_config->load_file(config_path) != 0) {
|
||||
host_log(DSTALK_LOG_WARN, "Failed to load config: %s", config_path);
|
||||
}
|
||||
}
|
||||
|
||||
// 扫描插件目录
|
||||
// 扫描插件目录 / Scan plugin directory
|
||||
const char* plugin_dir = g_config->get("plugin_dir");
|
||||
if (!plugin_dir) plugin_dir = "plugins";
|
||||
int loaded = load_plugins_from_directory(plugin_dir);
|
||||
@@ -195,7 +235,7 @@ DSTALK_API int dstalk_init(const char* config_path)
|
||||
loaded = load_plugins_from_directory("../plugins");
|
||||
}
|
||||
|
||||
// 初始化所有插件
|
||||
// 初始化所有插件 / Initialize all plugins
|
||||
if (g_plugin_loader->initialize_all(&g_host_api) != 0) {
|
||||
host_log(DSTALK_LOG_WARN, "Some plugins failed to initialize");
|
||||
}
|
||||
@@ -214,6 +254,8 @@ DSTALK_API int dstalk_init(const char* config_path)
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭 dstalk 主机:关闭插件、销毁单例、释放资源。
|
||||
// Shutdown the dstalk host: shutdown plugins, destroy singletons, release resources.
|
||||
DSTALK_API void dstalk_shutdown(void)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_init_mutex);
|
||||
@@ -234,6 +276,8 @@ DSTALK_API void dstalk_shutdown(void)
|
||||
g_initialized = false;
|
||||
}
|
||||
|
||||
// 从给定路径加载单个插件 DLL 并初始化。返回插件 ID,失败返回 -1。
|
||||
// Load a single plugin DLL from the given path and initialize it. Returns plugin ID or -1.
|
||||
DSTALK_API int dstalk_plugin_load(const char* path)
|
||||
{
|
||||
if (!g_initialized || !g_plugin_loader) return -1;
|
||||
@@ -244,12 +288,16 @@ DSTALK_API int dstalk_plugin_load(const char* path)
|
||||
return id;
|
||||
}
|
||||
|
||||
// 按 ID 卸载插件:调用 on_shutdown,卸载 DLL,从注册表中移除。成功返回 0。
|
||||
// Unload a plugin by ID: call on_shutdown, unload DLL, remove from registry. Returns 0 on success.
|
||||
DSTALK_API int dstalk_plugin_unload(int plugin_id)
|
||||
{
|
||||
if (!g_initialized || !g_plugin_loader) return -1;
|
||||
return g_plugin_loader->unload_plugin(plugin_id);
|
||||
}
|
||||
|
||||
// 以 JSON 字符串列出所有已加载插件。调用者必须用 dstalk_free 释放 *output_json。
|
||||
// List all loaded plugins as a JSON string. Caller must free *output_json with dstalk_free.
|
||||
DSTALK_API int dstalk_plugin_list(char** output_json)
|
||||
{
|
||||
if (!g_initialized || !g_plugin_loader || !output_json) return -1;
|
||||
@@ -257,12 +305,16 @@ DSTALK_API int dstalk_plugin_list(char** output_json)
|
||||
return *output_json ? 0 : -1;
|
||||
}
|
||||
|
||||
// 按名称和最低版本从全局服务注册表查询服务 vtable。
|
||||
// Query a service vtable by name and minimum version from the global service registry.
|
||||
DSTALK_API void* dstalk_service_query(const char* service_name, int min_version)
|
||||
{
|
||||
if (!g_initialized || !g_service_registry) return nullptr;
|
||||
return g_service_registry->query_service(service_name, min_version);
|
||||
}
|
||||
|
||||
// 订阅回调到事件类型。返回订阅 ID,失败返回 -1。
|
||||
// Subscribe a callback to an event type. Returns a subscription ID or -1.
|
||||
DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata)
|
||||
{
|
||||
if (!g_initialized || !g_event_bus || !handler) return -1;
|
||||
@@ -270,30 +322,40 @@ DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn ha
|
||||
[handler, userdata](int type, const void* data) { handler(type, data, userdata); });
|
||||
}
|
||||
|
||||
// 向订阅了该事件类型的所有处理函数发送事件。
|
||||
// Emit an event to all handlers subscribed to the given event type.
|
||||
DSTALK_API int dstalk_event_emit(int event_type, const void* data)
|
||||
{
|
||||
if (!g_initialized || !g_event_bus) return -1;
|
||||
return g_event_bus->emit(event_type, data);
|
||||
}
|
||||
|
||||
// 按订阅 ID 取消注册之前的事件处理函数。
|
||||
// Unsubscribe a previously registered event handler by subscription ID.
|
||||
DSTALK_API void dstalk_event_unsubscribe(int subscription_id)
|
||||
{
|
||||
if (!g_initialized || !g_event_bus) return;
|
||||
g_event_bus->unsubscribe(subscription_id);
|
||||
}
|
||||
|
||||
// 按键读取配置值。返回指向内部存储的指针(请勿释放)。
|
||||
// Read a configuration value by key. Returns pointer to internal storage (do not free).
|
||||
DSTALK_API const char* dstalk_config_get(const char* key)
|
||||
{
|
||||
if (!g_initialized || !g_config) return nullptr;
|
||||
return g_config->get(key);
|
||||
}
|
||||
|
||||
// 设置配置键值对。成功返回 0。
|
||||
// Set a configuration key/value pair. Returns 0 on success.
|
||||
DSTALK_API int dstalk_config_set(const char* key, const char* value)
|
||||
{
|
||||
if (!g_initialized || !g_config) return -1;
|
||||
return g_config->set(key, value);
|
||||
}
|
||||
|
||||
// 在给定级别记录消息(printf 风格)。写入 stderr 并转发到诊断回调。
|
||||
// Log a message at the given level (printf-style). Writes to stderr and forwards to diag callback.
|
||||
DSTALK_API void dstalk_log(int level, const char* fmt, ...)
|
||||
{
|
||||
va_list args;
|
||||
@@ -302,24 +364,33 @@ DSTALK_API void dstalk_log(int level, const char* fmt, ...)
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
// 通过 malloc 分配内存(为插件 ABI 一致性提供) / Allocate memory via malloc (provided for plugin ABI consistency).
|
||||
DSTALK_API void* dstalk_alloc(size_t size) { return malloc(size); }
|
||||
// 释放通过 dstalk_alloc 分配的内存(为插件 ABI 一致性提供) / Free memory allocated via dstalk_alloc (provided for plugin ABI consistency).
|
||||
DSTALK_API void dstalk_free(void* ptr) { free(ptr); }
|
||||
// 使用 dstalk_alloc 复制 C 字符串(调用者必须 dstalk_free) / Duplicate a C string using dstalk_alloc (caller must dstalk_free).
|
||||
DSTALK_API char* dstalk_strdup(const char* s) { return host_strdup(s); }
|
||||
|
||||
// 注册接收所有日志消息的诊断回调(传入 null 可取消设置)。
|
||||
// Register a diagnostic callback that receives all log messages (may be null to unset).
|
||||
DSTALK_API void dstalk_set_diag_callback(dstalk_diag_cb cb) {
|
||||
g_diag_callback.store(cb, std::memory_order_release);
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// LSP 便捷函数 (委托给 "lsp" 服务插件)
|
||||
// LSP 便捷函数 (委托给 "lsp" 服务插件) / LSP convenience functions (delegated to "lsp" service plugin)
|
||||
// ============================================================
|
||||
|
||||
// 从全局服务注册表获取 "lsp" 服务 vtable,不可用则返回 null。
|
||||
// Retrieve the "lsp" service vtable from the global service registry, or null if unavailable.
|
||||
static const dstalk_lsp_service_t* get_lsp_service() {
|
||||
if (!g_initialized || !g_service_registry) return nullptr;
|
||||
return static_cast<const dstalk_lsp_service_t*>(
|
||||
g_service_registry->query_service("lsp", 1));
|
||||
}
|
||||
|
||||
// 为给定的命令和语言启动语言服务器进程。
|
||||
// Start a language server process for the given command and language.
|
||||
DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language)
|
||||
{
|
||||
auto* svc = get_lsp_service();
|
||||
@@ -327,12 +398,16 @@ DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language)
|
||||
return svc->start(server_cmd, language);
|
||||
}
|
||||
|
||||
// 停止当前正在运行的语言服务器进程。
|
||||
// Stop the currently running language server process.
|
||||
DSTALK_API void dstalk_lsp_stop(void)
|
||||
{
|
||||
auto* svc = get_lsp_service();
|
||||
if (svc && svc->stop) svc->stop();
|
||||
}
|
||||
|
||||
// 在 LSP 服务器中打开文档以供分析(didOpen 通知)。
|
||||
// Open a document in the LSP server for analysis (didOpen notification).
|
||||
DSTALK_API int dstalk_lsp_open(const char* uri, const char* content, const char* language_id)
|
||||
{
|
||||
auto* svc = get_lsp_service();
|
||||
@@ -340,6 +415,8 @@ DSTALK_API int dstalk_lsp_open(const char* uri, const char* content, const char*
|
||||
return svc->open_document(uri, content, language_id);
|
||||
}
|
||||
|
||||
// 在 LSP 服务器中关闭文档(didClose 通知)。
|
||||
// Close a document in the LSP server (didClose notification).
|
||||
DSTALK_API int dstalk_lsp_close(const char* uri)
|
||||
{
|
||||
auto* svc = get_lsp_service();
|
||||
@@ -347,6 +424,8 @@ DSTALK_API int dstalk_lsp_close(const char* uri)
|
||||
return svc->close_document(uri);
|
||||
}
|
||||
|
||||
// 检索文档的当前诊断信息。调用者必须用 dstalk_free 释放 *output。
|
||||
// Retrieve current diagnostics for a document. Caller must free *output with dstalk_free.
|
||||
DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output)
|
||||
{
|
||||
auto* svc = get_lsp_service();
|
||||
@@ -354,6 +433,8 @@ DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output)
|
||||
return svc->get_diagnostics(uri, output);
|
||||
}
|
||||
|
||||
// 请求文档位置处的悬停信息。调用者必须用 dstalk_free 释放 *output。
|
||||
// Request hover information at a document position. Caller must free *output with dstalk_free.
|
||||
DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character, char** output)
|
||||
{
|
||||
auto* svc = get_lsp_service();
|
||||
@@ -361,6 +442,8 @@ DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character, char**
|
||||
return svc->get_hover(uri, line, character, output);
|
||||
}
|
||||
|
||||
// 请求文档位置处的补全项。调用者必须用 dstalk_free 释放 *output。
|
||||
// Request completion items at a document position. Caller must free *output with dstalk_free.
|
||||
DSTALK_API int dstalk_lsp_completion(const char* uri, int line, int character, char** output)
|
||||
{
|
||||
auto* svc = get_lsp_service();
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file plugin_loader.cpp
|
||||
* @brief PluginLoader implementation: DLL load/unload, path validation, Kahn topological sort, lifecycle management.
|
||||
* PluginLoader 实现:DLL 加载/卸载、路径验证、Kahn 拓扑排序、生命周期管理。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "plugin_loader.hpp"
|
||||
|
||||
#include <boost/json.hpp>
|
||||
@@ -21,20 +28,26 @@ namespace dstalk {
|
||||
namespace json = boost::json;
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
// 析构函数:调用 shutdown_all 释放所有插件并释放 DLL 句柄。
|
||||
// Destructor: calls shutdown_all to release all plugins and free DLL handles.
|
||||
PluginLoader::~PluginLoader()
|
||||
{
|
||||
shutdown_all();
|
||||
}
|
||||
|
||||
// 加载插件 DLL:验证路径(扩展名、目录遍历、目录),加载库,
|
||||
// 解析 dstalk_plugin_init,验证 API 版本,解析依赖,分配 ID。
|
||||
// Load a plugin DLL: validate path (extension, traversal, directory), load library,
|
||||
// resolve dstalk_plugin_init, verify API version, parse dependencies, assign ID.
|
||||
int PluginLoader::load_plugin(const char* path)
|
||||
{
|
||||
if (!path) return -1;
|
||||
|
||||
// === Path validation (F-18.3-3) ===
|
||||
// === 路径验证 (F-18.3-3) / Path validation (F-18.3-3) ===
|
||||
{
|
||||
fs::path p = fs::absolute(fs::path(path)).lexically_normal();
|
||||
|
||||
// Extension check (case-insensitive)
|
||||
// 扩展名检查(大小写不敏感) / Extension check (case-insensitive)
|
||||
std::string ext = p.extension().string();
|
||||
std::transform(ext.begin(), ext.end(), ext.begin(),
|
||||
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
||||
@@ -57,7 +70,7 @@ int PluginLoader::load_plugin(const char* path)
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Directory traversal check
|
||||
// 目录遍历检查 / Directory traversal check
|
||||
bool has_dotdot = false;
|
||||
bool in_plugins_dir = false;
|
||||
for (const auto& comp : p) {
|
||||
@@ -78,6 +91,7 @@ int PluginLoader::load_plugin(const char* path)
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 目录约束:必须位于 'plugins' 目录下或为纯文件名
|
||||
// Directory constraint: must be under a 'plugins' directory or be a plain filename
|
||||
if (!in_plugins_dir && p.has_parent_path()) {
|
||||
if (host_api_) {
|
||||
@@ -88,7 +102,7 @@ int PluginLoader::load_plugin(const char* path)
|
||||
}
|
||||
}
|
||||
|
||||
// 加载DLL
|
||||
// 加载DLL / Load DLL
|
||||
#ifdef _WIN32
|
||||
void* handle = LoadLibraryA(path);
|
||||
#else
|
||||
@@ -109,7 +123,7 @@ int PluginLoader::load_plugin(const char* path)
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 获取入口函数
|
||||
// 获取入口函数 / Resolve entry function
|
||||
#ifdef _WIN32
|
||||
auto init_fn = (dstalk_plugin_init_fn)GetProcAddress(
|
||||
(HMODULE)handle, "dstalk_plugin_init");
|
||||
@@ -138,7 +152,7 @@ int PluginLoader::load_plugin(const char* path)
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 调用入口函数获取插件信息
|
||||
// 调用入口函数获取插件信息 / Call entry function to get plugin info
|
||||
dstalk_plugin_info_t* info = nullptr;
|
||||
try {
|
||||
info = init_fn();
|
||||
@@ -160,7 +174,7 @@ int PluginLoader::load_plugin(const char* path)
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 检查API版本兼容性
|
||||
// 检查API版本兼容性 / Check API version compatibility
|
||||
if (info->api_version != DSTALK_API_VERSION) {
|
||||
if (host_api_) {
|
||||
host_api_->log(DSTALK_LOG_ERROR,
|
||||
@@ -175,7 +189,7 @@ int PluginLoader::load_plugin(const char* path)
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 创建插件信息
|
||||
// 创建插件信息 / Create plugin info
|
||||
int id = next_id_++;
|
||||
PluginInfo plugin;
|
||||
plugin.id = id;
|
||||
@@ -187,7 +201,7 @@ int PluginLoader::load_plugin(const char* path)
|
||||
plugin.info = info;
|
||||
plugin.initialized = false;
|
||||
|
||||
// 解析依赖
|
||||
// 解析依赖 / Parse dependencies
|
||||
for (int i = 0; i < DSTALK_MAX_DEPS && info->dependencies[i]; i++) {
|
||||
plugin.dependencies.push_back(info->dependencies[i]);
|
||||
}
|
||||
@@ -196,6 +210,8 @@ int PluginLoader::load_plugin(const char* path)
|
||||
return id;
|
||||
}
|
||||
|
||||
// 按 ID 卸载插件:若已初始化则调用 on_shutdown,释放 DLL 句柄,从 map 中移除。
|
||||
// Unload a plugin by ID: call on_shutdown if initialized, free the DLL handle, erase from map.
|
||||
int PluginLoader::unload_plugin(int plugin_id)
|
||||
{
|
||||
auto it = plugins_.find(plugin_id);
|
||||
@@ -203,7 +219,7 @@ int PluginLoader::unload_plugin(int plugin_id)
|
||||
|
||||
PluginInfo& plugin = it->second;
|
||||
|
||||
// 调用关闭回调
|
||||
// 调用关闭回调 / Call shutdown callback
|
||||
if (plugin.initialized && plugin.info->on_shutdown) {
|
||||
try {
|
||||
plugin.info->on_shutdown();
|
||||
@@ -216,7 +232,7 @@ int PluginLoader::unload_plugin(int plugin_id)
|
||||
}
|
||||
}
|
||||
|
||||
// 卸载DLL
|
||||
// 卸载DLL / Unload DLL
|
||||
#ifdef _WIN32
|
||||
FreeLibrary((HMODULE)plugin.handle);
|
||||
#else
|
||||
@@ -227,6 +243,8 @@ int PluginLoader::unload_plugin(int plugin_id)
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 将所有已加载插件序列化为 JSON 数组字符串。
|
||||
// Serialize all loaded plugins into a JSON array string.
|
||||
std::string PluginLoader::list_plugins() const
|
||||
{
|
||||
json::array arr;
|
||||
@@ -250,15 +268,19 @@ std::string PluginLoader::list_plugins() const
|
||||
return json::serialize(arr);
|
||||
}
|
||||
|
||||
// 使用 Kahn 算法计算依赖顺序的插件 ID 列表。
|
||||
// 若检测到循环依赖则抛出 std::runtime_error。
|
||||
// Compute dependency-ordered plugin IDs using Kahn's algorithm.
|
||||
// Throws std::runtime_error if a circular dependency is detected.
|
||||
std::vector<int> PluginLoader::topological_sort() const
|
||||
{
|
||||
// 构建名称到ID的映射
|
||||
// 构建名称到ID的映射 / Build name-to-ID map
|
||||
std::unordered_map<std::string, int> name_to_id;
|
||||
for (const auto& [id, plugin] : plugins_) {
|
||||
name_to_id[plugin.name] = id;
|
||||
}
|
||||
|
||||
// 计算入度
|
||||
// 计算入度 / Calculate in-degrees
|
||||
std::unordered_map<int, int> in_degree;
|
||||
std::unordered_map<int, std::vector<int>> dependents;
|
||||
|
||||
@@ -277,7 +299,7 @@ std::vector<int> PluginLoader::topological_sort() const
|
||||
}
|
||||
}
|
||||
|
||||
// 拓扑排序(Kahn算法)
|
||||
// 拓扑排序(Kahn算法) / Topological sort (Kahn's algorithm)
|
||||
std::queue<int> queue;
|
||||
for (const auto& [id, degree] : in_degree) {
|
||||
if (degree == 0) {
|
||||
@@ -298,7 +320,7 @@ std::vector<int> PluginLoader::topological_sort() const
|
||||
}
|
||||
}
|
||||
|
||||
// 检查循环依赖
|
||||
// 检查循环依赖 / Check for circular dependency
|
||||
if (sorted.size() != plugins_.size()) {
|
||||
throw std::runtime_error("Circular dependency detected");
|
||||
}
|
||||
@@ -306,17 +328,21 @@ std::vector<int> PluginLoader::topological_sort() const
|
||||
return sorted;
|
||||
}
|
||||
|
||||
// 验证依赖:检查缺失依赖和循环依赖。
|
||||
// 成功返回 0,发现错误返回 -1(错误通过 host_api_ 记录)。
|
||||
// Validate dependencies: checks for missing dependencies and circular dependencies.
|
||||
// Returns 0 on success, -1 if any errors found (errors are logged via host_api_).
|
||||
int PluginLoader::validate_dependencies() const
|
||||
{
|
||||
int error_count = 0;
|
||||
|
||||
// 构建名称到ID的映射
|
||||
// 构建名称到ID的映射 / Build name-to-ID map
|
||||
std::unordered_map<std::string, int> name_to_id;
|
||||
for (const auto& [id, plugin] : plugins_) {
|
||||
name_to_id[plugin.name] = id;
|
||||
}
|
||||
|
||||
// 检查1:缺失依赖(deps 引用的插件未加载)
|
||||
// 检查1:缺失依赖(deps 引用的插件未加载) / Check 1: missing dependencies (deps reference plugins not loaded)
|
||||
for (const auto& [id, plugin] : plugins_) {
|
||||
for (const auto& dep_name : plugin.dependencies) {
|
||||
if (name_to_id.find(dep_name) == name_to_id.end()) {
|
||||
@@ -330,7 +356,7 @@ int PluginLoader::validate_dependencies() const
|
||||
}
|
||||
}
|
||||
|
||||
// 检查2:循环依赖(拓扑排序失败)
|
||||
// 检查2:循环依赖(拓扑排序失败) / Check 2: circular dependency (topological sort fails)
|
||||
try {
|
||||
topological_sort();
|
||||
} catch (const std::runtime_error&) {
|
||||
@@ -344,12 +370,19 @@ int PluginLoader::validate_dependencies() const
|
||||
return error_count > 0 ? -1 : 0;
|
||||
}
|
||||
|
||||
// 按依赖顺序初始化所有未初始化的插件。
|
||||
// 无效依赖或失败初始化会标记插件名,避免级联失败。
|
||||
// 返回初始化失败的插件数量,严重错误返回 -1。
|
||||
// Initialize all uninitialized plugins in dependency order.
|
||||
// Invalid dependencies or failed inits mark the plugin name, avoiding cascading failures.
|
||||
// Returns the number of plugins that failed to initialize, or -1 on critical error.
|
||||
int PluginLoader::initialize_all(const dstalk_host_api_t* host_api)
|
||||
{
|
||||
if (!host_api) return -1;
|
||||
host_api_ = host_api;
|
||||
|
||||
// 依赖合法性校验(log 错误但不 crash,继续初始化流程)
|
||||
// Validate dependencies (log errors but don't crash, continue initialization)
|
||||
if (validate_dependencies() != 0) {
|
||||
host_api->log(DSTALK_LOG_WARN,
|
||||
"[plugin_loader] Dependency validation failed; initialization may be incomplete");
|
||||
@@ -368,7 +401,7 @@ int PluginLoader::initialize_all(const dstalk_host_api_t* host_api)
|
||||
PluginInfo& plugin = it->second;
|
||||
if (plugin.initialized) continue;
|
||||
|
||||
// 检查依赖是否已失败
|
||||
// 检查依赖是否已失败 / Check if dependency has already failed
|
||||
bool dep_unavailable = false;
|
||||
for (const auto& dep_name : plugin.dependencies) {
|
||||
if (failed_names.count(dep_name)) {
|
||||
@@ -415,13 +448,17 @@ int PluginLoader::initialize_all(const dstalk_host_api_t* host_api)
|
||||
|
||||
return failed_count;
|
||||
} catch (const std::runtime_error&) {
|
||||
// 循环依赖
|
||||
// 循环依赖 / Circular dependency
|
||||
return -1;
|
||||
} catch (const std::exception&) {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// 仅初始化尚未初始化的插件(用于增量/按需加载)。
|
||||
// 返回新初始化的插件数量,失败返回 -1。
|
||||
// Initialize only plugins that haven't been initialized yet (used for incremental/on-demand loading).
|
||||
// Returns the number of newly initialized plugins, or -1 on failure.
|
||||
int PluginLoader::initialize_pending(const dstalk_host_api_t* host_api)
|
||||
{
|
||||
host_api_ = host_api;
|
||||
@@ -463,15 +500,17 @@ int PluginLoader::initialize_pending(const dstalk_host_api_t* host_api)
|
||||
}
|
||||
}
|
||||
|
||||
// 按逆依赖顺序关闭所有插件,然后释放所有 DLL 句柄并清空 map。
|
||||
// Shutdown all plugins in reverse dependency order, then free all DLL handles and clear the map.
|
||||
void PluginLoader::shutdown_all()
|
||||
{
|
||||
// 按逆序关闭
|
||||
// 按逆序关闭 / Shutdown in reverse order
|
||||
std::vector<int> order;
|
||||
try {
|
||||
order = topological_sort();
|
||||
std::reverse(order.begin(), order.end());
|
||||
} catch (...) {
|
||||
// 如果排序失败,按任意顺序关闭
|
||||
// 如果排序失败,按任意顺序关闭 / If sorting fails, shutdown in arbitrary order
|
||||
for (const auto& [id, _] : plugins_) {
|
||||
order.push_back(id);
|
||||
}
|
||||
@@ -496,7 +535,7 @@ void PluginLoader::shutdown_all()
|
||||
plugin.initialized = false;
|
||||
}
|
||||
|
||||
// 释放所有 DLL 句柄
|
||||
// 释放所有 DLL 句柄 / Free all DLL handles
|
||||
for (auto& [id, plugin] : plugins_) {
|
||||
if (plugin.handle) {
|
||||
#ifdef _WIN32
|
||||
@@ -510,6 +549,8 @@ void PluginLoader::shutdown_all()
|
||||
plugins_.clear();
|
||||
}
|
||||
|
||||
// 按 ID 查找插件。返回 PluginInfo 指针,未找到则返回 nullptr。
|
||||
// Look up a plugin by ID. Returns pointer to PluginInfo, or nullptr if not found.
|
||||
const PluginInfo* PluginLoader::get_plugin(int plugin_id) const
|
||||
{
|
||||
auto it = plugins_.find(plugin_id);
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file plugin_loader.hpp
|
||||
* @brief DLL plugin loader with topological sort for dependency-ordered initialization.
|
||||
* DLL 插件加载器,使用拓扑排序实现按依赖顺序初始化。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
@@ -8,6 +15,8 @@
|
||||
|
||||
namespace dstalk {
|
||||
|
||||
// 描述单个已加载插件:标识、DLL 句柄、信息 vtable 和初始化状态。
|
||||
// Describes a single loaded plugin: identity, DLL handle, info vtable, and init state.
|
||||
struct PluginInfo {
|
||||
int id;
|
||||
std::string name;
|
||||
@@ -16,42 +25,47 @@ struct PluginInfo {
|
||||
int api_version;
|
||||
std::vector<std::string> dependencies;
|
||||
|
||||
void* handle; // DLL handle
|
||||
void* handle; // DLL 句柄 / DLL handle
|
||||
dstalk_plugin_info_t* info;
|
||||
bool initialized;
|
||||
};
|
||||
|
||||
// 管理基于 DLL 的插件生命周期:加载、卸载、验证依赖、
|
||||
// 拓扑排序初始化、关闭和 JSON 列表。
|
||||
// Manages the lifecycle of DLL-based plugins: load, unload, validate dependencies,
|
||||
// topological-sort initialization, shutdown, and JSON listing.
|
||||
class PluginLoader {
|
||||
public:
|
||||
PluginLoader() = default;
|
||||
~PluginLoader();
|
||||
|
||||
// 加载插件(返回插件ID,失败返回-1)
|
||||
// 加载插件(返回插件ID,失败返回-1) / Load plugin (returns plugin ID, -1 on failure)
|
||||
int load_plugin(const char* path);
|
||||
|
||||
// 卸载插件
|
||||
// 卸载插件 / Unload plugin
|
||||
int unload_plugin(int plugin_id);
|
||||
|
||||
// 获取插件列表(JSON格式)
|
||||
// 获取插件列表(JSON格式) / Get plugin list (JSON format)
|
||||
std::string list_plugins() const;
|
||||
|
||||
// 按依赖顺序初始化所有插件
|
||||
// 按依赖顺序初始化所有插件 / Initialize all plugins in dependency order
|
||||
int initialize_all(const dstalk_host_api_t* host_api);
|
||||
|
||||
// 仅初始化尚未初始化的插件(增量加载场景)
|
||||
// 仅初始化尚未初始化的插件(增量加载场景) / Initialize only uninitialized plugins (incremental loading scenario)
|
||||
int initialize_pending(const dstalk_host_api_t* host_api);
|
||||
|
||||
// 关闭所有插件
|
||||
// 关闭所有插件 / Shutdown all plugins
|
||||
void shutdown_all();
|
||||
|
||||
// 获取插件信息
|
||||
// 获取插件信息 / Get plugin info
|
||||
const PluginInfo* get_plugin(int plugin_id) const;
|
||||
|
||||
private:
|
||||
// 拓扑排序(按依赖顺序)
|
||||
// 拓扑排序(按依赖顺序) / Topological sort (by dependency order)
|
||||
std::vector<int> topological_sort() const;
|
||||
|
||||
// 依赖合法性校验(缺失依赖 + 循环依赖),返回 0 成功 / -1 失败
|
||||
// Validate dependencies (missing + circular), returns 0 success / -1 failure
|
||||
int validate_dependencies() const;
|
||||
|
||||
std::unordered_map<int, PluginInfo> plugins_;
|
||||
|
||||
@@ -1,22 +1,30 @@
|
||||
/* @file service_registry.cpp
|
||||
* @brief ServiceRegistry implementation: register, query, unregister with reader-writer locking.
|
||||
* ServiceRegistry 实现:基于读写锁的 register、query、unregister。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "service_registry.hpp"
|
||||
|
||||
namespace dstalk {
|
||||
|
||||
// 注册指定版本的命名服务 / Register a named service at a given version. 参数为空返回 -1,已注册返回 -2 / Returns -1 on null args, -2 if already registered.
|
||||
int ServiceRegistry::register_service(const char* name, int version, void* vtable)
|
||||
{
|
||||
if (!name || !vtable) return -1;
|
||||
|
||||
std::unique_lock<std::shared_mutex> lock(mutex_);
|
||||
|
||||
// 检查是否已注册
|
||||
// 检查是否已注册 / Check if already registered
|
||||
if (services_.find(name) != services_.end()) {
|
||||
return -2; // 已存在
|
||||
return -2; // 已存在 / already registered
|
||||
}
|
||||
|
||||
services_[name] = {name, version, vtable};
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 按名称和最低版本查询服务 / Query a service by name and minimum version. 返回 vtable 指针或 nullptr / Returns vtable pointer or nullptr if not found.
|
||||
void* ServiceRegistry::query_service(const char* name, int min_version) const
|
||||
{
|
||||
if (!name) return nullptr;
|
||||
@@ -31,6 +39,7 @@ void* ServiceRegistry::query_service(const char* name, int min_version) const
|
||||
return it->second.vtable;
|
||||
}
|
||||
|
||||
// 注销指定名称的服务(name 为空或未找到时无操作)/ Unregister a named service (no-op if name is null or not found).
|
||||
void ServiceRegistry::unregister_service(const char* name)
|
||||
{
|
||||
if (!name) return;
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
/* @file service_registry.hpp
|
||||
* @brief Name-versioned service registry for decoupled plugin communication.
|
||||
* 基于名称+版本的服务注册表,用于插件间解耦通信。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
@@ -7,18 +13,23 @@
|
||||
|
||||
namespace dstalk {
|
||||
|
||||
// 名称 + 最低版本服务目录 / Name + minimum-version service directory.
|
||||
// 插件注册 vtable;消费者按名称和版本约束查询 /
|
||||
// Plugins register vtables; consumers query by name and version constraint.
|
||||
// 读取(query)使用 shared_lock;写入(register/unregister)使用 unique_lock /
|
||||
// Reads (query) use shared_lock; writes (register/unregister) use unique_lock.
|
||||
class ServiceRegistry {
|
||||
public:
|
||||
ServiceRegistry() = default;
|
||||
~ServiceRegistry() = default;
|
||||
|
||||
// 注册服务
|
||||
// 注册服务 / Register a named service at a given version
|
||||
int register_service(const char* name, int version, void* vtable);
|
||||
|
||||
// 查询服务(返回 vtable 指针,或 nullptr)
|
||||
// 查询服务(返回 vtable 指针,或 nullptr)/ Query a service by name and minimum version
|
||||
void* query_service(const char* name, int min_version) const;
|
||||
|
||||
// 注销服务
|
||||
// 注销服务 / Unregister a named service
|
||||
void unregister_service(const char* name);
|
||||
|
||||
private:
|
||||
@@ -28,7 +39,7 @@ private:
|
||||
void* vtable;
|
||||
};
|
||||
|
||||
mutable std::shared_mutex mutex_;
|
||||
mutable std::shared_mutex mutex_; // 读写锁:query 用 shared,register/unregister 用 unique / RW lock: shared for query, unique for register/unregister
|
||||
std::unordered_map<std::string, ServiceEntry> services_;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
// ============================================================================
|
||||
// dstalk-gui — SDL3 聊天客户端
|
||||
// ============================================================================
|
||||
// 使用 SDL3 内置的 SDL_RenderDebugText() 渲染文本(8x8 像素),
|
||||
// 通过 SDL_SetRenderScale 2 倍缩放至有效的 16x16 像素。
|
||||
//
|
||||
// 该文件是独立的——不需要额外的源文件。
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file main.cpp
|
||||
* @brief SDL3-based GUI frontend for dstalk (stub/minimal implementation).
|
||||
* dstalk 的 SDL3 图形界面前端(最小化实现)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*
|
||||
* Uses SDL3's built-in SDL_RenderDebugText() for 8x8 pixel text, scaled 2x to
|
||||
* effective 16x16 pixels via SDL_SetRenderScale. Self-contained single-file GUI.
|
||||
* 使用 SDL3 内置的 SDL_RenderDebugText() 渲染 8x8 像素文本,通过 SDL_SetRenderScale
|
||||
* 缩放 2 倍达到 16x16 像素效果。自包含的单文件 GUI。
|
||||
*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
@@ -19,46 +22,48 @@
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
|
||||
// ---- 服务 vtable 指针 ----
|
||||
// ---- 服务 vtable 指针 / Service vtable pointers ----
|
||||
// Global pointers to service vtables queried from the host on startup.
|
||||
// 在启动时从主机查询获取的服务 vtable 全局指针。
|
||||
static const dstalk_ai_service_t* g_ai_svc = nullptr;
|
||||
static const dstalk_session_service_t* g_session_svc = nullptr;
|
||||
|
||||
// ---- 常量 ----
|
||||
// ---- 常量 / Constants ----
|
||||
|
||||
static constexpr int WINDOW_W = 1024;
|
||||
static constexpr int WINDOW_H = 768;
|
||||
static constexpr float RENDER_SCALE = 2.0f;
|
||||
|
||||
// 逻辑坐标尺寸(物理像素 / RENDER_SCALE)
|
||||
// 逻辑坐标尺寸(物理像素 / RENDER_SCALE) / Logical coordinate dimensions (physical pixels / RENDER_SCALE)
|
||||
static constexpr int LOGICAL_W = WINDOW_W / 2; // 512
|
||||
static constexpr int LOGICAL_H = WINDOW_H / 2; // 384
|
||||
|
||||
static constexpr int CHAR_W = 8; // SDL_RenderDebugText 原生字符宽度(逻辑像素)
|
||||
static constexpr int CHAR_H = 8; // 原生字符高度(逻辑像素)
|
||||
static constexpr int TITLE_H = 16; // 标题栏高度(逻辑像素)
|
||||
static constexpr int PADDING = 4; // 内边距(逻辑像素)
|
||||
static constexpr int CHAR_W = 8; // SDL_RenderDebugText 原生字符宽度(逻辑像素) / native char width (logical pixels)
|
||||
static constexpr int CHAR_H = 8; // 原生字符高度(逻辑像素) / native char height (logical pixels)
|
||||
static constexpr int TITLE_H = 16; // 标题栏高度(逻辑像素) / title bar height (logical pixels)
|
||||
static constexpr int PADDING = 4; // 内边距(逻辑像素) / padding (logical pixels)
|
||||
|
||||
// 侧边栏
|
||||
static constexpr int SIDEBAR_W = 80; // 侧边栏宽度(逻辑像素,渲染为 160 物理像素)
|
||||
// 侧边栏 / Sidebar
|
||||
static constexpr int SIDEBAR_W = 80; // 侧边栏宽度(逻辑像素,渲染为 160 物理像素) / sidebar width (logical, renders as 160 physical px)
|
||||
|
||||
// 状态栏
|
||||
static constexpr int STATUS_H = 20; // 状态栏高度(逻辑像素,渲染为 40 物理像素)
|
||||
// 状态栏 / Status bar
|
||||
static constexpr int STATUS_H = 20; // 状态栏高度(逻辑像素,渲染为 40 物理像素) / status bar height (logical, renders as 40 physical px)
|
||||
|
||||
// 输入区域动态高度
|
||||
static constexpr int INPUT_H_MIN = 40; // 最小高度(逻辑像素)
|
||||
static constexpr int INPUT_H_MAX = 120; // 最大高度(逻辑像素)
|
||||
// 输入区域动态高度 / Input area dynamic height
|
||||
static constexpr int INPUT_H_MIN = 40; // 最小高度(逻辑像素) / min height (logical pixels)
|
||||
static constexpr int INPUT_H_MAX = 120; // 最大高度(逻辑像素) / max height (logical pixels)
|
||||
|
||||
// 消息区域(Y 起点不变,宽度和高度动态计算)
|
||||
// 消息区域(Y 起点不变,宽度和高度动态计算) / Message area (Y origin fixed, width and height calculated dynamically)
|
||||
static constexpr int MSG_Y = TITLE_H;
|
||||
|
||||
// 颜色(ARGB 格式,用于 SDL_SetRenderDrawColor)
|
||||
// 颜色(ARGB 格式,用于 SDL_SetRenderDrawColor) / Colors (ARGB format, for SDL_SetRenderDrawColor)
|
||||
static constexpr SDL_Color COL_BG = {0x1E, 0x1E, 0x2E, 0xFF};
|
||||
static constexpr SDL_Color COL_TITLE_BG = {0x2D, 0x2D, 0x44, 0xFF};
|
||||
static constexpr SDL_Color COL_INPUT_BG = {0x2A, 0x2A, 0x3E, 0xFF};
|
||||
static constexpr SDL_Color COL_USER = {0x00, 0xFF, 0xFF, 0xFF}; // 青色
|
||||
static constexpr SDL_Color COL_AI = {0x00, 0xFF, 0x80, 0xFF}; // 绿色
|
||||
static constexpr SDL_Color COL_SYS = {0xFF, 0xFF, 0x00, 0xFF}; // 黄色
|
||||
static constexpr SDL_Color COL_BTN = {0x50, 0x50, 0x80, 0xFF}; // 按钮
|
||||
static constexpr SDL_Color COL_USER = {0x00, 0xFF, 0xFF, 0xFF}; // 青色 / cyan
|
||||
static constexpr SDL_Color COL_AI = {0x00, 0xFF, 0x80, 0xFF}; // 绿色 / green
|
||||
static constexpr SDL_Color COL_SYS = {0xFF, 0xFF, 0x00, 0xFF}; // 黄色 / yellow
|
||||
static constexpr SDL_Color COL_BTN = {0x50, 0x50, 0x80, 0xFF}; // 按钮 / button
|
||||
static constexpr SDL_Color COL_WHITE = {0xFF, 0xFF, 0xFF, 0xFF};
|
||||
static constexpr SDL_Color COL_CURSOR = {0xFF, 0xFF, 0xFF, 0xFF};
|
||||
static constexpr SDL_Color COL_SEP = {0x50, 0x50, 0x70, 0xFF};
|
||||
@@ -68,8 +73,9 @@ static constexpr SDL_Color COL_SIDEBAR_BTN = {0x40, 0x40, 0x68, 0xFF};
|
||||
static constexpr SDL_Color COL_STATUSBAR_BG= {0x2D, 0x2D, 0x44, 0xFF};
|
||||
static constexpr SDL_Color COL_DIM = {0x80, 0x80, 0x80, 0xFF};
|
||||
|
||||
// ---- 数据结构 ----
|
||||
// ---- 数据结构 / Data structures ----
|
||||
|
||||
// 单条聊天消息 / Represents a single chat message with role and text content.
|
||||
struct ChatMessage {
|
||||
enum Role { USER, ASSISTANT, SYSTEM } role;
|
||||
std::string content;
|
||||
@@ -77,62 +83,66 @@ struct ChatMessage {
|
||||
ChatMessage(Role r, std::string c) : role(r), content(std::move(c)) {}
|
||||
};
|
||||
|
||||
// 保存所有可变 UI 状态 / Holds all mutable UI state: message list, input buffer, scroll, streaming flag, etc.
|
||||
struct GuiState {
|
||||
std::vector<ChatMessage> messages;
|
||||
std::string inputBuffer;
|
||||
int scrollOffset = 0; // 从底部滚动的逻辑像素
|
||||
int scrollOffset = 0; // 从底部滚动的逻辑像素 / logical pixels scrolled from bottom
|
||||
bool streaming = false;
|
||||
bool running = true;
|
||||
int cursorPos = 0; // 输入缓冲区中的光标位置
|
||||
int cursorPos = 0; // 输入缓冲区中的光标位置 / cursor position in input buffer
|
||||
bool cursorVisible = true;
|
||||
Uint64 lastCursorBlink = 0;
|
||||
float maxScroll = 0; // 可用的最大滚动距离(逻辑像素)
|
||||
float maxScroll = 0; // 可用的最大滚动距离(逻辑像素) / max available scroll distance (logical pixels)
|
||||
|
||||
// P0 新增字段
|
||||
std::vector<std::string> input_history; // 输入历史(最多 20 条)
|
||||
int history_index = -1; // 当前历史位置(-1 = 新输入)
|
||||
std::string saved_input; // 浏览历史时暂存当前输入
|
||||
bool sidebar_visible = true; // 侧边栏可见性
|
||||
std::string model_name = "deepseek-chat";// 当前模型名
|
||||
// P0 新增字段 / P0 new fields
|
||||
std::vector<std::string> input_history; // 输入历史(最多 20 条) / input history (max 20 entries)
|
||||
int history_index = -1; // 当前历史位置(-1 = 新输入) / current history position (-1 = new input)
|
||||
std::string saved_input; // 浏览历史时暂存当前输入 / saved current input while browsing history
|
||||
bool sidebar_visible = true; // 侧边栏可见性 / sidebar visibility
|
||||
std::string model_name = "deepseek-chat";// 当前模型名 / current model name
|
||||
};
|
||||
|
||||
// 持有上下文指针,用于将回调传递给流式 API
|
||||
// 将 GuiState 与 SDL 窗口/渲染器句柄及逐帧标志打包。
|
||||
// 作为 userdata 传递给流式回调,使其可以更新缓冲区并重新渲染。
|
||||
// Bundles GuiState with SDL window/renderer handles and per-frame flags.
|
||||
// Passed as userdata to the streaming callback so it can update the buffer and re-render.
|
||||
struct AppContext {
|
||||
GuiState state;
|
||||
SDL_Window* window = nullptr;
|
||||
SDL_Renderer* renderer = nullptr;
|
||||
bool sendPending = false; // 按下 Enter 后设置为 true
|
||||
std::string streamBuffer; // 存储当前流式消息
|
||||
bool sendPending = false; // 按下 Enter 后设置为 true / set to true after pressing Enter
|
||||
std::string streamBuffer; // 存储当前流式消息 / stores current streaming message
|
||||
};
|
||||
|
||||
// ---- 辅助函数 ----
|
||||
// ---- 辅助函数 / Helper functions ----
|
||||
|
||||
// 获取一个逻辑坐标的 SDL 矩形
|
||||
// 在逻辑坐标系中创建 SDL_FRect / Create an SDL_FRect in logical coordinates.
|
||||
static SDL_FRect mkRect(float x, float y, float w, float h) {
|
||||
SDL_FRect r;
|
||||
r.x = x; r.y = y; r.w = w; r.h = h;
|
||||
return r;
|
||||
}
|
||||
|
||||
// 使用给定的颜色设置绘制颜色
|
||||
// 使用 SDL_Color 设置渲染器的绘制颜色 / Set the renderer's draw color from an SDL_Color.
|
||||
static void setColor(SDL_Renderer* r, SDL_Color c) {
|
||||
SDL_SetRenderDrawColor(r, c.r, c.g, c.b, c.a);
|
||||
}
|
||||
|
||||
// 使用颜色绘制填充矩形
|
||||
// 以纯色填充矩形(逻辑坐标) / Fill a rectangle with a solid color (logical coordinates).
|
||||
static void fillRect(SDL_Renderer* r, SDL_FRect rect, SDL_Color c) {
|
||||
setColor(r, c);
|
||||
SDL_RenderFillRect(r, &rect);
|
||||
}
|
||||
|
||||
// 在给定位置(逻辑坐标)绘制一个调试文本字符串,并设定颜色
|
||||
// 在指定逻辑位置以指定颜色绘制调试文本 / Draw a debug-text string at a given logical position with the specified color.
|
||||
static void drawText(SDL_Renderer* r, float x, float y,
|
||||
const char* text, SDL_Color color) {
|
||||
setColor(r, color);
|
||||
SDL_RenderDebugText(r, x, y, text);
|
||||
}
|
||||
|
||||
// 绘制一个可见的调试文本字符,避免为空字符串调用 SDL_RenderDebugText
|
||||
// 仅当字符串非空时绘制调试文本(避免 SDL_RenderDebugText 问题) / Draw debug text only if the string is non-empty (avoids SDL_RenderDebugText issues).
|
||||
static void drawTextSafe(SDL_Renderer* r, float x, float y,
|
||||
const char* text) {
|
||||
if (text && text[0] != '\0') {
|
||||
@@ -140,7 +150,7 @@ static void drawTextSafe(SDL_Renderer* r, float x, float y,
|
||||
}
|
||||
}
|
||||
|
||||
// 计算输入区域的动态高度(根据输入内容中的换行数)
|
||||
// 根据换行符数量计算输入区域的动态高度 / Compute the dynamic height of the input area based on the number of newlines.
|
||||
static int calcInputHeight(const std::string& input) {
|
||||
int lines = 1;
|
||||
for (char ch : input) {
|
||||
@@ -150,14 +160,13 @@ static int calcInputHeight(const std::string& input) {
|
||||
std::max(INPUT_H_MIN, lines * CHAR_H + PADDING * 2));
|
||||
}
|
||||
|
||||
// ---- 文本换行 ----
|
||||
// ---- 文本换行 / Text wrapping ----
|
||||
|
||||
// 将一段文本按字符数换行。保留嵌入的 '\n',并在单词边界处尽可能按字符数换行。
|
||||
// 返回逻辑文本行列表。
|
||||
// 按 maxChars 换行文本,保留嵌入的换行符 / Word-wrap text to fit within maxChars per line, respecting embedded newlines.
|
||||
static std::vector<std::string> wrapText(const std::string& text, int maxChars) {
|
||||
std::vector<std::string> lines;
|
||||
|
||||
// 首先按嵌入的换行符分割
|
||||
// 首先按嵌入的换行符分割 / First split by embedded newlines
|
||||
std::string remaining = text;
|
||||
while (!remaining.empty()) {
|
||||
std::string segment;
|
||||
@@ -170,13 +179,13 @@ static std::vector<std::string> wrapText(const std::string& text, int maxChars)
|
||||
remaining.clear();
|
||||
}
|
||||
|
||||
// 将片段按单词换行以适应 maxChars
|
||||
// 将片段按单词换行以适应 maxChars / Wrap segment by word to fit maxChars
|
||||
while (!segment.empty()) {
|
||||
if (static_cast<int>(segment.size()) <= maxChars) {
|
||||
lines.push_back(segment);
|
||||
break;
|
||||
}
|
||||
// 在 maxChars 位置寻找空格/单词边界
|
||||
// 在 maxChars 位置寻找空格/单词边界 / Find space/word boundary at maxChars position
|
||||
int splitAt = maxChars;
|
||||
for (int i = maxChars; i > 0; --i) {
|
||||
char ch = segment[i];
|
||||
@@ -187,7 +196,7 @@ static std::vector<std::string> wrapText(const std::string& text, int maxChars)
|
||||
break;
|
||||
}
|
||||
if ((ch & 0x80) != 0) {
|
||||
// UTF-8 多字节字符——不在中间分割
|
||||
// UTF-8 多字节字符——不在中间分割 / UTF-8 multi-byte char — don't split in the middle
|
||||
}
|
||||
}
|
||||
if (splitAt <= 0 || splitAt > maxChars) {
|
||||
@@ -195,7 +204,7 @@ static std::vector<std::string> wrapText(const std::string& text, int maxChars)
|
||||
}
|
||||
|
||||
lines.push_back(segment.substr(0, splitAt));
|
||||
// 去除下一行的前导空格
|
||||
// 去除下一行的前导空格 / Trim leading spaces for the next line
|
||||
size_t start = splitAt;
|
||||
while (start < segment.size() &&
|
||||
(segment[start] == ' ' || segment[start] == '\t')) {
|
||||
@@ -207,8 +216,7 @@ static std::vector<std::string> wrapText(const std::string& text, int maxChars)
|
||||
return lines;
|
||||
}
|
||||
|
||||
// 计算所有消息的总渲染高度(逻辑像素)。
|
||||
// 注意:这使用当前的侧边栏状态来决定宽度;调用者应在侧边栏宽度正确时调用。
|
||||
// 计算所有消息在换行后的总渲染高度(逻辑像素) / Calculate the total rendered height (in logical pixels) of all messages after wrapping.
|
||||
static int calcTotalMsgHeight(GuiState& state, int charsPerLine) {
|
||||
int totalH = 0;
|
||||
for (auto& msg : state.messages) {
|
||||
@@ -219,8 +227,10 @@ static int calcTotalMsgHeight(GuiState& state, int charsPerLine) {
|
||||
return totalH;
|
||||
}
|
||||
|
||||
// ---- 侧边栏渲染 ----
|
||||
// ---- 侧边栏渲染 / Sidebar rendering ----
|
||||
|
||||
// 渲染左侧边栏:背景、会话列表和"+ New Chat"按钮。
|
||||
// Render the left sidebar: background, session list, and "+ New Chat" button.
|
||||
static void renderSidebar(AppContext& ctx) {
|
||||
GuiState& gs = ctx.state;
|
||||
SDL_Renderer* r = ctx.renderer;
|
||||
@@ -228,32 +238,34 @@ static void renderSidebar(AppContext& ctx) {
|
||||
float sbY = static_cast<float>(TITLE_H);
|
||||
float sbH = static_cast<float>(LOGICAL_H) - TITLE_H - STATUS_H;
|
||||
|
||||
// 背景
|
||||
// 背景 / Background
|
||||
fillRect(r, mkRect(0, sbY, sbW, sbH), COL_SIDEBAR_BG);
|
||||
|
||||
// 右侧分隔线
|
||||
// 右侧分隔线 / Right separator line
|
||||
setColor(r, COL_SEP);
|
||||
SDL_RenderLine(r, sbW, sbY, sbW, sbY + sbH);
|
||||
|
||||
// "Chats" 标题
|
||||
// "Chats" 标题 / "Chats" title
|
||||
drawText(r, static_cast<float>(PADDING), sbY + PADDING, "Chats", COL_WHITE);
|
||||
|
||||
// 会话列表(当前只有 "default")
|
||||
// 会话列表(当前只有 "default") / Session list (currently only "default")
|
||||
float listY = sbY + TITLE_H;
|
||||
// "default" 条目(活动状态高亮)
|
||||
// "default" 条目(活动状态高亮) / "default" entry (active state highlighted)
|
||||
float itemH = static_cast<float>(CHAR_H + PADDING);
|
||||
fillRect(r, mkRect(PADDING, listY, sbW - PADDING * 2, itemH), COL_SIDEBAR_ACT);
|
||||
drawText(r, PADDING * 2.0f, listY + PADDING / 2.0f, "default", COL_AI);
|
||||
|
||||
// "+ New Chat" 按钮(侧边栏底部)
|
||||
// "+ New Chat" 按钮(侧边栏底部) / "+ New Chat" button (sidebar bottom)
|
||||
float btnY = sbY + sbH - CHAR_H - PADDING * 2;
|
||||
float btnH = static_cast<float>(CHAR_H + PADDING);
|
||||
fillRect(r, mkRect(PADDING, btnY, sbW - PADDING * 2, btnH), COL_SIDEBAR_BTN);
|
||||
drawText(r, PADDING * 2.0f, btnY + PADDING / 2.0f, "+ New Chat", COL_WHITE);
|
||||
}
|
||||
|
||||
// ---- 状态栏渲染 ----
|
||||
// ---- 状态栏渲染 / Status bar rendering ----
|
||||
|
||||
// 渲染底部状态栏:模型名、消息数和流式状态。
|
||||
// Render the bottom status bar: model name, message count, and streaming state.
|
||||
static void renderStatusBar(AppContext& ctx) {
|
||||
GuiState& gs = ctx.state;
|
||||
SDL_Renderer* r = ctx.renderer;
|
||||
@@ -261,20 +273,20 @@ static void renderStatusBar(AppContext& ctx) {
|
||||
float lh = static_cast<float>(LOGICAL_H);
|
||||
float barY = lh - STATUS_H;
|
||||
|
||||
// 背景
|
||||
// 背景 / Background
|
||||
fillRect(r, mkRect(0, barY, lw, static_cast<float>(STATUS_H)), COL_STATUSBAR_BG);
|
||||
|
||||
// 顶部分隔线
|
||||
// 顶部分隔线 / Top separator line
|
||||
setColor(r, COL_SEP);
|
||||
SDL_RenderLine(r, 0, barY, lw, barY);
|
||||
|
||||
// 统计消息数(排除系统消息)
|
||||
// 统计消息数(排除系统消息) / Count messages (excluding system messages)
|
||||
int msgCount = 0;
|
||||
for (auto& msg : gs.messages) {
|
||||
if (msg.role != ChatMessage::SYSTEM) msgCount++;
|
||||
}
|
||||
|
||||
// 状态文本:模型名 | 消息条数 | 流式状态
|
||||
// 状态文本:模型名 | 消息条数 | 流式状态 / Status text: model name | message count | streaming state
|
||||
char buf[256];
|
||||
snprintf(buf, sizeof(buf), "%s | %d messages | %s",
|
||||
gs.model_name.c_str(), msgCount,
|
||||
@@ -283,8 +295,10 @@ static void renderStatusBar(AppContext& ctx) {
|
||||
barY + (STATUS_H - CHAR_H) / 2.0f, buf, COL_WHITE);
|
||||
}
|
||||
|
||||
// ---- 主渲染 ----
|
||||
// ---- 主渲染 / Main rendering ----
|
||||
|
||||
// 渲染一帧:标题栏、侧边栏、消息区(滚动)、输入区、光标、发送按钮、状态栏。
|
||||
// Render one full frame: title bar, sidebar, message area (with scrolling), input area, cursor, send button, status bar.
|
||||
static void renderFrame(AppContext& ctx) {
|
||||
GuiState& gs = ctx.state;
|
||||
SDL_Renderer* r = ctx.renderer;
|
||||
@@ -301,33 +315,33 @@ static void renderFrame(AppContext& ctx) {
|
||||
int charsPerLine = std::max(20,
|
||||
static_cast<int>(msgAreaW - PADDING * 2) / CHAR_W);
|
||||
|
||||
// 1. 设置渲染缩放以获得 2 倍文本大小
|
||||
// 1. 设置渲染缩放以获得 2 倍文本大小 / Set render scale for 2x text size
|
||||
SDL_SetRenderScale(r, RENDER_SCALE, RENDER_SCALE);
|
||||
|
||||
// 2. 清除背景
|
||||
// 2. 清除背景 / Clear background
|
||||
setColor(r, COL_BG);
|
||||
SDL_RenderClear(r);
|
||||
|
||||
// 3. 标题栏(全宽)
|
||||
// 3. 标题栏(全宽)/ Title bar (full width)
|
||||
fillRect(r, mkRect(0, 0, lw, static_cast<float>(TITLE_H)), COL_TITLE_BG);
|
||||
drawText(r, static_cast<float>(PADDING), static_cast<float>(PADDING),
|
||||
"dstalk - AI Chat", COL_WHITE);
|
||||
// 右侧的状态指示器
|
||||
// 右侧的状态指示器 / Status indicator on the right
|
||||
const char* status = gs.streaming ? "[streaming...]" : "[ready]";
|
||||
float statusW = static_cast<float>(strlen(status)) * CHAR_W + PADDING;
|
||||
drawText(r, lw - statusW, static_cast<float>(PADDING), status, COL_WHITE);
|
||||
|
||||
// 4. 标题栏分隔线
|
||||
// 4. 标题栏分隔线 / Title bar separator line
|
||||
setColor(r, COL_SEP);
|
||||
SDL_RenderLine(r, 0, static_cast<float>(TITLE_H),
|
||||
lw, static_cast<float>(TITLE_H));
|
||||
|
||||
// 5. 侧边栏(可折叠)
|
||||
// 5. 侧边栏(可折叠)/ Sidebar (collapsible)
|
||||
if (gs.sidebar_visible) {
|
||||
renderSidebar(ctx);
|
||||
}
|
||||
|
||||
// 6. 消息区域(带滚动)
|
||||
// 6. 消息区域(带滚动)/ Message area (with scrolling)
|
||||
SDL_Rect msgClip;
|
||||
msgClip.x = static_cast<int>(msgAreaX * RENDER_SCALE);
|
||||
msgClip.y = static_cast<int>(msgAreaY * RENDER_SCALE);
|
||||
@@ -335,13 +349,13 @@ static void renderFrame(AppContext& ctx) {
|
||||
msgClip.h = static_cast<int>(msgAreaH * RENDER_SCALE);
|
||||
SDL_SetRenderClipRect(r, &msgClip);
|
||||
|
||||
// 计算总消息高度和滚动限制
|
||||
// 计算总消息高度和滚动限制 / Calculate total message height and scroll limits
|
||||
int totalMsgH = calcTotalMsgHeight(gs, charsPerLine);
|
||||
gs.maxScroll = std::max(0.0f, static_cast<float>(totalMsgH) - msgAreaH);
|
||||
if (gs.scrollOffset < 0) gs.scrollOffset = 0;
|
||||
if (gs.scrollOffset > gs.maxScroll) gs.scrollOffset = static_cast<int>(gs.maxScroll);
|
||||
|
||||
// 绘制消息:起始 Y 从消息区域顶部减去 scrollOffset
|
||||
// 绘制消息:起始 Y 从消息区域顶部减去 scrollOffset / Draw messages: start Y from message area top minus scrollOffset
|
||||
float drawY = msgAreaY - gs.scrollOffset;
|
||||
float unusedSpace = msgAreaH - static_cast<float>(totalMsgH);
|
||||
float bottomOffset = std::max(0.0f, unusedSpace);
|
||||
@@ -359,7 +373,7 @@ static void renderFrame(AppContext& ctx) {
|
||||
default: col = COL_SYS; prefix = "Sys> "; break;
|
||||
}
|
||||
|
||||
// 如果该消息可见,则绘制
|
||||
// 如果该消息可见,则绘制 / Draw if this message is visible
|
||||
float msgBottom = drawY + msgH;
|
||||
if (msgBottom > msgAreaY && drawY < msgAreaY + msgAreaH) {
|
||||
float lineY = drawY + 2;
|
||||
@@ -383,14 +397,14 @@ static void renderFrame(AppContext& ctx) {
|
||||
|
||||
SDL_SetRenderClipRect(r, nullptr);
|
||||
|
||||
// 7. 输入区域分隔线
|
||||
// 7. 输入区域分隔线 / Input area separator line
|
||||
setColor(r, COL_SEP);
|
||||
SDL_RenderLine(r, msgAreaX, inputY, lw, inputY);
|
||||
|
||||
// 8. 输入区域背景
|
||||
// 8. 输入区域背景 / Input area background
|
||||
fillRect(r, mkRect(msgAreaX, inputY, msgAreaW, static_cast<float>(inputH)), COL_INPUT_BG);
|
||||
|
||||
// 9. 输入文本(支持多行显示)
|
||||
// 9. 输入文本(支持多行显示)/ Input text (multi-line support)
|
||||
if (!gs.inputBuffer.empty()) {
|
||||
std::string remaining = gs.inputBuffer;
|
||||
int lineIdx = 0;
|
||||
@@ -416,7 +430,7 @@ static void renderFrame(AppContext& ctx) {
|
||||
textY, "Type here...");
|
||||
}
|
||||
|
||||
// 10. 光标(多行感知)
|
||||
// 10. 光标(多行感知)/ Cursor (multi-line aware)
|
||||
if (!gs.streaming) {
|
||||
Uint64 now = SDL_GetTicks();
|
||||
if (now - gs.lastCursorBlink > 530) {
|
||||
@@ -424,7 +438,7 @@ static void renderFrame(AppContext& ctx) {
|
||||
gs.lastCursorBlink = now;
|
||||
}
|
||||
if (gs.cursorVisible && gs.cursorPos <= static_cast<int>(gs.inputBuffer.size())) {
|
||||
// 计算光标所在行和列
|
||||
// 计算光标所在行和列 / Calculate cursor line and column
|
||||
int curLine = 0;
|
||||
int charsBeforeLine = 0;
|
||||
for (int i = 0; i < gs.cursorPos; i++) {
|
||||
@@ -444,7 +458,7 @@ static void renderFrame(AppContext& ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
// 11. 发送/停止按钮
|
||||
// 11. 发送/停止按钮 / Send/Stop button
|
||||
float btnW = 5 * CHAR_W + PADDING;
|
||||
float btnH = CHAR_H + PADDING;
|
||||
float btnX = lw - btnW - PADDING;
|
||||
@@ -458,26 +472,27 @@ static void renderFrame(AppContext& ctx) {
|
||||
drawText(r, btnTextX, btnTextY, "[Send]", COL_WHITE);
|
||||
}
|
||||
|
||||
// 12. 状态栏
|
||||
// 12. 状态栏 / Status bar
|
||||
renderStatusBar(ctx);
|
||||
|
||||
// 13. Present
|
||||
// 13. Present / Present
|
||||
SDL_RenderPresent(r);
|
||||
}
|
||||
|
||||
// ---- 事件处理 ----
|
||||
// ---- 事件处理 / Event handling ----
|
||||
|
||||
// 尝试发送当前输入缓冲区的内容;返回 true 表示消息已排队
|
||||
// 验证当前输入缓冲区并将其作为用户消息加入队列;成功发送则返回 true。
|
||||
// Validate and queue the current input buffer as a user message; returns true if sent.
|
||||
static bool trySendMessage(GuiState& gs) {
|
||||
std::string text = gs.inputBuffer;
|
||||
// 去除前导/尾随空白,但保留内容空白
|
||||
// 去除前导/尾随空白,但保留内容空白 / Trim leading/trailing whitespace but preserve content whitespace
|
||||
size_t start = text.find_first_not_of(" \t\r\n");
|
||||
size_t end = text.find_last_not_of(" \t\r\n");
|
||||
if (start == std::string::npos) return false; // 空输入
|
||||
if (start == std::string::npos) return false; // 空输入 / empty input
|
||||
text = text.substr(start, end - start + 1);
|
||||
if (text.empty()) return false;
|
||||
|
||||
// 保存原始输入到历史(最多保留 20 条)
|
||||
// 保存原始输入到历史(最多保留 20 条) / Save original input to history (max 20 entries)
|
||||
gs.input_history.push_back(gs.inputBuffer);
|
||||
if (gs.input_history.size() > 20)
|
||||
gs.input_history.erase(gs.input_history.begin());
|
||||
@@ -489,7 +504,8 @@ static bool trySendMessage(GuiState& gs) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 如果输入区域中的 Send/Stop 按钮被点击,返回 true
|
||||
// 检查物理像素坐标是否落在发送/停止按钮区域内。
|
||||
// Return true if the given physical-pixel coordinates fall within the Send/Stop button.
|
||||
static bool isSendButtonHit(AppContext& ctx, float physX, float physY) {
|
||||
float lx = physX / RENDER_SCALE;
|
||||
float ly = physY / RENDER_SCALE;
|
||||
@@ -506,8 +522,10 @@ static bool isSendButtonHit(AppContext& ctx, float physX, float physY) {
|
||||
ly >= btnY && ly <= btnY + btnH;
|
||||
}
|
||||
|
||||
// ---- 流式回调 ----
|
||||
// ---- 流式回调 / Streaming callback ----
|
||||
|
||||
// 流式 token 回调:将 token 追加到 streamBuffer,更新最后一条助手消息,然后重新渲染。
|
||||
// Streaming token callback: appends token to streamBuffer, updates last assistant message, then re-renders.
|
||||
static int streamTokenCallback(const char* token, void* userdata) {
|
||||
AppContext* ctx = static_cast<AppContext*>(userdata);
|
||||
GuiState& gs = ctx->state;
|
||||
@@ -520,7 +538,7 @@ static int streamTokenCallback(const char* token, void* userdata) {
|
||||
}
|
||||
}
|
||||
|
||||
// 泵送事件以保持窗口响应
|
||||
// 泵送事件以保持窗口响应 / Pump events to keep the window responsive
|
||||
SDL_PumpEvents();
|
||||
|
||||
SDL_Event ev;
|
||||
@@ -547,15 +565,17 @@ static int streamTokenCallback(const char* token, void* userdata) {
|
||||
}
|
||||
}
|
||||
|
||||
// 重新渲染以显示进度的令牌
|
||||
// 重新渲染以显示进度的令牌 / Re-render to show the token progress
|
||||
gs.scrollOffset = 0;
|
||||
renderFrame(*ctx);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// ---- 主事件处理函数 ----
|
||||
// ---- 主事件处理函数 / Main event processing function ----
|
||||
|
||||
// 分发单个 SDL 事件以更新 GuiState(键盘输入、鼠标点击、滚动、文本输入)。
|
||||
// Dispatch a single SDL event to update GuiState (keyboard input, mouse clicks, scroll, text input).
|
||||
static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
GuiState& gs = ctx.state;
|
||||
|
||||
@@ -571,23 +591,23 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
bool shift = (mod & SDL_KMOD_SHIFT) != 0;
|
||||
|
||||
if (gs.streaming) {
|
||||
// 流式传输期间,按 Escape 键取消
|
||||
// 流式传输期间,按 Escape 键取消 / While streaming, press Escape to cancel
|
||||
if (key == SDLK_ESCAPE) {
|
||||
gs.streaming = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// Tab 切换侧边栏显示/隐藏
|
||||
// Tab 切换侧边栏显示/隐藏 / Tab toggles sidebar visibility
|
||||
if (key == SDLK_TAB) {
|
||||
gs.sidebar_visible = !gs.sidebar_visible;
|
||||
break;
|
||||
}
|
||||
|
||||
// 输入历史浏览(↑/↓)
|
||||
// 输入历史浏览(↑/↓) / Input history browsing (Up/Down)
|
||||
if (key == SDLK_UP && !gs.input_history.empty()) {
|
||||
if (gs.history_index == -1) {
|
||||
// 首次进入历史浏览,保存当前输入
|
||||
// 首次进入历史浏览,保存当前输入 / First time browsing history, save current input
|
||||
gs.saved_input = gs.inputBuffer;
|
||||
gs.history_index = static_cast<int>(gs.input_history.size()) - 1;
|
||||
} else if (gs.history_index > 0) {
|
||||
@@ -606,7 +626,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
if (gs.history_index >= 0) {
|
||||
gs.inputBuffer = gs.input_history[gs.history_index];
|
||||
} else {
|
||||
// 回到新输入,恢复暂存的输入
|
||||
// 回到新输入,恢复暂存的输入 / Back to new input, restore saved input
|
||||
gs.inputBuffer = gs.saved_input;
|
||||
}
|
||||
gs.cursorPos = static_cast<int>(gs.inputBuffer.size());
|
||||
@@ -620,7 +640,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
case SDLK_RETURN:
|
||||
case SDLK_KP_ENTER:
|
||||
if (shift) {
|
||||
// Shift+Enter:插入换行符(不发送)
|
||||
// Shift+Enter:插入换行符(不发送) / Shift+Enter: insert newline (don't send)
|
||||
gs.inputBuffer.insert(gs.cursorPos, "\n");
|
||||
gs.cursorPos++;
|
||||
gs.cursorVisible = true;
|
||||
@@ -670,7 +690,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
break;
|
||||
case SDLK_V:
|
||||
if (ctrl) {
|
||||
// Ctrl+V:从剪贴板粘贴
|
||||
// Ctrl+V:从剪贴板粘贴 / Ctrl+V: paste from clipboard
|
||||
if (SDL_HasClipboardText()) {
|
||||
char* clip = SDL_GetClipboardText();
|
||||
if (clip) {
|
||||
@@ -685,7 +705,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
break;
|
||||
case SDLK_C:
|
||||
if (ctrl) {
|
||||
// Ctrl+C:复制到剪贴板(复制最后一条助手消息)
|
||||
// Ctrl+C:复制到剪贴板(复制最后一条助手消息) / Ctrl+C: copy to clipboard (copy last assistant message)
|
||||
if (!gs.messages.empty()) {
|
||||
for (int i = static_cast<int>(gs.messages.size()) - 1; i >= 0; --i) {
|
||||
if (gs.messages[i].role != ChatMessage::USER) {
|
||||
@@ -701,7 +721,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
break;
|
||||
case SDLK_L:
|
||||
if (ctrl) {
|
||||
// Ctrl+L:清除聊天
|
||||
// Ctrl+L:清除聊天 / Ctrl+L: clear chat
|
||||
if (g_session_svc) g_session_svc->clear();
|
||||
gs.messages.clear();
|
||||
gs.messages.push_back(ChatMessage(
|
||||
@@ -711,7 +731,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
break;
|
||||
case SDLK_S:
|
||||
if (ctrl) {
|
||||
// Ctrl+S:保存会话
|
||||
// Ctrl+S:保存会话 / Ctrl+S: save session
|
||||
if (g_session_svc && g_session_svc->save("session.json") == 0) {
|
||||
gs.messages.push_back(ChatMessage(
|
||||
ChatMessage::SYSTEM, "Session saved to session.json"));
|
||||
@@ -724,7 +744,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
break;
|
||||
case SDLK_O:
|
||||
if (ctrl) {
|
||||
// Ctrl+O:加载会话
|
||||
// Ctrl+O:加载会话 / Ctrl+O: load session
|
||||
if (g_session_svc && g_session_svc->load("session.json") == 0) {
|
||||
gs.messages.push_back(ChatMessage(
|
||||
ChatMessage::SYSTEM, "Session loaded from session.json"));
|
||||
@@ -743,7 +763,7 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
|
||||
case SDL_EVENT_TEXT_INPUT:
|
||||
if (!gs.streaming) {
|
||||
// 将文本插入光标位置
|
||||
// 将文本插入光标位置 / Insert text at cursor position
|
||||
gs.inputBuffer.insert(gs.cursorPos, ev.text.text);
|
||||
gs.cursorPos += static_cast<int>(strlen(ev.text.text));
|
||||
gs.cursorVisible = true;
|
||||
@@ -772,6 +792,8 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
case SDL_EVENT_WINDOW_RESIZED: {
|
||||
// 当窗口大小改变时,不更新我们的常量——保持 1024x768 的逻辑尺寸。
|
||||
// SDL 将自动缩放输出。
|
||||
// When window resizes, don't update our constants — keep 1024x768 logical size.
|
||||
// SDL will auto-scale the output.
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -780,10 +802,12 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
|
||||
}
|
||||
}
|
||||
|
||||
// ---- 入口 ----
|
||||
// ---- 入口 / Entry point ----
|
||||
|
||||
// 入口:初始化 dstalk host 和 SDL3,运行主事件/渲染循环,然后清理。
|
||||
// Entry point: initializes dstalk host and SDL3, runs the main event/render loop, then cleans up.
|
||||
int main(int argc, char* argv[]) {
|
||||
// ----- 初始化 dstalk -----
|
||||
// ----- 初始化 dstalk / Initialize dstalk -----
|
||||
if (dstalk_init(nullptr) != 0) {
|
||||
std::fprintf(stderr, "[dstalk] Init failed\n");
|
||||
return 1;
|
||||
@@ -796,7 +820,7 @@ int main(int argc, char* argv[]) {
|
||||
if (!g_ai_svc) dstalk_log(3, "AI service not found (check plugins directory)");
|
||||
if (!g_session_svc) dstalk_log(3, "Session service not found");
|
||||
|
||||
// ----- 初始化 SDL -----
|
||||
// ----- 初始化 SDL / Initialize SDL -----
|
||||
if (!SDL_Init(SDL_INIT_VIDEO)) {
|
||||
std::fprintf(stderr, "[dstalk] SDL init failed: %s\n", SDL_GetError());
|
||||
dstalk_shutdown();
|
||||
@@ -822,10 +846,10 @@ int main(int argc, char* argv[]) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// 启用文本输入事件
|
||||
// 启用文本输入事件 / Enable text input events
|
||||
SDL_StartTextInput(window);
|
||||
|
||||
// ----- 应用程序状态 -----
|
||||
// ----- 应用程序状态 / Application state -----
|
||||
AppContext ctx;
|
||||
ctx.window = window;
|
||||
ctx.renderer = renderer;
|
||||
@@ -834,29 +858,29 @@ int main(int argc, char* argv[]) {
|
||||
"Ctrl+L clear, Ctrl+S save, Ctrl+O load. "
|
||||
"Shift+Enter for newline, Up/Down for history, Tab toggle sidebar."));
|
||||
|
||||
// ----- 主循环 -----
|
||||
// ----- 主循环 / Main loop -----
|
||||
SDL_Event event;
|
||||
while (ctx.state.running) {
|
||||
// 处理所有待处理事件
|
||||
// 处理所有待处理事件 / Process all pending events
|
||||
while (SDL_PollEvent(&event)) {
|
||||
processEvent(ctx, event);
|
||||
if (!ctx.state.running) break;
|
||||
}
|
||||
if (!ctx.state.running) break;
|
||||
|
||||
// 检查待发送的消息
|
||||
// 检查待发送的消息 / Check for pending message to send
|
||||
if (ctx.sendPending && !ctx.state.streaming) {
|
||||
ctx.sendPending = false;
|
||||
if (trySendMessage(ctx.state)) {
|
||||
// 开始流式传输
|
||||
// 开始流式传输 / Start streaming
|
||||
ctx.state.streaming = true;
|
||||
ctx.streamBuffer.clear();
|
||||
// 为流式响应添加占位消息
|
||||
// 为流式响应添加占位消息 / Add placeholder message for streaming response
|
||||
ctx.state.messages.push_back(
|
||||
ChatMessage(ChatMessage::ASSISTANT, ""));
|
||||
ctx.state.scrollOffset = 0;
|
||||
|
||||
// 对最后一条消息调用流式 API(通过插件服务 vtable)
|
||||
// 对最后一条消息调用流式 API(通过插件服务 vtable) / Call streaming API for the last message (via plugin service vtable)
|
||||
std::string& userMsg =
|
||||
ctx.state.messages[ctx.state.messages.size() - 2].content;
|
||||
int rc = -1;
|
||||
@@ -871,7 +895,7 @@ int main(int argc, char* argv[]) {
|
||||
g_ai_svc->free_result(&result);
|
||||
}
|
||||
|
||||
// 流式传输完成(或被取消)
|
||||
// 流式传输完成(或被取消) / Streaming completed (or cancelled)
|
||||
if (rc != 0) {
|
||||
if (!ctx.state.messages.empty() &&
|
||||
ctx.state.messages.back().role == ChatMessage::ASSISTANT) {
|
||||
@@ -884,14 +908,14 @@ int main(int argc, char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
// 渲染当前帧
|
||||
// 渲染当前帧 / Render current frame
|
||||
renderFrame(ctx);
|
||||
|
||||
// 短暂休眠以降低 CPU 使用率
|
||||
// 短暂休眠以降低 CPU 使用率 / Brief sleep to reduce CPU usage
|
||||
SDL_Delay(16); // ~60 FPS
|
||||
}
|
||||
|
||||
// ----- 清理 -----
|
||||
// ----- 清理 / Cleanup -----
|
||||
SDL_StopTextInput(window);
|
||||
SDL_DestroyRenderer(renderer);
|
||||
SDL_DestroyWindow(window);
|
||||
|
||||
27
dstalk-web/CMakeLists.txt
Normal file
27
dstalk-web/CMakeLists.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
# ============================================================
|
||||
# dstalk-web — Web 前端 / Web frontend (Boost.Beast HTTP + SSE)
|
||||
# ============================================================
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
add_executable(dstalk-web
|
||||
src/main.cpp
|
||||
)
|
||||
|
||||
set_target_properties(dstalk-web PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
|
||||
)
|
||||
|
||||
target_compile_features(dstalk-web PRIVATE cxx_std_20)
|
||||
|
||||
target_link_libraries(dstalk-web
|
||||
PRIVATE
|
||||
dstalk
|
||||
boost::boost
|
||||
dstalk_boost_config
|
||||
)
|
||||
|
||||
# Windows: Boost.Asio 需要 Winsock / Boost.Asio requires Winsock
|
||||
if(WIN32)
|
||||
target_link_libraries(dstalk-web PRIVATE ws2_32)
|
||||
endif()
|
||||
561
dstalk-web/src/main.cpp
Normal file
561
dstalk-web/src/main.cpp
Normal file
@@ -0,0 +1,561 @@
|
||||
/*
|
||||
* @file main.cpp
|
||||
* @brief Boost.Beast HTTP server frontend for dstalk-web: SSE streaming chat, embedded web UI, CORS support.
|
||||
* dstalk-web 的 Boost.Beast HTTP 服务端:SSE 流式对话、嵌入式网页界面、CORS 支持。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "web_ui.hpp"
|
||||
|
||||
#include <boost/beast/core.hpp>
|
||||
#include <boost/beast/http.hpp>
|
||||
#include <boost/asio.hpp>
|
||||
#include <boost/json.hpp>
|
||||
#include <boost/json/src.hpp>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <windows.h>
|
||||
#else
|
||||
#include <signal.h>
|
||||
#endif
|
||||
|
||||
// ---- 命名空间别名 / Namespace aliases ----
|
||||
namespace beast = boost::beast;
|
||||
namespace http = beast::http;
|
||||
namespace asio = boost::asio;
|
||||
using tcp = asio::ip::tcp;
|
||||
|
||||
// ---- 前置声明 / Forward declarations ----
|
||||
class SseSession;
|
||||
|
||||
// ---- 服务 vtable 指针 / Service vtable pointers ----
|
||||
// Global pointers to plugin service vtables, queried from the host on startup.
|
||||
// 插件服务 vtable 的全局指针,在启动时从主机查询获取。
|
||||
static const dstalk_ai_service_t* g_ai = nullptr;
|
||||
static const dstalk_session_service_t* g_session = nullptr;
|
||||
|
||||
// ---- 运行时状态 / Runtime state ----
|
||||
// g_quit signals the main loop to exit (set by Ctrl+C).
|
||||
// g_ioc is the io_context pointer for use by signal handlers to stop the event loop.
|
||||
// g_quit 通知主循环退出(由 Ctrl+C 设置)。
|
||||
// g_ioc 供信号处理函数调用 stop() 的 io_context 指针。
|
||||
static std::atomic<bool> g_quit{false};
|
||||
static asio::io_context* g_ioc = nullptr;
|
||||
|
||||
// ---- Ctrl+C 信号处理 / Ctrl+C signal handlers ----
|
||||
// Windows console event handler (CTRL_C_EVENT / CTRL_BREAK_EVENT).
|
||||
// Windows 控制台事件处理(CTRL_C_EVENT / CTRL_BREAK_EVENT)。
|
||||
#ifdef _WIN32
|
||||
static BOOL WINAPI on_console_event(DWORD event)
|
||||
{
|
||||
if (event == CTRL_C_EVENT || event == CTRL_BREAK_EVENT) {
|
||||
g_quit = true;
|
||||
if (g_ioc) g_ioc->stop();
|
||||
return TRUE;
|
||||
}
|
||||
return FALSE;
|
||||
}
|
||||
// Unix signal handler (SIGINT).
|
||||
// Unix 信号处理(SIGINT)。
|
||||
#else
|
||||
static void on_signal(int /*sig*/)
|
||||
{
|
||||
g_quit = true;
|
||||
if (g_ioc) g_ioc->stop();
|
||||
}
|
||||
#endif
|
||||
|
||||
// ========================================================================
|
||||
// SseSession — 管理一个 SSE 流式响应连接 / Manages one SSE streaming response
|
||||
// ========================================================================
|
||||
// 持有从 HttpSession 转移过来的 tcp::socket,以 SSE 格式流式发送 AI 回复。
|
||||
// 所有公开方法均在 io_context 线程上被调用,因此无需互斥锁。
|
||||
// Owns the tcp::socket transferred from HttpSession; streams AI response as SSE.
|
||||
// All public methods are called on the io_context thread, so no mutex is needed.
|
||||
class SseSession : public std::enable_shared_from_this<SseSession> {
|
||||
public:
|
||||
// 构造函数:接管已接受的 socket / Constructor: take ownership of the accepted socket.
|
||||
explicit SseSession(tcp::socket&& s) : socket_(std::move(s)) {}
|
||||
|
||||
// 发送 SSE HTTP 响应头并准备接收数据帧 / Send SSE HTTP response headers and prepare for data frames.
|
||||
void start() {
|
||||
writing_ = true; // 阻止数据写入,等待头部发送完成 / Block data writes until headers are sent
|
||||
std::string header =
|
||||
"HTTP/1.1 200 OK\r\n"
|
||||
"Content-Type: text/event-stream\r\n"
|
||||
"Cache-Control: no-cache\r\n"
|
||||
"Connection: keep-alive\r\n"
|
||||
"Access-Control-Allow-Origin: *\r\n"
|
||||
"\r\n";
|
||||
auto self = shared_from_this();
|
||||
asio::async_write(socket_, asio::buffer(header),
|
||||
[self](beast::error_code ec, size_t) {
|
||||
self->writing_ = false;
|
||||
if (!ec && !self->pending_.empty()) {
|
||||
self->do_write();
|
||||
}
|
||||
// 写入失败则让 socket 随 SseSession 析构自然关闭 / On error, let socket close via destructor
|
||||
});
|
||||
}
|
||||
|
||||
// 将 token 加入待发送队列,若未在写入则启动写链 / Push token to pending queue; start write chain if idle.
|
||||
void send_token(const std::string& token) {
|
||||
if (closed_) return;
|
||||
// 换行符会破坏 SSE 帧结构,替换为空格 / Newlines break SSE frame structure; replace with spaces
|
||||
std::string t = token;
|
||||
for (auto& c : t) if (c == '\n' || c == '\r') c = ' ';
|
||||
if (t.empty()) return;
|
||||
pending_.push_back("data: " + t + "\n\n");
|
||||
if (!writing_) do_write();
|
||||
}
|
||||
|
||||
// 发送完成事件后关闭连接 / Send done event then close the connection.
|
||||
void send_done(bool ok, const std::string& content) {
|
||||
if (closed_) return;
|
||||
closed_ = true;
|
||||
(void)ok;
|
||||
(void)content;
|
||||
// JS 客户端忽略 [DONE] token,流自然结束触发最终渲染 / JS client ignores [DONE] token; stream end triggers final render
|
||||
pending_.push_back("event: done\ndata: [DONE]\n\n");
|
||||
if (!writing_) do_write();
|
||||
}
|
||||
|
||||
// 发送错误事件后关闭连接 / Send error event then close the connection.
|
||||
void send_error(const std::string& msg) {
|
||||
if (closed_) return;
|
||||
closed_ = true;
|
||||
std::string m = msg;
|
||||
for (auto& c : m) if (c == '\n' || c == '\r') c = ' ';
|
||||
pending_.push_back("event: error\ndata: " + m + "\n\n");
|
||||
if (!writing_) do_write();
|
||||
}
|
||||
|
||||
private:
|
||||
// 从待发送队列头部取出并异步写入 / Pop front of pending queue and async-write it.
|
||||
void do_write() {
|
||||
if (writing_ || pending_.empty()) return;
|
||||
writing_ = true;
|
||||
auto self = shared_from_this();
|
||||
asio::async_write(socket_, asio::buffer(pending_.front()),
|
||||
[self](beast::error_code ec, size_t) {
|
||||
self->writing_ = false;
|
||||
self->pending_.pop_front();
|
||||
if (!ec) {
|
||||
if (!self->pending_.empty()) {
|
||||
self->do_write();
|
||||
} else if (self->closed_) {
|
||||
// 队列已空且会话已关闭 → 关闭 socket / Queue drained and session closed → close socket
|
||||
beast::error_code ignored;
|
||||
self->socket_.shutdown(tcp::socket::shutdown_both, ignored);
|
||||
self->socket_.close(ignored);
|
||||
}
|
||||
}
|
||||
// 写入错误时 socket 随 SseSession 析构 / On write error, socket closes with SseSession
|
||||
});
|
||||
}
|
||||
|
||||
tcp::socket socket_;
|
||||
std::deque<std::string> pending_;
|
||||
bool writing_ = false;
|
||||
bool closed_ = false;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// run_chat_worker — 在独立线程中执行流式 AI 聊天 / Execute streaming AI chat in a dedicated thread
|
||||
// ========================================================================
|
||||
// 将用户消息加入会话,调用 g_ai->chat_stream(),通过 asio::post 将 token 投递到 io_context。
|
||||
// Add user message to session, call g_ai->chat_stream(), post tokens to io_context via asio::post.
|
||||
static void run_chat_worker(
|
||||
std::string user_input,
|
||||
std::weak_ptr<SseSession> weak_sse,
|
||||
asio::io_context& ioc)
|
||||
{
|
||||
// 将用户消息加入会话 / Add user message to session
|
||||
dstalk_message_t user_msg = {"user", user_input.c_str(), nullptr, nullptr};
|
||||
g_session->add(&user_msg);
|
||||
|
||||
// 获取会话历史 / Get session history
|
||||
int history_count = 0;
|
||||
const dstalk_message_t* history = g_session->history(&history_count);
|
||||
|
||||
// 流式回调上下文 / Streaming callback context
|
||||
struct CallbackData {
|
||||
std::weak_ptr<SseSession> sse;
|
||||
asio::io_context* ioc;
|
||||
};
|
||||
CallbackData cb_data{weak_sse, &ioc};
|
||||
|
||||
// 流式 token 回调:将 token 投递到 io_context 线程 / Streaming token callback: post token to io_context thread
|
||||
auto token_cb = [](const char* token, void* userdata) -> int {
|
||||
auto* data = static_cast<CallbackData*>(userdata);
|
||||
if (auto sse = data->sse.lock()) {
|
||||
std::string t(token);
|
||||
asio::post(*data->ioc, [sse, t = std::move(t)]() {
|
||||
sse->send_token(t);
|
||||
});
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
// 调用流式 AI 聊天 / Call streaming AI chat
|
||||
dstalk_chat_result_t result = g_ai->chat_stream(
|
||||
history, history_count, nullptr, token_cb, &cb_data);
|
||||
|
||||
// 将 AI 回复加入会话 / Add AI reply to session
|
||||
if (result.ok) {
|
||||
dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json};
|
||||
g_session->add(&ai_msg);
|
||||
}
|
||||
|
||||
// 将完成/错误事件投递到 io_context 线程 / Post completion/error event to io_context thread
|
||||
bool ok = result.ok;
|
||||
std::string content_copy = result.content ? result.content : "";
|
||||
std::string error_copy = result.error ? result.error : "";
|
||||
g_ai->free_result(&result);
|
||||
|
||||
asio::post(ioc, [weak_sse, ok, content_copy, error_copy]() {
|
||||
if (auto sse = weak_sse.lock()) {
|
||||
if (ok) {
|
||||
sse->send_done(true, content_copy);
|
||||
} else {
|
||||
sse->send_error(error_copy.empty() ? "unknown error" : error_copy);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// HttpSession — 处理单个 HTTP 请求 / Handles one HTTP request
|
||||
// ========================================================================
|
||||
// 使用 Beast 解析请求,按 method + target 路由到相应处理器。
|
||||
// Uses Beast to parse the request, routing by method + target to the appropriate handler.
|
||||
class HttpSession : public std::enable_shared_from_this<HttpSession> {
|
||||
public:
|
||||
// 构造函数:接管已接受的 socket / Constructor: take ownership of the accepted socket.
|
||||
explicit HttpSession(tcp::socket&& s) : socket_(std::move(s)) {}
|
||||
|
||||
// 开始读取请求 / Start reading the request.
|
||||
void start() { do_read(); }
|
||||
|
||||
private:
|
||||
// 异步读取 HTTP 请求 / Asynchronously read the HTTP request.
|
||||
void do_read() {
|
||||
auto self = shared_from_this();
|
||||
http::async_read(socket_, buffer_, request_,
|
||||
[self](beast::error_code ec, size_t) {
|
||||
if (ec) return; // 客户端断开或读取错误 / Client disconnected or read error
|
||||
self->handle_request();
|
||||
});
|
||||
}
|
||||
|
||||
// 路由 HTTP 请求到相应的处理器 / Route the HTTP request to the appropriate handler.
|
||||
void handle_request() {
|
||||
auto const method = request_.method();
|
||||
auto const target = std::string(request_.target());
|
||||
|
||||
// GET / — 返回嵌入式网页界面 / Return embedded web UI
|
||||
if (method == http::verb::get && target == "/") {
|
||||
serve_web_ui();
|
||||
return;
|
||||
}
|
||||
|
||||
// POST /chat — SSE 流式聊天 / SSE streaming chat
|
||||
if (method == http::verb::post && target == "/chat") {
|
||||
handle_chat();
|
||||
return;
|
||||
}
|
||||
|
||||
// OPTIONS /chat — CORS 预检请求 / CORS preflight request
|
||||
if (method == http::verb::options && target == "/chat") {
|
||||
serve_cors_preflight();
|
||||
return;
|
||||
}
|
||||
|
||||
// POST /clear — 清除会话 / Clear session
|
||||
if (method == http::verb::post && target == "/clear") {
|
||||
handle_clear();
|
||||
return;
|
||||
}
|
||||
|
||||
// POST /status — 返回运行状态 / Return runtime status
|
||||
if (method == http::verb::post && target == "/status") {
|
||||
handle_status();
|
||||
return;
|
||||
}
|
||||
|
||||
// 未知路由 — 404 / Unknown route — 404
|
||||
serve_404();
|
||||
}
|
||||
|
||||
// 返回 HTML 网页界面 / Serve the HTML web UI.
|
||||
void serve_web_ui() {
|
||||
auto self = shared_from_this();
|
||||
http::response<http::string_body> res{http::status::ok, request_.version()};
|
||||
res.set(http::field::content_type, "text/html; charset=utf-8");
|
||||
res.set(http::field::cache_control, "no-cache");
|
||||
res.set("Access-Control-Allow-Origin", "*");
|
||||
res.body() = kWebUiHtml;
|
||||
res.prepare_payload();
|
||||
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
|
||||
}
|
||||
|
||||
// 解析 JSON body、创建 SseSession、启动工作线程 / Parse JSON body, create SseSession, spawn worker thread.
|
||||
void handle_chat() {
|
||||
// 解析 JSON body / Parse JSON body
|
||||
boost::system::error_code ec;
|
||||
auto jv = boost::json::parse(request_.body(), ec);
|
||||
if (ec || !jv.is_object()) {
|
||||
serve_bad_request("Invalid JSON body");
|
||||
return;
|
||||
}
|
||||
auto const& obj = jv.as_object();
|
||||
auto it = obj.find("message");
|
||||
if (it == obj.end() || !it->value().is_string()) {
|
||||
serve_bad_request("Missing or invalid 'message' field");
|
||||
return;
|
||||
}
|
||||
std::string user_input = boost::json::value_to<std::string>(it->value());
|
||||
|
||||
// 创建 SseSession 并转移 socket 所有权 / Create SseSession and transfer socket ownership
|
||||
auto sse = std::make_shared<SseSession>(std::move(socket_));
|
||||
sse->start();
|
||||
|
||||
// 在独立线程中执行聊天(chat_stream 是阻塞调用) / Execute chat in dedicated thread (chat_stream is blocking)
|
||||
std::thread worker([user_input = std::move(user_input), sse]() {
|
||||
run_chat_worker(user_input, sse, *g_ioc);
|
||||
});
|
||||
worker.detach();
|
||||
}
|
||||
|
||||
// 返回 CORS 预检响应头 / Return CORS preflight response headers.
|
||||
void serve_cors_preflight() {
|
||||
auto self = shared_from_this();
|
||||
http::response<http::empty_body> res{http::status::ok, request_.version()};
|
||||
res.set("Access-Control-Allow-Origin", "*");
|
||||
res.set("Access-Control-Allow-Methods", "POST, OPTIONS");
|
||||
res.set("Access-Control-Allow-Headers", "Content-Type");
|
||||
res.set("Access-Control-Max-Age", "86400");
|
||||
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
|
||||
}
|
||||
|
||||
// 清除当前会话 / Clear the current session.
|
||||
void handle_clear() {
|
||||
if (g_session) g_session->clear();
|
||||
auto self = shared_from_this();
|
||||
http::response<http::string_body> res{http::status::ok, request_.version()};
|
||||
res.set("Access-Control-Allow-Origin", "*");
|
||||
res.set(http::field::content_type, "application/json");
|
||||
res.body() = "{\"ok\":true}";
|
||||
res.prepare_payload();
|
||||
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
|
||||
}
|
||||
|
||||
// 返回运行状态 JSON / Return runtime status as JSON.
|
||||
void handle_status() {
|
||||
boost::json::object st;
|
||||
if (g_session) {
|
||||
int count = 0;
|
||||
g_session->history(&count);
|
||||
st["messages"] = count;
|
||||
st["tokens"] = g_session->token_count();
|
||||
}
|
||||
const char* model = dstalk_config_get("api.model");
|
||||
if (model) st["model"] = std::string(model);
|
||||
st["status"] = "running";
|
||||
|
||||
auto self = shared_from_this();
|
||||
http::response<http::string_body> res{http::status::ok, request_.version()};
|
||||
res.set("Access-Control-Allow-Origin", "*");
|
||||
res.set(http::field::content_type, "application/json");
|
||||
res.body() = boost::json::serialize(st);
|
||||
res.prepare_payload();
|
||||
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
|
||||
}
|
||||
|
||||
// 返回 400 Bad Request / Return 400 Bad Request.
|
||||
void serve_bad_request(const std::string& msg) {
|
||||
auto self = shared_from_this();
|
||||
http::response<http::string_body> res{http::status::bad_request, request_.version()};
|
||||
res.set(http::field::content_type, "text/plain");
|
||||
res.set("Access-Control-Allow-Origin", "*");
|
||||
res.body() = msg;
|
||||
res.prepare_payload();
|
||||
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
|
||||
}
|
||||
|
||||
// 返回 404 Not Found / Return 404 Not Found.
|
||||
void serve_404() {
|
||||
auto self = shared_from_this();
|
||||
http::response<http::string_body> res{http::status::not_found, request_.version()};
|
||||
res.set(http::field::content_type, "text/plain");
|
||||
res.set("Access-Control-Allow-Origin", "*");
|
||||
res.body() = "404 Not Found";
|
||||
res.prepare_payload();
|
||||
http::async_write(socket_, res, [self](beast::error_code, size_t) {});
|
||||
}
|
||||
|
||||
tcp::socket socket_;
|
||||
beast::flat_buffer buffer_;
|
||||
http::request<http::string_body> request_;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// Listener — 接受 TCP 连接并创建 HttpSession / Accepts TCP connections and creates HttpSessions
|
||||
// ========================================================================
|
||||
// 异步接受循环:每个进入的连接包装为 HttpSession 并由 io_context 驱动其生命周期。
|
||||
// Async accept loop: each inbound connection is wrapped in an HttpSession driven by the io_context.
|
||||
class Listener {
|
||||
public:
|
||||
// 构造函数:打开 acceptor、绑定地址、开始监听 / Constructor: open acceptor, bind, start listening.
|
||||
Listener(asio::io_context& ioc, const tcp::endpoint& ep)
|
||||
: acceptor_(ioc)
|
||||
{
|
||||
beast::error_code ec;
|
||||
acceptor_.open(ep.protocol(), ec);
|
||||
if (ec) {
|
||||
std::fprintf(stderr, "[dstalk-web] acceptor.open: %s\n", ec.message().c_str());
|
||||
return;
|
||||
}
|
||||
acceptor_.set_option(asio::socket_base::reuse_address(true), ec);
|
||||
acceptor_.bind(ep, ec);
|
||||
if (ec) {
|
||||
std::fprintf(stderr, "[dstalk-web] acceptor.bind: %s\n", ec.message().c_str());
|
||||
return;
|
||||
}
|
||||
acceptor_.listen(asio::socket_base::max_listen_connections, ec);
|
||||
if (ec) {
|
||||
std::fprintf(stderr, "[dstalk-web] acceptor.listen: %s\n", ec.message().c_str());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 启动接受循环 / Start the accept loop.
|
||||
void run() { do_accept(); }
|
||||
|
||||
private:
|
||||
// 异步接受一个连接,创建 HttpSession 并继续监听 / Async-accept one connection, create HttpSession, keep listening.
|
||||
void do_accept() {
|
||||
acceptor_.async_accept(
|
||||
[this](beast::error_code ec, tcp::socket socket) {
|
||||
if (!ec) {
|
||||
// 为每个入站连接创建新的 HttpSession / Create a new HttpSession for each inbound connection
|
||||
std::make_shared<HttpSession>(std::move(socket))->start();
|
||||
}
|
||||
// 继续接受下一个连接(除非已发出退出信号) / Keep accepting (unless quit has been signaled)
|
||||
if (!g_quit) do_accept();
|
||||
});
|
||||
}
|
||||
|
||||
tcp::acceptor acceptor_;
|
||||
};
|
||||
|
||||
// ========================================================================
|
||||
// main — 入口点 / Entry point
|
||||
// ========================================================================
|
||||
// 初始化 dstalk host,查询 AI/Session 服务,配置 HTTP 监听,运行 io_context 事件循环。
|
||||
// Initialize dstalk host, query AI/Session services, configure HTTP listener, run io_context event loop.
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// Windows: 启用 ANSI 转义码 + 安装 Ctrl+C 处理器 / Windows: enable ANSI escape codes + install Ctrl+C handler
|
||||
#ifdef _WIN32
|
||||
HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE);
|
||||
DWORD mode = 0;
|
||||
GetConsoleMode(hOut, &mode);
|
||||
SetConsoleMode(hOut, mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING);
|
||||
SetConsoleCtrlHandler(on_console_event, TRUE);
|
||||
#else
|
||||
signal(SIGINT, on_signal);
|
||||
#endif
|
||||
|
||||
// 查找配置文件路径 / Locate config file path
|
||||
const char* config_path = nullptr;
|
||||
if (argc >= 2) {
|
||||
config_path = argv[1];
|
||||
}
|
||||
if (!config_path) {
|
||||
const char* default_configs[] = {"config.toml", nullptr};
|
||||
for (int i = 0; default_configs[i]; i++) {
|
||||
FILE* f = nullptr;
|
||||
#ifdef _WIN32
|
||||
fopen_s(&f, default_configs[i], "r");
|
||||
#else
|
||||
f = fopen(default_configs[i], "r");
|
||||
#endif
|
||||
if (f) {
|
||||
fclose(f);
|
||||
config_path = default_configs[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化 dstalk 主机(加载配置 + 自动扫描 plugins/ 目录) / Init dstalk host (load config + auto-scan plugins/)
|
||||
if (dstalk_init(config_path) != 0) {
|
||||
std::fprintf(stderr, "[dstalk-web] dstalk_init failed\n");
|
||||
return 3;
|
||||
}
|
||||
|
||||
// 查询插件服务 / Query plugin services
|
||||
const char* ai_provider = dstalk_config_get("ai.provider");
|
||||
if (!ai_provider) ai_provider = "ai.deepseek";
|
||||
g_ai = static_cast<const dstalk_ai_service_t*>(dstalk_service_query(ai_provider, 1));
|
||||
g_session = static_cast<const dstalk_session_service_t*>(dstalk_service_query("session", 1));
|
||||
|
||||
if (!g_ai) {
|
||||
std::fprintf(stderr, "[dstalk-web] AI service not found (check plugins directory)\n");
|
||||
}
|
||||
if (!g_session) {
|
||||
std::fprintf(stderr, "[dstalk-web] Session service not found\n");
|
||||
}
|
||||
|
||||
// 从配置自动加载 AI 设置 / Auto-load AI settings from config
|
||||
if (g_ai) {
|
||||
const char* base_url = dstalk_config_get("api.base_url");
|
||||
const char* api_key = dstalk_config_get("api.api_key");
|
||||
const char* model = dstalk_config_get("api.model");
|
||||
if (!base_url) base_url = "https://api.deepseek.com/v1";
|
||||
if (!model) model = "deepseek-v4-pro";
|
||||
g_ai->configure(ai_provider, base_url, api_key ? api_key : "", model, 4096, 0.7);
|
||||
}
|
||||
|
||||
// 读取 web 服务配置 / Read web server config
|
||||
const char* web_host = dstalk_config_get("web.host");
|
||||
if (!web_host || !web_host[0]) web_host = "127.0.0.1";
|
||||
const char* web_port_str = dstalk_config_get("web.port");
|
||||
unsigned short web_port = 8080;
|
||||
if (web_port_str && web_port_str[0]) {
|
||||
web_port = static_cast<unsigned short>(std::strtoul(web_port_str, nullptr, 10));
|
||||
}
|
||||
|
||||
// 创建 io_context 并启动监听 / Create io_context and start listener
|
||||
asio::io_context ioc;
|
||||
g_ioc = &ioc;
|
||||
|
||||
tcp::endpoint ep(asio::ip::make_address(web_host), web_port);
|
||||
Listener listener(ioc, ep);
|
||||
listener.run();
|
||||
|
||||
// 打印启动信息 / Print startup message
|
||||
std::printf("[dstalk-web] running at http://%s:%u\n", web_host, web_port);
|
||||
std::printf("[dstalk-web] Press Ctrl+C to stop\n");
|
||||
|
||||
// 运行事件循环(阻塞直到 g_ioc->stop() 被信号处理函数调用) / Run event loop (blocks until g_ioc->stop() called by signal handler)
|
||||
ioc.run();
|
||||
|
||||
// 清理 / Cleanup
|
||||
g_ioc = nullptr;
|
||||
dstalk_shutdown();
|
||||
|
||||
std::printf("[dstalk-web] stopped\n");
|
||||
return 0;
|
||||
}
|
||||
226
dstalk-web/src/web_ui.hpp
Normal file
226
dstalk-web/src/web_ui.hpp
Normal file
@@ -0,0 +1,226 @@
|
||||
/*
|
||||
* @file web_ui.hpp
|
||||
* @brief Embedded HTML/JS chat UI served by dstalk-web.
|
||||
* 嵌入的 HTML/JS 聊天界面,由 dstalk-web 提供。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#ifndef DSTALK_WEB_UI_HPP
|
||||
#define DSTALK_WEB_UI_HPP
|
||||
|
||||
// 深色主题单页聊天界面 — 通过 fetch ReadableStream 实现 SSE 流式传输
|
||||
// Dark-themed single-page chat UI — SSE streaming via fetch ReadableStream
|
||||
static const char kWebUiHtml[] = R"html(<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width,initial-scale=1.0">
|
||||
<title>dstalk Web</title>
|
||||
<style>
|
||||
*,*::before,*::after{box-sizing:border-box;margin:0;padding:0}
|
||||
body{font-family:system-ui,-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,Oxygen,Ubuntu,Cantarell,sans-serif;background:#1a1a2e;color:#eaeaea;height:100dvh;display:flex;flex-direction:column;overflow:hidden}
|
||||
#header{display:flex;align-items:center;justify-content:space-between;padding:8px 18px;background:#16162a;border-bottom:1px solid #2a2a4a;flex-shrink:0}
|
||||
#header h1{font-size:1rem;font-weight:600;display:flex;align-items:center;gap:8px;color:#a6b8e0}
|
||||
#header .status-row{display:flex;align-items:center;gap:14px;font-size:.75rem;color:#7a7a9a}
|
||||
#dot{width:9px;height:9px;border-radius:50%;background:#555;flex-shrink:0;transition:background .3s}
|
||||
#dot.connected{background:#4caf50}
|
||||
#dot.streaming{background:#f06292;animation:pulse .8s infinite}
|
||||
@keyframes pulse{0%,100%{opacity:1}50%{opacity:.35}}
|
||||
#clearBtn{font-size:.72rem;padding:3px 10px;border:1px solid #f06292;border-radius:4px;color:#f06292;background:transparent;cursor:pointer;transition:background .2s}
|
||||
#clearBtn:hover{background:#f0629218}
|
||||
#messages{flex:1;overflow-y:auto;padding:16px 18px;display:flex;flex-direction:column;gap:10px;scroll-behavior:smooth}
|
||||
.bubble{max-width:82%;padding:10px 14px;border-radius:10px;font-size:.9rem;line-height:1.55;word-wrap:break-word;animation:fadeIn .2s}
|
||||
.bubble.user{align-self:flex-end;background:#2a3f6e;border-bottom-right-radius:3px}
|
||||
.bubble.assistant{align-self:flex-start;background:#1e1e38;border:1px solid #2a2a4a;border-bottom-left-radius:3px}
|
||||
.bubble pre{background:#0d1117;padding:9px 12px;border-radius:6px;overflow-x:auto;margin:6px 0;font-size:.8rem;white-space:pre-wrap}
|
||||
.bubble code{font-family:"Fira Code","Cascadia Code",Consolas,monospace;font-size:.8rem}
|
||||
.bubble strong{color:#f06292}
|
||||
.bubble .lang{display:block;color:#8b949e;font-size:.68rem;margin-bottom:3px;text-transform:uppercase;letter-spacing:.4px}
|
||||
#typing{display:none;align-self:flex-start;padding:10px 14px;background:#1e1e38;border:1px solid #2a2a4a;border-radius:10px;border-bottom-left-radius:3px}
|
||||
#typing.active{display:block}
|
||||
#typing span{display:inline-block;width:6px;height:6px;border-radius:50%;background:#f06292;margin:0 2px;animation:bounce 1.2s infinite}
|
||||
#typing span:nth-child(2){animation-delay:.15s}
|
||||
#typing span:nth-child(3){animation-delay:.3s}
|
||||
@keyframes bounce{0%,60%,100%{transform:translateY(0);opacity:.3}30%{transform:translateY(-6px);opacity:1}}
|
||||
#inputBar{display:flex;gap:8px;padding:10px 16px;background:#16162a;border-top:1px solid #2a2a4a;flex-shrink:0}
|
||||
#inputBar textarea{flex:1;resize:none;background:#1a1a2e;color:#eaeaea;border:1px solid #2a2a4a;border-radius:8px;padding:9px 12px;font-size:.88rem;font-family:inherit;outline:none;transition:border .2s,box-shadow .2s;min-height:40px;max-height:120px;rows:1}
|
||||
#inputBar textarea:focus{border-color:#f06292;box-shadow:0 0 8px #f0629230}
|
||||
#sendBtn,#stopBtn{padding:9px 16px;border:none;border-radius:8px;font-size:.88rem;font-weight:600;cursor:pointer;transition:background .2s,opacity .2s}
|
||||
#sendBtn{background:#f06292;color:#fff}
|
||||
#sendBtn:hover{background:#d4517a}
|
||||
#sendBtn:disabled{opacity:.45;cursor:not-allowed}
|
||||
#stopBtn{display:none;background:#4a4a6a;color:#ccc}
|
||||
#stopBtn:hover{background:#5a5a7a}
|
||||
#stopBtn.visible{display:inline-block}
|
||||
.emptyState{text-align:center;color:#4a4a6a;margin-top:20vh;font-size:.92rem;line-height:1.7}
|
||||
.emptyState .logo{font-size:2rem;margin-bottom:8px}
|
||||
@keyframes fadeIn{from{opacity:0;transform:translateY(5px)}to{opacity:1;transform:translateY(0)}}
|
||||
@media(max-width:600px){.bubble{max-width:92%}#inputBar{padding:8px 10px;gap:6px}#sendBtn,#stopBtn{padding:8px 14px;font-size:.82rem}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="header">
|
||||
<h1>◆ dstalk Web <span id="dot"></span></h1>
|
||||
<div class="status-row">
|
||||
<span id="lblModel">-</span>
|
||||
<button id="clearBtn" title="Clear conversation / 清空对话">Clear</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="messages"><div class="emptyState"><div class="logo">◆</div>dstalk Web<br>Send a message to begin.<br>发送消息开始对话。</div></div>
|
||||
<div id="typing"><span></span><span></span><span></span></div>
|
||||
<div id="inputBar">
|
||||
<textarea id="msgInput" placeholder="输入消息... (Enter 发送 / Shift+Enter 换行)" rows="1"></textarea>
|
||||
<button id="sendBtn">Send</button>
|
||||
<button id="stopBtn">Stop</button>
|
||||
</div>
|
||||
<script>
|
||||
const msgs=document.getElementById('messages'),input=document.getElementById('msgInput'),
|
||||
sendBtn=document.getElementById('sendBtn'),stopBtn=document.getElementById('stopBtn'),
|
||||
dot=document.getElementById('dot'),typing=document.getElementById('typing'),
|
||||
lblModel=document.getElementById('lblModel');
|
||||
let abortCtrl=null,streaming=false,lastAiBubble=null,tokenBuf='';
|
||||
|
||||
function scrollDown(){msgs.scrollTop=msgs.scrollHeight}
|
||||
|
||||
function clearEmptyState(){const e=msgs.querySelector('.emptyState');if(e)e.remove()}
|
||||
|
||||
function addBubble(role,text){
|
||||
clearEmptyState();
|
||||
const d=document.createElement('div');
|
||||
d.className='bubble '+role;
|
||||
d.innerHTML=renderMD(text);
|
||||
msgs.appendChild(d);
|
||||
scrollDown();
|
||||
return d;
|
||||
}
|
||||
|
||||
function renderMD(t){
|
||||
t=t.replace(/&/g,'&').replace(/</g,'<').replace(/>/g,'>');
|
||||
// 代码块 / code fences
|
||||
t=t.replace(/```(\w*)\n?([\s\S]*?)```/g,(_,lang,code)=>{
|
||||
const label=lang?'<span class="lang">'+lang+'</span>':'';
|
||||
return '<pre><code>'+label+code.trim()+'</code></pre>';
|
||||
});
|
||||
// 行内代码 / inline code
|
||||
t=t.replace(/`([^`]+)`/g,'<code>$1</code>');
|
||||
// 粗体 / bold
|
||||
t=t.replace(/\*\*(.+?)\*\*/g,'<strong>$1</strong>');
|
||||
// 换行 / newlines
|
||||
t=t.replace(/\n/g,'<br>');
|
||||
return t;
|
||||
}
|
||||
|
||||
function setStreaming(s){
|
||||
streaming=s;
|
||||
dot.classList.toggle('streaming',s);
|
||||
dot.classList.toggle('connected',!s);
|
||||
typing.classList.toggle('active',s&&!lastAiBubble);
|
||||
sendBtn.disabled=s;
|
||||
stopBtn.classList.toggle('visible',s);
|
||||
if(s){
|
||||
if(!lastAiBubble){lastAiBubble=addBubble('assistant','');typing.classList.remove('active')}
|
||||
tokenBuf='';
|
||||
}else{
|
||||
if(lastAiBubble&&tokenBuf)lastAiBubble.innerHTML=renderMD(tokenBuf);
|
||||
lastAiBubble=null;tokenBuf='';abortCtrl=null;
|
||||
}
|
||||
}
|
||||
|
||||
// 解析 SSE 帧 / Parse SSE frames (separated by double-newline)
|
||||
function parseSSE(text,callback){
|
||||
let idx;
|
||||
while((idx=text.indexOf('\n\n'))!==-1){
|
||||
const frame=text.slice(0,idx);
|
||||
text=text.slice(idx+2);
|
||||
const lines=frame.split('\n');
|
||||
let event='message',data='';
|
||||
for(const line of lines){
|
||||
if(line.startsWith('event: '))event=line.slice(7).trim();
|
||||
else if(line.startsWith('data: '))data+=line.slice(6);
|
||||
}
|
||||
callback(event,data);
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
async function send(){
|
||||
const text=input.value.trim();
|
||||
if(!text||streaming)return;
|
||||
input.value='';input.style.height='auto';
|
||||
addBubble('user',text);
|
||||
setStreaming(true);
|
||||
abortCtrl=new AbortController();
|
||||
try{
|
||||
const res=await fetch('/chat',{
|
||||
method:'POST',headers:{'Content-Type':'application/json'},
|
||||
body:JSON.stringify({message:text}),signal:abortCtrl.signal
|
||||
});
|
||||
if(!res.ok)throw new Error('HTTP '+res.status);
|
||||
const reader=res.body.getReader(),decoder=new TextDecoder();
|
||||
let leftover='';
|
||||
while(true){
|
||||
const{value,done}=await reader.read();
|
||||
if(done)break;
|
||||
leftover+=decoder.decode(value,{stream:true});
|
||||
leftover=parseSSE(leftover,(event,data)=>{
|
||||
if(event==='error'){
|
||||
if(lastAiBubble)lastAiBubble.innerHTML=renderMD(tokenBuf||'');
|
||||
const errEl=addBubble('assistant','[Error] '+data);
|
||||
errEl.style.borderColor='#f06292';
|
||||
}else if(event==='done'){
|
||||
/* stream complete */
|
||||
}else{
|
||||
tokenBuf+=data;
|
||||
if(lastAiBubble){lastAiBubble.innerHTML=renderMD(tokenBuf);scrollDown()}
|
||||
}
|
||||
});
|
||||
}
|
||||
// 处理残留在 leftover 中的非帧数据 / handle leftover non-frame data
|
||||
if(leftover.trim()){
|
||||
const trimmed=leftover.trim();
|
||||
if(trimmed.startsWith('data: '))tokenBuf+=trimmed.slice(6);
|
||||
}
|
||||
}catch(e){
|
||||
if(e.name!=='AbortError'){
|
||||
if(lastAiBubble&&tokenBuf)lastAiBubble.innerHTML=renderMD(tokenBuf);
|
||||
}
|
||||
}
|
||||
setStreaming(false);
|
||||
loadStatus();
|
||||
}
|
||||
|
||||
function stop(){if(abortCtrl){abortCtrl.abort();setStreaming(false)}}
|
||||
|
||||
input.addEventListener('keydown',e=>{
|
||||
if(e.key==='Enter'&&!e.shiftKey){e.preventDefault();send()}
|
||||
setTimeout(()=>{input.style.height='auto';input.style.height=Math.min(input.scrollHeight,120)+'px'},0);
|
||||
});
|
||||
|
||||
sendBtn.addEventListener('click',send);
|
||||
stopBtn.addEventListener('click',stop);
|
||||
|
||||
async function loadStatus(){
|
||||
try{
|
||||
const r=await fetch('/status',{method:'POST'});
|
||||
if(!r.ok)return;
|
||||
const s=await r.json();
|
||||
lblModel.textContent=(s.model||'-')+' | '+(s.provider||'-');
|
||||
dot.classList.toggle('connected',!!s.ai);
|
||||
}catch(e){}
|
||||
}
|
||||
|
||||
document.getElementById('clearBtn').addEventListener('click',async()=>{
|
||||
try{await fetch('/clear',{method:'POST'})}catch(e){}
|
||||
msgs.innerHTML='<div class="emptyState"><div class="logo">◆</div>dstalk Web<br>Send a message to begin.<br>发送消息开始对话。</div>';
|
||||
lastAiBubble=null;tokenBuf='';
|
||||
if(streaming)setStreaming(false);
|
||||
loadStatus();
|
||||
});
|
||||
|
||||
// 页面加载时检查后端状态 / Check backend status on load
|
||||
loadStatus();
|
||||
</script>
|
||||
</body>
|
||||
</html>)html";
|
||||
|
||||
#endif // DSTALK_WEB_UI_HPP
|
||||
@@ -1,7 +1,10 @@
|
||||
/*
|
||||
* example_plugin.cpp - Minimal dstalk plugin demonstrating the API contract.
|
||||
* @file example_plugin.cpp
|
||||
* @brief Example plugin demonstrating the dstalk plugin API contract.
|
||||
* 示例插件:演示 dstalk 插件 API 契约。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*
|
||||
* Build instructions (conceptual):
|
||||
* Build instructions (conceptual) / 构建说明(概念性):
|
||||
*
|
||||
* Linux / macOS:
|
||||
* g++ -std=c++20 -shared -fPIC -fvisibility=hidden \
|
||||
@@ -14,6 +17,7 @@
|
||||
* /Fe:example_plugin.dll example_plugin.cpp
|
||||
*
|
||||
* The resulting `.so` / `.dylib` / `.dll` can be loaded with:
|
||||
* 生成的 .so / .dylib / .dll 可通过以下方式加载:
|
||||
*
|
||||
* int id = dstalk_plugin_load("./example_plugin.so");
|
||||
*/
|
||||
@@ -25,11 +29,12 @@
|
||||
#include <cstring> /* strlen, strcmp */
|
||||
|
||||
/* ------------------------------------------------------------------
|
||||
* Private state (one instance per plugin load)
|
||||
* 私有状态(每个插件加载实例一份) / Private state (one instance per plugin load)
|
||||
* ------------------------------------------------------------------
|
||||
*
|
||||
* In a more complex plugin this struct would hold open database
|
||||
* connections, configuration, etc.
|
||||
* 在更复杂的插件中,此结构体可包含打开的数据库连接、配置等。
|
||||
*/
|
||||
|
||||
struct ExampleState {
|
||||
@@ -37,31 +42,33 @@ struct ExampleState {
|
||||
};
|
||||
|
||||
/* ------------------------------------------------------------------
|
||||
* Stored host API table so callbacks can use host services.
|
||||
* 保存主机 API 表,以便回调函数使用主机服务 / Stored host API table so callbacks can use host services.
|
||||
* ------------------------------------------------------------------ */
|
||||
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
|
||||
static ExampleState g_state; /* not heap-allocated: stays valid
|
||||
static ExampleState g_state; /* 非堆分配:在库映射期间持续有效 / not heap-allocated: stays valid
|
||||
while the library is mapped */
|
||||
|
||||
/* ------------------------------------------------------------------
|
||||
* on_init (was on_load)
|
||||
* on_init(原 on_load) / on_init (was on_load)
|
||||
* ------------------------------------------------------------------ */
|
||||
|
||||
// 插件初始化:保存主机指针,重置调用计数,记录加载消息 / Plugin init: store host pointer, reset call count, log loaded message.
|
||||
static int my_on_init(const dstalk_host_api_t* host)
|
||||
{
|
||||
g_host = host;
|
||||
g_state.call_count = 0;
|
||||
|
||||
/* TODO: real plugins would initialise resources here:
|
||||
* - parse a plugin-specific config file via host->config_get
|
||||
* - open a log file
|
||||
* - connect to a local service
|
||||
* - register services via host->register_service
|
||||
/* TODO: 真实插件应在此处初始化资源 / real plugins would initialise resources here:
|
||||
* - 通过 host->config_get 解析插件专属配置文件 / parse a plugin-specific config file via host->config_get
|
||||
* - 打开日志文件 / open a log file
|
||||
* - 连接到本地服务 / connect to a local service
|
||||
* - 通过 host->register_service 注册服务 / register services via host->register_service
|
||||
*
|
||||
* Return non-zero to signal a fatal initialisation error to the
|
||||
* host, which will then unload the plugin immediately.
|
||||
* 返回非零值以向主机报告致命初始化错误,主机将立即卸载该插件。
|
||||
*/
|
||||
|
||||
if (host) {
|
||||
@@ -73,12 +80,13 @@ static int my_on_init(const dstalk_host_api_t* host)
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------
|
||||
* on_shutdown (was on_unload)
|
||||
* on_shutdown(原 on_unload) / on_shutdown (was on_unload)
|
||||
* ------------------------------------------------------------------ */
|
||||
|
||||
// 插件关闭:记录调用次数,释放资源 / Plugin shutdown: log call count, release any resources.
|
||||
static void my_on_shutdown(void)
|
||||
{
|
||||
/* TODO: release any resources allocated in on_init. After this
|
||||
/* TODO: 释放 on_init 中分配的所有资源。此函数返回后主机将卸载共享库。 / release any resources allocated in on_init. After this
|
||||
* function returns the host will unmap the shared library. */
|
||||
|
||||
if (g_host) {
|
||||
@@ -92,20 +100,21 @@ static void my_on_shutdown(void)
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------
|
||||
* on_event (was on_message)
|
||||
* on_event(原 on_message) / on_event (was on_message)
|
||||
* ------------------------------------------------------------------ */
|
||||
|
||||
// 插件事件处理:记录消息事件,忽略其他事件类型 / Plugin event handler: log message events, ignore other event types.
|
||||
static void my_on_event(int event_type, const void* data)
|
||||
{
|
||||
if (event_type == DSTALK_EVENT_MESSAGE && data) {
|
||||
const auto* msg = static_cast<const dstalk_message_t*>(data);
|
||||
g_state.call_count++;
|
||||
|
||||
/* A real plugin might:
|
||||
* - log the conversation to a file
|
||||
* - apply content moderation
|
||||
* - translate messages on the fly
|
||||
* - enrich messages with external data
|
||||
/* 真实插件可能: / A real plugin might:
|
||||
* - 将对话记录到文件 / log the conversation to a file
|
||||
* - 实施内容审核 / apply content moderation
|
||||
* - 实时翻译消息 / translate messages on the fly
|
||||
* - 用外部数据丰富消息 / enrich messages with external data
|
||||
*/
|
||||
|
||||
if (g_host) {
|
||||
@@ -117,19 +126,19 @@ static void my_on_event(int event_type, const void* data)
|
||||
msg->role, std::strlen(msg->content));
|
||||
}
|
||||
}
|
||||
/* Other event types (DSTALK_EVENT_SESSION_CLEAR, DSTALK_EVENT_CONFIG_CHANGED,
|
||||
/* 其他事件类型 / Other event types (DSTALK_EVENT_SESSION_CLEAR, DSTALK_EVENT_CONFIG_CHANGED,
|
||||
DSTALK_EVENT_PLUGIN_LOADED, DSTALK_EVENT_PLUGIN_UNLOADED, DSTALK_EVENT_CUSTOM+)
|
||||
are silently ignored by this minimal plugin. */
|
||||
此最小化插件静默忽略 / are silently ignored by this minimal plugin. */
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------
|
||||
* Plugin descriptor (static -- lives for the lifetime of the .so)
|
||||
* 插件描述符(静态 —— 在 .so 的生命周期内有效) / Plugin descriptor (static -- lives for the lifetime of the .so)
|
||||
* ------------------------------------------------------------------ */
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "example-plugin",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "An example plugin for dstalk",
|
||||
/* .description = */ "An example plugin for dstalk / dstalk 示例插件",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ {nullptr},
|
||||
/* .on_init = */ my_on_init,
|
||||
@@ -138,13 +147,16 @@ static dstalk_plugin_info_t g_info = {
|
||||
};
|
||||
|
||||
/* ------------------------------------------------------------------
|
||||
* Mandatory entry point
|
||||
* 必须入口点 / Mandatory entry point
|
||||
* ------------------------------------------------------------------
|
||||
*
|
||||
* The host looks for this symbol via dlsym / GetProcAddress.
|
||||
* 主机通过 dlsym / GetProcAddress 查找此符号。
|
||||
* It MUST be declared extern "C" so the name is not mangled.
|
||||
* 必须声明为 extern "C" 以避免名称修饰。
|
||||
*/
|
||||
|
||||
// 返回插件描述符给主机加载器 / Returns the plugin descriptor to the host loader.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
|
||||
{
|
||||
return &g_info;
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file anthropic_plugin.cpp
|
||||
* @brief Anthropic Claude Messages API provider plugin with streaming support.
|
||||
* Anthropic Claude Messages API 提供者插件,支持流式输出。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
@@ -11,14 +18,14 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 全局指针 — W17.4: std::atomic 保护 on_shutdown 与 service 函数并发读写
|
||||
// 全局指针 — W17.4: std::atomic 保护 on_shutdown 与 service 函数并发读写 / Global pointers — W17.4: std::atomic protects concurrent read/write between on_shutdown and service functions
|
||||
// ============================================================================
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
|
||||
static dstalk_config_service_t* g_config = nullptr;
|
||||
|
||||
// ============================================================================
|
||||
// 配置数据
|
||||
// 配置数据 / Config data
|
||||
// ============================================================================
|
||||
struct PluginConfig {
|
||||
std::string provider;
|
||||
@@ -29,19 +36,21 @@ struct PluginConfig {
|
||||
double temperature = 0.7;
|
||||
};
|
||||
static PluginConfig g_cfg;
|
||||
static std::string g_tools_json; // W21.2: cached by configure(), consumed by chat/chat_stream
|
||||
static std::string g_tools_json; // W21.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
|
||||
|
||||
// ============================================================================
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
|
||||
// ============================================================================
|
||||
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
|
||||
static void secure_zero(void* p, size_t n) {
|
||||
volatile char* vp = (volatile char*)p;
|
||||
while (n--) *vp++ = 0;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:提取 host / target
|
||||
// 辅助:提取 host / target / Helper: extract host / target
|
||||
// ============================================================================
|
||||
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
|
||||
static bool extract_host_port(const std::string& url,
|
||||
std::string& scheme_out, std::string& host_out,
|
||||
std::string& port_out, std::string& target_out)
|
||||
@@ -65,8 +74,9 @@ static bool extract_host_port(const std::string& url,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 构建 Anthropic headers JSON
|
||||
// 构建 Anthropic headers JSON / Build Anthropic headers JSON
|
||||
// ============================================================================
|
||||
// 构建包含 x-api-key 和 anthropic-version 的 JSON headers 对象 / Build the JSON headers object containing x-api-key and anthropic-version.
|
||||
static std::string build_headers_json()
|
||||
{
|
||||
json::object h;
|
||||
@@ -76,8 +86,11 @@ static std::string build_headers_json()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 构建 Anthropic JSON 请求体
|
||||
// 构建 Anthropic JSON 请求体 / Build Anthropic JSON request body
|
||||
// ============================================================================
|
||||
// 构建 Anthropic Messages API 的完整 JSON 请求体。
|
||||
// 按 Anthropic 规范将 system 消息提取为顶层 system 字段 / Build the full JSON request body for the Anthropic Messages API.
|
||||
// Extracts system messages as a top-level "system" field per Anthropic spec.
|
||||
static std::string build_request_json(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const std::string& user_input,
|
||||
@@ -89,7 +102,7 @@ static std::string build_request_json(
|
||||
root["max_tokens"] = g_cfg.max_tokens;
|
||||
root["stream"] = stream;
|
||||
|
||||
// 提取 system 消息作为顶层字段
|
||||
// 提取 system 消息作为顶层字段 / Extract system messages as top-level field
|
||||
std::string system_prompt;
|
||||
json::array msgs;
|
||||
|
||||
@@ -106,7 +119,7 @@ static std::string build_request_json(
|
||||
msgs.push_back(obj);
|
||||
}
|
||||
|
||||
// 追加当前用户输入
|
||||
// 追加当前用户输入 / Append current user input
|
||||
{
|
||||
json::object obj;
|
||||
obj["role"] = "user";
|
||||
@@ -124,7 +137,7 @@ static std::string build_request_json(
|
||||
root["temperature"] = g_cfg.temperature;
|
||||
}
|
||||
|
||||
// W21.2: tools 定义传递给 API
|
||||
// W21.2: tools 定义传递给 API / Pass tools definition to API
|
||||
if (!tools_json.empty()) {
|
||||
root["tools"] = json::parse(tools_json);
|
||||
}
|
||||
@@ -133,8 +146,11 @@ static std::string build_request_json(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 解析非流式响应
|
||||
// 解析非流式响应 / Parse non-streaming response
|
||||
// ============================================================================
|
||||
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t。
|
||||
// 处理 text 和 tool_use content block,将 tool_use 转换为 OpenAI 格式 / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
|
||||
// Handles text and tool_use content blocks, converting tool_use to OpenAI format.
|
||||
static void parse_response(const char* body, int http_status,
|
||||
dstalk_chat_result_t& r)
|
||||
{
|
||||
@@ -169,7 +185,7 @@ static void parse_response(const char* body, int http_status,
|
||||
auto obj = jv.as_object();
|
||||
auto content = obj["content"].as_array();
|
||||
if (!content.empty()) {
|
||||
// W21.2: 提取 text 和 tool_use content blocks
|
||||
// W21.2: 提取 text 和 tool_use content blocks / Extract text and tool_use content blocks
|
||||
std::string text_content;
|
||||
json::array tool_use_blocks;
|
||||
|
||||
@@ -181,7 +197,7 @@ static void parse_response(const char* body, int http_status,
|
||||
if (btype == "text") {
|
||||
text_content = json::value_to<std::string>(bobj["text"]);
|
||||
} else if (btype == "tool_use") {
|
||||
// 转换为 OpenAI 兼容格式: {id, type:"function", function:{name, arguments}}
|
||||
// 转换为 OpenAI 兼容格式: {id, type:"function", function:{name, arguments}} / Convert to OpenAI-compatible format: {id, type:"function", function:{name, arguments}}
|
||||
json::object tc;
|
||||
tc["id"] = bobj["id"];
|
||||
tc["type"] = "function";
|
||||
@@ -206,7 +222,7 @@ static void parse_response(const char* body, int http_status,
|
||||
r.error = nullptr;
|
||||
return;
|
||||
} else if (!tool_use_blocks.empty()) {
|
||||
// tool-only 响应
|
||||
// tool-only 响应 / tool-only response
|
||||
r.content = nullptr;
|
||||
r.ok = 1;
|
||||
r.error = nullptr;
|
||||
@@ -235,15 +251,15 @@ static void parse_response(const char* body, int http_status,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE 事件解析(Anthropic 格式: event/content_block_delta)
|
||||
// SSE 事件解析(Anthropic 格式: event/content_block_delta) / SSE event parsing (Anthropic format: event/content_block_delta)
|
||||
// ============================================================================
|
||||
|
||||
// W21.2: 按 content_block index 累积 Anthropic tool_use 增量
|
||||
// W21.2: 按 content_block index 累积 Anthropic tool_use 增量 / Accumulate Anthropic tool_use increments by content_block index
|
||||
struct ToolCallAccum {
|
||||
int index = -1;
|
||||
std::string id;
|
||||
std::string name;
|
||||
std::string arguments; // 从 input_json_delta.partial_json 累积
|
||||
std::string arguments; // 从 input_json_delta.partial_json 累积 / accumulated from input_json_delta.partial_json
|
||||
};
|
||||
|
||||
struct StreamContext {
|
||||
@@ -252,10 +268,15 @@ struct StreamContext {
|
||||
void* userdata;
|
||||
std::string accumulated;
|
||||
bool saw_data_line = false;
|
||||
std::vector<ToolCallAccum> tool_calls; // W21.2: 按 index 累积 tool_use content blocks
|
||||
std::vector<ToolCallAccum> tool_calls; // W21.2: 按 index 累积 tool_use content blocks / accumulate tool_use content blocks by index
|
||||
};
|
||||
|
||||
// W21.2: 解析 Anthropic SSE 事件,含 tool_use content_block 增量解析
|
||||
// W21.2: 解析 Anthropic SSE 事件,含 tool_use content_block 增量解析 / Parse Anthropic SSE events with tool_use content_block incremental parsing
|
||||
// 解析单个 Anthropic SSE "data:" JSON 事件。处理 content_block_start、
|
||||
// content_block_delta (text_delta/input_json_delta) 和 message_stop。
|
||||
// 如果产生了 content token 则返回 true,否则返回 false / Parse a single Anthropic SSE "data:" JSON event. Handles content_block_start,
|
||||
// content_block_delta (text_delta/input_json_delta), and message_stop.
|
||||
// Returns true if a content token was produced, false otherwise.
|
||||
static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
StreamContext* ctx)
|
||||
{
|
||||
@@ -268,7 +289,7 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
std::string type = json::value_to<std::string>(*type_ptr);
|
||||
|
||||
if (type == "content_block_start") {
|
||||
// content_block_start 可能为 tool_use
|
||||
// content_block_start 可能为 tool_use / content_block_start may be tool_use
|
||||
auto* cb = obj.if_contains("content_block");
|
||||
if (!cb || !cb->is_object()) return false;
|
||||
auto& cb_obj = cb->as_object();
|
||||
@@ -311,7 +332,7 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
return true;
|
||||
}
|
||||
} else if (delta_type == "input_json_delta" && ctx) {
|
||||
// W21.2: 累积 tool_use arguments 分片
|
||||
// W21.2: 累积 tool_use arguments 分片 / Accumulate tool_use arguments fragments
|
||||
auto* pj = dobj.if_contains("partial_json");
|
||||
if (pj && pj->is_string()) {
|
||||
auto* idx_ptr = obj.if_contains("index");
|
||||
@@ -326,18 +347,19 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
}
|
||||
} else if (type == "message_stop") {
|
||||
token_out.clear();
|
||||
return true; // 流结束
|
||||
return true; // 流结束 / stream end
|
||||
}
|
||||
// 忽略: message_start, content_block_stop, ping, message_delta
|
||||
// 忽略: message_start, content_block_stop, ping, message_delta / Ignore: message_start, content_block_stop, ping, message_delta
|
||||
} catch (...) {
|
||||
// 解析失败忽略
|
||||
// 解析失败忽略 / Ignore parse failures
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// configure
|
||||
// configure / configure
|
||||
// ============================================================================
|
||||
// 配置插件:provider、endpoint、auth、model 和生成参数 / Configure the plugin with provider, endpoint, auth, model, and generation parameters.
|
||||
static int my_configure(const char* provider, const char* base_url,
|
||||
const char* api_key, const char* model,
|
||||
int max_tokens, double temperature)
|
||||
@@ -352,7 +374,7 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) {
|
||||
// W21.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用
|
||||
// W21.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
|
||||
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
|
||||
h->query_service("tools", 1));
|
||||
if (tools_svc && tools_svc->get_tools_json) {
|
||||
@@ -381,8 +403,9 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// chat
|
||||
// chat / chat
|
||||
// ============================================================================
|
||||
// 非流式 chat completion:发送 history + user input,返回完整响应 / Non-streaming chat completion: send history + user input, return full response.
|
||||
static dstalk_chat_result_t my_chat(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
@@ -447,26 +470,27 @@ static dstalk_chat_result_t my_chat(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// chat_stream
|
||||
// chat_stream / chat_stream
|
||||
// ============================================================================
|
||||
|
||||
// 行回调
|
||||
// 行回调 / SSE line callback
|
||||
// SSE 行回调:解析每个 Anthropic SSE 行并将文本 token 转发给用户 / SSE line callback: parses each Anthropic SSE line and forwards text tokens to user.
|
||||
static int sse_line_callback(const char* line, void* userdata)
|
||||
{
|
||||
try {
|
||||
auto* ctx = static_cast<StreamContext*>(userdata);
|
||||
if (!line || !line[0]) return 1; // 空行,继续
|
||||
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
|
||||
|
||||
std::string line_str(line);
|
||||
|
||||
// SSE 格式: "data: <json>"
|
||||
// SSE 格式: "data: <json>" / SSE format: "data: <json>"
|
||||
if (line_str.rfind("data: ", 0) == 0) {
|
||||
std::string data = line_str.substr(6);
|
||||
std::string token;
|
||||
if (parse_sse_data(data, token, ctx)) {
|
||||
ctx->saw_data_line = true;
|
||||
if (token.empty()) {
|
||||
// message_stop
|
||||
// message_stop / message_stop
|
||||
return 0;
|
||||
}
|
||||
ctx->accumulated += token;
|
||||
@@ -475,7 +499,7 @@ static int sse_line_callback(const char* line, void* userdata)
|
||||
}
|
||||
}
|
||||
}
|
||||
// "event: ..." 行和其他 -> 忽略
|
||||
// "event: ..." 行和其他 -> 忽略 / "event: ..." lines and others -> ignored
|
||||
return 1;
|
||||
} catch (const std::exception& e) {
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
@@ -488,6 +512,9 @@ static int sse_line_callback(const char* line, void* userdata)
|
||||
}
|
||||
}
|
||||
|
||||
// 流式 chat completion:以 stream=true 发送 history + user input,通过回调传递 token。
|
||||
// 累积 tool_use blocks 并在结束时序列化 / Streaming chat completion: send history + user input with stream=true, deliver tokens
|
||||
// via callback. Accumulates tool_use blocks and serializes them at end.
|
||||
static dstalk_chat_result_t my_chat_stream(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
@@ -531,7 +558,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
r.http_status = status_code;
|
||||
|
||||
// 检查错误状态
|
||||
// 检查错误状态 / Check error status
|
||||
if (status_code < 200 || status_code >= 300) {
|
||||
r.ok = 0;
|
||||
if (response_body && response_body[0]) {
|
||||
@@ -560,7 +587,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
if (response_body) host->free(response_body);
|
||||
|
||||
// W21.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling)
|
||||
// W21.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling) / Success = has content OR has tool_calls (tool-only responses like function calling)
|
||||
bool has_content = !ctx.accumulated.empty();
|
||||
bool has_tool_calls = !ctx.tool_calls.empty();
|
||||
|
||||
@@ -575,7 +602,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
r.content = has_content
|
||||
? host->strdup(ctx.accumulated.c_str()) : nullptr;
|
||||
|
||||
// W21.2: 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式)
|
||||
// W21.2: 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible format)
|
||||
if (has_tool_calls) {
|
||||
json::array tc_array;
|
||||
for (auto& tc : ctx.tool_calls) {
|
||||
@@ -614,8 +641,9 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// free_result
|
||||
// free_result / free_result
|
||||
// ============================================================================
|
||||
// 释放 chat result 结构体中所有主机分配的字符串字段 / Free all host-allocated string fields in a chat result struct.
|
||||
static void my_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
@@ -626,7 +654,7 @@ static void my_free_result(dstalk_chat_result_t* result)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 服务 vtable
|
||||
// 服务 vtable / Service vtable
|
||||
// ============================================================================
|
||||
static dstalk_ai_service_t g_service = {
|
||||
&my_configure,
|
||||
@@ -636,8 +664,9 @@ static dstalk_ai_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// 生命周期
|
||||
// 生命周期 / Lifecycle
|
||||
// ============================================================================
|
||||
// 插件初始化:查询 http 和 config 服务,注册 ai.anthropic 服务 / Plugin init: query http and config services, register ai.anthropic service.
|
||||
static int on_init(const dstalk_host_api_t* host)
|
||||
{
|
||||
try {
|
||||
@@ -666,6 +695,7 @@ static int on_init(const dstalk_host_api_t* host)
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:从内存安全擦除 API key,清空服务指针 / Plugin shutdown: securely erase API key from memory, null out service pointers.
|
||||
static void on_shutdown()
|
||||
{
|
||||
try {
|
||||
@@ -686,12 +716,12 @@ static void on_shutdown()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 插件描述符
|
||||
// 插件描述符 / Plugin descriptor
|
||||
// ============================================================================
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "anthropic-ai",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "Anthropic Claude AI provider (Messages API)",
|
||||
/* .description = */ "Anthropic Claude AI provider (Messages API) / Anthropic Claude AI 提供者 (Messages API)",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ { "http", "config", NULL },
|
||||
/* .on_init = */ on_init,
|
||||
@@ -699,6 +729,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
/* .on_event = */ nullptr,
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
|
||||
{
|
||||
return &g_info;
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
/*
|
||||
* @file toml_parse.h
|
||||
* @brief Lightweight single-header TOML parser (subset: flat key-value pairs).
|
||||
* 轻量级单头文件 TOML 解析器(子集:扁平键值对)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
// Shared TOML parser — used by both ConfigStore (core) and config plugin.
|
||||
// 共享 TOML 解析器 —— 由 ConfigStore(核心)和 config 插件共同使用 / Shared TOML parser — used by both ConfigStore (core) and config plugin.
|
||||
// W12.2: Extracted from config_store.cpp:23-61 and config_plugin.cpp:28-66
|
||||
// to eliminate the 74-line code duplication (W11.2 audit Finding 1).
|
||||
// Does NOT support: inline tables, arrays, multi-line strings, escape sequences.
|
||||
// 不支持:内联表、数组、多行字符串、转义序列。
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace dstalk {
|
||||
namespace toml {
|
||||
|
||||
/// Parse a TOML string, calling on_kv(full_key, value) for each key-value pair.
|
||||
/// Supports [section] headers, key = "value" pairs, # comments, blank lines.
|
||||
/// 解析 TOML 字符串,对每个键值对调用 on_kv(full_key, value) / Parse a TOML string, calling on_kv(full_key, value) for each key-value pair.
|
||||
/// 支持 [section] 标题、key = "value" 键值对、# 注释、空行 / Supports [section] headers, key = "value" pairs, # comments, blank lines.
|
||||
template<typename F>
|
||||
inline void parse(const std::string& content, F&& on_kv)
|
||||
{
|
||||
@@ -18,31 +26,31 @@ inline void parse(const std::string& content, F&& on_kv)
|
||||
size_t pos = 0;
|
||||
|
||||
while (pos < content.size()) {
|
||||
// Trim left whitespace
|
||||
// 去除左侧空白 / Trim left whitespace
|
||||
while (pos < content.size() && (content[pos] == ' ' || content[pos] == '\t'))
|
||||
pos++;
|
||||
if (pos >= content.size()) break;
|
||||
|
||||
// Extract next line
|
||||
// 提取下一行 / Extract next line
|
||||
size_t nl = content.find('\n', pos);
|
||||
std::string line = (nl != std::string::npos)
|
||||
? content.substr(pos, nl - pos) : content.substr(pos);
|
||||
pos = (nl != std::string::npos) ? nl + 1 : content.size();
|
||||
|
||||
// Trim right whitespace (including \r)
|
||||
// 去除右侧空白(包括 \r) / Trim right whitespace (including \r)
|
||||
while (!line.empty() && (line.back() == '\r' || line.back() == ' '))
|
||||
line.pop_back();
|
||||
|
||||
// Skip empty lines and comments
|
||||
// 跳过空行和注释 / Skip empty lines and comments
|
||||
if (line.empty() || line[0] == '#') continue;
|
||||
|
||||
// Section header: [section_name]
|
||||
// 节标题: [section_name] / Section header: [section_name]
|
||||
if (line[0] == '[' && line.back() == ']') {
|
||||
current_section = line.substr(1, line.size() - 2);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Key = value
|
||||
// 键 = 值 / Key = value
|
||||
size_t eq = line.find('=');
|
||||
if (eq == std::string::npos) continue;
|
||||
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file config_plugin.cpp
|
||||
* @brief Config plugin: TOML file parsing and key-value configuration service.
|
||||
* 配置插件:TOML 文件解析和键值配置服务。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
#include "../include/toml_parse.h"
|
||||
@@ -7,12 +14,12 @@
|
||||
#include <sstream>
|
||||
|
||||
// ============================================================
|
||||
// Global state
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// Service implementations
|
||||
// 服务实现 / Service implementations
|
||||
//
|
||||
// W12.2: Eliminated private ConfigStore (was 90 lines duplicating core).
|
||||
// All get/set/load_file now delegate to the host store via g_host->config_get
|
||||
@@ -20,16 +27,19 @@ static const dstalk_host_api_t* g_host = nullptr;
|
||||
// TOML parsing uses the shared dstalk::toml::parse() from toml_parse.h.
|
||||
// ============================================================
|
||||
|
||||
// 从主机存储中按 key 获取配置值 / Retrieve a configuration value by key from the host store.
|
||||
static const char* config_get(const char* key) {
|
||||
if (!g_host) return nullptr;
|
||||
return g_host->config_get(key);
|
||||
}
|
||||
|
||||
// 将键值对存入主机存储 / Store a configuration key-value pair into the host store.
|
||||
static int config_set(const char* key, const char* value) {
|
||||
if (!g_host) return -1;
|
||||
return g_host->config_set(key, value);
|
||||
}
|
||||
|
||||
// 解析指定路径的 TOML 文件,将所有键值对加载到主机存储中 / Parse a TOML file at `path` and load all key-value pairs into the host store.
|
||||
static int config_load_file(const char* path) {
|
||||
if (!g_host || !path) return -1;
|
||||
|
||||
@@ -58,12 +68,13 @@ static dstalk_config_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Plugin lifecycle
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
// 插件初始化:保存主机指针并注册 config 服务 vtable / Plugin init: store host pointer and register the config service vtable.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
|
||||
// W12.2: This service is now a thin wrapper around host->config_get/set.
|
||||
// W12.2: 该服务现为 host->config_get/set 的薄封装,建议直接调用主机 API / This service is now a thin wrapper around host->config_get/set.
|
||||
// Direct host API calls are preferred.
|
||||
host->log(DSTALK_LOG_INFO,
|
||||
"plugin config service is deprecated, prefer host->config_get/set");
|
||||
@@ -76,8 +87,10 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
return (rc >= 0) ? 0 : -1;
|
||||
}
|
||||
|
||||
// 插件关闭:无需清理本地存储(所有数据在主机存储中) / Plugin shutdown: no local store to clean up (all data lives in host store).
|
||||
static void on_shutdown() {
|
||||
// W12.2: No local store to clean up — all data lives in host store.
|
||||
// 无需清理本地存储——所有数据位于主机存储中。
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
@@ -91,6 +104,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
nullptr // on_event
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
// plugin-context: 上下文管理服务插件
|
||||
// 提供 dstalk_context_service_t vtable 实现
|
||||
// 依赖: session (获取历史消息做 token 计数)
|
||||
/*
|
||||
* @file context_plugin.cpp
|
||||
* @brief Context plugin: token counting and context window trimming.
|
||||
* 上下文插件:token 计数和上下文窗口裁剪。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-context: 上下文管理服务插件 / Context management service plugin
|
||||
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
|
||||
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_types.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
@@ -15,21 +22,26 @@
|
||||
#include <vector>
|
||||
|
||||
// ============================================================
|
||||
// 全局状态
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
static const dstalk_session_service_t* g_session = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// 内部 C++ 辅助:共享 UTF-8 token 计数
|
||||
// 内部 C++ 辅助:共享 UTF-8 token 计数 / Internal C++ helper: shared UTF-8 token counting
|
||||
// W18.1: 合并 count_tokens_one_message / count_tokens_trim 的重复逻辑 (F-11.1-5)
|
||||
// Merge duplicated logic between count_tokens_one_message / count_tokens_trim (F-11.1-5)
|
||||
// 添加 UTF-8 越界保护 (F-11.1-4) 和 0xC0/0xC1 过短编码检测 (F-11.1-6)
|
||||
// Add UTF-8 out-of-bounds protection (F-11.1-4) and 0xC0/0xC1 overlong encoding detection (F-11.1-6)
|
||||
// ============================================================
|
||||
|
||||
// 统计 UTF-8 字节序列 [text, text+len) 的估算 token 数。
|
||||
// overhead: 每条消息的固定开销 token(role + separators = 4)
|
||||
// 多字节序列在越界或无效后继字节时回退为单字节 other_chars 计数,不崩溃。
|
||||
// Count estimated tokens for UTF-8 byte sequence [text, text+len).
|
||||
// overhead: fixed token overhead per message (role + separators = 4).
|
||||
// Multi-byte sequences fall back to single-byte other_chars counting when out-of-bounds or invalid continuation bytes.
|
||||
static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
if (!text || len == 0) return overhead;
|
||||
|
||||
@@ -42,12 +54,12 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
unsigned char c = static_cast<unsigned char>(text[i]);
|
||||
|
||||
if (c < 0x80) {
|
||||
// ASCII
|
||||
// ASCII / ASCII
|
||||
ascii_chars++;
|
||||
i += 1;
|
||||
} else if (c >= 0xE4 && c <= 0xE9) {
|
||||
// CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9
|
||||
// W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内
|
||||
// CJK 统一表意文字 (U+4E00-U+9FFF): 3 字节 UTF-8 0xE4-0xE9 / CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9
|
||||
// W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内 / Check if subsequent 2 bytes are in valid range
|
||||
if (i + 2 >= len ||
|
||||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
|
||||
@@ -58,8 +70,8 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
i += 3;
|
||||
}
|
||||
} else if (c >= 0xC2 && c < 0xE0) {
|
||||
// 2-byte sequence (valid range 0xC2-0xDF)
|
||||
// W18.1 (F-11.1-4): 检查后续 1 字节
|
||||
// 2 字节序列 (有效范围 0xC2-0xDF) / 2-byte sequence (valid range 0xC2-0xDF)
|
||||
// W18.1 (F-11.1-4): 检查后续 1 字节 / Check subsequent 1 byte
|
||||
if (i + 1 >= len ||
|
||||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80) {
|
||||
other_chars++;
|
||||
@@ -69,13 +81,13 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
i += 2;
|
||||
}
|
||||
} else if (c == 0xC0 || c == 0xC1) {
|
||||
// W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节
|
||||
// 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars
|
||||
// W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节 / Overlong encoding, invalid UTF-8 start byte
|
||||
// 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars / 0xC0/0xC1 never appear in valid UTF-8; counted as single-byte in other_chars
|
||||
other_chars++;
|
||||
i += 1;
|
||||
} else if (c >= 0xE0 && c < 0xF0) {
|
||||
// Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF)
|
||||
// CJK 范围 0xE4-0xE9 已在上方分支处理
|
||||
// 非 CJK 3 字节序列 (0xE0-0xE3, 0xEA-0xEF) / Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF)
|
||||
// CJK 范围 0xE4-0xE9 已在上方分支处理 / CJK range 0xE4-0xE9 handled in branch above
|
||||
if (i + 2 >= len ||
|
||||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
|
||||
@@ -86,7 +98,7 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
i += 3;
|
||||
}
|
||||
} else if (c >= 0xF0 && c < 0xF8) {
|
||||
// 4-byte sequence
|
||||
// 4 字节序列 / 4-byte sequence
|
||||
if (i + 3 >= len ||
|
||||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80 ||
|
||||
@@ -98,7 +110,7 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
i += 4;
|
||||
}
|
||||
} else {
|
||||
// Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF)
|
||||
// 续字节 (0x80-0xBF) 和其他无效起始字节 (0xF8-0xFF) / Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF)
|
||||
other_chars++;
|
||||
i += 1;
|
||||
}
|
||||
@@ -108,15 +120,17 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装)
|
||||
// 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装) / Message-level token counting (thin wrappers for count_tokens_all and trim_impl)
|
||||
// ============================================================
|
||||
|
||||
// 对单条 C 消息结构体封装 count_tokens_utf8 / Wrap count_tokens_utf8 for a single C message struct.
|
||||
static size_t count_tokens_one_message(const dstalk_message_t& msg) {
|
||||
const char* text = msg.content;
|
||||
if (!text) return 4; // 只有 overhead
|
||||
if (!text) return 4; // 只有 overhead / overhead only
|
||||
return count_tokens_utf8(text, std::strlen(text), 4);
|
||||
}
|
||||
|
||||
// 对 C 消息数组求和估算 token / Sum token estimates across an array of C messages.
|
||||
static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
|
||||
size_t total = 0;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
@@ -126,10 +140,10 @@ static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 内部 trim 逻辑
|
||||
// 内部 trim 逻辑 / Internal trim logic
|
||||
// ============================================================
|
||||
|
||||
// 为 trim 操作将 C 消息数组复制到内部 struct
|
||||
// 为 trim 操作将 C 消息数组复制到内部 struct / Copy C message array to internal struct for trim operation
|
||||
struct TrimMessage {
|
||||
std::string role;
|
||||
std::string content;
|
||||
@@ -148,7 +162,7 @@ static size_t count_tokens_trim_vec(const std::vector<TrimMessage>& msgs) {
|
||||
return total;
|
||||
}
|
||||
|
||||
// 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚)
|
||||
// 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚) / Free all host-allocated string fields in a single dstalk_message_t (OOM rollback helper).
|
||||
static void free_msg_strs(dstalk_message_t* msg) {
|
||||
if (msg->role) { g_host->free((void*)msg->role); msg->role = nullptr; }
|
||||
if (msg->content) { g_host->free((void*)msg->content); msg->content = nullptr; }
|
||||
@@ -158,6 +172,8 @@ static void free_msg_strs(dstalk_message_t* msg) {
|
||||
|
||||
// 将 TrimMessage 的字符串字段通过 g_host->strdup 复制到 dstalk_message_t。
|
||||
// 成功返回 0;OOM 时释放当前消息已分配字段并返回 -1。
|
||||
// Copy TrimMessage string fields into a dstalk_message_t via host->strdup.
|
||||
// On OOM, frees already-allocated fields and returns -1.
|
||||
static int strdup_message_fields(dstalk_message_t* dst, const TrimMessage& src) {
|
||||
memset(dst, 0, sizeof(dstalk_message_t));
|
||||
|
||||
@@ -184,7 +200,10 @@ oom:
|
||||
return -1;
|
||||
}
|
||||
|
||||
// W12.1 修复:trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3)
|
||||
// W12.1 修复:trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3) / W12.1 fix: trim_impl wrapped in try/catch to prevent C++ exceptions crossing ABI boundary (§5.3)
|
||||
// 核心裁剪逻辑:通过删除最旧的 user/assistant 对来减少消息列表以适应 max_tokens。
|
||||
// 保留 system 消息。try/catch 保护 ABI / Core trim logic: reduce message list to fit within max_tokens by removing
|
||||
// oldest user/assistant pairs. Preserves system messages. try/catch guards ABI.
|
||||
static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
dstalk_message_t** out, int* out_count,
|
||||
size_t max_tokens) {
|
||||
@@ -192,10 +211,11 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
if (!in || in_count <= 0 || !out || !out_count) return -1;
|
||||
|
||||
// W18.1 (F-11.1-3): g_max_tokens 已移除,调用方必须提供有效 max_tokens;
|
||||
// 传 0 时使用硬编码默认值 4096。
|
||||
// 传 0 时使用硬编码默认值 4096 / g_max_tokens removed, caller must provide valid max_tokens;
|
||||
// when 0 is passed, use hardcoded default 4096.
|
||||
if (max_tokens == 0) max_tokens = 4096;
|
||||
|
||||
// 将 C 数组转换为内部 vector
|
||||
// 将 C 数组转换为内部 vector / Convert C array to internal vector
|
||||
std::vector<TrimMessage> messages;
|
||||
messages.reserve(in_count);
|
||||
for (int i = 0; i < in_count; ++i) {
|
||||
@@ -207,13 +227,13 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
messages.push_back(std::move(tm));
|
||||
}
|
||||
|
||||
// 如果已在限制内,直接返回完整副本
|
||||
// 如果已在限制内,直接返回完整副本 / If already within limit, return full copy directly
|
||||
size_t current = count_tokens_trim_vec(messages);
|
||||
if (current <= max_tokens) {
|
||||
*out_count = in_count;
|
||||
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * in_count));
|
||||
if (!*out) return -1;
|
||||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息
|
||||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback already allocated on OOM
|
||||
for (int i = 0; i < in_count; ++i) {
|
||||
if (strdup_message_fields(&(*out)[i], messages[i]) != 0) {
|
||||
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
|
||||
@@ -225,7 +245,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 分离 system 消息和非 system 消息
|
||||
// 分离 system 消息和非 system 消息 / Separate system messages from non-system messages
|
||||
std::vector<TrimMessage> system_msgs;
|
||||
std::vector<TrimMessage> non_system_msgs;
|
||||
for (const auto& msg : messages) {
|
||||
@@ -243,7 +263,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
system_tokens, max_tokens);
|
||||
}
|
||||
|
||||
// 检查是否有单条消息超过限制
|
||||
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
|
||||
for (const auto& msg : non_system_msgs) {
|
||||
size_t msg_tokens = count_tokens_trim(msg);
|
||||
if (msg_tokens > max_tokens) {
|
||||
@@ -257,19 +277,19 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
}
|
||||
|
||||
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除
|
||||
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
|
||||
while (!non_system_msgs.empty()) {
|
||||
current = system_tokens + count_tokens_trim_vec(non_system_msgs);
|
||||
if (current <= max_tokens) break;
|
||||
|
||||
// 找第一个 "user" 消息
|
||||
// 找第一个 "user" 消息 / Find first "user" message
|
||||
auto user_it = non_system_msgs.begin();
|
||||
while (user_it != non_system_msgs.end() && user_it->role != "user") {
|
||||
++user_it;
|
||||
}
|
||||
if (user_it == non_system_msgs.end()) break;
|
||||
|
||||
// 找下一个 "assistant"
|
||||
// 找下一个 "assistant" / Find next "assistant"
|
||||
auto assistant_it = user_it + 1;
|
||||
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
|
||||
++assistant_it;
|
||||
@@ -278,7 +298,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
if (assistant_it == non_system_msgs.end()) {
|
||||
non_system_msgs.erase(user_it);
|
||||
} else {
|
||||
// 先删 assistant 再删 user 避免迭代器失效
|
||||
// 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
|
||||
non_system_msgs.erase(assistant_it);
|
||||
user_it = non_system_msgs.begin();
|
||||
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
|
||||
@@ -286,7 +306,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
}
|
||||
|
||||
// W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token),使用当前 max_tokens
|
||||
// W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token),使用当前 max_tokens / Message count upper bound rough estimate (~100 tokens per message), uses current max_tokens
|
||||
{
|
||||
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
|
||||
if (max_msg_count < 1) max_msg_count = 1;
|
||||
@@ -295,7 +315,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
}
|
||||
|
||||
// 组装结果
|
||||
// 组装结果 / Assemble result
|
||||
std::vector<TrimMessage> result;
|
||||
result.reserve(system_msgs.size() + non_system_msgs.size());
|
||||
result.insert(result.end(), system_msgs.begin(), system_msgs.end());
|
||||
@@ -306,7 +326,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * result_count));
|
||||
if (!*out) return -1;
|
||||
|
||||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息
|
||||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback on OOM
|
||||
for (int i = 0; i < result_count; ++i) {
|
||||
if (strdup_message_fields(&(*out)[i], result[i]) != 0) {
|
||||
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
|
||||
@@ -318,7 +338,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
// W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate()
|
||||
// W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate() / Prevent C++ exceptions (std::bad_alloc etc.) from crossing C ABI boundary -> std::terminate()
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl exception: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
@@ -328,10 +348,11 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Context 服务 vtable 实现
|
||||
// Context 服务 vtable 实现 / Context service vtable implementation
|
||||
// ============================================================
|
||||
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate()
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate() / Wrapped try/catch prevents exceptions crossing C ABI boundary -> std::terminate()
|
||||
// 对 C 消息数组进行 token 计数。输入为 null/空时返回 0 / Count tokens across an array of C messages. Returns 0 on null/empty input.
|
||||
static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
|
||||
try {
|
||||
if (!msgs || count <= 0) return 0;
|
||||
@@ -341,7 +362,8 @@ static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
|
||||
}
|
||||
}
|
||||
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
|
||||
// 裁剪消息列表以适应 max_tokens,返回新分配的主机内存数组 / Trim a message list to fit within max_tokens, returning a new host-allocated array.
|
||||
static int context_trim(const dstalk_message_t* in, int in_count,
|
||||
dstalk_message_t** out, int* out_count,
|
||||
size_t max_tokens) {
|
||||
@@ -355,21 +377,24 @@ static int context_trim(const dstalk_message_t* in, int in_count,
|
||||
// W18.1 (F-11.1-3): g_max_tokens / context_set_max_tokens 已移除。
|
||||
// max_tokens 由调用方通过 trim() 的 max_tokens 参数直接传入;
|
||||
// 传 0 时 trim_impl 使用硬编码默认值 4096。
|
||||
// g_max_tokens / context_set_max_tokens removed. max_tokens is passed directly
|
||||
// by caller via trim()'s max_tokens parameter; trim_impl uses hardcoded default 4096 when 0.
|
||||
static dstalk_context_service_t g_context_service = {
|
||||
context_count_tokens,
|
||||
context_trim
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
|
||||
// 插件初始化:保存主机指针,查询 session 依赖,注册 context 服务 / Plugin init: store host pointer, query session dependency, register context service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
try {
|
||||
g_host = host;
|
||||
|
||||
// 查询依赖服务: session
|
||||
// 查询依赖服务: session / Query dependency service: session
|
||||
void* raw = host->query_service("session", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
|
||||
@@ -387,7 +412,8 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
}
|
||||
|
||||
// W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log
|
||||
// W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log / Wrapped try/catch prevents exceptions crossing C ABI boundary -- void function only logs
|
||||
// 插件关闭:清空指针。try/catch 保护 ABI(void 函数) / Plugin shutdown: null out pointers. try/catch guards ABI (void function).
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
g_session = nullptr;
|
||||
@@ -406,7 +432,7 @@ static void on_shutdown() {
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"context",
|
||||
"1.0.0",
|
||||
"Context management plugin with token counting and trim support",
|
||||
"Context management plugin with token counting and trim support / 支持 token 计数和裁剪的上下文管理插件",
|
||||
DSTALK_API_VERSION,
|
||||
{"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||||
on_init,
|
||||
@@ -414,6 +440,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
nullptr
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file deepseek_plugin.cpp
|
||||
* @brief DeepSeek/OpenAI-compatible AI provider plugin with SSE streaming and tool calls.
|
||||
* DeepSeek/OpenAI 兼容 AI 提供者插件,支持 SSE 流式输出和工具调用。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
@@ -11,14 +18,14 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 全局指针:从 on_init 获取(W14.3: atomic acquire/release 保护读写竞态)
|
||||
// 全局指针:从 on_init 获取(W14.3: atomic acquire/release 保护读写竞态) / Global pointers: obtained from on_init (W14.3: atomic acquire/release protects read/write races)
|
||||
// ============================================================================
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
|
||||
static std::atomic<dstalk_config_service_t*> g_config{nullptr};
|
||||
|
||||
// ============================================================================
|
||||
// 配置数据(由 configure() 设置)
|
||||
// 配置数据(由 configure() 设置) / Config data (set by configure())
|
||||
// ============================================================================
|
||||
struct PluginConfig {
|
||||
std::string provider;
|
||||
@@ -29,19 +36,21 @@ struct PluginConfig {
|
||||
double temperature = 0.7;
|
||||
};
|
||||
static PluginConfig g_cfg;
|
||||
static std::string g_tools_json; // W20.2: cached by configure(), consumed by chat/chat_stream
|
||||
static std::string g_tools_json; // W20.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
|
||||
|
||||
// ============================================================================
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
|
||||
// ============================================================================
|
||||
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
|
||||
static void secure_zero(void* p, size_t n) {
|
||||
volatile char* vp = (volatile char*)p;
|
||||
while (n--) *vp++ = 0;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:从 base_url 提取 host 和 target
|
||||
// 辅助:从 base_url 提取 host 和 target / Helper: extract host and target from base_url
|
||||
// ============================================================================
|
||||
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
|
||||
static bool extract_host_port(const std::string& url,
|
||||
std::string& scheme_out, std::string& host_out,
|
||||
std::string& port_out, std::string& target_out)
|
||||
@@ -65,8 +74,9 @@ static bool extract_host_port(const std::string& url,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:构建 headers JSON 字符串
|
||||
// 辅助:构建 headers JSON 字符串 / Helper: build headers JSON string
|
||||
// ============================================================================
|
||||
// 构建包含 Bearer 授权令牌的 JSON headers 对象 / Build the JSON headers object containing the Bearer authorization token.
|
||||
static std::string build_headers_json(const std::string& auth_header_value)
|
||||
{
|
||||
json::object h;
|
||||
@@ -75,8 +85,9 @@ static std::string build_headers_json(const std::string& auth_header_value)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:dstalk_message_t[] -> boost::json::array
|
||||
// 辅助:dstalk_message_t[] -> boost::json::array / Helper: dstalk_message_t[] -> boost::json::array
|
||||
// ============================================================================
|
||||
// 将 dstalk_message_t 数组转换为 Boost.JSON 数组,用于 API 请求体 / Convert dstalk_message_t array into a Boost.JSON array for the API request body.
|
||||
static void append_history(json::array& msgs,
|
||||
const dstalk_message_t* history, int history_len)
|
||||
{
|
||||
@@ -100,8 +111,9 @@ static void append_history(json::array& msgs,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 构建 DeepSeek JSON 请求体
|
||||
// 构建 DeepSeek JSON 请求体 / Build DeepSeek JSON request body
|
||||
// ============================================================================
|
||||
// 构建 DeepSeek/OpenAI chat completions API 的完整 JSON 请求体 / Build the full JSON request body for the DeepSeek/OpenAI chat completions API.
|
||||
static std::string build_request_json(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const std::string& user_input,
|
||||
@@ -117,7 +129,7 @@ static std::string build_request_json(
|
||||
json::array msgs;
|
||||
append_history(msgs, history, history_len);
|
||||
|
||||
// 追加当前用户输入
|
||||
// 追加当前用户输入 / Append current user input
|
||||
if (!user_input.empty()) {
|
||||
json::object obj;
|
||||
obj["role"] = "user";
|
||||
@@ -127,7 +139,7 @@ static std::string build_request_json(
|
||||
|
||||
root["messages"] = msgs;
|
||||
|
||||
// tools 定义
|
||||
// tools 定义 / tools definition
|
||||
if (!tools_json.empty()) {
|
||||
root["tools"] = json::parse(tools_json);
|
||||
}
|
||||
@@ -136,8 +148,9 @@ static std::string build_request_json(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 解析非流式 JSON 响应
|
||||
// 解析非流式 JSON 响应 / Parse non-streaming JSON response
|
||||
// ============================================================================
|
||||
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
|
||||
static void parse_response(const dstalk_host_api_t* host,
|
||||
const char* body, int http_status,
|
||||
dstalk_chat_result_t& r)
|
||||
@@ -207,13 +220,13 @@ static void parse_response(const dstalk_host_api_t* host,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 流式上下文:在 SSE 回调间累积内容和 tool_calls
|
||||
// 流式上下文:在 SSE 回调间累积内容和 tool_calls / Stream context: accumulate content and tool_calls across SSE callbacks
|
||||
// ============================================================================
|
||||
struct ToolCallAccum {
|
||||
int index = -1;
|
||||
std::string id;
|
||||
std::string name;
|
||||
std::string arguments; // 增量拼接的 JSON arguments 字符串
|
||||
std::string arguments; // 增量拼接的 JSON arguments 字符串 / incrementally concatenated JSON arguments string
|
||||
};
|
||||
|
||||
struct StreamContext {
|
||||
@@ -222,12 +235,18 @@ struct StreamContext {
|
||||
void* userdata;
|
||||
std::string accumulated;
|
||||
bool streaming_ok = true;
|
||||
std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls
|
||||
std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls / accumulate delta tool_calls by index
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// SSE 行解析(OpenAI 兼容格式)
|
||||
// SSE 行解析(OpenAI 兼容格式) / SSE line parsing (OpenAI-compatible format)
|
||||
// ============================================================================
|
||||
// 解析单行 SSE "data:" 行。如果包含 content delta,将 token 写入 token_out。
|
||||
// 如果包含 tool_calls delta,累积到 ctx->tool_calls。
|
||||
// 如果产生了 content token 则返回 true,否则返回 false(tool_calls 或未知)。
|
||||
// Parse a single SSE "data:" line. If it contains a content delta, writes the token
|
||||
// to token_out. If it contains tool_calls delta, accumulates into ctx->tool_calls.
|
||||
// Returns true if a content token was produced, false otherwise (tool_calls or unknown).
|
||||
static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
StreamContext* ctx)
|
||||
{
|
||||
@@ -235,7 +254,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
|
||||
std::string data = line.substr(6);
|
||||
|
||||
// F-13.2-3: Trim leading/trailing whitespace before comparing [DONE] sentinel.
|
||||
// F-13.2-3: 比较 [DONE] 哨兵前去除首尾空白 / Trim leading/trailing whitespace before comparing [DONE] sentinel.
|
||||
const char* ws = " \t\r\n";
|
||||
size_t start = data.find_first_not_of(ws);
|
||||
if (start != std::string::npos) {
|
||||
@@ -244,7 +263,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
}
|
||||
if (data == "[DONE]") {
|
||||
token_out.clear();
|
||||
return true; // 流结束信号
|
||||
return true; // 流结束信号 / stream end signal
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -254,12 +273,12 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
if (!choices.empty()) {
|
||||
auto delta = choices[0].as_object()["delta"].as_object();
|
||||
|
||||
// W20.2: 处理 delta["tool_calls"] 增量 chunk
|
||||
// DeepSeek/OpenAI 流式模式 tool_calls 跨多个 SSE 事件分片传输:
|
||||
// 事件 1: {"index":0, "id":"call_xxx", "function":{"name":"foo"}}
|
||||
// 事件 2: {"index":0, "function":{"arguments":"{\"bar\":"}}
|
||||
// 事件 3: {"index":0, "function":{"arguments":"1}"}}
|
||||
// 需要按 index 累积 id/name/arguments。
|
||||
// W20.2: 处理 delta["tool_calls"] 增量 chunk / Handle delta["tool_calls"] incremental chunks
|
||||
// DeepSeek/OpenAI 流式模式 tool_calls 跨多个 SSE 事件分片传输 / DeepSeek/OpenAI streaming mode: tool_calls transmitted across multiple SSE event chunks:
|
||||
// 事件 1 / Event 1: {"index":0, "id":"call_xxx", "function":{"name":"foo"}}
|
||||
// 事件 2 / Event 2: {"index":0, "function":{"arguments":"{\"bar\":"}}
|
||||
// 事件 3 / Event 3: {"index":0, "function":{"arguments":"1}"}}
|
||||
// 需要按 index 累积 id/name/arguments / Need to accumulate id/name/arguments by index.
|
||||
if (delta.contains("tool_calls") && ctx) {
|
||||
auto tc_array = delta["tool_calls"].as_array();
|
||||
for (auto& tc_val : tc_array) {
|
||||
@@ -288,7 +307,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
return false; // tool_calls 已处理,无内容 token 给用户回调
|
||||
return false; // tool_calls 已处理,无内容 token 给用户回调 / tool_calls processed, no content token for user callback
|
||||
}
|
||||
|
||||
if (delta.contains("content")) {
|
||||
@@ -297,14 +316,15 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
}
|
||||
}
|
||||
} catch (...) {
|
||||
// 忽略解析失败
|
||||
// 忽略解析失败 / Ignore parse failures
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// configure 实现
|
||||
// configure 实现 / configure implementation
|
||||
// ============================================================================
|
||||
// 配置插件:provider、endpoint、auth、model 和生成参数 / Configure the plugin with provider, endpoint, auth, model, and generation parameters.
|
||||
static int my_configure(const char* provider, const char* base_url,
|
||||
const char* api_key, const char* model,
|
||||
int max_tokens, double temperature)
|
||||
@@ -319,7 +339,7 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) {
|
||||
// W20.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用
|
||||
// W20.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
|
||||
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
|
||||
host->query_service("tools", 1));
|
||||
if (tools_svc && tools_svc->get_tools_json) {
|
||||
@@ -348,8 +368,9 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// chat 实现
|
||||
// chat 实现 / chat implementation
|
||||
// ============================================================================
|
||||
// 非流式 chat completion:发送 history + user input,返回完整响应 / Non-streaming chat completion: send history + user input, return full response.
|
||||
static dstalk_chat_result_t my_chat(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
@@ -412,29 +433,29 @@ static dstalk_chat_result_t my_chat(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// chat_stream 实现
|
||||
// chat_stream 实现 / chat_stream implementation
|
||||
// ============================================================================
|
||||
|
||||
// 行回调:解析 SSE line,将 token 传递给用户回调
|
||||
// 行回调:解析 SSE line,将 token 传递给用户回调 / SSE line callback: parses each line and forwards content tokens to the user callback.
|
||||
static int sse_line_callback(const char* line, void* userdata)
|
||||
{
|
||||
try {
|
||||
auto* ctx = static_cast<StreamContext*>(userdata);
|
||||
if (!line || !line[0]) return 1; // 空行,继续
|
||||
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
|
||||
|
||||
std::string line_str(line);
|
||||
std::string token;
|
||||
|
||||
if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续
|
||||
if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续 / not a data/tool_calls line, continue
|
||||
|
||||
if (token.empty()) return 0; // [DONE],停止
|
||||
if (token.empty()) return 0; // [DONE],停止 / [DONE], stop
|
||||
|
||||
ctx->accumulated += token;
|
||||
|
||||
if (ctx->user_cb) {
|
||||
return ctx->user_cb(token.c_str(), ctx->userdata);
|
||||
}
|
||||
return 1; // 继续
|
||||
return 1; // 继续 / continue
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[deepseek] sse_line_callback exception: %s", e.what());
|
||||
@@ -446,6 +467,9 @@ static int sse_line_callback(const char* line, void* userdata)
|
||||
}
|
||||
}
|
||||
|
||||
// 流式 chat completion:以 stream=true 发送 history + user input,通过回调传递 token。
|
||||
// 在 SSE 分片中累积 tool_calls 并在结束时序列化 / Streaming chat completion: send history + user input with stream=true, deliver tokens
|
||||
// via callback. Accumulates tool_calls across SSE chunks and serializes them at end.
|
||||
static dstalk_chat_result_t my_chat_stream(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
@@ -488,10 +512,10 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
r.http_status = status_code;
|
||||
|
||||
// 检查传输层错误或非 2xx 状态
|
||||
// 检查传输层错误或非 2xx 状态 / Check transport errors or non-2xx status
|
||||
if (status_code < 200 || status_code >= 300) {
|
||||
r.ok = 0;
|
||||
// 尝试从响应体提取错误信息
|
||||
// 尝试从响应体提取错误信息 / Try to extract error info from response body
|
||||
if (response_body && response_body[0]) {
|
||||
try {
|
||||
auto jv = json::parse(response_body);
|
||||
@@ -518,7 +542,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
if (response_body && host) host->free(response_body);
|
||||
|
||||
// W20.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling)
|
||||
// W20.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling) / Success = has content OR has tool_calls (tool-only responses like function calling)
|
||||
bool has_content = !ctx.accumulated.empty();
|
||||
bool has_tool_calls = !ctx.tool_calls.empty();
|
||||
|
||||
@@ -533,7 +557,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
r.content = has_content
|
||||
? host->strdup(ctx.accumulated.c_str()) : nullptr;
|
||||
|
||||
// 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式)
|
||||
// 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible tool_calls format)
|
||||
if (has_tool_calls) {
|
||||
json::array tc_array;
|
||||
for (auto& tc : ctx.tool_calls) {
|
||||
@@ -572,8 +596,9 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// free_result 实现
|
||||
// free_result 实现 / free_result implementation
|
||||
// ============================================================================
|
||||
// 释放 chat result 结构体中所有主机分配的字符串字段 / Free all host-allocated string fields in a chat result struct.
|
||||
static void my_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
@@ -584,7 +609,7 @@ static void my_free_result(dstalk_chat_result_t* result)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 服务 vtable
|
||||
// 服务 vtable / Service vtable
|
||||
// ============================================================================
|
||||
static dstalk_ai_service_t g_service = {
|
||||
&my_configure,
|
||||
@@ -594,8 +619,9 @@ static dstalk_ai_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// 生命周期
|
||||
// 生命周期 / Lifecycle
|
||||
// ============================================================================
|
||||
// 插件初始化:查询 http 和 config 服务,注册 ai.deepseek 服务 / Plugin init: query http and config services, register ai.deepseek service.
|
||||
static int on_init(const dstalk_host_api_t* host)
|
||||
{
|
||||
try {
|
||||
@@ -624,6 +650,7 @@ static int on_init(const dstalk_host_api_t* host)
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:从内存安全擦除 API key,清空服务指针 / Plugin shutdown: securely erase API key from memory, null out service pointers.
|
||||
static void on_shutdown()
|
||||
{
|
||||
try {
|
||||
@@ -644,12 +671,12 @@ static void on_shutdown()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 插件描述符
|
||||
// 插件描述符 / Plugin descriptor
|
||||
// ============================================================================
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "deepseek-ai",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "DeepSeek AI provider (OpenAI-compatible API)",
|
||||
/* .description = */ "DeepSeek AI provider (OpenAI-compatible API) / DeepSeek AI 提供者 (OpenAI 兼容 API)",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ { "http", "config", NULL },
|
||||
/* .on_init = */ on_init,
|
||||
@@ -657,6 +684,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
/* .on_event = */ nullptr,
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
|
||||
{
|
||||
return &g_info;
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file file_io_plugin.cpp
|
||||
* @brief File I/O plugin: basic file read/write service.
|
||||
* 文件 I/O 插件:基本文件读写服务。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
@@ -6,20 +13,21 @@
|
||||
#include <cstring>
|
||||
|
||||
// ============================================================
|
||||
// Global state
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// Service implementations
|
||||
// 服务实现 / Service implementations
|
||||
// ============================================================
|
||||
// 读取文件全部内容到主机分配的缓冲区,调用方须通过 host->free 释放 / Read the entire contents of a file into a host-allocated buffer. Caller must free via host->free.
|
||||
static int file_read(const char* path, char** content) {
|
||||
if (!path || !content) return -1;
|
||||
|
||||
FILE* fp = fopen(path, "rb");
|
||||
if (!fp) return -1;
|
||||
|
||||
// Get file size
|
||||
// 获取文件大小 / Get file size
|
||||
fseek(fp, 0, SEEK_END);
|
||||
long fsize = ftell(fp);
|
||||
fseek(fp, 0, SEEK_SET);
|
||||
@@ -29,7 +37,7 @@ static int file_read(const char* path, char** content) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Allocate buffer via host allocator (+1 for null terminator)
|
||||
// 通过主机分配器分配缓冲区(+1 用于空终止符) / Allocate buffer via host allocator (+1 for null terminator)
|
||||
char* buf = (char*)g_host->alloc((size_t)fsize + 1);
|
||||
if (!buf) {
|
||||
fclose(fp);
|
||||
@@ -49,6 +57,7 @@ static int file_read(const char* path, char** content) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 将字符串写入文件,覆盖已有内容 / Write a string to a file, overwriting any existing content.
|
||||
static int file_write(const char* path, const char* content) {
|
||||
if (!path || !content) return -1;
|
||||
|
||||
@@ -68,28 +77,31 @@ static dstalk_file_io_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Plugin lifecycle
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
// 插件初始化:保存主机指针并注册 file_io 服务 / Plugin init: store host pointer and register the file_io service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
return host->register_service("file_io", 1, &g_service);
|
||||
}
|
||||
|
||||
// 插件关闭:无需清理 / Plugin shutdown: nothing to clean up.
|
||||
static void on_shutdown() {
|
||||
// nothing to clean up
|
||||
// 无需清理 / nothing to clean up
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"file-io", // name
|
||||
"1.0.0", // version
|
||||
"Basic file I/O service", // description
|
||||
"file-io", // name 名称
|
||||
"1.0.0", // version 版本
|
||||
"Basic file I/O service", // description 描述
|
||||
DSTALK_API_VERSION, // api_version
|
||||
{nullptr}, // dependencies (none)
|
||||
{nullptr}, // dependencies 依赖 (none)
|
||||
on_init, // on_init
|
||||
on_shutdown, // on_shutdown
|
||||
nullptr // on_event
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
/*
|
||||
* plugin-lsp — LSP (Language Server Protocol) 服务
|
||||
*
|
||||
* 自行管理语言服务器子进程,使用 JSON-RPC 2.0 over stdio 通信。
|
||||
* 无外部服务依赖(不依赖 http/config 等其他插件)。
|
||||
* @file lsp_plugin.cpp
|
||||
* @brief LSP plugin: Language Server Protocol JSON-RPC client for diagnostics, hover, completion.
|
||||
* LSP 插件:Language Server Protocol JSON-RPC 客户端,用于诊断、悬停、补全。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-lsp — LSP (Language Server Protocol) 服务 / LSP (Language Server Protocol) service
|
||||
//
|
||||
// 自行管理语言服务器子进程,使用 JSON-RPC 2.0 over stdio 通信 / Self-manages language server subprocess, communicates via JSON-RPC 2.0 over stdio.
|
||||
// 无外部服务依赖(不依赖 http/config 等其他插件) / No external service dependencies (does not depend on http/config or other plugins).
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
@@ -22,7 +27,7 @@
|
||||
#include <unordered_map>
|
||||
|
||||
// ============================================================================
|
||||
// 平台相关 — 子进程管理 (内嵌 subprocess::Process)
|
||||
// 平台相关 — 子进程管理 (内嵌 subprocess::Process) / Platform specific — subprocess management (embedded subprocess::Process)
|
||||
// ============================================================================
|
||||
|
||||
#ifdef _WIN32
|
||||
@@ -45,12 +50,12 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 全局指针
|
||||
// 全局指针 / Global pointers
|
||||
// ============================================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
|
||||
// ============================================================================
|
||||
// 子进程封装 (内嵌 subprocess.hpp)
|
||||
// 子进程封装 (内嵌 subprocess.hpp) / Subprocess wrapper (embedded subprocess.hpp)
|
||||
// ============================================================================
|
||||
struct Process {
|
||||
#ifdef _WIN32
|
||||
@@ -64,6 +69,7 @@ struct Process {
|
||||
int stdout_fd = -1;
|
||||
#endif
|
||||
|
||||
// 从给定命令行启动子进程。为 stdin/stdout 设置管道 / Start a child process from the given command line. Sets up pipes for stdin/stdout.
|
||||
bool start(const char* cmd) {
|
||||
if (!cmd || !cmd[0]) return false;
|
||||
stop();
|
||||
@@ -169,6 +175,7 @@ struct Process {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 优雅终止子进程,回退到 SIGKILL/TerminateProcess / Gracefully terminate the child process, with fallback to SIGKILL/TerminateProcess.
|
||||
void stop() {
|
||||
#ifdef _WIN32
|
||||
if (hProcess != INVALID_HANDLE_VALUE) {
|
||||
@@ -198,6 +205,7 @@ struct Process {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 将数据字符串写入子进程 stdin 管道 / Write a data string to the child's stdin pipe.
|
||||
bool write(const std::string& data) {
|
||||
if (data.empty()) return true;
|
||||
#ifdef _WIN32
|
||||
@@ -219,6 +227,7 @@ struct Process {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 从子进程 stdout 管道读取一行(到并包括 '\n') / Read one line (up to and including '\n') from the child's stdout pipe.
|
||||
bool read_line(std::string& line) {
|
||||
line.clear();
|
||||
#ifdef _WIN32
|
||||
@@ -242,6 +251,7 @@ struct Process {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 从子进程 stdout 管道读取恰好 count 字节到 buf / Read exactly `count` bytes from the child's stdout pipe into `buf`.
|
||||
bool read_bytes(std::string& buf, int count) {
|
||||
if (count <= 0) { buf.clear(); return true; }
|
||||
#ifdef _WIN32
|
||||
@@ -274,7 +284,7 @@ struct Process {
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// LSP 状态(静态单例)
|
||||
// LSP 状态(静态单例) / LSP state (static singleton)
|
||||
// ============================================================================
|
||||
struct LspState {
|
||||
Process proc;
|
||||
@@ -283,23 +293,24 @@ struct LspState {
|
||||
|
||||
std::atomic<int> next_id{1};
|
||||
|
||||
// 响应用于同步等待
|
||||
// 响应用于同步等待 / Responses for synchronous waiting
|
||||
std::mutex mutex;
|
||||
std::condition_variable cv;
|
||||
std::unordered_map<int, std::string> pending_responses;
|
||||
|
||||
// 诊断缓存: URI -> JSON 字符串
|
||||
// 诊断缓存: URI -> JSON 字符串 / Diagnostics cache: URI -> JSON string
|
||||
std::unordered_map<std::string, std::string> diagnostics;
|
||||
|
||||
// 读取线程
|
||||
// 读取线程 / Reader thread
|
||||
std::thread reader_thread;
|
||||
};
|
||||
static LspState g_lsp;
|
||||
|
||||
// ============================================================================
|
||||
// 辅助函数
|
||||
// 辅助函数 / Helper functions
|
||||
// ============================================================================
|
||||
|
||||
// 去除 string_view 首尾空白 / Trim leading and trailing whitespace from a string_view.
|
||||
static std::string_view trim(std::string_view sv) {
|
||||
while (!sv.empty() && (sv.front() == ' ' || sv.front() == '\t' ||
|
||||
sv.front() == '\r' || sv.front() == '\n'))
|
||||
@@ -310,6 +321,7 @@ static std::string_view trim(std::string_view sv) {
|
||||
return sv;
|
||||
}
|
||||
|
||||
// 将 JSON-RPC 消息体包装在 LSP 头中 (Content-Length: ...\r\n\r\n) / Wrap a JSON-RPC message body in an LSP header (Content-Length: ...\r\n\r\n).
|
||||
static std::string frame_message(const std::string& body) {
|
||||
std::string frame;
|
||||
frame.reserve(64 + body.size());
|
||||
@@ -320,6 +332,7 @@ static std::string frame_message(const std::string& body) {
|
||||
return frame;
|
||||
}
|
||||
|
||||
// 从 LSP 头行中解析 Content-Length 值。解析失败返回 -1 / Parse the Content-Length value from an LSP header line. Returns -1 on parse failure.
|
||||
static int parse_content_length(const std::string& line) {
|
||||
auto sv = trim(std::string_view(line));
|
||||
const char prefix[] = "Content-Length:";
|
||||
@@ -341,9 +354,10 @@ static int parse_content_length(const std::string& line) {
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// JSON-RPC 消息发送
|
||||
// JSON-RPC 消息发送 / JSON-RPC message sending
|
||||
// ============================================================================
|
||||
|
||||
// 向 LSP 服务器发送 JSON-RPC 请求并返回分配的请求 id / Send a JSON-RPC request to the LSP server and return the assigned request id.
|
||||
static int send_request(const std::string& method, const json::object& params) {
|
||||
int id = g_lsp.next_id.fetch_add(1);
|
||||
|
||||
@@ -358,6 +372,7 @@ static int send_request(const std::string& method, const json::object& params) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// 向 LSP 服务器发送 JSON-RPC 通知(无 id 字段,不期待响应) / Send a JSON-RPC notification to the LSP server (no id field, no response expected).
|
||||
static void send_notification(const std::string& method, const json::object& params) {
|
||||
json::object msg;
|
||||
msg["jsonrpc"] = "2.0";
|
||||
@@ -369,9 +384,12 @@ static void send_notification(const std::string& method, const json::object& par
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 消息处理
|
||||
// 消息处理 / Message handling
|
||||
// ============================================================================
|
||||
|
||||
// 分发接收到的 JSON-RPC 消息:将响应路由到待处理队列,
|
||||
// 处理 textDocument/publishDiagnostics 通知并存入诊断缓存 / Dispatch a received JSON-RPC message: route responses to pending queue,
|
||||
// handle textDocument/publishDiagnostics notifications into diagnostics cache.
|
||||
static void handle_message(const std::string& body) {
|
||||
try {
|
||||
json::value val;
|
||||
@@ -383,14 +401,14 @@ static void handle_message(const std::string& body) {
|
||||
catch (...) { return; }
|
||||
|
||||
if (msg.contains("id") && !msg.contains("method")) {
|
||||
// 响应 (有 id, 无 method)
|
||||
// 响应 (有 id, 无 method) / Response (has id, no method)
|
||||
int id = static_cast<int>(msg["id"].as_int64());
|
||||
std::lock_guard<std::mutex> lock(g_lsp.mutex);
|
||||
g_lsp.pending_responses[id] = body;
|
||||
g_lsp.cv.notify_all();
|
||||
|
||||
} else if (msg.contains("method") && !msg.contains("id")) {
|
||||
// 通知 (有 method, 无 id)
|
||||
// 通知 (有 method, 无 id) / Notification (has method, no id)
|
||||
std::string method;
|
||||
try { method = json::value_to<std::string>(msg["method"]); }
|
||||
catch (...) { return; }
|
||||
@@ -419,17 +437,18 @@ static void handle_message(const std::string& body) {
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 读取线程主循环
|
||||
// 读取线程主循环 / Reader thread main loop
|
||||
// ============================================================================
|
||||
|
||||
// 读取线程主循环:解析 LSP header+body 帧并分发消息 / Main loop for the reader thread: parse LSP header+body frames and dispatch messages.
|
||||
static void reader_loop() {
|
||||
try {
|
||||
while (g_lsp.running) {
|
||||
int content_length = -1;
|
||||
bool pipe_ok = true;
|
||||
|
||||
// 状态机式读取 header 块:循环 read_line 直到读到空行
|
||||
// LSP 3.17: header 块以空行(\r\n)结束,允许 Content-Type 等其他 header
|
||||
// 状态机式读取 header 块:循环 read_line 直到读到空行 / State-machine header block read: loop read_line until empty line
|
||||
// LSP 3.17: header 块以空行(\r\n)结束,允许 Content-Type 等其他 header / LSP 3.17: header block ends with empty line (\r\n), allows other headers like Content-Type
|
||||
while (pipe_ok) {
|
||||
std::string line;
|
||||
if (!g_lsp.proc.read_line(line)) {
|
||||
@@ -437,18 +456,18 @@ static void reader_loop() {
|
||||
break;
|
||||
}
|
||||
|
||||
// header 块以空行结束
|
||||
// header 块以空行结束 / header block ends with empty line
|
||||
auto sv = trim(std::string_view(line));
|
||||
if (sv.empty()) break;
|
||||
|
||||
// 累积 Content-Length;遇到其他 header 不丢弃,继续读取下一行
|
||||
// 累积 Content-Length;遇到其他 header 不丢弃,继续读取下一行 / Accumulate Content-Length; don't discard other headers, continue reading next line
|
||||
int len = parse_content_length(line);
|
||||
if (len >= 0) content_length = len;
|
||||
}
|
||||
|
||||
if (!pipe_ok) break;
|
||||
|
||||
// 空行前都没读到 Content-Length,协议错误——记日志并跳过这一帧
|
||||
// 空行前都没读到 Content-Length,协议错误——记日志并跳过这一帧 / Content-Length not read before empty line, protocol error — log and skip this frame
|
||||
if (content_length < 0) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] Invalid LSP frame: missing Content-Length header");
|
||||
continue;
|
||||
@@ -471,38 +490,39 @@ static void reader_loop() {
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LSP 服务 vtable 实现 (定义在 vtable 变量之前)
|
||||
// LSP 服务 vtable 实现 (定义在 vtable 变量之前) / LSP service vtable implementation (defined before vtable variable)
|
||||
// ============================================================================
|
||||
|
||||
static void g_lsp_impl_stop();
|
||||
static void g_lsp_impl_stop_nolock();
|
||||
static void g_lsp_impl_stop_locked(std::unique_lock<std::mutex>& lock);
|
||||
|
||||
// 启动 LSP 服务器进程,发送 initialize/initialized 握手,启动读取线程 / Start the LSP server process, send initialize/initialized handshake, start reader thread.
|
||||
static int g_lsp_impl_start(const char* server_cmd, const char* language) {
|
||||
if (!server_cmd || !server_cmd[0]) return -1;
|
||||
|
||||
try {
|
||||
// 如果已在运行, 先停止
|
||||
// 如果已在运行, 先停止 / If already running, stop first
|
||||
if (g_lsp.running) {
|
||||
g_lsp_impl_stop();
|
||||
}
|
||||
|
||||
g_lsp.language = language ? language : "";
|
||||
|
||||
// 启动进程
|
||||
// 启动进程 / Start process
|
||||
if (!g_lsp.proc.start(server_cmd)) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] failed to start: %s", server_cmd);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 重置 ID 计数器
|
||||
// 重置 ID 计数器 / Reset ID counter
|
||||
g_lsp.next_id = 1;
|
||||
|
||||
// 启动读取线程
|
||||
// 启动读取线程 / Start reader thread
|
||||
g_lsp.running = true;
|
||||
g_lsp.reader_thread = std::thread(reader_loop);
|
||||
|
||||
// 构建 initialize 参数
|
||||
// 构建 initialize 参数 / Build initialize params
|
||||
json::object text_doc_caps;
|
||||
{
|
||||
json::object hover;
|
||||
@@ -526,10 +546,10 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
|
||||
init_params["rootUri"] = nullptr;
|
||||
init_params["capabilities"] = capabilities;
|
||||
|
||||
// 发送 initialize 请求
|
||||
// 发送 initialize 请求 / Send initialize request
|
||||
int init_id = send_request("initialize", init_params);
|
||||
|
||||
// 等待 initialize 响应 (最多 10 秒)
|
||||
// 等待 initialize 响应 (最多 10 秒) / Wait for initialize response (max 10 seconds)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(g_lsp.mutex);
|
||||
bool got = g_lsp.cv.wait_for(lock, std::chrono::seconds(10), [init_id]() {
|
||||
@@ -544,7 +564,7 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
|
||||
g_lsp.pending_responses.erase(init_id);
|
||||
}
|
||||
|
||||
// 发送 initialized 通知
|
||||
// 发送 initialized 通知 / Send initialized notification
|
||||
send_notification("initialized", json::object{});
|
||||
|
||||
if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] server started: %s", server_cmd);
|
||||
@@ -558,14 +578,15 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
|
||||
}
|
||||
}
|
||||
|
||||
// 停止 LSP 服务器:发送 shutdown 请求,发送 exit 通知,停止进程和线程 / Stop the LSP server: send shutdown request, send exit notification, stop process & thread.
|
||||
static void g_lsp_impl_stop_nolock() {
|
||||
try {
|
||||
if (!g_lsp.running) return;
|
||||
|
||||
// 发送 shutdown 请求
|
||||
// 发送 shutdown 请求 / Send shutdown request
|
||||
int shutdown_id = send_request("shutdown", json::object{});
|
||||
|
||||
// 等待 shutdown 响应 (最多 2 秒)
|
||||
// 等待 shutdown 响应 (最多 2 秒) / Wait for shutdown response (max 2 seconds)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(g_lsp.mutex);
|
||||
g_lsp.cv.wait_for(lock, std::chrono::seconds(2), [shutdown_id]() {
|
||||
@@ -574,10 +595,10 @@ static void g_lsp_impl_stop_nolock() {
|
||||
g_lsp.pending_responses.clear();
|
||||
}
|
||||
|
||||
// 发送 exit 通知
|
||||
// 发送 exit 通知 / Send exit notification
|
||||
send_notification("exit", json::object{});
|
||||
|
||||
// 停止读取线程
|
||||
// 停止读取线程 / Stop reader thread
|
||||
g_lsp.running = false;
|
||||
g_lsp.proc.stop();
|
||||
|
||||
@@ -593,15 +614,18 @@ static void g_lsp_impl_stop_nolock() {
|
||||
}
|
||||
}
|
||||
|
||||
// 公开 stop:无锁获取(委托给 g_lsp_impl_stop_nolock) / Public stop: acquires no lock (delegates to g_lsp_impl_stop_nolock).
|
||||
static void g_lsp_impl_stop() {
|
||||
g_lsp_impl_stop_nolock();
|
||||
}
|
||||
|
||||
// Stop 辅助函数:在调用 g_lsp_impl_stop_nolock 前解锁给定的 unique_lock / Stop helper: unlocks the given unique_lock before calling g_lsp_impl_stop_nolock.
|
||||
static void g_lsp_impl_stop_locked(std::unique_lock<std::mutex>& lock) {
|
||||
lock.unlock();
|
||||
g_lsp_impl_stop_nolock();
|
||||
}
|
||||
|
||||
// 向 LSP 服务器发送 textDocument/didOpen 通知 / Send a textDocument/didOpen notification to the LSP server.
|
||||
static int g_lsp_impl_open_document(const char* uri, const char* content,
|
||||
const char* lang_id) {
|
||||
if (!g_lsp.running) return -1;
|
||||
@@ -628,6 +652,7 @@ static int g_lsp_impl_open_document(const char* uri, const char* content,
|
||||
}
|
||||
}
|
||||
|
||||
// 向 LSP 服务器发送 textDocument/didClose 通知 / Send a textDocument/didClose notification to the LSP server.
|
||||
static int g_lsp_impl_close_document(const char* uri) {
|
||||
if (!g_lsp.running) return -1;
|
||||
if (!uri) return -1;
|
||||
@@ -650,6 +675,7 @@ static int g_lsp_impl_close_document(const char* uri) {
|
||||
}
|
||||
}
|
||||
|
||||
// 返回给定文档 URI 的缓存诊断 JSON / Return the cached diagnostics JSON for the given document URI.
|
||||
static int g_lsp_impl_get_diagnostics(const char* uri, char** json_out) {
|
||||
if (!g_lsp.running) return -1;
|
||||
if (!uri || !json_out) return -1;
|
||||
@@ -674,6 +700,7 @@ static int g_lsp_impl_get_diagnostics(const char* uri, char** json_out) {
|
||||
}
|
||||
}
|
||||
|
||||
// 发送 textDocument/hover 请求并以 JSON 返回悬停结果 / Send a textDocument/hover request and return the hover result as JSON.
|
||||
static int g_lsp_impl_get_hover(const char* uri, int line, int col, char** json_out) {
|
||||
if (!g_lsp.running) return -1;
|
||||
if (!uri || !json_out) return -1;
|
||||
@@ -727,6 +754,7 @@ static int g_lsp_impl_get_hover(const char* uri, int line, int col, char** json_
|
||||
}
|
||||
}
|
||||
|
||||
// 发送 textDocument/completion 请求并以 JSON 返回补全列表 / Send a textDocument/completion request and return the completion list as JSON.
|
||||
static int g_lsp_impl_get_completion(const char* uri, int line, int col, char** json_out) {
|
||||
if (!g_lsp.running) return -1;
|
||||
if (!uri || !json_out) return -1;
|
||||
@@ -781,7 +809,7 @@ static int g_lsp_impl_get_completion(const char* uri, int line, int col, char**
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 服务 vtable
|
||||
// 服务 vtable / Service vtable
|
||||
// ============================================================================
|
||||
|
||||
static dstalk_lsp_service_t g_service_vtable = {
|
||||
@@ -795,15 +823,17 @@ static dstalk_lsp_service_t g_service_vtable = {
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// 生命周期回调
|
||||
// 生命周期回调 / Lifecycle callbacks
|
||||
// ============================================================================
|
||||
|
||||
// 插件初始化:保存主机指针并注册 lsp 服务 / Plugin init: store host pointer and register the lsp service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] initializing LSP service plugin");
|
||||
return host->register_service("lsp", 1, &g_service_vtable);
|
||||
}
|
||||
|
||||
// 插件关闭:如果运行中则停止 LSP 服务器,清空主机指针 / Plugin shutdown: stop LSP server if running, null out host pointer.
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
if (g_lsp.running) {
|
||||
@@ -821,20 +851,21 @@ static void on_shutdown() {
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 插件描述符
|
||||
// 插件描述符 / Plugin descriptor
|
||||
// ============================================================================
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "lsp",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "Language Server Protocol client (subprocess manager)",
|
||||
/* .description = */ "Language Server Protocol client (subprocess manager) / Language Server Protocol 客户端(子进程管理器)",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ { NULL }, // 无依赖,自行管理子进程
|
||||
/* .dependencies = */ { NULL }, // 无依赖,自行管理子进程 / No dependencies, self-manages subprocess
|
||||
/* .on_init = */ on_init,
|
||||
/* .on_shutdown = */ on_shutdown,
|
||||
/* .on_event = */ nullptr,
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
// MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20)
|
||||
/*
|
||||
* @file network_plugin.cpp
|
||||
* @brief Network plugin: HTTP/HTTPS POST and streaming via Boost.Beast + OpenSSL.
|
||||
* 网络插件:基于 Boost.Beast + OpenSSL 的 HTTP/HTTPS POST 和流式传输。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// MSVC 14.16 (VS 2017) 不提供 std::to_address (C++20) / MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20)
|
||||
#define BOOST_ASIO_DISABLE_STD_TO_ADDRESS
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
@@ -29,21 +36,22 @@ namespace ssl = boost::asio::ssl;
|
||||
using tcp = asio::ip::tcp;
|
||||
|
||||
// ============================================================
|
||||
// Global state
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
static dstalk_config_service_t* g_config_svc = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// Minimal JSON header parser
|
||||
// Parses {"key1":"value1","key2":"value2"} into unordered_map
|
||||
// 极简 JSON 头解析器 / Minimal JSON header parser
|
||||
// 将 {"key1":"value1","key2":"value2"} 解析到 unordered_map / Parses {"key1":"value1","key2":"value2"} into unordered_map
|
||||
// ============================================================
|
||||
// 将扁平 JSON 对象中的字符串键值对解析到 unordered_map / Parse a flat JSON object of string key-value pairs into an unordered_map.
|
||||
static std::unordered_map<std::string, std::string> parse_headers_json(const char* json) {
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
if (!json || !*json) return headers;
|
||||
|
||||
std::string s(json);
|
||||
// Very simple state-machine parser for flat string-key/value objects
|
||||
// 极简状态机解析器,处理扁平的字符串键值对象 / Very simple state-machine parser for flat string-key/value objects
|
||||
enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE;
|
||||
std::string current_key;
|
||||
std::string current_value;
|
||||
@@ -64,7 +72,7 @@ static std::unordered_map<std::string, std::string> parse_headers_json(const cha
|
||||
break;
|
||||
case IN_VALUE:
|
||||
if (c == '"') {
|
||||
// Read until closing quote
|
||||
// 读取到闭合引号 / Read until closing quote
|
||||
++i;
|
||||
while (i < s.size() && s[i] != '"') {
|
||||
if (s[i] == '\\' && i + 1 < s.size()) { current_value += s[++i]; }
|
||||
@@ -81,7 +89,7 @@ static std::unordered_map<std::string, std::string> parse_headers_json(const cha
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// HTTP Client implementation (adapted from dstalk-core HttpClient)
|
||||
// HTTP 客户端实现(改编自 dstalk-core HttpClient) / HTTP Client implementation (adapted from dstalk-core HttpClient)
|
||||
// ============================================================
|
||||
struct HttpClientCtx {
|
||||
asio::io_context ioc;
|
||||
@@ -91,15 +99,22 @@ struct HttpClientCtx {
|
||||
|
||||
HttpClientCtx() {
|
||||
ssl_ctx.set_default_verify_paths();
|
||||
// Enable peer certificate verification (CVSS 7.4 fix).
|
||||
// set_default_verify_paths() loads system CA bundle; without verify_peer
|
||||
// 启用对等证书验证 (CVSS 7.4 修复) / Enable peer certificate verification (CVSS 7.4 fix).
|
||||
// set_default_verify_paths() 加载系统 CA 包;没有 verify_peer
|
||||
// CA 存储不会被查询——任何证书(自签名/过期)都将被接受 / set_default_verify_paths() loads system CA bundle; without verify_peer
|
||||
// the CA store is never consulted — any cert (self-signed/expired) is accepted.
|
||||
// TODO: Windows: set_default_verify_paths() may not locate system CAs;
|
||||
// TODO: Windows: set_default_verify_paths() 可能无法定位系统 CA;
|
||||
// 如果验证失败,设置 SSL_CERT_FILE 环境变量或捆绑 cacert.pem / Windows: set_default_verify_paths() may not locate system CAs;
|
||||
// if verification fails, set SSL_CERT_FILE env or bundle a cacert.pem.
|
||||
ssl_ctx.set_verify_mode(ssl::verify_peer);
|
||||
}
|
||||
};
|
||||
|
||||
// 核心 HTTP/HTTPS POST,支持可选 SSE 流式传输。执行 DNS 解析、
|
||||
// TLS 握手(含 SNI 和主机名验证),然后发送请求。
|
||||
// 如果 cb 非空,响应体将逐行解析用于流式传输 / Core HTTP/HTTPS POST with optional SSE streaming. Performs DNS resolve,
|
||||
// TLS handshake with SNI and hostname verification, then sends the request.
|
||||
// If `cb` is non-null, response body is parsed line-by-line for streaming.
|
||||
static int do_post_stream(
|
||||
const char* host,
|
||||
const char* port,
|
||||
@@ -117,11 +132,11 @@ static int do_post_stream(
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Initialize output
|
||||
// 初始化输出 / Initialize output
|
||||
*response_body = nullptr;
|
||||
*status_code = -1;
|
||||
|
||||
// Build C++ lambda from C callback
|
||||
// 从 C 回调构建 C++ lambda / Build C++ lambda from C callback
|
||||
std::function<bool(const std::string&)> on_line;
|
||||
if (cb) {
|
||||
on_line = [cb, userdata](const std::string& line) -> bool {
|
||||
@@ -131,7 +146,7 @@ static int do_post_stream(
|
||||
|
||||
HttpClientCtx ctx;
|
||||
|
||||
// Read timeouts from config if available
|
||||
// 从配置读取超时设置 / Read timeouts from config if available
|
||||
if (g_config_svc) {
|
||||
const char* ct = g_config_svc->get("http.connect_timeout");
|
||||
const char* rt = g_config_svc->get("http.request_timeout");
|
||||
@@ -147,7 +162,9 @@ static int do_post_stream(
|
||||
try {
|
||||
tcp::resolver resolver(ctx.ioc);
|
||||
|
||||
// DNS resolve with 10-second timeout. Boost.Asio's synchronous
|
||||
// DNS 解析,10 秒超时。Boost.Asio 的同步 resolve()
|
||||
// 在内部运行 io_context,因此定时器的 async_wait 回调在 resolve() 期间执行,
|
||||
// 并在超时触发时调用 resolver.cancel() / DNS resolve with 10-second timeout. Boost.Asio's synchronous
|
||||
// resolve() runs the io_context internally, so the timer's async_wait
|
||||
// callback executes during resolve() and calls resolver.cancel() when
|
||||
// the deadline fires.
|
||||
@@ -172,7 +189,7 @@ static int do_post_stream(
|
||||
beast::ssl_stream<beast::tcp_stream> stream(ctx.ioc, ctx.ssl_ctx);
|
||||
beast::flat_buffer buffer;
|
||||
|
||||
// SNI hostname
|
||||
// SNI 主机名 / SNI hostname
|
||||
if (!SSL_set_tlsext_host_name(stream.native_handle(), host)) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR,
|
||||
"do_post_stream: SNI hostname set failed for %s", host);
|
||||
@@ -180,7 +197,9 @@ static int do_post_stream(
|
||||
goto done;
|
||||
}
|
||||
|
||||
// Hostname verification: require server certificate CN/SAN to match
|
||||
// 主机名验证:要求服务器证书 CN/SAN 匹配 'host'。
|
||||
// 与 ssl::verify_peer 协同工作——没有它的话,
|
||||
// 使用不同主机名的有效 CA 签名证书进行 MITM 攻击仍可通过 / Hostname verification: require server certificate CN/SAN to match
|
||||
// 'host'. This works in conjunction with ssl::verify_peer on the
|
||||
// context — without it MITM with a valid CA-signed cert for a
|
||||
// different hostname would still pass.
|
||||
@@ -191,19 +210,19 @@ static int do_post_stream(
|
||||
goto done;
|
||||
}
|
||||
|
||||
// Connect
|
||||
// 连接 / Connect
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.connect_timeout));
|
||||
beast::get_lowest_layer(stream).connect(endpoints);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// SSL handshake
|
||||
// SSL 握手 / SSL handshake
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.connect_timeout));
|
||||
stream.handshake(ssl::stream_base::client);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// Build HTTP POST request
|
||||
// 构建 HTTP POST 请求 / Build HTTP POST request
|
||||
http::request<http::string_body> req{http::verb::post, target, 11};
|
||||
req.set(http::field::host, host);
|
||||
req.set(http::field::user_agent, "dstalk/0.1");
|
||||
@@ -211,19 +230,19 @@ static int do_post_stream(
|
||||
req.body() = body;
|
||||
req.prepare_payload();
|
||||
|
||||
// Add extra headers from JSON
|
||||
// 从 JSON 添加额外的头 / Add extra headers from JSON
|
||||
auto extra_headers = parse_headers_json(headers_json);
|
||||
for (const auto& h : extra_headers) {
|
||||
req.set(h.first, h.second);
|
||||
}
|
||||
|
||||
// Send
|
||||
// 发送 / Send
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.request_timeout));
|
||||
http::write(stream, req);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// Read response
|
||||
// 读取响应 / Read response
|
||||
http::response_parser<http::string_body> parser;
|
||||
parser.body_limit(16 * 1024 * 1024);
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
@@ -310,8 +329,9 @@ done:
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Service implementations
|
||||
// 服务实现 / Service implementations
|
||||
// ============================================================
|
||||
// 同步 HTTP POST,返回完整响应体 / Synchronous HTTP POST returning the complete response body.
|
||||
static int http_post_json(
|
||||
const char* host, const char* port,
|
||||
const char* target, const char* body,
|
||||
@@ -322,6 +342,7 @@ static int http_post_json(
|
||||
nullptr, nullptr, response_body, status_code);
|
||||
}
|
||||
|
||||
// HTTP POST 带 SSE 流式传输:响应行到达时通过 cb 回调传递 / HTTP POST with SSE streaming: response lines are delivered to `cb` as they arrive.
|
||||
static int http_post_stream(
|
||||
const char* host, const char* port,
|
||||
const char* target, const char* body,
|
||||
@@ -339,32 +360,35 @@ static dstalk_http_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Plugin lifecycle
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
// 插件初始化:保存主机指针,查询 config 服务,注册 http 服务 / Plugin init: store host pointer, query config service, register http service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
|
||||
// Query config service (declared dependency)
|
||||
// 查询 config 服务(声明的依赖) / Query config service (declared dependency)
|
||||
g_config_svc = (dstalk_config_service_t*)host->query_service("config", 1);
|
||||
|
||||
return host->register_service("http", 1, &g_service);
|
||||
}
|
||||
|
||||
// 插件关闭:无需清理 / Plugin shutdown: nothing to clean up.
|
||||
static void on_shutdown() {
|
||||
// nothing to clean up
|
||||
// 无需清理 / nothing to clean up
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"http", // name
|
||||
"1.0.0", // version
|
||||
"HTTP/HTTPS client service using Boost.Beast + OpenSSL", // description
|
||||
"http", // name 名称
|
||||
"1.0.0", // version 版本
|
||||
"HTTP/HTTPS client service using Boost.Beast + OpenSSL", // description 描述
|
||||
DSTALK_API_VERSION, // api_version
|
||||
{"config", nullptr}, // dependencies
|
||||
{"config", nullptr}, // dependencies 依赖
|
||||
on_init, // on_init
|
||||
on_shutdown, // on_shutdown
|
||||
nullptr // on_event
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
// plugin-session: 会话管理服务插件
|
||||
// 提供 dstalk_session_service_t vtable 实现
|
||||
// 依赖: file_io (save/load 需要文件操作)
|
||||
/*
|
||||
* @file session_plugin.cpp
|
||||
* @brief Session plugin: conversation message history management with save/load.
|
||||
* 会话插件:对话消息历史管理,支持保存/加载。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-session: 会话管理服务插件 / Session management service plugin
|
||||
// 提供 dstalk_session_service_t vtable 实现 / Provides dstalk_session_service_t vtable implementation
|
||||
// 依赖: file_io (save/load 需要文件操作) / Depends on: file_io (save/load needs file operations)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_types.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
@@ -24,14 +31,14 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================
|
||||
// 内部 C++ 数据结构
|
||||
// 内部 C++ 数据结构 / Internal C++ data structures
|
||||
// ============================================================
|
||||
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
|
||||
|
||||
// 内部消息结构(C++ 易用,外部暴露 C struct)
|
||||
// 内部消息结构(C++ 易用,外部暴露 C struct) / Internal message struct (C++ friendly, externally exposed as C struct)
|
||||
struct InternalMessage {
|
||||
std::string role;
|
||||
std::string content;
|
||||
@@ -39,21 +46,24 @@ struct InternalMessage {
|
||||
std::string tool_calls_json;
|
||||
};
|
||||
|
||||
// 会话历史 + 缓存 —— W14.3: mutex 保护读写
|
||||
// 会话历史 + 缓存 —— W14.3: mutex 保护读写 / Session history + cache — W14.3: mutex protects read/write
|
||||
static std::vector<InternalMessage> g_history;
|
||||
static std::vector<dstalk_message_t> g_cached_history;
|
||||
static std::mutex g_session_mutex;
|
||||
|
||||
// ============================================================
|
||||
// Token 计数工具(内联,避免硬依赖 context 头文件)
|
||||
// Token 计数工具(内联,避免硬依赖 context 头文件) / Token counting utilities (inline, avoids hard dep on context headers)
|
||||
// ============================================================
|
||||
|
||||
// 如果字节是 ASCII (0x00–0x7F) 则返回 true / Returns true if the byte is ASCII (0x00–0x7F).
|
||||
static bool is_ascii(unsigned char c) { return c < 0x80; }
|
||||
|
||||
// 启发式判断:如果字节起始一个 UTF-8 CJK 统一表意文字 (0xE4–0xE9) 则返回 true / Heuristic: returns true if the byte starts a CJK Unified Ideograph in UTF-8 (0xE4–0xE9).
|
||||
static bool starts_cjk(unsigned char c) {
|
||||
return c >= 0xE4 && c <= 0xE9;
|
||||
}
|
||||
|
||||
// 使用启发式 UTF-8 字节计数估算单条消息的 token 数 / Estimate token count for a single message using heuristic UTF-8 byte counting.
|
||||
static size_t count_tokens_one(const std::string& text) {
|
||||
size_t ascii_chars = 0;
|
||||
size_t chinese_chars = 0;
|
||||
@@ -85,9 +95,10 @@ static size_t count_tokens_one(const std::string& text) {
|
||||
}
|
||||
|
||||
size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3);
|
||||
return content_tokens + 4; // +4 per message overhead
|
||||
return content_tokens + 4; // +4 每条消息开销 / +4 per message overhead
|
||||
}
|
||||
|
||||
// 估算所有消息的总 token 数 / Estimate total token count across all messages.
|
||||
static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) {
|
||||
size_t total = 0;
|
||||
for (const auto& m : msgs) {
|
||||
@@ -97,13 +108,15 @@ static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex)
|
||||
// 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex) / Helper: rebuild C cached array (caller must hold g_session_mutex)
|
||||
// ============================================================
|
||||
|
||||
// 从内部消息 vector 重建 C 兼容的缓存历史数组。调用方必须持有 g_session_mutex / Rebuild the C-compatible cached history array from the internal message vector.
|
||||
// Caller must hold g_session_mutex.
|
||||
static void rebuild_cached_history_locked() {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
|
||||
// 释放旧的字符串
|
||||
// 释放旧的字符串 / Free old strings
|
||||
for (auto& m : g_cached_history) {
|
||||
if (m.role) { host->free(const_cast<char*>(m.role)); }
|
||||
if (m.content) { host->free(const_cast<char*>(m.content)); }
|
||||
@@ -112,7 +125,7 @@ static void rebuild_cached_history_locked() {
|
||||
}
|
||||
g_cached_history.clear();
|
||||
|
||||
// 重建
|
||||
// 重建 / Rebuild
|
||||
g_cached_history.reserve(g_history.size());
|
||||
for (const auto& im : g_history) {
|
||||
dstalk_message_t cm;
|
||||
@@ -125,9 +138,10 @@ static void rebuild_cached_history_locked() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Session 服务 vtable 实现 (W14.3: try/catch + mutex)
|
||||
// Session 服务 vtable 实现 (W14.3: try/catch + mutex) / Session service vtable implementation (W14.3: try/catch + mutex)
|
||||
// ============================================================
|
||||
|
||||
// 向对话历史追加一条消息 / Append a message to the conversation history.
|
||||
static void session_add(const dstalk_message_t* msg) {
|
||||
try {
|
||||
if (!msg) return;
|
||||
@@ -148,11 +162,13 @@ static void session_add(const dstalk_message_t* msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// 清空对话历史中的所有消息 / Clear all messages from the conversation history.
|
||||
static void session_clear() {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
g_history.clear();
|
||||
}
|
||||
|
||||
// 将当前对话历史序列化为 JSON 行文件并保存到 path / Serialize the current conversation history to a JSON lines file at `path`.
|
||||
static int session_save(const char* path) {
|
||||
try {
|
||||
if (!path) return -1;
|
||||
@@ -187,6 +203,7 @@ static int session_save(const char* path) {
|
||||
}
|
||||
}
|
||||
|
||||
// 从 JSON 行文件中加载对话历史,替换当前历史 / Load conversation history from a JSON lines file at `path`, replacing current history.
|
||||
static int session_load(const char* path) {
|
||||
try {
|
||||
if (!path) return -1;
|
||||
@@ -246,6 +263,7 @@ static int session_load(const char* path) {
|
||||
}
|
||||
}
|
||||
|
||||
// 返回指向缓存 C 消息数组的指针,并将 *out_count 设置为数组大小 / Return a pointer to the cached C-array of messages and set *out_count to its size.
|
||||
static const dstalk_message_t* session_history(int* out_count) {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
@@ -265,6 +283,7 @@ static const dstalk_message_t* session_history(int* out_count) {
|
||||
}
|
||||
}
|
||||
|
||||
// 返回当前对话历史的估算 token 数 / Return the estimated token count for the current conversation history.
|
||||
static int session_token_count() {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
@@ -290,11 +309,12 @@ static dstalk_session_service_t g_session_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// W20.6: 默认会话保存路径(平台标准目录)
|
||||
// W20.6: 默认会话保存路径(平台标准目录) / Default session save path (platform standard directory)
|
||||
// ============================================================
|
||||
|
||||
// 返回平台特定的默认会话保存路径,根据需要创建目录 / Return the platform-specific default session save path, creating directories as needed.
|
||||
static std::string get_default_session_path() {
|
||||
// W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录
|
||||
// W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录 / static cache + mkdir guarantee + fallback to current dir on failure
|
||||
static std::string cached_path = []() -> std::string {
|
||||
#ifdef _WIN32
|
||||
char* buf = nullptr;
|
||||
@@ -323,14 +343,17 @@ static std::string get_default_session_path() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
|
||||
// 插件初始化:保存主机指针,查询 file_io 依赖,注册 session 服务,
|
||||
// 并从默认路径自动加载已有会话 / Plugin init: store host pointer, query file_io dependency, register session service,
|
||||
// and auto-load any existing session from the default path.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
try {
|
||||
g_host.store(host, std::memory_order_release);
|
||||
|
||||
// 查询依赖服务: file_io
|
||||
// 查询依赖服务: file_io / Query dependency service: file_io
|
||||
void* raw = host->query_service("file_io", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-session] required service 'file_io' not found");
|
||||
@@ -338,11 +361,11 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
|
||||
|
||||
// 注册自身服务
|
||||
// 注册自身服务 / Register own service
|
||||
int ret = host->register_service("session", 1, &g_session_service);
|
||||
if (ret != 0) return ret;
|
||||
|
||||
// W20.6: 从默认路径恢复会话(文件不存在则静默失败)
|
||||
// W20.6: 从默认路径恢复会话(文件不存在则静默失败) / Restore session from default path (silent fail if file missing)
|
||||
session_load(get_default_session_path().c_str());
|
||||
|
||||
return 0;
|
||||
@@ -357,10 +380,13 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:自动保存会话到默认路径,失败时回退到当前目录,
|
||||
// 然后释放缓存历史和清空状态 / Plugin shutdown: auto-save session to default path, fallback to current dir on failure,
|
||||
// then release cached history and clear state.
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
// W20.6: 清空前自动保存到默认路径
|
||||
// W21.4: 失败告警 + 当前目录 fallback
|
||||
// W20.6: 清空前自动保存到默认路径 / Auto-save to default path before clearing
|
||||
// W21.4: 失败告警 + 当前目录 fallback / Failure warning + current dir fallback
|
||||
int ret = session_save(get_default_session_path().c_str());
|
||||
if (ret != 0) {
|
||||
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
|
||||
@@ -389,7 +415,7 @@ static void on_shutdown() {
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"session",
|
||||
"1.0.0",
|
||||
"Session management plugin with save/load support",
|
||||
"Session management plugin with save/load support / 支持保存/加载的会话管理插件",
|
||||
DSTALK_API_VERSION,
|
||||
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||||
on_init,
|
||||
@@ -397,6 +423,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
nullptr
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
// plugin-tools: 工具注册服务插件
|
||||
// 提供 dstalk_tools_service_t vtable 实现
|
||||
// 依赖: file_io (内置 file_read / file_write 工具)
|
||||
/*
|
||||
* @file tools_plugin.cpp
|
||||
* @brief Tools plugin: tool registration, schema management, and execution registry.
|
||||
* 工具插件:工具注册、schema 管理和执行注册表。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-tools: 工具注册服务插件 / Tool registration service plugin
|
||||
// 提供 dstalk_tools_service_t vtable 实现 / Provides dstalk_tools_service_t vtable implementation
|
||||
// 依赖: file_io (内置 file_read / file_write 工具) / Depends on: file_io (built-in file_read / file_write tools)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_types.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
@@ -20,21 +27,22 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================
|
||||
// 路径安全校验 (W14.3: 防止路径遍历攻击)
|
||||
// 路径安全校验 (W14.3: 防止路径遍历攻击) / Path safety validation (W14.3: prevent path traversal attacks)
|
||||
// ============================================================
|
||||
|
||||
// 验证文件路径是否安全(无绝对路径、无 ".." 遍历、非空) / Validate that a file path is safe (no absolute paths, no ".." traversal, no empty).
|
||||
static bool is_safe_path(const std::string& path) {
|
||||
// 拒绝空路径
|
||||
// 拒绝空路径 / Reject empty path
|
||||
if (path.empty()) return false;
|
||||
|
||||
// 拒绝绝对路径: Unix '/' 开头 或 Windows 盘符 (第二字符 ':')
|
||||
// 拒绝绝对路径: Unix '/' 开头 或 Windows 盘符 (第二字符 ':') / Reject absolute paths: Unix '/' prefix or Windows drive letter (second char ':')
|
||||
if (path[0] == '/' || path[0] == '\\') return false;
|
||||
if (path.size() >= 2 && path[1] == ':') return false;
|
||||
|
||||
// 拒绝含 ".." 段的目录遍历
|
||||
// 拒绝含 ".." 段的目录遍历 / Reject directory traversal with ".." segments
|
||||
if (path.find("..") != std::string::npos) return false;
|
||||
|
||||
// lexical_normal 消解相对组件后再次校验
|
||||
// lexical_normal 消解相对组件后再次校验 / Re-validate after resolving relative components with lexical_normal
|
||||
std::string norm = std::filesystem::path(path).lexically_normal().string();
|
||||
if (norm.empty()) return false;
|
||||
if (norm[0] == '/' || norm[0] == '\\') return false;
|
||||
@@ -45,10 +53,10 @@ static bool is_safe_path(const std::string& path) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 内部数据结构
|
||||
// 内部数据结构 / Internal data structures
|
||||
// ============================================================
|
||||
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
|
||||
|
||||
@@ -59,14 +67,15 @@ struct ToolDef {
|
||||
dstalk_tool_handler_fn handler;
|
||||
};
|
||||
|
||||
// W14.3: g_tools 使用 mutex 保护读写
|
||||
// W14.3: g_tools 使用 mutex 保护读写 / g_tools uses mutex to protect read/write
|
||||
static std::vector<ToolDef> g_tools;
|
||||
static std::mutex g_tools_mutex;
|
||||
|
||||
// ============================================================
|
||||
// 内置工具: file_read, file_write
|
||||
// 内置工具: file_read, file_write / Built-in tools: file_read, file_write
|
||||
// ============================================================
|
||||
|
||||
// 内置工具处理器:读取文件并以 JSON 字符串返回内容 / Built-in tool handler: read a file and return its contents as a JSON string.
|
||||
static char* builtin_file_read(const char* args_json) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
|
||||
@@ -83,7 +92,7 @@ static char* builtin_file_read(const char* args_json) {
|
||||
}
|
||||
std::string path = json::value_to<std::string>(*path_j);
|
||||
|
||||
// W14.3: 路径遍历防护
|
||||
// W14.3: 路径遍历防护 / Path traversal protection
|
||||
if (!is_safe_path(path)) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_read: unsafe path rejected");
|
||||
return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr;
|
||||
@@ -110,6 +119,7 @@ static char* builtin_file_read(const char* args_json) {
|
||||
}
|
||||
}
|
||||
|
||||
// 内置工具处理器:将内容写入文件,返回成功/错误 JSON 对象 / Built-in tool handler: write content to a file, returning a success/error JSON object.
|
||||
static char* builtin_file_write(const char* args_json) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
|
||||
@@ -132,7 +142,7 @@ static char* builtin_file_write(const char* args_json) {
|
||||
std::string path = json::value_to<std::string>(*path_j);
|
||||
std::string content = json::value_to<std::string>(*content_j);
|
||||
|
||||
// W14.3: 路径遍历防护
|
||||
// W14.3: 路径遍历防护 / Path traversal protection
|
||||
if (!is_safe_path(path)) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_write: unsafe path rejected");
|
||||
return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr;
|
||||
@@ -155,18 +165,19 @@ static char* builtin_file_write(const char* args_json) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Tools 服务 vtable 实现 (W14.3: try/catch + mutex)
|
||||
// Tools 服务 vtable 实现 (W14.3: try/catch + mutex) / Tools service vtable implementation (W14.3: try/catch + mutex)
|
||||
// ============================================================
|
||||
|
||||
static void tools_unregister_tool(const char* name);
|
||||
|
||||
// 注册命名工具及其描述、JSON Schema 参数和处理函数 / Register a named tool with its description, JSON Schema parameters, and handler function.
|
||||
static int tools_register_tool(const char* name, const char* desc,
|
||||
const char* params_schema,
|
||||
dstalk_tool_handler_fn handler) {
|
||||
try {
|
||||
if (!name || !handler) return -1;
|
||||
|
||||
// 如果已存在同名工具,先注销
|
||||
// 如果已存在同名工具,先注销 / If a tool with the same name exists, unregister first
|
||||
tools_unregister_tool(name);
|
||||
|
||||
ToolDef td;
|
||||
@@ -189,6 +200,7 @@ static int tools_register_tool(const char* name, const char* desc,
|
||||
}
|
||||
}
|
||||
|
||||
// 按名称注销之前注册的工具 / Unregister a previously registered tool by name.
|
||||
static void tools_unregister_tool(const char* name) {
|
||||
try {
|
||||
if (!name) return;
|
||||
@@ -207,6 +219,7 @@ static void tools_unregister_tool(const char* name) {
|
||||
}
|
||||
}
|
||||
|
||||
// 将所有已注册工具序列化为 OpenAI function-calling 格式的 JSON 数组 / Serialize all registered tools into a JSON array in OpenAI function-calling format.
|
||||
static char* tools_get_tools_json() {
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
@@ -249,6 +262,7 @@ static char* tools_get_tools_json() {
|
||||
}
|
||||
}
|
||||
|
||||
// 按名称查找工具并分派执行到注册的处理器 / Look up a tool by name and dispatch execution to its registered handler.
|
||||
static char* tools_execute(const char* name, const char* args_json) {
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
@@ -298,14 +312,15 @@ static dstalk_tools_service_t g_tools_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
|
||||
// 插件初始化:查询 file_io 依赖,注册内置文件工具,注册 tools 服务 / Plugin init: query file_io dependency, register built-in file tools, register tools service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
try {
|
||||
g_host.store(host, std::memory_order_release);
|
||||
|
||||
// 查询依赖服务: file_io
|
||||
// 查询依赖服务: file_io / Query dependency service: file_io
|
||||
void* raw = host->query_service("file_io", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-tools] required service 'file_io' not found");
|
||||
@@ -313,7 +328,7 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
|
||||
|
||||
// 向自身注册内置工具
|
||||
// 向自身注册内置工具 / Register built-in tools with self
|
||||
tools_register_tool(
|
||||
"file_read",
|
||||
"Read the contents of a file at the given path",
|
||||
@@ -340,6 +355,7 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:清空所有已注册工具并清空服务指针 / Plugin shutdown: clear all registered tools and null out service pointers.
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_tools_mutex);
|
||||
@@ -358,7 +374,7 @@ static void on_shutdown() {
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"tools",
|
||||
"1.0.0",
|
||||
"Tool registration and execution plugin with built-in file tools",
|
||||
"Tool registration and execution plugin with built-in file tools / 内置文件工具的工具注册和执行插件",
|
||||
DSTALK_API_VERSION,
|
||||
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||||
on_init,
|
||||
@@ -366,6 +382,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
nullptr
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
BIN
scripts/__pycache__/check_agents_metadata.cpython-313.pyc
Normal file
BIN
scripts/__pycache__/check_agents_metadata.cpython-313.pyc
Normal file
Binary file not shown.
@@ -1,8 +1,12 @@
|
||||
// ============================================================================
|
||||
// anthropic_plugin_test.cpp — Anthropic AI 插件单元测试
|
||||
// W21.6 (qa-wang): 覆盖 SSE 解析 / JSON 请求构建 / URL 解析 / 安全擦除
|
||||
// 通过 #include plugin source 访问 file-scope static 函数
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file anthropic_plugin_test.cpp
|
||||
* @brief Anthropic AI plugin unit tests: SSE parsing (parse_sse_data edge cases),
|
||||
* request building (build_request_json), header construction, URL parsing
|
||||
* (extract_host_port), secure_zero, and null-safety for free_result/configure.
|
||||
* Anthropic AI 插件单元测试:SSE 解析(parse_sse_data 边界情况)、请求构建(build_request_json)、
|
||||
* 头部构造、URL 解析(extract_host_port)、secure_zero、free_result/configure 空指针安全。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
#define BOOST_JSON_HEADER_ONLY
|
||||
#define BOOST_ALL_NO_LIB
|
||||
#include "../plugins/anthropic/src/anthropic_plugin.cpp"
|
||||
@@ -12,6 +16,7 @@
|
||||
#include <string>
|
||||
|
||||
static int g_failures = 0;
|
||||
// Lightweight assertion macro: increments g_failures counter on failure
|
||||
#define CHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
@@ -22,6 +27,8 @@ static int g_failures = 0;
|
||||
} while (0)
|
||||
|
||||
// Test helper: populate g_cfg for build functions
|
||||
// Test helper: populate g_cfg with valid anthropic defaults before build_* tests
|
||||
// 测试辅助函数:为 build_* 测试填充 g_cfg 的有效 anthropic 默认值
|
||||
static void setup_config() {
|
||||
g_cfg.provider = "anthropic";
|
||||
g_cfg.base_url = "https://api.anthropic.com";
|
||||
@@ -31,10 +38,18 @@ static void setup_config() {
|
||||
g_cfg.temperature = 0.7;
|
||||
}
|
||||
|
||||
// Anthropic 插件测试 (W21.6):parse_sse_data 畸形/无效 JSON、content_block_delta 文本提取、
|
||||
// message_stop/忽略类型、深层/边界结构、build_request_json 基础+边界、build_headers_json、
|
||||
// extract_host_port、secure_zero、my_free_result 空指针安全、my_configure 空指针安全。
|
||||
// Anthropic plugin tests (W21.6): parse_sse_data for malformed/invalid JSON,
|
||||
// content_block_delta text extraction, message_stop/ignored types, deep/edge structures,
|
||||
// build_request_json basics+edges, build_headers_json, extract_host_port,
|
||||
// secure_zero, my_free_result null-safety, and my_configure null-safety.
|
||||
int main()
|
||||
{
|
||||
// ================================================================
|
||||
// Test Block 1: parse_sse_data — invalid/malformed inputs
|
||||
// 测试块 1:parse_sse_data — 无效/畸形输入
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 1: parse_sse_data invalid/malformed ---\n";
|
||||
|
||||
@@ -69,14 +84,14 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Malformed JSON: unclosed brace
|
||||
// Malformed JSON: unclosed brace / 畸形 JSON:未闭合的花括号
|
||||
std::string token;
|
||||
bool ret = parse_sse_data("{\"type\":\"ping\"", token, nullptr);
|
||||
CHECK(!ret, "T1.6: malformed JSON (unclosed brace) returns false (no crash)");
|
||||
}
|
||||
|
||||
{
|
||||
// Random garbage bytes
|
||||
// Random garbage bytes / 随机垃圾字节
|
||||
std::string token;
|
||||
bool ret = parse_sse_data("\x00\x01\xFF\xFE", token, nullptr);
|
||||
CHECK(!ret, "T1.7: binary garbage returns false (no crash)");
|
||||
@@ -84,6 +99,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 2: parse_sse_data — content_block_delta
|
||||
// 测试块 2:parse_sse_data — content_block_delta
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 2: parse_sse_data content_block_delta ---\n";
|
||||
|
||||
@@ -146,6 +162,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 3: parse_sse_data — message_stop / ignored types
|
||||
// 测试块 3:parse_sse_data — message_stop / 忽略的类型
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 3: parse_sse_data message_stop / ignored types ---\n";
|
||||
|
||||
@@ -194,11 +211,12 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 4: parse_sse_data — deeply nested / edge structures
|
||||
// 测试块 4:parse_sse_data — 深层嵌套 / 边界结构
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 4: parse_sse_data deep/edge structures ---\n";
|
||||
|
||||
{
|
||||
// Unrecognized event type should just be ignored
|
||||
// Unrecognized event type should just be ignored / 未识别的事件类型应被忽略
|
||||
std::string token;
|
||||
const char* json = "{\"type\":\"some_unknown_future_type\"}";
|
||||
bool ret = parse_sse_data(json, token, nullptr);
|
||||
@@ -206,7 +224,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// text_delta with unicode content (Japanese)
|
||||
// text_delta with unicode content (Japanese) / 含 unicode 内容的 text_delta(日语)
|
||||
std::string token;
|
||||
const char* json =
|
||||
"{\"type\":\"content_block_delta\","
|
||||
@@ -221,6 +239,7 @@ int main()
|
||||
|
||||
{
|
||||
// Realistic Anthropic SSE chunk (content_block_delta + text_delta)
|
||||
// 真实的 Anthropic SSE 数据块(content_block_delta + text_delta)
|
||||
std::string token;
|
||||
const char* json =
|
||||
"{\"type\":\"content_block_delta\","
|
||||
@@ -233,12 +252,13 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 5: build_request_json — basic cases
|
||||
// 测试块 5:build_request_json — 基础用例
|
||||
// ================================================================
|
||||
setup_config();
|
||||
std::cout << "\n--- Block 5: build_request_json basic ---\n";
|
||||
|
||||
{
|
||||
// Single user input, no history, stream=false
|
||||
// Single user input, no history, stream=false / 单一用户输入,无历史,stream=false
|
||||
std::string json = build_request_json(nullptr, 0, "Hello", "", false);
|
||||
CHECK(!json.empty(), "T5.1: non-empty JSON produced");
|
||||
CHECK(json.find("\"messages\"") != std::string::npos,
|
||||
@@ -257,7 +277,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// With system message in history
|
||||
// With system message in history / 历史中包含系统消息
|
||||
dstalk_message_t msgs[1] = {
|
||||
{"system", "You are a helpful assistant", nullptr, nullptr}
|
||||
};
|
||||
@@ -268,6 +288,7 @@ int main()
|
||||
"T5.9: system prompt content present");
|
||||
// messages should NOT contain the system role
|
||||
// (since system messages are stripped from messages[] and put in system field)
|
||||
// messages 不应包含 system 角色(系统消息从 messages[] 中提取出来,放入 system 字段)
|
||||
// Actually, the code puts non-system into msgs. Let me check if system is in messages...
|
||||
// The loop skips system: `if (m.role && strcmp(m.role, "system")==0) { ... continue; }`
|
||||
// So system should NOT be in the messages array.
|
||||
@@ -276,7 +297,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// With user+assistant history
|
||||
// With user+assistant history / 包含 user+assistant 历史
|
||||
dstalk_message_t msgs[2] = {
|
||||
{"user", "What is 2+2?", nullptr, nullptr},
|
||||
{"assistant", "It is 4.", nullptr, nullptr}
|
||||
@@ -292,11 +313,12 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 6: build_request_json — edge cases
|
||||
// 测试块 6:build_request_json — 边界情况
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 6: build_request_json edge cases ---\n";
|
||||
|
||||
{
|
||||
// Empty user input
|
||||
// Empty user input / 空用户输入
|
||||
std::string json = build_request_json(nullptr, 0, "", "", false);
|
||||
CHECK(!json.empty(), "T6.1: empty user input produces valid JSON");
|
||||
CHECK(json.find("\"role\":\"user\"") != std::string::npos,
|
||||
@@ -314,7 +336,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Temperature in valid range -> should be included
|
||||
// Temperature in valid range -> should be included / 有效范围内的 temperature -> 应包含
|
||||
g_cfg.temperature = 1.0;
|
||||
std::string json = build_request_json(nullptr, 0, "Hi", "", false);
|
||||
CHECK(json.find("\"temperature\"") != std::string::npos,
|
||||
@@ -323,7 +345,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Temperature out of range -> should NOT be included
|
||||
// Temperature out of range -> should NOT be included / 超出范围的 temperature -> 不应包含
|
||||
g_cfg.temperature = 1.5;
|
||||
std::string json = build_request_json(nullptr, 0, "Hi", "", false);
|
||||
CHECK(json.find("\"temperature\"") == std::string::npos,
|
||||
@@ -336,7 +358,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// History with null role (should default to "")
|
||||
// History with null role (should default to "") / null 角色的历史(应默认为 "")
|
||||
dstalk_message_t msgs[1] = {
|
||||
{nullptr, "some content", nullptr, nullptr}
|
||||
};
|
||||
@@ -345,7 +367,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// History with null content
|
||||
// History with null content / null 内容的历史
|
||||
dstalk_message_t msgs[1] = {
|
||||
{"user", nullptr, nullptr, nullptr}
|
||||
};
|
||||
@@ -355,6 +377,7 @@ int main()
|
||||
|
||||
{
|
||||
// Very long message (>2000 chars) — validate no truncation / crash
|
||||
// 超长消息 (>2000 字符) — 验证无截断/崩溃
|
||||
std::string long_input(5000, 'A');
|
||||
std::string json = build_request_json(nullptr, 0, long_input, "", false);
|
||||
CHECK(!json.empty(), "T6.10: 5000-char input produces valid JSON");
|
||||
@@ -362,7 +385,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Multiple system messages concatenated
|
||||
// Multiple system messages concatenated / 多条系统消息拼接
|
||||
dstalk_message_t msgs[2] = {
|
||||
{"system", "Rule 1: be polite", nullptr, nullptr},
|
||||
{"system", "Rule 2: be concise", nullptr, nullptr}
|
||||
@@ -376,6 +399,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 7: build_headers_json
|
||||
// 测试块 7:build_headers_json
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 7: build_headers_json ---\n";
|
||||
|
||||
@@ -392,7 +416,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// With empty API key
|
||||
// With empty API key / 空 API key
|
||||
std::string saved = g_cfg.api_key;
|
||||
g_cfg.api_key = "";
|
||||
std::string headers = build_headers_json();
|
||||
@@ -405,6 +429,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 8: extract_host_port
|
||||
// 测试块 8:extract_host_port
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 8: extract_host_port ---\n";
|
||||
|
||||
@@ -473,6 +498,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 9: secure_zero
|
||||
// 测试块 9:secure_zero
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 9: secure_zero ---\n";
|
||||
|
||||
@@ -488,7 +514,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Zero-length should not crash
|
||||
// Zero-length should not crash / 零长度不应崩溃
|
||||
char buf[4] = {1,2,3,4};
|
||||
secure_zero(buf, 0);
|
||||
CHECK(buf[0] == 1 && buf[3] == 4,
|
||||
@@ -496,18 +522,20 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Null pointer + zero length = no-op
|
||||
// Null pointer + zero length = no-op / 空指针 + 零长度 = 空操作
|
||||
secure_zero(nullptr, 0);
|
||||
CHECK(true, "T9.3: secure_zero(nullptr, 0) does not crash");
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Test Block 10: my_free_result — null safety
|
||||
// 测试块 10:my_free_result — 空指针安全
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 10: my_free_result null safety ---\n";
|
||||
|
||||
{
|
||||
// g_host is nullptr, so free_result should early-return
|
||||
// g_host 为 nullptr,free_result 应提前返回
|
||||
my_free_result(nullptr);
|
||||
CHECK(true, "T10.1: free_result(nullptr) does not crash (null host)");
|
||||
}
|
||||
@@ -524,11 +552,13 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 11: my_configure — null host safety
|
||||
// 测试块 11:my_configure — null host 安全
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 11: my_configure null host safety ---\n";
|
||||
|
||||
{
|
||||
// g_host is nullptr, configure should still return 0 (log skipped)
|
||||
// g_host 为 nullptr,configure 仍应返回 0(跳过日志)
|
||||
int ret = my_configure(
|
||||
"anthropic", "https://api.anthropic.com",
|
||||
"sk-key", "claude-sonnet", 2048, 0.5);
|
||||
@@ -539,13 +569,13 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Null string params — should not crash
|
||||
// Null string params — should not crash / null 字符串参数 — 不应崩溃
|
||||
int ret = my_configure(nullptr, nullptr, nullptr, nullptr, 4096, 0.7);
|
||||
CHECK(ret == 0, "T11.5: my_configure with all-null strings returns 0");
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Summary
|
||||
// Summary / 总结
|
||||
// ================================================================
|
||||
std::cout << "\n";
|
||||
if (g_failures == 0) {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// ============================================================================
|
||||
// context_plugin_test.cpp — 上下文插件单元测试
|
||||
// ============================================================================
|
||||
// W18.1 (qa-wang + architect-lin): 覆盖 token 计数、trim、UTF-8 边界、
|
||||
// 0xC0/0xC1 过短编码检测。修复 F-11.1-3/4/5/6 后补充测试。
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file context_plugin_test.cpp
|
||||
* @brief Context plugin unit tests: token counting (ASCII, CJK, mixed, emoji),
|
||||
* UTF-8 truncation safety, trim edge cases, and system message preservation.
|
||||
* Context 插件单元测试:token 计数(ASCII、CJK、混合、emoji)、UTF-8 截断安全、trim 边界情况、系统消息保留。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
|
||||
@@ -14,6 +15,7 @@
|
||||
#include <string>
|
||||
|
||||
static int g_failures = 0;
|
||||
// Lightweight assertion macro: increments g_failures counter on failure
|
||||
#define CHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
@@ -23,6 +25,12 @@ static int g_failures = 0;
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Context 插件测试:token 计数边界(null、空、ASCII、CJK、混合)、截断 UTF-8 边界保护 (F-11.1-4)、
|
||||
// 0xC0/0xC1 超长编码 (F-11.1-6)、多消息 token、trim 的各种场景、系统消息保留、4 字节 emoji、孤立的续字节。
|
||||
// Context plugin tests: token counting edge cases (null, empty, ASCII, CJK, mixed),
|
||||
// truncated UTF-8 bounds protection (F-11.1-4), 0xC0/0xC1 overlong encoding (F-11.1-6),
|
||||
// multiple-message tokens, trim null/edge/within-limit/exceeds-limit scenarios,
|
||||
// system message preservation, 4-byte emoji, and lone continuation bytes.
|
||||
int main()
|
||||
{
|
||||
const auto dir = std::filesystem::temp_directory_path() / "dstalk-ctx-test";
|
||||
@@ -54,6 +62,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 1: count_tokens edge cases (null / empty)
|
||||
// 测试块 1:count_tokens 边界情况(null / 空)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 1: count_tokens edge cases ---\n";
|
||||
|
||||
@@ -77,6 +86,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 2: count_tokens — ASCII
|
||||
// 测试块 2:count_tokens — ASCII
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 2: count_tokens ASCII ---\n";
|
||||
|
||||
@@ -107,6 +117,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 3: count_tokens — Chinese (CJK U+4E00-U+9FFF)
|
||||
// 测试块 3:count_tokens — 中文 (CJK U+4E00-U+9FFF)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 3: count_tokens Chinese (CJK) ---\n";
|
||||
|
||||
@@ -132,6 +143,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 4: count_tokens — Mixed content
|
||||
// 测试块 4:count_tokens — 混合内容
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 4: count_tokens mixed content ---\n";
|
||||
|
||||
@@ -146,6 +158,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 5: Truncated UTF-8 bounds protection (F-11.1-4)
|
||||
// 测试块 5:截断 UTF-8 边界保护 (F-11.1-4)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 5: Truncated UTF-8 (F-11.1-4 fix) ---\n";
|
||||
|
||||
@@ -197,6 +210,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6)
|
||||
// 测试块 6:0xC0/0xC1 超长编码 (F-11.1-6)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 6: 0xC0/0xC1 overlong encoding (F-11.1-6 fix) ---\n";
|
||||
|
||||
@@ -230,6 +244,7 @@ int main()
|
||||
{
|
||||
// Verify 0xC0/0xC1 are NOT treated as valid 2-byte sequences
|
||||
// They should each count as 1 other_char, not as 2-byte sequence
|
||||
// 验证 0xC0/0xC1 不被视为合法的 2 字节序列 / 它们每个应计为 1 个 other_char,而非 2 字节序列
|
||||
// 0xC0 + 0xC1 + 2 ASCII = 2 other + 2 ascii
|
||||
// = (2/3) + (2/4) + 4 overhead = 0 + 0 + 4 = 4
|
||||
// Actually 2/4 = 0 (integer division) for ascii, 2/3 = 0 for other
|
||||
@@ -244,6 +259,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 7: count_tokens — multiple messages
|
||||
// 测试块 7:count_tokens — 多消息
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 7: multiple messages ---\n";
|
||||
|
||||
@@ -275,6 +291,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 8: trim — null and edge cases
|
||||
// 测试块 8:trim — null 和边界情况
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 8: trim edge cases ---\n";
|
||||
|
||||
@@ -291,6 +308,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 9: trim — within limit (no trimming needed)
|
||||
// 测试块 9:trim — 预算内(无需裁剪)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 9: trim within limit ---\n";
|
||||
|
||||
@@ -320,6 +338,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 10: trim — exceeds limit (trimming required)
|
||||
// 测试块 10:trim — 超预算(需要裁剪)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 10: trim exceeds limit ---\n";
|
||||
|
||||
@@ -358,6 +377,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 11: trim — system message preservation
|
||||
// 测试块 11:trim — 系统消息保留
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 11: trim preserves system messages ---\n";
|
||||
|
||||
@@ -387,11 +407,12 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 12: count_tokens — 4-byte UTF-8 (emoji / supplementary)
|
||||
// 测试块 12:count_tokens — 4 字节 UTF-8(emoji / 补充平面)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 12: 4-byte UTF-8 ---\n";
|
||||
|
||||
{
|
||||
// U+1F600 (😀) = F0 9F 98 80
|
||||
// U+1F600 (<EFBFBD><EFBFBD>) = F0 9F 98 80
|
||||
char buf[6] = {static_cast<char>(0xF0), static_cast<char>(0x9F),
|
||||
static_cast<char>(0x98), static_cast<char>(0x80), '\0'};
|
||||
dstalk_message_t msg = {"user", buf, nullptr, nullptr};
|
||||
@@ -403,6 +424,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 13: count_tokens — continuation bytes as lone chars
|
||||
// 测试块 13:count_tokens — 孤立的续字节
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 13: lone continuation bytes ---\n";
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
// ============================================================================
|
||||
// deepseek_plugin_test.cpp — DeepSeek AI 插件单元测试
|
||||
// W21.6 (qa-wang): 覆盖 SSE 解析 / [DONE] 匹配 / JSON 请求构建 / tool_calls
|
||||
// 通过 #include plugin source 访问 file-scope static 函数
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file deepseek_plugin_test.cpp
|
||||
* @brief DeepSeek AI plugin unit tests: SSE parsing (parse_sse_line edge cases),
|
||||
* [DONE] sentinel matching, tool_calls delta extraction, request building,
|
||||
* append_history, extract_host_port, secure_zero, and null-safety.
|
||||
* DeepSeek AI 插件单元测试:SSE 解析(parse_sse_line 边界情况)、[DONE] 标记匹配、
|
||||
* tool_calls delta 提取、请求构建、append_history、extract_host_port、secure_zero、空指针安全。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
#define BOOST_JSON_HEADER_ONLY
|
||||
#define BOOST_ALL_NO_LIB
|
||||
#include "../plugins/deepseek/src/deepseek_plugin.cpp"
|
||||
@@ -12,6 +16,7 @@
|
||||
#include <string>
|
||||
|
||||
static int g_failures = 0;
|
||||
// Lightweight assertion macro: increments g_failures counter on failure
|
||||
#define CHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
@@ -22,6 +27,8 @@ static int g_failures = 0;
|
||||
} while (0)
|
||||
|
||||
// Test helper: populate g_cfg for build functions
|
||||
// Test helper: populate g_cfg with valid deepseek defaults before build_* tests
|
||||
// 测试辅助函数:为 build_* 测试填充 g_cfg 的有效 deepseek 默认值
|
||||
static void setup_config() {
|
||||
g_cfg.provider = "deepseek";
|
||||
g_cfg.base_url = "https://api.deepseek.com/v1";
|
||||
@@ -31,10 +38,19 @@ static void setup_config() {
|
||||
g_cfg.temperature = 0.7;
|
||||
}
|
||||
|
||||
// DeepSeek 插件测试 (W21.6):parse_sse_line 无效/畸形输入、[DONE] 标记及空白变体、
|
||||
// content delta 提取、tool_calls delta 累积、build_request_json(基础、tools、边界)、
|
||||
// build_headers_json、extract_host_port、secure_zero、append_history(所有消息类型)、
|
||||
// my_free_result、my_configure。
|
||||
// DeepSeek plugin tests (W21.6): parse_sse_line invalid/malformed inputs, [DONE] sentinel
|
||||
// with whitespace variants, content delta extraction, tool_calls delta accumulation,
|
||||
// build_request_json (basic, tools, edge cases), build_headers_json, extract_host_port,
|
||||
// secure_zero, append_history (all message types), my_free_result, and my_configure.
|
||||
int main()
|
||||
{
|
||||
// ================================================================
|
||||
// Test Block 1: parse_sse_line — invalid/malformed inputs
|
||||
// 测试块 1:parse_sse_line — 无效/畸形输入
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 1: parse_sse_line invalid/malformed ---\n";
|
||||
|
||||
@@ -58,27 +74,28 @@ int main()
|
||||
|
||||
{
|
||||
// "data:" without space — rfind("data: ", 0) should fail
|
||||
// "data:" 无空格 — rfind("data: ", 0) 应失败
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data:{\"x\":1}", token, nullptr);
|
||||
CHECK(!ret, "T1.4: 'data:' without trailing space returns false (rfind mismatch)");
|
||||
}
|
||||
|
||||
{
|
||||
// "data: " followed by invalid JSON
|
||||
// "data: " followed by invalid JSON / "data: " 后跟无效 JSON
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: not valid json!!!", token, nullptr);
|
||||
CHECK(!ret, "T1.5: 'data: ' + invalid JSON returns false (no crash)");
|
||||
}
|
||||
|
||||
{
|
||||
// "data: " followed by binary garbage
|
||||
// "data: " followed by binary garbage / "data: " 后跟二进制垃圾
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: \x00\x01\xFF\xFE", token, nullptr);
|
||||
CHECK(!ret, "T1.6: 'data: ' + binary garbage returns false (no crash)");
|
||||
}
|
||||
|
||||
{
|
||||
// Empty data after "data: "
|
||||
// Empty data after "data: " / "data: " 后数据为空
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: ", token, nullptr);
|
||||
CHECK(!ret, "T1.7: 'data: ' with empty payload returns false");
|
||||
@@ -86,6 +103,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 2: parse_sse_line — [DONE] sentinel
|
||||
// 测试块 2:parse_sse_line — [DONE] 标记
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 2: parse_sse_line [DONE] sentinel ---\n";
|
||||
|
||||
@@ -97,7 +115,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// [DONE] with leading whitespace
|
||||
// [DONE] with leading whitespace / [DONE] 前导空白
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: [DONE]", token, nullptr);
|
||||
CHECK(ret, "T2.3: 'data: [DONE]' (leading spaces) returns true");
|
||||
@@ -105,7 +123,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// [DONE] with trailing whitespace
|
||||
// [DONE] with trailing whitespace / [DONE] 尾部空白
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: [DONE] ", token, nullptr);
|
||||
CHECK(ret, "T2.5: 'data: [DONE] ' (trailing spaces) returns true");
|
||||
@@ -113,7 +131,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// [DONE] with tabs and newlines around it
|
||||
// [DONE] with tabs and newlines around it / [DONE] 周围有制表符和换行符
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: \t [DONE] \t\r\n", token, nullptr);
|
||||
CHECK(ret, "T2.7: '[DONE]' with mixed whitespace returns true");
|
||||
@@ -121,7 +139,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// [DONE] without spaces — exact match
|
||||
// [DONE] without spaces — exact match / [DONE] 精确匹配(无空格)
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: [DONE]", token, nullptr);
|
||||
CHECK(ret, "T2.9: '[DONE]' exact match returns true");
|
||||
@@ -129,13 +147,14 @@ int main()
|
||||
|
||||
{
|
||||
// "[done]" lowercase — should NOT match (case-sensitive)
|
||||
// "[done]" 小写 — 不应匹配(大小写敏感)
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: [done]", token, nullptr);
|
||||
CHECK(!ret, "T2.10: '[done]' lowercase NOT treated as DONE (case-sensitive)");
|
||||
}
|
||||
|
||||
{
|
||||
// "[DONE" without closing bracket
|
||||
// "[DONE" without closing bracket / "[DONE" 缺少闭括号
|
||||
std::string token;
|
||||
bool ret = parse_sse_line("data: [DONE", token, nullptr);
|
||||
CHECK(!ret, "T2.11: '[DONE' (no closing bracket) not treated as DONE");
|
||||
@@ -143,6 +162,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 3: parse_sse_line — content delta
|
||||
// 测试块 3:parse_sse_line — content delta
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 3: parse_sse_line content delta ---\n";
|
||||
|
||||
@@ -166,7 +186,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Delta with no content field
|
||||
// Delta with no content field / delta 不含 content 字段
|
||||
std::string token;
|
||||
const char* json =
|
||||
"data: {\"choices\":[{\"delta\":{},\"index\":0}]}";
|
||||
@@ -175,7 +195,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Empty choices array
|
||||
// Empty choices array / 空 choices 数组
|
||||
std::string token;
|
||||
const char* json =
|
||||
"data: {\"choices\":[]}";
|
||||
@@ -184,7 +204,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Single character token (typical streaming)
|
||||
// Single character token (typical streaming) / 单字符 token(典型流式)
|
||||
std::string token;
|
||||
const char* json =
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"H\"},\"index\":0}]}";
|
||||
@@ -194,7 +214,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Multi-byte UTF-8 content (emoji) in delta
|
||||
// Multi-byte UTF-8 content (emoji) in delta / delta 中的多字节 UTF-8 内容(emoji)
|
||||
std::string token;
|
||||
const char* json =
|
||||
"data: {\"choices\":[{\"delta\":{\"content\":\"\\uD83D\\uDE00\"},"
|
||||
@@ -207,7 +227,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Malformed JSON structure — no "delta" key
|
||||
// Malformed JSON structure — no "delta" key / 畸形 JSON 结构 — 无 "delta" key
|
||||
std::string token;
|
||||
const char* json =
|
||||
"data: {\"choices\":[{\"no_delta\":{},\"index\":0}]}";
|
||||
@@ -217,6 +237,7 @@ int main()
|
||||
|
||||
{
|
||||
// Realistic DeepSeek streaming chunk (with finish_reason)
|
||||
// 真实的 DeepSeek 流式数据块(含 finish_reason)
|
||||
std::string token;
|
||||
const char* json =
|
||||
"data: {\"id\":\"chatcmpl-xxx\","
|
||||
@@ -233,11 +254,13 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 4: parse_sse_line — tool_calls delta
|
||||
// 测试块 4:parse_sse_line — tool_calls delta
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 4: parse_sse_line tool_calls delta ---\n";
|
||||
|
||||
{
|
||||
// tool_calls chunk with id + function name (first chunk)
|
||||
// tool_calls 数据块含 id + function name(首个数据块)
|
||||
StreamContext ctx = {};
|
||||
std::string token;
|
||||
const char* json =
|
||||
@@ -258,8 +281,9 @@ int main()
|
||||
|
||||
{
|
||||
// tool_calls arguments chunk (second chunk, same index)
|
||||
// tool_calls arguments 数据块(第二个数据块,相同 index)
|
||||
StreamContext ctx;
|
||||
// First, set up the initial state
|
||||
// First, set up the initial state / 先设置初始状态
|
||||
ctx.tool_calls.push_back({0, "call_abc123", "get_weather", ""});
|
||||
|
||||
std::string token;
|
||||
@@ -276,7 +300,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// tool_calls final arguments chunk
|
||||
// tool_calls final arguments chunk / tool_calls 最终 arguments 数据块
|
||||
StreamContext ctx;
|
||||
ctx.tool_calls.push_back({0, "call_abc123", "get_weather", "{\"city\":\""});
|
||||
|
||||
@@ -295,6 +319,7 @@ int main()
|
||||
|
||||
{
|
||||
// tool_calls with null ctx — should skip tool_calls processing
|
||||
// tool_calls 配合 null ctx — 应跳过 tool_calls 处理
|
||||
std::string token;
|
||||
const char* json =
|
||||
"data: {\"choices\":[{\"index\":0,"
|
||||
@@ -306,6 +331,7 @@ int main()
|
||||
|
||||
{
|
||||
// Multiple tool_calls in single chunk (unusual but valid)
|
||||
// 单个数据块中有多个 tool_calls(不常见但合法)
|
||||
StreamContext ctx;
|
||||
std::string token;
|
||||
const char* json =
|
||||
@@ -325,6 +351,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 5: build_request_json — basic cases
|
||||
// 测试块 5:build_request_json — 基础用例
|
||||
// ================================================================
|
||||
setup_config();
|
||||
std::cout << "\n--- Block 5: build_request_json basic ---\n";
|
||||
@@ -351,7 +378,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// With user+assistant history
|
||||
// With user+assistant history / 包含 user+assistant 历史
|
||||
dstalk_message_t msgs[2] = {
|
||||
{"user", "What is 2+2?", nullptr, nullptr},
|
||||
{"assistant", "It is 4.", nullptr, nullptr}
|
||||
@@ -376,22 +403,26 @@ int main()
|
||||
|
||||
{
|
||||
// Empty user input — no user message appended
|
||||
// 空用户输入 — 不追加 user 消息
|
||||
std::string json = build_request_json(
|
||||
nullptr, 0, "", "", false);
|
||||
CHECK(!json.empty(), "T5.13: empty user input produces valid JSON");
|
||||
// DeepSeek's build_request_json checks `if (!user_input.empty())` before adding
|
||||
// So there should be no user message for empty input
|
||||
// DeepSeek 的 build_request_json 在添加前检查 `if (!user_input.empty())`
|
||||
// 因此空输入时不应有 user 消息
|
||||
CHECK(json.find("\"role\":\"user\"") == std::string::npos,
|
||||
"T5.14: empty user input NOT added to messages (DeepSeek guard)");
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Test Block 6: build_request_json — tools / edge cases
|
||||
// 测试块 6:build_request_json — tools / 边界情况
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 6: build_request_json tools / edges ---\n";
|
||||
|
||||
{
|
||||
// With tools_json
|
||||
// With tools_json / 含 tools_json
|
||||
std::string tools = "[{\"type\":\"function\","
|
||||
"\"function\":{\"name\":\"get_weather\","
|
||||
"\"description\":\"Get current weather\","
|
||||
@@ -407,7 +438,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Empty tools_json — no tools field
|
||||
// Empty tools_json — no tools field / 空 tools_json — 无 tools 字段
|
||||
std::string json = build_request_json(
|
||||
nullptr, 0, "Hello", "", false);
|
||||
CHECK(json.find("\"tools\"") == std::string::npos,
|
||||
@@ -418,6 +449,8 @@ int main()
|
||||
// Malformed tools_json — build_request_json calls json::parse()
|
||||
// without try/catch, so it will throw std::exception.
|
||||
// This test verifies that the exception is thrown (rather than crashing).
|
||||
// 畸形 tools_json — build_request_json 调用 json::parse() 不含 try/catch,
|
||||
// 因此会抛出 std::exception。本测试验证异常被抛出(而非崩溃)。
|
||||
bool threw = false;
|
||||
try {
|
||||
build_request_json(nullptr, 0, "Hello", "NOT JSON", false);
|
||||
@@ -430,7 +463,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// History with null role
|
||||
// History with null role / null 角色的历史
|
||||
dstalk_message_t msgs[1] = {
|
||||
{nullptr, "some content", nullptr, nullptr}
|
||||
};
|
||||
@@ -439,7 +472,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// History with null content
|
||||
// History with null content / null 内容的历史
|
||||
dstalk_message_t msgs[1] = {
|
||||
{"user", nullptr, nullptr, nullptr}
|
||||
};
|
||||
@@ -448,7 +481,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Very long message
|
||||
// Very long message / 超长消息
|
||||
std::string long_input(5000, 'A');
|
||||
std::string json = build_request_json(
|
||||
nullptr, 0, long_input, "", false);
|
||||
@@ -458,6 +491,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 7: build_headers_json
|
||||
// 测试块 7:build_headers_json
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 7: build_headers_json ---\n";
|
||||
|
||||
@@ -470,7 +504,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Empty API key
|
||||
// Empty API key / 空 API key
|
||||
std::string headers = build_headers_json("");
|
||||
CHECK(headers.find("Authorization") != std::string::npos,
|
||||
"T7.3: Authorization header present with empty key");
|
||||
@@ -480,6 +514,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 8: extract_host_port (same logic as anthropic)
|
||||
// 测试块 8:extract_host_port(逻辑同 anthropic)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 8: extract_host_port ---\n";
|
||||
|
||||
@@ -525,6 +560,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 9: secure_zero
|
||||
// 测试块 9:secure_zero
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 9: secure_zero ---\n";
|
||||
|
||||
@@ -546,6 +582,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 10: append_history
|
||||
// 测试块 10:append_history
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 10: append_history ---\n";
|
||||
|
||||
@@ -561,7 +598,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Tool message (should include tool_call_id)
|
||||
// Tool message (should include tool_call_id) / Tool 消息(应包含 tool_call_id)
|
||||
json::array msgs;
|
||||
dstalk_message_t m = {"tool", "result data", "call_xyz", nullptr};
|
||||
append_history(msgs, &m, 1);
|
||||
@@ -575,7 +612,7 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Assistant with tool_calls_json
|
||||
// Assistant with tool_calls_json / Assistant 含 tool_calls_json
|
||||
json::array msgs;
|
||||
const char* tc_json = "[{\"id\":\"call_1\",\"type\":\"function\","
|
||||
"\"function\":{\"name\":\"get_weather\",\"arguments\":\"{}\"}}]";
|
||||
@@ -589,14 +626,14 @@ int main()
|
||||
}
|
||||
|
||||
{
|
||||
// Empty history (0 messages)
|
||||
// Empty history (0 messages) / 空历史(0 条消息)
|
||||
json::array msgs;
|
||||
append_history(msgs, nullptr, 0);
|
||||
CHECK(msgs.size() == 0, "T10.12: empty history produces empty array");
|
||||
}
|
||||
|
||||
{
|
||||
// Multiple messages
|
||||
// Multiple messages / 多条消息
|
||||
json::array msgs;
|
||||
dstalk_message_t ms[2] = {
|
||||
{"user", "Q1", nullptr, nullptr},
|
||||
@@ -608,6 +645,7 @@ int main()
|
||||
|
||||
{
|
||||
// Null role and null content — default to empty strings
|
||||
// null 角色与 null 内容 — 默认为空字符串
|
||||
json::array msgs;
|
||||
dstalk_message_t m = {nullptr, nullptr, nullptr, nullptr};
|
||||
append_history(msgs, &m, 1);
|
||||
@@ -619,11 +657,13 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 11: my_free_result — null safety
|
||||
// 测试块 11:my_free_result — 空指针安全
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 11: my_free_result null safety ---\n";
|
||||
|
||||
{
|
||||
// g_host is nullptr, so free_result should early-return
|
||||
// g_host 为 nullptr,free_result 应提前返回
|
||||
my_free_result(nullptr);
|
||||
CHECK(true, "T11.1: free_result(nullptr) does not crash (null host)");
|
||||
}
|
||||
@@ -637,6 +677,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 12: my_configure — null host safety
|
||||
// 测试块 12:my_configure — null host 安全
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 12: my_configure null host safety ---\n";
|
||||
|
||||
@@ -656,7 +697,7 @@ int main()
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Summary
|
||||
// Summary / 总结
|
||||
// ================================================================
|
||||
std::cout << "\n";
|
||||
if (g_failures == 0) {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// ============================================================================
|
||||
// event_bus_test.cpp — EventBus 单元测试
|
||||
// ============================================================================
|
||||
// 测试: subscribe / unsubscribe / emit / 多订阅者 / 空总线
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file event_bus_test.cpp
|
||||
* @brief EventBus unit tests: subscribe, emit, unsubscribe, multi-handler
|
||||
* dispatch order, independent event types.
|
||||
* EventBus 单元测试:订阅、发布、取消订阅、多处理器分发顺序、独立事件类型。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
@@ -13,6 +15,7 @@
|
||||
|
||||
// ---- 轻量断言 ----
|
||||
static int g_failures = 0;
|
||||
// Lightweight assertion helper: increments g_failures counter on failure
|
||||
#define TCHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
@@ -22,13 +25,16 @@ static int g_failures = 0;
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// ============================================================
|
||||
// EventBus 单元测试:订阅+发布、取消订阅、多处理器分发顺序、空总线、独立事件类型路由、取消不存在的订阅。
|
||||
// EventBus unit tests: subscribe+emit, unsubscribe, multi-handler dispatch order,
|
||||
// empty bus, independent event type routing, and non-existent unsubscribe safety.
|
||||
int main()
|
||||
{
|
||||
std::cout << "=== dstalk event_bus unit tests ===\n\n";
|
||||
|
||||
// ====================================================================
|
||||
// Test 1: subscribe + emit — 基本发布订阅流程
|
||||
// Test 1: subscribe + emit — basic pub/sub flow
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::EventBus bus;
|
||||
@@ -49,6 +55,7 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 2: unsubscribe — 取消订阅后 handler 不再被调用
|
||||
// Test 2: unsubscribe — handler NOT called after unsubscription
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::EventBus bus;
|
||||
@@ -64,6 +71,7 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 3: 多订阅者 — 同一事件多个 handler 按订阅顺序全部调用
|
||||
// Test 3: multi-subscriber — all handlers for same event invoked in subscription order
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::EventBus bus;
|
||||
@@ -77,13 +85,14 @@ int main()
|
||||
TCHECK(emitted == 3, "emit returns 3 handlers called");
|
||||
TCHECK(order.size() == 3, "all 3 handlers invoked");
|
||||
|
||||
// 验证订阅顺序 (FIFO: 按 subscribe 顺序触发)
|
||||
// 验证订阅顺序 (FIFO: 按 subscribe 顺序触发) / Verify subscription order (FIFO: in subscribe order)
|
||||
bool ordered = (order[0] == 1 && order[1] == 2 && order[2] == 3);
|
||||
TCHECK(ordered, "handlers invoked in subscription order (1,2,3)");
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// Test 4: 空总线 emit 不崩溃,返回 0
|
||||
// Test 4: emit on empty bus no crash, returns 0
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::EventBus bus;
|
||||
@@ -93,6 +102,7 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 5: 不同 event_type 独立分发 — 只触发匹配的 handler
|
||||
// Test 5: independent event_type dispatch — only matching handler triggered
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::EventBus bus;
|
||||
@@ -112,15 +122,16 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 6: 退订不存在的 ID 不崩溃
|
||||
// Test 6: unsubscribe non-existent ID does not crash
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::EventBus bus;
|
||||
bus.unsubscribe(99999); // 不存在的 ID
|
||||
bus.unsubscribe(99999); // 不存在的 ID / non-existent ID
|
||||
std::cout << "[OK] unsubscribe non-existent ID (99999) did not crash\n";
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// 结果
|
||||
// 结果 / Result
|
||||
// ====================================================================
|
||||
std::cout << "\n";
|
||||
if (g_failures == 0) {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
// ============================================================================
|
||||
// host_api_test.cpp — host API 单元测试 (独立于 smoke_test)
|
||||
// ============================================================================
|
||||
// 测试: register_service / query_service / alloc / free / log / init / shutdown
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file host_api_test.cpp
|
||||
* @brief Host API unit tests: service registration, event bus, config store,
|
||||
* alloc/free, logging, init/shutdown lifecycle.
|
||||
* Host API 单元测试:服务注册、事件总线、配置存储、alloc/free、日志、init/shutdown 生命周期。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
@@ -13,13 +15,14 @@
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
// 引入 ServiceRegistry 实现做纯单元测试
|
||||
// 引入 ServiceRegistry 实现做纯单元测试 / Include ServiceRegistry impl for pure unit tests
|
||||
#include "service_registry.hpp"
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
|
||||
// ---- 轻量断言 ----
|
||||
static int g_failures = 0;
|
||||
// Lightweight assertion helper: increments g_failures counter on failure
|
||||
#define TCHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
@@ -29,26 +32,32 @@ static int g_failures = 0;
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// ---- 辅助: 创建临时配置文件 ----
|
||||
|
||||
// Helper: creates a temporary config.toml pointing to a non-existent plugin dir,
|
||||
// so dstalk_init loads no external plugins during tests.
|
||||
// 辅助函数:创建临时 config.toml 指向不存在的插件目录,使 dstalk_init 在测试时不加载任何外部插件。
|
||||
static std::string make_temp_config(const std::string& tag) {
|
||||
auto dir = std::filesystem::temp_directory_path() / ("dstalk-host-api-" + tag);
|
||||
std::filesystem::create_directories(dir);
|
||||
auto config_path = dir / "config.toml";
|
||||
{
|
||||
std::ofstream c(config_path);
|
||||
// 指向不存在的插件目录,避免加载任何 .dll
|
||||
// 指向不存在的插件目录,避免加载任何 .dll / Point to nonexistent plugin dir, avoid loading any .dll
|
||||
c << "plugin_dir = \"__no_such_plugins_dir__\"\n";
|
||||
}
|
||||
return config_path.string();
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Host API 单元测试:覆盖注册/查询重复、版本不匹配、双重 init 防护、alloc/free 边界、日志级别、shutdown 后查询。
|
||||
// Host API unit tests: covers register/query duplicates, version mismatch,
|
||||
// double-init guard, alloc/free edge cases, logging levels, and post-shutdown query.
|
||||
int main()
|
||||
{
|
||||
std::cout << "=== dstalk host_api unit tests ===\n\n";
|
||||
|
||||
// ====================================================================
|
||||
// Test 1: register_service 重复注册 同名+同版本 → 应返回 -2
|
||||
// Test 1: register_service duplicate same-name+same-version -> should return -2
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::ServiceRegistry reg;
|
||||
@@ -64,6 +73,8 @@ int main()
|
||||
// ====================================================================
|
||||
// Test 2: register_service 同名+不同版本 → 应返回 -2
|
||||
// 名称已占用,与版本无关
|
||||
// Test 2: register_service same-name+different-version -> should return -2
|
||||
// Name already taken, regardless of version
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::ServiceRegistry reg;
|
||||
@@ -78,6 +89,7 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 3: query_service 不存在的 name → nullptr
|
||||
// Test 3: query_service nonexistent name -> nullptr
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::ServiceRegistry reg;
|
||||
@@ -88,6 +100,8 @@ int main()
|
||||
// ====================================================================
|
||||
// Test 4: query_service 错误版本号 → nullptr
|
||||
// 注册 v=1, 查询 min_version=2 → 不满足 → nullptr
|
||||
// Test 4: query_service wrong version -> nullptr
|
||||
// Registered v=1, query min_version=2 -> unsatisfied -> nullptr
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::ServiceRegistry reg;
|
||||
@@ -97,13 +111,14 @@ int main()
|
||||
void* q = reg.query_service("solo", 2);
|
||||
TCHECK(q == nullptr, "query_service(\"solo\",2) with only v1 available returns nullptr");
|
||||
|
||||
// 确证以正确版本查询能拿到
|
||||
// 确证以正确版本查询能拿到 / Confirm correct version query works
|
||||
void* q2 = reg.query_service("solo", 1);
|
||||
TCHECK(q2 == dummy_vtable, "query_service(\"solo\",1) with v1 available returns vtable");
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// Test 5: dstalk_init 多次调用 → 第二次应返回 -1 (幂等拒绝)
|
||||
// Test 5: dstalk_init multiple calls -> second should return -1 (idempotent guard)
|
||||
// ====================================================================
|
||||
{
|
||||
std::string cfg = make_temp_config("init-twice");
|
||||
@@ -120,6 +135,9 @@ int main()
|
||||
// Test 6: alloc(0) / free(nullptr) 行为
|
||||
// malloc(0) 可返回 null 或合法指针; 两者都可 free
|
||||
// free(nullptr) 是安全空操作
|
||||
// Test 6: alloc(0) / free(nullptr) behavior
|
||||
// malloc(0) may return null or valid pointer; both are free-able
|
||||
// free(nullptr) is a safe no-op
|
||||
// ====================================================================
|
||||
{
|
||||
void* p = dstalk_alloc(0);
|
||||
@@ -134,6 +152,7 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 7: log 各 level 不崩溃 (DEBUG / INFO / WARN / ERROR)
|
||||
// Test 7: log at each level no crash (DEBUG / INFO / WARN / ERROR)
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk_log(DSTALK_LOG_DEBUG, "host_api_test: debug level message");
|
||||
@@ -148,7 +167,7 @@ int main()
|
||||
dstalk_log(DSTALK_LOG_ERROR, "host_api_test: error level message");
|
||||
std::cout << "[OK] dstalk_log(ERROR) no crash\n";
|
||||
|
||||
// 带格式参数
|
||||
// 带格式参数 / With format args
|
||||
dstalk_log(DSTALK_LOG_INFO, "formatted: %s %d", "answer", 42);
|
||||
std::cout << "[OK] dstalk_log with format args no crash\n";
|
||||
}
|
||||
@@ -156,6 +175,8 @@ int main()
|
||||
// ====================================================================
|
||||
// Test 8: dstalk_shutdown 后 query_service → nullptr
|
||||
// g_service_registry 已被 delete 置空
|
||||
// Test 8: query_service after dstalk_shutdown -> nullptr
|
||||
// g_service_registry has been deleted and nulled
|
||||
// ====================================================================
|
||||
{
|
||||
std::string cfg = make_temp_config("after-shutdown");
|
||||
@@ -167,7 +188,7 @@ int main()
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// 结果
|
||||
// 结果 / Result
|
||||
// ====================================================================
|
||||
std::cout << "\n";
|
||||
if (g_failures == 0) {
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
// ============================================================================
|
||||
// network_plugin_test.cpp — Network 插件单元测试
|
||||
// W22.2 (qa-xu): 覆盖 parse_headers_json / SSE 行解析 / 参数校验
|
||||
// 通过 #include plugin source 访问 file-scope static 函数
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file network_plugin_test.cpp
|
||||
* @brief Network plugin unit tests (W22.2): parse_headers_json (normal, empty,
|
||||
* malformed, long values), SSE line splitting boundaries, and
|
||||
* http_post_json/http_post_stream parameter validation.
|
||||
* Network 插件单元测试 (W22.2):parse_headers_json(正常、空、畸形、长值)、SSE 行解析边界、
|
||||
* http_post_json/http_post_stream 参数校验。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
#define _CRT_SECURE_NO_WARNINGS
|
||||
#define BOOST_ASIO_DISABLE_STD_TO_ADDRESS
|
||||
#include "../plugins/network/src/network_plugin.cpp"
|
||||
@@ -15,6 +19,7 @@
|
||||
#include <vector>
|
||||
|
||||
static int g_failures = 0;
|
||||
// Lightweight assertion macro: increments g_failures counter on failure
|
||||
#define CHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
@@ -24,9 +29,9 @@ static int g_failures = 0;
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// ================================================================
|
||||
// SSE 行分割 helper (复刻 do_post_stream 的 emit_lines 逻辑)
|
||||
// ================================================================
|
||||
// SSE line-split helper: mirrors do_post_stream's emit_lines logic for unit-testing
|
||||
// SSE chunk parsing without a live network connection.
|
||||
// SSE 行分割辅助函数:镜像 do_post_stream 的 emit_lines 逻辑,无需实时网络连接即可单元测试 SSE 数据块解析。
|
||||
static std::vector<std::string> split_sse_lines(std::string fragment) {
|
||||
std::vector<std::string> lines;
|
||||
size_t pos = 0;
|
||||
@@ -55,11 +60,17 @@ static std::vector<std::string> split_sse_lines(std::string fragment) {
|
||||
return lines;
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Network 插件测试 (W22.2):parse_headers_json 正常/空/畸形/长值、
|
||||
// SSE 行分割(LF、CRLF、空、null 字节、尾部 CR)、
|
||||
// http_post_json/http_post_stream 参数校验(空指针)。
|
||||
// Network plugin tests (W22.2): parse_headers_json normal/empty/malformed/long,
|
||||
// SSE line splitting (LF, CRLF, empty, null-bytes, trailing CR),
|
||||
// and http_post_json/http_post_stream parameter validation (null pointers).
|
||||
int main()
|
||||
{
|
||||
// ================================================================
|
||||
// Test Block 1: parse_headers_json — 正常 JSON
|
||||
// Test Block 1: parse_headers_json — normal JSON
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 1: parse_headers_json normal JSON ---\n";
|
||||
|
||||
@@ -98,6 +109,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 2: parse_headers_json — 空 / null 输入
|
||||
// Test Block 2: parse_headers_json — empty/null input
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 2: parse_headers_json empty/null input ---\n";
|
||||
|
||||
@@ -124,6 +136,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 3: parse_headers_json — 畸形 JSON
|
||||
// Test Block 3: parse_headers_json — malformed JSON
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 3: parse_headers_json malformed JSON ---\n";
|
||||
|
||||
@@ -177,6 +190,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 4: parse_headers_json — 超长 header 值
|
||||
// Test Block 4: parse_headers_json — long values
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 4: parse_headers_json long values ---\n";
|
||||
|
||||
@@ -209,6 +223,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 5: SSE 行解析边界
|
||||
// Test Block 5: SSE line splitting boundaries
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 5: SSE line splitting boundaries ---\n";
|
||||
|
||||
@@ -274,6 +289,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 6: http_post_json — 参数校验 (null ptr, early return)
|
||||
// Test Block 6: http_post_json — parameter validation (null ptr, early return)
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 6: http_post_json parameter validation ---\n";
|
||||
|
||||
@@ -319,6 +335,7 @@ int main()
|
||||
|
||||
// ================================================================
|
||||
// Test Block 7: http_post_stream — 参数校验
|
||||
// Test Block 7: http_post_stream — parameter validation
|
||||
// ================================================================
|
||||
std::cout << "\n--- Block 7: http_post_stream parameter validation ---\n";
|
||||
|
||||
@@ -338,7 +355,7 @@ int main()
|
||||
}
|
||||
|
||||
// ================================================================
|
||||
// Summary
|
||||
// Summary / 总结
|
||||
// ================================================================
|
||||
std::cout << "\n";
|
||||
if (g_failures == 0) {
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
// ============================================================================
|
||||
// plugin_loader_test.cpp — PluginLoader 安全回归测试
|
||||
// ============================================================================
|
||||
// W20.3 (qa-xu 徐磊): 覆盖 W19 修复的 5 条发现 (F-18.3-1~5)
|
||||
// - F-18.3-3: 路径验证 (lexically_normal + 扩展名 + 目录约束)
|
||||
// - F-18.3-4: next_id_ atomic 唯一性 + 单调递增
|
||||
// - F-18.3-2: host_api_->log 调用 (mock 验证)
|
||||
// - F-18.3-1: try/catch 异常安全边界 (间接: 注入 mock 不崩溃)
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file plugin_loader_test.cpp
|
||||
* @brief PluginLoader safety regression tests (W20.3): path validation,
|
||||
* ABI checks, next_id_ atomicity, failure-path logging with mock host API.
|
||||
* PluginLoader 安全回归测试 (W20.3):路径验证、ABI 检查、next_id_ 原子性、失败路径日志(使用 mock host API)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "plugin_loader.hpp"
|
||||
|
||||
@@ -24,6 +22,7 @@ namespace fs = std::filesystem;
|
||||
|
||||
// ---- 轻量断言 ----
|
||||
static int g_failures = 0;
|
||||
// Lightweight assertion macro: increments g_failures counter on failure
|
||||
#define CHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
@@ -35,11 +34,14 @@ static int g_failures = 0;
|
||||
|
||||
// ============================================================================
|
||||
// Mock host_api — 捕获 log 调用以验证失败路径日志 (F-18.3-2)
|
||||
// Mock host_api — captures log calls to verify failure-path logging (F-18.3-2)
|
||||
// ============================================================================
|
||||
static int g_log_call_count = 0;
|
||||
static int g_last_severity = 0;
|
||||
static char g_last_log_msg[1024] = {0};
|
||||
|
||||
// Mock host_api::log implementation: counts calls and captures last severity+message
|
||||
// Mock host_api::log 实现:计数调用并捕获最后的 severity+message
|
||||
static void mock_log(int level, const char* fmt, ...) {
|
||||
g_log_call_count++;
|
||||
g_last_severity = level;
|
||||
@@ -49,6 +51,8 @@ static void mock_log(int level, const char* fmt, ...) {
|
||||
va_end(args);
|
||||
}
|
||||
|
||||
// Stub host_api functions: return failure/default for all operations except log
|
||||
// Stub host_api 函数:除 log 外所有操作均返回失败/默认值
|
||||
static int stub_reg(const char*, int, void*) { return -1; }
|
||||
static void* stub_query(const char*, int) { return nullptr; }
|
||||
static int stub_sub(int, dstalk_event_handler_fn, void*) { return -1; }
|
||||
@@ -60,6 +64,8 @@ static void* stub_alloc(size_t) { return nullptr; }
|
||||
static void stub_free(void*) {}
|
||||
static char* stub_strdup(const char*) { return nullptr; }
|
||||
|
||||
// Mock host_api vtable: all stubs except mock_log for capturing error-path diagnostics
|
||||
// Mock host_api 虚表:除 mock_log 外全部 stub,用于捕获错误路径诊断
|
||||
static dstalk_host_api_t g_mock_host_api = {
|
||||
stub_reg, stub_query,
|
||||
stub_sub, stub_emit, stub_unsub,
|
||||
@@ -68,15 +74,16 @@ static dstalk_host_api_t g_mock_host_api = {
|
||||
stub_alloc, stub_free, stub_strdup
|
||||
};
|
||||
|
||||
// Reset log capture state between tests
|
||||
// 重置日志捕获状态(测试间使用)
|
||||
static void reset_log_state() {
|
||||
g_log_call_count = 0;
|
||||
g_last_severity = 0;
|
||||
g_last_log_msg[0] = '\0';
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Helper: 获取已构建的 plugins/ 目录绝对路径
|
||||
// ============================================================================
|
||||
// Get the absolute path to the build output plugins/ directory
|
||||
// 获取构建输出 plugins/ 目录的绝对路径
|
||||
static fs::path get_plugins_dir() {
|
||||
#ifdef DSTALK_TEST_PLUGINS_DIR
|
||||
return fs::path(DSTALK_TEST_PLUGINS_DIR);
|
||||
@@ -85,50 +92,56 @@ static fs::path get_plugins_dir() {
|
||||
#endif
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// PluginLoader 回归测试 (W20.3):F-18.3-3 路径验证拒绝、F-18.3-4 next_id_ 唯一性+单调性+并发、
|
||||
// F-18.3-2 失败路径日志,以及边界情况(空 loader、无效操作)。
|
||||
// PluginLoader regression tests (W20.3): F-18.3-3 path validation rejection,
|
||||
// F-18.3-4 next_id_ uniqueness+monotonic+concurrent, F-18.3-2 failure-path logging,
|
||||
// and edge cases (empty loader, invalid operations).
|
||||
int main()
|
||||
{
|
||||
std::cout << "=== dstalk plugin_loader regression tests (W20.3) ===\n\n";
|
||||
|
||||
// ========================================================================
|
||||
// Block 1: 路径验证 — 拒绝非法路径 (F-18.3-3)
|
||||
// Block 1: Path validation — reject illegal paths (F-18.3-3)
|
||||
// ========================================================================
|
||||
std::cout << "--- Block 1: Path validation — rejection ---\n";
|
||||
{
|
||||
dstalk::PluginLoader loader;
|
||||
|
||||
// T1.1: nullptr
|
||||
// T1.1: nullptr / null pointer
|
||||
CHECK(loader.load_plugin(nullptr) == -1,
|
||||
"T1.1: nullptr path returns -1");
|
||||
|
||||
// T1.2: 非法扩展名 .txt
|
||||
// T1.2: 非法扩展名 .txt / illegal .txt extension
|
||||
CHECK(loader.load_plugin("plugins/test.txt") == -1,
|
||||
"T1.2: .txt extension rejected");
|
||||
|
||||
// T1.3: 路径含 .. 遍历
|
||||
// T1.3: 路径含 .. 遍历 / path contains .. traversal
|
||||
CHECK(loader.load_plugin("../plugins/test.dll") == -1,
|
||||
"T1.3: ../ traversal rejected");
|
||||
|
||||
// T1.4: 不在 plugins/ 目录下
|
||||
// T1.4: 不在 plugins/ 目录下 / not under plugins/ dir
|
||||
auto tmp = fs::temp_directory_path() / "dstalk_test_no_plugins" / "test.dll";
|
||||
CHECK(loader.load_plugin(tmp.string().c_str()) == -1,
|
||||
"T1.4: path not under plugins/ dir rejected");
|
||||
|
||||
// T1.5: 路径中间的 .. 段
|
||||
// T1.5: 路径中间的 .. 段 / .. segment in middle of path
|
||||
CHECK(loader.load_plugin("plugins/../secret/test.dll") == -1,
|
||||
"T1.5: .. in middle of path rejected");
|
||||
|
||||
// T1.6: 无扩展名
|
||||
// T1.6: 无扩展名 / no extension
|
||||
CHECK(loader.load_plugin("plugins/test") == -1,
|
||||
"T1.6: no extension rejected");
|
||||
|
||||
// T1.7: 合法扩展名但不在 plugins/ 下
|
||||
// T1.7: 合法扩展名但不在 plugins/ 下 / valid extension but not under plugins/
|
||||
CHECK(loader.load_plugin("/etc/someconfig.so") == -1,
|
||||
"T1.7: .so extension but not under plugins/ rejected");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Block 2: 合法路径 — 成功加载 + next_id_ 验证 (F-18.3-4)
|
||||
// Block 2: Valid path — successful load + ID uniqueness (F-18.3-4)
|
||||
// ========================================================================
|
||||
std::cout << "\n--- Block 2: Valid path — successful load + ID uniqueness ---\n";
|
||||
{
|
||||
@@ -144,23 +157,23 @@ int main()
|
||||
std::cout << "[WARN] Plugin DLLs not found at " << plugins_dir.string()
|
||||
<< " — skipping Block 2\n";
|
||||
} else {
|
||||
// T2.1: 加载第一个插件
|
||||
// T2.1: 加载第一个插件 / load first plugin
|
||||
int id1 = loader.load_plugin(dll_config.string().c_str());
|
||||
CHECK(id1 >= 1, "T2.1: first plugin loaded with positive ID");
|
||||
std::cout << " id1 = " << id1 << "\n";
|
||||
|
||||
// T2.2: 加载第二个不同插件
|
||||
// T2.2: 加载第二个不同插件 / load second (different) plugin
|
||||
int id2 = loader.load_plugin(dll_fileio.string().c_str());
|
||||
CHECK(id2 >= 1, "T2.2: second plugin loaded with positive ID");
|
||||
std::cout << " id2 = " << id2 << "\n";
|
||||
|
||||
// T2.3: ID 唯一
|
||||
// T2.3: ID 唯一 / IDs are unique
|
||||
CHECK(id1 != id2, "T2.3: IDs are unique (next_id_ atomicity)");
|
||||
|
||||
// T2.4: ID 单调递增
|
||||
// T2.4: ID 单调递增 / IDs monotonically increasing
|
||||
CHECK(id2 > id1, "T2.4: IDs monotonically increasing");
|
||||
|
||||
// T2.5: get_plugin 可查询到已加载插件
|
||||
// T2.5: get_plugin 可查询到已加载插件 / get_plugin can find loaded plugin
|
||||
const dstalk::PluginInfo* info1 = loader.get_plugin(id1);
|
||||
CHECK(info1 != nullptr, "T2.5: get_plugin(id1) returns non-null");
|
||||
if (info1) {
|
||||
@@ -168,23 +181,24 @@ int main()
|
||||
std::cout << " plugin1 name: " << info1->name << "\n";
|
||||
}
|
||||
|
||||
// T2.7: get_plugin 对无效 ID 返回 nullptr
|
||||
// T2.7: get_plugin 对无效 ID 返回 nullptr / get_plugin returns nullptr for invalid ID
|
||||
CHECK(loader.get_plugin(99999) == nullptr,
|
||||
"T2.7: get_plugin(invalid_id) returns nullptr");
|
||||
|
||||
// T2.8: 卸载后 get_plugin 返回 nullptr
|
||||
// T2.8: 卸载后 get_plugin 返回 nullptr / get_plugin returns nullptr after unload
|
||||
int ret = loader.unload_plugin(id1);
|
||||
CHECK(ret == 0, "T2.8: unload_plugin returns 0");
|
||||
CHECK(loader.get_plugin(id1) == nullptr,
|
||||
"T2.9: get_plugin returns nullptr after unload");
|
||||
|
||||
// 清理
|
||||
// 清理 / cleanup
|
||||
loader.unload_plugin(id2);
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Block 3: next_id_ 原子性 — 多线程并发加载 (F-18.3-4)
|
||||
// Block 3: next_id_ atomicity — concurrent loads (F-18.3-4)
|
||||
// ========================================================================
|
||||
std::cout << "\n--- Block 3: next_id_ atomicity — concurrent loads ---\n";
|
||||
{
|
||||
@@ -213,7 +227,7 @@ int main()
|
||||
|
||||
for (auto& t : threads) t.join();
|
||||
|
||||
// 验证: 所有 load 成功, ID 唯一且 > 0
|
||||
// 验证: 所有 load 成功, ID 唯一且 > 0 / Verify: all loads succeed, IDs unique and > 0
|
||||
std::vector<int> valid_ids;
|
||||
for (size_t i = 0; i < ids.size(); ++i) {
|
||||
CHECK(ids[i] >= 1, "T3." + std::to_string(i)
|
||||
@@ -222,7 +236,7 @@ int main()
|
||||
if (ids[i] >= 1) valid_ids.push_back(ids[i]);
|
||||
}
|
||||
|
||||
// 去重后大小应等于成功加载数
|
||||
// 去重后大小应等于成功加载数 / dedup size should equal successful load count
|
||||
std::sort(valid_ids.begin(), valid_ids.end());
|
||||
auto dup = std::unique(valid_ids.begin(), valid_ids.end());
|
||||
size_t unique_count = std::distance(valid_ids.begin(), dup);
|
||||
@@ -231,26 +245,27 @@ int main()
|
||||
+ std::to_string(unique_count) + "/"
|
||||
+ std::to_string(valid_ids.size()) + ")");
|
||||
|
||||
// 清理
|
||||
// 清理 / cleanup
|
||||
for (int id : valid_ids) loader.unload_plugin(id);
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Block 4: 失败路径日志 — host_api->log 被调用 (F-18.3-2)
|
||||
// Block 4: Failure-path logging — host_api->log is called (F-18.3-2)
|
||||
// ========================================================================
|
||||
std::cout << "\n--- Block 4: Failure-path logging (host_api->log) ---\n";
|
||||
{
|
||||
dstalk::PluginLoader loader;
|
||||
|
||||
// 4.1: 无 host_api 时 load_plugin 失败不崩溃
|
||||
// 4.1: 无 host_api 时 load_plugin 失败不崩溃 / load_plugin fails without crash when no host_api
|
||||
reset_log_state();
|
||||
int id = loader.load_plugin("bad_ext.noext");
|
||||
CHECK(id == -1, "T4.1: load_plugin with invalid ext returns -1 (no host_api)");
|
||||
CHECK(g_log_call_count == 0,
|
||||
"T4.2: log NOT called when host_api_ is null");
|
||||
|
||||
// 4.2: 设置 mock host_api 后验证 log 被调用
|
||||
// 4.2: 设置 mock host_api 后验证 log 被调用 / set mock host_api and verify log is called
|
||||
int init_ret = loader.initialize_all(&g_mock_host_api);
|
||||
CHECK(init_ret == 0, "T4.3: initialize_all with mock host_api returns 0");
|
||||
|
||||
@@ -263,7 +278,7 @@ int main()
|
||||
"T4.6: log severity is DSTALK_LOG_ERROR");
|
||||
std::cout << " log msg: " << g_last_log_msg << "\n";
|
||||
|
||||
// 4.3: LoadLibrary 失败也触发 log (文件不存在)
|
||||
// 4.3: LoadLibrary 失败也触发 log (文件不存在) / LoadLibrary failure also triggers log (file missing)
|
||||
reset_log_state();
|
||||
fs::path missing = get_plugins_dir() / "nonexistent_plugin.dll";
|
||||
id = loader.load_plugin(missing.string().c_str());
|
||||
@@ -275,28 +290,29 @@ int main()
|
||||
|
||||
// ========================================================================
|
||||
// Block 5: 边界 — 空 loader / 无效操作
|
||||
// Block 5: Edge cases — empty loader / invalid operations
|
||||
// ========================================================================
|
||||
std::cout << "\n--- Block 5: Edge cases — empty loader / invalid op ---\n";
|
||||
{
|
||||
dstalk::PluginLoader loader;
|
||||
|
||||
// T5.1: unload 不存在的 ID 返回 -1
|
||||
// T5.1: unload 不存在的 ID 返回 -1 / unload non-existent ID returns -1
|
||||
CHECK(loader.unload_plugin(42) == -1,
|
||||
"T5.1: unload_plugin(nonexistent) returns -1");
|
||||
|
||||
// T5.2: 空 PluginLoader 的 list_plugins 返回 "[]"
|
||||
// T5.2: 空 PluginLoader 的 list_plugins 返回 "[]" / empty PluginLoader list_plugins returns "[]"
|
||||
std::string json = loader.list_plugins();
|
||||
CHECK(!json.empty(), "T5.2: list_plugins returns non-empty string");
|
||||
CHECK(json == "[]", "T5.3: empty loader produces empty JSON array");
|
||||
std::cout << " list_plugins (empty): " << json << "\n";
|
||||
|
||||
// T5.3: get_plugin 在空 loader 上返回 nullptr
|
||||
// T5.3: get_plugin 在空 loader 上返回 nullptr / get_plugin on empty loader returns nullptr
|
||||
CHECK(loader.get_plugin(1) == nullptr,
|
||||
"T5.4: get_plugin on empty loader returns nullptr");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// 结果
|
||||
// 结果 / Result
|
||||
// ========================================================================
|
||||
std::cout << "\n";
|
||||
if (g_failures == 0) {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// ============================================================================
|
||||
// service_registry_test.cpp — ServiceRegistry 单元测试(补充覆盖,不与 host_api_test 重叠)
|
||||
// ============================================================================
|
||||
// host_api_test 已覆盖: 重复注册(同名同版/同名异版)、查询不存在服务、版本不满足、
|
||||
// shutdown 后查询。本测试补充边界与生命周期路径。
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file service_registry_test.cpp
|
||||
* @brief ServiceRegistry unit tests (supplement to host_api_test): register,
|
||||
* query, version check, unregister, null-pointer safety, re-registration.
|
||||
* ServiceRegistry 单元测试(host_api_test 补充):注册、查询、版本检查、取消注册、空指针安全、重新注册。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
@@ -12,6 +13,7 @@
|
||||
|
||||
// ---- 轻量断言 ----
|
||||
static int g_failures = 0;
|
||||
// Lightweight assertion helper: increments g_failures counter on failure
|
||||
#define TCHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
std::cout << "[OK] " << (msg) << "\n"; \
|
||||
@@ -21,7 +23,11 @@ static int g_failures = 0;
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// ============================================================
|
||||
// ServiceRegistry 补充测试:空名称/虚表拒绝、完整生命周期(注册→查询→取消注册→查询为空)、
|
||||
// 取消注册空指针安全、取消注册后重新注册、空名称查询。
|
||||
// ServiceRegistry supplement tests: null-name/vtable rejection, full lifecycle
|
||||
// (register->query->unregister->query nullptr), unregister nullptr safety,
|
||||
// re-registration after unregister, and query with nullptr name.
|
||||
int main()
|
||||
{
|
||||
std::cout << "=== dstalk service_registry unit tests (supplement) ===\n\n";
|
||||
@@ -47,6 +53,7 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 3: 完整生命周期 — register → query → unregister → query(nullptr)
|
||||
// Test 3: full lifecycle — register → query → unregister → query(nullptr)
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::ServiceRegistry reg;
|
||||
@@ -66,6 +73,7 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 4: unregister_service(nullptr name) 不崩溃(安全空操作)
|
||||
// Test 4: unregister_service(nullptr name) does not crash (safe no-op)
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::ServiceRegistry reg;
|
||||
@@ -75,6 +83,7 @@ int main()
|
||||
|
||||
// ====================================================================
|
||||
// Test 5: 注册后重新注册同名 → 先 unregister 再 register 成功
|
||||
// Test 5: re-register same name after unregister → succeeds
|
||||
// ====================================================================
|
||||
{
|
||||
dstalk::ServiceRegistry reg;
|
||||
@@ -101,7 +110,7 @@ int main()
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
// 结果
|
||||
// 结果 / Result
|
||||
// ====================================================================
|
||||
std::cout << "\n";
|
||||
if (g_failures == 0) {
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
// ============================================================================
|
||||
// smoke_test.cpp — 插件化架构烟雾测试
|
||||
// ============================================================================
|
||||
// 测试: 核心初始化、插件加载、服务查询、file_io、session 功能
|
||||
// W13.6 (qa-xu 徐磊): 新增 R1-R4 回归保护点,覆盖 W11.7/W12 已修 bug
|
||||
// ============================================================================
|
||||
/*
|
||||
* @file smoke_test.cpp
|
||||
* @brief Basic smoke test: verifies dstalk_init/shutdown cycle, service queries,
|
||||
* file_io, session, null-safety, escape boundaries, tool chain, and
|
||||
* regression protections R1-R4 (W13.6 qa-xu).
|
||||
* 基础冒烟测试:验证 dstalk_init/shutdown 生命周期、服务查询、file_io、session、
|
||||
* 空指针安全、转义边界、工具链调用,以及回归保护 R1-R4 (W13.6 qa-xu)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
|
||||
@@ -14,6 +17,7 @@
|
||||
#include <string>
|
||||
|
||||
// ---- 回归测试断言 (W13.6 qa-xu) ----
|
||||
// Regression test assertion macro (W13.6 qa-xu): prints [OK]/[FAIL] and tracks failures
|
||||
static int g_regression_failures = 0;
|
||||
#define REGCHECK(cond, msg) do { \
|
||||
if (cond) { \
|
||||
@@ -24,19 +28,26 @@ static int g_regression_failures = 0;
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// ---- W21.5 mock tool handler (qa-xu) ----
|
||||
// W21.5 mock tool handler (qa-xu): increments call counter and returns mock result JSON
|
||||
// W21.5 模拟工具处理函数 (qa-xu):递增调用计数器并返回模拟结果 JSON
|
||||
static int g_mock_tool_called = 0;
|
||||
static char* mock_tool_handler(const char* /*args_json*/) {
|
||||
g_mock_tool_called++;
|
||||
return dstalk_strdup("{\"mock_result\":\"ok\"}");
|
||||
}
|
||||
|
||||
// 冒烟测试主流程:init → 服务查询 → file_io → session → ai → config,
|
||||
// 然后是扩展测试(空指针安全、转义边界、工具链、session 健壮性),
|
||||
// 接着是回归保护 R1-R3、W21.5 工具调用边界和 R4 生命周期循环。
|
||||
// Smoke test main: init -> service queries -> file_io -> session -> ai -> config,
|
||||
// then extended tests (null-safety, escape, tool chain, session robustness),
|
||||
// then regression protections R1-R3, W21.5 tool-call boundaries, and R4 lifecycle cycles.
|
||||
int main()
|
||||
{
|
||||
const auto dir = std::filesystem::temp_directory_path() / "dstalk-smoke-test";
|
||||
std::filesystem::create_directories(dir);
|
||||
|
||||
// 写一个配置文件用于初始化
|
||||
// 写一个配置文件用于初始化 / Write a config file for initialization
|
||||
const auto config_path = dir / "config.toml";
|
||||
{
|
||||
std::ofstream config(config_path);
|
||||
@@ -47,14 +58,14 @@ int main()
|
||||
<< "model = \"deepseek-v4-pro\"\n";
|
||||
}
|
||||
|
||||
// 初始化主机(会自动扫描 plugins/ 加载插件)
|
||||
// 初始化主机(会自动扫描 plugins/ 加载插件)/ Init host (auto-scans plugins/ to load plugins)
|
||||
if (dstalk_init(config_path.string().c_str()) != 0) {
|
||||
std::cerr << "dstalk_init failed\n";
|
||||
return 1;
|
||||
}
|
||||
std::cout << "[OK] dstalk_init succeeded\n";
|
||||
|
||||
// 验证插件列表
|
||||
// 验证插件列表 / Verify plugin list
|
||||
{
|
||||
char* list_json = nullptr;
|
||||
int ret = dstalk_plugin_list(&list_json);
|
||||
@@ -66,13 +77,13 @@ int main()
|
||||
}
|
||||
}
|
||||
|
||||
// 测试服务查询: file_io
|
||||
// 测试服务查询: file_io / Test service query: file_io
|
||||
auto* file_io = static_cast<const dstalk_file_io_service_t*>(
|
||||
dstalk_service_query("file_io", 1));
|
||||
if (file_io) {
|
||||
std::cout << "[OK] file_io service found\n";
|
||||
|
||||
// 测试写入
|
||||
// 测试写入 / Test write
|
||||
const auto file_path = dir / "sample.txt";
|
||||
constexpr const char* sample_content = "hello dstalk\nquote=\"yes\" tab=\t slash=\\";
|
||||
if (file_io->write(file_path.string().c_str(), sample_content) == 0) {
|
||||
@@ -83,7 +94,7 @@ int main()
|
||||
return 1;
|
||||
}
|
||||
|
||||
// 测试读取
|
||||
// 测试读取 / Test read
|
||||
char* content = nullptr;
|
||||
if (file_io->read(file_path.string().c_str(), &content) == 0 && content) {
|
||||
bool ok = std::strcmp(content, sample_content) == 0;
|
||||
@@ -104,13 +115,13 @@ int main()
|
||||
std::cerr << "[WARN] file_io service not found (plugin may not be in plugins/ dir)\n";
|
||||
}
|
||||
|
||||
// 测试服务查询: session
|
||||
// 测试服务查询: session / Test service query: session
|
||||
auto* session = static_cast<const dstalk_session_service_t*>(
|
||||
dstalk_service_query("session", 1));
|
||||
if (session) {
|
||||
std::cout << "[OK] session service found\n";
|
||||
|
||||
// 测试 session save/load
|
||||
// 测试 session save/load / Test session save/load
|
||||
const auto session_path = dir / "session.jsonl";
|
||||
const auto saved_path = dir / "session-saved.jsonl";
|
||||
constexpr const char* session_content =
|
||||
@@ -137,7 +148,7 @@ int main()
|
||||
return 1;
|
||||
}
|
||||
|
||||
// 验证保存的内容
|
||||
// 验证保存的内容 / Verify saved content
|
||||
if (file_io) {
|
||||
char* saved = nullptr;
|
||||
if (file_io->read(saved_path.string().c_str(), &saved) == 0 && saved) {
|
||||
@@ -153,16 +164,16 @@ int main()
|
||||
}
|
||||
}
|
||||
|
||||
// 测试 token 计数
|
||||
// 测试 token 计数 / Test token count
|
||||
int tokens = session->token_count();
|
||||
std::cout << "[OK] session->token_count: " << tokens << "\n";
|
||||
|
||||
// 测试 history
|
||||
// 测试 history / Test history
|
||||
int count = 0;
|
||||
session->history(&count);
|
||||
std::cout << "[OK] session->history count: " << count << "\n";
|
||||
|
||||
// 测试 clear
|
||||
// 测试 clear / Test clear
|
||||
session->clear();
|
||||
session->history(&count);
|
||||
if (count == 0) {
|
||||
@@ -173,6 +184,7 @@ int main()
|
||||
}
|
||||
|
||||
// 测试服务查询: ai(可能因为没有真实 API key 而失败,但服务应存在)
|
||||
// Test service query: ai (may fail without real API key, but service should exist)
|
||||
const char* ai_provider = dstalk_config_get("ai.provider");
|
||||
if (!ai_provider) ai_provider = "ai.deepseek";
|
||||
auto* ai = static_cast<const dstalk_ai_service_t*>(
|
||||
@@ -183,7 +195,7 @@ int main()
|
||||
std::cerr << "[WARN] ai service not found\n";
|
||||
}
|
||||
|
||||
// 测试服务查询: config
|
||||
// 测试服务查询: config / Test service query: config
|
||||
auto* config_svc = static_cast<const dstalk_config_service_t*>(
|
||||
dstalk_service_query("config", 1));
|
||||
if (config_svc) {
|
||||
@@ -196,21 +208,22 @@ int main()
|
||||
std::cerr << "[WARN] config service not found\n";
|
||||
}
|
||||
|
||||
// 测试 dstalk_config_get(主机级配置 API)
|
||||
// 测试 dstalk_config_get(主机级配置 API)/ Test dstalk_config_get (host-level config API)
|
||||
const char* model = dstalk_config_get("api.model");
|
||||
if (model) {
|
||||
std::cout << "[OK] dstalk_config_get(\"api.model\"): " << model << "\n";
|
||||
}
|
||||
|
||||
// 测试 dstalk_log
|
||||
// 测试 dstalk_log / Test dstalk_log
|
||||
dstalk_log(DSTALK_LOG_INFO, "Smoke test completed successfully");
|
||||
|
||||
// ========================================================================
|
||||
// 扩展测试块 C2: null-safety / 转义边界 / tools 调用链 / session 健壮性
|
||||
// Extended test block C2: null-safety / escape boundaries / tools chain / session robustness
|
||||
// ========================================================================
|
||||
std::cout << "\n--- Extended Smoke Tests (C2) ---\n";
|
||||
|
||||
// 提前查询 tools 服务,供后续测试块使用
|
||||
// 提前查询 tools 服务,供后续测试块使用 / Pre-query tools service for subsequent test blocks
|
||||
auto* tools = static_cast<const dstalk_tools_service_t*>(
|
||||
dstalk_service_query("tools", 1));
|
||||
|
||||
@@ -234,7 +247,7 @@ int main()
|
||||
std::cerr << "[FAIL] file_io->write(nullptr, ...) should return error\n";
|
||||
}
|
||||
|
||||
// read 的 content 参数也为 null
|
||||
// read 的 content 参数也为 null / read's content param also null
|
||||
ret = file_io->read("dummy_path", nullptr);
|
||||
if (ret != 0) {
|
||||
std::cout << "[OK] file_io->read(path, nullptr) returned error (" << ret << ")\n";
|
||||
@@ -242,7 +255,7 @@ int main()
|
||||
std::cerr << "[FAIL] file_io->read(path, nullptr) should return error\n";
|
||||
}
|
||||
|
||||
// write 的 content 参数为 null
|
||||
// write 的 content 参数为 null / write's content param is null
|
||||
ret = file_io->write("dummy_path", nullptr);
|
||||
if (ret != 0) {
|
||||
std::cout << "[OK] file_io->write(path, nullptr) returned error (" << ret << ")\n";
|
||||
@@ -278,6 +291,7 @@ int main()
|
||||
char* result = tools->execute(nullptr, nullptr);
|
||||
if (result) {
|
||||
// 实现返回了错误字符串(如 {"error":"tool name is null"}),未崩溃
|
||||
// Implementation returned error string (e.g. {"error":"tool name is null"}), no crash
|
||||
std::cout << "[OK] tools->execute(nullptr, nullptr) did not crash"
|
||||
<< " (returned: " << result << ")\n";
|
||||
dstalk_free(result);
|
||||
@@ -303,7 +317,7 @@ int main()
|
||||
std::cerr << "[FAIL] config->set(nullptr, nullptr) should return error\n";
|
||||
}
|
||||
|
||||
// set 的 value 为 null
|
||||
// set 的 value 为 null / set's value is null
|
||||
ret = config_svc->set("some.key", nullptr);
|
||||
if (ret != 0) {
|
||||
std::cout << "[OK] config->set(key, nullptr) returned error (" << ret << ")\n";
|
||||
@@ -316,6 +330,8 @@ int main()
|
||||
|
||||
// ---- 2. 转义边界测试 ----
|
||||
// 写入含特殊字符的内容,读回后验证内容一致
|
||||
// ---- Escape boundary tests ----
|
||||
// Write content with special chars, verify round-trip integrity
|
||||
std::cout << "\n[Block] Escape boundary tests\n";
|
||||
|
||||
if (file_io) {
|
||||
@@ -325,6 +341,12 @@ int main()
|
||||
// - 实际反斜杠 (0x5C)
|
||||
// - 实际制表符 (0x09)
|
||||
// - 以及字面上的 \n \" \\ \t 转义序列文本
|
||||
// Build content with various special bytes:
|
||||
// - literal newline (0x0A)
|
||||
// - literal double-quote (0x22)
|
||||
// - literal backslash (0x5C)
|
||||
// - literal tab (0x09)
|
||||
// - plus textual \n \" \\ \t escape sequences
|
||||
constexpr const char* escape_content =
|
||||
"line1\nline2\n"
|
||||
"quote=\"yes\"\n"
|
||||
@@ -363,22 +385,25 @@ int main()
|
||||
|
||||
// ---- 3. Tools 调用链测试 ----
|
||||
// 通过 tools->execute("file_read", ...) 验证内置工具可正确调用 file_io
|
||||
// ---- Tools call chain tests ----
|
||||
// Verify built-in tools correctly call file_io via tools->execute("file_read", ...)
|
||||
std::cout << "\n[Block] Tools call chain tests\n";
|
||||
|
||||
if (tools && file_io) {
|
||||
// 准备测试文件
|
||||
// 准备测试文件 / Prepare test file
|
||||
const auto chain_path = dir / "tool_chain_test.txt";
|
||||
constexpr const char* chain_content = "tools-chain-ok\n";
|
||||
file_io->write(chain_path.string().c_str(), chain_content);
|
||||
|
||||
// 用 generic_string() 获取正斜杠路径,避免 JSON 中反斜杠转义问题
|
||||
// Use generic_string() for forward-slash paths to avoid backslash escaping in JSON
|
||||
std::string generic_path = chain_path.generic_string();
|
||||
std::string args_json = "{\"path\":\"" + generic_path + "\"}";
|
||||
|
||||
char* result = tools->execute("file_read", args_json.c_str());
|
||||
if (result) {
|
||||
std::cout << "[OK] tools->execute(\"file_read\", ...) returned result\n";
|
||||
// 验证返回的 JSON 中包含原始文件内容
|
||||
// 验证返回的 JSON 中包含原始文件内容 / Verify returned JSON contains original file content
|
||||
if (std::strstr(result, "tools-chain-ok")) {
|
||||
std::cout << "[OK] tools->execute chain correctly called file_io\n";
|
||||
} else {
|
||||
@@ -391,7 +416,7 @@ int main()
|
||||
<< " (tool may not be registered)\n";
|
||||
}
|
||||
|
||||
// 额外测试:查询 tools 返回的工具列表
|
||||
// 额外测试:查询 tools 返回的工具列表 / Additional test: query tools list
|
||||
char* tools_json = tools->get_tools_json();
|
||||
if (tools_json) {
|
||||
std::cout << "[OK] tools->get_tools_json() returned: " << tools_json << "\n";
|
||||
@@ -406,14 +431,17 @@ int main()
|
||||
// ---- 4. Session 健壮性测试 ----
|
||||
// session->add(nullptr) 后验证 history 不变
|
||||
// session->clear 后验证 token_count 为 0
|
||||
// ---- Session robustness tests ----
|
||||
// Verify history unchanged after session->add(nullptr)
|
||||
// Verify token_count == 0 after session->clear
|
||||
std::cout << "\n[Block] Session robustness tests\n";
|
||||
|
||||
if (session) {
|
||||
// 记录 add(nullptr) 前的 history 计数
|
||||
// 记录 add(nullptr) 前的 history 计数 / Record history count before add(nullptr)
|
||||
int count_before = 0;
|
||||
session->history(&count_before);
|
||||
|
||||
// 传 null 不应改变 history
|
||||
// 传 null 不应改变 history / Passing null should not change history
|
||||
session->add(nullptr);
|
||||
|
||||
int count_after = 0;
|
||||
@@ -427,7 +455,7 @@ int main()
|
||||
<< count_before << " -> " << count_after << "\n";
|
||||
}
|
||||
|
||||
// clear 后 token_count 应为 0
|
||||
// clear 后 token_count 应为 0 / token_count should be 0 after clear
|
||||
session->clear();
|
||||
int tokens = session->token_count();
|
||||
if (tokens == 0) {
|
||||
@@ -443,6 +471,8 @@ int main()
|
||||
// ========================================================================
|
||||
// W13.6 回归保护点 R1-R3 (qa-xu 徐磊)
|
||||
// 覆盖: W11.7 BUG-2/3/4 + W11.1 Discovery 2/3 + W12.2/W12.3 修复
|
||||
// W13.6 regression protections R1-R3 (qa-xu)
|
||||
// Covers: W11.7 BUG-2/3/4 + W11.1 Discovery 2/3 + W12.2/W12.3 fixes
|
||||
// ========================================================================
|
||||
std::cout << "\n--- Regression Tests (R1-R3: W11.7/W12 bug protection) ---\n";
|
||||
|
||||
@@ -450,6 +480,10 @@ int main()
|
||||
// 回归: W11.1 Discovery 3 (g_max_tokens 死变量 — W12.3 已修, W18.1 彻底移除)
|
||||
// W11.7 BUG-3 (/context 静默 — W12.3 已修)
|
||||
// 验证: trim 能正确裁剪消息数,调用链完整不崩溃
|
||||
// ---- R1: context max_tokens takes effect ----
|
||||
// Regression: W11.1 Discovery 3 (g_max_tokens dead var — fixed W12.3, removed W18.1)
|
||||
// W11.7 BUG-3 (/context silent — fixed W12.3)
|
||||
// Verify: trim reduces message count correctly, full call chain without crash
|
||||
{
|
||||
auto* ctx = static_cast<const dstalk_context_service_t*>(
|
||||
dstalk_service_query("context", 1));
|
||||
@@ -457,6 +491,7 @@ int main()
|
||||
std::cout << "[OK] R1: context service found\n";
|
||||
|
||||
// 构造 5 条消息,每条 ~50 字符 / ~15 token,总计 ~75 token > 50 max
|
||||
// Build 5 messages, each ~50 chars / ~15 tokens, total ~75 tokens > 50 max
|
||||
dstalk_message_t msgs[5];
|
||||
msgs[0] = {"user", "Hello this is message one with enough text to count tokens", nullptr, nullptr};
|
||||
msgs[1] = {"assistant", "Message two also has sufficient length for token counting", nullptr, nullptr};
|
||||
@@ -476,6 +511,7 @@ int main()
|
||||
dstalk_free(out);
|
||||
} else if (ret >= 0) {
|
||||
// 首条消息即超 max_tokens 时 trim 可能返回空,这也是合法路径
|
||||
// When first message exceeds max_tokens, trim may return empty; also valid
|
||||
std::cout << "[WARN] R1: trim returned null output (single msg exceeds max?)\n";
|
||||
}
|
||||
} else {
|
||||
@@ -487,15 +523,19 @@ int main()
|
||||
// 回归: W11.2 Discovery 2 (双 ConfigStore 数据孤岛 — W12.2 已修)
|
||||
// W11.2 Discovery 3 (c_str() 悬垂 — W12.2 已修)
|
||||
// 验证: dstalk_config_set 写入后,dstalk_config_get 和 config_service->get 返回一致值
|
||||
// ---- R2: config dual-store consistency ----
|
||||
// Regression: W11.2 Discovery 2 (dual ConfigStore islands — fixed W12.2)
|
||||
// W11.2 Discovery 3 (c_str() dangling — fixed W12.2)
|
||||
// Verify: after dstalk_config_set write, dstalk_config_get and config_service->get return same value
|
||||
{
|
||||
constexpr const char* k = "__regr_w13_6_dual";
|
||||
constexpr const char* v = "dual_ok_42";
|
||||
|
||||
// 通过 host API 写入
|
||||
// 通过 host API 写入 / Write via host API
|
||||
int set_ret = dstalk_config_set(k, v);
|
||||
REGCHECK(set_ret == 0, "R2: dstalk_config_set returned 0");
|
||||
|
||||
// 通过 host API 读回
|
||||
// 通过 host API 读回 / Read back via host API
|
||||
const char* host_val = dstalk_config_get(k);
|
||||
REGCHECK(host_val && std::strcmp(host_val, v) == 0,
|
||||
"R2: dstalk_config_get matches written value");
|
||||
@@ -503,6 +543,9 @@ int main()
|
||||
// 通过 plugin config 服务读回 — 验证双 store 整合后数据可见性一致
|
||||
// 注: W12.2 双 store 整合尚未部署,跨 store 可见性当前为已知 gap;
|
||||
// 本检查用 WARN 记录现状,待 W12.2 fix 落地后改为 REGCHECK
|
||||
// Read back via plugin config service — verify visibility after dual-store merge
|
||||
// Note: W12.2 dual-store merge not yet deployed; cross-store visibility is a known gap;
|
||||
// this check uses WARN to record status, upgrade to REGCHECK after W12.2 lands
|
||||
auto* cfg_svc = static_cast<const dstalk_config_service_t*>(
|
||||
dstalk_service_query("config", 1));
|
||||
if (cfg_svc) {
|
||||
@@ -520,7 +563,7 @@ int main()
|
||||
std::cerr << "[WARN] R2: config service not found, partial skip\n";
|
||||
}
|
||||
|
||||
// 清理测试 key
|
||||
// 清理测试 key / Clean up test key
|
||||
dstalk_config_set(k, "");
|
||||
}
|
||||
|
||||
@@ -529,6 +572,11 @@ int main()
|
||||
// W11.7 BUG-4 (/file write 落空) 同类的错误路径静默问题
|
||||
// 验证: http post_json 到不可达目标返回错误而不崩溃;
|
||||
// 若 http 服务不可用,回退测 ai 服务错误路径
|
||||
// ---- R3: HTTP / AI service error paths do not crash ----
|
||||
// Regression: W12.1 removed TLS/http_client code (removed rewritten network layer)
|
||||
// W11.7 BUG-4 (/file write miss) similar error-path silent issues
|
||||
// Verify: http post_json to unreachable target returns error without crash;
|
||||
// fall back to ai service error path if http unavailable
|
||||
{
|
||||
auto* http = static_cast<const dstalk_http_service_t*>(
|
||||
dstalk_service_query("http", 1));
|
||||
@@ -536,6 +584,8 @@ int main()
|
||||
std::cout << "[OK] R3: http service found\n";
|
||||
// 向 127.0.0.1:1 发请求 — 端口 1 在 Windows 上几乎肯定无服务监听
|
||||
// 连接拒绝应立即返回错误而非崩溃
|
||||
// Send request to 127.0.0.1:1 — port 1 on Windows almost certainly has no listener
|
||||
// Connection refused should return error immediately, not crash
|
||||
char* body = nullptr;
|
||||
int status = 0;
|
||||
int ret = http->post_json("127.0.0.1", "1", "/",
|
||||
@@ -549,6 +599,7 @@ int main()
|
||||
}
|
||||
} else {
|
||||
// 回退:测 AI 服务 (ai.deepseek) 错误路径
|
||||
// Fallback: test AI service (ai.deepseek) error path
|
||||
auto* ai_svc = static_cast<const dstalk_ai_service_t*>(
|
||||
dstalk_service_query("ai.deepseek", 1));
|
||||
if (ai_svc) {
|
||||
@@ -556,6 +607,7 @@ int main()
|
||||
dstalk_message_t msg = {"user", "hi", nullptr, nullptr};
|
||||
dstalk_chat_result_t r = ai_svc->chat(&msg, 1, "", nullptr);
|
||||
// api_key="test-key" 为无效 key,应返回 error result 而非崩溃
|
||||
// api_key="test-key" is invalid, should return error result, not crash
|
||||
REGCHECK(r.ok == 0 || r.error != nullptr,
|
||||
"R3: ai->chat with invalid key returned error result (no crash)");
|
||||
if (r.content) dstalk_free((void*)r.content);
|
||||
@@ -570,11 +622,14 @@ int main()
|
||||
// ========================================================================
|
||||
// W21.5 Tool Calls 边界测试 (qa-xu 徐磊)
|
||||
// 覆盖: null tool_calls_json / 空数组 "[]" / 有效 tool_calls mock 验证
|
||||
// W21.5 Tool Calls boundary tests (qa-xu)
|
||||
// Covers: null tool_calls_json / empty array "[]" / valid tool_calls mock verification
|
||||
// ========================================================================
|
||||
std::cout << "\n--- Tool Calls Boundary Tests (W21.5) ---\n";
|
||||
|
||||
if (tools && session) {
|
||||
// ---- W21.5-1: null tool_calls_json → 正常处理(不崩溃)----
|
||||
// ---- W21.5-1: null tool_calls_json → handle normally (no crash) ----
|
||||
{
|
||||
int before = 0;
|
||||
session->history(&before);
|
||||
@@ -595,6 +650,7 @@ int main()
|
||||
}
|
||||
|
||||
// ---- W21.5-2: 空 JSON 数组 "[]" → 正常处理(不崩溃)----
|
||||
// ---- W21.5-2: empty JSON array "[]" → handle normally (no crash) ----
|
||||
{
|
||||
int before = 0;
|
||||
session->history(&before);
|
||||
@@ -616,6 +672,7 @@ int main()
|
||||
}
|
||||
|
||||
// ---- W21.5-3: 有效 tool_calls JSON → 验证 execute 被调用 (mock) ----
|
||||
// ---- W21.5-3: valid tool_calls JSON → verify execute is called (mock) ----
|
||||
{
|
||||
g_mock_tool_called = 0;
|
||||
int reg = tools->register_tool(
|
||||
@@ -638,6 +695,7 @@ int main()
|
||||
tools->unregister_tool("__w21_5_mock");
|
||||
|
||||
// 验证已注销的工具返回 error 而非崩溃
|
||||
// Verify unregistered tool returns error, not crash
|
||||
char* err_result = tools->execute("__w21_5_mock", "{}");
|
||||
REGCHECK(err_result && std::strstr(err_result, "error") != nullptr,
|
||||
"W21.5-3d: unregistered tool returns error (not crash)");
|
||||
@@ -645,6 +703,7 @@ int main()
|
||||
}
|
||||
|
||||
// ---- W21.5-4: save/load 往返保留 tool_calls_json ----
|
||||
// ---- W21.5-4: save/load round-trip preserves tool_calls_json ----
|
||||
if (file_io) {
|
||||
const auto rtt_path = dir / "w21_5_tc_rtt.jsonl";
|
||||
int ret = session->save(rtt_path.string().c_str());
|
||||
@@ -664,23 +723,28 @@ int main()
|
||||
std::cerr << "[WARN] W21.5: tools or session service not available\n";
|
||||
}
|
||||
|
||||
// 清理
|
||||
// 清理 / Cleanup
|
||||
dstalk_shutdown();
|
||||
std::cout << "[OK] dstalk_shutdown succeeded\n";
|
||||
|
||||
// ========================================================================
|
||||
// W13.6 回归保护点 R4 (qa-xu 徐磊)
|
||||
// W13.6 regression protection R4 (qa-xu)
|
||||
// ========================================================================
|
||||
|
||||
// ---- R4: 重复 init / shutdown 生命周期 ----
|
||||
// 回归: W9.8 initialize_all 容错 (插件生命周期健壮性)
|
||||
// W11.7 BUG-1 [CRITICAL] build/bin/ 损坏副本 (stale state 残留)
|
||||
// 验证: 多次 dstalk_init/dstalk_shutdown 循环不崩溃,每次 reload 正常
|
||||
// ---- R4: repeat init/shutdown lifecycle ----
|
||||
// Regression: W9.8 initialize_all fault tolerance (plugin lifecycle robustness)
|
||||
// W11.7 BUG-1 [CRITICAL] build/bin/ corrupt copy (stale state residue)
|
||||
// Verify: multiple dstalk_init/dstalk_shutdown cycles without crash, each reload ok
|
||||
{
|
||||
std::cout << "\n[Block] R4: Repeat init/shutdown lifecycle\n";
|
||||
constexpr int cycles = 3;
|
||||
for (int i = 0; i < cycles; i++) {
|
||||
// 每轮重写配置(模拟独立启动)
|
||||
// 每轮重写配置(模拟独立启动)/ Rewrite config each cycle (simulate independent start)
|
||||
{
|
||||
std::ofstream c(config_path);
|
||||
c << "[api]\n"
|
||||
@@ -700,7 +764,7 @@ int main()
|
||||
break;
|
||||
}
|
||||
|
||||
// 快速验证服务可用
|
||||
// 快速验证服务可用 / Quick verify service is available
|
||||
void* q = dstalk_service_query("config", 1);
|
||||
REGCHECK(q != nullptr, "R4: service query ok after init");
|
||||
|
||||
@@ -710,7 +774,7 @@ int main()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- 最终结果 ----
|
||||
// ---- 最终结果 / Final result ----
|
||||
std::cout << "\n";
|
||||
if (g_regression_failures == 0) {
|
||||
std::cout << "=== All smoke tests passed ===\n";
|
||||
|
||||
0
模块目录和功能说明.md
Normal file
0
模块目录和功能说明.md
Normal file
39
说明此文件不可AI修改.txt
Normal file
39
说明此文件不可AI修改.txt
Normal file
@@ -0,0 +1,39 @@
|
||||
此文件不可AI修改!
|
||||
此文件不可AI修改!
|
||||
此文件不可AI修改!
|
||||
|
||||
说明:此文件包含重要信息,禁止使用AI进行修改。
|
||||
|
||||
dstalk名称中的ds来源于deepseek和display,希望能够实现交流的简洁,为人类提供更好的服务。
|
||||
|
||||
|
||||
dstalk是基于C/C++开发的基础cli接口,提供高性能接口以方便开发者进行二次开发,同时提供了丰富的接口以满足不同的需求。
|
||||
dstalk基于多模块的自动加载和自我依赖,完全自下而上设计,
|
||||
所有基于dstalk的模块均可以跟随dstalk实现跨平台接入和相互依赖,自动加载等。
|
||||
dstalk更像是一种编程框架和基础系统支持。
|
||||
|
||||
dstalk本质上只有dstalk网关是核心,所有模块的沟通和相互依赖以及使用都是通过dstalk网关来实现的。
|
||||
|
||||
dstalk提供以下默认安装的模块,方便大多数模块的开发和使用:
|
||||
1. 基础模块:提供了基础的输入输出、文件操作、网络通信等接口,方便开发者进行二次开发。
|
||||
2. 模块管理模块:提供了模块化的接口,方便开发者进行模块化开发和管理。
|
||||
3. AI接入Openai兼容格式模块:提供了AI接入openai兼容格式的接口,方便开发者进行AI接入和使用。
|
||||
4. AI接入Anthropic兼容格式模块:提供了AI接入Anthropic兼容格式的接口,方便开发者进行AI接入和使用。
|
||||
5. Anthropic和Openai兼容格式转换模块:提供了Anthropic和Openai兼容格式转换的接口,方便开发者进行格式转换和使用。
|
||||
|
||||
依赖于基础模块的模块:
|
||||
1. AI接入自动识别接口:依赖于基础模块,提供了AI接入所有格式的接口,方便开发者进行AI接入和使用。
|
||||
|
||||
|
||||
|
||||
|
||||
dstalk类似openclaw的geteway可以实现相互的沟通和控制,一切操作基于dstalk,源于模块,用于模块,服务于模块。
|
||||
dstalk的设计理念是模块化、自动加载、自我依赖、跨平台、易用性和高性能,旨在为开发者提供一个强大而灵活的基础系统支持,
|
||||
以便他们能够专注于开发自己的应用程序,而不必担心底层的细节。
|
||||
dstalk的目标是成为一个强大而灵活的操作系统,帮助开发者更高效地开发应用程序,以满足不同的需求。
|
||||
|
||||
|
||||
newapi测试数据:
|
||||
|
||||
"key":"sk-DWiHMg4T3cIxWUSwRGtjLuPe1c8FuwM0FiGyoyuNFWGpkhjY"
|
||||
"url":"https://api.ai.pulsareon.com"
|
||||
Reference in New Issue
Block a user