feat: add OpenAI-compatible AI provider plugin with SSE streaming support
- Implemented the OpenAI-compatible AI provider plugin, including configuration, chat, and chat_stream functionalities. - Added support for SSE streaming and tool calls. - Integrated Boost.JSON for JSON handling. - Created CMake configuration for the plugin. - Added error handling and logging throughout the plugin.
This commit is contained in:
7
plugins_middle/CMakeLists.txt
Normal file
7
plugins_middle/CMakeLists.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
# ============================================================
|
||||
# 依赖基础插件的插件 / Plugins depending on base plugins only
|
||||
# ============================================================
|
||||
|
||||
add_subdirectory(network) # 依赖 config / depends on config
|
||||
add_subdirectory(session) # 依赖 file_io / depends on file_io
|
||||
add_subdirectory(tools) # 依赖 file_io / depends on file_io
|
||||
16
plugins_middle/network/CMakeLists.txt
Normal file
16
plugins_middle/network/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
find_package(OpenSSL REQUIRED CONFIG)
|
||||
|
||||
add_library(plugin-network SHARED src/network_plugin.cpp)
|
||||
|
||||
target_link_libraries(plugin-network PRIVATE
|
||||
dstalk
|
||||
boost::boost
|
||||
openssl::openssl
|
||||
)
|
||||
|
||||
set_target_properties(plugin-network PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
)
|
||||
394
plugins_middle/network/src/network_plugin.cpp
Normal file
394
plugins_middle/network/src/network_plugin.cpp
Normal file
@@ -0,0 +1,394 @@
|
||||
/*
|
||||
* @file network_plugin.cpp
|
||||
* @brief Network plugin: HTTP/HTTPS POST and streaming via Boost.Beast + OpenSSL.
|
||||
* 网络插件:基于 Boost.Beast + OpenSSL 的 HTTP/HTTPS POST 和流式传输。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// MSVC 14.16 (VS 2017) 不提供 std::to_address (C++20) / MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20)
|
||||
#define BOOST_ASIO_DISABLE_STD_TO_ADDRESS
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
#include <boost/asio/connect.hpp>
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/ssl.hpp>
|
||||
#include <boost/asio/steady_timer.hpp>
|
||||
#include <boost/beast/core.hpp>
|
||||
#include <boost/beast/http.hpp>
|
||||
#include <boost/beast/ssl.hpp>
|
||||
#include <boost/beast/version.hpp>
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace beast = boost::beast;
|
||||
namespace http = beast::http;
|
||||
namespace asio = boost::asio;
|
||||
namespace ssl = boost::asio::ssl;
|
||||
using tcp = asio::ip::tcp;
|
||||
|
||||
// ============================================================
|
||||
// 全局状态 / Global state
|
||||
// ============================================================
|
||||
static const dstalk_host_api_t* g_host = nullptr;
|
||||
static dstalk_config_service_t* g_config_svc = nullptr;
|
||||
|
||||
// ============================================================
|
||||
// 极简 JSON 头解析器 / Minimal JSON header parser
|
||||
// 将 {"key1":"value1","key2":"value2"} 解析到 unordered_map / Parses {"key1":"value1","key2":"value2"} into unordered_map
|
||||
// ============================================================
|
||||
// 将扁平 JSON 对象中的字符串键值对解析到 unordered_map / Parse a flat JSON object of string key-value pairs into an unordered_map.
|
||||
static std::unordered_map<std::string, std::string> parse_headers_json(const char* json) {
|
||||
std::unordered_map<std::string, std::string> headers;
|
||||
if (!json || !*json) return headers;
|
||||
|
||||
std::string s(json);
|
||||
// 极简状态机解析器,处理扁平的字符串键值对象 / Very simple state-machine parser for flat string-key/value objects
|
||||
enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE;
|
||||
std::string current_key;
|
||||
std::string current_value;
|
||||
|
||||
for (size_t i = 0; i < s.size(); ++i) {
|
||||
char c = s[i];
|
||||
switch (state) {
|
||||
case OUTSIDE:
|
||||
if (c == '"') { state = IN_KEY; current_key.clear(); }
|
||||
break;
|
||||
case IN_KEY:
|
||||
if (c == '"') { state = AFTER_KEY; }
|
||||
else if (c == '\\' && i + 1 < s.size()) { current_key += s[++i]; }
|
||||
else { current_key += c; }
|
||||
break;
|
||||
case AFTER_KEY:
|
||||
if (c == ':') { state = IN_VALUE; current_value.clear(); }
|
||||
break;
|
||||
case IN_VALUE:
|
||||
if (c == '"') {
|
||||
// 读取到闭合引号 / Read until closing quote
|
||||
++i;
|
||||
while (i < s.size() && s[i] != '"') {
|
||||
if (s[i] == '\\' && i + 1 < s.size()) { current_value += s[++i]; }
|
||||
else { current_value += s[i]; }
|
||||
++i;
|
||||
}
|
||||
headers[current_key] = current_value;
|
||||
state = OUTSIDE;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// HTTP 客户端实现(改编自 dstalk_core HttpClient) / HTTP Client implementation (adapted from dstalk_core HttpClient)
|
||||
// ============================================================
|
||||
struct HttpClientCtx {
|
||||
asio::io_context ioc;
|
||||
ssl::context ssl_ctx{ssl::context::tlsv12_client};
|
||||
int connect_timeout = 30;
|
||||
int request_timeout = 120;
|
||||
|
||||
HttpClientCtx() {
|
||||
ssl_ctx.set_default_verify_paths();
|
||||
// 启用对等证书验证 (CVSS 7.4 修复) / Enable peer certificate verification (CVSS 7.4 fix).
|
||||
// set_default_verify_paths() 加载系统 CA 包;没有 verify_peer
|
||||
// CA 存储不会被查询——任何证书(自签名/过期)都将被接受 / set_default_verify_paths() loads system CA bundle; without verify_peer
|
||||
// the CA store is never consulted — any cert (self-signed/expired) is accepted.
|
||||
// TODO: Windows: set_default_verify_paths() 可能无法定位系统 CA;
|
||||
// 如果验证失败,设置 SSL_CERT_FILE 环境变量或捆绑 cacert.pem / Windows: set_default_verify_paths() may not locate system CAs;
|
||||
// if verification fails, set SSL_CERT_FILE env or bundle a cacert.pem.
|
||||
ssl_ctx.set_verify_mode(ssl::verify_peer);
|
||||
}
|
||||
};
|
||||
|
||||
// 核心 HTTP/HTTPS POST,支持可选 SSE 流式传输。执行 DNS 解析、
|
||||
// TLS 握手(含 SNI 和主机名验证),然后发送请求。
|
||||
// 如果 cb 非空,响应体将逐行解析用于流式传输 / Core HTTP/HTTPS POST with optional SSE streaming. Performs DNS resolve,
|
||||
// TLS handshake with SNI and hostname verification, then sends the request.
|
||||
// If `cb` is non-null, response body is parsed line-by-line for streaming.
|
||||
static int do_post_stream(
|
||||
const char* host,
|
||||
const char* port,
|
||||
const char* target,
|
||||
const char* body,
|
||||
const char* headers_json,
|
||||
dstalk_stream_cb cb,
|
||||
void* userdata,
|
||||
char** response_body,
|
||||
int* status_code)
|
||||
{
|
||||
if (!host || !port || !target || !body || !response_body || !status_code) {
|
||||
if (response_body) *response_body = nullptr;
|
||||
if (status_code) *status_code = -1;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 初始化输出 / Initialize output
|
||||
*response_body = nullptr;
|
||||
*status_code = -1;
|
||||
|
||||
// 从 C 回调构建 C++ lambda / Build C++ lambda from C callback
|
||||
std::function<bool(const std::string&)> on_line;
|
||||
if (cb) {
|
||||
on_line = [cb, userdata](const std::string& line) -> bool {
|
||||
return cb(line.c_str(), userdata) == 0;
|
||||
};
|
||||
}
|
||||
|
||||
HttpClientCtx ctx;
|
||||
|
||||
// 从配置读取超时设置 / Read timeouts from config if available
|
||||
if (g_config_svc) {
|
||||
const char* ct = g_config_svc->get("http.connect_timeout");
|
||||
const char* rt = g_config_svc->get("http.request_timeout");
|
||||
if (ct) ctx.connect_timeout = std::atoi(ct);
|
||||
if (rt) ctx.request_timeout = std::atoi(rt);
|
||||
if (ctx.connect_timeout <= 0) ctx.connect_timeout = 30;
|
||||
if (ctx.request_timeout <= 0) ctx.request_timeout = 120;
|
||||
}
|
||||
|
||||
std::string result_body;
|
||||
int result_code = -1;
|
||||
|
||||
try {
|
||||
tcp::resolver resolver(ctx.ioc);
|
||||
|
||||
// DNS 解析,10 秒超时。Boost.Asio 的同步 resolve()
|
||||
// 在内部运行 io_context,因此定时器的 async_wait 回调在 resolve() 期间执行,
|
||||
// 并在超时触发时调用 resolver.cancel() / DNS resolve with 10-second timeout. Boost.Asio's synchronous
|
||||
// resolve() runs the io_context internally, so the timer's async_wait
|
||||
// callback executes during resolve() and calls resolver.cancel() when
|
||||
// the deadline fires.
|
||||
asio::steady_timer resolve_timer(ctx.ioc);
|
||||
resolve_timer.expires_after(std::chrono::seconds(10));
|
||||
resolve_timer.async_wait([&](const beast::error_code& ec) {
|
||||
if (!ec) resolver.cancel();
|
||||
});
|
||||
|
||||
beast::error_code resolve_ec;
|
||||
auto endpoints = resolver.resolve(host, port, resolve_ec);
|
||||
resolve_timer.cancel();
|
||||
|
||||
if (resolve_ec) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR,
|
||||
"do_post_stream: DNS resolve %s:%s failed: %s",
|
||||
host, port, resolve_ec.message().c_str());
|
||||
result_body = std::string("DNS resolve failed: ") + resolve_ec.message();
|
||||
goto done;
|
||||
}
|
||||
|
||||
beast::ssl_stream<beast::tcp_stream> stream(ctx.ioc, ctx.ssl_ctx);
|
||||
beast::flat_buffer buffer;
|
||||
|
||||
// SNI 主机名 / SNI hostname
|
||||
if (!SSL_set_tlsext_host_name(stream.native_handle(), host)) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR,
|
||||
"do_post_stream: SNI hostname set failed for %s", host);
|
||||
result_body = "SNI hostname set failed";
|
||||
goto done;
|
||||
}
|
||||
|
||||
// 主机名验证:要求服务器证书 CN/SAN 匹配 'host'。
|
||||
// 与 ssl::verify_peer 协同工作——没有它的话,
|
||||
// 使用不同主机名的有效 CA 签名证书进行 MITM 攻击仍可通过 / Hostname verification: require server certificate CN/SAN to match
|
||||
// 'host'. This works in conjunction with ssl::verify_peer on the
|
||||
// context — without it MITM with a valid CA-signed cert for a
|
||||
// different hostname would still pass.
|
||||
if (!SSL_set1_host(stream.native_handle(), host)) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR,
|
||||
"do_post_stream: SSL_set1_host failed for %s", host);
|
||||
result_body = "SSL_set1_host failed";
|
||||
goto done;
|
||||
}
|
||||
|
||||
// 连接 / Connect
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.connect_timeout));
|
||||
beast::get_lowest_layer(stream).connect(endpoints);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// SSL 握手 / SSL handshake
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.connect_timeout));
|
||||
stream.handshake(ssl::stream_base::client);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// 构建 HTTP POST 请求 / Build HTTP POST request
|
||||
http::request<http::string_body> req{http::verb::post, target, 11};
|
||||
req.set(http::field::host, host);
|
||||
req.set(http::field::user_agent, "dstalk/0.1");
|
||||
req.set(http::field::content_type, "application/json");
|
||||
req.body() = body;
|
||||
req.prepare_payload();
|
||||
|
||||
// 从 JSON 添加额外的头 / Add extra headers from JSON
|
||||
auto extra_headers = parse_headers_json(headers_json);
|
||||
for (const auto& h : extra_headers) {
|
||||
req.set(h.first, h.second);
|
||||
}
|
||||
|
||||
// 发送 / Send
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.request_timeout));
|
||||
http::write(stream, req);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
// 读取响应 / Read response
|
||||
http::response_parser<http::string_body> parser;
|
||||
parser.body_limit(16 * 1024 * 1024);
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.request_timeout));
|
||||
http::read_header(stream, buffer, parser);
|
||||
beast::get_lowest_layer(stream).expires_never();
|
||||
|
||||
result_code = parser.get().result_int();
|
||||
|
||||
beast::error_code ec;
|
||||
|
||||
if (on_line) {
|
||||
std::string fragment = parser.get().body();
|
||||
auto emit_lines = [&]() -> bool {
|
||||
size_t pos = 0;
|
||||
while (pos < fragment.size()) {
|
||||
size_t nl = fragment.find('\n', pos);
|
||||
if (nl == std::string::npos) break;
|
||||
std::string line = fragment.substr(pos, nl - pos);
|
||||
if (!line.empty() && line.back() == '\r')
|
||||
line.pop_back();
|
||||
if (!on_line(line)) return false;
|
||||
pos = nl + 1;
|
||||
}
|
||||
if (pos > 0)
|
||||
fragment = fragment.substr(pos);
|
||||
return true;
|
||||
};
|
||||
if (!emit_lines()) goto done;
|
||||
|
||||
size_t processed = parser.get().body().size();
|
||||
while (!parser.is_done()) {
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.request_timeout));
|
||||
http::read_some(stream, buffer, parser, ec);
|
||||
if (ec) break;
|
||||
|
||||
const std::string& full_body = parser.get().body();
|
||||
if (full_body.size() > processed) {
|
||||
std::string_view new_data(full_body.data() + processed,
|
||||
full_body.size() - processed);
|
||||
processed = full_body.size();
|
||||
|
||||
fragment.append(new_data.data(), new_data.size());
|
||||
if (!emit_lines()) goto done;
|
||||
}
|
||||
}
|
||||
if (!fragment.empty()) {
|
||||
if (fragment.back() == '\r')
|
||||
fragment.pop_back();
|
||||
if (!fragment.empty())
|
||||
on_line(fragment);
|
||||
}
|
||||
} else {
|
||||
while (!parser.is_done()) {
|
||||
beast::get_lowest_layer(stream).expires_after(
|
||||
std::chrono::seconds(ctx.request_timeout));
|
||||
http::read_some(stream, buffer, parser, ec);
|
||||
if (ec) break;
|
||||
}
|
||||
}
|
||||
|
||||
result_body = parser.get().body();
|
||||
beast::get_lowest_layer(stream).cancel();
|
||||
stream.shutdown(ec);
|
||||
} catch (const std::exception& e) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR,
|
||||
"do_post_stream: %s", e.what());
|
||||
result_code = -1;
|
||||
result_body = e.what();
|
||||
} catch (...) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR,
|
||||
"do_post_stream: unknown exception (non-std::exception)");
|
||||
result_code = -1;
|
||||
result_body = "unknown exception";
|
||||
}
|
||||
|
||||
done:
|
||||
*status_code = result_code;
|
||||
if (!result_body.empty()) {
|
||||
*response_body = g_host->strdup(result_body.c_str());
|
||||
}
|
||||
return (result_code >= 200 && result_code < 300) ? 0 : -1;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 服务实现 / Service implementations
|
||||
// ============================================================
|
||||
// 同步 HTTP POST,返回完整响应体 / Synchronous HTTP POST returning the complete response body.
|
||||
static int http_post_json(
|
||||
const char* host, const char* port,
|
||||
const char* target, const char* body,
|
||||
const char* headers_json,
|
||||
char** response_body, int* status_code)
|
||||
{
|
||||
return do_post_stream(host, port, target, body, headers_json,
|
||||
nullptr, nullptr, response_body, status_code);
|
||||
}
|
||||
|
||||
// HTTP POST 带 SSE 流式传输:响应行到达时通过 cb 回调传递 / HTTP POST with SSE streaming: response lines are delivered to `cb` as they arrive.
|
||||
static int http_post_stream(
|
||||
const char* host, const char* port,
|
||||
const char* target, const char* body,
|
||||
const char* headers_json,
|
||||
dstalk_stream_cb cb, void* userdata,
|
||||
char** response_body, int* status_code)
|
||||
{
|
||||
return do_post_stream(host, port, target, body, headers_json,
|
||||
cb, userdata, response_body, status_code);
|
||||
}
|
||||
|
||||
static dstalk_http_service_t g_service = {
|
||||
http_post_json,
|
||||
http_post_stream
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
// 插件初始化:保存主机指针,查询 config 服务,注册 http 服务 / Plugin init: store host pointer, query config service, register http service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
g_host = host;
|
||||
|
||||
// 查询 config 服务(声明的依赖) / Query config service (declared dependency)
|
||||
g_config_svc = (dstalk_config_service_t*)host->query_service("config", 1);
|
||||
|
||||
return host->register_service("http", 1, &g_service);
|
||||
}
|
||||
|
||||
// 插件关闭:无需清理 / Plugin shutdown: nothing to clean up.
|
||||
static void on_shutdown() {
|
||||
// 无需清理 / nothing to clean up
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"http", // name 名称
|
||||
"1.0.0", // version 版本
|
||||
"HTTP/HTTPS client service using Boost.Beast + OpenSSL", // description 描述
|
||||
DSTALK_API_VERSION, // api_version
|
||||
{"config", nullptr}, // dependencies 依赖
|
||||
on_init, // on_init
|
||||
on_shutdown, // on_shutdown
|
||||
nullptr // on_event
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
12
plugins_middle/session/CMakeLists.txt
Normal file
12
plugins_middle/session/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
add_library(plugin-session SHARED src/session_plugin.cpp)
|
||||
|
||||
target_link_libraries(plugin-session PRIVATE dstalk)
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
target_link_libraries(plugin-session PRIVATE boost::boost dstalk_boost_config)
|
||||
|
||||
set_target_properties(plugin-session PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
)
|
||||
429
plugins_middle/session/src/session_plugin.cpp
Normal file
429
plugins_middle/session/src/session_plugin.cpp
Normal file
@@ -0,0 +1,429 @@
|
||||
/*
|
||||
* @file session_plugin.cpp
|
||||
* @brief Session plugin: conversation message history management with save/load.
|
||||
* 会话插件:对话消息历史管理,支持保存/加载。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-session: 会话管理服务插件 / Session management service plugin
|
||||
// 提供 dstalk_session_service_t vtable 实现 / Provides dstalk_session_service_t vtable implementation
|
||||
// 依赖: file_io (save/load 需要文件操作) / Depends on: file_io (save/load needs file operations)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_types.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
#include <boost/json.hpp>
|
||||
#include <boost/json/src.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================
|
||||
// 内部 C++ 数据结构 / Internal C++ data structures
|
||||
// ============================================================
|
||||
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
|
||||
|
||||
// 内部消息结构(C++ 易用,外部暴露 C struct) / Internal message struct (C++ friendly, externally exposed as C struct)
|
||||
struct InternalMessage {
|
||||
std::string role;
|
||||
std::string content;
|
||||
std::string tool_call_id;
|
||||
std::string tool_calls_json;
|
||||
};
|
||||
|
||||
// 会话历史 + 缓存 —— W14.3: mutex 保护读写 / Session history + cache — W14.3: mutex protects read/write
|
||||
static std::vector<InternalMessage> g_history;
|
||||
static std::vector<dstalk_message_t> g_cached_history;
|
||||
static std::mutex g_session_mutex;
|
||||
|
||||
// ============================================================
|
||||
// Token 计数工具(内联,避免硬依赖 context 头文件) / Token counting utilities (inline, avoids hard dep on context headers)
|
||||
// ============================================================
|
||||
|
||||
// 如果字节是 ASCII (0x00–0x7F) 则返回 true / Returns true if the byte is ASCII (0x00–0x7F).
|
||||
static bool is_ascii(unsigned char c) { return c < 0x80; }
|
||||
|
||||
// 启发式判断:如果字节起始一个 UTF-8 CJK 统一表意文字 (0xE4–0xE9) 则返回 true / Heuristic: returns true if the byte starts a CJK Unified Ideograph in UTF-8 (0xE4–0xE9).
|
||||
static bool starts_cjk(unsigned char c) {
|
||||
return c >= 0xE4 && c <= 0xE9;
|
||||
}
|
||||
|
||||
// 使用启发式 UTF-8 字节计数估算单条消息的 token 数 / Estimate token count for a single message using heuristic UTF-8 byte counting.
|
||||
static size_t count_tokens_one(const std::string& text) {
|
||||
size_t ascii_chars = 0;
|
||||
size_t chinese_chars = 0;
|
||||
size_t other_chars = 0;
|
||||
|
||||
size_t i = 0;
|
||||
while (i < text.size()) {
|
||||
unsigned char c = static_cast<unsigned char>(text[i]);
|
||||
|
||||
if (is_ascii(c)) {
|
||||
ascii_chars++;
|
||||
i += 1;
|
||||
} else if (starts_cjk(c)) {
|
||||
chinese_chars++;
|
||||
i += 3;
|
||||
} else if (c >= 0xC0 && c < 0xE0) {
|
||||
other_chars++;
|
||||
i += 2;
|
||||
} else if (c >= 0xE0 && c < 0xF0) {
|
||||
other_chars++;
|
||||
i += 3;
|
||||
} else if (c >= 0xF0 && c < 0xF8) {
|
||||
other_chars++;
|
||||
i += 4;
|
||||
} else {
|
||||
other_chars++;
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3);
|
||||
return content_tokens + 4; // +4 每条消息开销 / +4 per message overhead
|
||||
}
|
||||
|
||||
// 估算所有消息的总 token 数 / Estimate total token count across all messages.
|
||||
static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) {
|
||||
size_t total = 0;
|
||||
for (const auto& m : msgs) {
|
||||
total += count_tokens_one(m.content);
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex) / Helper: rebuild C cached array (caller must hold g_session_mutex)
|
||||
// ============================================================
|
||||
|
||||
// 从内部消息 vector 重建 C 兼容的缓存历史数组。调用方必须持有 g_session_mutex / Rebuild the C-compatible cached history array from the internal message vector.
|
||||
// Caller must hold g_session_mutex.
|
||||
static void rebuild_cached_history_locked() {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
|
||||
// 释放旧的字符串 / Free old strings
|
||||
for (auto& m : g_cached_history) {
|
||||
if (m.role) { host->free(const_cast<char*>(m.role)); }
|
||||
if (m.content) { host->free(const_cast<char*>(m.content)); }
|
||||
if (m.tool_call_id) { host->free(const_cast<char*>(m.tool_call_id)); }
|
||||
if (m.tool_calls_json){ host->free(const_cast<char*>(m.tool_calls_json)); }
|
||||
}
|
||||
g_cached_history.clear();
|
||||
|
||||
// 重建 / Rebuild
|
||||
g_cached_history.reserve(g_history.size());
|
||||
for (const auto& im : g_history) {
|
||||
dstalk_message_t cm;
|
||||
cm.role = im.role.empty() ? nullptr : host->strdup(im.role.c_str());
|
||||
cm.content = im.content.empty() ? nullptr : host->strdup(im.content.c_str());
|
||||
cm.tool_call_id = im.tool_call_id.empty() ? nullptr : host->strdup(im.tool_call_id.c_str());
|
||||
cm.tool_calls_json = im.tool_calls_json.empty() ? nullptr : host->strdup(im.tool_calls_json.c_str());
|
||||
g_cached_history.push_back(cm);
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Session 服务 vtable 实现 (W14.3: try/catch + mutex) / Session service vtable implementation (W14.3: try/catch + mutex)
|
||||
// ============================================================
|
||||
|
||||
// 向对话历史追加一条消息 / Append a message to the conversation history.
|
||||
static void session_add(const dstalk_message_t* msg) {
|
||||
try {
|
||||
if (!msg) return;
|
||||
InternalMessage im;
|
||||
if (msg->role) im.role = msg->role;
|
||||
if (msg->content) im.content = msg->content;
|
||||
if (msg->tool_call_id) im.tool_call_id = msg->tool_call_id;
|
||||
if (msg->tool_calls_json) im.tool_calls_json = msg->tool_calls_json;
|
||||
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
g_history.push_back(std::move(im));
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_add: %s", e.what());
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_add: unknown exception");
|
||||
}
|
||||
}
|
||||
|
||||
// 清空对话历史中的所有消息 / Clear all messages from the conversation history.
|
||||
static void session_clear() {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
g_history.clear();
|
||||
}
|
||||
|
||||
// 将当前对话历史序列化为 JSON 行文件并保存到 path / Serialize the current conversation history to a JSON lines file at `path`.
|
||||
static int session_save(const char* path) {
|
||||
try {
|
||||
if (!path) return -1;
|
||||
|
||||
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
|
||||
if (!fio) return -1;
|
||||
|
||||
std::string data;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
for (const auto& m : g_history) {
|
||||
json::object entry;
|
||||
entry["role"] = m.role;
|
||||
entry["content"] = m.content;
|
||||
if (!m.tool_call_id.empty())
|
||||
entry["tool_call_id"] = m.tool_call_id;
|
||||
if (!m.tool_calls_json.empty())
|
||||
entry["tool_calls_json"] = m.tool_calls_json;
|
||||
data += json::serialize(entry);
|
||||
data += '\n';
|
||||
}
|
||||
}
|
||||
return fio->write(path, data.c_str());
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_save: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_save: unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// 从 JSON 行文件中加载对话历史,替换当前历史 / Load conversation history from a JSON lines file at `path`, replacing current history.
|
||||
static int session_load(const char* path) {
|
||||
try {
|
||||
if (!path) return -1;
|
||||
|
||||
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
|
||||
if (!fio) return -1;
|
||||
|
||||
char* content = nullptr;
|
||||
int ret = fio->read(path, &content);
|
||||
if (ret != 0 || !content) return -1;
|
||||
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
std::string data(content);
|
||||
host->free(content);
|
||||
|
||||
std::vector<InternalMessage> parsed;
|
||||
size_t pos = 0;
|
||||
while (pos < data.size()) {
|
||||
size_t nl = data.find('\n', pos);
|
||||
std::string line = (nl != std::string::npos)
|
||||
? data.substr(pos, nl - pos) : data.substr(pos);
|
||||
pos = (nl != std::string::npos) ? nl + 1 : data.size();
|
||||
if (line.empty()) continue;
|
||||
|
||||
auto obj = json::parse(line).as_object();
|
||||
auto* role_j = obj.if_contains("role");
|
||||
auto* content_j = obj.if_contains("content");
|
||||
if (role_j && content_j && role_j->is_string() && content_j->is_string()) {
|
||||
InternalMessage im;
|
||||
im.role = json::value_to<std::string>(*role_j);
|
||||
im.content = json::value_to<std::string>(*content_j);
|
||||
auto* tci = obj.if_contains("tool_call_id");
|
||||
if (tci && tci->is_string())
|
||||
im.tool_call_id = json::value_to<std::string>(*tci);
|
||||
auto* tcj = obj.if_contains("tool_calls_json");
|
||||
if (tcj && tcj->is_string())
|
||||
im.tool_calls_json = json::value_to<std::string>(*tcj);
|
||||
parsed.push_back(std::move(im));
|
||||
}
|
||||
}
|
||||
|
||||
if (parsed.empty()) return -1;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
g_history = std::move(parsed);
|
||||
}
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_load: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_load: unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// 返回指向缓存 C 消息数组的指针,并将 *out_count 设置为数组大小 / Return a pointer to the cached C-array of messages and set *out_count to its size.
|
||||
static const dstalk_message_t* session_history(int* out_count) {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
rebuild_cached_history_locked();
|
||||
if (out_count) *out_count = static_cast<int>(g_cached_history.size());
|
||||
return g_cached_history.empty() ? nullptr : g_cached_history.data();
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_history: %s", e.what());
|
||||
if (out_count) *out_count = 0;
|
||||
return nullptr;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_history: unknown exception");
|
||||
if (out_count) *out_count = 0;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// 返回当前对话历史的估算 token 数 / Return the estimated token count for the current conversation history.
|
||||
static int session_token_count() {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
return static_cast<int>(count_tokens_all(g_history));
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_token_count: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "session_token_count: unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static dstalk_session_service_t g_session_service = {
|
||||
session_add,
|
||||
session_clear,
|
||||
session_save,
|
||||
session_load,
|
||||
session_history,
|
||||
session_token_count
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// W20.6: 默认会话保存路径(平台标准目录) / Default session save path (platform standard directory)
|
||||
// ============================================================
|
||||
|
||||
// 返回平台特定的默认会话保存路径,根据需要创建目录 / Return the platform-specific default session save path, creating directories as needed.
|
||||
static std::string get_default_session_path() {
|
||||
// W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录 / static cache + mkdir guarantee + fallback to current dir on failure
|
||||
static std::string cached_path = []() -> std::string {
|
||||
#ifdef _WIN32
|
||||
char* buf = nullptr;
|
||||
size_t len = 0;
|
||||
_dupenv_s(&buf, &len, "APPDATA");
|
||||
std::string dir = buf ? std::string(buf) + "/dstalk" : "dstalk";
|
||||
free(buf);
|
||||
#else
|
||||
const char* home = std::getenv("HOME");
|
||||
std::string dir = home ? std::string(home) + "/.dstalk" : "/tmp/dstalk";
|
||||
#endif
|
||||
|
||||
std::error_code ec;
|
||||
std::filesystem::create_directories(dir, ec);
|
||||
if (ec) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_WARN,
|
||||
"get_default_session_path: cannot mkdir '%s' (%s), fallback to .",
|
||||
dir.c_str(), ec.message().c_str());
|
||||
return std::string("./session.json");
|
||||
}
|
||||
|
||||
return dir + "/session.json";
|
||||
}();
|
||||
return cached_path;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
|
||||
// 插件初始化:保存主机指针,查询 file_io 依赖,注册 session 服务,
|
||||
// 并从默认路径自动加载已有会话 / Plugin init: store host pointer, query file_io dependency, register session service,
|
||||
// and auto-load any existing session from the default path.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
try {
|
||||
g_host.store(host, std::memory_order_release);
|
||||
|
||||
// 查询依赖服务: file_io / Query dependency service: file_io
|
||||
void* raw = host->query_service("file_io", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-session] required service 'file_io' not found");
|
||||
return -1;
|
||||
}
|
||||
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
|
||||
|
||||
// 注册自身服务 / Register own service
|
||||
int ret = host->register_service("session", 1, &g_session_service);
|
||||
if (ret != 0) return ret;
|
||||
|
||||
// W20.6: 从默认路径恢复会话(文件不存在则静默失败) / Restore session from default path (silent fail if file missing)
|
||||
session_load(get_default_session_path().c_str());
|
||||
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) h->log(DSTALK_LOG_ERROR, "on_init[session]: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) h->log(DSTALK_LOG_ERROR, "on_init[session]: unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:自动保存会话到默认路径,失败时回退到当前目录,
|
||||
// 然后释放缓存历史和清空状态 / Plugin shutdown: auto-save session to default path, fallback to current dir on failure,
|
||||
// then release cached history and clear state.
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
// W20.6: 清空前自动保存到默认路径 / Auto-save to default path before clearing
|
||||
// W21.4: 失败告警 + 当前目录 fallback / Failure warning + current dir fallback
|
||||
int ret = session_save(get_default_session_path().c_str());
|
||||
if (ret != 0) {
|
||||
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) h->log(DSTALK_LOG_WARN, "on_shutdown[session]: auto-save failed (ret=%d), trying fallback", ret);
|
||||
int fret = session_save("./dstalk_session_backup.json");
|
||||
if (fret != 0) {
|
||||
if (h) h->log(DSTALK_LOG_ERROR, "on_shutdown[session]: fallback also failed (ret=%d), data may be lost", fret);
|
||||
}
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(g_session_mutex);
|
||||
rebuild_cached_history_locked();
|
||||
g_cached_history.clear();
|
||||
g_history.clear();
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "on_shutdown[session]: %s", e.what());
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "on_shutdown[session]: unknown exception");
|
||||
}
|
||||
g_file_io.store(nullptr, std::memory_order_release);
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"session",
|
||||
"1.0.0",
|
||||
"Session management plugin with save/load support / 支持保存/加载的会话管理插件",
|
||||
DSTALK_API_VERSION,
|
||||
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||||
on_init,
|
||||
on_shutdown,
|
||||
nullptr
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
12
plugins_middle/tools/CMakeLists.txt
Normal file
12
plugins_middle/tools/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
add_library(plugin-tools SHARED src/tools_plugin.cpp)
|
||||
|
||||
target_link_libraries(plugin-tools PRIVATE dstalk)
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
target_link_libraries(plugin-tools PRIVATE boost::boost dstalk_boost_config)
|
||||
|
||||
set_target_properties(plugin-tools PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
)
|
||||
388
plugins_middle/tools/src/tools_plugin.cpp
Normal file
388
plugins_middle/tools/src/tools_plugin.cpp
Normal file
@@ -0,0 +1,388 @@
|
||||
/*
|
||||
* @file tools_plugin.cpp
|
||||
* @brief Tools plugin: tool registration, schema management, and execution registry.
|
||||
* 工具插件:工具注册、schema 管理和执行注册表。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-tools: 工具注册服务插件 / Tool registration service plugin
|
||||
// 提供 dstalk_tools_service_t vtable 实现 / Provides dstalk_tools_service_t vtable implementation
|
||||
// 依赖: file_io (内置 file_read / file_write 工具) / Depends on: file_io (built-in file_read / file_write tools)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_types.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
#include <boost/json.hpp>
|
||||
#include <boost/json/src.hpp>
|
||||
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <exception>
|
||||
#include <filesystem>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================
|
||||
// 路径安全校验 (W14.3: 防止路径遍历攻击) / Path safety validation (W14.3: prevent path traversal attacks)
|
||||
// ============================================================
|
||||
|
||||
// 验证文件路径是否安全(无绝对路径、无 ".." 遍历、非空) / Validate that a file path is safe (no absolute paths, no ".." traversal, no empty).
|
||||
static bool is_safe_path(const std::string& path) {
|
||||
// 拒绝空路径 / Reject empty path
|
||||
if (path.empty()) return false;
|
||||
|
||||
// 拒绝绝对路径: Unix '/' 开头 或 Windows 盘符 (第二字符 ':') / Reject absolute paths: Unix '/' prefix or Windows drive letter (second char ':')
|
||||
if (path[0] == '/' || path[0] == '\\') return false;
|
||||
if (path.size() >= 2 && path[1] == ':') return false;
|
||||
|
||||
// 拒绝含 ".." 段的目录遍历 / Reject directory traversal with ".." segments
|
||||
if (path.find("..") != std::string::npos) return false;
|
||||
|
||||
// lexical_normal 消解相对组件后再次校验 / Re-validate after resolving relative components with lexical_normal
|
||||
std::string norm = std::filesystem::path(path).lexically_normal().string();
|
||||
if (norm.empty()) return false;
|
||||
if (norm[0] == '/' || norm[0] == '\\') return false;
|
||||
if (norm.size() >= 2 && norm[1] == ':') return false;
|
||||
if (norm.find("..") != std::string::npos) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 内部数据结构 / Internal data structures
|
||||
// ============================================================
|
||||
|
||||
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
|
||||
|
||||
struct ToolDef {
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string parameters_schema;
|
||||
dstalk_tool_handler_fn handler;
|
||||
};
|
||||
|
||||
// W14.3: g_tools 使用 mutex 保护读写 / g_tools uses mutex to protect read/write
|
||||
static std::vector<ToolDef> g_tools;
|
||||
static std::mutex g_tools_mutex;
|
||||
|
||||
// ============================================================
|
||||
// 内置工具: file_read, file_write / Built-in tools: file_read, file_write
|
||||
// ============================================================
|
||||
|
||||
// 内置工具处理器:读取文件并以 JSON 字符串返回内容 / Built-in tool handler: read a file and return its contents as a JSON string.
|
||||
static char* builtin_file_read(const char* args_json) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
|
||||
|
||||
if (!fio) {
|
||||
return host ? host->strdup("{\"error\":\"file_io service not available\"}") : nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
auto args = json::parse(args_json).as_object();
|
||||
auto* path_j = args.if_contains("path");
|
||||
if (!path_j || !path_j->is_string()) {
|
||||
return host ? host->strdup("{\"error\":\"missing 'path' argument\"}") : nullptr;
|
||||
}
|
||||
std::string path = json::value_to<std::string>(*path_j);
|
||||
|
||||
// W14.3: 路径遍历防护 / Path traversal protection
|
||||
if (!is_safe_path(path)) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_read: unsafe path rejected");
|
||||
return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr;
|
||||
}
|
||||
|
||||
char* content = nullptr;
|
||||
int ret = fio->read(path.c_str(), &content);
|
||||
if (ret != 0 || !content) {
|
||||
return host ? host->strdup("{\"error\":\"failed to read file\"}") : nullptr;
|
||||
}
|
||||
|
||||
std::string escaped_content = json::serialize(json::string(content));
|
||||
if (host) host->free(content);
|
||||
|
||||
std::string result = "{\"content\":" + escaped_content + "}";
|
||||
return host ? host->strdup(result.c_str()) : nullptr;
|
||||
} catch (const std::exception& e) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_read: %s", e.what());
|
||||
std::string err = "{\"error\":\"file_read internal error\"}";
|
||||
return host ? host->strdup(err.c_str()) : nullptr;
|
||||
} catch (...) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_read: unknown exception");
|
||||
return host ? host->strdup("{\"error\":\"file_read internal error\"}") : nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// 内置工具处理器:将内容写入文件,返回成功/错误 JSON 对象 / Built-in tool handler: write content to a file, returning a success/error JSON object.
|
||||
static char* builtin_file_write(const char* args_json) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
|
||||
|
||||
if (!fio) {
|
||||
return host ? host->strdup("{\"error\":\"file_io service not available\"}") : nullptr;
|
||||
}
|
||||
|
||||
try {
|
||||
auto args = json::parse(args_json).as_object();
|
||||
auto* path_j = args.if_contains("path");
|
||||
auto* content_j = args.if_contains("content");
|
||||
if (!path_j || !path_j->is_string()) {
|
||||
return host ? host->strdup("{\"error\":\"missing 'path' argument\"}") : nullptr;
|
||||
}
|
||||
if (!content_j || !content_j->is_string()) {
|
||||
return host ? host->strdup("{\"error\":\"missing 'content' argument\"}") : nullptr;
|
||||
}
|
||||
|
||||
std::string path = json::value_to<std::string>(*path_j);
|
||||
std::string content = json::value_to<std::string>(*content_j);
|
||||
|
||||
// W14.3: 路径遍历防护 / Path traversal protection
|
||||
if (!is_safe_path(path)) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_write: unsafe path rejected");
|
||||
return host ? host->strdup("{\"error\":\"access denied: unsafe path\"}") : nullptr;
|
||||
}
|
||||
|
||||
int ret = fio->write(path.c_str(), content.c_str());
|
||||
if (ret != 0) {
|
||||
return host ? host->strdup("{\"error\":\"failed to write file\"}") : nullptr;
|
||||
}
|
||||
|
||||
return host ? host->strdup("{\"success\":true}") : nullptr;
|
||||
} catch (const std::exception& e) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_write: %s", e.what());
|
||||
std::string err = "{\"error\":\"file_write internal error\"}";
|
||||
return host ? host->strdup(err.c_str()) : nullptr;
|
||||
} catch (...) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "builtin_file_write: unknown exception");
|
||||
return host ? host->strdup("{\"error\":\"file_write internal error\"}") : nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Tools 服务 vtable 实现 (W14.3: try/catch + mutex) / Tools service vtable implementation (W14.3: try/catch + mutex)
|
||||
// ============================================================
|
||||
|
||||
static void tools_unregister_tool(const char* name);
|
||||
|
||||
// 注册命名工具及其描述、JSON Schema 参数和处理函数 / Register a named tool with its description, JSON Schema parameters, and handler function.
|
||||
static int tools_register_tool(const char* name, const char* desc,
|
||||
const char* params_schema,
|
||||
dstalk_tool_handler_fn handler) {
|
||||
try {
|
||||
if (!name || !handler) return -1;
|
||||
|
||||
// 如果已存在同名工具,先注销 / If a tool with the same name exists, unregister first
|
||||
tools_unregister_tool(name);
|
||||
|
||||
ToolDef td;
|
||||
td.name = name;
|
||||
td.description = desc ? desc : "";
|
||||
td.parameters_schema = params_schema ? params_schema : "";
|
||||
td.handler = handler;
|
||||
|
||||
std::lock_guard<std::mutex> lock(g_tools_mutex);
|
||||
g_tools.push_back(std::move(td));
|
||||
return 0;
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "tools_register_tool: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "tools_register_tool: unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// 按名称注销之前注册的工具 / Unregister a previously registered tool by name.
|
||||
static void tools_unregister_tool(const char* name) {
|
||||
try {
|
||||
if (!name) return;
|
||||
std::string n(name);
|
||||
std::lock_guard<std::mutex> lock(g_tools_mutex);
|
||||
g_tools.erase(
|
||||
std::remove_if(g_tools.begin(), g_tools.end(),
|
||||
[&n](const ToolDef& t) { return t.name == n; }),
|
||||
g_tools.end());
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "tools_unregister_tool: %s", e.what());
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "tools_unregister_tool: unknown exception");
|
||||
}
|
||||
}
|
||||
|
||||
// 将所有已注册工具序列化为 OpenAI function-calling 格式的 JSON 数组 / Serialize all registered tools into a JSON array in OpenAI function-calling format.
|
||||
static char* tools_get_tools_json() {
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
json::array tools_arr;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_tools_mutex);
|
||||
for (const auto& t : g_tools) {
|
||||
json::object tool_obj;
|
||||
tool_obj["type"] = "function";
|
||||
|
||||
json::object func_obj;
|
||||
func_obj["name"] = t.name;
|
||||
func_obj["description"] = t.description;
|
||||
|
||||
if (!t.parameters_schema.empty()) {
|
||||
func_obj["parameters"] = json::parse(t.parameters_schema);
|
||||
} else {
|
||||
json::object empty_params;
|
||||
empty_params["type"] = "object";
|
||||
empty_params["properties"] = json::object{};
|
||||
func_obj["parameters"] = empty_params;
|
||||
}
|
||||
|
||||
tool_obj["function"] = func_obj;
|
||||
tools_arr.push_back(tool_obj);
|
||||
}
|
||||
}
|
||||
|
||||
std::string result = json::serialize(tools_arr);
|
||||
return host ? host->strdup(result.c_str()) : nullptr;
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "tools_get_tools_json: %s", e.what());
|
||||
return nullptr;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "tools_get_tools_json: unknown exception");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// 按名称查找工具并分派执行到注册的处理器 / Look up a tool by name and dispatch execution to its registered handler.
|
||||
static char* tools_execute(const char* name, const char* args_json) {
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (!name) {
|
||||
return host ? host->strdup("{\"error\":\"tool name is null\"}") : nullptr;
|
||||
}
|
||||
|
||||
std::string n(name);
|
||||
ToolDef* found = nullptr;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(g_tools_mutex);
|
||||
for (auto& t : g_tools) {
|
||||
if (t.name == n) {
|
||||
found = &t;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!found) {
|
||||
json::object err_obj;
|
||||
err_obj["error"] = "unknown tool: " + n;
|
||||
return host ? host->strdup(json::serialize(err_obj).c_str()) : nullptr;
|
||||
}
|
||||
|
||||
const char* args = args_json ? args_json : "{}";
|
||||
return found->handler(args);
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "tools_execute: %s", e.what());
|
||||
json::object err_obj;
|
||||
err_obj["error"] = "tool execution internal error";
|
||||
return host ? host->strdup(json::serialize(err_obj).c_str()) : nullptr;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "tools_execute: unknown exception");
|
||||
return host ? host->strdup("{\"error\":\"tool execution internal error\"}") : nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
static dstalk_tools_service_t g_tools_service = {
|
||||
tools_register_tool,
|
||||
tools_unregister_tool,
|
||||
tools_get_tools_json,
|
||||
tools_execute
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// 插件生命周期 / Plugin lifecycle
|
||||
// ============================================================
|
||||
|
||||
// 插件初始化:查询 file_io 依赖,注册内置文件工具,注册 tools 服务 / Plugin init: query file_io dependency, register built-in file tools, register tools service.
|
||||
static int on_init(const dstalk_host_api_t* host) {
|
||||
try {
|
||||
g_host.store(host, std::memory_order_release);
|
||||
|
||||
// 查询依赖服务: file_io / Query dependency service: file_io
|
||||
void* raw = host->query_service("file_io", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-tools] required service 'file_io' not found");
|
||||
return -1;
|
||||
}
|
||||
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
|
||||
|
||||
// 向自身注册内置工具 / Register built-in tools with self
|
||||
tools_register_tool(
|
||||
"file_read",
|
||||
"Read the contents of a file at the given path",
|
||||
"{\"type\":\"object\",\"properties\":{\"path\":{\"type\":\"string\",\"description\":\"Path to the file to read\"}},\"required\":[\"path\"]}",
|
||||
builtin_file_read
|
||||
);
|
||||
|
||||
tools_register_tool(
|
||||
"file_write",
|
||||
"Write content to a file at the given path",
|
||||
"{\"type\":\"object\",\"properties\":{\"path\":{\"type\":\"string\",\"description\":\"Path to the file to write\"},\"content\":{\"type\":\"string\",\"description\":\"Content to write to the file\"}},\"required\":[\"path\",\"content\"]}",
|
||||
builtin_file_write
|
||||
);
|
||||
|
||||
return host->register_service("tools", 1, &g_tools_service);
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) h->log(DSTALK_LOG_ERROR, "on_init[tools]: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) h->log(DSTALK_LOG_ERROR, "on_init[tools]: unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// 插件关闭:清空所有已注册工具并清空服务指针 / Plugin shutdown: clear all registered tools and null out service pointers.
|
||||
static void on_shutdown() {
|
||||
try {
|
||||
std::lock_guard<std::mutex> lock(g_tools_mutex);
|
||||
g_tools.clear();
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "on_shutdown[tools]: %s", e.what());
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "on_shutdown[tools]: unknown exception");
|
||||
}
|
||||
g_file_io.store(nullptr, std::memory_order_release);
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
"tools",
|
||||
"1.0.0",
|
||||
"Tool registration and execution plugin with built-in file tools / 内置文件工具的工具注册和执行插件",
|
||||
DSTALK_API_VERSION,
|
||||
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
|
||||
on_init,
|
||||
on_shutdown,
|
||||
nullptr
|
||||
};
|
||||
|
||||
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
|
||||
return &g_info;
|
||||
}
|
||||
Reference in New Issue
Block a user