- New plugins_upper/ai_common/ static library: shared PluginConfig, ToolCallAccum, StreamContext, secure_zero, extract_host_port, serialize_tool_calls, free_chat_result - Refactored openai/anthropic plugins to use dstalk_ai:: namespace from ai_common - Fixed anthropic g_config raw pointer → std::atomic (data race) - Added SSE parse error counter with threshold abort (kMaxSseParseErrors=5) - Fixed missing closing brace in both plugins' error-body catch block - Updated test targets: ai_common include path + link, using namespace dstalk_ai - plugin_loader_test: added stub_unreg + service_registry.cpp for unregister_service - Includes pre-existing uncommitted changes from prior waves Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
119 lines
4.7 KiB
C++
119 lines
4.7 KiB
C++
/*
|
||
* @file ai_common.hpp
|
||
* @brief Shared types and utilities for AI provider plugins (OpenAI / Anthropic).
|
||
* AI 提供者插件(OpenAI / Anthropic)的共享类型和工具函数。
|
||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||
*/
|
||
|
||
#pragma once
|
||
|
||
#include "dstalk/dstalk_host.h"
|
||
#include "dstalk/dstalk_services.h"
|
||
|
||
#include <boost/json.hpp>
|
||
#include <string>
|
||
#include <vector>
|
||
|
||
namespace dstalk_ai {
|
||
|
||
namespace json = boost::json;
|
||
|
||
// ============================================================================
|
||
// 共享类型 / Shared types
|
||
// ============================================================================
|
||
|
||
/// Provider connection configuration / 服务商连接配置
|
||
struct PluginConfig {
|
||
std::string provider;
|
||
std::string base_url;
|
||
std::string api_key;
|
||
std::string model;
|
||
int max_tokens = 4096;
|
||
double temperature = 0.7;
|
||
};
|
||
|
||
/// Per-index tool-call accumulator for SSE streaming / SSE 流式传输的按索引工具调用累积器
|
||
struct ToolCallAccum {
|
||
int index = -1;
|
||
std::string id;
|
||
std::string name;
|
||
std::string arguments; // 增量拼接的 JSON arguments / incrementally concatenated JSON arguments
|
||
};
|
||
|
||
/// Streaming context passed through SSE callbacks / 通过 SSE 回调传递的流式上下文
|
||
struct StreamContext {
|
||
const dstalk_host_api_t* host = nullptr;
|
||
dstalk_stream_cb user_cb = nullptr;
|
||
void* userdata = nullptr;
|
||
std::string accumulated;
|
||
std::vector<ToolCallAccum> tool_calls;
|
||
int sse_parse_errors = 0; // 连续 SSE 解析错误计数器 / consecutive SSE parse error counter
|
||
bool streaming_ok = true; // OpenAI: tracks stream health
|
||
bool saw_data_line = false; // Anthropic: tracks if any SSE data received
|
||
};
|
||
|
||
/// Maximum consecutive SSE parse errors before aborting the stream / 中止流之前的最大连续 SSE 解析错误数
|
||
inline constexpr int kMaxSseParseErrors = 5;
|
||
|
||
// ============================================================================
|
||
// 函数声明(实现于 ai_common.cpp) / Function declarations (implemented in ai_common.cpp)
|
||
// ============================================================================
|
||
|
||
/// Securely zero memory through volatile write to prevent compiler optimization.
|
||
/// 通过 volatile 写入零来安全擦除内存,防止编译器优化。
|
||
void secure_zero(void* p, size_t n);
|
||
|
||
/// Parse a URL into scheme, host, port, and target path components.
|
||
/// 将 URL 解析为 scheme、host、port 和 target path 组件。
|
||
bool extract_host_port(const std::string& url,
|
||
std::string& scheme_out, std::string& host_out,
|
||
std::string& port_out, std::string& target_out);
|
||
|
||
// ============================================================================
|
||
// 内联工具函数 / Inline utility functions
|
||
// ============================================================================
|
||
|
||
/// Free all host-allocated string fields in a chat result struct.
|
||
/// 释放 chat result 结构体中所有主机分配的字符串字段。
|
||
inline void free_chat_result(const dstalk_host_api_t* host, dstalk_chat_result_t* result) {
|
||
if (!result || !host) return;
|
||
if (result->content) { host->free((void*)result->content); result->content = nullptr; }
|
||
if (result->error) { host->free((void*)result->error); result->error = nullptr; }
|
||
if (result->tool_calls_json) { host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; }
|
||
}
|
||
|
||
/// Cache tools_json from the tools service for reuse in chat/chat_stream.
|
||
/// 从 tools service 缓存 tools_json,供 chat/chat_stream 复用。
|
||
inline void cache_tools_json(const dstalk_host_api_t* host, std::string& tools_json) {
|
||
if (!host) return;
|
||
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
|
||
host->query_service("tools", 1));
|
||
if (tools_svc && tools_svc->get_tools_json) {
|
||
char* j = tools_svc->get_tools_json();
|
||
if (j) {
|
||
tools_json = j;
|
||
host->free(j);
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Serialize accumulated tool_calls into OpenAI-compatible JSON array.
|
||
/// 将累积的 tool_calls 序列化为兼容 OpenAI 格式的 JSON 数组。
|
||
inline std::string serialize_tool_calls(const std::vector<ToolCallAccum>& tool_calls) {
|
||
json::array tc_array;
|
||
for (const auto& tc : tool_calls) {
|
||
json::object tc_obj;
|
||
tc_obj["index"] = tc.index;
|
||
if (!tc.id.empty()) tc_obj["id"] = tc.id;
|
||
tc_obj["type"] = "function";
|
||
json::object func;
|
||
if (!tc.name.empty()) func["name"] = tc.name;
|
||
func["arguments"] = tc.arguments;
|
||
tc_obj["function"] = func;
|
||
tc_array.push_back(std::move(tc_obj));
|
||
}
|
||
return json::serialize(tc_array);
|
||
}
|
||
|
||
} // namespace dstalk_ai
|