feat: add OpenAI-compatible AI provider plugin with SSE streaming support

- Implemented the OpenAI-compatible AI provider plugin, including configuration, chat, and chat_stream functionalities.
- Added support for SSE streaming and tool calls.
- Integrated Boost.JSON for JSON handling.
- Created CMake configuration for the plugin.
- Added error handling and logging throughout the plugin.
This commit is contained in:
2026-05-31 05:37:04 +08:00
parent f6cb51b40a
commit ba7382db2a
61 changed files with 163 additions and 147 deletions

View File

@@ -0,0 +1,7 @@
# ============================================================
# 依赖其他插件的插件 / Plugins depending on non-base plugins
# ============================================================
add_subdirectory(context) # 依赖 session / depends on session
add_subdirectory(openai) # 依赖 http, config / depends on http, config
add_subdirectory(anthropic) # 依赖 http, config / depends on http, config

View File

@@ -0,0 +1,22 @@
cmake_minimum_required(VERSION 3.21)
# ============================================================
# plugin-anthropic — Anthropic Claude AI 服务
# 依赖: http 服务 (查询), config 服务 (查询)
# ============================================================
add_library(plugin-anthropic SHARED
src/anthropic_plugin.cpp
)
target_link_libraries(plugin-anthropic PRIVATE dstalk)
# Boost.JSON 用于构建/解析请求和响应
find_package(Boost REQUIRED CONFIG)
target_link_libraries(plugin-anthropic PRIVATE boost::boost dstalk_boost_config)
set_target_properties(plugin-anthropic PROPERTIES
PREFIX ""
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
)

View File

