Files
dstalk/dstalk-core/src/plugin_loader.cpp
XiuChengWu df3bf298ee
Some checks failed
CI / Determine matrix (push) Has been cancelled
CI / ${{ matrix.os }} / ${{ matrix.build_type }} (push) Has been cancelled
CI / Sanitizer (ASan+UBSan) / ubuntu-24.04 (push) Has been cancelled
CI / Coverage (gcovr) / ubuntu-24.04 (push) Has been cancelled
W22: coverage metric + network tests + Tool stream feedback + stdin pipe + session path + dependency check (W22.1-W22.6)
- W22.1: gcovr 覆盖率度量 + CI coverage job(40% 阈值 warning)
- W22.2: network_plugin 单元测试(parse_headers_json/extract_host_port/SSE/异常保护)
- W22.3: Tool Calling 流式反馈(chat_stream + "[工具调用]/[工具结果]" 状态行)
- W22.4: --prompt stdin pipe(--prompt - 从 stdin 读取)
- W22.5: session 路径健壮化(static 缓存 + mkdir + fallback)
- W22.6: 插件依赖拓扑静态校验(validate_dependencies 循环/缺失检测)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-27 21:21:24 +08:00

521 lines
15 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "plugin_loader.hpp"
#include <boost/json.hpp>
#ifdef _WIN32
#include <windows.h>
#else
#include <dlfcn.h>
#endif
#include <algorithm>
#include <cctype>
#include <exception>
#include <filesystem>
#include <queue>
#include <stdexcept>
#include <unordered_set>
namespace dstalk {
namespace json = boost::json;
namespace fs = std::filesystem;
PluginLoader::~PluginLoader()
{
shutdown_all();
}
int PluginLoader::load_plugin(const char* path)
{
if (!path) return -1;
// === Path validation (F-18.3-3) ===
{
fs::path p = fs::absolute(fs::path(path)).lexically_normal();
// Extension check (case-insensitive)
std::string ext = p.extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(),
[](unsigned char c) { return static_cast<char>(std::tolower(c)); });
bool valid_ext = false;
#ifdef _WIN32
valid_ext = (ext == ".dll");
#elif defined(__APPLE__)
valid_ext = (ext == ".dylib" || ext == ".so");
#else
valid_ext = (ext == ".so");
#endif
if (!valid_ext) {
if (host_api_) {
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': invalid extension '%s', expected .dll/.so/.dylib",
path, ext.c_str());
}
return -1;
}
// Directory traversal check
bool has_dotdot = false;
bool in_plugins_dir = false;
for (const auto& comp : p) {
if (comp == "..") {
has_dotdot = true;
break;
}
if (comp == "plugins") {
in_plugins_dir = true;
}
}
if (has_dotdot) {
if (host_api_) {
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': directory traversal rejected", path);
}
return -1;
}
// Directory constraint: must be under a 'plugins' directory or be a plain filename
if (!in_plugins_dir && p.has_parent_path()) {
if (host_api_) {
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': path not under a 'plugins' directory", path);
}
return -1;
}
}
// 加载DLL
#ifdef _WIN32
void* handle = LoadLibraryA(path);
#else
void* handle = dlopen(path, RTLD_NOW | RTLD_LOCAL);
#endif
if (!handle) {
if (host_api_) {
#ifdef _WIN32
DWORD err = GetLastError();
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': LoadLibraryA failed (error %lu)", path, (unsigned long)err);
#else
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': dlopen failed: %s", path, dlerror());
#endif
}
return -1;
}
// 获取入口函数
#ifdef _WIN32
auto init_fn = (dstalk_plugin_init_fn)GetProcAddress(
(HMODULE)handle, "dstalk_plugin_init");
#else
auto init_fn = (dstalk_plugin_init_fn)dlsym(handle, "dstalk_plugin_init");
#endif
if (!init_fn) {
if (host_api_) {
#ifdef _WIN32
DWORD err = GetLastError();
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': GetProcAddress(dstalk_plugin_init) failed (error %lu)",
path, (unsigned long)err);
#else
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': dlsym(dstalk_plugin_init) failed: %s",
path, dlerror());
#endif
}
#ifdef _WIN32
FreeLibrary((HMODULE)handle);
#else
dlclose(handle);
#endif
return -1;
}
// 调用入口函数获取插件信息
dstalk_plugin_info_t* info = nullptr;
try {
info = init_fn();
} catch (const std::exception& e) {
if (host_api_) host_api_->log(DSTALK_LOG_ERROR, "[plugin_loader] %s: init_fn threw: %s", path, e.what());
} catch (...) {
if (host_api_) host_api_->log(DSTALK_LOG_ERROR, "[plugin_loader] %s: init_fn threw unknown exception", path);
}
if (!info) {
if (host_api_) {
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': dstalk_plugin_init returned null", path);
}
#ifdef _WIN32
FreeLibrary((HMODULE)handle);
#else
dlclose(handle);
#endif
return -1;
}
// 检查API版本兼容性
if (info->api_version != DSTALK_API_VERSION) {
if (host_api_) {
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] '%s': API version mismatch (got %d, expected %d)",
path, info->api_version, DSTALK_API_VERSION);
}
#ifdef _WIN32
FreeLibrary((HMODULE)handle);
#else
dlclose(handle);
#endif
return -1;
}
// 创建插件信息
int id = next_id_++;
PluginInfo plugin;
plugin.id = id;
plugin.name = info->name ? info->name : "";
plugin.version = info->version ? info->version : "";
plugin.description = info->description ? info->description : "";
plugin.api_version = info->api_version;
plugin.handle = handle;
plugin.info = info;
plugin.initialized = false;
// 解析依赖
for (int i = 0; i < DSTALK_MAX_DEPS && info->dependencies[i]; i++) {
plugin.dependencies.push_back(info->dependencies[i]);
}
plugins_[id] = std::move(plugin);
return id;
}
int PluginLoader::unload_plugin(int plugin_id)
{
auto it = plugins_.find(plugin_id);
if (it == plugins_.end()) return -1;
PluginInfo& plugin = it->second;
// 调用关闭回调
if (plugin.initialized && plugin.info->on_shutdown) {
try {
plugin.info->on_shutdown();
} catch (const std::exception& e) {
if (host_api_) host_api_->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' on_shutdown threw: %s",
plugin.name.c_str(), e.what());
} catch (...) {
if (host_api_) host_api_->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' on_shutdown threw unknown exception",
plugin.name.c_str());
}
}
// 卸载DLL
#ifdef _WIN32
FreeLibrary((HMODULE)plugin.handle);
#else
dlclose(plugin.handle);
#endif
plugins_.erase(it);
return 0;
}
std::string PluginLoader::list_plugins() const
{
json::array arr;
for (const auto& [id, plugin] : plugins_) {
json::object obj;
obj["id"] = id;
obj["name"] = plugin.name;
obj["version"] = plugin.version;
obj["description"] = plugin.description;
obj["api_version"] = plugin.api_version;
obj["initialized"] = plugin.initialized;
json::array deps;
for (const auto& dep : plugin.dependencies) {
deps.push_back(json::value(dep));
}
obj["dependencies"] = std::move(deps);
arr.push_back(std::move(obj));
}
return json::serialize(arr);
}
std::vector<int> PluginLoader::topological_sort() const
{
// 构建名称到ID的映射
std::unordered_map<std::string, int> name_to_id;
for (const auto& [id, plugin] : plugins_) {
name_to_id[plugin.name] = id;
}
// 计算入度
std::unordered_map<int, int> in_degree;
std::unordered_map<int, std::vector<int>> dependents;
for (const auto& [id, plugin] : plugins_) {
in_degree[id] = 0;
}
for (const auto& [id, plugin] : plugins_) {
for (const auto& dep_name : plugin.dependencies) {
auto it = name_to_id.find(dep_name);
if (it != name_to_id.end()) {
int dep_id = it->second;
dependents[dep_id].push_back(id);
in_degree[id]++;
}
}
}
// 拓扑排序Kahn算法
std::queue<int> queue;
for (const auto& [id, degree] : in_degree) {
if (degree == 0) {
queue.push(id);
}
}
std::vector<int> sorted;
while (!queue.empty()) {
int id = queue.front();
queue.pop();
sorted.push_back(id);
for (int dependent : dependents[id]) {
if (--in_degree[dependent] == 0) {
queue.push(dependent);
}
}
}
// 检查循环依赖
if (sorted.size() != plugins_.size()) {
throw std::runtime_error("Circular dependency detected");
}
return sorted;
}
int PluginLoader::validate_dependencies() const
{
int error_count = 0;
// 构建名称到ID的映射
std::unordered_map<std::string, int> name_to_id;
for (const auto& [id, plugin] : plugins_) {
name_to_id[plugin.name] = id;
}
// 检查1缺失依赖deps 引用的插件未加载)
for (const auto& [id, plugin] : plugins_) {
for (const auto& dep_name : plugin.dependencies) {
if (name_to_id.find(dep_name) == name_to_id.end()) {
if (host_api_) {
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] Plugin '%s': dependency '%s' not found (plugin not loaded)",
plugin.name.c_str(), dep_name.c_str());
}
error_count++;
}
}
}
// 检查2循环依赖拓扑排序失败
try {
topological_sort();
} catch (const std::runtime_error&) {
if (host_api_) {
host_api_->log(DSTALK_LOG_ERROR,
"[plugin_loader] Circular dependency detected among loaded plugins");
}
error_count++;
}
return error_count > 0 ? -1 : 0;
}
int PluginLoader::initialize_all(const dstalk_host_api_t* host_api)
{
if (!host_api) return -1;
host_api_ = host_api;
// 依赖合法性校验log 错误但不 crash继续初始化流程
if (validate_dependencies() != 0) {
host_api->log(DSTALK_LOG_WARN,
"[plugin_loader] Dependency validation failed; initialization may be incomplete");
}
try {
std::vector<int> order = topological_sort();
std::unordered_set<std::string> failed_names;
int failed_count = 0;
for (int id : order) {
auto it = plugins_.find(id);
if (it == plugins_.end()) continue;
PluginInfo& plugin = it->second;
if (plugin.initialized) continue;
// 检查依赖是否已失败
bool dep_unavailable = false;
for (const auto& dep_name : plugin.dependencies) {
if (failed_names.count(dep_name)) {
dep_unavailable = true;
break;
}
}
if (dep_unavailable) {
host_api->log(DSTALK_LOG_WARN, "[plugin_loader] Plugin '%s' skipped: dependency unavailable",
plugin.name.c_str());
failed_names.insert(plugin.name);
failed_count++;
continue;
}
if (plugin.info->on_init) {
int result;
try {
result = plugin.info->on_init(host_api);
} catch (const std::exception& e) {
host_api->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' init threw: %s",
plugin.name.c_str(), e.what());
failed_names.insert(plugin.name);
failed_count++;
continue;
} catch (...) {
host_api->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' init threw unknown exception",
plugin.name.c_str());
failed_names.insert(plugin.name);
failed_count++;
continue;
}
if (result != 0) {
host_api->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' init failed (code %d)",
plugin.name.c_str(), result);
failed_names.insert(plugin.name);
failed_count++;
continue;
}
}
plugin.initialized = true;
}
return failed_count;
} catch (const std::runtime_error&) {
// 循环依赖
return -1;
} catch (const std::exception&) {
return -1;
}
}
int PluginLoader::initialize_pending(const dstalk_host_api_t* host_api)
{
host_api_ = host_api;
try {
std::vector<int> order = topological_sort();
int count = 0;
for (int id : order) {
auto it = plugins_.find(id);
if (it == plugins_.end()) continue;
PluginInfo& plugin = it->second;
if (plugin.initialized) continue;
if (plugin.info->on_init) {
int result;
try {
result = plugin.info->on_init(host_api);
} catch (const std::exception& e) {
if (host_api) host_api->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' init threw: %s",
plugin.name.c_str(), e.what());
return -1;
} catch (...) {
if (host_api) host_api->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' init threw unknown exception",
plugin.name.c_str());
return -1;
}
if (result != 0) {
return -1;
}
}
plugin.initialized = true;
count++;
}
return count;
} catch (const std::exception&) {
return -1;
}
}
void PluginLoader::shutdown_all()
{
// 按逆序关闭
std::vector<int> order;
try {
order = topological_sort();
std::reverse(order.begin(), order.end());
} catch (...) {
// 如果排序失败,按任意顺序关闭
for (const auto& [id, _] : plugins_) {
order.push_back(id);
}
}
for (int id : order) {
auto it = plugins_.find(id);
if (it == plugins_.end()) continue;
PluginInfo& plugin = it->second;
if (plugin.initialized && plugin.info->on_shutdown) {
try {
plugin.info->on_shutdown();
} catch (const std::exception& e) {
if (host_api_) host_api_->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' shutdown threw: %s",
plugin.name.c_str(), e.what());
} catch (...) {
if (host_api_) host_api_->log(DSTALK_LOG_ERROR, "[plugin_loader] Plugin '%s' shutdown threw unknown exception",
plugin.name.c_str());
}
}
plugin.initialized = false;
}
// 释放所有 DLL 句柄
for (auto& [id, plugin] : plugins_) {
if (plugin.handle) {
#ifdef _WIN32
FreeLibrary((HMODULE)plugin.handle);
#else
dlclose(plugin.handle);
#endif
plugin.handle = nullptr;
}
}
plugins_.clear();
}
const PluginInfo* PluginLoader::get_plugin(int plugin_id) const
{
auto it = plugins_.find(plugin_id);
if (it == plugins_.end()) return nullptr;
return &it->second;
}
} // namespace dstalk