feat: add AI endpoint manager plugin with configuration and routing capabilities
Some checks failed
CI / Determine matrix (push) Has been cancelled
CI / Sanitizer (ASan+UBSan) / ubuntu-24.04 (push) Has been cancelled
CI / Coverage (gcovr) / ubuntu-24.04 (push) Has been cancelled
CI / ${{ matrix.os }} / ${{ matrix.build_type }} (push) Has been cancelled

- Introduced `ai_endpoint_mgr` plugin to manage multiple AI provider endpoints.
- Added configuration reference documentation for `config.toml`.
- Implemented endpoint loading, active endpoint switching, and model mutation.
- Included error handling for missing endpoints and configuration failures.
- Developed unit tests covering various scenarios including error paths and concurrency.
This commit is contained in:
2026-06-03 21:07:25 +08:00
parent 28ae90a6cc
commit 4745ce1f1c
18 changed files with 1570 additions and 34 deletions

View File

@@ -0,0 +1,32 @@
# ============================================================
# AI endpoint manager plugin / AI endpoint manager 插件
# ============================================================
find_package(Boost REQUIRED CONFIG)
add_library(plugin_ai_endpoint_mgr SHARED
src/endpoint_mgr_plugin.cpp
)
target_include_directories(plugin_ai_endpoint_mgr
PRIVATE
${CMAKE_SOURCE_DIR}/dstalk_core/include
${CMAKE_SOURCE_DIR}/plugins_upper/ai_common/include
)
target_link_libraries(plugin_ai_endpoint_mgr
PRIVATE
dstalk
ai_common
dstalk_boost_config
boost::boost
)
# cxx_std_20 已由 dstalk 和 ai_common (PUBLIC) 传播,无需重复声明
# cxx_std_20 is already propagated by dstalk and ai_common (PUBLIC); no need to redeclare
set_target_properties(plugin_ai_endpoint_mgr PROPERTIES
PREFIX ""
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
)

View File