@@ -0,0 +1,736 @@
/*
* @file anthropic_plugin.cpp
* @brief Anthropic Claude Messages API provider plugin with streaming support.
* Anthropic Claude Messages API 提供者插件,支持流式输出。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h"
#include <boost/json.hpp>
#include <boost/json/src.hpp>
#include <atomic>
#include <cstring>
#include <string>
#include <vector>
namespace json = boost::json;
// ============================================================================
// 全局指针 — W17.4: std::atomic 保护 on_shutdown 与 service 函数并发读写 / Global pointers — W17.4: std::atomic protects concurrent read/write between on_shutdown and service functions
// ============================================================================
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
static dstalk_config_service_t* g_config = nullptr;
// ============================================================================
// 配置数据 / Config data
// ============================================================================
struct PluginConfig {
std::string provider;
std::string base_url;
std::string api_key;
std::string model;
int max_tokens = 4096;
double temperature = 0.7;
};
static PluginConfig g_cfg;
static std::string g_tools_json; // W21.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
// ============================================================================
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
// ============================================================================
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
static void secure_zero(void* p, size_t n) {
volatile char* vp = (volatile char*)p;
while (n--) *vp++ = 0;
}
// ============================================================================
// 辅助:提取 host / target / Helper: extract host / target
// ============================================================================
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
static bool extract_host_port(const std::string& url,
std::string& scheme_out, std::string& host_out,
std::string& port_out, std::string& target_out)
{
size_t scheme_end = url.find("://");
if (scheme_end == std::string::npos) return false;
scheme_out = url.substr(0, scheme_end);
std::string rest = url.substr(scheme_end + 3);
size_t slash = rest.find('/');
std::string authority = (slash != std::string::npos) ? rest.substr(0, slash) : rest;
target_out = (slash != std::string::npos) ? rest.substr(slash) : "/";
size_t colon = authority.rfind(':');
if (colon != std::string::npos) {
host_out = authority.substr(0, colon);
port_out = authority.substr(colon + 1);
} else {
host_out = authority;
port_out = (scheme_out == "https") ? "443" : "80";
}
return true;
}
// ============================================================================
// 构建 Anthropic headers JSON / Build Anthropic headers JSON
// ============================================================================
// 构建包含 x-api-key 和 anthropic-version 的 JSON headers 对象 / Build the JSON headers object containing x-api-key and anthropic-version.
static std::string build_headers_json()
{
json::object h;
h["x-api-key"] = g_cfg.api_key;
h["anthropic-version"] = "2023-06-01";
return json::serialize(h);
}
// ============================================================================
// 构建 Anthropic JSON 请求体 / Build Anthropic JSON request body
// ============================================================================
// 构建 Anthropic Messages API 的完整 JSON 请求体。
// 按 Anthropic 规范将 system 消息提取为顶层 system 字段 / Build the full JSON request body for the Anthropic Messages API.
// Extracts system messages as a top-level "system" field per Anthropic spec.
static std::string build_request_json(
const dstalk_message_t* history, int history_len,
const std::string& user_input,
const std::string& tools_json,
bool stream)
{
json::object root;
root["model"] = g_cfg.model;
root["max_tokens"] = g_cfg.max_tokens;
root["stream"] = stream;
// 提取 system 消息作为顶层字段 / Extract system messages as top-level field
std::string system_prompt;
json::array msgs;
for (int i = 0; i < history_len; ++i) {
const auto& m = history[i];
if (m.role && std::strcmp(m.role, "system") == 0) {
if (!system_prompt.empty()) system_prompt += "\n\n";
system_prompt += m.content ? m.content : "";
continue;
}
json::object obj;
obj["role"] = m.role ? m.role : "";
obj["content"] = m.content ? m.content : "";
msgs.push_back(obj);
}
// 追加当前用户输入 / Append current user input
{
json::object obj;
obj["role"] = "user";
obj["content"] = user_input;
msgs.push_back(obj);
}
root["messages"] = msgs;
if (!system_prompt.empty()) {
root["system"] = system_prompt;
}
if (g_cfg.temperature >= 0.0 && g_cfg.temperature <= 1.0) {
root["temperature"] = g_cfg.temperature;
}
// W21.2: tools 定义传递给 API / Pass tools definition to API
if (!tools_json.empty()) {
root["tools"] = json::parse(tools_json);
}
return json::serialize(root);
}
// ============================================================================
// 解析非流式响应 / Parse non-streaming response
// ============================================================================
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t。
// 处理 text 和 tool_use content block将 tool_use 转换为 OpenAI 格式 / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
// Handles text and tool_use content blocks, converting tool_use to OpenAI format.
static void parse_response(const char* body, int http_status,
dstalk_chat_result_t& r)
{
const auto* h = g_host.load(std::memory_order_acquire);
r.http_status = http_status;
if (http_status < 200 || http_status >= 300) {
r.ok = 0;
try {
auto jv = json::parse(body ? body : "{}");
auto obj = jv.as_object();
if (obj.contains("error")) {
auto err = obj["error"].as_object();
r.error = h->strdup(
json::value_to<std::string>(err["message"]).c_str());
}
} catch (...) {
std::string msg = "HTTP " + std::to_string(http_status);
r.error = h->strdup(msg.c_str());
}
if (!r.error) {
std::string msg = "HTTP " + std::to_string(http_status);
r.error = h->strdup(msg.c_str());
}
r.content = nullptr;
r.tool_calls_json = nullptr;
return;
}
try {
auto jv = json::parse(body ? body : "{}");
auto obj = jv.as_object();
auto content = obj["content"].as_array();
if (!content.empty()) {
// W21.2: 提取 text 和 tool_use content blocks / Extract text and tool_use content blocks
std::string text_content;
json::array tool_use_blocks;
for (const auto& block : content) {
auto bobj = block.as_object();
if (!bobj.contains("type")) continue;
std::string btype = json::value_to<std::string>(bobj["type"]);
if (btype == "text") {
text_content = json::value_to<std::string>(bobj["text"]);
} else if (btype == "tool_use") {
// 转换为 OpenAI 兼容格式: {id, type:"function", function:{name, arguments}} / Convert to OpenAI-compatible format: {id, type:"function", function:{name, arguments}}
json::object tc;
tc["id"] = bobj["id"];
tc["type"] = "function";
json::object func;
func["name"] = bobj["name"];
func["arguments"] = json::serialize(bobj["input"]);
tc["function"] = func;
tool_use_blocks.push_back(std::move(tc));
}
}
if (!tool_use_blocks.empty()) {
r.tool_calls_json = h->strdup(
json::serialize(tool_use_blocks).c_str());
} else {
r.tool_calls_json = nullptr;
}
if (!text_content.empty()) {
r.content = h->strdup(text_content.c_str());
r.ok = 1;
r.error = nullptr;
return;
} else if (!tool_use_blocks.empty()) {
// tool-only 响应 / tool-only response
r.content = nullptr;
r.ok = 1;
r.error = nullptr;
return;
}
r.ok = 0;
r.error = h->strdup("no text or tool_use content block found");
} else {
r.ok = 0;
r.error = h->strdup("empty response");
}
r.content = nullptr;
r.tool_calls_json = nullptr;
} catch (std::exception& e) {
r.ok = 0;
std::string msg = std::string("json parse: ") + e.what();
r.error = h->strdup(msg.c_str());
r.content = nullptr;
r.tool_calls_json = nullptr;
} catch (...) {
r.ok = 0;
r.error = h->strdup("json parse error");
r.content = nullptr;
r.tool_calls_json = nullptr;
}
}
// ============================================================================
// SSE 事件解析Anthropic 格式: event/content_block_delta) / SSE event parsing (Anthropic format: event/content_block_delta)
// ============================================================================
// W21.2: 按 content_block index 累积 Anthropic tool_use 增量 / Accumulate Anthropic tool_use increments by content_block index
struct ToolCallAccum {
int index = -1;
std::string id;
std::string name;
std::string arguments; // 从 input_json_delta.partial_json 累积 / accumulated from input_json_delta.partial_json
};
struct StreamContext {
const dstalk_host_api_t* host;
dstalk_stream_cb user_cb;
void* userdata;
std::string accumulated;
bool saw_data_line = false;
std::vector<ToolCallAccum> tool_calls; // W21.2: 按 index 累积 tool_use content blocks / accumulate tool_use content blocks by index
};
// W21.2: 解析 Anthropic SSE 事件,含 tool_use content_block 增量解析 / Parse Anthropic SSE events with tool_use content_block incremental parsing
// 解析单个 Anthropic SSE "data:" JSON 事件。处理 content_block_start、
// content_block_delta (text_delta/input_json_delta) 和 message_stop。
// 如果产生了 content token 则返回 true否则返回 false / Parse a single Anthropic SSE "data:" JSON event. Handles content_block_start,
// content_block_delta (text_delta/input_json_delta), and message_stop.
// Returns true if a content token was produced, false otherwise.
static bool parse_sse_data(const std::string& data, std::string& token_out,
StreamContext* ctx)
{
try {
auto jv = json::parse(data);
auto obj = jv.as_object();
auto* type_ptr = obj.if_contains("type");
if (!type_ptr || !type_ptr->is_string()) return false;
std::string type = json::value_to<std::string>(*type_ptr);
if (type == "content_block_start") {
// content_block_start 可能为 tool_use / content_block_start may be tool_use
auto* cb = obj.if_contains("content_block");
if (!cb || !cb->is_object()) return false;
auto& cb_obj = cb->as_object();
auto* cb_type = cb_obj.if_contains("type");
if (!cb_type || !cb_type->is_string()) return false;
std::string cb_type_str = json::value_to<std::string>(*cb_type);
if (cb_type_str == "tool_use" && ctx) {
auto* idx_ptr = obj.if_contains("index");
int idx = idx_ptr ? static_cast<int>(
json::value_to<int64_t>(*idx_ptr)) : -1;
if (idx < 0) return false;
while (static_cast<int>(ctx->tool_calls.size()) <= idx) {
ctx->tool_calls.push_back({});
}
auto& acc = ctx->tool_calls[idx];
acc.index = idx;
if (cb_obj.contains("id") && cb_obj["id"].is_string())
acc.id = json::value_to<std::string>(cb_obj["id"]);
if (cb_obj.contains("name") && cb_obj["name"].is_string())
acc.name = json::value_to<std::string>(cb_obj["name"]);
}
return false;
}
if (type == "content_block_delta") {
auto* delta = obj.if_contains("delta");
if (!delta || !delta->is_object()) return false;
auto& dobj = delta->as_object();
auto* dtype = dobj.if_contains("type");
if (!dtype || !dtype->is_string()) return false;
std::string delta_type = json::value_to<std::string>(*dtype);
if (delta_type == "text_delta") {
auto* text = dobj.if_contains("text");
if (text && text->is_string()) {
token_out = json::value_to<std::string>(*text);
return true;
}
} else if (delta_type == "input_json_delta" && ctx) {
// W21.2: 累积 tool_use arguments 分片 / Accumulate tool_use arguments fragments
auto* pj = dobj.if_contains("partial_json");
if (pj && pj->is_string()) {
auto* idx_ptr = obj.if_contains("index");
int idx = idx_ptr ? static_cast<int>(
json::value_to<int64_t>(*idx_ptr)) : -1;
if (idx >= 0 && idx < static_cast<int>(ctx->tool_calls.size())) {
ctx->tool_calls[idx].arguments +=
json::value_to<std::string>(*pj);
}
}
return false;
}
} else if (type == "message_stop") {
token_out.clear();
return true; // 流结束 / stream end
}
// 忽略: message_start, content_block_stop, ping, message_delta / Ignore: message_start, content_block_stop, ping, message_delta
} catch (...) {
// 解析失败忽略 / Ignore parse failures
}
return false;
}
// ============================================================================
// configure / configure
// ============================================================================
// 配置插件provider、endpoint、auth、model 和生成参数 / Configure the plugin with provider, endpoint, auth, model, and generation parameters.
static int my_configure(const char* provider, const char* base_url,
const char* api_key, const char* model,
int max_tokens, double temperature)
{
try {
if (provider) g_cfg.provider = provider;
if (base_url) g_cfg.base_url = base_url;
if (api_key) g_cfg.api_key = api_key;
if (model) g_cfg.model = model;
g_cfg.max_tokens = max_tokens;
g_cfg.temperature = temperature;
const auto* h = g_host.load(std::memory_order_acquire);
if (h) {
// W21.2: 从 tools service 缓存 tools_json供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
h->query_service("tools", 1));
if (tools_svc && tools_svc->get_tools_json) {
char* json = tools_svc->get_tools_json();
if (json) {
g_tools_json = json;
h->free(json);
}
}
h->log(DSTALK_LOG_INFO,
"[anthropic] configured: model=%s base_url=%s max_tokens=%d temperature=%.2f",
g_cfg.model.c_str(), g_cfg.base_url.c_str(),
g_cfg.max_tokens, g_cfg.temperature);
}
return 0;
} catch (const std::exception& e) {
const auto* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] my_configure exception: %s", e.what());
return -1;
} catch (...) {
const auto* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] my_configure unknown exception");
return -1;
}
}
// ============================================================================
// chat / chat
// ============================================================================
// 非流式 chat completion发送 history + user input返回完整响应 / Non-streaming chat completion: send history + user input, return full response.
static dstalk_chat_result_t my_chat(
const dstalk_message_t* history, int history_len,
const char* user_input,
const char* tools_json)
{
try {
dstalk_chat_result_t r = {};
r.ok = 0;
const auto* host = g_host.load(std::memory_order_acquire);
const auto* http = g_http.load(std::memory_order_acquire);
if (!http) {
r.error = host->strdup("http service not available");
return r;
}
std::string scheme, hostname, port, target;
extract_host_port(g_cfg.base_url, scheme, hostname, port, target);
std::string target_path = target + "/v1/messages";
std::string body = build_request_json(history, history_len,
user_input ? user_input : "",
tools_json ? tools_json : g_tools_json, false);
std::string headers_json = build_headers_json();
char* response_body = nullptr;
int status_code = 0;
int ret = http->post_json(
hostname.c_str(), port.c_str(), target_path.c_str(), body.c_str(),
headers_json.c_str(), &response_body, &status_code);
if (ret != 0) {
r.error = host->strdup("http request failed");
if (response_body) host->free(response_body);
return r;
}
parse_response(response_body, status_code, r);
if (response_body) {
host->free(response_body);
}
return r;
} catch (const std::exception& e) {
const auto* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[anthropic] my_chat exception: %s", e.what());
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup(e.what()) : nullptr;
return r;
} catch (...) {
const auto* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[anthropic] my_chat unknown exception");
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup("unknown exception") : nullptr;
return r;
}
}
// ============================================================================
// chat_stream / chat_stream
// ============================================================================
// 行回调 / SSE line callback
// SSE 行回调:解析每个 Anthropic SSE 行并将文本 token 转发给用户 / SSE line callback: parses each Anthropic SSE line and forwards text tokens to user.
static int sse_line_callback(const char* line, void* userdata)
{
try {
auto* ctx = static_cast<StreamContext*>(userdata);
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
std::string line_str(line);
// SSE 格式: "data: <json>" / SSE format: "data: <json>"
if (line_str.rfind("data: ", 0) == 0) {
std::string data = line_str.substr(6);
std::string token;
if (parse_sse_data(data, token, ctx)) {
ctx->saw_data_line = true;
if (token.empty()) {
// message_stop / message_stop
return 0;
}
ctx->accumulated += token;
if (ctx->user_cb) {
return ctx->user_cb(token.c_str(), ctx->userdata);
}
}
}
// "event: ..." 行和其他 -> 忽略 / "event: ..." lines and others -> ignored
return 1;
} catch (const std::exception& e) {
const auto* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] sse_line_callback exception: %s", e.what());
return 0;
} catch (...) {
const auto* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] sse_line_callback unknown exception");
return 0;
}
}
// 流式 chat completion以 stream=true 发送 history + user input通过回调传递 token。
// 累积 tool_use blocks 并在结束时序列化 / Streaming chat completion: send history + user input with stream=true, deliver tokens
// via callback. Accumulates tool_use blocks and serializes them at end.
static dstalk_chat_result_t my_chat_stream(
const dstalk_message_t* history, int history_len,
const char* user_input,
dstalk_stream_cb cb, void* userdata)
{
try {
dstalk_chat_result_t r = {};
r.ok = 0;
const auto* host = g_host.load(std::memory_order_acquire);
const auto* http = g_http.load(std::memory_order_acquire);
if (!http) {
r.error = host->strdup("http service not available");
return r;
}
std::string scheme, hostname, port, target;
extract_host_port(g_cfg.base_url, scheme, hostname, port, target);
std::string target_path = target + "/v1/messages";
std::string body = build_request_json(history, history_len,
user_input ? user_input : "", g_tools_json, true);
std::string headers_json = build_headers_json();
StreamContext ctx;
ctx.host = host;
ctx.user_cb = cb;
ctx.userdata = userdata;
ctx.saw_data_line = false;
char* response_body = nullptr;
int status_code = 0;
int ret = http->post_stream(
hostname.c_str(), port.c_str(), target_path.c_str(), body.c_str(),
headers_json.c_str(),
sse_line_callback, &ctx,
&response_body, &status_code);
r.http_status = status_code;
// 检查错误状态 / Check error status
if (status_code < 200 || status_code >= 300) {
r.ok = 0;
if (response_body && response_body[0]) {
try {
auto jv = json::parse(response_body);
auto obj = jv.as_object();
if (obj.contains("error")) {
auto err = obj["error"].as_object();
r.error = host->strdup(
json::value_to<std::string>(err["message"]).c_str());
}
} catch (...) {}
}
if (!r.error) {
if (status_code <= 0)
r.error = host->strdup("transport error");
else
r.error = host->strdup(
("HTTP " + std::to_string(status_code)).c_str());
}
if (response_body) host->free(response_body);
r.content = nullptr;
r.tool_calls_json = nullptr;
return r;
}
if (response_body) host->free(response_body);
// W21.2: 成功条件 = 有内容 OR 有 tool_callstool-only 响应如 function calling / Success = has content OR has tool_calls (tool-only responses like function calling)
bool has_content = !ctx.accumulated.empty();
bool has_tool_calls = !ctx.tool_calls.empty();
if (!has_content && !has_tool_calls) {
r.ok = 0;
r.error = host->strdup("no content received");
r.content = nullptr;
r.tool_calls_json = nullptr;
} else {
r.ok = 1;
r.error = nullptr;
r.content = has_content
? host->strdup(ctx.accumulated.c_str()) : nullptr;
// W21.2: 序列化累积的 tool_calls 为 JSON兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible format)
if (has_tool_calls) {
json::array tc_array;
for (auto& tc : ctx.tool_calls) {
json::object tc_obj;
tc_obj["index"] = tc.index;
if (!tc.id.empty()) tc_obj["id"] = tc.id;
tc_obj["type"] = "function";
json::object func;
if (!tc.name.empty()) func["name"] = tc.name;
func["arguments"] = tc.arguments;
tc_obj["function"] = func;
tc_array.push_back(std::move(tc_obj));
}
std::string tc_json = json::serialize(tc_array);
r.tool_calls_json = host ? host->strdup(tc_json.c_str()) : nullptr;
} else {
r.tool_calls_json = nullptr;
}
}
return r;
} catch (const std::exception& e) {
const auto* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[anthropic] my_chat_stream exception: %s", e.what());
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup(e.what()) : nullptr;
return r;
} catch (...) {
const auto* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[anthropic] my_chat_stream unknown exception");
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup("unknown exception") : nullptr;
return r;
}
}
// ============================================================================
// free_result / free_result
// ============================================================================
// 释放 chat result 结构体中所有主机分配的字符串字段 / Free all host-allocated string fields in a chat result struct.
static void my_free_result(dstalk_chat_result_t* result)
{
const auto* h = g_host.load(std::memory_order_acquire);
if (!result || !h) return;
if (result->content) { h->free((void*)result->content); result->content = nullptr; }
if (result->error) { h->free((void*)result->error); result->error = nullptr; }
if (result->tool_calls_json) { h->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; }
}
// ============================================================================
// 服务 vtable / Service vtable
// ============================================================================
static dstalk_ai_service_t g_service = {
&my_configure,
&my_chat,
&my_chat_stream,
&my_free_result,
};
// ============================================================================
// 生命周期 / Lifecycle
// ============================================================================
// 插件初始化:查询 http 和 config 服务,注册 ai.anthropic 服务 / Plugin init: query http and config services, register ai.anthropic service.
static int on_init(const dstalk_host_api_t* host)
{
try {
g_host.store(host, std::memory_order_release);
auto* http_svc = (dstalk_http_service_t*)host->query_service("http", 1);
g_http.store(http_svc, std::memory_order_release);
g_config = (dstalk_config_service_t*)host->query_service("config", 1);
if (!http_svc) {
if (host) host->log(DSTALK_LOG_ERROR, "[anthropic] http service not found");
return -1;
}
if (host) host->log(DSTALK_LOG_INFO, "[anthropic] initializing Anthropic AI plugin");
return host->register_service("ai.anthropic", 1, &g_service);
} catch (const std::exception& e) {
const auto* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] on_init exception: %s", e.what());
return -1;
} catch (...) {
const auto* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] on_init unknown exception");
return -1;
}
}
// 插件关闭:从内存安全擦除 API key清空服务指针 / Plugin shutdown: securely erase API key from memory, null out service pointers.
static void on_shutdown()
{
try {
const auto* h = g_host.load(std::memory_order_acquire);
if (h) h->log(DSTALK_LOG_INFO, "[anthropic] shutdown");
secure_zero(g_cfg.api_key.data(), g_cfg.api_key.size());
g_cfg.api_key.clear();
g_http.store(nullptr, std::memory_order_release);
g_config = nullptr;
g_host.store(nullptr, std::memory_order_release);
} catch (const std::exception& e) {
const auto* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] on_shutdown exception: %s", e.what());
} catch (...) {
const auto* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] on_shutdown unknown exception");
}
}
// ============================================================================
// 插件描述符 / Plugin descriptor
// ============================================================================
static dstalk_plugin_info_t g_info = {
/* .name = */ "anthropic-ai",
/* .version = */ "1.0.0",
/* .description = */ "Anthropic Claude AI provider (Messages API) / Anthropic Claude AI 提供者 (Messages API)",
/* .api_version = */ DSTALK_API_VERSION,
/* .dependencies = */ { "http", "config", NULL },
/* .on_init = */ on_init,
/* .on_shutdown = */ on_shutdown,
/* .on_event = */ nullptr,
};
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
{
return &g_info;
}

