feat: add AI endpoint manager plugin with configuration and routing capabilities
Some checks failed
Some checks failed
- Introduced `ai_endpoint_mgr` plugin to manage multiple AI provider endpoints. - Added configuration reference documentation for `config.toml`. - Implemented endpoint loading, active endpoint switching, and model mutation. - Included error handling for missing endpoints and configuration failures. - Developed unit tests covering various scenarios including error paths and concurrency.
This commit is contained in:
@@ -6,3 +6,4 @@ add_subdirectory(ai_common) # 共享 AI 工具库(静态库)/ shared AI u
|
||||
add_subdirectory(context) # 依赖 session / depends on session
|
||||
add_subdirectory(openai) # 依赖 http, config, ai_common / depends on http, config, ai_common
|
||||
add_subdirectory(anthropic) # 依赖 http, config, ai_common / depends on http, config, ai_common
|
||||
add_subdirectory(ai_endpoint_mgr) # 路由多个 AI endpoint / routes multiple AI endpoints
|
||||
|
||||
32
plugins_upper/ai_endpoint_mgr/CMakeLists.txt
Normal file
32
plugins_upper/ai_endpoint_mgr/CMakeLists.txt
Normal file
@@ -0,0 +1,32 @@
|
||||
# ============================================================
|
||||
# AI endpoint manager plugin / AI endpoint manager 插件
|
||||
# ============================================================
|
||||
|
||||
find_package(Boost REQUIRED CONFIG)
|
||||
|
||||
add_library(plugin_ai_endpoint_mgr SHARED
|
||||
src/endpoint_mgr_plugin.cpp
|
||||
)
|
||||
|
||||
target_include_directories(plugin_ai_endpoint_mgr
|
||||
PRIVATE
|
||||
${CMAKE_SOURCE_DIR}/dstalk_core/include
|
||||
${CMAKE_SOURCE_DIR}/plugins_upper/ai_common/include
|
||||
)
|
||||
|
||||
target_link_libraries(plugin_ai_endpoint_mgr
|
||||
PRIVATE
|
||||
dstalk
|
||||
ai_common
|
||||
dstalk_boost_config
|
||||
boost::boost
|
||||
)
|
||||
|
||||
# cxx_std_20 已由 dstalk 和 ai_common (PUBLIC) 传播,无需重复声明
|
||||
# cxx_std_20 is already propagated by dstalk and ai_common (PUBLIC); no need to redeclare
|
||||
|
||||
set_target_properties(plugin_ai_endpoint_mgr PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
)
|
||||
400
plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp
Normal file
400
plugins_upper/ai_endpoint_mgr/src/endpoint_mgr_plugin.cpp
Normal file
@@ -0,0 +1,400 @@
|
||||
/*
|
||||
* @file endpoint_mgr_plugin.cpp
|
||||
* @brief AI endpoint manager: routes named endpoint configs to AI provider services.
|
||||
* AI endpoint 管理器:把命名 endpoint 配置路由到具体 AI provider 服务。
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
#include "dstalk/dstalk_host.h"
|
||||
#include "dstalk/dstalk_services.h"
|
||||
#include "ai_common.hpp"
|
||||
|
||||
#include <boost/json.hpp>
|
||||
#include <boost/json/src.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <exception>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <shared_mutex>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace json = boost::json;
|
||||
|
||||
struct EndpointConfig {
|
||||
std::string name;
|
||||
std::string provider;
|
||||
std::string base_url;
|
||||
std::string api_key;
|
||||
std::string model;
|
||||
int max_tokens = 4096;
|
||||
double temperature = 0.7;
|
||||
};
|
||||
|
||||
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
|
||||
static std::unordered_map<std::string, EndpointConfig> g_endpoints;
|
||||
static std::string g_active_endpoint;
|
||||
static std::shared_mutex g_endpoints_mutex;
|
||||
|
||||
// 按 provider 名称动态分配互斥锁;避免未知 provider 错误共享 OpenAI/Anthropic 专用锁
|
||||
// Dynamically allocate mutex per provider name; prevents unknown providers from incorrectly sharing the OpenAI/Anthropic-specific locks
|
||||
static std::shared_mutex g_provider_mutexes_mutex;
|
||||
static std::unordered_map<std::string, std::unique_ptr<std::mutex>> g_provider_mutexes;
|
||||
|
||||
static std::mutex& provider_mutex(const std::string& provider)
|
||||
{
|
||||
// 快速路径:读锁查找已有 mutex / Fast path: read lock to find existing mutex
|
||||
{
|
||||
std::shared_lock lock(g_provider_mutexes_mutex);
|
||||
auto it = g_provider_mutexes.find(provider);
|
||||
if (it != g_provider_mutexes.end()) return *it->second;
|
||||
}
|
||||
// 慢速路径:写锁创建新 mutex (双重检查) / Slow path: write lock to create new mutex (double-check)
|
||||
std::unique_lock lock(g_provider_mutexes_mutex);
|
||||
auto it = g_provider_mutexes.find(provider);
|
||||
if (it != g_provider_mutexes.end()) return *it->second;
|
||||
auto [new_it, _] = g_provider_mutexes.emplace(provider, std::make_unique<std::mutex>());
|
||||
return *new_it->second;
|
||||
}
|
||||
|
||||
static std::string trim_copy(std::string s)
|
||||
{
|
||||
auto is_space = [](unsigned char c) { return c == ' ' || c == '\t' || c == '\r' || c == '\n'; };
|
||||
while (!s.empty() && is_space(static_cast<unsigned char>(s.front()))) s.erase(s.begin());
|
||||
while (!s.empty() && is_space(static_cast<unsigned char>(s.back()))) s.pop_back();
|
||||
return s;
|
||||
}
|
||||
|
||||
static std::vector<std::string> split_csv(const char* raw)
|
||||
{
|
||||
std::vector<std::string> out;
|
||||
if (!raw || !*raw) return out;
|
||||
std::stringstream ss(raw);
|
||||
std::string item;
|
||||
while (std::getline(ss, item, ',')) {
|
||||
item = trim_copy(item);
|
||||
if (!item.empty()) out.push_back(item);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static int parse_int_or_default(const char* raw, int fallback)
|
||||
{
|
||||
if (!raw || !*raw) return fallback;
|
||||
char* end = nullptr;
|
||||
long v = std::strtol(raw, &end, 10);
|
||||
if (!end || *end != '\0' || v <= 0 || v > 1000000) return fallback;
|
||||
return static_cast<int>(v);
|
||||
}
|
||||
|
||||
static double parse_double_or_default(const char* raw, double fallback)
|
||||
{
|
||||
if (!raw || !*raw) return fallback;
|
||||
char* end = nullptr;
|
||||
double v = std::strtod(raw, &end);
|
||||
if (!end || *end != '\0' || v < 0.0 || v > 2.0) return fallback;
|
||||
return v;
|
||||
}
|
||||
|
||||
static const char* cfg_get(const dstalk_host_api_t* host, const std::string& key)
|
||||
{
|
||||
if (!host || !host->config_get) return nullptr;
|
||||
return host->config_get(key.c_str());
|
||||
}
|
||||
|
||||
static std::string cfg_get_copy(const dstalk_host_api_t* host, const std::string& key)
|
||||
{
|
||||
const char* value = cfg_get(host, key);
|
||||
return value ? std::string(value) : std::string();
|
||||
}
|
||||
|
||||
static const char* default_base_url_for_provider(const std::string& provider)
|
||||
{
|
||||
if (provider == "ai_anthropic") return "https://api.anthropic.com";
|
||||
if (provider == "ai_openai") return "https://api.openai.com/v1";
|
||||
return nullptr; // 未知 provider 必须显式配置 base_url / unknown provider must explicitly configure base_url
|
||||
}
|
||||
|
||||
static void clear_endpoints_locked()
|
||||
{
|
||||
for (auto& kv : g_endpoints) {
|
||||
dstalk_ai::secure_zero(kv.second.api_key.data(), kv.second.api_key.size());
|
||||
kv.second.api_key.clear();
|
||||
}
|
||||
g_endpoints.clear();
|
||||
g_active_endpoint.clear();
|
||||
}
|
||||
|
||||
static const dstalk_ai_service_t* lookup_ai_service(const EndpointConfig& ep)
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (!host || !host->query_service || ep.provider.empty()) return nullptr;
|
||||
return static_cast<const dstalk_ai_service_t*>(host->query_service(ep.provider.c_str(), 1));
|
||||
}
|
||||
|
||||
static dstalk_chat_result_t make_error(const char* msg, int status = 0)
|
||||
{
|
||||
dstalk_chat_result_t r = {};
|
||||
r.ok = 0;
|
||||
r.http_status = status;
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
r.error = (host && host->strdup) ? host->strdup(msg ? msg : "endpoint manager error") : nullptr;
|
||||
return r;
|
||||
}
|
||||
|
||||
static bool load_endpoint(const dstalk_host_api_t* host, const std::string& name, EndpointConfig& out)
|
||||
{
|
||||
std::string prefix = "endpoint." + name + ".";
|
||||
std::string provider = cfg_get_copy(host, prefix + "provider");
|
||||
std::string base_url = cfg_get_copy(host, prefix + "base_url");
|
||||
std::string api_key = cfg_get_copy(host, prefix + "api_key");
|
||||
std::string model = cfg_get_copy(host, prefix + "model");
|
||||
if (provider.empty() || model.empty()) return false;
|
||||
|
||||
out.name = name;
|
||||
out.provider = provider;
|
||||
|
||||
// 设定 base_url: 显式配置优先,其次用已知 provider 的默认值;未知 provider 必须显式配置
|
||||
// Determine base_url: explicit config first, then known provider default; unknown providers must configure explicitly
|
||||
if (!base_url.empty()) {
|
||||
out.base_url = base_url;
|
||||
} else {
|
||||
const char* default_url = default_base_url_for_provider(out.provider);
|
||||
if (default_url) {
|
||||
out.base_url = default_url;
|
||||
} else {
|
||||
return false; // 未知 provider 且未配置 base_url / unknown provider without explicit base_url
|
||||
}
|
||||
}
|
||||
|
||||
out.api_key = api_key;
|
||||
out.model = model;
|
||||
out.max_tokens = parse_int_or_default(cfg_get(host, prefix + "max_tokens"), 4096);
|
||||
out.temperature = parse_double_or_default(cfg_get(host, prefix + "temperature"), 0.7);
|
||||
return host && host->query_service && host->query_service(out.provider.c_str(), 1) != nullptr;
|
||||
}
|
||||
|
||||
static int reload_endpoints_locked(const dstalk_host_api_t* host)
|
||||
{
|
||||
clear_endpoints_locked();
|
||||
if (!host) return 0;
|
||||
|
||||
std::vector<std::string> names = split_csv(cfg_get(host, "endpoints.names"));
|
||||
if (names.empty()) return 0;
|
||||
|
||||
for (const std::string& name : names) {
|
||||
if (g_endpoints.find(name) != g_endpoints.end()) {
|
||||
if (host->log) host->log(DSTALK_LOG_WARN, "[ai_endpoint_mgr] skipping duplicate endpoint '%s'", name.c_str());
|
||||
continue;
|
||||
}
|
||||
EndpointConfig ep;
|
||||
if (load_endpoint(host, name, ep)) {
|
||||
if (g_active_endpoint.empty()) g_active_endpoint = name;
|
||||
g_endpoints.emplace(name, std::move(ep));
|
||||
} else if (host->log) {
|
||||
host->log(DSTALK_LOG_WARN, "[ai_endpoint_mgr] skipping invalid endpoint '%s'", name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
const char* active = cfg_get(host, "endpoints.active");
|
||||
if (active && g_endpoints.count(active)) {
|
||||
g_active_endpoint = active;
|
||||
}
|
||||
return static_cast<int>(g_endpoints.size());
|
||||
}
|
||||
|
||||
static int mgr_count()
|
||||
{
|
||||
std::shared_lock lock(g_endpoints_mutex);
|
||||
return static_cast<int>(g_endpoints.size());
|
||||
}
|
||||
|
||||
static char* mgr_list_json()
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (!host || !host->strdup) return nullptr;
|
||||
json::array arr;
|
||||
{
|
||||
std::shared_lock lock(g_endpoints_mutex);
|
||||
std::vector<std::string> names;
|
||||
names.reserve(g_endpoints.size());
|
||||
for (const auto& kv : g_endpoints) names.push_back(kv.first);
|
||||
std::sort(names.begin(), names.end());
|
||||
for (const auto& name : names) {
|
||||
const auto& ep = g_endpoints.at(name);
|
||||
json::object o;
|
||||
o["name"] = ep.name;
|
||||
o["provider"] = ep.provider;
|
||||
o["base_url"] = ep.base_url;
|
||||
o["model"] = ep.model;
|
||||
o["active"] = (ep.name == g_active_endpoint);
|
||||
arr.emplace_back(std::move(o));
|
||||
}
|
||||
}
|
||||
return host->strdup(json::serialize(arr).c_str());
|
||||
}
|
||||
|
||||
static const char* mgr_get_active()
|
||||
{
|
||||
static thread_local std::string active;
|
||||
std::shared_lock lock(g_endpoints_mutex);
|
||||
active = g_active_endpoint;
|
||||
return active.empty() ? nullptr : active.c_str();
|
||||
}
|
||||
|
||||
static int mgr_set_active(const char* endpoint_name)
|
||||
{
|
||||
if (!endpoint_name || !*endpoint_name) return -1;
|
||||
std::unique_lock lock(g_endpoints_mutex);
|
||||
auto it = g_endpoints.find(endpoint_name);
|
||||
if (it == g_endpoints.end()) return -2;
|
||||
g_active_endpoint = endpoint_name;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static EndpointConfig lookup_endpoint(const char* endpoint_name)
|
||||
{
|
||||
std::shared_lock lock(g_endpoints_mutex);
|
||||
std::string name;
|
||||
if (endpoint_name && *endpoint_name) name = endpoint_name;
|
||||
else name = g_active_endpoint;
|
||||
auto it = g_endpoints.find(name);
|
||||
if (it == g_endpoints.end()) return {};
|
||||
return it->second;
|
||||
}
|
||||
|
||||
static int mgr_set_model(const char* endpoint_name, const char* model)
|
||||
{
|
||||
if (!model || !*model) return -1;
|
||||
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
std::string selected;
|
||||
{
|
||||
std::unique_lock lock(g_endpoints_mutex);
|
||||
selected = (endpoint_name && *endpoint_name) ? endpoint_name : g_active_endpoint;
|
||||
auto it = g_endpoints.find(selected);
|
||||
if (it == g_endpoints.end()) return -2;
|
||||
it->second.model = model;
|
||||
}
|
||||
|
||||
if (host && host->config_set) {
|
||||
std::string key = "endpoint." + selected + ".model";
|
||||
host->config_set(key.c_str(), model);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static dstalk_chat_result_t mgr_chat(const char* endpoint_name,
|
||||
const dstalk_message_t* history,
|
||||
int history_len,
|
||||
const char* user_input,
|
||||
const char* tools_json)
|
||||
{
|
||||
// 防御: history_len > 0 时 history 不得为 nullptr / Guard: history must not be null when history_len > 0
|
||||
if (history_len > 0 && history == nullptr) return make_error("null history with non-zero length");
|
||||
EndpointConfig ep = lookup_endpoint(endpoint_name);
|
||||
const dstalk_ai_service_t* service = lookup_ai_service(ep);
|
||||
if (!service) return make_error("endpoint not found");
|
||||
std::lock_guard<std::mutex> guard(provider_mutex(ep.provider));
|
||||
int rc = service->configure(ep.provider.c_str(), ep.base_url.c_str(), ep.api_key.c_str(),
|
||||
ep.model.c_str(), ep.max_tokens, ep.temperature);
|
||||
if (rc != 0) return make_error("endpoint configure failed");
|
||||
return service->chat(history, history_len, user_input, tools_json);
|
||||
}
|
||||
|
||||
static dstalk_chat_result_t mgr_chat_stream(const char* endpoint_name,
|
||||
const dstalk_message_t* history,
|
||||
int history_len,
|
||||
const char* user_input,
|
||||
dstalk_stream_cb cb,
|
||||
void* userdata)
|
||||
{
|
||||
// 防御: history_len > 0 时 history 不得为 nullptr / Guard: history must not be null when history_len > 0
|
||||
if (history_len > 0 && history == nullptr) return make_error("null history with non-zero length");
|
||||
EndpointConfig ep = lookup_endpoint(endpoint_name);
|
||||
const dstalk_ai_service_t* service = lookup_ai_service(ep);
|
||||
if (!service) return make_error("endpoint not found");
|
||||
std::lock_guard<std::mutex> guard(provider_mutex(ep.provider));
|
||||
int rc = service->configure(ep.provider.c_str(), ep.base_url.c_str(), ep.api_key.c_str(),
|
||||
ep.model.c_str(), ep.max_tokens, ep.temperature);
|
||||
if (rc != 0) return make_error("endpoint configure failed");
|
||||
return service->chat_stream(history, history_len, user_input, cb, userdata);
|
||||
}
|
||||
|
||||
static void mgr_free_result(dstalk_chat_result_t* result)
|
||||
{
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
dstalk_ai::free_chat_result(host, result);
|
||||
}
|
||||
|
||||
static dstalk_ai_endpoint_mgr_t g_service = {
|
||||
&mgr_count,
|
||||
&mgr_list_json,
|
||||
&mgr_get_active,
|
||||
&mgr_set_active,
|
||||
&mgr_set_model,
|
||||
&mgr_chat,
|
||||
&mgr_chat_stream,
|
||||
&mgr_free_result,
|
||||
};
|
||||
|
||||
static int on_init(const dstalk_host_api_t* host)
|
||||
{
|
||||
try {
|
||||
if (!host) return -1;
|
||||
g_host.store(host, std::memory_order_release);
|
||||
{
|
||||
std::unique_lock lock(g_endpoints_mutex);
|
||||
reload_endpoints_locked(host);
|
||||
}
|
||||
if (host->log) host->log(DSTALK_LOG_INFO, "[ai_endpoint_mgr] loaded %d endpoint(s)", mgr_count());
|
||||
return host->register_service("ai_endpoint_mgr", 1, &g_service);
|
||||
} catch (const std::exception& e) {
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_init exception: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_init unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
static void on_shutdown()
|
||||
{
|
||||
try {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_INFO, "[ai_endpoint_mgr] shutdown");
|
||||
std::unique_lock lock(g_endpoints_mutex);
|
||||
clear_endpoints_locked();
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
} catch (const std::exception& e) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_shutdown exception: %s", e.what());
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
} catch (...) {
|
||||
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
|
||||
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_shutdown unknown exception");
|
||||
g_host.store(nullptr, std::memory_order_release);
|
||||
}
|
||||
}
|
||||
|
||||
static dstalk_plugin_info_t g_info = {
|
||||
/* .name = */ "ai_endpoint_mgr",
|
||||
/* .version = */ "1.0.0",
|
||||
/* .description = */ "AI endpoint manager for multiple named provider/model endpoints / 多命名 AI endpoint 管理器",
|
||||
/* .api_version = */ DSTALK_API_VERSION,
|
||||
/* .dependencies = */ { "openai_compat", "anthropic_ai", NULL },
|
||||
/* .on_init = */ on_init,
|
||||
/* .on_shutdown = */ on_shutdown,
|
||||
/* .on_event = */ nullptr,
|
||||
};
|
||||
|
||||
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
|
||||
{
|
||||
return &g_info;
|
||||
}
|
||||
@@ -65,6 +65,7 @@ static std::string build_request_json(
|
||||
std::string system_prompt;
|
||||
json::array msgs;
|
||||
|
||||
if (history) { // 防御: history 为空时跳过历史遍历 / Defensive: skip history iteration when null
|
||||
for (int i = 0; i < history_len; ++i) {
|
||||
const auto& m = history[i];
|
||||
if (m.role && std::strcmp(m.role, "system") == 0) {
|
||||
@@ -77,6 +78,7 @@ static std::string build_request_json(
|
||||
obj["content"] = m.content ? m.content : "";
|
||||
msgs.push_back(obj);
|
||||
}
|
||||
} // if (history)
|
||||
|
||||
// 追加当前用户输入 / Append current user input
|
||||
{
|
||||
|
||||
@@ -49,6 +49,7 @@ static std::string build_headers_json(const std::string& auth_header_value)
|
||||
static void append_history(json::array& msgs,
|
||||
const dstalk_message_t* history, int history_len)
|
||||
{
|
||||
if (!history) return; // 防御: history 为空时直接返回 / Defensive: return early when history is null
|
||||
for (int i = 0; i < history_len; ++i) {
|
||||
const auto& m = history[i];
|
||||
json::object obj;
|
||||
|
||||
Reference in New Issue
Block a user