@@ -0,0 +1,400 @@
/*
* @file endpoint_mgr_plugin.cpp
* @brief AI endpoint manager: routes named endpoint configs to AI provider services.
* AI endpoint 管理器:把命名 endpoint 配置路由到具体 AI provider 服务。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h"
#include "ai_common.hpp"
#include <boost/json.hpp>
#include <boost/json/src.hpp>
#include <algorithm>
#include <atomic>
#include <cstdlib>
#include <cstring>
#include <exception>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <sstream>
#include <string>
#include <unordered_map>
#include <vector>
namespace json = boost::json;
struct EndpointConfig {
std::string name;
std::string provider;
std::string base_url;
std::string api_key;
std::string model;
int max_tokens = 4096;
double temperature = 0.7;
};
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::unordered_map<std::string, EndpointConfig> g_endpoints;
static std::string g_active_endpoint;
static std::shared_mutex g_endpoints_mutex;
// 按 provider 名称动态分配互斥锁;避免未知 provider 错误共享 OpenAI/Anthropic 专用锁
// Dynamically allocate mutex per provider name; prevents unknown providers from incorrectly sharing the OpenAI/Anthropic-specific locks
static std::shared_mutex g_provider_mutexes_mutex;
static std::unordered_map<std::string, std::unique_ptr<std::mutex>> g_provider_mutexes;
static std::mutex& provider_mutex(const std::string& provider)
{
// 快速路径:读锁查找已有 mutex / Fast path: read lock to find existing mutex
{
std::shared_lock lock(g_provider_mutexes_mutex);
auto it = g_provider_mutexes.find(provider);
if (it != g_provider_mutexes.end()) return *it->second;
}
// 慢速路径:写锁创建新 mutex (双重检查) / Slow path: write lock to create new mutex (double-check)
std::unique_lock lock(g_provider_mutexes_mutex);
auto it = g_provider_mutexes.find(provider);
if (it != g_provider_mutexes.end()) return *it->second;
auto [new_it, _] = g_provider_mutexes.emplace(provider, std::make_unique<std::mutex>());
return *new_it->second;
}
static std::string trim_copy(std::string s)
{
auto is_space = [](unsigned char c) { return c == ' ' || c == '\t' || c == '\r' || c == '\n'; };
while (!s.empty() && is_space(static_cast<unsigned char>(s.front()))) s.erase(s.begin());
while (!s.empty() && is_space(static_cast<unsigned char>(s.back()))) s.pop_back();
return s;
}
static std::vector<std::string> split_csv(const char* raw)
{
std::vector<std::string> out;
if (!raw || !*raw) return out;
std::stringstream ss(raw);
std::string item;
while (std::getline(ss, item, ',')) {
item = trim_copy(item);
if (!item.empty()) out.push_back(item);
}
return out;
}
static int parse_int_or_default(const char* raw, int fallback)
{
if (!raw || !*raw) return fallback;
char* end = nullptr;
long v = std::strtol(raw, &end, 10);
if (!end || *end != '\0' || v <= 0 || v > 1000000) return fallback;
return static_cast<int>(v);
}
static double parse_double_or_default(const char* raw, double fallback)
{
if (!raw || !*raw) return fallback;
char* end = nullptr;
double v = std::strtod(raw, &end);
if (!end || *end != '\0' || v < 0.0 || v > 2.0) return fallback;
return v;
}
static const char* cfg_get(const dstalk_host_api_t* host, const std::string& key)
{
if (!host || !host->config_get) return nullptr;
return host->config_get(key.c_str());
}
static std::string cfg_get_copy(const dstalk_host_api_t* host, const std::string& key)
{
const char* value = cfg_get(host, key);
return value ? std::string(value) : std::string();
}
static const char* default_base_url_for_provider(const std::string& provider)
{
if (provider == "ai_anthropic") return "https://api.anthropic.com";
if (provider == "ai_openai") return "https://api.openai.com/v1";
return nullptr; // 未知 provider 必须显式配置 base_url / unknown provider must explicitly configure base_url
}
static void clear_endpoints_locked()
{
for (auto& kv : g_endpoints) {
dstalk_ai::secure_zero(kv.second.api_key.data(), kv.second.api_key.size());
kv.second.api_key.clear();
}
g_endpoints.clear();
g_active_endpoint.clear();
}
static const dstalk_ai_service_t* lookup_ai_service(const EndpointConfig& ep)
{
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (!host || !host->query_service || ep.provider.empty()) return nullptr;
return static_cast<const dstalk_ai_service_t*>(host->query_service(ep.provider.c_str(), 1));
}
static dstalk_chat_result_t make_error(const char* msg, int status = 0)
{
dstalk_chat_result_t r = {};
r.ok = 0;
r.http_status = status;
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
r.error = (host && host->strdup) ? host->strdup(msg ? msg : "endpoint manager error") : nullptr;
return r;
}
static bool load_endpoint(const dstalk_host_api_t* host, const std::string& name, EndpointConfig& out)
{
std::string prefix = "endpoint." + name + ".";
std::string provider = cfg_get_copy(host, prefix + "provider");
std::string base_url = cfg_get_copy(host, prefix + "base_url");
std::string api_key = cfg_get_copy(host, prefix + "api_key");
std::string model = cfg_get_copy(host, prefix + "model");
if (provider.empty() || model.empty()) return false;
out.name = name;
out.provider = provider;
// 设定 base_url: 显式配置优先,其次用已知 provider 的默认值;未知 provider 必须显式配置
// Determine base_url: explicit config first, then known provider default; unknown providers must configure explicitly
if (!base_url.empty()) {
out.base_url = base_url;
} else {
const char* default_url = default_base_url_for_provider(out.provider);
if (default_url) {
out.base_url = default_url;
} else {
return false; // 未知 provider 且未配置 base_url / unknown provider without explicit base_url
}
}
out.api_key = api_key;
out.model = model;
out.max_tokens = parse_int_or_default(cfg_get(host, prefix + "max_tokens"), 4096);
out.temperature = parse_double_or_default(cfg_get(host, prefix + "temperature"), 0.7);
return host && host->query_service && host->query_service(out.provider.c_str(), 1) != nullptr;
}
static int reload_endpoints_locked(const dstalk_host_api_t* host)
{
clear_endpoints_locked();
if (!host) return 0;
std::vector<std::string> names = split_csv(cfg_get(host, "endpoints.names"));
if (names.empty()) return 0;
for (const std::string& name : names) {
if (g_endpoints.find(name) != g_endpoints.end()) {
if (host->log) host->log(DSTALK_LOG_WARN, "[ai_endpoint_mgr] skipping duplicate endpoint '%s'", name.c_str());
continue;
}
EndpointConfig ep;
if (load_endpoint(host, name, ep)) {
if (g_active_endpoint.empty()) g_active_endpoint = name;
g_endpoints.emplace(name, std::move(ep));
} else if (host->log) {
host->log(DSTALK_LOG_WARN, "[ai_endpoint_mgr] skipping invalid endpoint '%s'", name.c_str());
}
}
const char* active = cfg_get(host, "endpoints.active");
if (active && g_endpoints.count(active)) {
g_active_endpoint = active;
}
return static_cast<int>(g_endpoints.size());
}
static int mgr_count()
{
std::shared_lock lock(g_endpoints_mutex);
return static_cast<int>(g_endpoints.size());
}
static char* mgr_list_json()
{
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (!host || !host->strdup) return nullptr;
json::array arr;
{
std::shared_lock lock(g_endpoints_mutex);
std::vector<std::string> names;
names.reserve(g_endpoints.size());
for (const auto& kv : g_endpoints) names.push_back(kv.first);
std::sort(names.begin(), names.end());
for (const auto& name : names) {
const auto& ep = g_endpoints.at(name);
json::object o;
o["name"] = ep.name;
o["provider"] = ep.provider;
o["base_url"] = ep.base_url;
o["model"] = ep.model;
o["active"] = (ep.name == g_active_endpoint);
arr.emplace_back(std::move(o));
}
}
return host->strdup(json::serialize(arr).c_str());
}
static const char* mgr_get_active()
{
static thread_local std::string active;
std::shared_lock lock(g_endpoints_mutex);
active = g_active_endpoint;
return active.empty() ? nullptr : active.c_str();
}
static int mgr_set_active(const char* endpoint_name)
{
if (!endpoint_name || !*endpoint_name) return -1;
std::unique_lock lock(g_endpoints_mutex);
auto it = g_endpoints.find(endpoint_name);
if (it == g_endpoints.end()) return -2;
g_active_endpoint = endpoint_name;
return 0;
}
static EndpointConfig lookup_endpoint(const char* endpoint_name)
{
std::shared_lock lock(g_endpoints_mutex);
std::string name;
if (endpoint_name && *endpoint_name) name = endpoint_name;
else name = g_active_endpoint;
auto it = g_endpoints.find(name);
if (it == g_endpoints.end()) return {};
return it->second;
}
static int mgr_set_model(const char* endpoint_name, const char* model)
{
if (!model || !*model) return -1;
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
std::string selected;
{
std::unique_lock lock(g_endpoints_mutex);
selected = (endpoint_name && *endpoint_name) ? endpoint_name : g_active_endpoint;
auto it = g_endpoints.find(selected);
if (it == g_endpoints.end()) return -2;
it->second.model = model;
}
if (host && host->config_set) {
std::string key = "endpoint." + selected + ".model";
host->config_set(key.c_str(), model);
}
return 0;
}
static dstalk_chat_result_t mgr_chat(const char* endpoint_name,
const dstalk_message_t* history,
int history_len,
const char* user_input,
const char* tools_json)
{
// 防御: history_len > 0 时 history 不得为 nullptr / Guard: history must not be null when history_len > 0
if (history_len > 0 && history == nullptr) return make_error("null history with non-zero length");
EndpointConfig ep = lookup_endpoint(endpoint_name);
const dstalk_ai_service_t* service = lookup_ai_service(ep);
if (!service) return make_error("endpoint not found");
std::lock_guard<std::mutex> guard(provider_mutex(ep.provider));
int rc = service->configure(ep.provider.c_str(), ep.base_url.c_str(), ep.api_key.c_str(),
ep.model.c_str(), ep.max_tokens, ep.temperature);
if (rc != 0) return make_error("endpoint configure failed");
return service->chat(history, history_len, user_input, tools_json);
}
static dstalk_chat_result_t mgr_chat_stream(const char* endpoint_name,
const dstalk_message_t* history,
int history_len,
const char* user_input,
dstalk_stream_cb cb,
void* userdata)
{
// 防御: history_len > 0 时 history 不得为 nullptr / Guard: history must not be null when history_len > 0
if (history_len > 0 && history == nullptr) return make_error("null history with non-zero length");
EndpointConfig ep = lookup_endpoint(endpoint_name);
const dstalk_ai_service_t* service = lookup_ai_service(ep);
if (!service) return make_error("endpoint not found");
std::lock_guard<std::mutex> guard(provider_mutex(ep.provider));
int rc = service->configure(ep.provider.c_str(), ep.base_url.c_str(), ep.api_key.c_str(),
ep.model.c_str(), ep.max_tokens, ep.temperature);
if (rc != 0) return make_error("endpoint configure failed");
return service->chat_stream(history, history_len, user_input, cb, userdata);
}
static void mgr_free_result(dstalk_chat_result_t* result)
{
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
dstalk_ai::free_chat_result(host, result);
}
static dstalk_ai_endpoint_mgr_t g_service = {
&mgr_count,
&mgr_list_json,
&mgr_get_active,
&mgr_set_active,
&mgr_set_model,
&mgr_chat,
&mgr_chat_stream,
&mgr_free_result,
};
static int on_init(const dstalk_host_api_t* host)
{
try {
if (!host) return -1;
g_host.store(host, std::memory_order_release);
{
std::unique_lock lock(g_endpoints_mutex);
reload_endpoints_locked(host);
}
if (host->log) host->log(DSTALK_LOG_INFO, "[ai_endpoint_mgr] loaded %d endpoint(s)", mgr_count());
return host->register_service("ai_endpoint_mgr", 1, &g_service);
} catch (const std::exception& e) {
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_init exception: %s", e.what());
return -1;
} catch (...) {
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_init unknown exception");
return -1;
}
}
static void on_shutdown()
{
try {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_INFO, "[ai_endpoint_mgr] shutdown");
std::unique_lock lock(g_endpoints_mutex);
clear_endpoints_locked();
g_host.store(nullptr, std::memory_order_release);
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_shutdown exception: %s", e.what());
g_host.store(nullptr, std::memory_order_release);
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host && host->log) host->log(DSTALK_LOG_ERROR, "[ai_endpoint_mgr] on_shutdown unknown exception");
g_host.store(nullptr, std::memory_order_release);
}
}
static dstalk_plugin_info_t g_info = {
/* .name = */ "ai_endpoint_mgr",
/* .version = */ "1.0.0",
/* .description = */ "AI endpoint manager for multiple named provider/model endpoints / 多命名 AI endpoint 管理器",
/* .api_version = */ DSTALK_API_VERSION,
/* .dependencies = */ { "openai_compat", "anthropic_ai", NULL },
/* .on_init = */ on_init,
/* .on_shutdown = */ on_shutdown,
/* .on_event = */ nullptr,
};
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void)
{
return &g_info;
}