View File

@@ -0,0 +1,9 @@
add_library(plugin-context SHARED src/context_plugin.cpp)
target_link_libraries(plugin-context PRIVATE dstalk)
set_target_properties(plugin-context PROPERTIES
PREFIX ""
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
)

View File

@@ -0,0 +1,446 @@
/*
* @file context_plugin.cpp
* @brief Context plugin: token counting and context window trimming.
* 上下文插件token 计数和上下文窗口裁剪。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
// plugin-context: 上下文管理服务插件 / Context management service plugin
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_types.h"
#include "dstalk/dstalk_services.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <exception>
#include <string>
#include <vector>
// ============================================================
// 全局状态 / Global state
// ============================================================
static const dstalk_host_api_t* g_host = nullptr;
static const dstalk_session_service_t* g_session = nullptr;
// ============================================================
// 内部 C++ 辅助:共享 UTF-8 token 计数 / Internal C++ helper: shared UTF-8 token counting
// W18.1: 合并 count_tokens_one_message / count_tokens_trim 的重复逻辑 (F-11.1-5)
// Merge duplicated logic between count_tokens_one_message / count_tokens_trim (F-11.1-5)
// 添加 UTF-8 越界保护 (F-11.1-4) 和 0xC0/0xC1 过短编码检测 (F-11.1-6)
// Add UTF-8 out-of-bounds protection (F-11.1-4) and 0xC0/0xC1 overlong encoding detection (F-11.1-6)
// ============================================================
// 统计 UTF-8 字节序列 [text, text+len) 的估算 token 数。
// overhead: 每条消息的固定开销 tokenrole + separators = 4
// 多字节序列在越界或无效后继字节时回退为单字节 other_chars 计数,不崩溃。
// Count estimated tokens for UTF-8 byte sequence [text, text+len).
// overhead: fixed token overhead per message (role + separators = 4).
// Multi-byte sequences fall back to single-byte other_chars counting when out-of-bounds or invalid continuation bytes.
static size_t count_tokens_utf8(const char* text, size_t len, size_t overhead) {
if (!text || len == 0) return overhead;
size_t ascii_chars = 0;
size_t chinese_chars = 0;
size_t other_chars = 0;
size_t i = 0;
while (i < len && text[i] != '\0') {
unsigned char c = static_cast<unsigned char>(text[i]);
if (c < 0x80) {
// ASCII / ASCII
ascii_chars++;
i += 1;
} else if (c >= 0xE4 && c <= 0xE9) {
// CJK 统一表意文字 (U+4E00-U+9FFF): 3 字节 UTF-8 0xE4-0xE9 / CJK Unified Ideographs (U+4E00-U+9FFF): 3-byte UTF-8 0xE4-0xE9
// W18.1 (F-11.1-4): 检查后续 2 字节是否在有效范围内 / Check if subsequent 2 bytes are in valid range
if (i + 2 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
other_chars++;
i += 1;
} else {
chinese_chars++;
i += 3;
}
} else if (c >= 0xC2 && c < 0xE0) {
// 2 字节序列 (有效范围 0xC2-0xDF) / 2-byte sequence (valid range 0xC2-0xDF)
// W18.1 (F-11.1-4): 检查后续 1 字节 / Check subsequent 1 byte
if (i + 1 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80) {
other_chars++;
i += 1;
} else {
other_chars++;
i += 2;
}
} else if (c == 0xC0 || c == 0xC1) {
// W18.1 (F-11.1-6): 过短编码 (overlong encoding),非法 UTF-8 起始字节 / Overlong encoding, invalid UTF-8 start byte
// 0xC0/0xC1 永远不会出现在合法 UTF-8 中;视为单字节计入 other_chars / 0xC0/0xC1 never appear in valid UTF-8; counted as single-byte in other_chars
other_chars++;
i += 1;
} else if (c >= 0xE0 && c < 0xF0) {
// 非 CJK 3 字节序列 (0xE0-0xE3, 0xEA-0xEF) / Non-CJK 3-byte sequence (0xE0-0xE3, 0xEA-0xEF)
// CJK 范围 0xE4-0xE9 已在上方分支处理 / CJK range 0xE4-0xE9 handled in branch above
if (i + 2 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80) {
other_chars++;
i += 1;
} else {
other_chars++;
i += 3;
}
} else if (c >= 0xF0 && c < 0xF8) {
// 4 字节序列 / 4-byte sequence
if (i + 3 >= len ||
(static_cast<unsigned char>(text[i + 1]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 2]) & 0xC0) != 0x80 ||
(static_cast<unsigned char>(text[i + 3]) & 0xC0) != 0x80) {
other_chars++;
i += 1;
} else {
other_chars++;
i += 4;
}
} else {
// 续字节 (0x80-0xBF) 和其他无效起始字节 (0xF8-0xFF) / Continuation bytes (0x80-0xBF) and other invalid start bytes (0xF8-0xFF)
other_chars++;
i += 1;
}
}
return (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3) + overhead;
}
// ============================================================
// 消息级 token 计数(供 count_tokens_all 和 trim_impl 调用的薄封装) / Message-level token counting (thin wrappers for count_tokens_all and trim_impl)
// ============================================================
// 对单条 C 消息结构体封装 count_tokens_utf8 / Wrap count_tokens_utf8 for a single C message struct.
static size_t count_tokens_one_message(const dstalk_message_t& msg) {
const char* text = msg.content;
if (!text) return 4; // 只有 overhead / overhead only
return count_tokens_utf8(text, std::strlen(text), 4);
}
// 对 C 消息数组求和估算 token / Sum token estimates across an array of C messages.
static size_t count_tokens_all(const dstalk_message_t* msgs, int count) {
size_t total = 0;
for (int i = 0; i < count; ++i) {
total += count_tokens_one_message(msgs[i]);
}
return total;
}
// ============================================================
// 内部 trim 逻辑 / Internal trim logic
// ============================================================
// 为 trim 操作将 C 消息数组复制到内部 struct / Copy C message array to internal struct for trim operation
struct TrimMessage {
std::string role;
std::string content;
std::string tool_call_id;
std::string tool_calls_json;
};
static size_t count_tokens_trim(const TrimMessage& msg) {
if (msg.content.empty()) return 4;
return count_tokens_utf8(msg.content.c_str(), msg.content.size(), 4);
}
static size_t count_tokens_trim_vec(const std::vector<TrimMessage>& msgs) {
size_t total = 0;
for (const auto& m : msgs) total += count_tokens_trim(m);
return total;
}
// 释放单条消息中所有已分配的字符串字段(用于 OOM 回滚) / Free all host-allocated string fields in a single dstalk_message_t (OOM rollback helper).
static void free_msg_strs(dstalk_message_t* msg) {
if (msg->role) { g_host->free((void*)msg->role); msg->role = nullptr; }
if (msg->content) { g_host->free((void*)msg->content); msg->content = nullptr; }
if (msg->tool_call_id) { g_host->free((void*)msg->tool_call_id); msg->tool_call_id = nullptr; }
if (msg->tool_calls_json) { g_host->free((void*)msg->tool_calls_json); msg->tool_calls_json = nullptr; }
}
// 将 TrimMessage 的字符串字段通过 g_host->strdup 复制到 dstalk_message_t。
// 成功返回 0OOM 时释放当前消息已分配字段并返回 -1。
// Copy TrimMessage string fields into a dstalk_message_t via host->strdup.
// On OOM, frees already-allocated fields and returns -1.
static int strdup_message_fields(dstalk_message_t* dst, const TrimMessage& src) {
memset(dst, 0, sizeof(dstalk_message_t));
if (!src.role.empty()) {
dst->role = g_host->strdup(src.role.c_str());
if (!dst->role) goto oom;
}
if (!src.content.empty()) {
dst->content = g_host->strdup(src.content.c_str());
if (!dst->content) goto oom;
}
if (!src.tool_call_id.empty()) {
dst->tool_call_id = g_host->strdup(src.tool_call_id.c_str());
if (!dst->tool_call_id) goto oom;
}
if (!src.tool_calls_json.empty()) {
dst->tool_calls_json = g_host->strdup(src.tool_calls_json.c_str());
if (!dst->tool_calls_json) goto oom;
}
return 0;
oom:
free_msg_strs(dst);
return -1;
}
// W12.1 修复trim_impl 包裹 try/catch 防止 C++ 异常穿越 ABI 边界 (§5.3) / W12.1 fix: trim_impl wrapped in try/catch to prevent C++ exceptions crossing ABI boundary (§5.3)
// 核心裁剪逻辑:通过删除最旧的 user/assistant 对来减少消息列表以适应 max_tokens。
// 保留 system 消息。try/catch 保护 ABI / Core trim logic: reduce message list to fit within max_tokens by removing
// oldest user/assistant pairs. Preserves system messages. try/catch guards ABI.
static int trim_impl(const dstalk_message_t* in, int in_count,
dstalk_message_t** out, int* out_count,
size_t max_tokens) {
try {
if (!in || in_count <= 0 || !out || !out_count) return -1;
// W18.1 (F-11.1-3): g_max_tokens 已移除,调用方必须提供有效 max_tokens
// 传 0 时使用硬编码默认值 4096 / g_max_tokens removed, caller must provide valid max_tokens;
// when 0 is passed, use hardcoded default 4096.
if (max_tokens == 0) max_tokens = 4096;
// 将 C 数组转换为内部 vector / Convert C array to internal vector
std::vector<TrimMessage> messages;
messages.reserve(in_count);
for (int i = 0; i < in_count; ++i) {
TrimMessage tm;
if (in[i].role) tm.role = in[i].role;
if (in[i].content) tm.content = in[i].content;
if (in[i].tool_call_id) tm.tool_call_id = in[i].tool_call_id;
if (in[i].tool_calls_json) tm.tool_calls_json = in[i].tool_calls_json;
messages.push_back(std::move(tm));
}
// 如果已在限制内,直接返回完整副本 / If already within limit, return full copy directly
size_t current = count_tokens_trim_vec(messages);
if (current <= max_tokens) {
*out_count = in_count;
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * in_count));
if (!*out) return -1;
// W12.1: strdup 返回值逐一检查OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback already allocated on OOM
for (int i = 0; i < in_count; ++i) {
if (strdup_message_fields(&(*out)[i], messages[i]) != 0) {
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
g_host->free(*out);
*out = nullptr;
return -1;
}
}
return 0;
}
// 分离 system 消息和非 system 消息 / Separate system messages from non-system messages
std::vector<TrimMessage> system_msgs;
std::vector<TrimMessage> non_system_msgs;
for (const auto& msg : messages) {
if (msg.role == "system") {
system_msgs.push_back(msg);
} else {
non_system_msgs.push_back(msg);
}
}
size_t system_tokens = count_tokens_trim_vec(system_msgs);
if (system_tokens > max_tokens) {
std::fprintf(stderr, "[context] WARNING: system messages alone "
"(%zu tokens) exceed max_context_tokens (%zu)\n",
system_tokens, max_tokens);
}
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
for (const auto& msg : non_system_msgs) {
size_t msg_tokens = count_tokens_trim(msg);
if (msg_tokens > max_tokens) {
std::fprintf(stderr, "[context] WARNING: single message "
"(%s, %zu tokens) exceeds max_context_tokens (%zu). "
"Returning empty list.\n",
msg.role.c_str(), msg_tokens, max_tokens);
*out = nullptr;
*out_count = 0;
return -1;
}
}
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
while (!non_system_msgs.empty()) {
current = system_tokens + count_tokens_trim_vec(non_system_msgs);
if (current <= max_tokens) break;
// 找第一个 "user" 消息 / Find first "user" message
auto user_it = non_system_msgs.begin();
while (user_it != non_system_msgs.end() && user_it->role != "user") {
++user_it;
}
if (user_it == non_system_msgs.end()) break;
// 找下一个 "assistant" / Find next "assistant"
auto assistant_it = user_it + 1;
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
++assistant_it;
}
if (assistant_it == non_system_msgs.end()) {
non_system_msgs.erase(user_it);
} else {
// 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
non_system_msgs.erase(assistant_it);
user_it = non_system_msgs.begin();
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
if (user_it != non_system_msgs.end()) non_system_msgs.erase(user_it);
}
}
// W18.1 (F-11.1-3): 消息数量上限粗略估算(每消息 ~100 token使用当前 max_tokens / Message count upper bound rough estimate (~100 tokens per message), uses current max_tokens
{
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
if (max_msg_count < 1) max_msg_count = 1;
while (non_system_msgs.size() > max_msg_count) {
non_system_msgs.erase(non_system_msgs.begin());
}
}
// 组装结果 / Assemble result
std::vector<TrimMessage> result;
result.reserve(system_msgs.size() + non_system_msgs.size());
result.insert(result.end(), system_msgs.begin(), system_msgs.end());
result.insert(result.end(), non_system_msgs.begin(), non_system_msgs.end());
int result_count = static_cast<int>(result.size());
*out_count = result_count;
*out = static_cast<dstalk_message_t*>(g_host->alloc(sizeof(dstalk_message_t) * result_count));
if (!*out) return -1;
// W12.1: strdup 返回值逐一检查OOM 时回滚已分配消息 / strdup return value checked one-by-one, rollback on OOM
for (int i = 0; i < result_count; ++i) {
if (strdup_message_fields(&(*out)[i], result[i]) != 0) {
for (int j = 0; j < i; ++j) free_msg_strs(&(*out)[j]);
g_host->free(*out);
*out = nullptr;
return -1;
}
}
return 0;
} catch (const std::exception& e) {
// W12.1: 防止 std::bad_alloc 等 C++ 异常穿越 C ABI 边界 -> std::terminate() / Prevent C++ exceptions (std::bad_alloc etc.) from crossing C ABI boundary -> std::terminate()
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl exception: %s", e.what());
return -1;
} catch (...) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[context] trim_impl unknown exception");
return -1;
}
}
// ============================================================
// Context 服务 vtable 实现 / Context service vtable implementation
// ============================================================
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 -> std::terminate() / Wrapped try/catch prevents exceptions crossing C ABI boundary -> std::terminate()
// 对 C 消息数组进行 token 计数。输入为 null/空时返回 0 / Count tokens across an array of C messages. Returns 0 on null/empty input.
static size_t context_count_tokens(const dstalk_message_t* msgs, int count) {
try {
if (!msgs || count <= 0) return 0;
return count_tokens_all(msgs, count);
} catch (...) {
return 0;
}
}
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
// 裁剪消息列表以适应 max_tokens返回新分配的主机内存数组 / Trim a message list to fit within max_tokens, returning a new host-allocated array.
static int context_trim(const dstalk_message_t* in, int in_count,
dstalk_message_t** out, int* out_count,
size_t max_tokens) {
try {
return trim_impl(in, in_count, out, out_count, max_tokens);
} catch (...) {
return -1;
}
}
// W18.1 (F-11.1-3): g_max_tokens / context_set_max_tokens 已移除。
// max_tokens 由调用方通过 trim() 的 max_tokens 参数直接传入;
// 传 0 时 trim_impl 使用硬编码默认值 4096。
// g_max_tokens / context_set_max_tokens removed. max_tokens is passed directly
// by caller via trim()'s max_tokens parameter; trim_impl uses hardcoded default 4096 when 0.
static dstalk_context_service_t g_context_service = {
context_count_tokens,
context_trim
};
// ============================================================
// 插件生命周期 / Plugin lifecycle
// ============================================================
// W12.1: 包裹 try/catch 防止异常穿越 C ABI 边界 / Wrapped try/catch prevents exceptions crossing C ABI boundary
// 插件初始化:保存主机指针,查询 session 依赖,注册 context 服务 / Plugin init: store host pointer, query session dependency, register context service.
static int on_init(const dstalk_host_api_t* host) {
try {
g_host = host;
// 查询依赖服务: session / Query dependency service: session
void* raw = host->query_service("session", 1);
if (!raw) {
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
return -1;
}
g_session = static_cast<const dstalk_session_service_t*>(raw);
return host->register_service("context", 1, &g_context_service);
} catch (const std::exception& e) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init exception: %s", e.what());
return -1;
} catch (...) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init unknown exception");
return -1;
}
}
// W16.2: 包裹 try/catch 防止异常穿越 C ABI 边界 -- void 函数仅 log / Wrapped try/catch prevents exceptions crossing C ABI boundary -- void function only logs
// 插件关闭清空指针。try/catch 保护 ABIvoid 函数) / Plugin shutdown: null out pointers. try/catch guards ABI (void function).
static void on_shutdown() {
try {
g_session = nullptr;
g_host = nullptr;
} catch (const std::exception& e) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: %s", e.what());
g_session = nullptr;
g_host = nullptr;
} catch (...) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: unknown exception");
g_session = nullptr;
g_host = nullptr;
}
}
static dstalk_plugin_info_t g_info = {
"context",
"1.0.0",
"Context management plugin with token counting and trim support / 支持 token 计数和裁剪的上下文管理插件",
DSTALK_API_VERSION,
{"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
on_init,
on_shutdown,
nullptr
};
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info;
}

View File

@@ -0,0 +1,20 @@
# ============================================================
# plugin-openai — OpenAI 兼容 AI 服务 / OpenAI-compatible AI service
# ============================================================
find_package(Boost REQUIRED CONFIG)
add_library(plugin-openai SHARED
src/openai_plugin.cpp
)
target_link_libraries(plugin-openai PRIVATE dstalk)
# Boost.JSON (header-only)
target_link_libraries(plugin-openai PRIVATE boost::boost dstalk_boost_config)
set_target_properties(plugin-openai PROPERTIES
PREFIX ""
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/plugins
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/plugins
)

View File

@@ -0,0 +1,691 @@
/*
* @file openai_plugin.cpp
* @brief OpenAI-compatible AI provider plugin with SSE streaming and tool calls.
* DeepSeek/OpenAI 兼容 AI 提供者插件,支持 SSE 流式输出和工具调用。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h"
#include <boost/json.hpp>
#include <boost/json/src.hpp>
#include <atomic>
#include <cstring>
#include <string>
#include <vector>
namespace json = boost::json;
// ============================================================================
// 全局指针:从 on_init 获取W14.3: atomic acquire/release 保护读写竞态) / Global pointers: obtained from on_init (W14.3: atomic acquire/release protects read/write races)
// ============================================================================
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
static std::atomic<dstalk_config_service_t*> g_config{nullptr};
// ============================================================================
// 配置数据(由 configure() 设置) / Config data (set by configure())
// ============================================================================
struct PluginConfig {
std::string provider;
std::string base_url;
std::string api_key;
std::string model;
int max_tokens = 4096;
double temperature = 0.7;
};
static PluginConfig g_cfg;
static std::string g_tools_json; // W20.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
// ============================================================================
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
// ============================================================================
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
static void secure_zero(void* p, size_t n) {
volatile char* vp = (volatile char*)p;
while (n--) *vp++ = 0;
}
// ============================================================================
// 辅助:从 base_url 提取 host 和 target / Helper: extract host and target from base_url
// ============================================================================
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
static bool extract_host_port(const std::string& url,
std::string& scheme_out, std::string& host_out,
std::string& port_out, std::string& target_out)
{
size_t scheme_end = url.find("://");
if (scheme_end == std::string::npos) return false;
scheme_out = url.substr(0, scheme_end);
std::string rest = url.substr(scheme_end + 3);
size_t slash = rest.find('/');
std::string authority = (slash != std::string::npos) ? rest.substr(0, slash) : rest;
target_out = (slash != std::string::npos) ? rest.substr(slash) : "/";
size_t colon = authority.rfind(':');
if (colon != std::string::npos) {
host_out = authority.substr(0, colon);
port_out = authority.substr(colon + 1);
} else {
host_out = authority;
port_out = (scheme_out == "https") ? "443" : "80";
}
return true;
}
// ============================================================================
// 辅助:构建 headers JSON 字符串 / Helper: build headers JSON string
// ============================================================================
// 构建包含 Bearer 授权令牌的 JSON headers 对象 / Build the JSON headers object containing the Bearer authorization token.
static std::string build_headers_json(const std::string& auth_header_value)
{
json::object h;
h["Authorization"] = "Bearer " + auth_header_value;
return json::serialize(h);
}
// ============================================================================
// 辅助dstalk_message_t[] -> boost::json::array / Helper: dstalk_message_t[] -> boost::json::array
// ============================================================================
// 将 dstalk_message_t 数组转换为 Boost.JSON 数组,用于 API 请求体 / Convert dstalk_message_t array into a Boost.JSON array for the API request body.
static void append_history(json::array& msgs,
const dstalk_message_t* history, int history_len)
{
for (int i = 0; i < history_len; ++i) {
const auto& m = history[i];
json::object obj;
obj["role"] = m.role ? m.role : "";
if (m.role && std::strcmp(m.role, "tool") == 0) {
obj["tool_call_id"] = m.tool_call_id ? m.tool_call_id : "";
obj["content"] = m.content ? m.content : "";
} else if (m.role && std::strcmp(m.role, "assistant") == 0 &&
m.tool_calls_json && m.tool_calls_json[0] != '\0') {
obj["content"] = m.content ? m.content : "";
obj["tool_calls"] = json::parse(m.tool_calls_json);
} else {
obj["content"] = m.content ? m.content : "";
}
msgs.push_back(obj);
}
}
// ============================================================================
// 构建 DeepSeek JSON 请求体 / Build DeepSeek JSON request body
// ============================================================================
// 构建 DeepSeek/OpenAI chat completions API 的完整 JSON 请求体 / Build the full JSON request body for the DeepSeek/OpenAI chat completions API.
static std::string build_request_json(
const dstalk_message_t* history, int history_len,
const std::string& user_input,
const std::string& tools_json,
bool stream)
{
json::object root;
root["model"] = g_cfg.model;
root["max_tokens"] = g_cfg.max_tokens;
root["temperature"] = g_cfg.temperature;
root["stream"] = stream;
json::array msgs;
append_history(msgs, history, history_len);
// 追加当前用户输入 / Append current user input
if (!user_input.empty()) {
json::object obj;
obj["role"] = "user";
obj["content"] = user_input;
msgs.push_back(obj);
}
root["messages"] = msgs;
// tools 定义 / tools definition
if (!tools_json.empty()) {
root["tools"] = json::parse(tools_json);
}
return json::serialize(root);
}
// ============================================================================
// 解析非流式 JSON 响应 / Parse non-streaming JSON response
// ============================================================================
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
static void parse_response(const dstalk_host_api_t* host,
const char* body, int http_status,
dstalk_chat_result_t& r)
{
r.http_status = http_status;
if (http_status < 200 || http_status >= 300) {
r.ok = 0;
try {
auto jv = json::parse(body ? body : "{}");
auto obj = jv.as_object();
if (obj.contains("error")) {
auto err = obj["error"].as_object();
r.error = host ? host->strdup(
json::value_to<std::string>(err["message"]).c_str()) : nullptr;
}
} catch (...) {
std::string msg = "HTTP " + std::to_string(http_status);
r.error = host ? host->strdup(msg.c_str()) : nullptr;
}
if (!r.error && host) {
std::string msg = "HTTP " + std::to_string(http_status);
r.error = host->strdup(msg.c_str());
}
r.content = nullptr;
r.tool_calls_json = nullptr;
return;
}
try {
auto jv = json::parse(body ? body : "{}");
auto obj = jv.as_object();
auto choices = obj["choices"].as_array();
if (!choices.empty()) {
auto msg = choices[0].as_object()["message"].as_object();
std::string content = json::value_to<std::string>(msg["content"]);
r.content = host ? host->strdup(content.c_str()) : nullptr;
if (msg.contains("tool_calls")) {
std::string tc = json::serialize(msg["tool_calls"]);
r.tool_calls_json = host ? host->strdup(tc.c_str()) : nullptr;
} else {
r.tool_calls_json = nullptr;
}
r.ok = 1;
r.error = nullptr;
} else {
r.ok = 0;
r.error = host ? host->strdup("empty response") : nullptr;
r.content = nullptr;
r.tool_calls_json = nullptr;
}
} catch (std::exception& e) {
r.ok = 0;
std::string msg = std::string("json parse: ") + e.what();
r.error = host ? host->strdup(msg.c_str()) : nullptr;
r.content = nullptr;
r.tool_calls_json = nullptr;
} catch (...) {
r.ok = 0;
r.error = host ? host->strdup("json parse error") : nullptr;
r.content = nullptr;
r.tool_calls_json = nullptr;
}
}
// ============================================================================
// 流式上下文:在 SSE 回调间累积内容和 tool_calls / Stream context: accumulate content and tool_calls across SSE callbacks
// ============================================================================
struct ToolCallAccum {
int index = -1;
std::string id;
std::string name;
std::string arguments; // 增量拼接的 JSON arguments 字符串 / incrementally concatenated JSON arguments string
};
struct StreamContext {
const dstalk_host_api_t* host;
dstalk_stream_cb user_cb;
void* userdata;
std::string accumulated;
bool streaming_ok = true;
std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls / accumulate delta tool_calls by index
};
// ============================================================================
// SSE 行解析OpenAI 兼容格式) / SSE line parsing (OpenAI-compatible format)
// ============================================================================
// 解析单行 SSE "data:" 行。如果包含 content delta将 token 写入 token_out。
// 如果包含 tool_calls delta累积到 ctx->tool_calls。
// 如果产生了 content token 则返回 true否则返回 falsetool_calls 或未知)。
// Parse a single SSE "data:" line. If it contains a content delta, writes the token
// to token_out. If it contains tool_calls delta, accumulates into ctx->tool_calls.
// Returns true if a content token was produced, false otherwise (tool_calls or unknown).
static bool parse_sse_line(const std::string& line, std::string& token_out,
StreamContext* ctx)
{
if (line.rfind("data: ", 0) != 0) return false;
std::string data = line.substr(6);
// F-13.2-3: 比较 [DONE] 哨兵前去除首尾空白 / Trim leading/trailing whitespace before comparing [DONE] sentinel.
const char* ws = " \t\r\n";
size_t start = data.find_first_not_of(ws);
if (start != std::string::npos) {
data.erase(0, start);
data.erase(data.find_last_not_of(ws) + 1);
}
if (data == "[DONE]") {
token_out.clear();
return true; // 流结束信号 / stream end signal
}
try {
auto jv = json::parse(data);
auto obj = jv.as_object();
auto choices = obj["choices"].as_array();
if (!choices.empty()) {
auto delta = choices[0].as_object()["delta"].as_object();
// W20.2: 处理 delta["tool_calls"] 增量 chunk / Handle delta["tool_calls"] incremental chunks
// DeepSeek/OpenAI 流式模式 tool_calls 跨多个 SSE 事件分片传输 / DeepSeek/OpenAI streaming mode: tool_calls transmitted across multiple SSE event chunks:
// 事件 1 / Event 1: {"index":0, "id":"call_xxx", "function":{"name":"foo"}}
// 事件 2 / Event 2: {"index":0, "function":{"arguments":"{\"bar\":"}}
// 事件 3 / Event 3: {"index":0, "function":{"arguments":"1}"}}
// 需要按 index 累积 id/name/arguments / Need to accumulate id/name/arguments by index.
if (delta.contains("tool_calls") && ctx) {
auto tc_array = delta["tool_calls"].as_array();
for (auto& tc_val : tc_array) {
auto tc_obj = tc_val.as_object();
int idx = tc_obj.contains("index")
? static_cast<int>(json::value_to<int64_t>(tc_obj["index"])) : -1;
if (idx < 0) continue;
while (static_cast<int>(ctx->tool_calls.size()) <= idx) {
ctx->tool_calls.push_back({});
}
auto& acc = ctx->tool_calls[idx];
acc.index = idx;
if (tc_obj.contains("id") && tc_obj["id"].is_string()) {
acc.id = json::value_to<std::string>(tc_obj["id"]);
}
if (tc_obj.contains("function") && tc_obj["function"].is_object()) {
auto func = tc_obj["function"].as_object();
if (func.contains("name") && func["name"].is_string()) {
acc.name = json::value_to<std::string>(func["name"]);
}
if (func.contains("arguments") && func["arguments"].is_string()) {
acc.arguments += json::value_to<std::string>(func["arguments"]);
}
}
}
return false; // tool_calls 已处理,无内容 token 给用户回调 / tool_calls processed, no content token for user callback
}
if (delta.contains("content")) {
token_out = json::value_to<std::string>(delta["content"]);
return true;
}
}
} catch (...) {
// 忽略解析失败 / Ignore parse failures
}
return false;
}
// ============================================================================
// configure 实现 / configure implementation
// ============================================================================
// 配置插件provider、endpoint、auth、model 和生成参数 / Configure the plugin with provider, endpoint, auth, model, and generation parameters.
static int my_configure(const char* provider, const char* base_url,
const char* api_key, const char* model,
int max_tokens, double temperature)
{
try {
if (provider) g_cfg.provider = provider;
if (base_url) g_cfg.base_url = base_url;
if (api_key) g_cfg.api_key = api_key;
if (model) g_cfg.model = model;
g_cfg.max_tokens = max_tokens;
g_cfg.temperature = temperature;
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) {
// W20.2: 从 tools service 缓存 tools_json供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
host->query_service("tools", 1));
if (tools_svc && tools_svc->get_tools_json) {
char* json = tools_svc->get_tools_json();
if (json) {
g_tools_json = json;
host->free(json);
}
}
host->log(DSTALK_LOG_INFO,
"[openai] configured: model=%s base_url=%s max_tokens=%d temperature=%.2f",
g_cfg.model.c_str(), g_cfg.base_url.c_str(),
g_cfg.max_tokens, g_cfg.temperature);
}
return 0;
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_configure exception: %s", e.what());
return -1;
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_configure unknown exception");
return -1;
}
}
// ============================================================================
// chat 实现 / chat implementation
// ============================================================================
// 非流式 chat completion发送 history + user input返回完整响应 / Non-streaming chat completion: send history + user input, return full response.
static dstalk_chat_result_t my_chat(
const dstalk_message_t* history, int history_len,
const char* user_input,
const char* tools_json)
{
try {
dstalk_chat_result_t r = {};
r.ok = 0;
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
dstalk_http_service_t* http = g_http.load(std::memory_order_acquire);
if (!http) {
r.error = host ? host->strdup("http service not available") : nullptr;
return r;
}
std::string scheme, host_name, port, target;
extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
std::string target_path = target + "/chat/completions";
std::string body = build_request_json(history, history_len,
user_input ? user_input : "", tools_json ? tools_json : "", false);
std::string headers_json = build_headers_json(g_cfg.api_key);
char* response_body = nullptr;
int status_code = 0;
int ret = http->post_json(
host_name.c_str(), port.c_str(), target_path.c_str(), body.c_str(),
headers_json.c_str(), &response_body, &status_code);
if (ret != 0) {
r.error = host ? host->strdup("http request failed") : nullptr;
return r;
}
parse_response(host, response_body, status_code, r);
if (response_body) {
if (host) host->free(response_body);
}
return r;
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat exception: %s", e.what());
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup(e.what()) : nullptr;
return r;
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat unknown exception");
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup("unknown exception") : nullptr;
return r;
}
}
// ============================================================================
// chat_stream 实现 / chat_stream implementation
// ============================================================================
// 行回调:解析 SSE line将 token 传递给用户回调 / SSE line callback: parses each line and forwards content tokens to the user callback.
static int sse_line_callback(const char* line, void* userdata)
{
try {
auto* ctx = static_cast<StreamContext*>(userdata);
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
std::string line_str(line);
std::string token;
if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续 / not a data/tool_calls line, continue
if (token.empty()) return 0; // [DONE],停止 / [DONE], stop
ctx->accumulated += token;
if (ctx->user_cb) {
return ctx->user_cb(token.c_str(), ctx->userdata);
}
return 1; // 继续 / continue
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] sse_line_callback exception: %s", e.what());
return 0;
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] sse_line_callback unknown exception");
return 0;
}
}
// 流式 chat completion以 stream=true 发送 history + user input通过回调传递 token。
// 在 SSE 分片中累积 tool_calls 并在结束时序列化 / Streaming chat completion: send history + user input with stream=true, deliver tokens
// via callback. Accumulates tool_calls across SSE chunks and serializes them at end.
static dstalk_chat_result_t my_chat_stream(
const dstalk_message_t* history, int history_len,
const char* user_input,
dstalk_stream_cb cb, void* userdata)
{
try {
dstalk_chat_result_t r = {};
r.ok = 0;
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
dstalk_http_service_t* http = g_http.load(std::memory_order_acquire);
if (!http) {
r.error = host ? host->strdup("http service not available") : nullptr;
return r;
}
std::string scheme, host_name, port, target;
extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
std::string target_path = target + "/chat/completions";
std::string body = build_request_json(history, history_len,
user_input ? user_input : "", g_tools_json, true);
std::string headers_json = build_headers_json(g_cfg.api_key);
StreamContext ctx;
ctx.host = host;
ctx.user_cb = cb;
ctx.userdata = userdata;
char* response_body = nullptr;
int status_code = 0;
int ret = http->post_stream(
host_name.c_str(), port.c_str(), target_path.c_str(), body.c_str(),
headers_json.c_str(),
sse_line_callback, &ctx,
&response_body, &status_code);
r.http_status = status_code;
// 检查传输层错误或非 2xx 状态 / Check transport errors or non-2xx status
if (status_code < 200 || status_code >= 300) {
r.ok = 0;
// 尝试从响应体提取错误信息 / Try to extract error info from response body
if (response_body && response_body[0]) {
try {
auto jv = json::parse(response_body);
auto obj = jv.as_object();
if (obj.contains("error")) {
auto err = obj["error"].as_object();
r.error = host ? host->strdup(
json::value_to<std::string>(err["message"]).c_str()) : nullptr;
}
} catch (...) {}
}
if (!r.error && host) {
if (status_code <= 0)
r.error = host->strdup("transport error");
else
r.error = host->strdup(
("HTTP " + std::to_string(status_code)).c_str());
}
if (response_body && host) host->free(response_body);
r.content = nullptr;
r.tool_calls_json = nullptr;
return r;
}
if (response_body && host) host->free(response_body);
// W20.2: 成功条件 = 有内容 OR 有 tool_callstool-only 响应如 function calling / Success = has content OR has tool_calls (tool-only responses like function calling)
bool has_content = !ctx.accumulated.empty();
bool has_tool_calls = !ctx.tool_calls.empty();
if (!has_content && !has_tool_calls) {
r.ok = 0;
r.error = host ? host->strdup("no content received") : nullptr;
r.content = nullptr;
r.tool_calls_json = nullptr;
} else {
r.ok = 1;
r.error = nullptr;
r.content = has_content
? host->strdup(ctx.accumulated.c_str()) : nullptr;
// 序列化累积的 tool_calls 为 JSON兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible tool_calls format)
if (has_tool_calls) {
json::array tc_array;
for (auto& tc : ctx.tool_calls) {
json::object tc_obj;
tc_obj["index"] = tc.index;
if (!tc.id.empty()) tc_obj["id"] = tc.id;
tc_obj["type"] = "function";
json::object func;
if (!tc.name.empty()) func["name"] = tc.name;
func["arguments"] = tc.arguments;
tc_obj["function"] = func;
tc_array.push_back(std::move(tc_obj));
}
std::string tc_json = json::serialize(tc_array);
r.tool_calls_json = host ? host->strdup(tc_json.c_str()) : nullptr;
} else {
r.tool_calls_json = nullptr;
}
}
return r;
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat_stream exception: %s", e.what());
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup(e.what()) : nullptr;
return r;
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat_stream unknown exception");
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup("unknown exception") : nullptr;
return r;
}
}
// ============================================================================
// free_result 实现 / free_result implementation
// ============================================================================
// 释放 chat result 结构体中所有主机分配的字符串字段 / Free all host-allocated string fields in a chat result struct.
static void my_free_result(dstalk_chat_result_t* result)
{
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (!result || !host) return;
if (result->content) { host->free((void*)result->content); result->content = nullptr; }
if (result->error) { host->free((void*)result->error); result->error = nullptr; }
if (result->tool_calls_json) { host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; }
}
// ============================================================================
// 服务 vtable / Service vtable
// ============================================================================
static dstalk_ai_service_t g_service = {
&my_configure,
&my_chat,
&my_chat_stream,
&my_free_result,
};
// ============================================================================
// 生命周期 / Lifecycle
// ============================================================================
// 插件初始化:查询 http 和 config 服务,注册 ai.openai 服务 / Plugin init: query http and config services, register ai.openai service.
static int on_init(const dstalk_host_api_t* host)
{
try {
dstalk_http_service_t* http = (dstalk_http_service_t*)host->query_service("http", 1);
dstalk_config_service_t* cfg = (dstalk_config_service_t*)host->query_service("config", 1);
g_host.store(host, std::memory_order_release);
g_http.store(http, std::memory_order_release);
g_config.store(cfg, std::memory_order_release);
if (!http) {
if (host) host->log(DSTALK_LOG_ERROR, "[openai] http service not found");
return -1;
}
if (host) host->log(DSTALK_LOG_INFO, "[openai] initializing OpenAI-compatible AI plugin");
return host->register_service("ai.openai", 1, &g_service);
} catch (const std::exception& e) {
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[openai] on_init exception: %s", e.what());
return -1;
} catch (...) {
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[openai] on_init unknown exception");
return -1;
}
}
// 插件关闭:从内存安全擦除 API key清空服务指针 / Plugin shutdown: securely erase API key from memory, null out service pointers.
static void on_shutdown()
{
try {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_INFO, "[openai] shutdown");
secure_zero(g_cfg.api_key.data(), g_cfg.api_key.size());
g_cfg.api_key.clear();
g_http.store(nullptr, std::memory_order_release);
g_config.store(nullptr, std::memory_order_release);
g_host.store(nullptr, std::memory_order_release);
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] on_shutdown exception: %s", e.what());
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] on_shutdown unknown exception");
}
}
// ============================================================================
// 插件描述符 / Plugin descriptor
// ============================================================================
static dstalk_plugin_info_t g_info = {
/* .name = */ "openai-compat",
/* .version = */ "1.0.0",
/* .description = */ "OpenAI-compatible AI provider (OpenAI-compatible API) / OpenAI-compatible AI 提供者 (OpenAI 兼容 API)",
/* .api_version = */ DSTALK_API_VERSION,
/* .dependencies = */ { "http", "config", NULL },
/* .on_init = */ on_init,
/* .on_shutdown = */ on_shutdown,
/* .on_event = */ nullptr,
};
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
{
return &g_info;
}