Files
dstalk/plugins_upper/ai_common/include/ai_common.hpp
XiuChengWu 8faa02c3d5 W17: extract ai_common shared module + fix anthropic data race + brace bugs
- New plugins_upper/ai_common/ static library: shared PluginConfig, ToolCallAccum,
  StreamContext, secure_zero, extract_host_port, serialize_tool_calls, free_chat_result
- Refactored openai/anthropic plugins to use dstalk_ai:: namespace from ai_common
- Fixed anthropic g_config raw pointer → std::atomic (data race)
- Added SSE parse error counter with threshold abort (kMaxSseParseErrors=5)
- Fixed missing closing brace in both plugins' error-body catch block
- Updated test targets: ai_common include path + link, using namespace dstalk_ai
- plugin_loader_test: added stub_unreg + service_registry.cpp for unregister_service
- Includes pre-existing uncommitted changes from prior waves

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-31 16:58:25 +08:00

119 lines
4.7 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* @file ai_common.hpp
* @brief Shared types and utilities for AI provider plugins (OpenAI / Anthropic).
* AI 提供者插件OpenAI / Anthropic的共享类型和工具函数。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#pragma once
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h"
#include <boost/json.hpp>
#include <string>
#include <vector>
namespace dstalk_ai {
namespace json = boost::json;
// ============================================================================
// 共享类型 / Shared types
// ============================================================================
/// Provider connection configuration / 服务商连接配置
struct PluginConfig {
std::string provider;
std::string base_url;
std::string api_key;
std::string model;
int max_tokens = 4096;
double temperature = 0.7;
};
/// Per-index tool-call accumulator for SSE streaming / SSE 流式传输的按索引工具调用累积器
struct ToolCallAccum {
int index = -1;
std::string id;
std::string name;
std::string arguments; // 增量拼接的 JSON arguments / incrementally concatenated JSON arguments
};
/// Streaming context passed through SSE callbacks / 通过 SSE 回调传递的流式上下文
struct StreamContext {
const dstalk_host_api_t* host = nullptr;
dstalk_stream_cb user_cb = nullptr;
void* userdata = nullptr;
std::string accumulated;
std::vector<ToolCallAccum> tool_calls;
int sse_parse_errors = 0; // 连续 SSE 解析错误计数器 / consecutive SSE parse error counter
bool streaming_ok = true; // OpenAI: tracks stream health
bool saw_data_line = false; // Anthropic: tracks if any SSE data received
};
/// Maximum consecutive SSE parse errors before aborting the stream / 中止流之前的最大连续 SSE 解析错误数
inline constexpr int kMaxSseParseErrors = 5;
// ============================================================================
// 函数声明(实现于 ai_common.cpp / Function declarations (implemented in ai_common.cpp)
// ============================================================================
/// Securely zero memory through volatile write to prevent compiler optimization.
/// 通过 volatile 写入零来安全擦除内存,防止编译器优化。
void secure_zero(void* p, size_t n);
/// Parse a URL into scheme, host, port, and target path components.
/// 将 URL 解析为 scheme、host、port 和 target path 组件。
bool extract_host_port(const std::string& url,
std::string& scheme_out, std::string& host_out,
std::string& port_out, std::string& target_out);
// ============================================================================
// 内联工具函数 / Inline utility functions
// ============================================================================
/// Free all host-allocated string fields in a chat result struct.
/// 释放 chat result 结构体中所有主机分配的字符串字段。
inline void free_chat_result(const dstalk_host_api_t* host, dstalk_chat_result_t* result) {
if (!result || !host) return;
if (result->content) { host->free((void*)result->content); result->content = nullptr; }
if (result->error) { host->free((void*)result->error); result->error = nullptr; }
if (result->tool_calls_json) { host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; }
}
/// Cache tools_json from the tools service for reuse in chat/chat_stream.
/// 从 tools service 缓存 tools_json供 chat/chat_stream 复用。
inline void cache_tools_json(const dstalk_host_api_t* host, std::string& tools_json) {
if (!host) return;
auto* tools_svc = reinterpret_cast<const dstalk_tools_service_t*>(
host->query_service("tools", 1));
if (tools_svc && tools_svc->get_tools_json) {
char* j = tools_svc->get_tools_json();
if (j) {
tools_json = j;
host->free(j);
}
}
}
/// Serialize accumulated tool_calls into OpenAI-compatible JSON array.
/// 将累积的 tool_calls 序列化为兼容 OpenAI 格式的 JSON 数组。
inline std::string serialize_tool_calls(const std::vector<ToolCallAccum>& tool_calls) {
json::array tc_array;
for (const auto& tc : tool_calls) {
json::object tc_obj;
tc_obj["index"] = tc.index;
if (!tc.id.empty()) tc_obj["id"] = tc.id;
tc_obj["type"] = "function";
json::object func;
if (!tc.name.empty()) func["name"] = tc.name;
func["arguments"] = tc.arguments;
tc_obj["function"] = func;
tc_array.push_back(std::move(tc_obj));
}
return json::serialize(tc_array);
}
} // namespace dstalk_ai