Add metadata validation script and module documentation
- Introduced a new Python script `check_agents_metadata.py` for validating agent metadata, including YAML parsing, rating ranges, and cross-references. - Added usage instructions and exit codes for the script. - Created a new markdown file `模块目录和功能说明.md` to outline the directory structure and functionality of the modules. - Added a text file `说明此文件不可AI修改.txt` to specify that certain files should not be modified by AI, including important information about the `dstalk` framework and its modules.
This commit is contained in:
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file anthropic_plugin.cpp
|
||||
* @brief Anthropic Claude Messages API provider plugin with streaming support.
|
||||
* Anthropic Claude Messages API 提供者插件,支持流式输出。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
@@ -11,14 +18,14 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 全局指针 — W17.4: std::atomic 保护 on_shutdown 与 service 函数并发读写
|
||||
// 全局指针 — W17.4: std::atomic 保护 on_shutdown 与 service 函数并发读写 / Global pointers — W17.4: std::atomic protects concurrent read/write between on_shutdown and service functions
|
||||
// ============================================================================
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
|
||||
static dstalk_config_service_t* g_config = nullptr;
|
||||
|
||||
// ============================================================================
|
||||
// 配置数据
|
||||
// 配置数据 / Config data
|
||||
// ============================================================================
|
||||
struct PluginConfig {
|
||||
std::string provider;
|
||||
@@ -29,19 +36,21 @@ struct PluginConfig {
|
||||
double temperature = 0.7;
|
||||
};
|
||||
static PluginConfig g_cfg;
|
||||
static std::string g_tools_json; // W21.2: cached by configure(), consumed by chat/chat_stream
|
||||
static std::string g_tools_json; // W21.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
|
||||
|
||||
// ============================================================================
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
|
||||
// ============================================================================
|
||||
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
|
||||
static void secure_zero(void* p, size_t n) {
|
||||
volatile char* vp = (volatile char*)p;
|
||||
while (n--) *vp++ = 0;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:提取 host / target
|
||||
// 辅助:提取 host / target / Helper: extract host / target
|
||||
// ============================================================================
|
||||
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
|
||||
static bool extract_host_port(const std::string& url,
|
||||
std::string& scheme_out, std::string& host_out,
|
||||
std::string& port_out, std::string& target_out)
|
||||
@@ -65,8 +74,9 @@ static bool extract_host_port(const std::string& url,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 构建 Anthropic headers JSON
|
||||
// 构建 Anthropic headers JSON / Build Anthropic headers JSON
|
||||
// ============================================================================
|
||||
// 构建包含 x-api-key 和 anthropic-version 的 JSON headers 对象 / Build the JSON headers object containing x-api-key and anthropic-version.
|
||||
static std::string build_headers_json()
|
||||
{
|
||||
json::object h;
|
||||
@@ -76,8 +86,11 @@ static std::string build_headers_json()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 构建 Anthropic JSON 请求体
|
||||
// 构建 Anthropic JSON 请求体 / Build Anthropic JSON request body
|
||||
// ============================================================================
|
||||
// 构建 Anthropic Messages API 的完整 JSON 请求体。
|
||||
// 按 Anthropic 规范将 system 消息提取为顶层 system 字段 / Build the full JSON request body for the Anthropic Messages API.
|
||||
// Extracts system messages as a top-level "system" field per Anthropic spec.
|
||||
static std::string build_request_json(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const std::string& user_input,
|
||||
@@ -89,7 +102,7 @@ static std::string build_request_json(
|
||||
root["max_tokens"] = g_cfg.max_tokens;
|
||||
root["stream"] = stream;
|
||||
|
||||
// 提取 system 消息作为顶层字段
|
||||
// 提取 system 消息作为顶层字段 / Extract system messages as top-level field
|
||||
std::string system_prompt;
|
||||
json::array msgs;
|
||||
|
||||
@@ -106,7 +119,7 @@ static std::string build_request_json(
|
||||
msgs.push_back(obj);
|
||||
}
|
||||
|
||||
// 追加当前用户输入
|
||||
// 追加当前用户输入 / Append current user input
|
||||
{
|
||||
json::object obj;
|
||||
obj["role"] = "user";
|
||||
@@ -124,7 +137,7 @@ static std::string build_request_json(
|
||||
root["temperature"] = g_cfg.temperature;
|
||||
}
|
||||
|
||||
// W21.2: tools 定义传递给 API
|
||||
// W21.2: tools 定义传递给 API / Pass tools definition to API
|
||||
if (!tools_json.empty()) {
|
||||
root["tools"] = json::parse(tools_json);
|
||||
}
|
||||
@@ -133,8 +146,11 @@ static std::string build_request_json(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 解析非流式响应
|
||||
// 解析非流式响应 / Parse non-streaming response
|
||||
// ============================================================================
|
||||
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t。
|
||||
// 处理 text 和 tool_use content block,将 tool_use 转换为 OpenAI 格式 / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
|
||||
// Handles text and tool_use content blocks, converting tool_use to OpenAI format.
|
||||
static void parse_response(const char* body, int http_status,
|
||||
dstalk_chat_result_t& r)
|
||||
{
|
||||
@@ -169,7 +185,7 @@ static void parse_response(const char* body, int http_status,
|
||||
auto obj = jv.as_object();
|
||||
auto content = obj["content"].as_array();
|
||||
if (!content.empty()) {
|
||||
// W21.2: 提取 text 和 tool_use content blocks
|
||||
// W21.2: 提取 text 和 tool_use content blocks / Extract text and tool_use content blocks
|
||||
std::string text_content;
|
||||
json::array tool_use_blocks;
|
||||
|
||||
@@ -181,7 +197,7 @@ static void parse_response(const char* body, int http_status,
|
||||
if (btype == "text") {
|
||||
text_content = json::value_to<std::string>(bobj["text"]);
|
||||
} else if (btype == "tool_use") {
|
||||
// 转换为 OpenAI 兼容格式: {id, type:"function", function:{name, arguments}}
|
||||
// 转换为 OpenAI 兼容格式: {id, type:"function", function:{name, arguments}} / Convert to OpenAI-compatible format: {id, type:"function", function:{name, arguments}}
|
||||
json::object tc;
|
||||
tc["id"] = bobj["id"];
|
||||
tc["type"] = "function";
|
||||
@@ -206,7 +222,7 @@ static void parse_response(const char* body, int http_status,
|
||||
r.error = nullptr;
|
||||
return;
|
||||
} else if (!tool_use_blocks.empty()) {
|
||||
// tool-only 响应
|
||||
// tool-only 响应 / tool-only response
|
||||
r.content = nullptr;
|
||||
r.ok = 1;
|
||||
r.error = nullptr;
|
||||
@@ -235,15 +251,15 @@ static void parse_response(const char* body, int http_status,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE 事件解析(Anthropic 格式: event/content_block_delta)
|
||||
// SSE 事件解析(Anthropic 格式: event/content_block_delta) / SSE event parsing (Anthropic format: event/content_block_delta)
|
||||
// ============================================================================
|
||||
|
||||
// W21.2: 按 content_block index 累积 Anthropic tool_use 增量
|
||||
// W21.2: 按 content_block index 累积 Anthropic tool_use 增量 / Accumulate Anthropic tool_use increments by content_block index
|
||||
struct ToolCallAccum {
|
||||
int index = -1;
|
||||
std::string id;
|
||||
std::string name;
|
||||
std::string arguments; // 从 input_json_delta.partial_json 累积
|
||||
std::string arguments; // 从 input_json_delta.partial_json 累积 / accumulated from input_json_delta.partial_json
|
||||
};
|
||||
|
||||
struct StreamContext {
|
||||
@@ -252,10 +268,15 @@ struct StreamContext {
|
||||
void* userdata;
|
||||
std::string accumulated;
|
||||
bool saw_data_line = false;
|
||||
std::vector<ToolCallAccum> tool_calls; // W21.2: 按 index 累积 tool_use content blocks
|
||||
std::vector<ToolCallAccum> tool_calls; // W21.2: 按 index 累积 tool_use content blocks / accumulate tool_use content blocks by index
|
||||
};
|
||||
|
||||
// W21.2: 解析 Anthropic SSE 事件,含 tool_use content_block 增量解析
|
||||
// W21.2: 解析 Anthropic SSE 事件,含 tool_use content_block 增量解析 / Parse Anthropic SSE events with tool_use content_block incremental parsing
|
||||
// 解析单个 Anthropic SSE "data:" JSON 事件。处理 content_block_start、
|
||||
// content_block_delta (text_delta/input_json_delta) 和 message_stop。
|
||||
// 如果产生了 content token 则返回 true,否则返回 false / Parse a single Anthropic SSE "data:" JSON event. Handles content_block_start,
|
||||
// content_block_delta (text_delta/input_json_delta), and message_stop.
|
||||
// Returns true if a content token was produced, false otherwise.
|
||||
static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
StreamContext* ctx)
|
||||
{
|
||||
@@ -268,7 +289,7 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
std::string type = json::value_to<std::string>(*type_ptr);
|
||||
|
||||
if (type == "content_block_start") {
|
||||
// content_block_start 可能为 tool_use
|
||||
// content_block_start 可能为 tool_use / content_block_start may be tool_use
|
||||
auto* cb = obj.if_contains("content_block");
|
||||
if (!cb || !cb->is_object()) return false;
|
||||
auto& cb_obj = cb->as_object();
|
||||
@@ -311,7 +332,7 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
return true;
|
||||
}
|
||||
} else if (delta_type == "input_json_delta" && ctx) {
|
||||
// W21.2: 累积 tool_use arguments 分片
|
||||
// W21.2: 累积 tool_use arguments 分片 / Accumulate tool_use arguments fragments
|
||||
auto* pj = dobj.if_contains("partial_json");
|
||||
if (pj && pj->is_string()) {
|
||||
auto* idx_ptr = obj.if_contains("index");
|
||||
@@ -326,18 +347,19 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
}
|
||||
} else if (type == "message_stop") {
|
||||
token_out.clear();
|
||||
return true; // 流结束
|
||||
return true; // 流结束 / stream end
|
||||
}
|
||||
// 忽略: message_start, content_block_stop, ping, message_delta
|
||||
// 忽略: message_start, content_block_stop, ping, message_delta / Ignore: message_start, content_block_stop, ping, message_delta
|
||||
} catch (...) {
|
||||
// 解析失败忽略
|
||||
// 解析失败忽略 / Ignore parse failures
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// configure
|
||||
// configure / configure
|
||||
// ============================================================================
|
||||
// 配置插件:provider、endpoint、auth、model 和生成参数 / Configure the plugin with provider, endpoint, auth, model, and generation parameters.
|
||||
static int my_configure(const char* provider, const char* base_url,
|
||||
const char* api_key, const char* model,
|
||||
int max_tokens, double temperature)
|
||||
@@ -352,7 +374,7 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) {
|
||||
// W21.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用
|
||||
// W21.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
|
||||
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
|
||||
h->query_service("tools", 1));
|
||||
if (tools_svc && tools_svc->get_tools_json) {
|
||||
@@ -381,8 +403,9 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// chat
|
||||
// chat / chat
|
||||
// ============================================================================
|
||||
// 非流式 chat completion:发送 history + user input,返回完整响应 / Non-streaming chat completion: send history + user input, return full response.
|
||||
static dstalk_chat_result_t my_chat(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
@@ -447,26 +470,27 @@ static dstalk_chat_result_t my_chat(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// chat_stream
|
||||
// chat_stream / chat_stream
|
||||
// ============================================================================
|
||||
|
||||
// 行回调
|
||||
// 行回调 / SSE line callback
|
||||
// SSE 行回调:解析每个 Anthropic SSE 行并将文本 token 转发给用户 / SSE line callback: parses each Anthropic SSE line and forwards text tokens to user.
|
||||
static int sse_line_callback(const char* line, void* userdata)
|
||||
{
|
||||
try {
|
||||
auto* ctx = static_cast<StreamContext*>(userdata);
|
||||
if (!line || !line[0]) return 1; // 空行,继续
|
||||
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
|
||||
|
||||
std::string line_str(line);
|
||||
|
||||
// SSE 格式: "data: <json>"
|
||||
// SSE 格式: "data: <json>" / SSE format: "data: <json>"
|
||||
if (line_str.rfind("data: ", 0) == 0) {
|
||||
std::string data = line_str.substr(6);
|
||||
std::string token;
|
||||
if (parse_sse_data(data, token, ctx)) {
|
||||
ctx->saw_data_line = true;
|
||||
if (token.empty()) {
|
||||
// message_stop
|
||||
// message_stop / message_stop
|
||||
return 0;
|
||||
}
|
||||
ctx->accumulated += token;
|
||||
@@ -475,7 +499,7 @@ static int sse_line_callback(const char* line, void* userdata)
|
||||
}
|
||||
}
|
||||
}
|
||||
// "event: ..." 行和其他 -> 忽略
|
||||
// "event: ..." 行和其他 -> 忽略 / "event: ..." lines and others -> ignored
|
||||
return 1;
|
||||
} catch (const std::exception& e) {
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
@@ -488,6 +512,9 @@ static int sse_line_callback(const char* line, void* userdata)
|
||||
}
|
||||
}
|
||||
|
||||
// 流式 chat completion:以 stream=true 发送 history + user input,通过回调传递 token。
|
||||
// 累积 tool_use blocks 并在结束时序列化 / Streaming chat completion: send history + user input with stream=true, deliver tokens
|
||||
// via callback. Accumulates tool_use blocks and serializes them at end.
|
||||
static dstalk_chat_result_t my_chat_stream(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
@@ -531,7 +558,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
r.http_status = status_code;
|
||||
|
||||
// 检查错误状态
|
||||
// 检查错误状态 / Check error status
|
||||
if (status_code < 200 || status_code >= 300) {
|
||||
r.ok = 0;
|
||||
if (response_body && response_body[0]) {
|
||||
@@ -560,7 +587,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
if (response_body) host->free(response_body);
|
||||
|
||||
// W21.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling)
|
||||
// W21.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling) / Success = has content OR has tool_calls (tool-only responses like function calling)
|
||||
bool has_content = !ctx.accumulated.empty();
|
||||
bool has_tool_calls = !ctx.tool_calls.empty();
|
||||
|
||||
@@ -575,7 +602,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
r.content = has_content
|
||||
? host->strdup(ctx.accumulated.c_str()) : nullptr;
|
||||
|
||||
// W21.2: 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式)
|
||||
// W21.2: 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible format)
|
||||
if (has_tool_calls) {
|
||||
json::array tc_array;
|
||||
for (auto& tc : ctx.tool_calls) {
|
||||
@@ -614,8 +641,9 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// free_result
|
||||
// free_result / free_result
|
||||
// ============================================================================
|
||||
// 释放 chat result 结构体中所有主机分配的字符串字段 / Free all host-allocated string fields in a chat result struct.
|
||||
static void my_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
@@ -626,7 +654,7 @@ static void my_free_result(dstalk_chat_result_t* result)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 服务 vtable
|
||||
// 服务 vtable / Service vtable
|
||||
// ============================================================================
|
||||
static dstalk_ai_service_t g_service = {
|
||||
&my_configure,
|
||||
@@ -636,8 +664,9 @@ static dstalk_ai_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// 生命周期
|
||||
// 生命周期 / Lifecycle
|
||||
// ============================================================================
|
||||
// 插件初始化:查询 http 和 config 服务,注册 ai.anthropic 服务 / Plugin init: query http and config services, register ai.anthropic service.
|
||||
static int on_init(const dstalk_host_api_t* host)
|
||||
{
|
||||
try {
|
||||
@@ -666,6 +695,7 @@ static int on_init(const dstalk_host_api_t* host)
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:从内存安全擦除 API key,清空服务指针 / Plugin shutdown: securely erase API key from memory, null out service pointers.
|
||||
static void on_shutdown()
|
||||
{
|
||||
try {
|
||||
@@ -686,12 +716,12 @@ static void on_shutdown()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 插件描述符
|
||||
// 插件描述符 / Plugin descriptor
|
||||
// ============================================================================
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "anthropic-ai",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "Anthropic Claude AI provider (Messages API)",
|
||||
/* .description = */ "Anthropic Claude AI provider (Messages API) / Anthropic Claude AI 提供者 (Messages API)",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ { "http", "config", NULL },
|
||||
/* .on_init = */ on_init,
|
||||
@@ -699,6 +729,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
/* .on_event = */ nullptr,
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
|
||||
{
|
||||
return &g_info;
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
/*
|
||||
* @file toml_parse.h
|
||||
* @brief Lightweight single-header TOML parser (subset: flat key-value pairs).
|
||||
* 轻量级单头文件 TOML 解析器(子集:扁平键值对)。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
// Shared TOML parser — used by both ConfigStore (core) and config plugin.
|
||||
// 共享 TOML 解析器 —— 由 ConfigStore(核心)和 config 插件共同使用 / Shared TOML parser — used by both ConfigStore (core) and config plugin.
|
||||
// W12.2: Extracted from config_store.cpp:23-61 and config_plugin.cpp:28-66
|
||||
// to eliminate the 74-line code duplication (W11.2 audit Finding 1).
|
||||
// Does NOT support: inline tables, arrays, multi-line strings, escape sequences.
|
||||
// 不支持:内联表、数组、多行字符串、转义序列。
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace dstalk {
|
||||
namespace toml {
|
||||
|
||||
/// Parse a TOML string, calling on_kv(full_key, value) for each key-value pair.
|
||||
/// Supports [section] headers, key = "value" pairs, # comments, blank lines.
|
||||
/// 解析 TOML 字符串,对每个键值对调用 on_kv(full_key, value) / Parse a TOML string, calling on_kv(full_key, value) for each key-value pair.
|
||||
/// 支持 [section] 标题、key = "value" 键值对、# 注释、空行 / Supports [section] headers, key = "value" pairs, # comments, blank lines.
|
||||
template<typename F>
|
||||
inline void parse(const std::string& content, F&& on_kv)
|
||||
{
|
||||
@@ -18,31 +26,31 @@ inline void parse(const std::string& content, F&& on_kv)
|
||||
size_t pos = 0;
|
||||
|
||||
while (pos < content.size()) {
|
||||
// Trim left whitespace
|
||||
// 去除左侧空白 / Trim left whitespace
|
||||
while (pos < content.size() && (content[pos] == ' ' || content[pos] == '\t'))
|
||||
pos++;
|
||||
if (pos >= content.size()) break;
|
||||
|
||||
// Extract next line
|
||||
// 提取下一行 / Extract next line
|
||||
size_t nl = content.find('\n', pos);
|
||||
std::string line = (nl != std::string::npos)
|
||||
? content.substr(pos, nl - pos) : content.substr(pos);
|
||||
pos = (nl != std::string::npos) ? nl + 1 : content.size();
|
||||
|
||||
// Trim right whitespace (including \r)
|
||||
// 去除右侧空白(包括 \r) / Trim right whitespace (including \r)
|
||||
while (!line.empty() && (line.back() == '\r' || line.back() == ' '))
|
||||
line.pop_back();
|
||||
|
||||
// Skip empty lines and comments
|
||||
// 跳过空行和注释 / Skip empty lines and comments
|
||||
if (line.empty() || line[0] == '#') continue;
|
||||
|
||||
// Section header: [section_name]
|
||||
// 节标题: [section_name] / Section header: [section_name]
|
||||
if (line[0] == '[' && line.back() == ']') {
|
||||
current_section = line.substr(1, line.size() - 2);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Key = value
|
||||
// 键 = 值 / Key = value
|
||||
size_t eq = line.find('=');
|
||||
if (eq == std::string::npos) continue;
|
||||
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file config_plugin.cpp
|
||||
* @brief Config plugin: TOML file parsing and key-value configuration service.
|
||||
* 配置插件:TOML 文件解析和键值配置服务。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
#include "../include/toml_parse.h"
|
||||
@@ -7,12 +14,12 @@
|
||||
#include <sstream>
|
||||
|
||||
// ============================================================
|
||||
// Global state
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// Service implementations
|
||||
// 服务实现 / Service implementations
|
||||
//
|
||||
// W12.2: Eliminated private ConfigStore (was 90 lines duplicating core).
|
||||
// All get/set/load_file now delegate to the host store via g_host->config_get
|
||||
@@ -20,16 +27,19 @@ static const dstalk_host_api_t* g_host = nullptr;
|
||||
// TOML parsing uses the shared dstalk::toml::parse() from toml_parse.h.
|
||||
// ============================================================
|
||||
|
||||
// 从主机存储中按 key 获取配置值 / Retrieve a configuration value by key from the host store.
|
||||
static const char* config_get(const char* key) {
|
||||
if (!g_host) return nullptr;
|
||||
return g_host->config_get(key);
|
||||
}
|
||||
|
||||
// 将键值对存入主机存储 / Store a configuration key-value pair into the host store.
|
||||
static int config_set(const char* key, const char* value) {
|
||||
if (!g_host) return -1;
|
||||
return g_host->config_set(key, value);
|
||||
}
|
||||
|
||||
// 解析指定路径的 TOML 文件,将所有键值对加载到主机存储中 / Parse a TOML file at `path` and load all key-value pairs into the host store.
|
||||
static int config_load_file(const char* path) {
|
||||
if (!g_host || !path) return -1;
|
||||
|
||||
@@ -58,12 +68,13 @@ static dstalk_config_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Plugin lifecycle
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
// 插件初始化:保存主机指针并注册 config 服务 vtable / Plugin init: store host pointer and register the config service vtable.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
|
||||
// W12.2: This service is now a thin wrapper around host->config_get/set.
|
||||
// W12.2: 该服务现为 host->config_get/set 的薄封装,建议直接调用主机 API / This service is now a thin wrapper around host->config_get/set.
|
||||
// Direct host API calls are preferred.
|
||||
host->log(DSTALK_LOG_INFO,
|
||||
"plugin config service is deprecated, prefer host->config_get/set");
|
||||
@@ -76,8 +87,10 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
return (rc >= 0) ? 0 : -1;
|
||||
}
|
||||
|
||||
// 插件关闭:无需清理本地存储(所有数据在主机存储中) / Plugin shutdown: no local store to clean up (all data lives in host store).
|
||||
static void on_shutdown() {
|
||||
// W12.2: No local store to clean up — all data lives in host store.
|
||||
// 无需清理本地存储——所有数据位于主机存储中。
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
@@ -91,6 +104,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
nullptr // on_event
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
// plugin-context: 上下文管理服务插件
|
||||
// 提供 dstalk_context_service_t vtable 实现
|
||||
// 依赖: session (获取历史消息做 token 计数)
|
||||
/*
|
||||
* @file context_plugin.cpp
|
||||
* @brief Context plugin: token counting and context window trimming.
|
||||
* 上下文插件:token 计数和上下文窗口裁剪。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-context: 上下文管理服务插件 / Context management service plugin
|
||||
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
|
||||
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_types.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
@@ -15,21 +22,26 @@
|
||||
#include <vector>
|
||||
|
||||
// ============================================================
|
||||
// 全局状态
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
static const dstalk_session_service_t* g_session = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// 内部 C++ 辅助:共享 UTF-8 token 计数
|
||||
// 内部 C++ 辅助:共享 UTF-8 token 计数 / Internal C++ helper: shared UTF-8 token counting
|
||||
// W18.1: 合并 count_tokens_one_message / count_tokens_trim 的重复逻辑 (F-11.1-5)
|
||||
// Merge duplicated logic between count_tokens_one_message / count_tokens_trim (F-11.1-5)
|
||||
// 添加 UTF-8 越界保护 (F-11.1-4) 和 0xC0/0xC1 过短编码检测 (F-11.1-6)
|
||||
// Add UTF-8 out-of-bounds protection (F-11.1-4) and 0xC0/0xC1 overlong encoding detection (F-11.1-6)
|
||||
// ============================================================
|
||||
|
||||
// 统计 UTF-8 字节序列 [text, text+len) 的估算 token 数。
|
||||
// overhead: 每条消息的固定开销 token(role + separators = 4)
|
||||
// 多字节序列在越界或无效后继字节时回退为单字节 other_chars 计数,不崩溃。
|
||||
// Count estimated tokens for UTF-8 byte sequence [text, text+len).
|
||||
// overhead: fixed token overhead per message (role + separators = 4).
|
||||
// Multi-byte sequences fall back to single-byte other_chars counting when out-of-bounds or invalid continuation bytes.
|
||||
static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
if (!text || len == 0) return overhead;
|
||||
|
||||
@@ -42,12 +54,12 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
unsigned char c = static_cast<unsigned char>(text[i]);
|
||||
|
||||
if (c < 0x80) {
|
||||
// ASCII
|
||||
// ASCII / ASCII
|
||||
ascii_chars++;
|
||||
i += 1;
|
||||
} else if (c >= 0xE4 && c <= 0xE9) {
|
||||
// CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9
|
||||
// W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内
|
||||
// CJK 统一表意文字 (U+4E00-U+9FFF): 3 字节 UTF-8 0xE4-0xE9 / CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9
|
||||
// W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内 / Check if subsequent 2 bytes are in valid range
|
||||
if (i + 2 >= len ||
|
||||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
|
||||
@@ -58,8 +70,8 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
i += 3;
|
||||
}
|
||||
} else if (c >= 0xC2 && c < 0xE0) {
|
||||
// 2-byte sequence (valid range 0xC2-0xDF)
|
||||
// W18.1 (F-11.1-4): 检查后续 1 字节
|
||||
// 2 字节序列 (有效范围 0xC2-0xDF) / 2-byte sequence (valid range 0xC2-0xDF)
|
||||
// W18.1 (F-11.1-4): 检查后续 1 字节 / Check subsequent 1 byte
|
||||
if (i + 1 >= len ||
|
||||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80) {
|
||||
other_chars++;
|
||||
@@ -69,13 +81,13 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
i += 2;
|
||||
}
|
||||
} else if (c == 0xC0 || c == 0xC1) {
|
||||
// W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节
|
||||
// 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars
|
||||
// W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节 / Overlong encoding, invalid UTF-8 start byte
|
||||
// 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars / 0xC0/0xC1 never appear in valid UTF-8; counted as single-byte in other_chars
|
||||
other_chars++;
|
||||
i += 1;
|
||||
} else if (c >= 0xE0 && c < 0xF0) {
|
||||
// Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF)
|
||||
// CJK 范围 0xE4-0xE9 已在上方分支处理
|
||||
// 非 CJK 3 字节序列 (0xE0-0xE3, 0xEA-0xEF) / Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF)
|
||||
// CJK 范围 0xE4-0xE9 已在上方分支处理 / CJK range 0xE4-0xE9 handled in branch above
|
||||
if (i + 2 >= len ||
|
||||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
|
||||
@@ -86,7 +98,7 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
i += 3;
|
||||
}
|
||||
} else if (c >= 0xF0 && c < 0xF8) {
|
||||
// 4-byte sequence
|
||||
// 4 字节序列 / 4-byte sequence
|
||||
if (i + 3 >= len ||
|
||||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
|
||||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80 ||
|
||||
@@ -98,7 +110,7 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
i += 4;
|
||||
}
|
||||
} else {
|
||||
// Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF)
|
||||
// 续字节 (0x80-0xBF) 和其他无效起始字节 (0xF8-0xFF) / Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF)
|
||||
other_chars++;
|
||||
i += 1;
|
||||
}
|
||||
@@ -108,15 +120,17 @@ static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装)
|
||||
// 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装) / Message-level token counting (thin wrappers for count_tokens_all and trim_impl)
|
||||
// ============================================================
|
||||
|
||||
// 对单条 C 消息结构体封装 count_tokens_utf8 / Wrap count_tokens_utf8 for a single C message struct.
|
||||
static size_t count_tokens_one_message(const dstalk_message_t& msg) {
|
||||
const char* text = msg.content;
|
||||
if (!text) return 4; // 只有 overhead
|
||||
if (!text) return 4; // 只有 overhead / overhead only
|
||||
return count_tokens_utf8(text, std::strlen(text), 4);
|
||||
}
|
||||
|
||||
// 对 C 消息数组求和估算 token / Sum token estimates across an array of C messages.
|
||||
static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
|
||||
size_t total = 0;
|
||||
for (int i = 0; i < count; ++i) {
|
||||
@@ -126,10 +140,10 @@ static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 内部 trim 逻辑
|
||||
// 内部 trim 逻辑 / Internal trim logic
|
||||
// ============================================================
|
||||
|
||||
// 为 trim 操作将 C 消息数组复制到内部 struct
|
||||
// 为 trim 操作将 C 消息数组复制到内部 struct / Copy C message array to internal struct for trim operation
|
||||
struct TrimMessage {
|
||||
std::string role;
|
||||
std::string content;
|
||||
@@ -148,7 +162,7 @@ static size_t count_tokens_trim_vec(const std::vector<TrimMessage>& msgs) {
|
||||
return total;
|
||||
}
|
||||
|
||||
// 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚)
|
||||
// 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚) / Free all host-allocated string fields in a single dstalk_message_t (OOM rollback helper).
|
||||
static void free_msg_strs(dstalk_message_t* msg) {
|
||||
if (msg->role) { g_host->free((void*)msg->role); msg->role = nullptr; }
|
||||
if (msg->content) { g_host->free((void*)msg->content); msg->content = nullptr; }
|
||||
@@ -158,6 +172,8 @@ static void free_msg_strs(dstalk_message_t* msg) {
|
||||
|
||||
// 将 TrimMessage 的字符串字段通过 g_host->strdup 复制到 dstalk_message_t。
|
||||
// 成功返回 0;OOM 时释放当前消息已分配字段并返回 -1。
|
||||
// Copy TrimMessage string fields into a dstalk_message_t via host->strdup.
|
||||
// On OOM, frees already-allocated fields and returns -1.
|
||||
static int strdup_message_fields(dstalk_message_t* dst, const TrimMessage& src) {
|
||||
memset(dst, 0, sizeof(dstalk_message_t));
|
||||
|
||||
@@ -184,7 +200,10 @@ oom:
|
||||
return -1;
|
||||
}
|
||||
|
||||
// W12.1 修复:trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3)
|
||||
// W12.1 修复:trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3) / W12.1 fix: trim_impl wrapped in try/catch to prevent C++ exceptions crossing ABI boundary (§5.3)
|
||||
// 核心裁剪逻辑:通过删除最旧的 user/assistant 对来减少消息列表以适应 max_tokens。
|
||||
// 保留 system 消息。try/catch 保护 ABI / Core trim logic: reduce message list to fit within max_tokens by removing
|
||||
// oldest user/assistant pairs. Preserves system messages. try/catch guards ABI.
|
||||
static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
dstalk_message_t** out, int* out_count,
|
||||
size_t max_tokens) {
|
||||
@@ -192,10 +211,11 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
if (!in || in_count <= 0 || !out || !out_count) return -1;
|
||||
|
||||
// W18.1 (F-11.1-3): g_max_tokens 已移除,调用方必须提供有效 max_tokens;
|
||||
// 传 0 时使用硬编码默认值 4096。
|
||||
// 传 0 时使用硬编码默认值 4096 / g_max_tokens removed, caller must provide valid max_tokens;
|
||||
// when 0 is passed, use hardcoded default 4096.
|
||||
if (max_tokens == 0) max_tokens = 4096;
|
||||
|
||||
// 将 C 数组转换为内部 vector
|
||||
// 将 C 数组转换为内部 vector / Convert C array to internal vector
|
||||
std::vector<TrimMessage> messages;
|
||||
messages.reserve(in_count);
|
||||
for (int i = 0; i < in_count; ++i) {
|
||||
@@ -207,13 +227,13 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
messages.push_back(std::move(tm));
|
||||
}
|
||||
|
||||
// 如果已在限制内,直接返回完整副本
|
||||
// 如果已在限制内,直接返回完整副本 / If already within limit, return full copy directly
|
||||
size_t current = count_tokens_trim_vec(messages);
|
||||
if (current <= max_tokens) {
|
||||
*out_count = in_count;
|
||||
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * in_count));
|
||||
if (!*out) return -1;
|
||||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息
|
||||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback already allocated on OOM
|
||||
for (int i = 0; i < in_count; ++i) {
|
||||
if (strdup_message_fields(&(*out)[i], messages[i]) != 0) {
|
||||
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
|
||||
@@ -225,7 +245,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 分离 system 消息和非 system 消息
|
||||
// 分离 system 消息和非 system 消息 / Separate system messages from non-system messages
|
||||
std::vector<TrimMessage> system_msgs;
|
||||
std::vector<TrimMessage> non_system_msgs;
|
||||
for (const auto& msg : messages) {
|
||||
@@ -243,7 +263,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
system_tokens, max_tokens);
|
||||
}
|
||||
|
||||
// 检查是否有单条消息超过限制
|
||||
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
|
||||
for (const auto& msg : non_system_msgs) {
|
||||
size_t msg_tokens = count_tokens_trim(msg);
|
||||
if (msg_tokens > max_tokens) {
|
||||
@@ -257,19 +277,19 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
}
|
||||
|
||||
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除
|
||||
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
|
||||
while (!non_system_msgs.empty()) {
|
||||
current = system_tokens + count_tokens_trim_vec(non_system_msgs);
|
||||
if (current <= max_tokens) break;
|
||||
|
||||
// 找第一个 "user" 消息
|
||||
// 找第一个 "user" 消息 / Find first "user" message
|
||||
auto user_it = non_system_msgs.begin();
|
||||
while (user_it != non_system_msgs.end() && user_it->role != "user") {
|
||||
++user_it;
|
||||
}
|
||||
if (user_it == non_system_msgs.end()) break;
|
||||
|
||||
// 找下一个 "assistant"
|
||||
// 找下一个 "assistant" / Find next "assistant"
|
||||
auto assistant_it = user_it + 1;
|
||||
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
|
||||
++assistant_it;
|
||||
@@ -278,7 +298,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
if (assistant_it == non_system_msgs.end()) {
|
||||
non_system_msgs.erase(user_it);
|
||||
} else {
|
||||
// 先删 assistant 再删 user 避免迭代器失效
|
||||
// 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
|
||||
non_system_msgs.erase(assistant_it);
|
||||
user_it = non_system_msgs.begin();
|
||||
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
|
||||
@@ -286,7 +306,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
}
|
||||
|
||||
// W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token),使用当前 max_tokens
|
||||
// W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token),使用当前 max_tokens / Message count upper bound rough estimate (~100 tokens per message), uses current max_tokens
|
||||
{
|
||||
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
|
||||
if (max_msg_count < 1) max_msg_count = 1;
|
||||
@@ -295,7 +315,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
}
|
||||
|
||||
// 组装结果
|
||||
// 组装结果 / Assemble result
|
||||
std::vector<TrimMessage> result;
|
||||
result.reserve(system_msgs.size() + non_system_msgs.size());
|
||||
result.insert(result.end(), system_msgs.begin(), system_msgs.end());
|
||||
@@ -306,7 +326,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * result_count));
|
||||
if (!*out) return -1;
|
||||
|
||||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息
|
||||
// W12.1: strdup 返回值逐一检查,OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback on OOM
|
||||
for (int i = 0; i < result_count; ++i) {
|
||||
if (strdup_message_fields(&(*out)[i], result[i]) != 0) {
|
||||
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
|
||||
@@ -318,7 +338,7 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
// W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate()
|
||||
// W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate() / Prevent C++ exceptions (std::bad_alloc etc.) from crossing C ABI boundary -> std::terminate()
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl exception: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
@@ -328,10 +348,11 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Context 服务 vtable 实现
|
||||
// Context 服务 vtable 实现 / Context service vtable implementation
|
||||
// ============================================================
|
||||
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate()
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate() / Wrapped try/catch prevents exceptions crossing C ABI boundary -> std::terminate()
|
||||
// 对 C 消息数组进行 token 计数。输入为 null/空时返回 0 / Count tokens across an array of C messages. Returns 0 on null/empty input.
|
||||
static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
|
||||
try {
|
||||
if (!msgs || count <= 0) return 0;
|
||||
@@ -341,7 +362,8 @@ static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
|
||||
}
|
||||
}
|
||||
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
|
||||
// 裁剪消息列表以适应 max_tokens,返回新分配的主机内存数组 / Trim a message list to fit within max_tokens, returning a new host-allocated array.
|
||||
static int context_trim(const dstalk_message_t* in, int in_count,
|
||||
dstalk_message_t** out, int* out_count,
|
||||
size_t max_tokens) {
|
||||
@@ -355,21 +377,24 @@ static int context_trim(const dstalk_message_t* in, int in_count,
|
||||
// W18.1 (F-11.1-3): g_max_tokens / context_set_max_tokens 已移除。
|
||||
// max_tokens 由调用方通过 trim() 的 max_tokens 参数直接传入;
|
||||
// 传 0 时 trim_impl 使用硬编码默认值 4096。
|
||||
// g_max_tokens / context_set_max_tokens removed. max_tokens is passed directly
|
||||
// by caller via trim()'s max_tokens parameter; trim_impl uses hardcoded default 4096 when 0.
|
||||
static dstalk_context_service_t g_context_service = {
|
||||
context_count_tokens,
|
||||
context_trim
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界
|
||||
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
|
||||
// 插件初始化:保存主机指针,查询 session 依赖,注册 context 服务 / Plugin init: store host pointer, query session dependency, register context service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
try {
|
||||
g_host = host;
|
||||
|
||||
// 查询依赖服务: session
|
||||
// 查询依赖服务: session / Query dependency service: session
|
||||
void* raw = host->query_service("session", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
|
||||
@@ -387,7 +412,8 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
}
|
||||
|
||||
// W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log
|
||||
// W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log / Wrapped try/catch prevents exceptions crossing C ABI boundary -- void function only logs
|
||||
// 插件关闭:清空指针。try/catch 保护 ABI(void 函数) / Plugin shutdown: null out pointers. try/catch guards ABI (void function).
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
g_session = nullptr;
|
||||
@@ -406,7 +432,7 @@ static void on_shutdown() {
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"context",
|
||||
"1.0.0",
|
||||
"Context management plugin with token counting and trim support",
|
||||
"Context management plugin with token counting and trim support / 支持 token 计数和裁剪的上下文管理插件",
|
||||
DSTALK_API_VERSION,
|
||||
{"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||||
on_init,
|
||||
@@ -414,6 +440,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
nullptr
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file deepseek_plugin.cpp
|
||||
* @brief DeepSeek/OpenAI-compatible AI provider plugin with SSE streaming and tool calls.
|
||||
* DeepSeek/OpenAI 兼容 AI 提供者插件,支持 SSE 流式输出和工具调用。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
@@ -11,14 +18,14 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 全局指针:从 on_init 获取(W14.3: atomic acquire/release 保护读写竞态)
|
||||
// 全局指针:从 on_init 获取(W14.3: atomic acquire/release 保护读写竞态) / Global pointers: obtained from on_init (W14.3: atomic acquire/release protects read/write races)
|
||||
// ============================================================================
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
|
||||
static std::atomic<dstalk_config_service_t*> g_config{nullptr};
|
||||
|
||||
// ============================================================================
|
||||
// 配置数据(由 configure() 设置)
|
||||
// 配置数据(由 configure() 设置) / Config data (set by configure())
|
||||
// ============================================================================
|
||||
struct PluginConfig {
|
||||
std::string provider;
|
||||
@@ -29,19 +36,21 @@ struct PluginConfig {
|
||||
double temperature = 0.7;
|
||||
};
|
||||
static PluginConfig g_cfg;
|
||||
static std::string g_tools_json; // W20.2: cached by configure(), consumed by chat/chat_stream
|
||||
static std::string g_tools_json; // W20.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
|
||||
|
||||
// ============================================================================
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
|
||||
// ============================================================================
|
||||
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
|
||||
static void secure_zero(void* p, size_t n) {
|
||||
volatile char* vp = (volatile char*)p;
|
||||
while (n--) *vp++ = 0;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:从 base_url 提取 host 和 target
|
||||
// 辅助:从 base_url 提取 host 和 target / Helper: extract host and target from base_url
|
||||
// ============================================================================
|
||||
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
|
||||
static bool extract_host_port(const std::string& url,
|
||||
std::string& scheme_out, std::string& host_out,
|
||||
std::string& port_out, std::string& target_out)
|
||||
@@ -65,8 +74,9 @@ static bool extract_host_port(const std::string& url,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:构建 headers JSON 字符串
|
||||
// 辅助:构建 headers JSON 字符串 / Helper: build headers JSON string
|
||||
// ============================================================================
|
||||
// 构建包含 Bearer 授权令牌的 JSON headers 对象 / Build the JSON headers object containing the Bearer authorization token.
|
||||
static std::string build_headers_json(const std::string& auth_header_value)
|
||||
{
|
||||
json::object h;
|
||||
@@ -75,8 +85,9 @@ static std::string build_headers_json(const std::string& auth_header_value)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:dstalk_message_t[] -> boost::json::array
|
||||
// 辅助:dstalk_message_t[] -> boost::json::array / Helper: dstalk_message_t[] -> boost::json::array
|
||||
// ============================================================================
|
||||
// 将 dstalk_message_t 数组转换为 Boost.JSON 数组,用于 API 请求体 / Convert dstalk_message_t array into a Boost.JSON array for the API request body.
|
||||
static void append_history(json::array& msgs,
|
||||
const dstalk_message_t* history, int history_len)
|
||||
{
|
||||
@@ -100,8 +111,9 @@ static void append_history(json::array& msgs,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 构建 DeepSeek JSON 请求体
|
||||
// 构建 DeepSeek JSON 请求体 / Build DeepSeek JSON request body
|
||||
// ============================================================================
|
||||
// 构建 DeepSeek/OpenAI chat completions API 的完整 JSON 请求体 / Build the full JSON request body for the DeepSeek/OpenAI chat completions API.
|
||||
static std::string build_request_json(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const std::string& user_input,
|
||||
@@ -117,7 +129,7 @@ static std::string build_request_json(
|
||||
json::array msgs;
|
||||
append_history(msgs, history, history_len);
|
||||
|
||||
// 追加当前用户输入
|
||||
// 追加当前用户输入 / Append current user input
|
||||
if (!user_input.empty()) {
|
||||
json::object obj;
|
||||
obj["role"] = "user";
|
||||
@@ -127,7 +139,7 @@ static std::string build_request_json(
|
||||
|
||||
root["messages"] = msgs;
|
||||
|
||||
// tools 定义
|
||||
// tools 定义 / tools definition
|
||||
if (!tools_json.empty()) {
|
||||
root["tools"] = json::parse(tools_json);
|
||||
}
|
||||
@@ -136,8 +148,9 @@ static std::string build_request_json(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 解析非流式 JSON 响应
|
||||
// 解析非流式 JSON 响应 / Parse non-streaming JSON response
|
||||
// ============================================================================
|
||||
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
|
||||
static void parse_response(const dstalk_host_api_t* host,
|
||||
const char* body, int http_status,
|
||||
dstalk_chat_result_t& r)
|
||||
@@ -207,13 +220,13 @@ static void parse_response(const dstalk_host_api_t* host,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 流式上下文:在 SSE 回调间累积内容和 tool_calls
|
||||
// 流式上下文:在 SSE 回调间累积内容和 tool_calls / Stream context: accumulate content and tool_calls across SSE callbacks
|
||||
// ============================================================================
|
||||
struct ToolCallAccum {
|
||||
int index = -1;
|
||||
std::string id;
|
||||
std::string name;
|
||||
std::string arguments; // 增量拼接的 JSON arguments 字符串
|
||||
std::string arguments; // 增量拼接的 JSON arguments 字符串 / incrementally concatenated JSON arguments string
|
||||
};
|
||||
|
||||
struct StreamContext {
|
||||
@@ -222,12 +235,18 @@ struct StreamContext {
|
||||
void* userdata;
|
||||
std::string accumulated;
|
||||
bool streaming_ok = true;
|
||||
std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls
|
||||
std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls / accumulate delta tool_calls by index
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// SSE 行解析(OpenAI 兼容格式)
|
||||
// SSE 行解析(OpenAI 兼容格式) / SSE line parsing (OpenAI-compatible format)
|
||||
// ============================================================================
|
||||
// 解析单行 SSE "data:" 行。如果包含 content delta,将 token 写入 token_out。
|
||||
// 如果包含 tool_calls delta,累积到 ctx->tool_calls。
|
||||
// 如果产生了 content token 则返回 true,否则返回 false(tool_calls 或未知)。
|
||||
// Parse a single SSE "data:" line. If it contains a content delta, writes the token
|
||||
// to token_out. If it contains tool_calls delta, accumulates into ctx->tool_calls.
|
||||
// Returns true if a content token was produced, false otherwise (tool_calls or unknown).
|
||||
static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
StreamContext* ctx)
|
||||
{
|
||||
@@ -235,7 +254,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
|
||||
std::string data = line.substr(6);
|
||||
|
||||
// F-13.2-3: Trim leading/trailing whitespace before comparing [DONE] sentinel.
|
||||
// F-13.2-3: 比较 [DONE] 哨兵前去除首尾空白 / Trim leading/trailing whitespace before comparing [DONE] sentinel.
|
||||
const char* ws = " \t\r\n";
|
||||
size_t start = data.find_first_not_of(ws);
|
||||
if (start != std::string::npos) {
|
||||
@@ -244,7 +263,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
}
|
||||
if (data == "[DONE]") {
|
||||
token_out.clear();
|
||||
return true; // 流结束信号
|
||||
return true; // 流结束信号 / stream end signal
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -254,12 +273,12 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
if (!choices.empty()) {
|
||||
auto delta = choices[0].as_object()["delta"].as_object();
|
||||
|
||||
// W20.2: 处理 delta["tool_calls"] 增量 chunk
|
||||
// DeepSeek/OpenAI 流式模式 tool_calls 跨多个 SSE 事件分片传输:
|
||||
// 事件 1: {"index":0, "id":"call_xxx", "function":{"name":"foo"}}
|
||||
// 事件 2: {"index":0, "function":{"arguments":"{\"bar\":"}}
|
||||
// 事件 3: {"index":0, "function":{"arguments":"1}"}}
|
||||
// 需要按 index 累积 id/name/arguments。
|
||||
// W20.2: 处理 delta["tool_calls"] 增量 chunk / Handle delta["tool_calls"] incremental chunks
|
||||
// DeepSeek/OpenAI 流式模式 tool_calls 跨多个 SSE 事件分片传输 / DeepSeek/OpenAI streaming mode: tool_calls transmitted across multiple SSE event chunks:
|
||||
// 事件 1 / Event 1: {"index":0, "id":"call_xxx", "function":{"name":"foo"}}
|
||||
// 事件 2 / Event 2: {"index":0, "function":{"arguments":"{\"bar\":"}}
|
||||
// 事件 3 / Event 3: {"index":0, "function":{"arguments":"1}"}}
|
||||
// 需要按 index 累积 id/name/arguments / Need to accumulate id/name/arguments by index.
|
||||
if (delta.contains("tool_calls") && ctx) {
|
||||
auto tc_array = delta["tool_calls"].as_array();
|
||||
for (auto& tc_val : tc_array) {
|
||||
@@ -288,7 +307,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
return false; // tool_calls 已处理,无内容 token 给用户回调
|
||||
return false; // tool_calls 已处理,无内容 token 给用户回调 / tool_calls processed, no content token for user callback
|
||||
}
|
||||
|
||||
if (delta.contains("content")) {
|
||||
@@ -297,14 +316,15 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
}
|
||||
}
|
||||
} catch (...) {
|
||||
// 忽略解析失败
|
||||
// 忽略解析失败 / Ignore parse failures
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// configure 实现
|
||||
// configure 实现 / configure implementation
|
||||
// ============================================================================
|
||||
// 配置插件:provider、endpoint、auth、model 和生成参数 / Configure the plugin with provider, endpoint, auth, model, and generation parameters.
|
||||
static int my_configure(const char* provider, const char* base_url,
|
||||
const char* api_key, const char* model,
|
||||
int max_tokens, double temperature)
|
||||
@@ -319,7 +339,7 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) {
|
||||
// W20.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用
|
||||
// W20.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
|
||||
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
|
||||
host->query_service("tools", 1));
|
||||
if (tools_svc && tools_svc->get_tools_json) {
|
||||
@@ -348,8 +368,9 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// chat 实现
|
||||
// chat 实现 / chat implementation
|
||||
// ============================================================================
|
||||
// 非流式 chat completion:发送 history + user input,返回完整响应 / Non-streaming chat completion: send history + user input, return full response.
|
||||
static dstalk_chat_result_t my_chat(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
@@ -412,29 +433,29 @@ static dstalk_chat_result_t my_chat(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// chat_stream 实现
|
||||
// chat_stream 实现 / chat_stream implementation
|
||||
// ============================================================================
|
||||
|
||||
// 行回调:解析 SSE line,将 token 传递给用户回调
|
||||
// 行回调:解析 SSE line,将 token 传递给用户回调 / SSE line callback: parses each line and forwards content tokens to the user callback.
|
||||
static int sse_line_callback(const char* line, void* userdata)
|
||||
{
|
||||
try {
|
||||
auto* ctx = static_cast<StreamContext*>(userdata);
|
||||
if (!line || !line[0]) return 1; // 空行,继续
|
||||
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
|
||||
|
||||
std::string line_str(line);
|
||||
std::string token;
|
||||
|
||||
if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续
|
||||
if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续 / not a data/tool_calls line, continue
|
||||
|
||||
if (token.empty()) return 0; // [DONE],停止
|
||||
if (token.empty()) return 0; // [DONE],停止 / [DONE], stop
|
||||
|
||||
ctx->accumulated += token;
|
||||
|
||||
if (ctx->user_cb) {
|
||||
return ctx->user_cb(token.c_str(), ctx->userdata);
|
||||
}
|
||||
return 1; // 继续
|
||||
return 1; // 继续 / continue
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[deepseek] sse_line_callback exception: %s", e.what());
|
||||
@@ -446,6 +467,9 @@ static int sse_line_callback(const char* line, void* userdata)
|
||||
}
|
||||
}
|
||||
|
||||
// 流式 chat completion:以 stream=true 发送 history + user input,通过回调传递 token。
|
||||
// 在 SSE 分片中累积 tool_calls 并在结束时序列化 / Streaming chat completion: send history + user input with stream=true, deliver tokens
|
||||
// via callback. Accumulates tool_calls across SSE chunks and serializes them at end.
|
||||
static dstalk_chat_result_t my_chat_stream(
|
||||
const dstalk_message_t* history, int history_len,
|
||||
const char* user_input,
|
||||
@@ -488,10 +512,10 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
r.http_status = status_code;
|
||||
|
||||
// 检查传输层错误或非 2xx 状态
|
||||
// 检查传输层错误或非 2xx 状态 / Check transport errors or non-2xx status
|
||||
if (status_code < 200 || status_code >= 300) {
|
||||
r.ok = 0;
|
||||
// 尝试从响应体提取错误信息
|
||||
// 尝试从响应体提取错误信息 / Try to extract error info from response body
|
||||
if (response_body && response_body[0]) {
|
||||
try {
|
||||
auto jv = json::parse(response_body);
|
||||
@@ -518,7 +542,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
if (response_body && host) host->free(response_body);
|
||||
|
||||
// W20.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling)
|
||||
// W20.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling) / Success = has content OR has tool_calls (tool-only responses like function calling)
|
||||
bool has_content = !ctx.accumulated.empty();
|
||||
bool has_tool_calls = !ctx.tool_calls.empty();
|
||||
|
||||
@@ -533,7 +557,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
r.content = has_content
|
||||
? host->strdup(ctx.accumulated.c_str()) : nullptr;
|
||||
|
||||
// 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式)
|
||||
// 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible tool_calls format)
|
||||
if (has_tool_calls) {
|
||||
json::array tc_array;
|
||||
for (auto& tc : ctx.tool_calls) {
|
||||
@@ -572,8 +596,9 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// free_result 实现
|
||||
// free_result 实现 / free_result implementation
|
||||
// ============================================================================
|
||||
// 释放 chat result 结构体中所有主机分配的字符串字段 / Free all host-allocated string fields in a chat result struct.
|
||||
static void my_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
@@ -584,7 +609,7 @@ static void my_free_result(dstalk_chat_result_t* result)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 服务 vtable
|
||||
// 服务 vtable / Service vtable
|
||||
// ============================================================================
|
||||
static dstalk_ai_service_t g_service = {
|
||||
&my_configure,
|
||||
@@ -594,8 +619,9 @@ static dstalk_ai_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// 生命周期
|
||||
// 生命周期 / Lifecycle
|
||||
// ============================================================================
|
||||
// 插件初始化:查询 http 和 config 服务,注册 ai.deepseek 服务 / Plugin init: query http and config services, register ai.deepseek service.
|
||||
static int on_init(const dstalk_host_api_t* host)
|
||||
{
|
||||
try {
|
||||
@@ -624,6 +650,7 @@ static int on_init(const dstalk_host_api_t* host)
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:从内存安全擦除 API key,清空服务指针 / Plugin shutdown: securely erase API key from memory, null out service pointers.
|
||||
static void on_shutdown()
|
||||
{
|
||||
try {
|
||||
@@ -644,12 +671,12 @@ static void on_shutdown()
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 插件描述符
|
||||
// 插件描述符 / Plugin descriptor
|
||||
// ============================================================================
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "deepseek-ai",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "DeepSeek AI provider (OpenAI-compatible API)",
|
||||
/* .description = */ "DeepSeek AI provider (OpenAI-compatible API) / DeepSeek AI 提供者 (OpenAI 兼容 API)",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ { "http", "config", NULL },
|
||||
/* .on_init = */ on_init,
|
||||
@@ -657,6 +684,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
/* .on_event = */ nullptr,
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
|
||||
{
|
||||
return &g_info;
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
/*
|
||||
* @file file_io_plugin.cpp
|
||||
* @brief File I/O plugin: basic file read/write service.
|
||||
* 文件 I/O 插件:基本文件读写服务。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
@@ -6,20 +13,21 @@
|
||||
#include <cstring>
|
||||
|
||||
// ============================================================
|
||||
// Global state
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// Service implementations
|
||||
// 服务实现 / Service implementations
|
||||
// ============================================================
|
||||
// 读取文件全部内容到主机分配的缓冲区,调用方须通过 host->free 释放 / Read the entire contents of a file into a host-allocated buffer. Caller must free via host->free.
|
||||
static int file_read(const char* path, char** content) {
|
||||
if (!path || !content) return -1;
|
||||
|
||||
FILE* fp = fopen(path, "rb");
|
||||
if (!fp) return -1;
|
||||
|
||||
// Get file size
|
||||
// 获取文件大小 / Get file size
|
||||
fseek(fp, 0, SEEK_END);
|
||||
long fsize = ftell(fp);
|
||||
fseek(fp, 0, SEEK_SET);
|
||||
@@ -29,7 +37,7 @@ static int file_read(const char* path, char** content) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Allocate buffer via host allocator (+1 for null terminator)
|
||||
// 通过主机分配器分配缓冲区(+1 用于空终止符) / Allocate buffer via host allocator (+1 for null terminator)
|
||||
char* buf = (char*)g_host->alloc((size_t)fsize + 1);
|
||||
if (!buf) {
|
||||
fclose(fp);
|
||||
@@ -49,6 +57,7 @@ static int file_read(const char* path, char** content) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 将字符串写入文件,覆盖已有内容 / Write a string to a file, overwriting any existing content.
|
||||
static int file_write(const char* path, const char* content) {
|
||||
if (!path || !content) return -1;
|
||||
|
||||
@@ -68,28 +77,31 @@ static dstalk_file_io_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Plugin lifecycle
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
// 插件初始化:保存主机指针并注册 file_io 服务 / Plugin init: store host pointer and register the file_io service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
return host->register_service("file_io", 1, &g_service);
|
||||
}
|
||||
|
||||
// 插件关闭:无需清理 / Plugin shutdown: nothing to clean up.
|
||||
static void on_shutdown() {
|
||||
// nothing to clean up
|
||||
// 无需清理 / nothing to clean up
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"file-io", // name
|
||||
"1.0.0", // version
|
||||
"Basic file I/O service", // description
|
||||
"file-io", // name 名称
|
||||
"1.0.0", // version 版本
|
||||
"Basic file I/O service", // description 描述
|
||||
DSTALK_API_VERSION, // api_version
|
||||
{nullptr}, // dependencies (none)
|
||||
{nullptr}, // dependencies 依赖 (none)
|
||||
on_init, // on_init
|
||||
on_shutdown, // on_shutdown
|
||||
nullptr // on_event
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
/*
|
||||
* plugin-lsp — LSP (Language Server Protocol) 服务
|
||||
*
|
||||
* 自行管理语言服务器子进程,使用 JSON-RPC 2.0 over stdio 通信。
|
||||
* 无外部服务依赖(不依赖 http/config 等其他插件)。
|
||||
* @file lsp_plugin.cpp
|
||||
* @brief LSP plugin: Language Server Protocol JSON-RPC client for diagnostics, hover, completion.
|
||||
* LSP 插件:Language Server Protocol JSON-RPC 客户端,用于诊断、悬停、补全。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-lsp — LSP (Language Server Protocol) 服务 / LSP (Language Server Protocol) service
|
||||
//
|
||||
// 自行管理语言服务器子进程,使用 JSON-RPC 2.0 over stdio 通信 / Self-manages language server subprocess, communicates via JSON-RPC 2.0 over stdio.
|
||||
// 无外部服务依赖(不依赖 http/config 等其他插件) / No external service dependencies (does not depend on http/config or other plugins).
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
@@ -22,7 +27,7 @@
|
||||
#include <unordered_map>
|
||||
|
||||
// ============================================================================
|
||||
// 平台相关 — 子进程管理 (内嵌 subprocess::Process)
|
||||
// 平台相关 — 子进程管理 (内嵌 subprocess::Process) / Platform specific — subprocess management (embedded subprocess::Process)
|
||||
// ============================================================================
|
||||
|
||||
#ifdef _WIN32
|
||||
@@ -45,12 +50,12 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 全局指针
|
||||
// 全局指针 / Global pointers
|
||||
// ============================================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
|
||||
// ============================================================================
|
||||
// 子进程封装 (内嵌 subprocess.hpp)
|
||||
// 子进程封装 (内嵌 subprocess.hpp) / Subprocess wrapper (embedded subprocess.hpp)
|
||||
// ============================================================================
|
||||
struct Process {
|
||||
#ifdef _WIN32
|
||||
@@ -64,6 +69,7 @@ struct Process {
|
||||
int stdout_fd = -1;
|
||||
#endif
|
||||
|
||||
// 从给定命令行启动子进程。为 stdin/stdout 设置管道 / Start a child process from the given command line. Sets up pipes for stdin/stdout.
|
||||
bool start(const char* cmd) {
|
||||
if (!cmd || !cmd[0]) return false;
|
||||
stop();
|
||||
@@ -169,6 +175,7 @@ struct Process {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 优雅终止子进程,回退到 SIGKILL/TerminateProcess / Gracefully terminate the child process, with fallback to SIGKILL/TerminateProcess.
|
||||
void stop() {
|
||||
#ifdef _WIN32
|
||||
if (hProcess != INVALID_HANDLE_VALUE) {
|
||||
@@ -198,6 +205,7 @@ struct Process {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 将数据字符串写入子进程 stdin 管道 / Write a data string to the child's stdin pipe.
|
||||
bool write(const std::string& data) {
|
||||
if (data.empty()) return true;
|
||||
#ifdef _WIN32
|
||||
@@ -219,6 +227,7 @@ struct Process {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 从子进程 stdout 管道读取一行(到并包括 '\n') / Read one line (up to and including '\n') from the child's stdout pipe.
|
||||
bool read_line(std::string& line) {
|
||||
line.clear();
|
||||
#ifdef _WIN32
|
||||
@@ -242,6 +251,7 @@ struct Process {
|
||||
#endif
|
||||
}
|
||||
|
||||
// 从子进程 stdout 管道读取恰好 count 字节到 buf / Read exactly `count` bytes from the child's stdout pipe into `buf`.
|
||||
bool read_bytes(std::string& buf, int count) {
|
||||
if (count <= 0) { buf.clear(); return true; }
|
||||
#ifdef _WIN32
|
||||
@@ -274,7 +284,7 @@ struct Process {
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// LSP 状态(静态单例)
|
||||
// LSP 状态(静态单例) / LSP state (static singleton)
|
||||
// ============================================================================
|
||||
struct LspState {
|
||||
Process proc;
|
||||
@@ -283,23 +293,24 @@ struct LspState {
|
||||
|
||||
std::atomic<int> next_id{1};
|
||||
|
||||
// 响应用于同步等待
|
||||
// 响应用于同步等待 / Responses for synchronous waiting
|
||||
std::mutex mutex;
|
||||
std::condition_variable cv;
|
||||
std::unordered_map<int, std::string> pending_responses;
|
||||
|
||||
// 诊断缓存: URI -> JSON 字符串
|
||||
// 诊断缓存: URI -> JSON 字符串 / Diagnostics cache: URI -> JSON string
|
||||
std::unordered_map<std::string, std::string> diagnostics;
|
||||
|
||||
// 读取线程
|
||||
// 读取线程 / Reader thread
|
||||
std::thread reader_thread;
|
||||
};
|
||||
static LspState g_lsp;
|
||||
|
||||
// ============================================================================
|
||||
// 辅助函数
|
||||
// 辅助函数 / Helper functions
|
||||
// ============================================================================
|
||||
|
||||
// 去除 string_view 首尾空白 / Trim leading and trailing whitespace from a string_view.
|
||||
static std::string_view trim(std::string_view sv) {
|
||||
while (!sv.empty() && (sv.front() == ' ' || sv.front() == '\t' ||
|
||||
sv.front() == '\r' || sv.front() == '\n'))
|
||||
@@ -310,6 +321,7 @@ static std::string_view trim(std::string_view sv) {
|
||||
return sv;
|
||||
}
|
||||
|
||||
// 将 JSON-RPC 消息体包装在 LSP 头中 (Content-Length: ...\r\n\r\n) / Wrap a JSON-RPC message body in an LSP header (Content-Length: ...\r\n\r\n).
|
||||
static std::string frame_message(const std::string& body) {
|
||||
std::string frame;
|
||||
frame.reserve(64 + body.size());
|
||||
@@ -320,6 +332,7 @@ static std::string frame_message(const std::string& body) {
|
||||
return frame;
|
||||
}
|
||||
|
||||
// 从 LSP 头行中解析 Content-Length 值。解析失败返回 -1 / Parse the Content-Length value from an LSP header line. Returns -1 on parse failure.
|
||||
static int parse_content_length(const std::string& line) {
|
||||
auto sv = trim(std::string_view(line));
|
||||
const char prefix[] = "Content-Length:";
|
||||
@@ -341,9 +354,10 @@ static int parse_content_length(const std::string& line) {
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// JSON-RPC 消息发送
|
||||
// JSON-RPC 消息发送 / JSON-RPC message sending
|
||||
// ============================================================================
|
||||
|
||||
// 向 LSP 服务器发送 JSON-RPC 请求并返回分配的请求 id / Send a JSON-RPC request to the LSP server and return the assigned request id.
|
||||
static int send_request(const std::string& method, const json::object& params) {
|
||||
int id = g_lsp.next_id.fetch_add(1);
|
||||
|
||||
@@ -358,6 +372,7 @@ static int send_request(const std::string& method, const json::object& params) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// 向 LSP 服务器发送 JSON-RPC 通知(无 id 字段,不期待响应) / Send a JSON-RPC notification to the LSP server (no id field, no response expected).
|
||||
static void send_notification(const std::string& method, const json::object& params) {
|
||||
json::object msg;
|
||||
msg["jsonrpc"] = "2.0";
|
||||
@@ -369,9 +384,12 @@ static void send_notification(const std::string& method, const json::object& par
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 消息处理
|
||||
// 消息处理 / Message handling
|
||||
// ============================================================================
|
||||
|
||||
// 分发接收到的 JSON-RPC 消息:将响应路由到待处理队列,
|
||||
// 处理 textDocument/publishDiagnostics 通知并存入诊断缓存 / Dispatch a received JSON-RPC message: route responses to pending queue,
|
||||
// handle textDocument/publishDiagnostics notifications into diagnostics cache.
|
||||
static void handle_message(const std::string& body) {
|
||||
try {
|
||||
json::value val;
|
||||
@@ -383,14 +401,14 @@ static void handle_message(const std::string& body) {
|
||||
catch (...) { return; }
|
||||
|
||||
if (msg.contains("id") && !msg.contains("method")) {
|
||||
// 响应 (有 id, 无 method)
|
||||
// 响应 (有 id, 无 method) / Response (has id, no method)
|
||||
int id = static_cast<int>(msg["id"].as_int64());
|
||||
std::lock_guard<std::mutex> lock(g_lsp.mutex);
|
||||
g_lsp.pending_responses[id] = body;
|
||||
g_lsp.cv.notify_all();
|
||||
|
||||
} else if (msg.contains("method") && !msg.contains("id")) {
|
||||
// 通知 (有 method, 无 id)
|
||||
// 通知 (有 method, 无 id) / Notification (has method, no id)
|
||||
std::string method;
|
||||
try { method = json::value_to<std::string>(msg["method"]); }
|
||||
catch (...) { return; }
|
||||
@@ -419,17 +437,18 @@ static void handle_message(const std::string& body) {
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 读取线程主循环
|
||||
// 读取线程主循环 / Reader thread main loop
|
||||
// ============================================================================
|
||||
|
||||
// 读取线程主循环:解析 LSP header+body 帧并分发消息 / Main loop for the reader thread: parse LSP header+body frames and dispatch messages.
|
||||
static void reader_loop() {
|
||||
try {
|
||||
while (g_lsp.running) {
|
||||
int content_length = -1;
|
||||
bool pipe_ok = true;
|
||||
|
||||
// 状态机式读取 header 块:循环 read_line 直到读到空行
|
||||
// LSP 3.17: header 块以空行(\r\n)结束,允许 Content-Type 等其他 header
|
||||
// 状态机式读取 header 块:循环 read_line 直到读到空行 / State-machine header block read: loop read_line until empty line
|
||||
// LSP 3.17: header 块以空行(\r\n)结束,允许 Content-Type 等其他 header / LSP 3.17: header block ends with empty line (\r\n), allows other headers like Content-Type
|
||||
while (pipe_ok) {
|
||||
std::string line;
|
||||
if (!g_lsp.proc.read_line(line)) {
|
||||
@@ -437,18 +456,18 @@ static void reader_loop() {
|
||||
break;
|
||||
}
|
||||
|
||||
// header 块以空行结束
|
||||
// header 块以空行结束 / header block ends with empty line
|
||||
auto sv = trim(std::string_view(line));
|
||||
if (sv.empty()) break;
|
||||
|
||||
// 累积 Content-Length;遇到其他 header 不丢弃,继续读取下一行
|
||||
// 累积 Content-Length;遇到其他 header 不丢弃,继续读取下一行 / Accumulate Content-Length; don't discard other headers, continue reading next line
|
||||
int len = parse_content_length(line);
|
||||
if (len >= 0) content_length = len;
|
||||
}
|
||||
|
||||
if (!pipe_ok) break;
|
||||
|
||||
// 空行前都没读到 Content-Length,协议错误——记日志并跳过这一帧
|
||||
// 空行前都没读到 Content-Length,协议错误——记日志并跳过这一帧 / Content-Length not read before empty line, protocol error — log and skip this frame
|
||||
if (content_length < 0) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] Invalid LSP frame: missing Content-Length header");
|
||||
continue;
|
||||
@@ -471,38 +490,39 @@ static void reader_loop() {
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LSP 服务 vtable 实现 (定义在 vtable 变量之前)
|
||||
// LSP 服务 vtable 实现 (定义在 vtable 变量之前) / LSP service vtable implementation (defined before vtable variable)
|
||||
// ============================================================================
|
||||
|
||||
static void g_lsp_impl_stop();
|
||||
static void g_lsp_impl_stop_nolock();
|
||||
static void g_lsp_impl_stop_locked(std::unique_lock<std::mutex>& lock);
|
||||
|
||||
// 启动 LSP 服务器进程,发送 initialize/initialized 握手,启动读取线程 / Start the LSP server process, send initialize/initialized handshake, start reader thread.
|
||||
static int g_lsp_impl_start(const char* server_cmd, const char* language) {
|
||||
if (!server_cmd || !server_cmd[0]) return -1;
|
||||
|
||||
try {
|
||||
// 如果已在运行, 先停止
|
||||
// 如果已在运行, 先停止 / If already running, stop first
|
||||
if (g_lsp.running) {
|
||||
g_lsp_impl_stop();
|
||||
}
|
||||
|
||||
g_lsp.language = language ? language : "";
|
||||
|
||||
// 启动进程
|
||||
// 启动进程 / Start process
|
||||
if (!g_lsp.proc.start(server_cmd)) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] failed to start: %s", server_cmd);
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 重置 ID 计数器
|
||||
// 重置 ID 计数器 / Reset ID counter
|
||||
g_lsp.next_id = 1;
|
||||
|
||||
// 启动读取线程
|
||||
// 启动读取线程 / Start reader thread
|
||||
g_lsp.running = true;
|
||||
g_lsp.reader_thread = std::thread(reader_loop);
|
||||
|
||||
// 构建 initialize 参数
|
||||
// 构建 initialize 参数 / Build initialize params
|
||||
json::object text_doc_caps;
|
||||
{
|
||||
json::object hover;
|
||||
@@ -526,10 +546,10 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
|
||||
init_params["rootUri"] = nullptr;
|
||||
init_params["capabilities"] = capabilities;
|
||||
|
||||
// 发送 initialize 请求
|
||||
// 发送 initialize 请求 / Send initialize request
|
||||
int init_id = send_request("initialize", init_params);
|
||||
|
||||
// 等待 initialize 响应 (最多 10 秒)
|
||||
// 等待 initialize 响应 (最多 10 秒) / Wait for initialize response (max 10 seconds)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(g_lsp.mutex);
|
||||
bool got = g_lsp.cv.wait_for(lock, std::chrono::seconds(10), [init_id]() {
|
||||
@@ -544,7 +564,7 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
|
||||
g_lsp.pending_responses.erase(init_id);
|
||||
}
|
||||
|
||||
// 发送 initialized 通知
|
||||
// 发送 initialized 通知 / Send initialized notification
|
||||
send_notification("initialized", json::object{});
|
||||
|
||||
if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] server started: %s", server_cmd);
|
||||
@@ -558,14 +578,15 @@ static int g_lsp_impl_start(const char* server_cmd, const char* language) {
|
||||
}
|
||||
}
|
||||
|
||||
// 停止 LSP 服务器:发送 shutdown 请求,发送 exit 通知,停止进程和线程 / Stop the LSP server: send shutdown request, send exit notification, stop process & thread.
|
||||
static void g_lsp_impl_stop_nolock() {
|
||||
try {
|
||||
if (!g_lsp.running) return;
|
||||
|
||||
// 发送 shutdown 请求
|
||||
// 发送 shutdown 请求 / Send shutdown request
|
||||
int shutdown_id = send_request("shutdown", json::object{});
|
||||
|
||||
// 等待 shutdown 响应 (最多 2 秒)
|
||||
// 等待 shutdown 响应 (最多 2 秒) / Wait for shutdown response (max 2 seconds)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(g_lsp.mutex);
|
||||
g_lsp.cv.wait_for(lock, std::chrono::seconds(2), [shutdown_id]() {
|
||||
@@ -574,10 +595,10 @@ static void g_lsp_impl_stop_nolock() {
|
||||
g_lsp.pending_responses.clear();
|
||||
}
|
||||
|
||||
// 发送 exit 通知
|
||||
// 发送 exit 通知 / Send exit notification
|
||||
send_notification("exit", json::object{});
|
||||
|
||||
// 停止读取线程
|
||||
// 停止读取线程 / Stop reader thread
|
||||
g_lsp.running = false;
|
||||
g_lsp.proc.stop();
|
||||
|
||||
@@ -593,15 +614,18 @@ static void g_lsp_impl_stop_nolock() {
|
||||
}
|
||||
}
|
||||
|
||||
// 公开 stop:无锁获取(委托给 g_lsp_impl_stop_nolock) / Public stop: acquires no lock (delegates to g_lsp_impl_stop_nolock).
|
||||
static void g_lsp_impl_stop() {
|
||||
g_lsp_impl_stop_nolock();
|
||||
}
|
||||
|
||||
// Stop 辅助函数:在调用 g_lsp_impl_stop_nolock 前解锁给定的 unique_lock / Stop helper: unlocks the given unique_lock before calling g_lsp_impl_stop_nolock.
|
||||
static void g_lsp_impl_stop_locked(std::unique_lock<std::mutex>& lock) {
|
||||
lock.unlock();
|
||||
g_lsp_impl_stop_nolock();
|
||||
}
|
||||
|
||||
// 向 LSP 服务器发送 textDocument/didOpen 通知 / Send a textDocument/didOpen notification to the LSP server.
|
||||
static int g_lsp_impl_open_document(const char* uri, const char* content,
|
||||
const char* lang_id) {
|
||||
if (!g_lsp.running) return -1;
|
||||
@@ -628,6 +652,7 @@ static int g_lsp_impl_open_document(const char* uri, const char* content,
|
||||
}
|
||||
}
|
||||
|
||||
// 向 LSP 服务器发送 textDocument/didClose 通知 / Send a textDocument/didClose notification to the LSP server.
|
||||
static int g_lsp_impl_close_document(const char* uri) {
|
||||
if (!g_lsp.running) return -1;
|
||||
if (!uri) return -1;
|
||||
@@ -650,6 +675,7 @@ static int g_lsp_impl_close_document(const char* uri) {
|
||||
}
|
||||
}
|
||||
|
||||
// 返回给定文档 URI 的缓存诊断 JSON / Return the cached diagnostics JSON for the given document URI.
|
||||
static int g_lsp_impl_get_diagnostics(const char* uri, char** json_out) {
|
||||
if (!g_lsp.running) return -1;
|
||||
if (!uri || !json_out) return -1;
|
||||
@@ -674,6 +700,7 @@ static int g_lsp_impl_get_diagnostics(const char* uri, char** json_out) {
|
||||
}
|
||||
}
|
||||
|
||||
// 发送 textDocument/hover 请求并以 JSON 返回悬停结果 / Send a textDocument/hover request and return the hover result as JSON.
|
||||
static int g_lsp_impl_get_hover(const char* uri, int line, int col, char** json_out) {
|
||||
if (!g_lsp.running) return -1;
|
||||
if (!uri || !json_out) return -1;
|
||||
@@ -727,6 +754,7 @@ static int g_lsp_impl_get_hover(const char* uri, int line, int col, char** json_
|
||||
}
|
||||
}
|
||||
|
||||
// 发送 textDocument/completion 请求并以 JSON 返回补全列表 / Send a textDocument/completion request and return the completion list as JSON.
|
||||
static int g_lsp_impl_get_completion(const char* uri, int line, int col, char** json_out) {
|
||||
if (!g_lsp.running) return -1;
|
||||
if (!uri || !json_out) return -1;
|
||||
@@ -781,7 +809,7 @@ static int g_lsp_impl_get_completion(const char* uri, int line, int col, char**
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 服务 vtable
|
||||
// 服务 vtable / Service vtable
|
||||
// ============================================================================
|
||||
|
||||
static dstalk_lsp_service_t g_service_vtable = {
|
||||
@@ -795,15 +823,17 @@ static dstalk_lsp_service_t g_service_vtable = {
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// 生命周期回调
|
||||
// 生命周期回调 / Lifecycle callbacks
|
||||
// ============================================================================
|
||||
|
||||
// 插件初始化:保存主机指针并注册 lsp 服务 / Plugin init: store host pointer and register the lsp service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] initializing LSP service plugin");
|
||||
return host->register_service("lsp", 1, &g_service_vtable);
|
||||
}
|
||||
|
||||
// 插件关闭:如果运行中则停止 LSP 服务器,清空主机指针 / Plugin shutdown: stop LSP server if running, null out host pointer.
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
if (g_lsp.running) {
|
||||
@@ -821,20 +851,21 @@ static void on_shutdown() {
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 插件描述符
|
||||
// 插件描述符 / Plugin descriptor
|
||||
// ============================================================================
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "lsp",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "Language Server Protocol client (subprocess manager)",
|
||||
/* .description = */ "Language Server Protocol client (subprocess manager) / Language Server Protocol 客户端(子进程管理器)",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ { NULL }, // 无依赖,自行管理子进程
|
||||
/* .dependencies = */ { NULL }, // 无依赖,自行管理子进程 / No dependencies, self-manages subprocess
|
||||
/* .on_init = */ on_init,
|
||||
/* .on_shutdown = */ on_shutdown,
|
||||
/* .on_event = */ nullptr,
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
// MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20)
|
||||
/*
|
||||
* @file network_plugin.cpp
|
||||
* @brief Network plugin: HTTP/HTTPS POST and streaming via Boost.Beast + OpenSSL.
|
||||
* 网络插件:基于 Boost.Beast + OpenSSL 的 HTTP/HTTPS POST 和流式传输。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// MSVC 14.16 (VS 2017) 不提供 std::to_address (C++20) / MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20)
|
||||
#define BOOST_ASIO_DISABLE_STD_TO_ADDRESS
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
@@ -29,21 +36,22 @@ namespace ssl = boost::asio::ssl;
|
||||
using tcp = asio::ip::tcp;
|
||||
|
||||
// ============================================================
|
||||
// Global state
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
static dstalk_config_service_t* g_config_svc = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// Minimal JSON header parser
|
||||
// Parses {"key1":"value1","key2":"value2"} into unordered_map
|
||||
// 极简 JSON 头解析器 / Minimal JSON header parser
|
||||
// 将 {"key1":"value1","key2":"value2"} 解析到 unordered_map / Parses {"key1":"value1","key2":"value2"} into unordered_map
|
||||
// ============================================================
|
||||
// 将扁平 JSON 对象中的字符串键值对解析到 unordered_map / Parse a flat JSON object of string key-value pairs into an unordered_map.
|
||||
static std::unordered_map<std::string, std::string> parse_headers_json(const char* json) {
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
if (!json || !*json) return headers;
|
||||
|
||||
std::string s(json);
|
||||
// Very simple state-machine parser for flat string-key/value objects
|
||||
// 极简状态机解析器,处理扁平的字符串键值对象 / Very simple state-machine parser for flat string-key/value objects
|
||||
enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE;
|
||||
std::string current_key;
|
||||
std::string current_value;
|
||||
@@ -64,7 +72,7 @@ static std::unordered_map<std::string, std::string> parse_headers_json(const cha
|
||||
break;
|
||||
case IN_VALUE:
|
||||
if (c == '"') {
|
||||
// Read until closing quote
|
||||
// 读取到闭合引号 / Read until closing quote
|
||||
++i;
|
||||
while (i < s.size() && s[i] != '"') {
|
||||
if (s[i] == '\\' && i + 1 < s.size()) { current_value += s[++i]; }
|
||||
@@ -81,7 +89,7 @@ static std::unordered_map<std::string, std::string> parse_headers_json(const cha
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// HTTP Client implementation (adapted from dstalk-core HttpClient)
|
||||
// HTTP 客户端实现(改编自 dstalk-core HttpClient) / HTTP Client implementation (adapted from dstalk-core HttpClient)
|
||||
// ============================================================
|
||||
struct HttpClientCtx {
|
||||
asio::io_context ioc;
|
||||
@@ -91,15 +99,22 @@ struct HttpClientCtx {
|
||||
|
||||
HttpClientCtx() {
|
||||
ssl_ctx.set_default_verify_paths();
|
||||
// Enable peer certificate verification (CVSS 7.4 fix).
|
||||
// set_default_verify_paths() loads system CA bundle; without verify_peer
|
||||
// 启用对等证书验证 (CVSS 7.4 修复) / Enable peer certificate verification (CVSS 7.4 fix).
|
||||
// set_default_verify_paths() 加载系统 CA 包;没有 verify_peer
|
||||
// CA 存储不会被查询——任何证书(自签名/过期)都将被接受 / set_default_verify_paths() loads system CA bundle; without verify_peer
|
||||
// the CA store is never consulted — any cert (self-signed/expired) is accepted.
|
||||
// TODO: Windows: set_default_verify_paths() may not locate system CAs;
|
||||
// TODO: Windows: set_default_verify_paths() 可能无法定位系统 CA;
|
||||
// 如果验证失败,设置 SSL_CERT_FILE 环境变量或捆绑 cacert.pem / Windows: set_default_verify_paths() may not locate system CAs;
|
||||
// if verification fails, set SSL_CERT_FILE env or bundle a cacert.pem.
|
||||
ssl_ctx.set_verify_mode(ssl::verify_peer);
|
||||
}
|
||||
};
|
||||
|
||||
// 核心 HTTP/HTTPS POST,支持可选 SSE 流式传输。执行 DNS 解析、
|
||||
// TLS 握手(含 SNI 和主机名验证),然后发送请求。
|
||||
// 如果 cb 非空,响应体将逐行解析用于流式传输 / Core HTTP/HTTPS POST with optional SSE streaming. Performs DNS resolve,
|
||||
// TLS handshake with SNI and hostname verification, then sends the request.
|
||||
// If `cb` is non-null, response body is parsed line-by-line for streaming.
|
||||
static int do_post_stream(
|
||||
const char* host,
|
||||
const char* port,
|
||||
@@ -117,11 +132,11 @@ static int do_post_stream(
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Initialize output
|
||||
// 初始化输出 / Initialize output
|
||||
*response_body = nullptr;
|
||||
*status_code = -1;
|
||||
|
||||
// Build C++ lambda from C callback
|
||||
// 从 C 回调构建 C++ lambda / Build C++ lambda from C callback
|
||||
std::function<bool(const std::string&)> on_line;
|
||||
if (cb) {
|
||||
on_line = [cb, userdata](const std::string& line) -> bool {
|
||||
@@ -131,7 +146,7 @@ static int do_post_stream(
|
||||
|
||||
HttpClientCtx ctx;
|
||||
|
||||
// Read timeouts from config if available
|
||||
// 从配置读取超时设置 / Read timeouts from config if available
|
||||
if (g_config_svc) {
|
||||
const char* ct = g_config_svc->get("http.connect_timeout");
|
||||
const char* rt = g_config_svc->get("http.request_timeout");
|
||||
@@ -147,7 +162,9 @@ static int do_post_stream(
|
||||
try {
|
||||
tcp::resolver resolver(ctx.ioc);
|
||||
|
||||
// DNS resolve with 10-second timeout. Boost.Asio's synchronous
|
||||
// DNS 解析,10 秒超时。Boost.Asio 的同步 resolve()
|
||||
// 在内部运行 io_context,因此定时器的 async_wait 回调在 resolve() 期间执行,
|
||||
// 并在超时触发时调用 resolver.cancel() / DNS resolve with 10-second timeout. Boost.Asio's synchronous
|
||||
// resolve() runs the io_context internally, so the timer's async_wait
|
||||
// callback executes during resolve() and calls resolver.cancel() when
|
||||
// the deadline fires.
|
||||
@@ -172,7 +189,7 @@ static int do_post_stream(
|
||||
beast::ssl_stream<beast::tcp_stream> stream(ctx.ioc, ctx.ssl_ctx);
|
||||
beast::flat_buffer buffer;
|
||||
|
||||
// SNI hostname
|
||||
// SNI 主机名 / SNI hostname
|
||||
if (!SSL_set_tlsext_host_name(stream.native_handle(), host)) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR,
|
||||
"do_post_stream: SNI hostname set failed for %s", host);
|
||||
@@ -180,7 +197,9 @@ static int do_post_stream(
|
||||
goto done;
|
||||
}
|
||||
|
||||
// Hostname verification: require server certificate CN/SAN to match
|
||||
// 主机名验证:要求服务器证书 CN/SAN 匹配 'host'。
|
||||
// 与 ssl::verify_peer 协同工作——没有它的话,
|
||||
// 使用不同主机名的有效 CA 签名证书进行 MITM 攻击仍可通过 / Hostname verification: require server certificate CN/SAN to match
|
||||
// 'host'. This works in conjunction with ssl::verify_peer on the
|
||||
// context — without it MITM with a valid CA-signed cert for a
|
||||
// different hostname would still pass.
|
||||
@@ -191,19 +210,19 @@ static int do_post_stream(
|
||||
goto done;
|
||||
}
|
||||
|
||||
// Connect
|
||||
// 连接 / Connect
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.connect_timeout));
|
||||
beast::get_lowest_layer(stream).connect(endpoints);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// SSL handshake
|
||||
// SSL 握手 / SSL handshake
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.connect_timeout));
|
||||
stream.handshake(ssl::stream_base::client);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// Build HTTP POST request
|
||||
// 构建 HTTP POST 请求 / Build HTTP POST request
|
||||
http::request<http::string_body> req{http::verb::post, target, 11};
|
||||
req.set(http::field::host, host);
|
||||
req.set(http::field::user_agent, "dstalk/0.1");
|
||||
@@ -211,19 +230,19 @@ static int do_post_stream(
|
||||
req.body() = body;
|
||||
req.prepare_payload();
|
||||
|
||||
// Add extra headers from JSON
|
||||
// 从 JSON 添加额外的头 / Add extra headers from JSON
|
||||
auto extra_headers = parse_headers_json(headers_json);
|
||||
for (const auto& h : extra_headers) {
|
||||
req.set(h.first, h.second);
|
||||
}
|
||||
|
||||
// Send
|
||||
// 发送 / Send
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.request_timeout));
|
||||
http::write(stream, req);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// Read response
|
||||
// 读取响应 / Read response
|
||||
http::response_parser<http::string_body> parser;
|
||||
parser.body_limit(16 * 1024 * 1024);
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
@@ -310,8 +329,9 @@ done:
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Service implementations
|
||||
// 服务实现 / Service implementations
|
||||
// ============================================================
|
||||
// 同步 HTTP POST,返回完整响应体 / Synchronous HTTP POST returning the complete response body.
|
||||
static int http_post_json(
|
||||
const char* host, const char* port,
|
||||
const char* target, const char* body,
|
||||
@@ -322,6 +342,7 @@ static int http_post_json(
|
||||
nullptr, nullptr, response_body, status_code);
|
||||
}
|
||||
|
||||
// HTTP POST 带 SSE 流式传输:响应行到达时通过 cb 回调传递 / HTTP POST with SSE streaming: response lines are delivered to `cb` as they arrive.
|
||||
static int http_post_stream(
|
||||
const char* host, const char* port,
|
||||
const char* target, const char* body,
|
||||
@@ -339,32 +360,35 @@ static dstalk_http_service_t g_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Plugin lifecycle
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
// 插件初始化:保存主机指针,查询 config 服务,注册 http 服务 / Plugin init: store host pointer, query config service, register http service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
|
||||
// Query config service (declared dependency)
|
||||
// 查询 config 服务(声明的依赖) / Query config service (declared dependency)
|
||||
g_config_svc = (dstalk_config_service_t*)host->query_service("config", 1);
|
||||
|
||||
return host->register_service("http", 1, &g_service);
|
||||
}
|
||||
|
||||
// 插件关闭:无需清理 / Plugin shutdown: nothing to clean up.
|
||||
static void on_shutdown() {
|
||||
// nothing to clean up
|
||||
// 无需清理 / nothing to clean up
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"http", // name
|
||||
"1.0.0", // version
|
||||
"HTTP/HTTPS client service using Boost.Beast + OpenSSL", // description
|
||||
"http", // name 名称
|
||||
"1.0.0", // version 版本
|
||||
"HTTP/HTTPS client service using Boost.Beast + OpenSSL", // description 描述
|
||||
DSTALK_API_VERSION, // api_version
|
||||
{"config", nullptr}, // dependencies
|
||||
{"config", nullptr}, // dependencies 依赖
|
||||
on_init, // on_init
|
||||
on_shutdown, // on_shutdown
|
||||
nullptr // on_event
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
// plugin-session: 会话管理服务插件
|
||||
// 提供 dstalk_session_service_t vtable 实现
|
||||
// 依赖: file_io (save/load 需要文件操作)
|
||||
/*
|
||||
* @file session_plugin.cpp
|
||||
* @brief Session plugin: conversation message history management with save/load.
|
||||
* 会话插件:对话消息历史管理,支持保存/加载。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-session: 会话管理服务插件 / Session management service plugin
|
||||
// 提供 dstalk_session_service_t vtable 实现 / Provides dstalk_session_service_t vtable implementation
|
||||
// 依赖: file_io (save/load 需要文件操作) / Depends on: file_io (save/load needs file operations)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_types.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
@@ -24,14 +31,14 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================
|
||||
// 内部 C++ 数据结构
|
||||
// 内部 C++ 数据结构 / Internal C++ data structures
|
||||
// ============================================================
|
||||
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
|
||||
|
||||
// 内部消息结构(C++ 易用,外部暴露 C struct)
|
||||
// 内部消息结构(C++ 易用,外部暴露 C struct) / Internal message struct (C++ friendly, externally exposed as C struct)
|
||||
struct InternalMessage {
|
||||
std::string role;
|
||||
std::string content;
|
||||
@@ -39,21 +46,24 @@ struct InternalMessage {
|
||||
std::string tool_calls_json;
|
||||
};
|
||||
|
||||
// 会话历史 + 缓存 —— W14.3: mutex 保护读写
|
||||
// 会话历史 + 缓存 —— W14.3: mutex 保护读写 / Session history + cache — W14.3: mutex protects read/write
|
||||
static std::vector<InternalMessage> g_history;
|
||||
static std::vector<dstalk_message_t> g_cached_history;
|
||||
static std::mutex g_session_mutex;
|
||||
|
||||
// ============================================================
|
||||
// Token 计数工具(内联,避免硬依赖 context 头文件)
|
||||
// Token 计数工具(内联,避免硬依赖 context 头文件) / Token counting utilities (inline, avoids hard dep on context headers)
|
||||
// ============================================================
|
||||
|
||||
// 如果字节是 ASCII (0x00–0x7F) 则返回 true / Returns true if the byte is ASCII (0x00–0x7F).
|
||||
static bool is_ascii(unsigned char c) { return c < 0x80; }
|
||||
|
||||
// 启发式判断:如果字节起始一个 UTF-8 CJK 统一表意文字 (0xE4–0xE9) 则返回 true / Heuristic: returns true if the byte starts a CJK Unified Ideograph in UTF-8 (0xE4–0xE9).
|
||||
static bool starts_cjk(unsigned char c) {
|
||||
return c >= 0xE4 && c <= 0xE9;
|
||||
}
|
||||
|
||||
// 使用启发式 UTF-8 字节计数估算单条消息的 token 数 / Estimate token count for a single message using heuristic UTF-8 byte counting.
|
||||
static size_t count_tokens_one(const std::string& text) {
|
||||
size_t ascii_chars = 0;
|
||||
size_t chinese_chars = 0;
|
||||
@@ -85,9 +95,10 @@ static size_t count_tokens_one(const std::string& text) {
|
||||
}
|
||||
|
||||
size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3);
|
||||
return content_tokens + 4; // +4 per message overhead
|
||||
return content_tokens + 4; // +4 每条消息开销 / +4 per message overhead
|
||||
}
|
||||
|
||||
// 估算所有消息的总 token 数 / Estimate total token count across all messages.
|
||||
static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) {
|
||||
size_t total = 0;
|
||||
for (const auto& m : msgs) {
|
||||
@@ -97,13 +108,15 @@ static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex)
|
||||
// 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex) / Helper: rebuild C cached array (caller must hold g_session_mutex)
|
||||
// ============================================================
|
||||
|
||||
// 从内部消息 vector 重建 C 兼容的缓存历史数组。调用方必须持有 g_session_mutex / Rebuild the C-compatible cached history array from the internal message vector.
|
||||
// Caller must hold g_session_mutex.
|
||||
static void rebuild_cached_history_locked() {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
|
||||
// 释放旧的字符串
|
||||
// 释放旧的字符串 / Free old strings
|
||||
for (auto& m : g_cached_history) {
|
||||
if (m.role) { host->free(const_cast<char*>(m.role)); }
|
||||
if (m.content) { host->free(const_cast<char*>(m.content)); }
|
||||
@@ -112,7 +125,7 @@ static void rebuild_cached_history_locked() {
|
||||
}
|
||||
g_cached_history.clear();
|
||||
|
||||
// 重建
|
||||
// 重建 / Rebuild
|
||||
g_cached_history.reserve(g_history.size());
|
||||
for (const auto& im : g_history) {
|
||||
dstalk_message_t cm;
|
||||
@@ -125,9 +138,10 @@ static void rebuild_cached_history_locked() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Session 服务 vtable 实现 (W14.3: try/catch + mutex)
|
||||
// Session 服务 vtable 实现 (W14.3: try/catch + mutex) / Session service vtable implementation (W14.3: try/catch + mutex)
|
||||
// ============================================================
|
||||
|
||||
// 向对话历史追加一条消息 / Append a message to the conversation history.
|
||||
static void session_add(const dstalk_message_t* msg) {
|
||||
try {
|
||||
if (!msg) return;
|
||||
@@ -148,11 +162,13 @@ static void session_add(const dstalk_message_t* msg) {
|
||||
}
|
||||
}
|
||||
|
||||
// 清空对话历史中的所有消息 / Clear all messages from the conversation history.
|
||||
static void session_clear() {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
g_history.clear();
|
||||
}
|
||||
|
||||
// 将当前对话历史序列化为 JSON 行文件并保存到 path / Serialize the current conversation history to a JSON lines file at `path`.
|
||||
static int session_save(const char* path) {
|
||||
try {
|
||||
if (!path) return -1;
|
||||
@@ -187,6 +203,7 @@ static int session_save(const char* path) {
|
||||
}
|
||||
}
|
||||
|
||||
// 从 JSON 行文件中加载对话历史,替换当前历史 / Load conversation history from a JSON lines file at `path`, replacing current history.
|
||||
static int session_load(const char* path) {
|
||||
try {
|
||||
if (!path) return -1;
|
||||
@@ -246,6 +263,7 @@ static int session_load(const char* path) {
|
||||
}
|
||||
}
|
||||
|
||||
// 返回指向缓存 C 消息数组的指针,并将 *out_count 设置为数组大小 / Return a pointer to the cached C-array of messages and set *out_count to its size.
|
||||
static const dstalk_message_t* session_history(int* out_count) {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
@@ -265,6 +283,7 @@ static const dstalk_message_t* session_history(int* out_count) {
|
||||
}
|
||||
}
|
||||
|
||||
// 返回当前对话历史的估算 token 数 / Return the estimated token count for the current conversation history.
|
||||
static int session_token_count() {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
@@ -290,11 +309,12 @@ static dstalk_session_service_t g_session_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// W20.6: 默认会话保存路径(平台标准目录)
|
||||
// W20.6: 默认会话保存路径(平台标准目录) / Default session save path (platform standard directory)
|
||||
// ============================================================
|
||||
|
||||
// 返回平台特定的默认会话保存路径,根据需要创建目录 / Return the platform-specific default session save path, creating directories as needed.
|
||||
static std::string get_default_session_path() {
|
||||
// W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录
|
||||
// W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录 / static cache + mkdir guarantee + fallback to current dir on failure
|
||||
static std::string cached_path = []() -> std::string {
|
||||
#ifdef _WIN32
|
||||
char* buf = nullptr;
|
||||
@@ -323,14 +343,17 @@ static std::string get_default_session_path() {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
|
||||
// 插件初始化:保存主机指针,查询 file_io 依赖,注册 session 服务,
|
||||
// 并从默认路径自动加载已有会话 / Plugin init: store host pointer, query file_io dependency, register session service,
|
||||
// and auto-load any existing session from the default path.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
try {
|
||||
g_host.store(host, std::memory_order_release);
|
||||
|
||||
// 查询依赖服务: file_io
|
||||
// 查询依赖服务: file_io / Query dependency service: file_io
|
||||
void* raw = host->query_service("file_io", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-session] required service 'file_io' not found");
|
||||
@@ -338,11 +361,11 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
|
||||
|
||||
// 注册自身服务
|
||||
// 注册自身服务 / Register own service
|
||||
int ret = host->register_service("session", 1, &g_session_service);
|
||||
if (ret != 0) return ret;
|
||||
|
||||
// W20.6: 从默认路径恢复会话(文件不存在则静默失败)
|
||||
// W20.6: 从默认路径恢复会话(文件不存在则静默失败) / Restore session from default path (silent fail if file missing)
|
||||
session_load(get_default_session_path().c_str());
|
||||
|
||||
return 0;
|
||||
@@ -357,10 +380,13 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:自动保存会话到默认路径,失败时回退到当前目录,
|
||||
// 然后释放缓存历史和清空状态 / Plugin shutdown: auto-save session to default path, fallback to current dir on failure,
|
||||
// then release cached history and clear state.
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
// W20.6: 清空前自动保存到默认路径
|
||||
// W21.4: 失败告警 + 当前目录 fallback
|
||||
// W20.6: 清空前自动保存到默认路径 / Auto-save to default path before clearing
|
||||
// W21.4: 失败告警 + 当前目录 fallback / Failure warning + current dir fallback
|
||||
int ret = session_save(get_default_session_path().c_str());
|
||||
if (ret != 0) {
|
||||
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
|
||||
@@ -389,7 +415,7 @@ static void on_shutdown() {
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"session",
|
||||
"1.0.0",
|
||||
"Session management plugin with save/load support",
|
||||
"Session management plugin with save/load support / 支持保存/加载的会话管理插件",
|
||||
DSTALK_API_VERSION,
|
||||
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||||
on_init,
|
||||
@@ -397,6 +423,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
nullptr
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
// plugin-tools: 工具注册服务插件
|
||||
// 提供 dstalk_tools_service_t vtable 实现
|
||||
// 依赖: file_io (内置 file_read / file_write 工具)
|
||||
/*
|
||||
* @file tools_plugin.cpp
|
||||
* @brief Tools plugin: tool registration, schema management, and execution registry.
|
||||
* 工具插件:工具注册、schema 管理和执行注册表。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-tools: 工具注册服务插件 / Tool registration service plugin
|
||||
// 提供 dstalk_tools_service_t vtable 实现 / Provides dstalk_tools_service_t vtable implementation
|
||||
// 依赖: file_io (内置 file_read / file_write 工具) / Depends on: file_io (built-in file_read / file_write tools)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_types.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
@@ -20,21 +27,22 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================
|
||||
// 路径安全校验 (W14.3: 防止路径遍历攻击)
|
||||
// 路径安全校验 (W14.3: 防止路径遍历攻击) / Path safety validation (W14.3: prevent path traversal attacks)
|
||||
// ============================================================
|
||||
|
||||
// 验证文件路径是否安全(无绝对路径、无 ".." 遍历、非空) / Validate that a file path is safe (no absolute paths, no ".." traversal, no empty).
|
||||
static bool is_safe_path(const std::string& path) {
|
||||
// 拒绝空路径
|
||||
// 拒绝空路径 / Reject empty path
|
||||
if (path.empty()) return false;
|
||||
|
||||
// 拒绝绝对路径: Unix '/' 开头 或 Windows 盘符 (第二字符 ':')
|
||||
// 拒绝绝对路径: Unix '/' 开头 或 Windows 盘符 (第二字符 ':') / Reject absolute paths: Unix '/' prefix or Windows drive letter (second char ':')
|
||||
if (path[0] == '/' || path[0] == '\\') return false;
|
||||
if (path.size() >= 2 && path[1] == ':') return false;
|
||||
|
||||
// 拒绝含 ".." 段的目录遍历
|
||||
// 拒绝含 ".." 段的目录遍历 / Reject directory traversal with ".." segments
|
||||
if (path.find("..") != std::string::npos) return false;
|
||||
|
||||
// lexical_normal 消解相对组件后再次校验
|
||||
// lexical_normal 消解相对组件后再次校验 / Re-validate after resolving relative components with lexical_normal
|
||||
std::string norm = std::filesystem::path(path).lexically_normal().string();
|
||||
if (norm.empty()) return false;
|
||||
if (norm[0] == '/' || norm[0] == '\\') return false;
|
||||
@@ -45,10 +53,10 @@ static bool is_safe_path(const std::string& path) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 内部数据结构
|
||||
// 内部数据结构 / Internal data structures
|
||||
// ============================================================
|
||||
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
|
||||
|
||||
@@ -59,14 +67,15 @@ struct ToolDef {
|
||||
dstalk_tool_handler_fn handler;
|
||||
};
|
||||
|
||||
// W14.3: g_tools 使用 mutex 保护读写
|
||||
// W14.3: g_tools 使用 mutex 保护读写 / g_tools uses mutex to protect read/write
|
||||
static std::vector<ToolDef> g_tools;
|
||||
static std::mutex g_tools_mutex;
|
||||
|
||||
// ============================================================
|
||||
// 内置工具: file_read, file_write
|
||||
// 内置工具: file_read, file_write / Built-in tools: file_read, file_write
|
||||
// ============================================================
|
||||
|
||||
// 内置工具处理器:读取文件并以 JSON 字符串返回内容 / Built-in tool handler: read a file and return its contents as a JSON string.
|
||||
static char* builtin_file_read(const char* args_json) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
|
||||
@@ -83,7 +92,7 @@ static char* builtin_file_read(const char* args_json) {
|
||||
}
|
||||
std::string path = json::value_to<std::string>(*path_j);
|
||||
|
||||
// W14.3: 路径遍历防护
|
||||
// W14.3: 路径遍历防护 / Path traversal protection
|
||||
if (!is_safe_path(path)) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_read: unsafe path rejected");
|
||||
return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr;
|
||||
@@ -110,6 +119,7 @@ static char* builtin_file_read(const char* args_json) {
|
||||
}
|
||||
}
|
||||
|
||||
// 内置工具处理器:将内容写入文件,返回成功/错误 JSON 对象 / Built-in tool handler: write content to a file, returning a success/error JSON object.
|
||||
static char* builtin_file_write(const char* args_json) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
|
||||
@@ -132,7 +142,7 @@ static char* builtin_file_write(const char* args_json) {
|
||||
std::string path = json::value_to<std::string>(*path_j);
|
||||
std::string content = json::value_to<std::string>(*content_j);
|
||||
|
||||
// W14.3: 路径遍历防护
|
||||
// W14.3: 路径遍历防护 / Path traversal protection
|
||||
if (!is_safe_path(path)) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_write: unsafe path rejected");
|
||||
return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr;
|
||||
@@ -155,18 +165,19 @@ static char* builtin_file_write(const char* args_json) {
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Tools 服务 vtable 实现 (W14.3: try/catch + mutex)
|
||||
// Tools 服务 vtable 实现 (W14.3: try/catch + mutex) / Tools service vtable implementation (W14.3: try/catch + mutex)
|
||||
// ============================================================
|
||||
|
||||
static void tools_unregister_tool(const char* name);
|
||||
|
||||
// 注册命名工具及其描述、JSON Schema 参数和处理函数 / Register a named tool with its description, JSON Schema parameters, and handler function.
|
||||
static int tools_register_tool(const char* name, const char* desc,
|
||||
const char* params_schema,
|
||||
dstalk_tool_handler_fn handler) {
|
||||
try {
|
||||
if (!name || !handler) return -1;
|
||||
|
||||
// 如果已存在同名工具,先注销
|
||||
// 如果已存在同名工具,先注销 / If a tool with the same name exists, unregister first
|
||||
tools_unregister_tool(name);
|
||||
|
||||
ToolDef td;
|
||||
@@ -189,6 +200,7 @@ static int tools_register_tool(const char* name, const char* desc,
|
||||
}
|
||||
}
|
||||
|
||||
// 按名称注销之前注册的工具 / Unregister a previously registered tool by name.
|
||||
static void tools_unregister_tool(const char* name) {
|
||||
try {
|
||||
if (!name) return;
|
||||
@@ -207,6 +219,7 @@ static void tools_unregister_tool(const char* name) {
|
||||
}
|
||||
}
|
||||
|
||||
// 将所有已注册工具序列化为 OpenAI function-calling 格式的 JSON 数组 / Serialize all registered tools into a JSON array in OpenAI function-calling format.
|
||||
static char* tools_get_tools_json() {
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
@@ -249,6 +262,7 @@ static char* tools_get_tools_json() {
|
||||
}
|
||||
}
|
||||
|
||||
// 按名称查找工具并分派执行到注册的处理器 / Look up a tool by name and dispatch execution to its registered handler.
|
||||
static char* tools_execute(const char* name, const char* args_json) {
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
@@ -298,14 +312,15 @@ static dstalk_tools_service_t g_tools_service = {
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
|
||||
// 插件初始化:查询 file_io 依赖,注册内置文件工具,注册 tools 服务 / Plugin init: query file_io dependency, register built-in file tools, register tools service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
try {
|
||||
g_host.store(host, std::memory_order_release);
|
||||
|
||||
// 查询依赖服务: file_io
|
||||
// 查询依赖服务: file_io / Query dependency service: file_io
|
||||
void* raw = host->query_service("file_io", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-tools] required service 'file_io' not found");
|
||||
@@ -313,7 +328,7 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
|
||||
|
||||
// 向自身注册内置工具
|
||||
// 向自身注册内置工具 / Register built-in tools with self
|
||||
tools_register_tool(
|
||||
"file_read",
|
||||
"Read the contents of a file at the given path",
|
||||
@@ -340,6 +355,7 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:清空所有已注册工具并清空服务指针 / Plugin shutdown: clear all registered tools and null out service pointers.
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_tools_mutex);
|
||||
@@ -358,7 +374,7 @@ static void on_shutdown() {
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"tools",
|
||||
"1.0.0",
|
||||
"Tool registration and execution plugin with built-in file tools",
|
||||
"Tool registration and execution plugin with built-in file tools / 内置文件工具的工具注册和执行插件",
|
||||
DSTALK_API_VERSION,
|
||||
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||||
on_init,
|
||||
@@ -366,6 +382,7 @@ static dstalk_plugin_info_t g_info = {
|
||||
nullptr
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user