/* * @file ai_common.hpp * @brief Shared types and utilities for AI provider plugins (OpenAI / Anthropic). * AI 提供者插件(OpenAI / Anthropic)的共享类型和工具函数。 * Copyright (c) 2026 dstalk contributors. GPLv3. */ #pragma once #include "dstalk/dstalk_host.h" #include "dstalk/dstalk_services.h" #include #include #include namespace dstalk_ai { namespace json = boost::json; // ============================================================================ // 共享类型 / Shared types // ============================================================================ /// Provider connection configuration / 服务商连接配置 struct PluginConfig { std::string provider; std::string base_url; std::string api_key; std::string model; int max_tokens = 4096; double temperature = 0.7; }; /// Per-index tool-call accumulator for SSE streaming / SSE 流式传输的按索引工具调用累积器 struct ToolCallAccum { int index = -1; std::string id; std::string name; std::string arguments; // 增量拼接的 JSON arguments / incrementally concatenated JSON arguments }; /// Streaming context passed through SSE callbacks / 通过 SSE 回调传递的流式上下文 struct StreamContext { const dstalk_host_api_t* host = nullptr; dstalk_stream_cb user_cb = nullptr; void* userdata = nullptr; std::string accumulated; std::vector tool_calls; int sse_parse_errors = 0; // 连续 SSE 解析错误计数器 / consecutive SSE parse error counter bool streaming_ok = true; // OpenAI: tracks stream health bool saw_data_line = false; // Anthropic: tracks if any SSE data received }; /// Maximum consecutive SSE parse errors before aborting the stream / 中止流之前的最大连续 SSE 解析错误数 inline constexpr int kMaxSseParseErrors = 5; // ============================================================================ // 函数声明(实现于 ai_common.cpp) / Function declarations (implemented in ai_common.cpp) // ============================================================================ /// Securely zero memory through volatile write to prevent compiler optimization. /// 通过 volatile 写入零来安全擦除内存,防止编译器优化。 void secure_zero(void* p, size_t n); /// Parse a URL into scheme, host, port, and target path components. /// 将 URL 解析为 scheme、host、port 和 target path 组件。 bool extract_host_port(const std::string& url, std::string& scheme_out, std::string& host_out, std::string& port_out, std::string& target_out); // ============================================================================ // 内联工具函数 / Inline utility functions // ============================================================================ /// Free all host-allocated string fields in a chat result struct. /// 释放 chat result 结构体中所有主机分配的字符串字段。 inline void free_chat_result(const dstalk_host_api_t* host, dstalk_chat_result_t* result) { if (!result || !host) return; if (result->content) { host->free((void*)result->content); result->content = nullptr; } if (result->error) { host->free((void*)result->error); result->error = nullptr; } if (result->tool_calls_json) { host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; } } /// Cache tools_json from the tools service for reuse in chat/chat_stream. /// 从 tools service 缓存 tools_json,供 chat/chat_stream 复用。 inline void cache_tools_json(const dstalk_host_api_t* host, std::string& tools_json) { if (!host) return; auto* tools_svc = reinterpret_cast( host->query_service("tools", 1)); if (tools_svc && tools_svc->get_tools_json) { char* j = tools_svc->get_tools_json(); if (j) { tools_json = j; host->free(j); } } } /// Serialize accumulated tool_calls into OpenAI-compatible JSON array. /// 将累积的 tool_calls 序列化为兼容 OpenAI 格式的 JSON 数组。 inline std::string serialize_tool_calls(const std::vector& tool_calls) { json::array tc_array; for (const auto& tc : tool_calls) { json::object tc_obj; tc_obj["index"] = tc.index; if (!tc.id.empty()) tc_obj["id"] = tc.id; tc_obj["type"] = "function"; json::object func; if (!tc.name.empty()) func["name"] = tc.name; func["arguments"] = tc.arguments; tc_obj["function"] = func; tc_array.push_back(std::move(tc_obj)); } return json::serialize(tc_array); } } // namespace dstalk_ai