W17: extract ai_common shared module + fix anthropic data race + brace bugs

- New plugins_upper/ai_common/ static library: shared PluginConfig, ToolCallAccum,
  StreamContext, secure_zero, extract_host_port, serialize_tool_calls, free_chat_result
- Refactored openai/anthropic plugins to use dstalk_ai:: namespace from ai_common
- Fixed anthropic g_config raw pointer → std::atomic (data race)
- Added SSE parse error counter with threshold abort (kMaxSseParseErrors=5)
- Fixed missing closing brace in both plugins' error-body catch block
- Updated test targets: ai_common include path + link, using namespace dstalk_ai
- plugin_loader_test: added stub_unreg + service_registry.cpp for unregister_service
- Includes pre-existing uncommitted changes from prior waves

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-31 16:58:25 +08:00
parent ba7382db2a
commit 8faa02c3d5
49 changed files with 1062 additions and 413 deletions

View File

@@ -1,19 +1,19 @@
# ============================================================
# plugin-openai — OpenAI 兼容 AI 服务 / OpenAI-compatible AI service
# plugin_openai — OpenAI 兼容 AI 服务 / OpenAI-compatible AI service
# ============================================================
find_package(Boost REQUIRED CONFIG)
add_library(plugin-openai SHARED
add_library(plugin_openai SHARED
src/openai_plugin.cpp
)
target_link_libraries(plugin-openai PRIVATE dstalk)
target_link_libraries(plugin_openai PRIVATE dstalk ai_common)
# Boost.JSON (header-only)
target_link_libraries(plugin-openai PRIVATE boost::boost dstalk_boost_config)
target_link_libraries(plugin_openai PRIVATE boost::boost dstalk_boost_config)
set_target_properties(plugin-openai PROPERTIES
set_target_properties(plugin_openai PROPERTIES
PREFIX ""
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/plugins
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/plugins

View File

@@ -7,6 +7,7 @@
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h"
#include "ai_common.hpp"
#include <boost/json.hpp>
#include <boost/json/src.hpp>
@@ -18,7 +19,7 @@
namespace json = boost::json;
// ============================================================================
// 全局指针:从 on_init 获取(W14.3: atomic acquire/release 保护读写竞态) / Global pointers: obtained from on_init (W14.3: atomic acquire/release protects read/write races)
// 全局指针:从 on_init 获取atomic acquire/release 保护读写竞态) / Global pointers: obtained from on_init (atomic acquire/release protects read/write races)
// ============================================================================
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
@@ -27,52 +28,9 @@ static std::atomic<dstalk_config_service_t*> g_config{nullptr};
// ============================================================================
// 配置数据(由 configure() 设置) / Config data (set by configure())
// ============================================================================
struct PluginConfig {
std::string provider;
std::string base_url;
std::string api_key;
std::string model;
int max_tokens = 4096;
double temperature = 0.7;
};
static PluginConfig g_cfg;
static dstalk_ai::PluginConfig g_cfg;
static std::string g_tools_json; // W20.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
// ============================================================================
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
// ============================================================================
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
static void secure_zero(void* p, size_t n) {
volatile char* vp = (volatile char*)p;
while (n--) *vp++ = 0;
}
// ============================================================================
// 辅助:从 base_url 提取 host 和 target / Helper: extract host and target from base_url
// ============================================================================
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
static bool extract_host_port(const std::string& url,
std::string& scheme_out, std::string& host_out,
std::string& port_out, std::string& target_out)
{
size_t scheme_end = url.find("://");
if (scheme_end == std::string::npos) return false;
scheme_out = url.substr(0, scheme_end);
std::string rest = url.substr(scheme_end + 3);
size_t slash = rest.find('/');
std::string authority = (slash != std::string::npos) ? rest.substr(0, slash) : rest;
target_out = (slash != std::string::npos) ? rest.substr(slash) : "/";
size_t colon = authority.rfind(':');
if (colon != std::string::npos) {
host_out = authority.substr(0, colon);
port_out = authority.substr(colon + 1);
} else {
host_out = authority;
port_out = (scheme_out == "https") ? "443" : "80";
}
return true;
}
// ============================================================================
// 辅助:构建 headers JSON 字符串 / Helper: build headers JSON string
// ============================================================================
@@ -219,25 +177,6 @@ static void parse_response(const dstalk_host_api_t* host,
}
}
// ============================================================================
// 流式上下文:在 SSE 回调间累积内容和 tool_calls / Stream context: accumulate content and tool_calls across SSE callbacks
// ============================================================================
struct ToolCallAccum {
int index = -1;
std::string id;
std::string name;
std::string arguments; // 增量拼接的 JSON arguments 字符串 / incrementally concatenated JSON arguments string
};
struct StreamContext {
const dstalk_host_api_t* host;
dstalk_stream_cb user_cb;
void* userdata;
std::string accumulated;
bool streaming_ok = true;
std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls / accumulate delta tool_calls by index
};
// ============================================================================
// SSE 行解析OpenAI 兼容格式) / SSE line parsing (OpenAI-compatible format)
// ============================================================================
@@ -248,7 +187,7 @@ struct StreamContext {
// to token_out. If it contains tool_calls delta, accumulates into ctx->tool_calls.
// Returns true if a content token was produced, false otherwise (tool_calls or unknown).
static bool parse_sse_line(const std::string& line, std::string& token_out,
StreamContext* ctx)
dstalk_ai::StreamContext* ctx)
{
if (line.rfind("data: ", 0) != 0) return false;
@@ -263,6 +202,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
}
if (data == "[DONE]") {
token_out.clear();
if (ctx) ctx->sse_parse_errors = 0; // 成功解析,重置错误计数 / successful parse, reset error counter
return true; // 流结束信号 / stream end signal
}
@@ -307,16 +247,31 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
}
}
}
if (ctx) ctx->sse_parse_errors = 0; // 成功解析 / successful parse
return false; // tool_calls 已处理,无内容 token 给用户回调 / tool_calls processed, no content token for user callback
}
if (delta.contains("content")) {
token_out = json::value_to<std::string>(delta["content"]);
if (ctx) ctx->sse_parse_errors = 0; // 成功解析 / successful parse
return true;
}
}
// 有效 JSON 但不是已知格式 — 非错误,只是未知事件类型 / valid JSON but unknown format — not an error, just unknown event type
// 重置计数器JSON 本身解析成功 / reset counter: JSON itself parsed successfully
if (ctx) ctx->sse_parse_errors = 0;
} catch (...) {
// 忽略解析失败 / Ignore parse failures
if (ctx) {
ctx->sse_parse_errors++;
const dstalk_host_api_t* log_host = g_host.load(std::memory_order_acquire);
if (log_host) {
if (ctx->sse_parse_errors == 1 || ctx->sse_parse_errors % 5 == 0) {
log_host->log(DSTALK_LOG_WARN,
"[openai] SSE parse error (#%d consecutive)",
ctx->sse_parse_errors);
}
}
}
}
return false;
}
@@ -340,15 +295,7 @@ static int my_configure(const char* provider, const char* base_url,
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) {
// W20.2: 从 tools service 缓存 tools_json供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
host->query_service("tools", 1));
if (tools_svc && tools_svc->get_tools_json) {
char* json = tools_svc->get_tools_json();
if (json) {
g_tools_json = json;
host->free(json);
}
}
dstalk_ai::cache_tools_json(host, g_tools_json);
host->log(DSTALK_LOG_INFO,
"[openai] configured: model=%s base_url=%s max_tokens=%d temperature=%.2f",
@@ -376,6 +323,8 @@ static dstalk_chat_result_t my_chat(
const char* user_input,
const char* tools_json)
{
char* response_body = nullptr;
int status_code = 0;
try {
dstalk_chat_result_t r = {};
r.ok = 0;
@@ -389,7 +338,7 @@ static dstalk_chat_result_t my_chat(
}
std::string scheme, host_name, port, target;
extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
dstalk_ai::extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
std::string target_path = target + "/chat/completions";
std::string body = build_request_json(history, history_len,
@@ -397,15 +346,13 @@ static dstalk_chat_result_t my_chat(
std::string headers_json = build_headers_json(g_cfg.api_key);
char* response_body = nullptr;
int status_code = 0;
int ret = http->post_json(
host_name.c_str(), port.c_str(), target_path.c_str(), body.c_str(),
headers_json.c_str(), &response_body, &status_code);
if (ret != 0) {
r.error = host ? host->strdup("http request failed") : nullptr;
if (response_body) host->free(response_body);
return r;
}
@@ -418,6 +365,7 @@ static dstalk_chat_result_t my_chat(
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat exception: %s", e.what());
if (response_body && host) host->free(response_body);
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup(e.what()) : nullptr;
@@ -425,6 +373,7 @@ static dstalk_chat_result_t my_chat(
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat unknown exception");
if (response_body && host) host->free(response_body);
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup("unknown exception") : nullptr;
@@ -440,7 +389,7 @@ static dstalk_chat_result_t my_chat(
static int sse_line_callback(const char* line, void* userdata)
{
try {
auto* ctx = static_cast<StreamContext*>(userdata);
auto* ctx = static_cast<dstalk_ai::StreamContext*>(userdata);
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
std::string line_str(line);
@@ -448,6 +397,15 @@ static int sse_line_callback(const char* line, void* userdata)
if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续 / not a data/tool_calls line, continue
// W21.5: 连续 SSE 解析错误超过阈值,中止流 / consecutive SSE parse errors exceed threshold, abort stream
if (ctx && ctx->sse_parse_errors >= dstalk_ai::kMaxSseParseErrors) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR,
"[openai] SSE stream aborted: %d consecutive parse errors",
ctx->sse_parse_errors);
return 0;
}
if (token.empty()) return 0; // [DONE],停止 / [DONE], stop
ctx->accumulated += token;
@@ -475,6 +433,8 @@ static dstalk_chat_result_t my_chat_stream(
const char* user_input,
dstalk_stream_cb cb, void* userdata)
{
char* response_body = nullptr;
int status_code = 0;
try {
dstalk_chat_result_t r = {};
r.ok = 0;
@@ -488,7 +448,7 @@ static dstalk_chat_result_t my_chat_stream(
}
std::string scheme, host_name, port, target;
extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
dstalk_ai::extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
std::string target_path = target + "/chat/completions";
std::string body = build_request_json(history, history_len,
@@ -496,14 +456,11 @@ static dstalk_chat_result_t my_chat_stream(
std::string headers_json = build_headers_json(g_cfg.api_key);
StreamContext ctx;
dstalk_ai::StreamContext ctx;
ctx.host = host;
ctx.user_cb = cb;
ctx.userdata = userdata;
char* response_body = nullptr;
int status_code = 0;
int ret = http->post_stream(
host_name.c_str(), port.c_str(), target_path.c_str(), body.c_str(),
headers_json.c_str(),
@@ -525,7 +482,9 @@ static dstalk_chat_result_t my_chat_stream(
r.error = host ? host->strdup(
json::value_to<std::string>(err["message"]).c_str()) : nullptr;
}
} catch (...) {}
} catch (...) {
if (host) host->log(DSTALK_LOG_WARN, "[openai] SSE error body parse error (ignored)");
}
}
if (!r.error && host) {
if (status_code <= 0)
@@ -559,19 +518,7 @@ static dstalk_chat_result_t my_chat_stream(
// 序列化累积的 tool_calls 为 JSON兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible tool_calls format)
if (has_tool_calls) {
json::array tc_array;
for (auto& tc : ctx.tool_calls) {
json::object tc_obj;
tc_obj["index"] = tc.index;
if (!tc.id.empty()) tc_obj["id"] = tc.id;
tc_obj["type"] = "function";
json::object func;
if (!tc.name.empty()) func["name"] = tc.name;
func["arguments"] = tc.arguments;
tc_obj["function"] = func;
tc_array.push_back(std::move(tc_obj));
}
std::string tc_json = json::serialize(tc_array);
std::string tc_json = dstalk_ai::serialize_tool_calls(ctx.tool_calls);
r.tool_calls_json = host ? host->strdup(tc_json.c_str()) : nullptr;
} else {
r.tool_calls_json = nullptr;
@@ -581,6 +528,7 @@ static dstalk_chat_result_t my_chat_stream(
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat_stream exception: %s", e.what());
if (response_body && host) host->free(response_body);
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup(e.what()) : nullptr;
@@ -588,6 +536,7 @@ static dstalk_chat_result_t my_chat_stream(
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat_stream unknown exception");
if (response_body && host) host->free(response_body);
dstalk_chat_result_t r = {};
r.ok = 0;
r.error = host ? host->strdup("unknown exception") : nullptr;
@@ -602,10 +551,7 @@ static dstalk_chat_result_t my_chat_stream(
static void my_free_result(dstalk_chat_result_t* result)
{
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (!result || !host) return;
if (result->content) { host->free((void*)result->content); result->content = nullptr; }
if (result->error) { host->free((void*)result->error); result->error = nullptr; }
if (result->tool_calls_json) { host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; }
dstalk_ai::free_chat_result(host, result);
}
// ============================================================================
@@ -638,7 +584,7 @@ static int on_init(const dstalk_host_api_t* host)
if (host) host->log(DSTALK_LOG_INFO, "[openai] initializing OpenAI-compatible AI plugin");
return host->register_service("ai.openai", 1, &g_service);
return host->register_service("ai_openai", 1, &g_service);
} catch (const std::exception& e) {
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[openai] on_init exception: %s", e.what());
@@ -656,7 +602,7 @@ static void on_shutdown()
try {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_INFO, "[openai] shutdown");
secure_zero(g_cfg.api_key.data(), g_cfg.api_key.size());
dstalk_ai::secure_zero(g_cfg.api_key.data(), g_cfg.api_key.size());
g_cfg.api_key.clear();
g_http.store(nullptr, std::memory_order_release);
g_config.store(nullptr, std::memory_order_release);
@@ -674,7 +620,7 @@ static void on_shutdown()
// 插件描述符 / Plugin descriptor
// ============================================================================
static dstalk_plugin_info_t g_info = {
/* .name = */ "openai-compat",
/* .name = */ "openai_compat",
/* .version = */ "1.0.0",
/* .description = */ "OpenAI-compatible AI provider (OpenAI-compatible API) / OpenAI-compatible AI 提供者 (OpenAI 兼容 API)",
/* .api_version = */ DSTALK_API_VERSION,