W17: extract ai_common shared module + fix anthropic data race + brace bugs
- New plugins_upper/ai_common/ static library: shared PluginConfig, ToolCallAccum, StreamContext, secure_zero, extract_host_port, serialize_tool_calls, free_chat_result - Refactored openai/anthropic plugins to use dstalk_ai:: namespace from ai_common - Fixed anthropic g_config raw pointer → std::atomic (data race) - Added SSE parse error counter with threshold abort (kMaxSseParseErrors=5) - Fixed missing closing brace in both plugins' error-body catch block - Updated test targets: ai_common include path + link, using namespace dstalk_ai - plugin_loader_test: added stub_unreg + service_registry.cpp for unregister_service - Includes pre-existing uncommitted changes from prior waves Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# 依赖其他插件的插件 / Plugins depending on non-base plugins
|
||||
# ============================================================
|
||||
|
||||
add_subdirectory(ai_common) # 共享 AI 工具库(静态库)/ shared AI utility library (static)
|
||||
add_subdirectory(context) # 依赖 session / depends on session
|
||||
add_subdirectory(openai) # 依赖 http, config / depends on http, config
|
||||
add_subdirectory(anthropic) # 依赖 http, config / depends on http, config
|
||||
add_subdirectory(openai) # 依赖 http, config, ai_common / depends on http, config, ai_common
|
||||
add_subdirectory(anthropic) # 依赖 http, config, ai_common / depends on http, config, ai_common
|
||||
|
||||
20
plugins_upper/ai_common/CMakeLists.txt
Normal file
20
plugins_upper/ai_common/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
# ============================================================
|
||||
# ai_common — 共享 AI 插件工具库(静态库)/ Shared AI plugin utility library (static)
|
||||
# ============================================================
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
add_library(ai_common STATIC
|
||||
src/ai_common.cpp
|
||||
)
|
||||
|
||||
target_include_directories(ai_common PUBLIC include)
|
||||
|
||||
target_compile_features(ai_common PUBLIC cxx_std_20)
|
||||
|
||||
target_link_libraries(ai_common
|
||||
PUBLIC
|
||||
dstalk
|
||||
dstalk_boost_config
|
||||
boost::boost
|
||||
)
|
||||
118
plugins_upper/ai_common/include/ai_common.hpp
Normal file
118
plugins_upper/ai_common/include/ai_common.hpp
Normal file
@@ -0,0 +1,118 @@
|
||||
/*
|
||||
* @file ai_common.hpp
|
||||
* @brief Shared types and utilities for AI provider plugins (OpenAI / Anthropic).
|
||||
* AI 提供者插件(OpenAI / Anthropic)的共享类型和工具函数。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
|
||||
#include <boost/json.hpp>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace dstalk_ai {
|
||||
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 共享类型 / Shared types
|
||||
// ============================================================================
|
||||
|
||||
/// Provider connection configuration / 服务商连接配置
|
||||
struct PluginConfig {
|
||||
std::string provider;
|
||||
std::string base_url;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
int max_tokens = 4096;
|
||||
double temperature = 0.7;
|
||||
};
|
||||
|
||||
/// Per-index tool-call accumulator for SSE streaming / SSE 流式传输的按索引工具调用累积器
|
||||
struct ToolCallAccum {
|
||||
int index = -1;
|
||||
std::string id;
|
||||
std::string name;
|
||||
std::string arguments; // 增量拼接的 JSON arguments / incrementally concatenated JSON arguments
|
||||
};
|
||||
|
||||
/// Streaming context passed through SSE callbacks / 通过 SSE 回调传递的流式上下文
|
||||
struct StreamContext {
|
||||
const dstalk_host_api_t* host = nullptr;
|
||||
dstalk_stream_cb user_cb = nullptr;
|
||||
void* userdata = nullptr;
|
||||
std::string accumulated;
|
||||
std::vector<ToolCallAccum> tool_calls;
|
||||
int sse_parse_errors = 0; // 连续 SSE 解析错误计数器 / consecutive SSE parse error counter
|
||||
bool streaming_ok = true; // OpenAI: tracks stream health
|
||||
bool saw_data_line = false; // Anthropic: tracks if any SSE data received
|
||||
};
|
||||
|
||||
/// Maximum consecutive SSE parse errors before aborting the stream / 中止流之前的最大连续 SSE 解析错误数
|
||||
inline constexpr int kMaxSseParseErrors = 5;
|
||||
|
||||
// ============================================================================
|
||||
// 函数声明(实现于 ai_common.cpp) / Function declarations (implemented in ai_common.cpp)
|
||||
// ============================================================================
|
||||
|
||||
/// Securely zero memory through volatile write to prevent compiler optimization.
|
||||
/// 通过 volatile 写入零来安全擦除内存,防止编译器优化。
|
||||
void secure_zero(void* p, size_t n);
|
||||
|
||||
/// Parse a URL into scheme, host, port, and target path components.
|
||||
/// 将 URL 解析为 scheme、host、port 和 target path 组件。
|
||||
bool extract_host_port(const std::string& url,
|
||||
std::string& scheme_out, std::string& host_out,
|
||||
std::string& port_out, std::string& target_out);
|
||||
|
||||
// ============================================================================
|
||||
// 内联工具函数 / Inline utility functions
|
||||
// ============================================================================
|
||||
|
||||
/// Free all host-allocated string fields in a chat result struct.
|
||||
/// 释放 chat result 结构体中所有主机分配的字符串字段。
|
||||
inline void free_chat_result(const dstalk_host_api_t* host, dstalk_chat_result_t* result) {
|
||||
if (!result || !host) return;
|
||||
if (result->content) { host->free((void*)result->content); result->content = nullptr; }
|
||||
if (result->error) { host->free((void*)result->error); result->error = nullptr; }
|
||||
if (result->tool_calls_json) { host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; }
|
||||
}
|
||||
|
||||
/// Cache tools_json from the tools service for reuse in chat/chat_stream.
|
||||
/// 从 tools service 缓存 tools_json,供 chat/chat_stream 复用。
|
||||
inline void cache_tools_json(const dstalk_host_api_t* host, std::string& tools_json) {
|
||||
if (!host) return;
|
||||
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
|
||||
host->query_service("tools", 1));
|
||||
if (tools_svc && tools_svc->get_tools_json) {
|
||||
char* j = tools_svc->get_tools_json();
|
||||
if (j) {
|
||||
tools_json = j;
|
||||
host->free(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Serialize accumulated tool_calls into OpenAI-compatible JSON array.
|
||||
/// 将累积的 tool_calls 序列化为兼容 OpenAI 格式的 JSON 数组。
|
||||
inline std::string serialize_tool_calls(const std::vector<ToolCallAccum>& tool_calls) {
|
||||
json::array tc_array;
|
||||
for (const auto& tc : tool_calls) {
|
||||
json::object tc_obj;
|
||||
tc_obj["index"] = tc.index;
|
||||
if (!tc.id.empty()) tc_obj["id"] = tc.id;
|
||||
tc_obj["type"] = "function";
|
||||
json::object func;
|
||||
if (!tc.name.empty()) func["name"] = tc.name;
|
||||
func["arguments"] = tc.arguments;
|
||||
tc_obj["function"] = func;
|
||||
tc_array.push_back(std::move(tc_obj));
|
||||
}
|
||||
return json::serialize(tc_array);
|
||||
}
|
||||
|
||||
} // namespace dstalk_ai
|
||||
39
plugins_upper/ai_common/src/ai_common.cpp
Normal file
39
plugins_upper/ai_common/src/ai_common.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
* @file ai_common.cpp
|
||||
* @brief Shared utility implementations for AI provider plugins.
|
||||
* AI 提供者插件的共享工具函数实现。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "ai_common.hpp"
|
||||
|
||||
namespace dstalk_ai {
|
||||
|
||||
void secure_zero(void* p, size_t n) {
|
||||
volatile char* vp = static_cast<volatile char*>(p);
|
||||
while (n--) *vp++ = 0;
|
||||
}
|
||||
|
||||
bool extract_host_port(const std::string& url,
|
||||
std::string& scheme_out, std::string& host_out,
|
||||
std::string& port_out, std::string& target_out)
|
||||
{
|
||||
size_t scheme_end = url.find("://");
|
||||
if (scheme_end == std::string::npos) return false;
|
||||
scheme_out = url.substr(0, scheme_end);
|
||||
std::string rest = url.substr(scheme_end + 3);
|
||||
size_t slash = rest.find('/');
|
||||
std::string authority = (slash != std::string::npos) ? rest.substr(0, slash) : rest;
|
||||
target_out = (slash != std::string::npos) ? rest.substr(slash) : "/";
|
||||
size_t colon = authority.rfind(':');
|
||||
if (colon != std::string::npos) {
|
||||
host_out = authority.substr(0, colon);
|
||||
port_out = authority.substr(colon + 1);
|
||||
} else {
|
||||
host_out = authority;
|
||||
port_out = (scheme_out == "https") ? "443" : "80";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace dstalk_ai
|
||||
@@ -1,21 +1,21 @@
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
|
||||
# ============================================================
|
||||
# plugin-anthropic — Anthropic Claude AI 服务
|
||||
# plugin_anthropic — Anthropic Claude AI 服务
|
||||
# 依赖: http 服务 (查询), config 服务 (查询)
|
||||
# ============================================================
|
||||
|
||||
add_library(plugin-anthropic SHARED
|
||||
add_library(plugin_anthropic SHARED
|
||||
src/anthropic_plugin.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(plugin-anthropic PRIVATE dstalk)
|
||||
target_link_libraries(plugin_anthropic PRIVATE dstalk ai_common)
|
||||
|
||||
# Boost.JSON 用于构建/解析请求和响应
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
target_link_libraries(plugin-anthropic PRIVATE boost::boost dstalk_boost_config)
|
||||
target_link_libraries(plugin_anthropic PRIVATE boost::boost dstalk_boost_config)
|
||||
|
||||
set_target_properties(plugin-anthropic PROPERTIES
|
||||
set_target_properties(plugin_anthropic PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
#include "ai_common.hpp"
|
||||
|
||||
#include <boost/json.hpp>
|
||||
#include <boost/json/src.hpp>
|
||||
@@ -19,60 +20,18 @@ namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 全局指针 — W17.4: std::atomic 保护 on_shutdown 与 service 函数并发读写 / Global pointers — W17.4: std::atomic protects concurrent read/write between on_shutdown and service functions
|
||||
// W21.5: g_config 改为 atomic,修复数据竞争 / g_config changed to atomic, fixing data race
|
||||
// ============================================================================
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
|
||||
static dstalk_config_service_t* g_config = nullptr;
|
||||
static std::atomic<dstalk_config_service_t*> g_config{nullptr};
|
||||
|
||||
// ============================================================================
|
||||
// 配置数据 / Config data
|
||||
// ============================================================================
|
||||
struct PluginConfig {
|
||||
std::string provider;
|
||||
std::string base_url;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
int max_tokens = 4096;
|
||||
double temperature = 0.7;
|
||||
};
|
||||
static PluginConfig g_cfg;
|
||||
static dstalk_ai::PluginConfig g_cfg;
|
||||
static std::string g_tools_json; // W21.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
|
||||
|
||||
// ============================================================================
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
|
||||
// ============================================================================
|
||||
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
|
||||
static void secure_zero(void* p, size_t n) {
|
||||
volatile char* vp = (volatile char*)p;
|
||||
while (n--) *vp++ = 0;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:提取 host / target / Helper: extract host / target
|
||||
// ============================================================================
|
||||
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
|
||||
static bool extract_host_port(const std::string& url,
|
||||
std::string& scheme_out, std::string& host_out,
|
||||
std::string& port_out, std::string& target_out)
|
||||
{
|
||||
size_t scheme_end = url.find("://");
|
||||
if (scheme_end == std::string::npos) return false;
|
||||
scheme_out = url.substr(0, scheme_end);
|
||||
std::string rest = url.substr(scheme_end + 3);
|
||||
size_t slash = rest.find('/');
|
||||
std::string authority = (slash != std::string::npos) ? rest.substr(0, slash) : rest;
|
||||
target_out = (slash != std::string::npos) ? rest.substr(slash) : "/";
|
||||
size_t colon = authority.rfind(':');
|
||||
if (colon != std::string::npos) {
|
||||
host_out = authority.substr(0, colon);
|
||||
port_out = authority.substr(colon + 1);
|
||||
} else {
|
||||
host_out = authority;
|
||||
port_out = (scheme_out == "https") ? "443" : "80";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 构建 Anthropic headers JSON / Build Anthropic headers JSON
|
||||
// ============================================================================
|
||||
@@ -151,10 +110,11 @@ static std::string build_request_json(
|
||||
// 将非流式 JSON 响应体解析为 dstalk_chat_result_t。
|
||||
// 处理 text 和 tool_use content block,将 tool_use 转换为 OpenAI 格式 / Parse a non-streaming JSON response body into a dstalk_chat_result_t.
|
||||
// Handles text and tool_use content blocks, converting tool_use to OpenAI format.
|
||||
static void parse_response(const char* body, int http_status,
|
||||
// W21.5: 添加 host nullptr 守卫,防止空指针解引用 / Added host nullptr guard to prevent null dereference
|
||||
static void parse_response(const dstalk_host_api_t* host,
|
||||
const char* body, int http_status,
|
||||
dstalk_chat_result_t& r)
|
||||
{
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
r.http_status = http_status;
|
||||
|
||||
if (http_status < 200 || http_status >= 300) {
|
||||
@@ -164,16 +124,16 @@ static void parse_response(const char* body, int http_status,
|
||||
auto obj = jv.as_object();
|
||||
if (obj.contains("error")) {
|
||||
auto err = obj["error"].as_object();
|
||||
r.error = h->strdup(
|
||||
json::value_to<std::string>(err["message"]).c_str());
|
||||
r.error = host ? host->strdup(
|
||||
json::value_to<std::string>(err["message"]).c_str()) : nullptr;
|
||||
}
|
||||
} catch (...) {
|
||||
std::string msg = "HTTP " + std::to_string(http_status);
|
||||
r.error = h->strdup(msg.c_str());
|
||||
r.error = host ? host->strdup(msg.c_str()) : nullptr;
|
||||
}
|
||||
if (!r.error) {
|
||||
if (!r.error && host) {
|
||||
std::string msg = "HTTP " + std::to_string(http_status);
|
||||
r.error = h->strdup(msg.c_str());
|
||||
r.error = host->strdup(msg.c_str());
|
||||
}
|
||||
r.content = nullptr;
|
||||
r.tool_calls_json = nullptr;
|
||||
@@ -210,14 +170,14 @@ static void parse_response(const char* body, int http_status,
|
||||
}
|
||||
|
||||
if (!tool_use_blocks.empty()) {
|
||||
r.tool_calls_json = h->strdup(
|
||||
json::serialize(tool_use_blocks).c_str());
|
||||
r.tool_calls_json = host ? host->strdup(
|
||||
json::serialize(tool_use_blocks).c_str()) : nullptr;
|
||||
} else {
|
||||
r.tool_calls_json = nullptr;
|
||||
}
|
||||
|
||||
if (!text_content.empty()) {
|
||||
r.content = h->strdup(text_content.c_str());
|
||||
r.content = host ? host->strdup(text_content.c_str()) : nullptr;
|
||||
r.ok = 1;
|
||||
r.error = nullptr;
|
||||
return;
|
||||
@@ -229,22 +189,22 @@ static void parse_response(const char* body, int http_status,
|
||||
return;
|
||||
}
|
||||
r.ok = 0;
|
||||
r.error = h->strdup("no text or tool_use content block found");
|
||||
r.error = host ? host->strdup("no text or tool_use content block found") : nullptr;
|
||||
} else {
|
||||
r.ok = 0;
|
||||
r.error = h->strdup("empty response");
|
||||
r.error = host ? host->strdup("empty response") : nullptr;
|
||||
}
|
||||
r.content = nullptr;
|
||||
r.tool_calls_json = nullptr;
|
||||
} catch (std::exception& e) {
|
||||
r.ok = 0;
|
||||
std::string msg = std::string("json parse: ") + e.what();
|
||||
r.error = h->strdup(msg.c_str());
|
||||
r.error = host ? host->strdup(msg.c_str()) : nullptr;
|
||||
r.content = nullptr;
|
||||
r.tool_calls_json = nullptr;
|
||||
} catch (...) {
|
||||
r.ok = 0;
|
||||
r.error = h->strdup("json parse error");
|
||||
r.error = host ? host->strdup("json parse error") : nullptr;
|
||||
r.content = nullptr;
|
||||
r.tool_calls_json = nullptr;
|
||||
}
|
||||
@@ -254,23 +214,6 @@ static void parse_response(const char* body, int http_status,
|
||||
// SSE 事件解析(Anthropic 格式: event/content_block_delta) / SSE event parsing (Anthropic format: event/content_block_delta)
|
||||
// ============================================================================
|
||||
|
||||
// W21.2: 按 content_block index 累积 Anthropic tool_use 增量 / Accumulate Anthropic tool_use increments by content_block index
|
||||
struct ToolCallAccum {
|
||||
int index = -1;
|
||||
std::string id;
|
||||
std::string name;
|
||||
std::string arguments; // 从 input_json_delta.partial_json 累积 / accumulated from input_json_delta.partial_json
|
||||
};
|
||||
|
||||
struct StreamContext {
|
||||
const dstalk_host_api_t* host;
|
||||
dstalk_stream_cb user_cb;
|
||||
void* userdata;
|
||||
std::string accumulated;
|
||||
bool saw_data_line = false;
|
||||
std::vector<ToolCallAccum> tool_calls; // W21.2: 按 index 累积 tool_use content blocks / accumulate tool_use content blocks by index
|
||||
};
|
||||
|
||||
// W21.2: 解析 Anthropic SSE 事件,含 tool_use content_block 增量解析 / Parse Anthropic SSE events with tool_use content_block incremental parsing
|
||||
// 解析单个 Anthropic SSE "data:" JSON 事件。处理 content_block_start、
|
||||
// content_block_delta (text_delta/input_json_delta) 和 message_stop。
|
||||
@@ -278,7 +221,7 @@ struct StreamContext {
|
||||
// content_block_delta (text_delta/input_json_delta), and message_stop.
|
||||
// Returns true if a content token was produced, false otherwise.
|
||||
static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
StreamContext* ctx)
|
||||
dstalk_ai::StreamContext* ctx)
|
||||
{
|
||||
try {
|
||||
auto jv = json::parse(data);
|
||||
@@ -313,6 +256,7 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
if (cb_obj.contains("name") && cb_obj["name"].is_string())
|
||||
acc.name = json::value_to<std::string>(cb_obj["name"]);
|
||||
}
|
||||
if (ctx) ctx->sse_parse_errors = 0; // 成功解析 / successful parse
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -329,6 +273,7 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
auto* text = dobj.if_contains("text");
|
||||
if (text && text->is_string()) {
|
||||
token_out = json::value_to<std::string>(*text);
|
||||
if (ctx) ctx->sse_parse_errors = 0; // 成功解析 / successful parse
|
||||
return true;
|
||||
}
|
||||
} else if (delta_type == "input_json_delta" && ctx) {
|
||||
@@ -343,15 +288,29 @@ static bool parse_sse_data(const std::string& data, std::string& token_out,
|
||||
json::value_to<std::string>(*pj);
|
||||
}
|
||||
}
|
||||
ctx->sse_parse_errors = 0; // 成功解析 / successful parse
|
||||
return false;
|
||||
}
|
||||
} else if (type == "message_stop") {
|
||||
token_out.clear();
|
||||
if (ctx) ctx->sse_parse_errors = 0; // 成功解析 / successful parse
|
||||
return true; // 流结束 / stream end
|
||||
}
|
||||
// 忽略: message_start, content_block_stop, ping, message_delta / Ignore: message_start, content_block_stop, ping, message_delta
|
||||
// 已知事件类型但无需处理 — 重置计数器 / known event type but no processing needed — reset counter
|
||||
if (ctx) ctx->sse_parse_errors = 0;
|
||||
} catch (...) {
|
||||
// 解析失败忽略 / Ignore parse failures
|
||||
if (ctx) {
|
||||
ctx->sse_parse_errors++;
|
||||
const auto* log_host = g_host.load(std::memory_order_acquire);
|
||||
if (log_host) {
|
||||
if (ctx->sse_parse_errors == 1 || ctx->sse_parse_errors % 5 == 0) {
|
||||
log_host->log(DSTALK_LOG_WARN,
|
||||
"[anthropic] SSE parse error (#%d consecutive)",
|
||||
ctx->sse_parse_errors);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -375,15 +334,7 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) {
|
||||
// W21.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
|
||||
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
|
||||
h->query_service("tools", 1));
|
||||
if (tools_svc && tools_svc->get_tools_json) {
|
||||
char* json = tools_svc->get_tools_json();
|
||||
if (json) {
|
||||
g_tools_json = json;
|
||||
h->free(json);
|
||||
}
|
||||
}
|
||||
dstalk_ai::cache_tools_json(h, g_tools_json);
|
||||
|
||||
h->log(DSTALK_LOG_INFO,
|
||||
"[anthropic] configured: model=%s base_url=%s max_tokens=%d temperature=%.2f",
|
||||
@@ -419,12 +370,12 @@ static dstalk_chat_result_t my_chat(
|
||||
const auto* http = g_http.load(std::memory_order_acquire);
|
||||
|
||||
if (!http) {
|
||||
r.error = host->strdup("http service not available");
|
||||
r.error = host ? host->strdup("http service not available") : nullptr;
|
||||
return r;
|
||||
}
|
||||
|
||||
std::string scheme, hostname, port, target;
|
||||
extract_host_port(g_cfg.base_url, scheme, hostname, port, target);
|
||||
dstalk_ai::extract_host_port(g_cfg.base_url, scheme, hostname, port, target);
|
||||
std::string target_path = target + "/v1/messages";
|
||||
|
||||
std::string body = build_request_json(history, history_len,
|
||||
@@ -441,15 +392,15 @@ static dstalk_chat_result_t my_chat(
|
||||
headers_json.c_str(), &response_body, &status_code);
|
||||
|
||||
if (ret != 0) {
|
||||
r.error = host->strdup("http request failed");
|
||||
if (response_body) host->free(response_body);
|
||||
r.error = host ? host->strdup("http request failed") : nullptr;
|
||||
if (response_body && host) host->free(response_body);
|
||||
return r;
|
||||
}
|
||||
|
||||
parse_response(response_body, status_code, r);
|
||||
parse_response(host, response_body, status_code, r);
|
||||
|
||||
if (response_body) {
|
||||
host->free(response_body);
|
||||
if (host) host->free(response_body);
|
||||
}
|
||||
return r;
|
||||
} catch (const std::exception& e) {
|
||||
@@ -473,12 +424,11 @@ static dstalk_chat_result_t my_chat(
|
||||
// chat_stream / chat_stream
|
||||
// ============================================================================
|
||||
|
||||
// 行回调 / SSE line callback
|
||||
// SSE 行回调:解析每个 Anthropic SSE 行并将文本 token 转发给用户 / SSE line callback: parses each Anthropic SSE line and forwards text tokens to user.
|
||||
static int sse_line_callback(const char* line, void* userdata)
|
||||
{
|
||||
try {
|
||||
auto* ctx = static_cast<StreamContext*>(userdata);
|
||||
auto* ctx = static_cast<dstalk_ai::StreamContext*>(userdata);
|
||||
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
|
||||
|
||||
std::string line_str(line);
|
||||
@@ -489,6 +439,16 @@ static int sse_line_callback(const char* line, void* userdata)
|
||||
std::string token;
|
||||
if (parse_sse_data(data, token, ctx)) {
|
||||
ctx->saw_data_line = true;
|
||||
|
||||
// W21.5: 连续 SSE 解析错误超过阈值,中止流 / consecutive SSE parse errors exceed threshold, abort stream
|
||||
if (ctx->sse_parse_errors >= dstalk_ai::kMaxSseParseErrors) {
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) h->log(DSTALK_LOG_ERROR,
|
||||
"[anthropic] SSE stream aborted: %d consecutive parse errors",
|
||||
ctx->sse_parse_errors);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (token.empty()) {
|
||||
// message_stop / message_stop
|
||||
return 0;
|
||||
@@ -528,12 +488,12 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
const auto* http = g_http.load(std::memory_order_acquire);
|
||||
|
||||
if (!http) {
|
||||
r.error = host->strdup("http service not available");
|
||||
r.error = host ? host->strdup("http service not available") : nullptr;
|
||||
return r;
|
||||
}
|
||||
|
||||
std::string scheme, hostname, port, target;
|
||||
extract_host_port(g_cfg.base_url, scheme, hostname, port, target);
|
||||
dstalk_ai::extract_host_port(g_cfg.base_url, scheme, hostname, port, target);
|
||||
std::string target_path = target + "/v1/messages";
|
||||
|
||||
std::string body = build_request_json(history, history_len,
|
||||
@@ -541,7 +501,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
std::string headers_json = build_headers_json();
|
||||
|
||||
StreamContext ctx;
|
||||
dstalk_ai::StreamContext ctx;
|
||||
ctx.host = host;
|
||||
ctx.user_cb = cb;
|
||||
ctx.userdata = userdata;
|
||||
@@ -567,25 +527,27 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
auto obj = jv.as_object();
|
||||
if (obj.contains("error")) {
|
||||
auto err = obj["error"].as_object();
|
||||
r.error = host->strdup(
|
||||
json::value_to<std::string>(err["message"]).c_str());
|
||||
r.error = host ? host->strdup(
|
||||
json::value_to<std::string>(err["message"]).c_str()) : nullptr;
|
||||
}
|
||||
} catch (...) {}
|
||||
} catch (...) {
|
||||
if (host) host->log(DSTALK_LOG_WARN, "[anthropic] SSE error body parse error (ignored)");
|
||||
}
|
||||
}
|
||||
if (!r.error) {
|
||||
if (!r.error && host) {
|
||||
if (status_code <= 0)
|
||||
r.error = host->strdup("transport error");
|
||||
else
|
||||
r.error = host->strdup(
|
||||
("HTTP " + std::to_string(status_code)).c_str());
|
||||
}
|
||||
if (response_body) host->free(response_body);
|
||||
if (response_body && host) host->free(response_body);
|
||||
r.content = nullptr;
|
||||
r.tool_calls_json = nullptr;
|
||||
return r;
|
||||
}
|
||||
|
||||
if (response_body) host->free(response_body);
|
||||
if (response_body && host) host->free(response_body);
|
||||
|
||||
// W21.2: 成功条件 = 有内容 OR 有 tool_calls(tool-only 响应如 function calling) / Success = has content OR has tool_calls (tool-only responses like function calling)
|
||||
bool has_content = !ctx.accumulated.empty();
|
||||
@@ -593,7 +555,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
if (!has_content && !has_tool_calls) {
|
||||
r.ok = 0;
|
||||
r.error = host->strdup("no content received");
|
||||
r.error = host ? host->strdup("no content received") : nullptr;
|
||||
r.content = nullptr;
|
||||
r.tool_calls_json = nullptr;
|
||||
} else {
|
||||
@@ -604,19 +566,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
// W21.2: 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible format)
|
||||
if (has_tool_calls) {
|
||||
json::array tc_array;
|
||||
for (auto& tc : ctx.tool_calls) {
|
||||
json::object tc_obj;
|
||||
tc_obj["index"] = tc.index;
|
||||
if (!tc.id.empty()) tc_obj["id"] = tc.id;
|
||||
tc_obj["type"] = "function";
|
||||
json::object func;
|
||||
if (!tc.name.empty()) func["name"] = tc.name;
|
||||
func["arguments"] = tc.arguments;
|
||||
tc_obj["function"] = func;
|
||||
tc_array.push_back(std::move(tc_obj));
|
||||
}
|
||||
std::string tc_json = json::serialize(tc_array);
|
||||
std::string tc_json = dstalk_ai::serialize_tool_calls(ctx.tool_calls);
|
||||
r.tool_calls_json = host ? host->strdup(tc_json.c_str()) : nullptr;
|
||||
} else {
|
||||
r.tool_calls_json = nullptr;
|
||||
@@ -647,10 +597,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
static void my_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
if (!result || !h) return;
|
||||
if (result->content) { h->free((void*)result->content); result->content = nullptr; }
|
||||
if (result->error) { h->free((void*)result->error); result->error = nullptr; }
|
||||
if (result->tool_calls_json) { h->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; }
|
||||
dstalk_ai::free_chat_result(h, result);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -674,7 +621,9 @@ static int on_init(const dstalk_host_api_t* host)
|
||||
|
||||
auto* http_svc = (dstalk_http_service_t*)host->query_service("http", 1);
|
||||
g_http.store(http_svc, std::memory_order_release);
|
||||
g_config = (dstalk_config_service_t*)host->query_service("config", 1);
|
||||
// W21.5: atomic store 替代裸指针赋值 / atomic store replaces raw pointer assignment
|
||||
auto* cfg_svc = (dstalk_config_service_t*)host->query_service("config", 1);
|
||||
g_config.store(cfg_svc, std::memory_order_release);
|
||||
|
||||
if (!http_svc) {
|
||||
if (host) host->log(DSTALK_LOG_ERROR, "[anthropic] http service not found");
|
||||
@@ -683,7 +632,7 @@ static int on_init(const dstalk_host_api_t* host)
|
||||
|
||||
if (host) host->log(DSTALK_LOG_INFO, "[anthropic] initializing Anthropic AI plugin");
|
||||
|
||||
return host->register_service("ai.anthropic", 1, &g_service);
|
||||
return host->register_service("ai_anthropic", 1, &g_service);
|
||||
} catch (const std::exception& e) {
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[anthropic] on_init exception: %s", e.what());
|
||||
@@ -701,10 +650,11 @@ static void on_shutdown()
|
||||
try {
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
if (h) h->log(DSTALK_LOG_INFO, "[anthropic] shutdown");
|
||||
secure_zero(g_cfg.api_key.data(), g_cfg.api_key.size());
|
||||
dstalk_ai::secure_zero(g_cfg.api_key.data(), g_cfg.api_key.size());
|
||||
g_cfg.api_key.clear();
|
||||
g_http.store(nullptr, std::memory_order_release);
|
||||
g_config = nullptr;
|
||||
// W21.5: atomic store 替代裸指针赋值,消除数据竞争 / atomic store replaces raw pointer assignment, eliminates data race
|
||||
g_config.store(nullptr, std::memory_order_release);
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
} catch (const std::exception& e) {
|
||||
const auto* h = g_host.load(std::memory_order_acquire);
|
||||
@@ -719,7 +669,7 @@ static void on_shutdown()
|
||||
// 插件描述符 / Plugin descriptor
|
||||
// ============================================================================
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "anthropic-ai",
|
||||
/* .name = */ "anthropic_ai",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "Anthropic Claude AI provider (Messages API) / Anthropic Claude AI 提供者 (Messages API)",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
add_library(plugin-context SHARED src/context_plugin.cpp)
|
||||
add_library(plugin_context SHARED src/context_plugin.cpp)
|
||||
|
||||
target_link_libraries(plugin-context PRIVATE dstalk)
|
||||
target_link_libraries(plugin_context PRIVATE dstalk)
|
||||
|
||||
set_target_properties(plugin-context PROPERTIES
|
||||
set_target_properties(plugin_context PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-context: 上下文管理服务插件 / Context management service plugin
|
||||
// plugin_context: 上下文管理服务插件 / Context management service plugin
|
||||
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
|
||||
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
@@ -263,14 +263,21 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
system_tokens, max_tokens);
|
||||
}
|
||||
|
||||
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
|
||||
// Precompute per-message token counts for non_system_msgs so that the
|
||||
// trim pass below is O(N) instead of O(N*K) (no re-counting per iteration).
|
||||
std::vector<size_t> ns_token_counts;
|
||||
ns_token_counts.reserve(non_system_msgs.size());
|
||||
for (const auto& msg : non_system_msgs) {
|
||||
size_t msg_tokens = count_tokens_trim(msg);
|
||||
if (msg_tokens > max_tokens) {
|
||||
ns_token_counts.push_back(count_tokens_trim(msg));
|
||||
}
|
||||
|
||||
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
|
||||
for (size_t i = 0; i < non_system_msgs.size(); ++i) {
|
||||
if (ns_token_counts[i] > max_tokens) {
|
||||
std::fprintf(stderr, "[context] WARNING: single message "
|
||||
"(%s, %zu tokens) exceeds max_context_tokens (%zu). "
|
||||
"Returning empty list.\n",
|
||||
msg.role.c_str(), msg_tokens, max_tokens);
|
||||
non_system_msgs[i].role.c_str(), ns_token_counts[i], max_tokens);
|
||||
*out = nullptr;
|
||||
*out_count = 0;
|
||||
return -1;
|
||||
@@ -278,31 +285,53 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
|
||||
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
|
||||
while (!non_system_msgs.empty()) {
|
||||
current = system_tokens + count_tokens_trim_vec(non_system_msgs);
|
||||
if (current <= max_tokens) break;
|
||||
// O(N): precompute token counts once, then mark removal candidates in a single forward pass
|
||||
{
|
||||
size_t ns_total = 0;
|
||||
for (size_t t : ns_token_counts) ns_total += t;
|
||||
current = system_tokens + ns_total;
|
||||
|
||||
// 找第一个 "user" 消息 / Find first "user" message
|
||||
auto user_it = non_system_msgs.begin();
|
||||
while (user_it != non_system_msgs.end() && user_it->role != "user") {
|
||||
++user_it;
|
||||
}
|
||||
if (user_it == non_system_msgs.end()) break;
|
||||
if (current > max_tokens) {
|
||||
std::vector<bool> keep(non_system_msgs.size(), true);
|
||||
size_t idx = 0;
|
||||
while (idx < non_system_msgs.size() && current > max_tokens) {
|
||||
// 找第一个 "user" 消息 / Find first "user" message
|
||||
while (idx < non_system_msgs.size() && non_system_msgs[idx].role != "user") {
|
||||
++idx;
|
||||
}
|
||||
if (idx >= non_system_msgs.size()) break;
|
||||
|
||||
// 找下一个 "assistant" / Find next "assistant"
|
||||
auto assistant_it = user_it + 1;
|
||||
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
|
||||
++assistant_it;
|
||||
}
|
||||
size_t user_idx = idx;
|
||||
++idx;
|
||||
|
||||
if (assistant_it == non_system_msgs.end()) {
|
||||
non_system_msgs.erase(user_it);
|
||||
} else {
|
||||
// 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
|
||||
non_system_msgs.erase(assistant_it);
|
||||
user_it = non_system_msgs.begin();
|
||||
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
|
||||
if (user_it != non_system_msgs.end()) non_system_msgs.erase(user_it);
|
||||
// 找下一个 "assistant" / Find next "assistant"
|
||||
while (idx < non_system_msgs.size() && non_system_msgs[idx].role != "assistant") {
|
||||
++idx;
|
||||
}
|
||||
|
||||
if (idx >= non_system_msgs.size()) {
|
||||
// 没有配对的 assistant,只移除 user / No paired assistant, remove user only
|
||||
keep[user_idx] = false;
|
||||
current -= ns_token_counts[user_idx];
|
||||
idx = user_idx + 1; // restart search after the removed message
|
||||
} else {
|
||||
// 移除 user + assistant 对 / Remove user + assistant pair
|
||||
keep[user_idx] = false;
|
||||
keep[idx] = false;
|
||||
current -= ns_token_counts[user_idx] + ns_token_counts[idx];
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild non_system_msgs with only kept messages (single O(N) pass)
|
||||
std::vector<TrimMessage> kept;
|
||||
kept.reserve(non_system_msgs.size());
|
||||
for (size_t i = 0; i < non_system_msgs.size(); ++i) {
|
||||
if (keep[i]) {
|
||||
kept.push_back(std::move(non_system_msgs[i]));
|
||||
}
|
||||
}
|
||||
non_system_msgs = std::move(kept);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,8 +339,10 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
{
|
||||
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
|
||||
if (max_msg_count < 1) max_msg_count = 1;
|
||||
while (non_system_msgs.size() > max_msg_count) {
|
||||
non_system_msgs.erase(non_system_msgs.begin());
|
||||
// O(N) single range-erase instead of O(N²) repeated erase(begin())
|
||||
if (non_system_msgs.size() > max_msg_count) {
|
||||
size_t to_remove = non_system_msgs.size() - max_msg_count;
|
||||
non_system_msgs.erase(non_system_msgs.begin(), non_system_msgs.begin() + to_remove);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -397,17 +428,17 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
// 查询依赖服务: session / Query dependency service: session
|
||||
void* raw = host->query_service("session", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin_context] required service 'session' not found");
|
||||
return -1;
|
||||
}
|
||||
g_session = static_cast<const dstalk_session_service_t*>(raw);
|
||||
|
||||
return host->register_service("context", 1, &g_context_service);
|
||||
} catch (const std::exception& e) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init exception: %s", e.what());
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_init exception: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init unknown exception");
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_init unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -419,11 +450,11 @@ static void on_shutdown() {
|
||||
g_session = nullptr;
|
||||
g_host = nullptr;
|
||||
} catch (const std::exception& e) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: %s", e.what());
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_shutdown: %s", e.what());
|
||||
g_session = nullptr;
|
||||
g_host = nullptr;
|
||||
} catch (...) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: unknown exception");
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_shutdown: unknown exception");
|
||||
g_session = nullptr;
|
||||
g_host = nullptr;
|
||||
}
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
# ============================================================
|
||||
# plugin-openai — OpenAI 兼容 AI 服务 / OpenAI-compatible AI service
|
||||
# plugin_openai — OpenAI 兼容 AI 服务 / OpenAI-compatible AI service
|
||||
# ============================================================
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
add_library(plugin-openai SHARED
|
||||
add_library(plugin_openai SHARED
|
||||
src/openai_plugin.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(plugin-openai PRIVATE dstalk)
|
||||
target_link_libraries(plugin_openai PRIVATE dstalk ai_common)
|
||||
|
||||
# Boost.JSON (header-only)
|
||||
target_link_libraries(plugin-openai PRIVATE boost::boost dstalk_boost_config)
|
||||
target_link_libraries(plugin_openai PRIVATE boost::boost dstalk_boost_config)
|
||||
|
||||
set_target_properties(plugin-openai PROPERTIES
|
||||
set_target_properties(plugin_openai PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/plugins
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/plugins
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
#include "ai_common.hpp"
|
||||
|
||||
#include <boost/json.hpp>
|
||||
#include <boost/json/src.hpp>
|
||||
@@ -18,7 +19,7 @@
|
||||
namespace json = boost::json;
|
||||
|
||||
// ============================================================================
|
||||
// 全局指针:从 on_init 获取(W14.3: atomic acquire/release 保护读写竞态) / Global pointers: obtained from on_init (W14.3: atomic acquire/release protects read/write races)
|
||||
// 全局指针:从 on_init 获取(atomic acquire/release 保护读写竞态) / Global pointers: obtained from on_init (atomic acquire/release protects read/write races)
|
||||
// ============================================================================
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::atomic<dstalk_http_service_t*> g_http{nullptr};
|
||||
@@ -27,52 +28,9 @@ static std::atomic<dstalk_config_service_t*> g_config{nullptr};
|
||||
// ============================================================================
|
||||
// 配置数据(由 configure() 设置) / Config data (set by configure())
|
||||
// ============================================================================
|
||||
struct PluginConfig {
|
||||
std::string provider;
|
||||
std::string base_url;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
int max_tokens = 4096;
|
||||
double temperature = 0.7;
|
||||
};
|
||||
static PluginConfig g_cfg;
|
||||
static dstalk_ai::PluginConfig g_cfg;
|
||||
static std::string g_tools_json; // W20.2: 由 configure() 缓存,供 chat/chat_stream 使用 / cached by configure(), consumed by chat/chat_stream
|
||||
|
||||
// ============================================================================
|
||||
// 安全擦除:用 volatile 写零循环防止编译器优化 / Secure erase: write zero loop through volatile to prevent compiler optimization
|
||||
// ============================================================================
|
||||
// 通过 volatile 写入零来安全擦除内存,防止编译器优化 / Securely zero out memory by writing through volatile to prevent compiler optimization.
|
||||
static void secure_zero(void* p, size_t n) {
|
||||
volatile char* vp = (volatile char*)p;
|
||||
while (n--) *vp++ = 0;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:从 base_url 提取 host 和 target / Helper: extract host and target from base_url
|
||||
// ============================================================================
|
||||
// 将 URL 解析为 scheme、host、port 和 target path 组件 / Parse a URL into scheme, host, port, and target path components.
|
||||
static bool extract_host_port(const std::string& url,
|
||||
std::string& scheme_out, std::string& host_out,
|
||||
std::string& port_out, std::string& target_out)
|
||||
{
|
||||
size_t scheme_end = url.find("://");
|
||||
if (scheme_end == std::string::npos) return false;
|
||||
scheme_out = url.substr(0, scheme_end);
|
||||
std::string rest = url.substr(scheme_end + 3);
|
||||
size_t slash = rest.find('/');
|
||||
std::string authority = (slash != std::string::npos) ? rest.substr(0, slash) : rest;
|
||||
target_out = (slash != std::string::npos) ? rest.substr(slash) : "/";
|
||||
size_t colon = authority.rfind(':');
|
||||
if (colon != std::string::npos) {
|
||||
host_out = authority.substr(0, colon);
|
||||
port_out = authority.substr(colon + 1);
|
||||
} else {
|
||||
host_out = authority;
|
||||
port_out = (scheme_out == "https") ? "443" : "80";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助:构建 headers JSON 字符串 / Helper: build headers JSON string
|
||||
// ============================================================================
|
||||
@@ -219,25 +177,6 @@ static void parse_response(const dstalk_host_api_t* host,
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 流式上下文:在 SSE 回调间累积内容和 tool_calls / Stream context: accumulate content and tool_calls across SSE callbacks
|
||||
// ============================================================================
|
||||
struct ToolCallAccum {
|
||||
int index = -1;
|
||||
std::string id;
|
||||
std::string name;
|
||||
std::string arguments; // 增量拼接的 JSON arguments 字符串 / incrementally concatenated JSON arguments string
|
||||
};
|
||||
|
||||
struct StreamContext {
|
||||
const dstalk_host_api_t* host;
|
||||
dstalk_stream_cb user_cb;
|
||||
void* userdata;
|
||||
std::string accumulated;
|
||||
bool streaming_ok = true;
|
||||
std::vector<ToolCallAccum> tool_calls; // W20.2: 按 index 累积 delta tool_calls / accumulate delta tool_calls by index
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// SSE 行解析(OpenAI 兼容格式) / SSE line parsing (OpenAI-compatible format)
|
||||
// ============================================================================
|
||||
@@ -248,7 +187,7 @@ struct StreamContext {
|
||||
// to token_out. If it contains tool_calls delta, accumulates into ctx->tool_calls.
|
||||
// Returns true if a content token was produced, false otherwise (tool_calls or unknown).
|
||||
static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
StreamContext* ctx)
|
||||
dstalk_ai::StreamContext* ctx)
|
||||
{
|
||||
if (line.rfind("data: ", 0) != 0) return false;
|
||||
|
||||
@@ -263,6 +202,7 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
}
|
||||
if (data == "[DONE]") {
|
||||
token_out.clear();
|
||||
if (ctx) ctx->sse_parse_errors = 0; // 成功解析,重置错误计数 / successful parse, reset error counter
|
||||
return true; // 流结束信号 / stream end signal
|
||||
}
|
||||
|
||||
@@ -307,16 +247,31 @@ static bool parse_sse_line(const std::string& line, std::string& token_out,
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ctx) ctx->sse_parse_errors = 0; // 成功解析 / successful parse
|
||||
return false; // tool_calls 已处理,无内容 token 给用户回调 / tool_calls processed, no content token for user callback
|
||||
}
|
||||
|
||||
if (delta.contains("content")) {
|
||||
token_out = json::value_to<std::string>(delta["content"]);
|
||||
if (ctx) ctx->sse_parse_errors = 0; // 成功解析 / successful parse
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// 有效 JSON 但不是已知格式 — 非错误,只是未知事件类型 / valid JSON but unknown format — not an error, just unknown event type
|
||||
// 重置计数器:JSON 本身解析成功 / reset counter: JSON itself parsed successfully
|
||||
if (ctx) ctx->sse_parse_errors = 0;
|
||||
} catch (...) {
|
||||
// 忽略解析失败 / Ignore parse failures
|
||||
if (ctx) {
|
||||
ctx->sse_parse_errors++;
|
||||
const dstalk_host_api_t* log_host = g_host.load(std::memory_order_acquire);
|
||||
if (log_host) {
|
||||
if (ctx->sse_parse_errors == 1 || ctx->sse_parse_errors % 5 == 0) {
|
||||
log_host->log(DSTALK_LOG_WARN,
|
||||
"[openai] SSE parse error (#%d consecutive)",
|
||||
ctx->sse_parse_errors);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -340,15 +295,7 @@ static int my_configure(const char* provider, const char* base_url,
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) {
|
||||
// W20.2: 从 tools service 缓存 tools_json,供 chat/chat_stream 复用 / Cache tools_json from tools service for reuse in chat/chat_stream
|
||||
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
|
||||
host->query_service("tools", 1));
|
||||
if (tools_svc && tools_svc->get_tools_json) {
|
||||
char* json = tools_svc->get_tools_json();
|
||||
if (json) {
|
||||
g_tools_json = json;
|
||||
host->free(json);
|
||||
}
|
||||
}
|
||||
dstalk_ai::cache_tools_json(host, g_tools_json);
|
||||
|
||||
host->log(DSTALK_LOG_INFO,
|
||||
"[openai] configured: model=%s base_url=%s max_tokens=%d temperature=%.2f",
|
||||
@@ -376,6 +323,8 @@ static dstalk_chat_result_t my_chat(
|
||||
const char* user_input,
|
||||
const char* tools_json)
|
||||
{
|
||||
char* response_body = nullptr;
|
||||
int status_code = 0;
|
||||
try {
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 0;
|
||||
@@ -389,7 +338,7 @@ static dstalk_chat_result_t my_chat(
|
||||
}
|
||||
|
||||
std::string scheme, host_name, port, target;
|
||||
extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
|
||||
dstalk_ai::extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
|
||||
std::string target_path = target + "/chat/completions";
|
||||
|
||||
std::string body = build_request_json(history, history_len,
|
||||
@@ -397,15 +346,13 @@ static dstalk_chat_result_t my_chat(
|
||||
|
||||
std::string headers_json = build_headers_json(g_cfg.api_key);
|
||||
|
||||
char* response_body = nullptr;
|
||||
int status_code = 0;
|
||||
|
||||
int ret = http->post_json(
|
||||
host_name.c_str(), port.c_str(), target_path.c_str(), body.c_str(),
|
||||
headers_json.c_str(), &response_body, &status_code);
|
||||
|
||||
if (ret != 0) {
|
||||
r.error = host ? host->strdup("http request failed") : nullptr;
|
||||
if (response_body) host->free(response_body);
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -418,6 +365,7 @@ static dstalk_chat_result_t my_chat(
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat exception: %s", e.what());
|
||||
if (response_body && host) host->free(response_body);
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 0;
|
||||
r.error = host ? host->strdup(e.what()) : nullptr;
|
||||
@@ -425,6 +373,7 @@ static dstalk_chat_result_t my_chat(
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat unknown exception");
|
||||
if (response_body && host) host->free(response_body);
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 0;
|
||||
r.error = host ? host->strdup("unknown exception") : nullptr;
|
||||
@@ -440,7 +389,7 @@ static dstalk_chat_result_t my_chat(
|
||||
static int sse_line_callback(const char* line, void* userdata)
|
||||
{
|
||||
try {
|
||||
auto* ctx = static_cast<StreamContext*>(userdata);
|
||||
auto* ctx = static_cast<dstalk_ai::StreamContext*>(userdata);
|
||||
if (!line || !line[0]) return 1; // 空行,继续 / empty line, continue
|
||||
|
||||
std::string line_str(line);
|
||||
@@ -448,6 +397,15 @@ static int sse_line_callback(const char* line, void* userdata)
|
||||
|
||||
if (!parse_sse_line(line_str, token, ctx)) return 1; // 非 data/tool_calls 行,继续 / not a data/tool_calls line, continue
|
||||
|
||||
// W21.5: 连续 SSE 解析错误超过阈值,中止流 / consecutive SSE parse errors exceed threshold, abort stream
|
||||
if (ctx && ctx->sse_parse_errors >= dstalk_ai::kMaxSseParseErrors) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_ERROR,
|
||||
"[openai] SSE stream aborted: %d consecutive parse errors",
|
||||
ctx->sse_parse_errors);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (token.empty()) return 0; // [DONE],停止 / [DONE], stop
|
||||
|
||||
ctx->accumulated += token;
|
||||
@@ -475,6 +433,8 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
const char* user_input,
|
||||
dstalk_stream_cb cb, void* userdata)
|
||||
{
|
||||
char* response_body = nullptr;
|
||||
int status_code = 0;
|
||||
try {
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 0;
|
||||
@@ -488,7 +448,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
}
|
||||
|
||||
std::string scheme, host_name, port, target;
|
||||
extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
|
||||
dstalk_ai::extract_host_port(g_cfg.base_url, scheme, host_name, port, target);
|
||||
std::string target_path = target + "/chat/completions";
|
||||
|
||||
std::string body = build_request_json(history, history_len,
|
||||
@@ -496,14 +456,11 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
std::string headers_json = build_headers_json(g_cfg.api_key);
|
||||
|
||||
StreamContext ctx;
|
||||
dstalk_ai::StreamContext ctx;
|
||||
ctx.host = host;
|
||||
ctx.user_cb = cb;
|
||||
ctx.userdata = userdata;
|
||||
|
||||
char* response_body = nullptr;
|
||||
int status_code = 0;
|
||||
|
||||
int ret = http->post_stream(
|
||||
host_name.c_str(), port.c_str(), target_path.c_str(), body.c_str(),
|
||||
headers_json.c_str(),
|
||||
@@ -525,7 +482,9 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
r.error = host ? host->strdup(
|
||||
json::value_to<std::string>(err["message"]).c_str()) : nullptr;
|
||||
}
|
||||
} catch (...) {}
|
||||
} catch (...) {
|
||||
if (host) host->log(DSTALK_LOG_WARN, "[openai] SSE error body parse error (ignored)");
|
||||
}
|
||||
}
|
||||
if (!r.error && host) {
|
||||
if (status_code <= 0)
|
||||
@@ -559,19 +518,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
|
||||
// 序列化累积的 tool_calls 为 JSON(兼容 OpenAI tool_calls 格式) / Serialize accumulated tool_calls to JSON (OpenAI-compatible tool_calls format)
|
||||
if (has_tool_calls) {
|
||||
json::array tc_array;
|
||||
for (auto& tc : ctx.tool_calls) {
|
||||
json::object tc_obj;
|
||||
tc_obj["index"] = tc.index;
|
||||
if (!tc.id.empty()) tc_obj["id"] = tc.id;
|
||||
tc_obj["type"] = "function";
|
||||
json::object func;
|
||||
if (!tc.name.empty()) func["name"] = tc.name;
|
||||
func["arguments"] = tc.arguments;
|
||||
tc_obj["function"] = func;
|
||||
tc_array.push_back(std::move(tc_obj));
|
||||
}
|
||||
std::string tc_json = json::serialize(tc_array);
|
||||
std::string tc_json = dstalk_ai::serialize_tool_calls(ctx.tool_calls);
|
||||
r.tool_calls_json = host ? host->strdup(tc_json.c_str()) : nullptr;
|
||||
} else {
|
||||
r.tool_calls_json = nullptr;
|
||||
@@ -581,6 +528,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat_stream exception: %s", e.what());
|
||||
if (response_body && host) host->free(response_body);
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 0;
|
||||
r.error = host ? host->strdup(e.what()) : nullptr;
|
||||
@@ -588,6 +536,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[openai] my_chat_stream unknown exception");
|
||||
if (response_body && host) host->free(response_body);
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 0;
|
||||
r.error = host ? host->strdup("unknown exception") : nullptr;
|
||||
@@ -602,10 +551,7 @@ static dstalk_chat_result_t my_chat_stream(
|
||||
static void my_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (!result || !host) return;
|
||||
if (result->content) { host->free((void*)result->content); result->content = nullptr; }
|
||||
if (result->error) { host->free((void*)result->error); result->error = nullptr; }
|
||||
if (result->tool_calls_json) { host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; }
|
||||
dstalk_ai::free_chat_result(host, result);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
@@ -638,7 +584,7 @@ static int on_init(const dstalk_host_api_t* host)
|
||||
|
||||
if (host) host->log(DSTALK_LOG_INFO, "[openai] initializing OpenAI-compatible AI plugin");
|
||||
|
||||
return host->register_service("ai.openai", 1, &g_service);
|
||||
return host->register_service("ai_openai", 1, &g_service);
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
|
||||
if (h && h->log) h->log(DSTALK_LOG_ERROR, "[openai] on_init exception: %s", e.what());
|
||||
@@ -656,7 +602,7 @@ static void on_shutdown()
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host) host->log(DSTALK_LOG_INFO, "[openai] shutdown");
|
||||
secure_zero(g_cfg.api_key.data(), g_cfg.api_key.size());
|
||||
dstalk_ai::secure_zero(g_cfg.api_key.data(), g_cfg.api_key.size());
|
||||
g_cfg.api_key.clear();
|
||||
g_http.store(nullptr, std::memory_order_release);
|
||||
g_config.store(nullptr, std::memory_order_release);
|
||||
@@ -674,7 +620,7 @@ static void on_shutdown()
|
||||
// 插件描述符 / Plugin descriptor
|
||||
// ============================================================================
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "openai-compat",
|
||||
/* .name = */ "openai_compat",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "OpenAI-compatible AI provider (OpenAI-compatible API) / OpenAI-compatible AI 提供者 (OpenAI 兼容 API)",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
|
||||
Reference in New Issue
Block a user