diff --git a/CMakeLists.txt b/CMakeLists.txt index 63994ff..e28a186 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ option(DSTALK_BUILD_TESTS "Build dstalk tests" ON) add_subdirectory(dstalk-core) add_subdirectory(dstalk-cli) +add_subdirectory(plugins) if(DSTALK_BUILD_GUI) add_subdirectory(dstalk-gui) diff --git a/README.md b/README.md index 7997a8f..27d0364 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ dstalk 是一款 AI 编程助手命令行工具。通过调用 DeepSeek V4 大模型(兼容 OpenAI 和 Anthropic API),在终端里用自然语言完成代码编写、重构、调试和文件操作。功能对标 Claude Code、OpenCode、KiloCode。 -核心设计为 **CDLL + 多前端解耦**: +核心设计为 **插件化 CDLL + 多前端解耦**: ```text ┌───────────────────────────────────────────────────────────┐ @@ -26,20 +26,41 @@ dstalk 是一款 AI 编程助手命令行工具。通过调用 DeepSeek V4 大 └──────────────────────────┼─────────────────────────────────┘ │ ┌──────────────────────────▼─────────────────────────────────┐ -│ 核心层 (dstalk-core.dll) │ -│ ┌────────────┐ ┌────────────┐ ┌──────────────────────┐ │ -│ │ 网络通讯 │ │ 文件读写 │ │ AI 接口适配 │ │ -│ │ Boost.Beast│ │ C++ 标准库 │ │ DeepSeek / OpenAI │ │ -│ │ + OpenSSL │ │ │ │ / Anthropic │ │ -│ └────────────┘ └────────────┘ └──────────────────────┘ │ +│ 核心层 (dstalk-core.dll) — 插件宿主 │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Host: 插件加载 · 服务注册 · 事件总线 · 配置管理 │ │ +│ └──────────────────────────────────────────────────────┘ │ +│ │ 服务查询 │ +│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ +│ │ deepseek │ │ anthropic│ │ network │ │ lsp │ │ +│ │ (ai) │ │ (ai) │ │ (http) │ │ 客户端 │ │ +│ └──────────┘ └──────────┘ └──────────┘ └──────────────┘ │ +│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ +│ │ session │ │ context │ │ file-io │ │ tools │ │ +│ └──────────┘ └──────────┘ └──────────┘ └──────────────┘ │ └─────────────────────────────────────────────────────────────┘ ``` -- **`dstalk-core`** —— C11/C++20 高性能核心 DLL,负责网络通信、AI 接口调用、文件 I/O。 +- **`dstalk-core`** —— C11/C++20 插件化宿主 DLL,负责插件加载、服务注册/查询、事件总线、配置管理。 - **`dstalk-cli`** —— 命令行前端,ANSI 转义码实现,调用 `dstalk.dll`。 - **`dstalk-gui`** —— 图形化前端,SDL3 跨平台窗口,调用 `dstalk.dll`。 +- **`plugins/`** —— 9 个功能插件,编译为独立 DLL,通过 C ABI 动态注册服务。 -核心与界面完全解耦,可以轻松编写自己的前端,或把 AI 能力嵌入到现有工具中。 +核心与界面完全解耦,可以轻松编写自己的前端,或把 AI 能力嵌入到现有工具中。所有功能通过插件实现,插件只需引用 `dstalk.dll` 即可。 + +--- + +## 核心功能 + +| 功能 | 状态 | 说明 | +|------|------|------| +| **多后端 AI 支持** | 已完成 | DeepSeek V4 和 Anthropic Claude 通过插件独立加载,`config.toml` 中 `ai.provider` 一键切换 | +| **流式输出** | 已完成 | SSE 流式响应,终端逐字打印 AI 思考过程 | +| **多轮会话** | 已完成 | 上下文窗口连续对话,支持 `/clear` 清空、`/save` `/load` 持久化 | +| **文件读写工具** | 已完成 | 内置 `/file` 命令集,支持列目录、查看、读取、写入文件 | +| **LSP 集成** | 已完成 | 完整 LSP 客户端(子进程管理、JSON-RPC 2.0),支持诊断、悬停、补全 | +| **插件系统** | 已完成 | 9 个功能插件,拓扑排序依赖管理,DLL 动态加载,服务注册/查询 | +| **GUI 前端** | 已完成 | SDL3 跨平台图形界面,流式输出、会话管理、输入历史、剪贴板 | --- @@ -108,7 +129,7 @@ build/dstalk-cli/dstalk-cli.exe # 命令行模式 ```text $ dstalk-cli - dstalk v0.1.0 | 模型: deepseek-v4 | /help 查看帮助 + dstalk v0.1.0 | 模型: deepseek-v4-pro | /help 查看帮助 > 帮我写一个读取 CSV 并计算平均值的 C 程序 @@ -170,24 +191,43 @@ $ dstalk-cli ```text dstalk/ ├── deps/ -│ └── conanfile.txt # Conan2 依赖声明 -├── dstalk-core/ # 核心 DLL +│ └── conanfile.txt # Conan2 依赖声明 (Boost, OpenSSL, SDL3) +├── dstalk-core/ # 核心 DLL — 插件宿主 │ ├── include/dstalk/ -│ │ └── dstalk_api.h # 公开 C API 头文件 +│ │ ├── dstalk_host.h # 公开 API: 宏定义、宿主API、插件生命周期 +│ │ ├── dstalk_services.h # 服务接口 vtable 定义 (AI/Session/Context/HTTP/FileIO/Config/Tools/LSP) +│ │ ├── dstalk_types.h # 共享类型: 消息、结果、事件、日志等级 +│ │ └── dstalk_lsp.h # LSP 便捷函数 (委托给 lsp 插件) │ ├── src/ -│ │ ├── api.cpp # API 实现 -│ │ ├── net/ # 网络通信 (HTTP/HTTPS) -│ │ ├── ai/ # AI 接口适配 -│ │ └── file/ # 文件读写 +│ │ ├── host.cpp # 宿主: 初始化、服务查询、LSP 便捷函数 +│ │ ├── config_store.cpp/.hpp # 配置管理 (TOML 解析) +│ │ ├── event_bus.cpp/.hpp # 事件总线 (发布/订阅) +│ │ ├── service_registry.cpp/.hpp # 服务注册表 (名称→vtable) +│ │ ├── plugin_loader.cpp/.hpp # 插件加载器 (DLL 加载、拓扑排序、依赖管理) +│ │ └── boost_json.cpp # Boost.JSON 编译单元 │ └── CMakeLists.txt +├── plugins/ # 功能插件 (每个编译为独立 DLL) +│ ├── deepseek/ # DeepSeek AI (服务名: ai.deepseek) +│ ├── anthropic/ # Anthropic Claude (服务名: ai.anthropic) +│ ├── network/ # HTTP/HTTPS 客户端 (服务名: http) +│ ├── session/ # 会话管理 (服务名: session) +│ ├── context/ # 上下文/Token 管理 (服务名: context) +│ ├── file-io/ # 文件读写 (服务名: file_io) +│ ├── tools/ # 工具注册/执行 (服务名: tools) +│ ├── lsp/ # LSP 客户端 (服务名: lsp) +│ ├── config/ # 配置服务 (服务名: config) +│ └── CMakeLists.txt # 插件构建 (按依赖顺序) ├── dstalk-cli/ # 命令行前端 (ANSI) │ ├── src/main.cpp │ └── CMakeLists.txt ├── dstalk-gui/ # 图形化前端 (SDL3) │ ├── src/main.cpp │ └── CMakeLists.txt -├── tests/ # 单元测试 -│ └── CMakeLists.txt +├── examples/ # 示例代码 +│ └── example_plugin/ +│ └── example_plugin.cpp # 插件开发示例 +├── tests/ # 集成测试 +│ └── smoke_test.cpp ├── CMakeLists.txt # 根 CMake └── README.md ``` @@ -196,47 +236,79 @@ dstalk/ ## 公开 API -头文件: [dstalk-core/include/dstalk/dstalk_api.h](dstalk-core/include/dstalk/dstalk_api.h) +头文件: +- [dstalk_host.h](dstalk-core/include/dstalk/dstalk_host.h) — 宿主 API、插件生命周期 +- [dstalk_services.h](dstalk-core/include/dstalk/dstalk_services.h) — 服务接口 vtable 定义 +- [dstalk_types.h](dstalk-core/include/dstalk/dstalk_types.h) — 共享类型 +- [dstalk_lsp.h](dstalk-core/include/dstalk/dstalk_lsp.h) — LSP 便捷函数 ```c -/* 初始化与销毁 */ +/* 宿主生命周期 */ int dstalk_init(const char* config_path); -void dstalk_destroy(void); +void dstalk_shutdown(void); -/* AI 对话 */ -int dstalk_chat(const char* input, char** output); -void dstalk_free_string(char* str); +/* 插件管理 */ +int dstalk_plugin_load(const char* path); +int dstalk_plugin_unload(int plugin_id); +int dstalk_plugin_list(char** output_json); -/* 文件操作 */ -int dstalk_file_read(const char* path, char** content); -int dstalk_file_write(const char* path, const char* content); +/* 服务查询 —— 通过名称获取插件注册的 vtable */ +void* dstalk_service_query(const char* service_name, int min_version); + +/* 事件总线 */ +int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata); +int dstalk_event_emit(int event_type, const void* data); +void dstalk_event_unsubscribe(int subscription_id); + +/* 配置 */ +const char* dstalk_config_get(const char* key); +int dstalk_config_set(const char* key, const char* value); + +/* 内存管理 */ +void* dstalk_alloc(size_t size); +void dstalk_free(void* ptr); +char* dstalk_strdup(const char* s); + +/* LSP 便捷函数 (委托给 lsp 插件) */ +int dstalk_lsp_start(const char* server_cmd, const char* language); +void dstalk_lsp_stop(void); +int dstalk_lsp_open(const char* uri, const char* content, const char* language_id); +int dstalk_lsp_close(const char* uri); +int dstalk_lsp_diagnostics(const char* uri, char** output); +int dstalk_lsp_hover(const char* uri, int line, int character, char** output); +int dstalk_lsp_completion(const char* uri, int line, int character, char** output); ``` **调用约定:** - 所有字符串均为 UTF-8 编码 -- `dstalk_chat` / `dstalk_file_read` 分配的内存由调用方通过 `dstalk_free_string` 释放 +- 通过 `dstalk_service_query` 获取服务 vtable,再通过函数指针调用具体功能 +- `dstalk_free` 释放所有 API 返回的堆内存 - 返回 `0` 成功,负数表示错误码 **跨语言调用示例:** ```c -#include "dstalk/dstalk_api.h" +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" #include int main(void) { - if (dstalk_init("config.json") != 0) { + if (dstalk_init("config.toml") != 0) { fprintf(stderr, "初始化失败\n"); return 1; } - char* reply = NULL; - if (dstalk_chat("解释这段代码", &reply) == 0) { - printf("AI: %s\n", reply); - dstalk_free_string(reply); + // 查询 AI 服务 + const char* provider = dstalk_config_get("ai.provider"); + if (!provider) provider = "ai.deepseek"; + const dstalk_ai_service_t* ai = dstalk_service_query(provider, 1); + if (ai) { + ai->configure(provider, "https://api.deepseek.com/v1", "sk-xxx", + "deepseek-v4-pro", 4096, 0.7); } - dstalk_destroy(); + dstalk_shutdown(); return 0; } ``` @@ -255,16 +327,25 @@ A: 主要支持 DeepSeek V4,同时兼容 OpenAI GPT 系列和 Anthropic Claude A: CLI 适合终端/SSH/CI 环境,GUI 适合需要富文本和鼠标交互的场景。两者共享同一核心 DLL,功能一致。 **Q: 如何配置 API Key?** -A: 首次运行前,手动创建项目目录下的 `config.toml`: +A: 首次运行前,手动创建项目目录下的 `config.toml`,按需选择后端: ```toml -[api] -provider = "deepseek" -base_url = "https://api.deepseek.com/v1" -api_key = "sk-xxxxxxxx" -model = "deepseek-v4" +# 选择 AI 后端插件: ai.deepseek 或 ai.anthropic +ai.provider = "ai.deepseek" + +# DeepSeek +api.base_url = "https://api.deepseek.com/v1" +api.api_key = "sk-xxxxxxxx" +api.model = "deepseek-v4-pro" + +# Anthropic Claude (切换 ai.provider 为 "ai.anthropic" 即可) +# api.base_url = "https://api.anthropic.com/v1" +# api.api_key = "sk-ant-xxxxxxxx" +# api.model = "claude-opus-4-20250514" ``` +修改 `ai.provider` 字段即可在不同后端间切换,无需改动代码。 + --- ## 路线图 @@ -273,8 +354,9 @@ model = "deepseek-v4" |------|------| | **Phase 1** | 项目骨架、CMake 构建、DLL 导出、CLI 前端主循环 | | **Phase 2** | HTTPS 网络层、DeepSeek API 对接、基本对话 | -| **Phase 3** (当前) | 流式输出、多轮会话、文件读写工具、CLI 体验对齐 | -| **Phase 4** | SDL3 GUI 完善、插件系统、LSP 集成 | +| **Phase 3** | ~~流式输出、多轮会话、文件读写工具、CLI 体验对齐~~ | +| **Phase 4** (当前) | ~~插件化架构重构、多后端 AI、LSP 客户端、SDL3 GUI~~ | +| **Phase 5** | GUI 完善、工具调用(Function Calling)、插件生态、多语言扩展 | --- diff --git a/deps/conanfile.txt b/deps/conanfile.txt index 70b8908..baacd93 100644 --- a/deps/conanfile.txt +++ b/deps/conanfile.txt @@ -1,6 +1,7 @@ [requires] boost/1.86.0 openssl/3.4.1 +sdl/3.4.8 [options] boost/*:header_only=True diff --git a/dstalk-cli/src/main.cpp b/dstalk-cli/src/main.cpp index 3de9806..a90c991 100644 --- a/dstalk-cli/src/main.cpp +++ b/dstalk-cli/src/main.cpp @@ -1,4 +1,11 @@ +// ============================================================================ +// dstalk-cli — 命令行前端 (使用插件化架构) +// ============================================================================ +// 通过 dstalk_host.h API 初始化核心,然后查询插件服务 vtable 调用功能。 +// ============================================================================ + #include +#include #include #include #include @@ -9,13 +16,14 @@ #ifdef _WIN32 #include +#include #else #include #include #include #endif -#include "dstalk/dstalk_api.h" +#include "dstalk/dstalk_host.h" // ---- ANSI 简写 ---- #define CLR_RESET "\033[0m" @@ -26,10 +34,42 @@ #define CLR_DIM "\033[2m" #define CLR_BOLD "\033[1m" +// ---- 退出码 ---- +#define EXIT_OK 0 +#define EXIT_INIT_FAIL 1 +#define EXIT_AI_ERROR 2 +#define EXIT_SVC_UNAVAIL 3 + +// ---- 服务 vtable 指针 ---- +static const dstalk_ai_service_t* g_ai = nullptr; +static const dstalk_session_service_t* g_session = nullptr; +static const dstalk_file_io_service_t* g_file_io = nullptr; + +// ---- 运行时状态 ---- +static std::string g_current_model; +static std::atomic g_quit_requested{false}; + +// ---- Ctrl+C 信号处理 ---- +#ifdef _WIN32 +static BOOL WINAPI on_console_event(DWORD event) +{ + if (event == CTRL_C_EVENT || event == CTRL_BREAK_EVENT) { + g_quit_requested = true; + return TRUE; + } + return FALSE; +} +#else +static void on_signal(int /*sig*/) +{ + g_quit_requested = true; +} +#endif + // ---- 工具函数 ---- static void print_banner() { - std::printf("%sdstalk v0.1.0%s | %sDeepSeek V4%s | " + std::printf("%sdstalk v0.1.0%s | %sdstalk AI%s | " "%s/help%s 查看帮助 | %s/quit%s 退出\n", CLR_CYAN CLR_BOLD, CLR_RESET, CLR_GREEN, CLR_RESET, @@ -43,6 +83,8 @@ static void print_help() std::printf(" %s/help%s 显示此帮助\n", CLR_YELLOW, CLR_RESET); std::printf(" %s/quit%s 退出程序\n", CLR_YELLOW, CLR_RESET); std::printf(" %s/clear%s 清空当前会话上下文\n", CLR_YELLOW, CLR_RESET); + std::printf(" %s/context%s 显示当前 Token 数和消息条数\n", CLR_YELLOW, CLR_RESET); + std::printf(" %s/status%s 显示当前运行状态(脱敏)\n", CLR_YELLOW, CLR_RESET); std::printf(" %s/model %s 切换模型\n", CLR_YELLOW, CLR_RESET); std::printf(" %s/file list [path]%s 列出目录内容\n", CLR_YELLOW, CLR_RESET); std::printf(" %s/file show %s 查看文件内容\n", CLR_YELLOW, CLR_RESET); @@ -56,12 +98,16 @@ static void print_help() static void print_file(const char* path) { while (*path == ' ') path++; + if (!g_file_io) { + std::printf(CLR_RED "[ERROR] file_io 服务不可用\n" CLR_RESET); + return; + } char* content = nullptr; - if (dstalk_file_read(path, &content) == 0 && content) { + if (g_file_io->read(path, &content) == 0 && content) { std::printf("%s--- %s ---%s\n", CLR_DIM, path, CLR_RESET); std::printf("%s\n", content); std::printf(CLR_DIM "--- EOF ---\n" CLR_RESET); - dstalk_free_string(content); + dstalk_free(content); } else { std::printf(CLR_RED "[ERROR] 无法读取: %s\n" CLR_RESET, path); } @@ -104,11 +150,11 @@ static void handle_command(const char* line) { if (!line || line[0] != '/') return; - // /quit + // /quit —— 设置退出标志,让控制流自然回到 main 末尾 if (std::strcmp(line, "/quit") == 0 || std::strcmp(line, "/q") == 0) { - dstalk_destroy(); - std::printf(CLR_DIM "再见!\n" CLR_RESET); - std::exit(0); + g_quit_requested = true; + std::printf("再见!\n"); + return; } // /help @@ -119,17 +165,60 @@ static void handle_command(const char* line) // /clear if (std::strcmp(line, "/clear") == 0) { - dstalk_session_clear(); + if (g_session) g_session->clear(); std::printf(CLR_GREEN "[OK] 会话已清空\n" CLR_RESET); return; } + // /context + if (std::strcmp(line, "/context") == 0) { + if (g_session) { + int count = 0; + g_session->history(&count); + int tokens = g_session->token_count(); + std::printf(CLR_DIM "消息条数: " CLR_RESET "%d | " + CLR_DIM "Token 估算: " CLR_RESET "%d\n", + count, tokens); + } + return; + } + + // /status —— 脱敏显示当前运行状态 + if (std::strcmp(line, "/status") == 0) { + const char* provider = dstalk_config_get("ai.provider"); + if (!provider) provider = "ai.deepseek"; + const char* base_url = dstalk_config_get("api.base_url"); + if (!base_url) base_url = "https://api.deepseek.com/v1"; + const char* api_key = dstalk_config_get("api.api_key"); + + std::printf(" 模型: %s\n", g_current_model.empty() ? "(未设置)" : g_current_model.c_str()); + std::printf(" base_url: %s\n", base_url ? base_url : "(未设置)"); + std::printf(" api_key: %s\n", (api_key && api_key[0]) ? "已设置" : "未设置"); + std::printf(" provider: %s\n", provider); + std::printf(" AI 服务: %s\n", g_ai ? "就绪" : "不可用"); + std::printf(" Session 服务: %s\n", g_session ? "就绪" : "不可用"); + std::printf(" File IO 服务: %s\n", g_file_io ? "就绪" : "不可用"); + const dstalk_tools_service_t* tools = static_cast( + dstalk_service_query("tools", 1)); + std::printf(" Tools 服务: %s\n", tools ? "就绪" : "不可用"); + return; + } + // /model if (std::strncmp(line, "/model ", 7) == 0) { const char* model = line + 7; while (*model == ' ') model++; - dstalk_set_model(model); - std::printf(CLR_GREEN "[OK] 模型已切换: %s\n" CLR_RESET, model); + if (*model == '\0') { + std::printf(CLR_RED "[ERROR] /model 需要模型名\n" CLR_RESET); + return; + } + if (g_ai) { + g_ai->configure(nullptr, nullptr, nullptr, model, 0, 0.0); + g_current_model = model; + std::printf(CLR_GREEN "[OK] 模型已切换: %s\n" CLR_RESET, model); + } else { + std::printf(CLR_RED "[ERROR] AI 服务不可用\n" CLR_RESET); + } return; } @@ -156,7 +245,6 @@ static void handle_command(const char* line) if (std::strncmp(line, "/file write ", 12) == 0) { const char* rest = line + 12; while (*rest == ' ') rest++; - // 第一个参数是路径,后面到行尾是内容 const char* space = std::strchr(rest, ' '); if (!space) { std::printf(CLR_RED "[ERROR] 用法: /file write \n" CLR_RESET); @@ -165,7 +253,7 @@ static void handle_command(const char* line) std::string path(rest, space - rest); const char* content = space + 1; while (*content == ' ') content++; - if (dstalk_file_write(path.c_str(), content) == 0) { + if (g_file_io && g_file_io->write(path.c_str(), content) == 0) { std::printf(CLR_GREEN "[OK] 已写入: %s\n" CLR_RESET, path.c_str()); } else { std::printf(CLR_RED "[ERROR] 写入失败: %s\n" CLR_RESET, path.c_str()); @@ -177,7 +265,7 @@ static void handle_command(const char* line) if (std::strncmp(line, "/save ", 6) == 0) { const char* path = line + 6; while (*path == ' ') path++; - if (dstalk_session_save(path) == 0) { + if (g_session && g_session->save(path) == 0) { std::printf(CLR_GREEN "[OK] 会话已保存: %s\n" CLR_RESET, path); } else { std::printf(CLR_RED "[ERROR] 保存失败: %s\n" CLR_RESET, path); @@ -189,7 +277,7 @@ static void handle_command(const char* line) if (std::strncmp(line, "/load ", 6) == 0) { const char* path = line + 6; while (*path == ' ') path++; - if (dstalk_session_load(path) == 0) { + if (g_session && g_session->load(path) == 0) { std::printf(CLR_GREEN "[OK] 会话已恢复: %s\n" CLR_RESET, path); } else { std::printf(CLR_RED "[ERROR] 恢复失败: %s\n" CLR_RESET, path); @@ -200,6 +288,19 @@ static void handle_command(const char* line) std::printf(CLR_RED "未知命令: %s (输入 /help 查看帮助)\n" CLR_RESET, line); } +// ---- 流式回调 ---- +static int on_stream_token(const char* token, void* userdata) +{ + bool* first = static_cast(userdata); + if (*first) { + std::printf(CLR_GREEN); + *first = false; + } + std::printf("%s", token); + std::fflush(stdout); + return 0; +} + // ---- 主程序 ---- int main(int argc, char* argv[]) { @@ -211,17 +312,40 @@ int main(int argc, char* argv[]) SetConsoleMode(hOut, mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING); #endif + // ---- C1: batch 模式检测 ---- + bool batch_mode = false; + for (int i = 1; i < argc; ++i) { + if (std::strcmp(argv[i], "--batch") == 0) { + batch_mode = true; + break; + } + } +#ifdef _WIN32 + if (!batch_mode && _isatty(_fileno(stdin)) == 0) batch_mode = true; +#else + if (!batch_mode && isatty(fileno(stdin)) == 0) batch_mode = true; +#endif + + // ---- B1: 安装 Ctrl+C 处理 ---- +#ifdef _WIN32 + SetConsoleCtrlHandler(on_console_event, TRUE); +#else + signal(SIGINT, on_signal); +#endif + // 查找配置文件 const char* config_path = nullptr; if (argc >= 2) { - config_path = argv[1]; - } else { - // 默认路径 -#ifdef _WIN32 + // 跳过 --batch 标志 + for (int i = 1; i < argc; ++i) { + if (std::strcmp(argv[i], "--batch") != 0) { + config_path = argv[i]; + break; + } + } + } + if (!config_path) { const char* default_configs[] = {"config.toml", nullptr}; -#else - const char* default_configs[] = {"config.toml", nullptr}; -#endif for (int i = 0; default_configs[i]; i++) { FILE* f = nullptr; #ifdef _WIN32 @@ -237,22 +361,64 @@ int main(int argc, char* argv[]) } } + // 初始化主机(加载配置 + 自动扫描 plugins/ 目录加载插件) if (dstalk_init(config_path) != 0) { std::fprintf(stderr, CLR_RED "[dstalk] 初始化失败\n" CLR_RESET); - return 1; + return EXIT_INIT_FAIL; } - std::printf("\n"); - print_banner(); - std::printf("\n"); + // 查询插件服务 + const char* ai_provider = dstalk_config_get("ai.provider"); + if (!ai_provider) ai_provider = "ai.deepseek"; + g_ai = static_cast(dstalk_service_query(ai_provider, 1)); + g_session = static_cast(dstalk_service_query("session", 1)); + g_file_io = static_cast(dstalk_service_query("file_io", 1)); + + if (!g_ai) { + std::fprintf(stderr, CLR_RED "[dstalk] AI 服务未找到(请检查插件目录)\n" CLR_RESET); + } + if (!g_session) { + std::fprintf(stderr, CLR_RED "[dstalk] Session 服务未找到\n" CLR_RESET); + } + + // 自动从配置加载 AI 设置 + if (g_ai) { + const char* base_url = dstalk_config_get("api.base_url"); + const char* api_key = dstalk_config_get("api.api_key"); + const char* model = dstalk_config_get("api.model"); + if (!base_url) base_url = "https://api.deepseek.com/v1"; + if (!model) model = "deepseek-v4-pro"; + g_ai->configure(ai_provider, base_url, api_key ? api_key : "", model, 4096, 0.7); + g_current_model = model; // A1: 记录当前模型名 + } + + if (!batch_mode) { + std::printf("\n"); + print_banner(); + std::printf("\n"); + } char buffer[8192]; while (true) { - std::printf(CLR_YELLOW "> " CLR_RESET); - std::fflush(stdout); + // B1: 检查退出标志 + if (g_quit_requested) break; + + // A1: 提示符带模型名(batch 模式不打印) + if (!batch_mode) { + std::printf(CLR_CYAN "[%s] " CLR_RESET CLR_YELLOW "> " CLR_RESET, + g_current_model.empty() ? "?" : g_current_model.c_str()); + std::fflush(stdout); + } if (!std::fgets(buffer, sizeof(buffer), stdin)) break; + // C3: fgets 截断检测 + if (!std::strchr(buffer, '\n') && !feof(stdin)) { + std::fprintf(stderr, CLR_RED "[ERROR] 输入超过 8KB,已截断。建议用文件方式:dstalk --batch < file.txt\n" CLR_RESET); + int c; + while ((c = std::fgetc(stdin)) != '\n' && c != EOF) {} + } + // 去除末尾换行 size_t len = std::strlen(buffer); while (len > 0 && (buffer[len-1] == '\n' || buffer[len-1] == '\r')) { @@ -267,25 +433,36 @@ int main(int argc, char* argv[]) continue; } - // AI 对话 - std::printf(CLR_DIM "思考中..." CLR_RESET "\n"); - std::fflush(stdout); - - char* reply = nullptr; - int ret = dstalk_chat(buffer, &reply); - if (ret == 0 && reply) { - std::printf("\n%s\n\n", reply); - dstalk_free_string(reply); - } else { - std::printf(CLR_RED "[ERROR] AI 调用失败" CLR_RESET); - if (reply) { - std::printf(": %s", reply); - dstalk_free_string(reply); - } - std::printf("\n"); + // AI 对话(通过插件服务 vtable) + if (!g_ai || !g_session) { + std::printf(CLR_RED "[ERROR] AI 或 Session 服务不可用\n" CLR_RESET); + continue; } + + // 获取会话历史 + int history_count = 0; + const dstalk_message_t* history = g_session->history(&history_count); + + bool first = true; + dstalk_chat_result_t result = g_ai->chat_stream( + history, history_count, buffer, on_stream_token, &first); + + if (result.ok) { + std::printf(CLR_RESET "\n\n"); + // 将用户消息和 AI 回复添加到会话 + dstalk_message_t user_msg = {"user", buffer, nullptr, nullptr}; + g_session->add(&user_msg); + dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json}; + g_session->add(&ai_msg); + } else { + // A3: error 路径下需 NULL 保护;当前只取 result.error,content 未涉及 + std::printf(CLR_RESET "\n" CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET, + result.error ? result.error : "unknown error"); + } + g_ai->free_result(&result); } - dstalk_destroy(); - return 0; + // B2: 单一退出点,dstalk_shutdown 只在此调用 + dstalk_shutdown(); + return EXIT_OK; } diff --git a/dstalk-core/CMakeLists.txt b/dstalk-core/CMakeLists.txt index 0c94168..69ff29f 100644 --- a/dstalk-core/CMakeLists.txt +++ b/dstalk-core/CMakeLists.txt @@ -1,16 +1,17 @@ # ============================================================ -# dstalk-core — 核心 DLL -# 包含: 网络通讯 / AI接口 / 文件读写 +# dstalk-core — 核心 DLL (插件宿主) +# 包含: 插件管理 / 服务注册 / 事件总线 / 配置存储 # ============================================================ find_package(Boost REQUIRED CONFIG) find_package(OpenSSL REQUIRED CONFIG) add_library(dstalk SHARED - src/api.cpp - src/file/file_io.cpp - src/net/http_client.cpp - src/ai/deepseek_api.cpp + src/host.cpp + src/config_store.cpp + src/event_bus.cpp + src/service_registry.cpp + src/plugin_loader.cpp src/boost_json.cpp ) @@ -25,6 +26,11 @@ target_link_libraries(dstalk openssl::openssl ) +# dlopen / dlclose / dlsym on Linux and macOS +if(NOT WIN32) + target_link_libraries(dstalk PRIVATE ${CMAKE_DL_LIBS}) +endif() + # 导出 DLL 符号宏 target_compile_definitions(dstalk PRIVATE diff --git a/dstalk-core/include/dstalk/dstalk_api.h b/dstalk-core/include/dstalk/dstalk_api.h deleted file mode 100644 index 585bdd2..0000000 --- a/dstalk-core/include/dstalk/dstalk_api.h +++ /dev/null @@ -1,52 +0,0 @@ -#ifndef DSTALK_API_H -#define DSTALK_API_H - -#ifdef __cplusplus -extern "C" { -#endif - -/* ---- DLL 导出 / 导入宏 ---- */ -#if defined(_WIN32) - #ifdef DSTALK_BUILD_DLL - #define DSTALK_API __declspec(dllexport) - #else - #define DSTALK_API __declspec(dllimport) - #endif -#else - #define DSTALK_API __attribute__((visibility("default"))) -#endif - -/* ---- 初始化和配置 ---- */ -DSTALK_API int dstalk_init(const char* config_path); -DSTALK_API void dstalk_destroy(void); - -/* 在 init 之后可修改 API 参数 (init 也会从配置文件读取) */ -DSTALK_API void dstalk_set_api_key(const char* api_key); -DSTALK_API void dstalk_set_base_url(const char* base_url); -DSTALK_API void dstalk_set_model(const char* model); - -/* ---- AI 对话 ---- */ -/* 同步对话: 发送 input,返回完整 AI 回复 (调用方通过 dstalk_free_string 释放) */ -DSTALK_API int dstalk_chat(const char* input, char** output); - -/* 流式对话: 每收到一个 token 调用回调,回调返回 0 继续,非 0 取消 */ -typedef int (*dstalk_stream_cb)(const char* token, void* userdata); -DSTALK_API int dstalk_chat_stream(const char* input, dstalk_stream_cb cb, void* userdata); - -/* 释放由 dstalk_chat / dstalk_file_read 分配的字符串 */ -DSTALK_API void dstalk_free_string(char* str); - -/* ---- 会话管理 ---- */ -DSTALK_API void dstalk_session_clear(void); /* 清空对话历史 */ -DSTALK_API int dstalk_session_save(const char* path); /* 保存会话到文件 */ -DSTALK_API int dstalk_session_load(const char* path); /* 从文件恢复会话 */ - -/* ---- 文件操作 ---- */ -DSTALK_API int dstalk_file_read(const char* path, char** content); -DSTALK_API int dstalk_file_write(const char* path, const char* content); - -#ifdef __cplusplus -} -#endif - -#endif /* DSTALK_API_H */ diff --git a/dstalk-core/include/dstalk/dstalk_host.h b/dstalk-core/include/dstalk/dstalk_host.h new file mode 100644 index 0000000..276bf61 --- /dev/null +++ b/dstalk-core/include/dstalk/dstalk_host.h @@ -0,0 +1,132 @@ +#ifndef DSTALK_HOST_H +#define DSTALK_HOST_H + +#include "dstalk_types.h" +#include "dstalk_services.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// === 平台导出宏 === +#ifndef DSTALK_API +#if defined(_WIN32) + #ifdef DSTALK_BUILD_DLL + #define DSTALK_API __declspec(dllexport) + #else + #define DSTALK_API __declspec(dllimport) + #endif +#else + #define DSTALK_API __attribute__((visibility("default"))) +#endif +#endif + +// === 插件导出宏 === +#if defined(_WIN32) + #define DSTALK_PLUGIN_EXPORT __declspec(dllexport) +#else + #define DSTALK_PLUGIN_EXPORT __attribute__((visibility("default"))) +#endif + +// === API 版本 === +#define DSTALK_API_VERSION 1 +#define DSTALK_MAX_DEPS 8 + +// === 诊断 === +typedef void (*dstalk_diag_cb)(int severity, const char* file, + int line, const char* func, const char* message); + +#define DSTALK_ERROR_RETURN(expr, retval) do { \ + if (!(expr)) { \ + dstalk_log(DSTALK_LOG_ERROR, "[%s:%d] %s: assertion '%s' failed", \ + __FILE__, __LINE__, __func__, #expr); \ + return (retval); \ + } \ +} while(0) + +DSTALK_API void dstalk_set_diag_callback(dstalk_diag_cb cb); + +// === 事件处理器 === +typedef void (*dstalk_event_handler_fn)(int event_type, const void* data, void* userdata); + +// === Host 提供给插件的 API 表 === +typedef struct { + // 服务注册/查询 + int (*register_service)(const char* name, int version, void* vtable); + void*(*query_service)(const char* name, int min_version); + + // 事件 + int (*event_subscribe)(int event_type, dstalk_event_handler_fn handler, void* userdata); + int (*event_emit)(int event_type, const void* data); + void (*event_unsubscribe)(int sub_id); + + // 配置 + const char* (*config_get)(const char* key); + int (*config_set)(const char* key, const char* value); + + // 日志 + void (*log)(int level, const char* fmt, ...); + + // 内存 + void* (*alloc)(size_t size); + void (*free)(void* ptr); + char* (*strdup)(const char* s); +} dstalk_host_api_t; + +// === 插件信息结构 === +typedef struct { + const char* name; // 插件名称(唯一标识) + const char* version; // 语义化版本号,如 "1.0.0" + const char* description; // 描述 + int api_version; // 必须 == DSTALK_API_VERSION + + // 依赖声明(以 NULL 结尾) + const char* dependencies[DSTALK_MAX_DEPS]; + + // 生命周期回调 + int (*on_init)(const dstalk_host_api_t* host); + void (*on_shutdown)(void); + + // 事件处理(可选) + void (*on_event)(int event_type, const void* data); +} dstalk_plugin_info_t; + +// === 插件入口函数 === +typedef dstalk_plugin_info_t* (*dstalk_plugin_init_fn)(void); + +// === Host 公共 API === + +// 初始化/销毁 +DSTALK_API int dstalk_init(const char* config_path); +DSTALK_API void dstalk_shutdown(void); + +// 插件管理 +DSTALK_API int dstalk_plugin_load(const char* path); +DSTALK_API int dstalk_plugin_unload(int plugin_id); +DSTALK_API int dstalk_plugin_list(char** output_json); + +// 服务查询 +DSTALK_API void* dstalk_service_query(const char* service_name, int min_version); + +// 事件系统 +DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata); +DSTALK_API int dstalk_event_emit(int event_type, const void* data); +DSTALK_API void dstalk_event_unsubscribe(int subscription_id); + +// 配置 +DSTALK_API const char* dstalk_config_get(const char* key); +DSTALK_API int dstalk_config_set(const char* key, const char* value); + +// 日志 +DSTALK_API void dstalk_log(int level, const char* fmt, ...); + +// 内存 +DSTALK_API void* dstalk_alloc(size_t size); +DSTALK_API void dstalk_free(void* ptr); +DSTALK_API char* dstalk_strdup(const char* s); + +#ifdef __cplusplus +} +#endif + +#endif // DSTALK_HOST_H diff --git a/dstalk-core/include/dstalk/dstalk_lsp.h b/dstalk-core/include/dstalk/dstalk_lsp.h new file mode 100644 index 0000000..4a5a31a --- /dev/null +++ b/dstalk-core/include/dstalk/dstalk_lsp.h @@ -0,0 +1,91 @@ +#ifndef DSTALK_LSP_H +#define DSTALK_LSP_H + +#include "dstalk_host.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* ---- LSP 服务器生命周期 ---- */ + +/* + * 启动语言服务器进程 + * server_cmd: 命令字符串,例如 "clangd" 或 "pyright --stdio" 或完整路径 + * language: 语言标识,例如 "c", "cpp", "python", "javascript", "rust" + * returns: 0 成功, -1 失败 + */ +DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language); + +/* + * 停止语言服务器 + * 发送 shutdown 请求,然后发送 exit 通知 + * 关闭管道,终止子进程 + */ +DSTALK_API void dstalk_lsp_stop(void); + +/* ---- 文档管理 ---- */ + +/* + * 在语言服务器中打开一个文档 + * uri: 文件 URI,例如 "file:///path/to/file.c" + * content: 文件内容文本 + * language_id: 语言 ID,例如 "c", "cpp", "python", "javascript" + * returns: 0 成功, -1 失败 + */ +DSTALK_API int dstalk_lsp_open(const char* uri, const char* content, + const char* language_id); + +/* + * 关闭语言服务器中的文档 + * uri: 文件 URI + * returns: 0 成功, -1 失败 + */ +DSTALK_API int dstalk_lsp_close(const char* uri); + +/* ---- 查询操作 ---- */ + +/* + * 获取诊断信息 (编译错误、警告等) + * uri: 文件 URI + * output: 输出参数,JSON 格式的诊断列表 (调用方通过 dstalk_free 释放) + * returns: 0 成功, -1 失败 + * + * JSON 输出格式示例: + * [ + * { + * "range": { "start": {"line":0,"character":0}, "end":{"line":0,"character":5} }, + * "severity": 1, + * "message": "error message" + * } + * ] + */ +DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output); + +/* + * 获取悬停信息 (类型、文档等) + * uri: 文件 URI + * line: 行号 (0-based) + * character: 列号 (0-based, UTF-16 code units) + * output: 输出参数,JSON 格式的悬停信息 (调用方通过 dstalk_free 释放) + * returns: 0 成功, -1 失败 + */ +DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character, + char** output); + +/* + * 获取代码补全建议 + * uri: 文件 URI + * line: 行号 (0-based) + * character: 列号 (0-based, UTF-16 code units) + * output: 输出参数,JSON 格式的补全列表 (调用方通过 dstalk_free 释放) + * returns: 0 成功, -1 失败 + */ +DSTALK_API int dstalk_lsp_completion(const char* uri, int line, int character, + char** output); + +#ifdef __cplusplus +} +#endif + +#endif /* DSTALK_LSP_H */ diff --git a/dstalk-core/include/dstalk/dstalk_services.h b/dstalk-core/include/dstalk/dstalk_services.h new file mode 100644 index 0000000..9702f5d --- /dev/null +++ b/dstalk-core/include/dstalk/dstalk_services.h @@ -0,0 +1,97 @@ +#ifndef DSTALK_SERVICES_H +#define DSTALK_SERVICES_H + +#include "dstalk_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// === AI 服务 vtable (实际服务名由插件注册: "ai.deepseek" / "ai.anthropic") === +typedef struct { + int (*configure)(const char* provider, const char* base_url, + const char* api_key, const char* model, + int max_tokens, double temperature); + dstalk_chat_result_t (*chat)( + const dstalk_message_t* history, int history_len, + const char* user_input, + const char* tools_json); + dstalk_chat_result_t (*chat_stream)( + const dstalk_message_t* history, int history_len, + const char* user_input, + dstalk_stream_cb cb, void* userdata); + void (*free_result)(dstalk_chat_result_t* result); +} dstalk_ai_service_t; + +// === Session 服务 (service name: "session") === +typedef struct { + void (*add)(const dstalk_message_t* msg); + void (*clear)(void); + int (*save)(const char* path); + int (*load)(const char* path); + const dstalk_message_t* (*history)(int* out_count); + int (*token_count)(void); +} dstalk_session_service_t; + +// === Context 服务 (service name: "context") === +typedef struct { + size_t (*count_tokens)(const dstalk_message_t* msgs, int count); + int (*trim)(const dstalk_message_t* in, int in_count, + dstalk_message_t** out, int* out_count, + size_t max_tokens); + void (*set_max_tokens)(size_t max); +} dstalk_context_service_t; + +// === HTTP 服务 (service name: "http") === +typedef struct { + int (*post_json)(const char* host, const char* port, + const char* target, const char* body, + const char* headers_json, + char** response_body, int* status_code); + int (*post_stream)(const char* host, const char* port, + const char* target, const char* body, + const char* headers_json, + dstalk_stream_cb cb, void* userdata, + char** response_body, int* status_code); +} dstalk_http_service_t; + +// === File IO 服务 (service name: "file_io") === +typedef struct { + int (*read)(const char* path, char** content); + int (*write)(const char* path, const char* content); +} dstalk_file_io_service_t; + +// === Config 服务 (service name: "config") === +typedef struct { + const char* (*get)(const char* key); + int (*set)(const char* key, const char* value); + int (*load_file)(const char* path); +} dstalk_config_service_t; + +// === Tools 服务 (service name: "tools") === +typedef char* (*dstalk_tool_handler_fn)(const char* args_json); +typedef struct { + int (*register_tool)(const char* name, const char* desc, + const char* params_schema, + dstalk_tool_handler_fn handler); + void (*unregister_tool)(const char* name); + char* (*get_tools_json)(void); + char* (*execute)(const char* name, const char* args_json); +} dstalk_tools_service_t; + +// === LSP 服务 (service name: "lsp") === +typedef struct { + int (*start)(const char* server_cmd, const char* language); + void (*stop)(void); + int (*open_document)(const char* uri, const char* content, const char* lang_id); + int (*close_document)(const char* uri); + int (*get_diagnostics)(const char* uri, char** json_out); + int (*get_hover)(const char* uri, int line, int col, char** json_out); + int (*get_completion)(const char* uri, int line, int col, char** json_out); +} dstalk_lsp_service_t; + +#ifdef __cplusplus +} +#endif + +#endif // DSTALK_SERVICES_H diff --git a/dstalk-core/include/dstalk/dstalk_types.h b/dstalk-core/include/dstalk/dstalk_types.h new file mode 100644 index 0000000..0dc52a1 --- /dev/null +++ b/dstalk-core/include/dstalk/dstalk_types.h @@ -0,0 +1,52 @@ +#ifndef DSTALK_TYPES_H +#define DSTALK_TYPES_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// 消息结构(跨插件共享) +typedef struct { + const char* role; // "user", "assistant", "system", "tool" + const char* content; // 消息内容 + const char* tool_call_id; // tool 响应时必填 + const char* tool_calls_json;// assistant 返回的工具调用(JSON 数组) +} dstalk_message_t; + +// 聊天结果 +typedef struct { + int ok; + const char* content; // dstalk_strdup 分配,调用方 dstalk_free + const char* error; // dstalk_strdup 分配 + int http_status; + const char* tool_calls_json;// dstalk_strdup 分配 +} dstalk_chat_result_t; + +// 流式回调 +typedef int (*dstalk_stream_cb)(const char* token, void* userdata); + +// 事件类型 +enum { + DSTALK_EVENT_MESSAGE = 1, // data = dstalk_message_t* + DSTALK_EVENT_SESSION_CLEAR, + DSTALK_EVENT_CONFIG_CHANGED, + DSTALK_EVENT_PLUGIN_LOADED, // data = plugin info JSON string + DSTALK_EVENT_PLUGIN_UNLOADED, + DSTALK_EVENT_CUSTOM = 1000, // 插件自定义事件起始值 +}; + +// 日志级别 +enum { + DSTALK_LOG_DEBUG = 0, + DSTALK_LOG_INFO = 1, + DSTALK_LOG_WARN = 2, + DSTALK_LOG_ERROR = 3, +}; + +#ifdef __cplusplus +} +#endif + +#endif // DSTALK_TYPES_H diff --git a/dstalk-core/src/ai/deepseek_api.cpp b/dstalk-core/src/ai/deepseek_api.cpp deleted file mode 100644 index 3925556..0000000 --- a/dstalk-core/src/ai/deepseek_api.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "ai/deepseek_api.hpp" -#include "net/http_client.hpp" - -#include -#include -#include - -namespace json = boost::json; - -namespace dstalk { -namespace ai { - -// ---- JSON 构造 ---- -static std::string build_request_json( - const ApiConfig& cfg, - const std::vector& history, - const std::string& user_input, - bool stream) -{ - json::object root; - root["model"] = cfg.model; - root["max_tokens"] = cfg.max_tokens; - root["temperature"] = cfg.temperature; - root["stream"] = stream; - - json::array msgs; - for (const auto& m : history) { - json::object obj; - obj["role"] = m.role; - obj["content"] = m.content; - msgs.push_back(obj); - } - // 追加当前用户输入 - { - json::object obj; - obj["role"] = "user"; - obj["content"] = user_input; - msgs.push_back(obj); - } - root["messages"] = msgs; - - return json::serialize(root); -} - -// ---- JSON 响应解析 ---- -static ChatResult parse_response(const std::string& body, int http_status) -{ - ChatResult r; - r.http_status = http_status; - - if (http_status < 200 || http_status >= 300) { - r.ok = false; - // 尝试提取错误信息 - try { - auto jv = json::parse(body); - auto obj = jv.as_object(); - if (obj.contains("error")) { - auto err = obj["error"].as_object(); - r.error = json::value_to(err["message"]); - } - } catch (...) { - r.error = "HTTP " + std::to_string(http_status); - } - return r; - } - - try { - auto jv = json::parse(body); - auto obj = jv.as_object(); - auto choices = obj["choices"].as_array(); - if (!choices.empty()) { - auto msg = choices[0].as_object()["message"].as_object(); - r.content = json::value_to(msg["content"]); - r.ok = true; - } else { - r.ok = false; - r.error = "empty response"; - } - } catch (std::exception& e) { - r.ok = false; - r.error = std::string("json parse: ") + e.what(); - } - return r; -} - -// ---- SSE 行解析 ---- -static bool parse_sse_line(const std::string& line, std::string& token_out) -{ - // SSE 格式: "data: " 或 "data: [DONE]" - if (line.rfind("data: ", 0) != 0) return false; - std::string data = line.substr(6); - if (data == "[DONE]") { - token_out.clear(); - return true; // 流结束信号 - } - - try { - auto jv = json::parse(data); - auto obj = jv.as_object(); - auto choices = obj["choices"].as_array(); - if (!choices.empty()) { - auto delta = choices[0].as_object()["delta"].as_object(); - if (delta.contains("content")) { - token_out = json::value_to(delta["content"]); - return true; - } - } - } catch (...) { - // 忽略解析失败的行 - } - return false; -} - -// ---- Impl ---- -struct DeepSeekClient::Impl { - net::HttpClient http; - ApiConfig config; - - std::string extract_host_port(std::string& target) { - // base_url 例如 "https://api.deepseek.com/v1" - // 提取 host: "api.deepseek.com" - // 提取 target 前缀: "/v1" - std::string url = config.base_url; - if (url.rfind("https://", 0) == 0) url = url.substr(8); - else if (url.rfind("http://", 0) == 0) url = url.substr(7); - - size_t slash = url.find('/'); - if (slash != std::string::npos) { - target = url.substr(slash); - return url.substr(0, slash); - } - target = "/"; - return url; - } -}; - -DeepSeekClient::DeepSeekClient() : impl_(new Impl{}) {} -DeepSeekClient::~DeepSeekClient() { delete impl_; } - -void DeepSeekClient::configure(const ApiConfig& config) -{ - impl_->config = config; -} - -ChatResult DeepSeekClient::chat( - const std::vector& history, - const std::string& user_input) -{ - std::string target; - std::string host = impl_->extract_host_port(target); - std::string target_path = target + "/chat/completions"; - - std::string body = build_request_json( - impl_->config, history, user_input, false); - - std::unordered_map headers; - headers["Authorization"] = "Bearer " + impl_->config.api_key; - - auto resp = impl_->http.post_json(host, "443", target_path, body, headers); - return parse_response(resp.body, resp.status_code); -} - -ChatResult DeepSeekClient::chat_stream( - const std::vector& history, - const std::string& user_input, - bool (*on_token)(const std::string& token, void* userdata), - void* userdata) -{ - std::string target; - std::string host = impl_->extract_host_port(target); - std::string target_path = target + "/chat/completions"; - - std::string body = build_request_json( - impl_->config, history, user_input, true); - - std::unordered_map headers; - headers["Authorization"] = "Bearer " + impl_->config.api_key; - - ChatResult result; - - auto resp = impl_->http.post_stream(host, "443", target_path, body, headers, - [&](const std::string& line) -> bool { - if (line.empty()) return true; - std::string token; - if (!parse_sse_line(line, token)) return true; - if (token.empty()) return false; // [DONE] - result.content += token; - return on_token ? on_token(token, userdata) : true; - }); - - result.http_status = resp.status_code; - - // 检查传输层错误或非 2xx 状态 - if (resp.status_code < 200 || resp.status_code >= 300) { - result.ok = false; - // 尝试从响应 body 提取错误信息(与 parse_response 等同逻辑) - try { - auto jv = json::parse(resp.body); - auto obj = jv.as_object(); - if (obj.contains("error")) { - auto err = obj["error"].as_object(); - result.error = json::value_to(err["message"]); - } - } catch (...) { - } - if (result.error.empty()) { - if (resp.status_code <= 0) { - result.error = "transport error"; - } else { - result.error = "HTTP " + std::to_string(resp.status_code); - } - } - return result; - } - - if (result.content.empty()) { - result.ok = false; - result.error = "no content received"; - } else { - result.ok = true; - } - return result; -} - -} // namespace ai -} // namespace dstalk diff --git a/dstalk-core/src/ai/deepseek_api.hpp b/dstalk-core/src/ai/deepseek_api.hpp deleted file mode 100644 index e502a07..0000000 --- a/dstalk-core/src/ai/deepseek_api.hpp +++ /dev/null @@ -1,65 +0,0 @@ -#pragma once - -#include -#include - -namespace dstalk { -namespace ai { - -// 单条消息 -struct Message { - std::string role; // "system", "user", "assistant" - std::string content; -}; - -// API 配置 -struct ApiConfig { - std::string provider; // 默认 "deepseek" - std::string base_url; // 默认 "https://api.deepseek.com/v1" - std::string api_key; - std::string model; // 默认 "deepseek-chat" - int max_tokens = 4096; - double temperature = 0.7; -}; - -// 对话补全结果 -struct ChatResult { - bool ok = false; - std::string content; - std::string error; - int http_status = 0; -}; - -/* - * DeepSeek API 客户端 (OpenAI 兼容) - * 内部使用 HttpClient 进行 HTTPS 通信 - */ -class DeepSeekClient { -public: - DeepSeekClient(); - ~DeepSeekClient(); - - // 配置 API 参数 - void configure(const ApiConfig& config); - - // 同步对话 (发送全部历史 + 新消息, 返回完整回复) - ChatResult chat( - const std::vector& history, - const std::string& user_input - ); - - // 流式对话, 每收到一个 token 调用 on_token, 返回 true 继续 / false 取消 - ChatResult chat_stream( - const std::vector& history, - const std::string& user_input, - bool (*on_token)(const std::string& token, void* userdata), - void* userdata = nullptr - ); - -private: - struct Impl; - Impl* impl_; -}; - -} // namespace ai -} // namespace dstalk diff --git a/dstalk-core/src/api.cpp b/dstalk-core/src/api.cpp deleted file mode 100644 index 6d352ce..0000000 --- a/dstalk-core/src/api.cpp +++ /dev/null @@ -1,306 +0,0 @@ -#include "dstalk/dstalk_api.h" -#include "ai/deepseek_api.hpp" -#include "file/file_io.hpp" -#include "net/http_client.hpp" - -#include - -#include -#include -#include -#include -#include -#include - -namespace json = boost::json; - -namespace { - -bool g_initialized = false; -dstalk::ai::DeepSeekClient g_ai; -dstalk::ai::ApiConfig g_config; -std::vector g_history; - -// 默认配置 -const char* DEFAULT_PROVIDER = "deepseek"; -const char* DEFAULT_BASE_URL = "https://api.deepseek.com/v1"; -const char* DEFAULT_MODEL = "deepseek-chat"; - -/* - * 简易 TOML 解析 (只处理 [api] 段中的 key = "value") - * 足够读取 dstalk 配置文件,不引入第三方 TOML 库 - */ -void parse_config_file(const char* path) -{ - if (!path) return; - size_t len = 0; - char* content = file_read_all(path, &len); - if (!content) return; - - std::string data(content, len); - std::free(content); - - std::string current_section; - size_t pos = 0; - while (pos < data.size()) { - // 跳过空白 - while (pos < data.size() && (data[pos] == ' ' || data[pos] == '\t')) - pos++; - if (pos >= data.size()) break; - - // 找行尾 - size_t nl = data.find('\n', pos); - std::string line = (nl != std::string::npos) - ? data.substr(pos, nl - pos) : data.substr(pos); - pos = (nl != std::string::npos) ? nl + 1 : data.size(); - - // 去尾随 \r 和空白 - while (!line.empty() && (line.back() == '\r' || line.back() == ' ')) - line.pop_back(); - - // 跳过空行和注释 - if (line.empty() || line[0] == '#') continue; - - // [section] - if (line[0] == '[' && line.back() == ']') { - current_section = line.substr(1, line.size() - 2); - continue; - } - - // key = "value" 或 key = value - size_t eq = line.find('='); - if (eq == std::string::npos) continue; - - std::string key = line.substr(0, eq); - while (!key.empty() && key.back() == ' ') key.pop_back(); - if (key.empty()) continue; - - std::string val = line.substr(eq + 1); - while (!val.empty() && (val.front() == ' ' || val.front() == '\t')) - val.erase(0, 1); - // 去引号 - if (val.size() >= 2 && val.front() == '"' && val.back() == '"') - val = val.substr(1, val.size() - 2); - - if (current_section == "api") { - if (key == "provider") - g_config.provider = val; - else if (key == "api_key" || key == "apikey") - g_config.api_key = val; - else if (key == "base_url") - g_config.base_url = val; - else if (key == "model") - g_config.model = val; - } - } -} - -char* copy_to_c_string(const std::string& value) -{ - char* output = static_cast(std::malloc(value.size() + 1)); - if (!output) return nullptr; - std::memcpy(output, value.c_str(), value.size() + 1); - return output; -} - -} // anonymous namespace - -// ---- 初始化 / 销毁 ---- - -DSTALK_API int dstalk_init(const char* config_path) -{ - if (g_initialized) return -1; - - // 设置默认值 - g_config.provider = DEFAULT_PROVIDER; - g_config.base_url = DEFAULT_BASE_URL; - g_config.model = DEFAULT_MODEL; - g_config.max_tokens = 4096; - g_config.temperature = 0.7; - g_history.clear(); - - // 读取配置文件 - if (config_path) { - parse_config_file(config_path); - } - - g_ai.configure(g_config); - g_initialized = true; - return 0; -} - -DSTALK_API void dstalk_destroy(void) -{ - if (!g_initialized) return; - g_history.clear(); - g_initialized = false; -} - -// ---- 配置 ---- - -DSTALK_API void dstalk_set_api_key(const char* api_key) -{ - if (!g_initialized || !api_key) return; - g_config.api_key = api_key; - g_ai.configure(g_config); -} - -DSTALK_API void dstalk_set_base_url(const char* base_url) -{ - if (!g_initialized || !base_url) return; - g_config.base_url = base_url; - g_ai.configure(g_config); -} - -DSTALK_API void dstalk_set_model(const char* model) -{ - if (!g_initialized || !model) return; - g_config.model = model; - g_ai.configure(g_config); -} - -// ---- AI 对话 ---- - -DSTALK_API int dstalk_chat(const char* input, char** output) -{ - if (!g_initialized || !input || !output) return -1; - *output = nullptr; - - auto result = g_ai.chat(g_history, input); - if (!result.ok) { - *output = copy_to_c_string(result.error); - return -1; - } - - char* reply = copy_to_c_string(result.content); - if (!reply) return -1; - - g_history.push_back({"user", input}); - g_history.push_back({"assistant", result.content}); - *output = reply; - return 0; -} - - -// 流式回调上下文 -struct StreamCtx { - std::string* buf; - dstalk_stream_cb cb; - void* ud; - bool cancelled; -}; - -static bool on_token_proxy(const std::string& token, void* userdata) -{ - auto* ctx = static_cast(userdata); - *ctx->buf += token; - int ret = ctx->cb(token.c_str(), ctx->ud); - if (ret == 0) return true; - ctx->cancelled = true; - return false; -} - -DSTALK_API int dstalk_chat_stream(const char* input, - dstalk_stream_cb cb, void* userdata) -{ - if (!g_initialized || !input || !cb) return -1; - - std::string full_reply; - StreamCtx ctx{&full_reply, cb, userdata, false}; - auto result = g_ai.chat_stream(g_history, input, on_token_proxy, &ctx); - - if (!result.ok && !ctx.cancelled) return -1; - - // 更新历史 - g_history.push_back({"user", input}); - g_history.push_back({"assistant", full_reply}); - - return 0; -} - -DSTALK_API void dstalk_free_string(char* str) -{ - std::free(str); -} - -// ---- 会话管理 ---- - -DSTALK_API void dstalk_session_clear(void) -{ - if (!g_initialized) return; - g_history.clear(); -} - -DSTALK_API int dstalk_session_save(const char* path) -{ - if (!g_initialized || !path) return -1; - - std::string data; - for (const auto& m : g_history) { - json::object entry; - entry["role"] = m.role; - entry["content"] = m.content; - data += json::serialize(entry); - data += '\n'; - } - return file_write_all(path, data.c_str()); -} - -DSTALK_API int dstalk_session_load(const char* path) -{ - if (!g_initialized || !path) return -1; - size_t len = 0; - char* content = file_read_all(path, &len); - if (!content) return -1; - - std::string data(content, len); - std::free(content); - - std::vector parsed; - - size_t pos = 0; - while (pos < data.size()) { - size_t nl = data.find('\n', pos); - std::string line = (nl != std::string::npos) - ? data.substr(pos, nl - pos) : data.substr(pos); - pos = (nl != std::string::npos) ? nl + 1 : data.size(); - if (line.empty()) continue; - - try { - auto obj = json::parse(line).as_object(); - auto* role = obj.if_contains("role"); - auto* content_val = obj.if_contains("content"); - if (role && content_val && role->is_string() && content_val->is_string()) { - parsed.push_back({json::value_to(*role), - json::value_to(*content_val)}); - } - } catch (const std::exception&) { - return -1; - } - } - - if (parsed.empty()) return -1; - g_history = std::move(parsed); - return 0; -} - -// ---- 文件操作 ---- - -DSTALK_API int dstalk_file_read(const char* path, char** content) -{ - if (!g_initialized || !path || !content) return -1; - *content = nullptr; - - size_t len = 0; - char* buf = file_read_all(path, &len); - if (!buf) return -1; - - *content = buf; - return 0; -} - -DSTALK_API int dstalk_file_write(const char* path, const char* content) -{ - if (!g_initialized || !path || !content) return -1; - return file_write_all(path, content); -} diff --git a/dstalk-core/src/config_store.cpp b/dstalk-core/src/config_store.cpp new file mode 100644 index 0000000..8b865fd --- /dev/null +++ b/dstalk-core/src/config_store.cpp @@ -0,0 +1,83 @@ +#include "config_store.hpp" + +#include +#include +#include +#include + +namespace dstalk { + +int ConfigStore::load_file(const char* path) +{ + if (!path) return -1; + + std::ifstream file(path); + if (!file.is_open()) return -1; + + std::stringstream ss; + ss << file.rdbuf(); + std::string data = ss.str(); + + // 简易 TOML 解析:只处理 [section] 和 key = "value" + std::string current_section; + size_t pos = 0; + while (pos < data.size()) { + while (pos < data.size() && (data[pos] == ' ' || data[pos] == '\t')) + pos++; + if (pos >= data.size()) break; + + size_t nl = data.find('\n', pos); + std::string line = (nl != std::string::npos) + ? data.substr(pos, nl - pos) : data.substr(pos); + pos = (nl != std::string::npos) ? nl + 1 : data.size(); + + while (!line.empty() && (line.back() == '\r' || line.back() == ' ')) + line.pop_back(); + + if (line.empty() || line[0] == '#') continue; + + if (line[0] == '[' && line.back() == ']') { + current_section = line.substr(1, line.size() - 2); + continue; + } + + size_t eq = line.find('='); + if (eq == std::string::npos) continue; + + std::string key = line.substr(0, eq); + while (!key.empty() && key.back() == ' ') key.pop_back(); + if (key.empty()) continue; + + std::string val = line.substr(eq + 1); + while (!val.empty() && (val.front() == ' ' || val.front() == '\t')) + val.erase(0, 1); + if (val.size() >= 2 && val.front() == '"' && val.back() == '"') + val = val.substr(1, val.size() - 2); + + std::lock_guard lock(mutex_); + std::string full_key = current_section.empty() + ? key : current_section + "." + key; + data_[full_key] = val; + } + + return 0; +} + +const char* ConfigStore::get(const char* key) const +{ + if (!key) return nullptr; + std::lock_guard lock(mutex_); + auto it = data_.find(key); + if (it == data_.end()) return nullptr; + return it->second.c_str(); +} + +int ConfigStore::set(const char* key, const char* value) +{ + if (!key || !value) return -1; + std::lock_guard lock(mutex_); + data_[key] = value; + return 0; +} + +} // namespace dstalk diff --git a/dstalk-core/src/config_store.hpp b/dstalk-core/src/config_store.hpp new file mode 100644 index 0000000..c7ec5a8 --- /dev/null +++ b/dstalk-core/src/config_store.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace dstalk { + +class ConfigStore { +public: + ConfigStore() = default; + ~ConfigStore() = default; + + // 从 TOML 文件加载配置 + int load_file(const char* path); + + // 获取配置值(返回内部指针,线程安全) + const char* get(const char* key) const; + + // 设置配置值 + int set(const char* key, const char* value); + +private: + mutable std::mutex mutex_; + std::unordered_map data_; +}; + +} // namespace dstalk diff --git a/dstalk-core/src/event_bus.cpp b/dstalk-core/src/event_bus.cpp new file mode 100644 index 0000000..42d6fdc --- /dev/null +++ b/dstalk-core/src/event_bus.cpp @@ -0,0 +1,39 @@ +#include "event_bus.hpp" + +#include + +namespace dstalk { + +int EventBus::subscribe(int event_type, EventHandler handler) +{ + std::unique_lock lock(mutex_); + int id = next_id_++; + subscriptions_.push_back({id, event_type, std::move(handler)}); + return id; +} + +void EventBus::unsubscribe(int subscription_id) +{ + std::unique_lock lock(mutex_); + subscriptions_.erase( + std::remove_if(subscriptions_.begin(), subscriptions_.end(), + [subscription_id](const Subscription& s) { + return s.id == subscription_id; + }), + subscriptions_.end()); +} + +int EventBus::emit(int event_type, const void* data) +{ + std::shared_lock lock(mutex_); + int count = 0; + for (const auto& sub : subscriptions_) { + if (sub.event_type == event_type) { + sub.handler(event_type, data); + count++; + } + } + return count; +} + +} // namespace dstalk diff --git a/dstalk-core/src/event_bus.hpp b/dstalk-core/src/event_bus.hpp new file mode 100644 index 0000000..e4401a1 --- /dev/null +++ b/dstalk-core/src/event_bus.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace dstalk { + +using EventHandler = std::function; + +class EventBus { +public: + EventBus() = default; + ~EventBus() = default; + + // 订阅事件,返回订阅ID + int subscribe(int event_type, EventHandler handler); + + // 取消订阅 + void unsubscribe(int subscription_id); + + // 发布事件 + int emit(int event_type, const void* data); + +private: + struct Subscription { + int id; + int event_type; + EventHandler handler; + }; + + mutable std::shared_mutex mutex_; + std::vector subscriptions_; + int next_id_ = 1; +}; + +} // namespace dstalk diff --git a/dstalk-core/src/file/file_io.cpp b/dstalk-core/src/file/file_io.cpp deleted file mode 100644 index 30b093a..0000000 --- a/dstalk-core/src/file/file_io.cpp +++ /dev/null @@ -1,81 +0,0 @@ -#include "file/file_io.hpp" - -#include -#include -#include - -#ifdef _WIN32 -#include -#define STDIN_FILENO _fileno(stdin) -#else -#include -#endif - -char* file_read_all(const char* path, size_t* out_len) -{ - if (!path || !out_len) return nullptr; - - FILE* f = nullptr; -#ifdef _WIN32 - fopen_s(&f, path, "rb"); -#else - f = fopen(path, "rb"); -#endif - if (!f) { - *out_len = 0; - return nullptr; - } - - fseek(f, 0, SEEK_END); - long sz = ftell(f); - fseek(f, 0, SEEK_SET); - if (sz < 0) { - fclose(f); - *out_len = 0; - return nullptr; - } - - if (sz == 0) { - fclose(f); - char* buf = (char*)std::malloc(1); - if (!buf) { - *out_len = 0; - return nullptr; - } - buf[0] = '\0'; - *out_len = 0; - return buf; - } - - char* buf = (char*)std::malloc(static_cast(sz) + 1); - if (!buf) { - fclose(f); - *out_len = 0; - return nullptr; - } - - size_t n = fread(buf, 1, static_cast(sz), f); - fclose(f); - buf[n] = '\0'; - *out_len = n; - return buf; -} - -int file_write_all(const char* path, const char* content) -{ - if (!path || !content) return -1; - - FILE* f = nullptr; -#ifdef _WIN32 - fopen_s(&f, path, "wb"); -#else - f = fopen(path, "wb"); -#endif - if (!f) return -1; - - size_t len = strlen(content); - size_t written = fwrite(content, 1, len, f); - fclose(f); - - return (written == len) ? 0 : -1; -} diff --git a/dstalk-core/src/file/file_io.hpp b/dstalk-core/src/file/file_io.hpp deleted file mode 100644 index 74ec9a9..0000000 --- a/dstalk-core/src/file/file_io.hpp +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include - -/* - * 内部文件 IO 实现 - * 读取整个文件到内存,返回 malloc 分配的 C 字符串 - * 调用方负责 free - */ -char* file_read_all(const char* path, size_t* out_len); - -/* - * 将内容写入文件(覆盖模式) - * 返回 0 成功,-1 失败 - */ -int file_write_all(const char* path, const char* content); - -#ifdef __cplusplus -} -#endif diff --git a/dstalk-core/src/host.cpp b/dstalk-core/src/host.cpp new file mode 100644 index 0000000..6a6d68b --- /dev/null +++ b/dstalk-core/src/host.cpp @@ -0,0 +1,362 @@ +#include "dstalk/dstalk_host.h" +#include "config_store.hpp" +#include "event_bus.hpp" +#include "service_registry.hpp" +#include "plugin_loader.hpp" + +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +// ============================================================ +// 全局主机上下文 +// ============================================================ +namespace { + std::mutex g_init_mutex; + bool g_initialized = false; + + dstalk::ConfigStore* g_config = nullptr; + dstalk::EventBus* g_event_bus = nullptr; + dstalk::ServiceRegistry* g_service_registry = nullptr; + dstalk::PluginLoader* g_plugin_loader = nullptr; + static dstalk_diag_cb g_diag_callback = nullptr; + + // ---- 内部辅助 ---- + + char* host_strdup(const char* s) { + if (!s) return nullptr; + size_t len = strlen(s); + char* copy = (char*)malloc(len + 1); + if (copy) memcpy(copy, s, len + 1); + return copy; + } + + void host_log_impl(int level, const char* fmt, va_list args) { + const char* prefix = ""; + switch (level) { + case DSTALK_LOG_DEBUG: prefix = "[DEBUG] "; break; + case DSTALK_LOG_INFO: prefix = "[INFO] "; break; + case DSTALK_LOG_WARN: prefix = "[WARN] "; break; + case DSTALK_LOG_ERROR: prefix = "[ERROR] "; break; + } + fprintf(stderr, "%s", prefix); + va_list args_copy; + va_copy(args_copy, args); + vfprintf(stderr, fmt, args); + fprintf(stderr, "\n"); + // 转发到诊断回调 + if (g_diag_callback) { + char buf[1024]; + vsnprintf(buf, sizeof(buf), fmt, args_copy); + g_diag_callback(level, nullptr, 0, nullptr, buf); + } + va_end(args_copy); + } + + void host_log(int level, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + host_log_impl(level, fmt, args); + va_end(args); + } + + // ---- Host API 表回调 ---- + + int api_register_service(const char* name, int version, void* vtable) { + return g_service_registry ? g_service_registry->register_service(name, version, vtable) : -1; + } + + void* api_query_service(const char* name, int min_version) { + return g_service_registry ? g_service_registry->query_service(name, min_version) : nullptr; + } + + int api_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata) { + if (!g_event_bus || !handler) return -1; + return g_event_bus->subscribe(event_type, + [handler, userdata](int type, const void* data) { + handler(type, data, userdata); + }); + } + + int api_event_emit(int event_type, const void* data) { + return g_event_bus ? g_event_bus->emit(event_type, data) : -1; + } + + void api_event_unsubscribe(int sub_id) { + if (g_event_bus) g_event_bus->unsubscribe(sub_id); + } + + const char* api_config_get(const char* key) { + return g_config ? g_config->get(key) : nullptr; + } + + int api_config_set(const char* key, const char* value) { + return g_config ? g_config->set(key, value) : -1; + } + + void api_log(int level, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + host_log_impl(level, fmt, args); + va_end(args); + } + + void* api_alloc(size_t size) { return malloc(size); } + void api_free(void* ptr) { free(ptr); } + + char* api_strdup(const char* s) { return host_strdup(s); } + + dstalk_host_api_t g_host_api = { + api_register_service, + api_query_service, + api_event_subscribe, + api_event_emit, + api_event_unsubscribe, + api_config_get, + api_config_set, + api_log, + api_alloc, + api_free, + api_strdup + }; + + // ---- 插件目录扫描 ---- + + int load_plugins_from_directory(const char* plugin_dir) { + if (!plugin_dir) return -1; + + try { + fs::path dir(plugin_dir); + if (!fs::exists(dir) || !fs::is_directory(dir)) return -1; + + int loaded = 0; + for (const auto& entry : fs::directory_iterator(dir)) { + if (!entry.is_regular_file()) continue; + + std::string ext = entry.path().extension().string(); +#ifdef _WIN32 + if (ext != ".dll") continue; +#else + if (ext != ".so" && ext != ".dylib") continue; +#endif + + int id = g_plugin_loader->load_plugin(entry.path().string().c_str()); + if (id >= 0) { + loaded++; + host_log(DSTALK_LOG_INFO, "Loaded plugin: %s", + entry.path().filename().string().c_str()); + } + } + return loaded; + } catch (const std::exception& e) { + host_log(DSTALK_LOG_ERROR, "Failed to scan plugin directory: %s", e.what()); + return -1; + } + } +} + +// ============================================================ +// 公共 API +// ============================================================ + +DSTALK_API int dstalk_init(const char* config_path) +{ + std::lock_guard lock(g_init_mutex); + + if (g_initialized) return -1; + + try { + g_config = new dstalk::ConfigStore(); + g_event_bus = new dstalk::EventBus(); + g_service_registry = new dstalk::ServiceRegistry(); + g_plugin_loader = new dstalk::PluginLoader(); + + // 加载配置 + if (config_path && config_path[0]) { + if (g_config->load_file(config_path) != 0) { + host_log(DSTALK_LOG_WARN, "Failed to load config: %s", config_path); + } + } + + // 扫描插件目录 + const char* plugin_dir = g_config->get("plugin_dir"); + if (!plugin_dir) plugin_dir = "plugins"; + load_plugins_from_directory(plugin_dir); + + // 初始化所有插件 + if (g_plugin_loader->initialize_all(&g_host_api) != 0) { + host_log(DSTALK_LOG_WARN, "Some plugins failed to initialize"); + } + + g_initialized = true; + host_log(DSTALK_LOG_INFO, "dstalk host initialized"); + return 0; + + } catch (const std::exception& e) { + host_log(DSTALK_LOG_ERROR, "Init failed: %s", e.what()); + delete g_plugin_loader; g_plugin_loader = nullptr; + delete g_service_registry; g_service_registry = nullptr; + delete g_event_bus; g_event_bus = nullptr; + delete g_config; g_config = nullptr; + return -1; + } +} + +DSTALK_API void dstalk_shutdown(void) +{ + std::lock_guard lock(g_init_mutex); + if (!g_initialized) return; + + host_log(DSTALK_LOG_INFO, "dstalk shutting down..."); + + if (g_plugin_loader) { + g_plugin_loader->shutdown_all(); + delete g_plugin_loader; + g_plugin_loader = nullptr; + } + + delete g_service_registry; g_service_registry = nullptr; + delete g_event_bus; g_event_bus = nullptr; + delete g_config; g_config = nullptr; + + g_initialized = false; +} + +DSTALK_API int dstalk_plugin_load(const char* path) +{ + if (!g_initialized || !g_plugin_loader) return -1; + int id = g_plugin_loader->load_plugin(path); + if (id >= 0) { + g_plugin_loader->initialize_pending(&g_host_api); + } + return id; +} + +DSTALK_API int dstalk_plugin_unload(int plugin_id) +{ + if (!g_initialized || !g_plugin_loader) return -1; + return g_plugin_loader->unload_plugin(plugin_id); +} + +DSTALK_API int dstalk_plugin_list(char** output_json) +{ + if (!g_initialized || !g_plugin_loader || !output_json) return -1; + *output_json = host_strdup(g_plugin_loader->list_plugins().c_str()); + return *output_json ? 0 : -1; +} + +DSTALK_API void* dstalk_service_query(const char* service_name, int min_version) +{ + if (!g_initialized || !g_service_registry) return nullptr; + return g_service_registry->query_service(service_name, min_version); +} + +DSTALK_API int dstalk_event_subscribe(int event_type, dstalk_event_handler_fn handler, void* userdata) +{ + if (!g_initialized || !g_event_bus || !handler) return -1; + return g_event_bus->subscribe(event_type, + [handler, userdata](int type, const void* data) { handler(type, data, userdata); }); +} + +DSTALK_API int dstalk_event_emit(int event_type, const void* data) +{ + if (!g_initialized || !g_event_bus) return -1; + return g_event_bus->emit(event_type, data); +} + +DSTALK_API void dstalk_event_unsubscribe(int subscription_id) +{ + if (!g_initialized || !g_event_bus) return; + g_event_bus->unsubscribe(subscription_id); +} + +DSTALK_API const char* dstalk_config_get(const char* key) +{ + if (!g_initialized || !g_config) return nullptr; + return g_config->get(key); +} + +DSTALK_API int dstalk_config_set(const char* key, const char* value) +{ + if (!g_initialized || !g_config) return -1; + return g_config->set(key, value); +} + +DSTALK_API void dstalk_log(int level, const char* fmt, ...) +{ + va_list args; + va_start(args, fmt); + host_log_impl(level, fmt, args); + va_end(args); +} + +DSTALK_API void* dstalk_alloc(size_t size) { return malloc(size); } +DSTALK_API void dstalk_free(void* ptr) { free(ptr); } +DSTALK_API char* dstalk_strdup(const char* s) { return host_strdup(s); } + +DSTALK_API void dstalk_set_diag_callback(dstalk_diag_cb cb) { + g_diag_callback = cb; +} + +// ============================================================ +// LSP 便捷函数 (委托给 "lsp" 服务插件) +// ============================================================ + +static const dstalk_lsp_service_t* get_lsp_service() { + if (!g_initialized || !g_service_registry) return nullptr; + return static_cast( + g_service_registry->query_service("lsp", 1)); +} + +DSTALK_API int dstalk_lsp_start(const char* server_cmd, const char* language) +{ + auto* svc = get_lsp_service(); + if (!svc || !svc->start) return -1; + return svc->start(server_cmd, language); +} + +DSTALK_API void dstalk_lsp_stop(void) +{ + auto* svc = get_lsp_service(); + if (svc && svc->stop) svc->stop(); +} + +DSTALK_API int dstalk_lsp_open(const char* uri, const char* content, const char* language_id) +{ + auto* svc = get_lsp_service(); + if (!svc || !svc->open_document) return -1; + return svc->open_document(uri, content, language_id); +} + +DSTALK_API int dstalk_lsp_close(const char* uri) +{ + auto* svc = get_lsp_service(); + if (!svc || !svc->close_document) return -1; + return svc->close_document(uri); +} + +DSTALK_API int dstalk_lsp_diagnostics(const char* uri, char** output) +{ + auto* svc = get_lsp_service(); + if (!svc || !svc->get_diagnostics) return -1; + return svc->get_diagnostics(uri, output); +} + +DSTALK_API int dstalk_lsp_hover(const char* uri, int line, int character, char** output) +{ + auto* svc = get_lsp_service(); + if (!svc || !svc->get_hover) return -1; + return svc->get_hover(uri, line, character, output); +} + +DSTALK_API int dstalk_lsp_completion(const char* uri, int line, int character, char** output) +{ + auto* svc = get_lsp_service(); + if (!svc || !svc->get_completion) return -1; + return svc->get_completion(uri, line, character, output); +} diff --git a/dstalk-core/src/net/http_client.cpp b/dstalk-core/src/net/http_client.cpp deleted file mode 100644 index a9da252..0000000 --- a/dstalk-core/src/net/http_client.cpp +++ /dev/null @@ -1,164 +0,0 @@ -// MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20) -#define BOOST_ASIO_DISABLE_STD_TO_ADDRESS - -#include "net/http_client.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace beast = boost::beast; -namespace http = beast::http; -namespace asio = boost::asio; -namespace ssl = boost::asio::ssl; -using tcp = asio::ip::tcp; - -namespace dstalk { -namespace net { - -struct HttpClient::Impl { - asio::io_context ioc; - ssl::context ssl_ctx{ssl::context::tlsv12_client}; - int connect_timeout = 30; - int request_timeout = 120; - - Impl() { - ssl_ctx.set_default_verify_paths(); - } -}; - -HttpClient::HttpClient() : impl_(new Impl{}) {} -HttpClient::~HttpClient() { delete impl_; } - -void HttpClient::set_timeout(int connect_sec, int request_sec) -{ - impl_->connect_timeout = connect_sec; - impl_->request_timeout = request_sec; -} - -HttpResponse HttpClient::post_json( - const std::string& host, - const std::string& port, - const std::string& target, - const std::string& json_body, - const std::unordered_map& extra_headers) -{ - return post_stream(host, port, target, json_body, extra_headers, nullptr); -} - -HttpResponse HttpClient::post_stream( - const std::string& host, - const std::string& port, - const std::string& target, - const std::string& json_body, - const std::unordered_map& extra_headers, - std::function on_line) -{ - HttpResponse result; - - try { - tcp::resolver resolver(impl_->ioc); - auto endpoints = resolver.resolve(host, port); - - ssl::stream stream(impl_->ioc, impl_->ssl_ctx); - beast::flat_buffer buffer; - - // SNI hostname (required for HTTPS) - if (!SSL_set_tlsext_host_name(stream.native_handle(), host.c_str())) { - result.status_code = -1; - return result; - } - - asio::connect(beast::get_lowest_layer(stream), endpoints); - stream.handshake(ssl::stream_base::client); - - // Build HTTP POST request - http::request req{http::verb::post, target, 11}; - req.set(http::field::host, host); - req.set(http::field::user_agent, "dstalk/0.1"); - req.set(http::field::content_type, "application/json"); - req.body() = json_body; - req.prepare_payload(); - - for (const auto& h : extra_headers) { - req.set(h.first, h.second); - } - - // Send - http::write(stream, req); - - // Read response - http::response_parser parser; - parser.body_limit(16 * 1024 * 1024); - http::read_header(stream, buffer, parser); - - result.status_code = parser.get().result_int(); - - beast::error_code ec; - - if (on_line) { - std::string fragment = parser.get().body(); - auto emit_lines = [&]() -> bool { - size_t pos = 0; - while (pos < fragment.size()) { - size_t nl = fragment.find('\n', pos); - if (nl == std::string::npos) break; - std::string line = fragment.substr(pos, nl - pos); - if (!line.empty() && line.back() == '\r') - line.pop_back(); - if (!on_line(line)) return false; - pos = nl + 1; - } - if (pos > 0) - fragment = fragment.substr(pos); - return true; - }; - if (!emit_lines()) goto done; - - size_t processed = parser.get().body().size(); - while (!parser.is_done()) { - http::read_some(stream, buffer, parser, ec); - if (ec) break; - - const std::string& full_body = parser.get().body(); - if (full_body.size() > processed) { - std::string_view new_data(full_body.data() + processed, - full_body.size() - processed); - processed = full_body.size(); - - fragment.append(new_data.data(), new_data.size()); - if (!emit_lines()) goto done; - } - } - if (!fragment.empty()) { - if (fragment.back() == '\r') - fragment.pop_back(); - if (!fragment.empty()) - on_line(fragment); - } - } else { - while (!parser.is_done()) { - http::read_some(stream, buffer, parser, ec); - if (ec) break; - } - } -done: - result.body = parser.get().body(); - beast::get_lowest_layer(stream).cancel(); - stream.shutdown(ec); - } catch (std::exception& e) { - result.status_code = -1; - result.body = e.what(); - } - - return result; -} - -} // namespace net -} // namespace dstalk diff --git a/dstalk-core/src/net/http_client.hpp b/dstalk-core/src/net/http_client.hpp deleted file mode 100644 index e67503b..0000000 --- a/dstalk-core/src/net/http_client.hpp +++ /dev/null @@ -1,53 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace dstalk { -namespace net { - -struct HttpResponse { - int status_code = 0; - std::string body; - std::unordered_map headers; -}; - -/* - * HTTPS 客户端统一接口 - * 所有平台统一使用 Boost.Beast + OpenSSL 实现 - */ -class HttpClient { -public: - HttpClient(); - ~HttpClient(); - - void set_timeout(int connect_sec, int request_sec); - - // 同步 POST JSON, 返回完整响应 - HttpResponse post_json( - const std::string& host, - const std::string& port, - const std::string& target, - const std::string& json_body, - const std::unordered_map& extra_headers - ); - - // 流式 POST (SSE 逐行回调), on_line 返回 false 提前终止 - using StreamCallback = std::function; - HttpResponse post_stream( - const std::string& host, - const std::string& port, - const std::string& target, - const std::string& json_body, - const std::unordered_map& extra_headers, - StreamCallback on_line - ); - -private: - struct Impl; - Impl* impl_; -}; - -} // namespace net -} // namespace dstalk diff --git a/dstalk-core/src/net/http_client_win.cpp b/dstalk-core/src/net/http_client_win.cpp deleted file mode 100644 index aadc331..0000000 --- a/dstalk-core/src/net/http_client_win.cpp +++ /dev/null @@ -1,223 +0,0 @@ -#include "net/http_client.hpp" - -#ifdef _WIN32 - -#define WIN32_LEAN_AND_MEAN -#include -#include -#include -#include - -#pragma comment(lib, "winhttp.lib") - -namespace dstalk { -namespace net { - -// ---- 宽字符转换 ---- -static std::wstring to_w(const std::string& s) -{ - if (s.empty()) return L""; - int len = MultiByteToWideChar(CP_UTF8, 0, s.c_str(), -1, nullptr, 0); - std::wstring out(len - 1, L'\0'); - MultiByteToWideChar(CP_UTF8, 0, s.c_str(), -1, &out[0], len); - return out; -} - -// ---- 读取全部 body ---- -static std::string read_all(HINTERNET hRequest, DWORD& status_code) -{ - DWORD status = 0, statusSize = sizeof(status); - WinHttpQueryHeaders(hRequest, - WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, - WINHTTP_HEADER_NAME_BY_INDEX, &status, &statusSize, - WINHTTP_NO_HEADER_INDEX); - status_code = status; - - std::string body; - char buf[4096]; - DWORD bytesRead = 0; - while (WinHttpReadData(hRequest, buf, sizeof(buf), &bytesRead)) { - if (bytesRead == 0) break; - body.append(buf, bytesRead); - } - return body; -} - -// ---- 流式读取 (SSE 逐行回调) ---- -static std::string read_stream(HINTERNET hRequest, DWORD& status_code, - HttpClient::StreamCallback on_line) -{ - DWORD status = 0, statusSize = sizeof(status); - WinHttpQueryHeaders(hRequest, - WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, - WINHTTP_HEADER_NAME_BY_INDEX, &status, &statusSize, - WINHTTP_NO_HEADER_INDEX); - status_code = status; - - if (status < 200 || status >= 300) { - return read_all(hRequest, status_code); - } - - std::string body; - std::string lineBuf; - char buf[1024]; - DWORD bytesRead = 0; - - while (WinHttpReadData(hRequest, buf, sizeof(buf), &bytesRead)) { - if (bytesRead == 0) break; - - for (DWORD i = 0; i < bytesRead; i++) { - char c = buf[i]; - body += c; - if (c == '\n') { - while (!lineBuf.empty() && lineBuf.back() == '\r') - lineBuf.pop_back(); - if (!lineBuf.empty()) { - if (!on_line(lineBuf)) return body; - } - lineBuf.clear(); - } else if (c != '\r') { - lineBuf += c; - } - } - } - while (!lineBuf.empty() && lineBuf.back() == '\r') - lineBuf.pop_back(); - if (!lineBuf.empty()) on_line(lineBuf); - return body; -} - -// ---- Impl ---- -struct HttpClient::Impl { - HINTERNET hSession = nullptr; - int connect_timeout = 30; - int request_timeout = 120; -}; - -HttpClient::HttpClient() : impl_(new Impl{}) -{ - impl_->hSession = WinHttpOpen( - L"dstalk/0.1", - WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, - WINHTTP_NO_PROXY_NAME, - WINHTTP_NO_PROXY_BYPASS, 0); -} - -HttpClient::~HttpClient() -{ - if (impl_->hSession) WinHttpCloseHandle(impl_->hSession); - delete impl_; -} - -void HttpClient::set_timeout(int connect_sec, int request_sec) -{ - impl_->connect_timeout = connect_sec; - impl_->request_timeout = request_sec; -} - -// ---- 核心请求逻辑 ---- -static HttpResponse do_request( - HINTERNET hSession, - const std::string& host, - const std::string& port, - const std::string& target, - const std::string& json_body, - const std::unordered_map& extra_headers, - int connect_timeout, - int request_timeout, - HttpClient::StreamCallback on_line) -{ - HttpResponse result; - - int nPort = port.empty() ? 443 : std::stoi(port); - DWORD flags = (nPort == 443) ? WINHTTP_FLAG_SECURE : 0; - - std::wstring wHost = to_w(host); - std::wstring wPath = to_w(target); - HINTERNET hConnect = WinHttpConnect(hSession, wHost.c_str(), (WORD)nPort, 0); - if (!hConnect) { result.status_code = -1; return result; } - - LPCWSTR acceptTypes[] = { L"application/json", nullptr }; - HINTERNET hRequest = WinHttpOpenRequest( - hConnect, L"POST", wPath.c_str(), - nullptr, WINHTTP_NO_REFERER, acceptTypes, flags); - if (!hRequest) { - WinHttpCloseHandle(hConnect); - result.status_code = -1; - return result; - } - - // Headers - WinHttpAddRequestHeaders(hRequest, - L"Content-Type: application/json\r\n", -1, - WINHTTP_ADDREQ_FLAG_ADD); - for (const auto& h : extra_headers) { - std::string hdr = h.first + ": " + h.second + "\r\n"; - std::wstring whdr = to_w(hdr); - WinHttpAddRequestHeaders(hRequest, whdr.c_str(), -1, - WINHTTP_ADDREQ_FLAG_ADD); - } - - // Timeouts - WinHttpSetTimeouts(hRequest, - connect_timeout * 1000, connect_timeout * 1000, - request_timeout * 1000, request_timeout * 1000); - - // Send - DWORD bodyLen = (DWORD)json_body.size(); - BOOL sent = WinHttpSendRequest( - hRequest, WINHTTP_NO_ADDITIONAL_HEADERS, 0, - (LPVOID)json_body.data(), bodyLen, bodyLen, 0); - if (!sent) { - WinHttpCloseHandle(hRequest); - WinHttpCloseHandle(hConnect); - result.status_code = -1; - return result; - } - - // Receive - WinHttpReceiveResponse(hRequest, nullptr); - DWORD status = 0; - if (on_line) { - result.body = read_stream(hRequest, status, on_line); - } else { - result.body = read_all(hRequest, status); - } - result.status_code = (int)status; - - WinHttpCloseHandle(hRequest); - WinHttpCloseHandle(hConnect); - return result; -} - -HttpResponse HttpClient::post_json( - const std::string& host, - const std::string& port, - const std::string& target, - const std::string& json_body, - const std::unordered_map& extra_headers) -{ - return post_stream(host, port, target, json_body, extra_headers, nullptr); -} - -HttpResponse HttpClient::post_stream( - const std::string& host, - const std::string& port, - const std::string& target, - const std::string& json_body, - const std::unordered_map& extra_headers, - StreamCallback on_line) -{ - return do_request(impl_->hSession, host, port, target, json_body, - extra_headers, - impl_->connect_timeout, impl_->request_timeout, - on_line); -} - -} // namespace net -} // namespace dstalk - -#else -// 非 Windows: 需要 Boost.Beast 实现 (编译时会报错提示) -# error "WinHTTP backend is Windows-only. Use net/http_client.cpp for non-Windows builds." -#endif diff --git a/dstalk-core/src/plugin_loader.cpp b/dstalk-core/src/plugin_loader.cpp new file mode 100644 index 0000000..7cab26f --- /dev/null +++ b/dstalk-core/src/plugin_loader.cpp @@ -0,0 +1,291 @@ +#include "plugin_loader.hpp" + +#include + +#ifdef _WIN32 + #include +#else + #include +#endif + +#include +#include +#include + +namespace dstalk { + +namespace json = boost::json; + +PluginLoader::~PluginLoader() +{ + shutdown_all(); +} + +int PluginLoader::load_plugin(const char* path) +{ + if (!path) return -1; + + // 加载DLL +#ifdef _WIN32 + void* handle = LoadLibraryA(path); +#else + void* handle = dlopen(path, RTLD_NOW | RTLD_LOCAL); +#endif + + if (!handle) { + return -1; + } + + // 获取入口函数 +#ifdef _WIN32 + auto init_fn = (dstalk_plugin_init_fn)GetProcAddress( + (HMODULE)handle, "dstalk_plugin_init"); +#else + auto init_fn = (dstalk_plugin_init_fn)dlsym(handle, "dstalk_plugin_init"); +#endif + + if (!init_fn) { +#ifdef _WIN32 + FreeLibrary((HMODULE)handle); +#else + dlclose(handle); +#endif + return -1; + } + + // 调用入口函数获取插件信息 + dstalk_plugin_info_t* info = init_fn(); + if (!info) { +#ifdef _WIN32 + FreeLibrary((HMODULE)handle); +#else + dlclose(handle); +#endif + return -1; + } + + // 检查API版本兼容性 + if (info->api_version != DSTALK_API_VERSION) { +#ifdef _WIN32 + FreeLibrary((HMODULE)handle); +#else + dlclose(handle); +#endif + return -1; + } + + // 创建插件信息 + int id = next_id_++; + PluginInfo plugin; + plugin.id = id; + plugin.name = info->name ? info->name : ""; + plugin.version = info->version ? info->version : ""; + plugin.description = info->description ? info->description : ""; + plugin.api_version = info->api_version; + plugin.handle = handle; + plugin.info = info; + plugin.initialized = false; + + // 解析依赖 + for (int i = 0; i < DSTALK_MAX_DEPS && info->dependencies[i]; i++) { + plugin.dependencies.push_back(info->dependencies[i]); + } + + plugins_[id] = std::move(plugin); + return id; +} + +int PluginLoader::unload_plugin(int plugin_id) +{ + auto it = plugins_.find(plugin_id); + if (it == plugins_.end()) return -1; + + PluginInfo& plugin = it->second; + + // 调用关闭回调 + if (plugin.initialized && plugin.info->on_shutdown) { + plugin.info->on_shutdown(); + } + + // 卸载DLL +#ifdef _WIN32 + FreeLibrary((HMODULE)plugin.handle); +#else + dlclose(plugin.handle); +#endif + + plugins_.erase(it); + return 0; +} + +std::string PluginLoader::list_plugins() const +{ + json::array arr; + for (const auto& [id, plugin] : plugins_) { + json::object obj; + obj["id"] = id; + obj["name"] = plugin.name; + obj["version"] = plugin.version; + obj["description"] = plugin.description; + obj["api_version"] = plugin.api_version; + obj["initialized"] = plugin.initialized; + + json::array deps; + for (const auto& dep : plugin.dependencies) { + deps.push_back(json::value(dep)); + } + obj["dependencies"] = std::move(deps); + + arr.push_back(std::move(obj)); + } + return json::serialize(arr); +} + +std::vector PluginLoader::topological_sort() const +{ + // 构建名称到ID的映射 + std::unordered_map name_to_id; + for (const auto& [id, plugin] : plugins_) { + name_to_id[plugin.name] = id; + } + + // 计算入度 + std::unordered_map in_degree; + std::unordered_map> dependents; + + for (const auto& [id, plugin] : plugins_) { + in_degree[id] = 0; + } + + for (const auto& [id, plugin] : plugins_) { + for (const auto& dep_name : plugin.dependencies) { + auto it = name_to_id.find(dep_name); + if (it != name_to_id.end()) { + int dep_id = it->second; + dependents[dep_id].push_back(id); + in_degree[id]++; + } + } + } + + // 拓扑排序(Kahn算法) + std::queue queue; + for (const auto& [id, degree] : in_degree) { + if (degree == 0) { + queue.push(id); + } + } + + std::vector sorted; + while (!queue.empty()) { + int id = queue.front(); + queue.pop(); + sorted.push_back(id); + + for (int dependent : dependents[id]) { + if (--in_degree[dependent] == 0) { + queue.push(dependent); + } + } + } + + // 检查循环依赖 + if (sorted.size() != plugins_.size()) { + throw std::runtime_error("Circular dependency detected"); + } + + return sorted; +} + +int PluginLoader::initialize_all(const dstalk_host_api_t* host_api) +{ + try { + std::vector order = topological_sort(); + + for (int id : order) { + auto it = plugins_.find(id); + if (it == plugins_.end()) continue; + + PluginInfo& plugin = it->second; + if (plugin.initialized) continue; + + if (plugin.info->on_init) { + int result = plugin.info->on_init(host_api); + if (result != 0) { + return -1; + } + } + plugin.initialized = true; + } + + return 0; + } catch (const std::exception&) { + return -1; + } +} + +int PluginLoader::initialize_pending(const dstalk_host_api_t* host_api) +{ + try { + std::vector order = topological_sort(); + + int count = 0; + for (int id : order) { + auto it = plugins_.find(id); + if (it == plugins_.end()) continue; + + PluginInfo& plugin = it->second; + if (plugin.initialized) continue; + + if (plugin.info->on_init) { + int result = plugin.info->on_init(host_api); + if (result != 0) { + return -1; + } + } + plugin.initialized = true; + count++; + } + + return count; + } catch (const std::exception&) { + return -1; + } +} + +void PluginLoader::shutdown_all() +{ + // 按逆序关闭 + std::vector order; + try { + order = topological_sort(); + std::reverse(order.begin(), order.end()); + } catch (...) { + // 如果排序失败,按任意顺序关闭 + for (const auto& [id, _] : plugins_) { + order.push_back(id); + } + } + + for (int id : order) { + auto it = plugins_.find(id); + if (it == plugins_.end()) continue; + + PluginInfo& plugin = it->second; + if (!plugin.initialized) continue; + + if (plugin.info->on_shutdown) { + plugin.info->on_shutdown(); + } + plugin.initialized = false; + } +} + +const PluginInfo* PluginLoader::get_plugin(int plugin_id) const +{ + auto it = plugins_.find(plugin_id); + if (it == plugins_.end()) return nullptr; + return &it->second; +} + +} // namespace dstalk diff --git a/dstalk-core/src/plugin_loader.hpp b/dstalk-core/src/plugin_loader.hpp new file mode 100644 index 0000000..7d433d7 --- /dev/null +++ b/dstalk-core/src/plugin_loader.hpp @@ -0,0 +1,57 @@ +#pragma once + +#include "dstalk/dstalk_host.h" +#include +#include +#include + +namespace dstalk { + +struct PluginInfo { + int id; + std::string name; + std::string version; + std::string description; + int api_version; + std::vector dependencies; + + void* handle; // DLL handle + dstalk_plugin_info_t* info; + bool initialized; +}; + +class PluginLoader { +public: + PluginLoader() = default; + ~PluginLoader(); + + // 加载插件(返回插件ID,失败返回-1) + int load_plugin(const char* path); + + // 卸载插件 + int unload_plugin(int plugin_id); + + // 获取插件列表(JSON格式) + std::string list_plugins() const; + + // 按依赖顺序初始化所有插件 + int initialize_all(const dstalk_host_api_t* host_api); + + // 仅初始化尚未初始化的插件(增量加载场景) + int initialize_pending(const dstalk_host_api_t* host_api); + + // 关闭所有插件 + void shutdown_all(); + + // 获取插件信息 + const PluginInfo* get_plugin(int plugin_id) const; + +private: + // 拓扑排序(按依赖顺序) + std::vector topological_sort() const; + + std::unordered_map plugins_; + int next_id_ = 1; +}; + +} // namespace dstalk diff --git a/dstalk-core/src/service_registry.cpp b/dstalk-core/src/service_registry.cpp new file mode 100644 index 0000000..54a0683 --- /dev/null +++ b/dstalk-core/src/service_registry.cpp @@ -0,0 +1,42 @@ +#include "service_registry.hpp" + +namespace dstalk { + +int ServiceRegistry::register_service(const char* name, int version, void* vtable) +{ + if (!name || !vtable) return -1; + + std::unique_lock lock(mutex_); + + // 检查是否已注册 + if (services_.find(name) != services_.end()) { + return -2; // 已存在 + } + + services_[name] = {name, version, vtable}; + return 0; +} + +void* ServiceRegistry::query_service(const char* name, int min_version) const +{ + if (!name) return nullptr; + + std::shared_lock lock(mutex_); + + auto it = services_.find(name); + if (it == services_.end()) return nullptr; + + if (it->second.version < min_version) return nullptr; + + return it->second.vtable; +} + +void ServiceRegistry::unregister_service(const char* name) +{ + if (!name) return; + + std::unique_lock lock(mutex_); + services_.erase(name); +} + +} // namespace dstalk diff --git a/dstalk-core/src/service_registry.hpp b/dstalk-core/src/service_registry.hpp new file mode 100644 index 0000000..e93e89f --- /dev/null +++ b/dstalk-core/src/service_registry.hpp @@ -0,0 +1,35 @@ +#pragma once + +#include +#include +#include +#include + +namespace dstalk { + +class ServiceRegistry { +public: + ServiceRegistry() = default; + ~ServiceRegistry() = default; + + // 注册服务 + int register_service(const char* name, int version, void* vtable); + + // 查询服务(返回 vtable 指针,或 nullptr) + void* query_service(const char* name, int min_version) const; + + // 注销服务 + void unregister_service(const char* name); + +private: + struct ServiceEntry { + std::string name; + int version; + void* vtable; + }; + + mutable std::shared_mutex mutex_; + std::unordered_map services_; +}; + +} // namespace dstalk diff --git a/dstalk-gui/CMakeLists.txt b/dstalk-gui/CMakeLists.txt index f8b1b53..d1a7566 100644 --- a/dstalk-gui/CMakeLists.txt +++ b/dstalk-gui/CMakeLists.txt @@ -2,8 +2,8 @@ # dstalk-gui — 图形化前端 (SDL3) # ============================================================ -# 启用 DSTALK_BUILD_GUI=ON 前,需要由系统或外部包管理器提供 SDL3。 -find_package(SDL3 REQUIRED) +# 启用 DSTALK_BUILD_GUI=ON 前,确保 deps/conanfile.txt 中包含 sdl 依赖 +find_package(SDL3 REQUIRED CONFIG) add_executable(dstalk-gui src/main.cpp diff --git a/dstalk-gui/src/main.cpp b/dstalk-gui/src/main.cpp index c5d9046..31b585e 100644 --- a/dstalk-gui/src/main.cpp +++ b/dstalk-gui/src/main.cpp @@ -1,56 +1,902 @@ +// ============================================================================ +// dstalk-gui — SDL3 聊天客户端 +// ============================================================================ +// 使用 SDL3 内置的 SDL_RenderDebugText() 渲染文本(8x8 像素), +// 通过 SDL_SetRenderScale 2 倍缩放至有效的 16x16 像素。 +// +// 该文件是独立的——不需要额外的源文件。 +// ============================================================================ + #include #include +#include +#include +#include +#include +#include #include -#include "dstalk/dstalk_api.h" +#include "dstalk/dstalk_host.h" -int main(int argc, char* argv[]) -{ +// ---- 服务 vtable 指针 ---- +static const dstalk_ai_service_t* g_ai_svc = nullptr; +static const dstalk_session_service_t* g_session_svc = nullptr; + +// ---- 常量 ---- + +static constexpr int WINDOW_W = 1024; +static constexpr int WINDOW_H = 768; +static constexpr float RENDER_SCALE = 2.0f; + +// 逻辑坐标尺寸(物理像素 / RENDER_SCALE) +static constexpr int LOGICAL_W = WINDOW_W / 2; // 512 +static constexpr int LOGICAL_H = WINDOW_H / 2; // 384 + +static constexpr int CHAR_W = 8; // SDL_RenderDebugText 原生字符宽度(逻辑像素) +static constexpr int CHAR_H = 8; // 原生字符高度(逻辑像素) +static constexpr int TITLE_H = 16; // 标题栏高度(逻辑像素) +static constexpr int PADDING = 4; // 内边距(逻辑像素) + +// 侧边栏 +static constexpr int SIDEBAR_W = 80; // 侧边栏宽度(逻辑像素,渲染为 160 物理像素) + +// 状态栏 +static constexpr int STATUS_H = 20; // 状态栏高度(逻辑像素,渲染为 40 物理像素) + +// 输入区域动态高度 +static constexpr int INPUT_H_MIN = 40; // 最小高度(逻辑像素) +static constexpr int INPUT_H_MAX = 120; // 最大高度(逻辑像素) + +// 消息区域(Y 起点不变,宽度和高度动态计算) +static constexpr int MSG_Y = TITLE_H; + +// 颜色(ARGB 格式,用于 SDL_SetRenderDrawColor) +static constexpr SDL_Color COL_BG = {0x1E, 0x1E, 0x2E, 0xFF}; +static constexpr SDL_Color COL_TITLE_BG = {0x2D, 0x2D, 0x44, 0xFF}; +static constexpr SDL_Color COL_INPUT_BG = {0x2A, 0x2A, 0x3E, 0xFF}; +static constexpr SDL_Color COL_USER = {0x00, 0xFF, 0xFF, 0xFF}; // 青色 +static constexpr SDL_Color COL_AI = {0x00, 0xFF, 0x80, 0xFF}; // 绿色 +static constexpr SDL_Color COL_SYS = {0xFF, 0xFF, 0x00, 0xFF}; // 黄色 +static constexpr SDL_Color COL_BTN = {0x50, 0x50, 0x80, 0xFF}; // 按钮 +static constexpr SDL_Color COL_WHITE = {0xFF, 0xFF, 0xFF, 0xFF}; +static constexpr SDL_Color COL_CURSOR = {0xFF, 0xFF, 0xFF, 0xFF}; +static constexpr SDL_Color COL_SEP = {0x50, 0x50, 0x70, 0xFF}; +static constexpr SDL_Color COL_SIDEBAR_BG = {0x18, 0x18, 0x28, 0xFF}; +static constexpr SDL_Color COL_SIDEBAR_ACT = {0x35, 0x35, 0x55, 0xFF}; +static constexpr SDL_Color COL_SIDEBAR_BTN = {0x40, 0x40, 0x68, 0xFF}; +static constexpr SDL_Color COL_STATUSBAR_BG= {0x2D, 0x2D, 0x44, 0xFF}; +static constexpr SDL_Color COL_DIM = {0x80, 0x80, 0x80, 0xFF}; + +// ---- 数据结构 ---- + +struct ChatMessage { + enum Role { USER, ASSISTANT, SYSTEM } role; + std::string content; + + ChatMessage(Role r, std::string c) : role(r), content(std::move(c)) {} +}; + +struct GuiState { + std::vector messages; + std::string inputBuffer; + int scrollOffset = 0; // 从底部滚动的逻辑像素 + bool streaming = false; + bool running = true; + int cursorPos = 0; // 输入缓冲区中的光标位置 + bool cursorVisible = true; + Uint64 lastCursorBlink = 0; + float maxScroll = 0; // 可用的最大滚动距离(逻辑像素) + + // P0 新增字段 + std::vector input_history; // 输入历史(最多 20 条) + int history_index = -1; // 当前历史位置(-1 = 新输入) + std::string saved_input; // 浏览历史时暂存当前输入 + bool sidebar_visible = true; // 侧边栏可见性 + std::string model_name = "deepseek-chat";// 当前模型名 +}; + +// 持有上下文指针,用于将回调传递给流式 API +struct AppContext { + GuiState state; + SDL_Window* window = nullptr; + SDL_Renderer* renderer = nullptr; + bool sendPending = false; // 按下 Enter 后设置为 true + std::string streamBuffer; // 存储当前流式消息 +}; + +// ---- 辅助函数 ---- + +// 获取一个逻辑坐标的 SDL 矩形 +static SDL_FRect mkRect(float x, float y, float w, float h) { + SDL_FRect r; + r.x = x; r.y = y; r.w = w; r.h = h; + return r; +} + +// 使用给定的颜色设置绘制颜色 +static void setColor(SDL_Renderer* r, SDL_Color c) { + SDL_SetRenderDrawColor(r, c.r, c.g, c.b, c.a); +} + +// 使用颜色绘制填充矩形 +static void fillRect(SDL_Renderer* r, SDL_FRect rect, SDL_Color c) { + setColor(r, c); + SDL_RenderFillRect(r, &rect); +} + +// 在给定位置(逻辑坐标)绘制一个调试文本字符串,并设定颜色 +static void drawText(SDL_Renderer* r, float x, float y, + const char* text, SDL_Color color) { + setColor(r, color); + SDL_RenderDebugText(r, x, y, text); +} + +// 绘制一个可见的调试文本字符,避免为空字符串调用 SDL_RenderDebugText +static void drawTextSafe(SDL_Renderer* r, float x, float y, + const char* text) { + if (text && text[0] != '\0') { + SDL_RenderDebugText(r, x, y, text); + } +} + +// 计算输入区域的动态高度(根据输入内容中的换行数) +static int calcInputHeight(const std::string& input) { + int lines = 1; + for (char ch : input) { + if (ch == '\n') lines++; + } + return std::min(INPUT_H_MAX, + std::max(INPUT_H_MIN, lines * CHAR_H + PADDING * 2)); +} + +// ---- 文本换行 ---- + +// 将一段文本按字符数换行。保留嵌入的 '\n',并在单词边界处尽可能按字符数换行。 +// 返回逻辑文本行列表。 +static std::vector wrapText(const std::string& text, int maxChars) { + std::vector lines; + + // 首先按嵌入的换行符分割 + std::string remaining = text; + while (!remaining.empty()) { + std::string segment; + auto nlPos = remaining.find('\n'); + if (nlPos != std::string::npos) { + segment = remaining.substr(0, nlPos); + remaining = remaining.substr(nlPos + 1); + } else { + segment = remaining; + remaining.clear(); + } + + // 将片段按单词换行以适应 maxChars + while (!segment.empty()) { + if (static_cast(segment.size()) <= maxChars) { + lines.push_back(segment); + break; + } + // 在 maxChars 位置寻找空格/单词边界 + int splitAt = maxChars; + for (int i = maxChars; i > 0; --i) { + char ch = segment[i]; + if (ch == ' ' || ch == '\t' || ch == ',' || ch == '.' || + ch == ';' || ch == ':' || ch == '!' || ch == '?' || + ch == '>' || ch == ')' || ch == ']' || ch == '}') { + splitAt = i + 1; + break; + } + if ((ch & 0x80) != 0) { + // UTF-8 多字节字符——不在中间分割 + } + } + if (splitAt <= 0 || splitAt > maxChars) { + splitAt = maxChars; + } + + lines.push_back(segment.substr(0, splitAt)); + // 去除下一行的前导空格 + size_t start = splitAt; + while (start < segment.size() && + (segment[start] == ' ' || segment[start] == '\t')) { + ++start; + } + segment = segment.substr(start); + } + } + return lines; +} + +// 计算所有消息的总渲染高度(逻辑像素)。 +// 注意:这使用当前的侧边栏状态来决定宽度;调用者应在侧边栏宽度正确时调用。 +static int calcTotalMsgHeight(GuiState& state, int charsPerLine) { + int totalH = 0; + for (auto& msg : state.messages) { + auto lines = wrapText(msg.content, charsPerLine); + int msgH = static_cast(lines.size()) * CHAR_H + PADDING; + totalH += msgH; + } + return totalH; +} + +// ---- 侧边栏渲染 ---- + +static void renderSidebar(AppContext& ctx) { + GuiState& gs = ctx.state; + SDL_Renderer* r = ctx.renderer; + float sbW = static_cast(SIDEBAR_W); + float sbY = static_cast(TITLE_H); + float sbH = static_cast(LOGICAL_H) - TITLE_H - STATUS_H; + + // 背景 + fillRect(r, mkRect(0, sbY, sbW, sbH), COL_SIDEBAR_BG); + + // 右侧分隔线 + setColor(r, COL_SEP); + SDL_RenderLine(r, sbW, sbY, sbW, sbY + sbH); + + // "Chats" 标题 + drawText(r, static_cast(PADDING), sbY + PADDING, "Chats", COL_WHITE); + + // 会话列表(当前只有 "default") + float listY = sbY + TITLE_H; + // "default" 条目(活动状态高亮) + float itemH = static_cast(CHAR_H + PADDING); + fillRect(r, mkRect(PADDING, listY, sbW - PADDING * 2, itemH), COL_SIDEBAR_ACT); + drawText(r, PADDING * 2.0f, listY + PADDING / 2.0f, "default", COL_AI); + + // "+ New Chat" 按钮(侧边栏底部) + float btnY = sbY + sbH - CHAR_H - PADDING * 2; + float btnH = static_cast(CHAR_H + PADDING); + fillRect(r, mkRect(PADDING, btnY, sbW - PADDING * 2, btnH), COL_SIDEBAR_BTN); + drawText(r, PADDING * 2.0f, btnY + PADDING / 2.0f, "+ New Chat", COL_WHITE); +} + +// ---- 状态栏渲染 ---- + +static void renderStatusBar(AppContext& ctx) { + GuiState& gs = ctx.state; + SDL_Renderer* r = ctx.renderer; + float lw = static_cast(LOGICAL_W); + float lh = static_cast(LOGICAL_H); + float barY = lh - STATUS_H; + + // 背景 + fillRect(r, mkRect(0, barY, lw, static_cast(STATUS_H)), COL_STATUSBAR_BG); + + // 顶部分隔线 + setColor(r, COL_SEP); + SDL_RenderLine(r, 0, barY, lw, barY); + + // 统计消息数(排除系统消息) + int msgCount = 0; + for (auto& msg : gs.messages) { + if (msg.role != ChatMessage::SYSTEM) msgCount++; + } + + // 状态文本:模型名 | 消息条数 | 流式状态 + char buf[256]; + snprintf(buf, sizeof(buf), "%s | %d messages | %s", + gs.model_name.c_str(), msgCount, + gs.streaming ? "streaming" : "ready"); + drawText(r, static_cast(PADDING), + barY + (STATUS_H - CHAR_H) / 2.0f, buf, COL_WHITE); +} + +// ---- 主渲染 ---- + +static void renderFrame(AppContext& ctx) { + GuiState& gs = ctx.state; + SDL_Renderer* r = ctx.renderer; + float lw = static_cast(LOGICAL_W); + float lh = static_cast(LOGICAL_H); + + // ----- 动态布局计算 ----- + int inputH = calcInputHeight(gs.inputBuffer); + float inputY = lh - STATUS_H - inputH; + float msgAreaX = gs.sidebar_visible ? static_cast(SIDEBAR_W) : 0.0f; + float msgAreaW = lw - msgAreaX; + float msgAreaY = static_cast(MSG_Y); + float msgAreaH = inputY - msgAreaY; + int charsPerLine = std::max(20, + static_cast(msgAreaW - PADDING * 2) / CHAR_W); + + // 1. 设置渲染缩放以获得 2 倍文本大小 + SDL_SetRenderScale(r, RENDER_SCALE, RENDER_SCALE); + + // 2. 清除背景 + setColor(r, COL_BG); + SDL_RenderClear(r); + + // 3. 标题栏(全宽) + fillRect(r, mkRect(0, 0, lw, static_cast(TITLE_H)), COL_TITLE_BG); + drawText(r, static_cast(PADDING), static_cast(PADDING), + "dstalk - AI Chat", COL_WHITE); + // 右侧的状态指示器 + const char* status = gs.streaming ? "[streaming...]" : "[ready]"; + float statusW = static_cast(strlen(status)) * CHAR_W + PADDING; + drawText(r, lw - statusW, static_cast(PADDING), status, COL_WHITE); + + // 4. 标题栏分隔线 + setColor(r, COL_SEP); + SDL_RenderLine(r, 0, static_cast(TITLE_H), + lw, static_cast(TITLE_H)); + + // 5. 侧边栏(可折叠) + if (gs.sidebar_visible) { + renderSidebar(ctx); + } + + // 6. 消息区域(带滚动) + SDL_Rect msgClip; + msgClip.x = static_cast(msgAreaX * RENDER_SCALE); + msgClip.y = static_cast(msgAreaY * RENDER_SCALE); + msgClip.w = static_cast(msgAreaW * RENDER_SCALE); + msgClip.h = static_cast(msgAreaH * RENDER_SCALE); + SDL_SetRenderClipRect(r, &msgClip); + + // 计算总消息高度和滚动限制 + int totalMsgH = calcTotalMsgHeight(gs, charsPerLine); + gs.maxScroll = std::max(0.0f, static_cast(totalMsgH) - msgAreaH); + if (gs.scrollOffset < 0) gs.scrollOffset = 0; + if (gs.scrollOffset > gs.maxScroll) gs.scrollOffset = static_cast(gs.maxScroll); + + // 绘制消息:起始 Y 从消息区域顶部减去 scrollOffset + float drawY = msgAreaY - gs.scrollOffset; + float unusedSpace = msgAreaH - static_cast(totalMsgH); + float bottomOffset = std::max(0.0f, unusedSpace); + drawY += bottomOffset; + + for (auto& msg : gs.messages) { + auto lines = wrapText(msg.content, charsPerLine); + int msgH = static_cast(lines.size()) * CHAR_H + PADDING; + + SDL_Color col; + const char* prefix; + switch (msg.role) { + case ChatMessage::USER: col = COL_USER; prefix = "You> "; break; + case ChatMessage::ASSISTANT: col = COL_AI; prefix = "AI> "; break; + default: col = COL_SYS; prefix = "Sys> "; break; + } + + // 如果该消息可见,则绘制 + float msgBottom = drawY + msgH; + if (msgBottom > msgAreaY && drawY < msgAreaY + msgAreaH) { + float lineY = drawY + 2; + for (size_t li = 0; li < lines.size(); ++li) { + if (lineY >= msgAreaY - CHAR_H && lineY <= msgAreaY + msgAreaH) { + if (li == 0) { + std::string line = prefix + lines[li]; + drawTextSafe(r, msgAreaX + static_cast(PADDING), + lineY, line.c_str()); + } else { + drawTextSafe(r, msgAreaX + static_cast(PADDING) + 4 * CHAR_W, + lineY, lines[li].c_str()); + } + } + lineY += CHAR_H; + } + } + + drawY += msgH; + } + + SDL_SetRenderClipRect(r, nullptr); + + // 7. 输入区域分隔线 + setColor(r, COL_SEP); + SDL_RenderLine(r, msgAreaX, inputY, lw, inputY); + + // 8. 输入区域背景 + fillRect(r, mkRect(msgAreaX, inputY, msgAreaW, static_cast(inputH)), COL_INPUT_BG); + + // 9. 输入文本(支持多行显示) + if (!gs.inputBuffer.empty()) { + std::string remaining = gs.inputBuffer; + int lineIdx = 0; + while (!remaining.empty() && lineIdx * CHAR_H < inputH) { + auto nlPos = remaining.find('\n'); + std::string line = (nlPos != std::string::npos) + ? remaining.substr(0, nlPos) : remaining; + float lineY = inputY + static_cast(PADDING) + CHAR_H + + lineIdx * CHAR_H; + drawTextSafe(r, msgAreaX + static_cast(PADDING) + 2, + lineY, line.c_str()); + lineIdx++; + if (nlPos != std::string::npos) { + remaining = remaining.substr(nlPos + 1); + } else { + break; + } + } + } else if (!gs.streaming) { + float textY = inputY + static_cast(PADDING) + CHAR_H; + setColor(r, COL_DIM); + SDL_RenderDebugText(r, msgAreaX + static_cast(PADDING) + 2, + textY, "Type here..."); + } + + // 10. 光标(多行感知) + if (!gs.streaming) { + Uint64 now = SDL_GetTicks(); + if (now - gs.lastCursorBlink > 530) { + gs.cursorVisible = !gs.cursorVisible; + gs.lastCursorBlink = now; + } + if (gs.cursorVisible && gs.cursorPos <= static_cast(gs.inputBuffer.size())) { + // 计算光标所在行和列 + int curLine = 0; + int charsBeforeLine = 0; + for (int i = 0; i < gs.cursorPos; i++) { + if (gs.inputBuffer[i] == '\n') { + curLine++; + charsBeforeLine = i + 1; + } + } + int colInLine = gs.cursorPos - charsBeforeLine; + float cursorX = msgAreaX + static_cast(PADDING) + 2 + + colInLine * CHAR_W; + float cursorY = inputY + static_cast(PADDING) + + curLine * CHAR_H; + setColor(r, COL_CURSOR); + SDL_RenderLine(r, cursorX, cursorY, + cursorX, cursorY + CHAR_H); + } + } + + // 11. 发送/停止按钮 + float btnW = 5 * CHAR_W + PADDING; + float btnH = CHAR_H + PADDING; + float btnX = lw - btnW - PADDING; + float btnY = inputY + (inputH - btnH) / 2.0f; + fillRect(r, mkRect(btnX, btnY, btnW, btnH), COL_BTN); + float btnTextX = btnX + PADDING / 2.0f; + float btnTextY = btnY + PADDING / 2.0f; + if (gs.streaming) { + drawText(r, btnTextX, btnTextY, "[Stop]", COL_WHITE); + } else { + drawText(r, btnTextX, btnTextY, "[Send]", COL_WHITE); + } + + // 12. 状态栏 + renderStatusBar(ctx); + + // 13. Present + SDL_RenderPresent(r); +} + +// ---- 事件处理 ---- + +// 尝试发送当前输入缓冲区的内容;返回 true 表示消息已排队 +static bool trySendMessage(GuiState& gs) { + std::string text = gs.inputBuffer; + // 去除前导/尾随空白,但保留内容空白 + size_t start = text.find_first_not_of(" \t\r\n"); + size_t end = text.find_last_not_of(" \t\r\n"); + if (start == std::string::npos) return false; // 空输入 + text = text.substr(start, end - start + 1); + if (text.empty()) return false; + + // 保存原始输入到历史(最多保留 20 条) + gs.input_history.push_back(gs.inputBuffer); + if (gs.input_history.size() > 20) + gs.input_history.erase(gs.input_history.begin()); + gs.history_index = -1; + + gs.messages.push_back(ChatMessage(ChatMessage::USER, text)); + gs.inputBuffer.clear(); + gs.cursorPos = 0; + return true; +} + +// 如果输入区域中的 Send/Stop 按钮被点击,返回 true +static bool isSendButtonHit(AppContext& ctx, float physX, float physY) { + float lx = physX / RENDER_SCALE; + float ly = physY / RENDER_SCALE; + + int inputH = calcInputHeight(ctx.state.inputBuffer); + float inputY = LOGICAL_H - STATUS_H - inputH; + + float btnW = 5 * CHAR_W + PADDING; + float btnH = CHAR_H + PADDING; + float btnX = LOGICAL_W - btnW - PADDING; + float btnY = inputY + (inputH - btnH) / 2.0f; + + return lx >= btnX && lx <= btnX + btnW && + ly >= btnY && ly <= btnY + btnH; +} + +// ---- 流式回调 ---- + +static int streamTokenCallback(const char* token, void* userdata) { + AppContext* ctx = static_cast(userdata); + GuiState& gs = ctx->state; + + if (token && token[0] != '\0') { + ctx->streamBuffer += token; + if (!gs.messages.empty() && + gs.messages.back().role == ChatMessage::ASSISTANT) { + gs.messages.back().content = ctx->streamBuffer; + } + } + + // 泵送事件以保持窗口响应 + SDL_PumpEvents(); + + SDL_Event ev; + while (SDL_PollEvent(&ev)) { + if (ev.type == SDL_EVENT_QUIT) { + gs.running = false; + gs.streaming = false; + return 1; + } + if (ev.type == SDL_EVENT_MOUSE_BUTTON_DOWN && + ev.button.button == SDL_BUTTON_LEFT) { + if (isSendButtonHit(*ctx, ev.button.x, ev.button.y)) { + gs.streaming = false; + return 1; + } + } + if (ev.type == SDL_EVENT_MOUSE_WHEEL) { + gs.scrollOffset -= static_cast(ev.wheel.y * CHAR_H * 3); + } + if (ev.type == SDL_EVENT_KEY_DOWN && + ev.key.key == SDLK_ESCAPE) { + gs.streaming = false; + return 1; + } + } + + // 重新渲染以显示进度的令牌 + gs.scrollOffset = 0; + renderFrame(*ctx); + + return 0; +} + +// ---- 主事件处理函数 ---- + +static void processEvent(AppContext& ctx, SDL_Event& ev) { + GuiState& gs = ctx.state; + + switch (ev.type) { + case SDL_EVENT_QUIT: + gs.running = false; + break; + + case SDL_EVENT_KEY_DOWN: { + SDL_Keycode key = ev.key.key; + SDL_Keymod mod = ev.key.mod; + bool ctrl = (mod & SDL_KMOD_CTRL) != 0; + bool shift = (mod & SDL_KMOD_SHIFT) != 0; + + if (gs.streaming) { + // 流式传输期间,按 Escape 键取消 + if (key == SDLK_ESCAPE) { + gs.streaming = false; + } + break; + } + + // Tab 切换侧边栏显示/隐藏 + if (key == SDLK_TAB) { + gs.sidebar_visible = !gs.sidebar_visible; + break; + } + + // 输入历史浏览(↑/↓) + if (key == SDLK_UP && !gs.input_history.empty()) { + if (gs.history_index == -1) { + // 首次进入历史浏览,保存当前输入 + gs.saved_input = gs.inputBuffer; + gs.history_index = static_cast(gs.input_history.size()) - 1; + } else if (gs.history_index > 0) { + gs.history_index--; + } + gs.inputBuffer = gs.input_history[gs.history_index]; + gs.cursorPos = static_cast(gs.inputBuffer.size()); + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + break; + } + + if (key == SDLK_DOWN) { + if (gs.history_index >= 0) { + gs.history_index--; + if (gs.history_index >= 0) { + gs.inputBuffer = gs.input_history[gs.history_index]; + } else { + // 回到新输入,恢复暂存的输入 + gs.inputBuffer = gs.saved_input; + } + gs.cursorPos = static_cast(gs.inputBuffer.size()); + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + } + break; + } + + switch (key) { + case SDLK_RETURN: + case SDLK_KP_ENTER: + if (shift) { + // Shift+Enter:插入换行符(不发送) + gs.inputBuffer.insert(gs.cursorPos, "\n"); + gs.cursorPos++; + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + } else { + ctx.sendPending = true; + } + break; + case SDLK_BACKSPACE: + if (!gs.inputBuffer.empty() && gs.cursorPos > 0) { + gs.inputBuffer.erase(gs.cursorPos - 1, 1); + gs.cursorPos--; + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + } + break; + case SDLK_DELETE: + if (gs.cursorPos < static_cast(gs.inputBuffer.size())) { + gs.inputBuffer.erase(gs.cursorPos, 1); + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + } + break; + case SDLK_LEFT: + if (gs.cursorPos > 0) { + gs.cursorPos--; + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + } + break; + case SDLK_RIGHT: + if (gs.cursorPos < static_cast(gs.inputBuffer.size())) { + gs.cursorPos++; + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + } + break; + case SDLK_HOME: + gs.cursorPos = 0; + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + break; + case SDLK_END: + gs.cursorPos = static_cast(gs.inputBuffer.size()); + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + break; + case SDLK_V: + if (ctrl) { + // Ctrl+V:从剪贴板粘贴 + if (SDL_HasClipboardText()) { + char* clip = SDL_GetClipboardText(); + if (clip) { + gs.inputBuffer.insert(gs.cursorPos, clip); + gs.cursorPos += static_cast(strlen(clip)); + SDL_free(clip); + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + } + } + } + break; + case SDLK_C: + if (ctrl) { + // Ctrl+C:复制到剪贴板(复制最后一条助手消息) + if (!gs.messages.empty()) { + for (int i = static_cast(gs.messages.size()) - 1; i >= 0; --i) { + if (gs.messages[i].role != ChatMessage::USER) { + SDL_SetClipboardText(gs.messages[i].content.c_str()); + gs.messages.push_back(ChatMessage( + ChatMessage::SYSTEM, "[Copied last AI response to clipboard]")); + gs.scrollOffset = 0; + break; + } + } + } + } + break; + case SDLK_L: + if (ctrl) { + // Ctrl+L:清除聊天 + if (g_session_svc) g_session_svc->clear(); + gs.messages.clear(); + gs.messages.push_back(ChatMessage( + ChatMessage::SYSTEM, "Session cleared.")); + gs.scrollOffset = 0; + } + break; + case SDLK_S: + if (ctrl) { + // Ctrl+S:保存会话 + if (g_session_svc && g_session_svc->save("session.json") == 0) { + gs.messages.push_back(ChatMessage( + ChatMessage::SYSTEM, "Session saved to session.json")); + } else { + gs.messages.push_back(ChatMessage( + ChatMessage::SYSTEM, "Failed to save session.")); + } + gs.scrollOffset = 0; + } + break; + case SDLK_O: + if (ctrl) { + // Ctrl+O:加载会话 + if (g_session_svc && g_session_svc->load("session.json") == 0) { + gs.messages.push_back(ChatMessage( + ChatMessage::SYSTEM, "Session loaded from session.json")); + } else { + gs.messages.push_back(ChatMessage( + ChatMessage::SYSTEM, "Failed to load session.")); + } + gs.scrollOffset = 0; + } + break; + default: + break; + } + break; + } + + case SDL_EVENT_TEXT_INPUT: + if (!gs.streaming) { + // 将文本插入光标位置 + gs.inputBuffer.insert(gs.cursorPos, ev.text.text); + gs.cursorPos += static_cast(strlen(ev.text.text)); + gs.cursorVisible = true; + gs.lastCursorBlink = SDL_GetTicks(); + } + break; + + case SDL_EVENT_MOUSE_BUTTON_DOWN: + if (ev.button.button == SDL_BUTTON_LEFT) { + if (isSendButtonHit(ctx, ev.button.x, ev.button.y)) { + if (gs.streaming) { + gs.streaming = false; + } else { + ctx.sendPending = true; + } + } + } + break; + + case SDL_EVENT_MOUSE_WHEEL: + if (ev.wheel.y != 0) { + gs.scrollOffset -= static_cast(ev.wheel.y * CHAR_H * 3); + } + break; + + case SDL_EVENT_WINDOW_RESIZED: { + // 当窗口大小改变时,不更新我们的常量——保持 1024x768 的逻辑尺寸。 + // SDL 将自动缩放输出。 + break; + } + + default: + break; + } +} + +// ---- 入口 ---- + +int main(int argc, char* argv[]) { + // ----- 初始化 dstalk ----- if (dstalk_init(nullptr) != 0) { - std::fprintf(stderr, "[dstalk] 初始化失败\n"); + std::fprintf(stderr, "[dstalk] Init failed\n"); return 1; } + const char* ai_provider = dstalk_config_get("ai.provider"); + if (!ai_provider) ai_provider = "ai.deepseek"; + g_ai_svc = static_cast(dstalk_service_query(ai_provider, 1)); + g_session_svc = static_cast(dstalk_service_query("session", 1)); + if (!g_ai_svc) dstalk_log(3, "AI service not found (check plugins directory)"); + if (!g_session_svc) dstalk_log(3, "Session service not found"); + + // ----- 初始化 SDL ----- if (!SDL_Init(SDL_INIT_VIDEO)) { - std::fprintf(stderr, "[dstalk] SDL 初始化失败: %s\n", SDL_GetError()); - dstalk_destroy(); + std::fprintf(stderr, "[dstalk] SDL init failed: %s\n", SDL_GetError()); + dstalk_shutdown(); return 1; } SDL_Window* window = SDL_CreateWindow( - "dstalk", 1024, 768, SDL_WINDOW_RESIZABLE); + "dstalk - AI Chat", WINDOW_W, WINDOW_H, + SDL_WINDOW_RESIZABLE); if (!window) { - std::fprintf(stderr, "[dstalk] 窗口创建失败: %s\n", SDL_GetError()); + std::fprintf(stderr, "[dstalk] Window creation failed: %s\n", SDL_GetError()); SDL_Quit(); - dstalk_destroy(); + dstalk_shutdown(); return 1; } SDL_Renderer* renderer = SDL_CreateRenderer(window, nullptr); if (!renderer) { + std::fprintf(stderr, "[dstalk] Renderer creation failed: %s\n", SDL_GetError()); SDL_DestroyWindow(window); SDL_Quit(); - dstalk_destroy(); + dstalk_shutdown(); return 1; } - bool running = true; + // 启用文本输入事件 + SDL_StartTextInput(window); + + // ----- 应用程序状态 ----- + AppContext ctx; + ctx.window = window; + ctx.renderer = renderer; + ctx.state.messages.push_back(ChatMessage( + ChatMessage::SYSTEM, "Welcome to dstalk! Type a message and press Enter to chat. " + "Ctrl+L clear, Ctrl+S save, Ctrl+O load. " + "Shift+Enter for newline, Up/Down for history, Tab toggle sidebar.")); + + // ----- 主循环 ----- SDL_Event event; - while (running) { + while (ctx.state.running) { + // 处理所有待处理事件 while (SDL_PollEvent(&event)) { - if (event.type == SDL_EVENT_QUIT) { - running = false; + processEvent(ctx, event); + if (!ctx.state.running) break; + } + if (!ctx.state.running) break; + + // 检查待发送的消息 + if (ctx.sendPending && !ctx.state.streaming) { + ctx.sendPending = false; + if (trySendMessage(ctx.state)) { + // 开始流式传输 + ctx.state.streaming = true; + ctx.streamBuffer.clear(); + // 为流式响应添加占位消息 + ctx.state.messages.push_back( + ChatMessage(ChatMessage::ASSISTANT, "")); + ctx.state.scrollOffset = 0; + + // 对最后一条消息调用流式 API(通过插件服务 vtable) + std::string& userMsg = + ctx.state.messages[ctx.state.messages.size() - 2].content; + int rc = -1; + if (g_ai_svc) { + int hcount = 0; + const dstalk_message_t* history = g_session_svc + ? g_session_svc->history(&hcount) : nullptr; + dstalk_chat_result_t result = g_ai_svc->chat_stream( + history, hcount, userMsg.c_str(), + streamTokenCallback, &ctx); + rc = result.ok ? 0 : -1; + g_ai_svc->free_result(&result); + } + + // 流式传输完成(或被取消) + if (rc != 0) { + if (!ctx.state.messages.empty() && + ctx.state.messages.back().role == ChatMessage::ASSISTANT) { + if (ctx.state.messages.back().content.empty()) { + ctx.state.messages.back().content = "[Error or cancelled]"; + } + } + } + ctx.state.streaming = false; } } - SDL_SetRenderDrawColor(renderer, 0x1E, 0x1E, 0x2E, 0xFF); - SDL_RenderClear(renderer); - SDL_RenderPresent(renderer); + + // 渲染当前帧 + renderFrame(ctx); + + // 短暂休眠以降低 CPU 使用率 + SDL_Delay(16); // ~60 FPS } + // ----- 清理 ----- + SDL_StopTextInput(window); SDL_DestroyRenderer(renderer); SDL_DestroyWindow(window); SDL_Quit(); - dstalk_destroy(); + dstalk_shutdown(); + return 0; } diff --git a/examples/example_plugin/example_plugin.cpp b/examples/example_plugin/example_plugin.cpp new file mode 100644 index 0000000..c288d82 --- /dev/null +++ b/examples/example_plugin/example_plugin.cpp @@ -0,0 +1,151 @@ +/* + * example_plugin.cpp - Minimal dstalk plugin demonstrating the API contract. + * + * Build instructions (conceptual): + * + * Linux / macOS: + * g++ -std=c++20 -shared -fPIC -fvisibility=hidden \ + * -I \ + * -o example_plugin.so example_plugin.cpp + * + * Windows (MSVC): + * cl /std:c++20 /LD /EHsc \ + * /I \ + * /Fe:example_plugin.dll example_plugin.cpp + * + * The resulting `.so` / `.dylib` / `.dll` can be loaded with: + * + * int id = dstalk_plugin_load("./example_plugin.so"); + */ + +#include "dstalk/dstalk_host.h" + +#include /* fprintf */ +#include /* malloc, free */ +#include /* strlen, strcmp */ + +/* ------------------------------------------------------------------ + * Private state (one instance per plugin load) + * ------------------------------------------------------------------ + * + * In a more complex plugin this struct would hold open database + * connections, configuration, etc. + */ + +struct ExampleState { + int call_count; +}; + +/* ------------------------------------------------------------------ + * Stored host API table so callbacks can use host services. + * ------------------------------------------------------------------ */ + +static const dstalk_host_api_t* g_host = nullptr; + +static ExampleState g_state; /* not heap-allocated: stays valid + while the library is mapped */ + +/* ------------------------------------------------------------------ + * on_init (was on_load) + * ------------------------------------------------------------------ */ + +static int my_on_init(const dstalk_host_api_t* host) +{ + g_host = host; + g_state.call_count = 0; + + /* TODO: real plugins would initialise resources here: + * - parse a plugin-specific config file via host->config_get + * - open a log file + * - connect to a local service + * - register services via host->register_service + * + * Return non-zero to signal a fatal initialisation error to the + * host, which will then unload the plugin immediately. + */ + + if (host) { + host->log(DSTALK_LOG_INFO, "[example-plugin] loaded (v1.0.0)"); + } else { + std::fprintf(stderr, "[example-plugin] loaded (v1.0.0)\n"); + } + return 0; +} + +/* ------------------------------------------------------------------ + * on_shutdown (was on_unload) + * ------------------------------------------------------------------ */ + +static void my_on_shutdown(void) +{ + /* TODO: release any resources allocated in on_init. After this + * function returns the host will unmap the shared library. */ + + if (g_host) { + g_host->log(DSTALK_LOG_INFO, "[example-plugin] unloaded (%d events processed)", + g_state.call_count); + } else { + std::fprintf(stderr, + "[example-plugin] unloaded (%d callbacks processed)\n", + g_state.call_count); + } +} + +/* ------------------------------------------------------------------ + * on_event (was on_message) + * ------------------------------------------------------------------ */ + +static void my_on_event(int event_type, const void* data) +{ + if (event_type == DSTALK_EVENT_MESSAGE && data) { + const auto* msg = static_cast(data); + g_state.call_count++; + + /* A real plugin might: + * - log the conversation to a file + * - apply content moderation + * - translate messages on the fly + * - enrich messages with external data + */ + + if (g_host) { + g_host->log(DSTALK_LOG_DEBUG, "[example-plugin] message | role=%-9s len=%zu", + msg->role, std::strlen(msg->content)); + } else { + std::fprintf(stderr, + "[example-plugin] message | role=%-9s len=%zu\n", + msg->role, std::strlen(msg->content)); + } + } + /* Other event types (DSTALK_EVENT_SESSION_CLEAR, DSTALK_EVENT_CONFIG_CHANGED, + DSTALK_EVENT_PLUGIN_LOADED, DSTALK_EVENT_PLUGIN_UNLOADED, DSTALK_EVENT_CUSTOM+) + are silently ignored by this minimal plugin. */ +} + +/* ------------------------------------------------------------------ + * Plugin descriptor (static -- lives for the lifetime of the .so) + * ------------------------------------------------------------------ */ + +static dstalk_plugin_info_t g_info = { + /* .name = */ "example-plugin", + /* .version = */ "1.0.0", + /* .description = */ "An example plugin for dstalk", + /* .api_version = */ DSTALK_API_VERSION, + /* .dependencies = */ {nullptr}, + /* .on_init = */ my_on_init, + /* .on_shutdown = */ my_on_shutdown, + /* .on_event = */ my_on_event, +}; + +/* ------------------------------------------------------------------ + * Mandatory entry point + * ------------------------------------------------------------------ + * + * The host looks for this symbol via dlsym / GetProcAddress. + * It MUST be declared extern "C" so the name is not mangled. + */ + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) +{ + return &g_info; +} diff --git a/plugins/CMakeLists.txt b/plugins/CMakeLists.txt new file mode 100644 index 0000000..386e243 --- /dev/null +++ b/plugins/CMakeLists.txt @@ -0,0 +1,18 @@ +# ============================================================ +# 插件目录 — 所有功能插件 +# ============================================================ + +# 基础插件(无外部服务依赖) +add_subdirectory(config) +add_subdirectory(file-io) +add_subdirectory(network) + +# 中间插件(依赖基础插件) +add_subdirectory(session) +add_subdirectory(context) + +# 上层插件(依赖中间插件) +add_subdirectory(deepseek) +add_subdirectory(anthropic) +add_subdirectory(tools) +add_subdirectory(lsp) diff --git a/plugins/anthropic/CMakeLists.txt b/plugins/anthropic/CMakeLists.txt new file mode 100644 index 0000000..cf3d7b5 --- /dev/null +++ b/plugins/anthropic/CMakeLists.txt @@ -0,0 +1,32 @@ +cmake_minimum_required(VERSION 3.21) + +# ============================================================ +# plugin-anthropic — Anthropic Claude AI 服务 +# 依赖: http 服务 (查询), config 服务 (查询) +# ============================================================ + +add_library(plugin-anthropic SHARED + src/anthropic_plugin.cpp +) + +target_include_directories(plugin-anthropic PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-anthropic PRIVATE dstalk) + +# Boost.JSON 用于构建/解析请求和响应 +find_package(Boost REQUIRED CONFIG) +target_link_libraries(plugin-anthropic PRIVATE boost::boost) + +target_compile_definitions(plugin-anthropic PRIVATE + BOOST_ALL_NO_LIB + BOOST_ERROR_CODE_HEADER_ONLY + BOOST_JSON_HEADER_ONLY +) + +set_target_properties(plugin-anthropic PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/anthropic/src/anthropic_plugin.cpp b/plugins/anthropic/src/anthropic_plugin.cpp new file mode 100644 index 0000000..a222508 --- /dev/null +++ b/plugins/anthropic/src/anthropic_plugin.cpp @@ -0,0 +1,486 @@ +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" + +#include +#include +#include + +namespace json = boost::json; + +// ============================================================================ +// 全局指针 +// ============================================================================ +static const dstalk_host_api_t* g_host = nullptr; +static dstalk_http_service_t* g_http = nullptr; +static dstalk_config_service_t* g_config = nullptr; + +// ============================================================================ +// 配置数据 +// ============================================================================ +struct PluginConfig { + std::string provider; + std::string base_url; + std::string api_key; + std::string model; + int max_tokens = 4096; + double temperature = 0.7; +}; +static PluginConfig g_cfg; + +// ============================================================================ +// 辅助:提取 host / target +// ============================================================================ +static bool extract_host_port(const std::string& url, + std::string& scheme_out, std::string& host_out, + std::string& port_out, std::string& target_out) +{ + size_t scheme_end = url.find("://"); + if (scheme_end == std::string::npos) return false; + scheme_out = url.substr(0, scheme_end); + std::string rest = url.substr(scheme_end + 3); + size_t slash = rest.find('/'); + std::string authority = (slash != std::string::npos) ? rest.substr(0, slash) : rest; + target_out = (slash != std::string::npos) ? rest.substr(slash) : "/"; + size_t colon = authority.rfind(':'); + if (colon != std::string::npos) { + host_out = authority.substr(0, colon); + port_out = authority.substr(colon + 1); + } else { + host_out = authority; + port_out = (scheme_out == "https") ? "443" : "80"; + } + return true; +} + +// ============================================================================ +// 构建 Anthropic headers JSON +// ============================================================================ +static std::string build_headers_json() +{ + json::object h; + h["x-api-key"] = g_cfg.api_key; + h["anthropic-version"] = "2023-06-01"; + return json::serialize(h); +} + +// ============================================================================ +// 构建 Anthropic JSON 请求体 +// ============================================================================ +static std::string build_request_json( + const dstalk_message_t* history, int history_len, + const std::string& user_input, + bool stream) +{ + json::object root; + root["model"] = g_cfg.model; + root["max_tokens"] = g_cfg.max_tokens; + root["stream"] = stream; + + // 提取 system 消息作为顶层字段 + std::string system_prompt; + json::array msgs; + + for (int i = 0; i < history_len; ++i) { + const auto& m = history[i]; + if (m.role && std::strcmp(m.role, "system") == 0) { + if (!system_prompt.empty()) system_prompt += "\n\n"; + system_prompt += m.content ? m.content : ""; + continue; + } + json::object obj; + obj["role"] = m.role ? m.role : ""; + obj["content"] = m.content ? m.content : ""; + msgs.push_back(obj); + } + + // 追加当前用户输入 + { + json::object obj; + obj["role"] = "user"; + obj["content"] = user_input; + msgs.push_back(obj); + } + + root["messages"] = msgs; + + if (!system_prompt.empty()) { + root["system"] = system_prompt; + } + + if (g_cfg.temperature >= 0.0 && g_cfg.temperature <= 1.0) { + root["temperature"] = g_cfg.temperature; + } + + return json::serialize(root); +} + +// ============================================================================ +// 解析非流式响应 +// ============================================================================ +static void parse_response(const char* body, int http_status, + dstalk_chat_result_t& r) +{ + r.http_status = http_status; + + if (http_status < 200 || http_status >= 300) { + r.ok = 0; + try { + auto jv = json::parse(body ? body : "{}"); + auto obj = jv.as_object(); + if (obj.contains("error")) { + auto err = obj["error"].as_object(); + r.error = g_host->strdup( + json::value_to(err["message"]).c_str()); + } + } catch (...) { + std::string msg = "HTTP " + std::to_string(http_status); + r.error = g_host->strdup(msg.c_str()); + } + if (!r.error) { + std::string msg = "HTTP " + std::to_string(http_status); + r.error = g_host->strdup(msg.c_str()); + } + r.content = nullptr; + r.tool_calls_json = nullptr; + return; + } + + try { + auto jv = json::parse(body ? body : "{}"); + auto obj = jv.as_object(); + auto content = obj["content"].as_array(); + if (!content.empty()) { + // 取第一个 text block + for (const auto& block : content) { + auto bobj = block.as_object(); + if (bobj.contains("type") && + json::value_to(bobj["type"]) == "text") { + std::string text = json::value_to(bobj["text"]); + r.content = g_host->strdup(text.c_str()); + r.ok = 1; + r.error = nullptr; + r.tool_calls_json = nullptr; + return; + } + } + r.ok = 0; + r.error = g_host->strdup("no text content block found"); + } else { + r.ok = 0; + r.error = g_host->strdup("empty response"); + } + r.content = nullptr; + r.tool_calls_json = nullptr; + } catch (std::exception& e) { + r.ok = 0; + std::string msg = std::string("json parse: ") + e.what(); + r.error = g_host->strdup(msg.c_str()); + r.content = nullptr; + r.tool_calls_json = nullptr; + } catch (...) { + r.ok = 0; + r.error = g_host->strdup("json parse error"); + r.content = nullptr; + r.tool_calls_json = nullptr; + } +} + +// ============================================================================ +// SSE 事件解析(Anthropic 格式: event/content_block_delta) +// ============================================================================ + +// 状态机:记录当前正在处理的事件类型 +// 简化版:直接从 data: 行解析,不依赖 event: 行 +static bool parse_sse_data(const std::string& data, std::string& token_out) +{ + try { + auto jv = json::parse(data); + auto obj = jv.as_object(); + + auto* type_ptr = obj.if_contains("type"); + if (!type_ptr || !type_ptr->is_string()) return false; + std::string type = json::value_to(*type_ptr); + + if (type == "content_block_delta") { + auto* delta = obj.if_contains("delta"); + if (!delta || !delta->is_object()) return false; + auto& dobj = delta->as_object(); + + auto* dtype = dobj.if_contains("type"); + if (!dtype || !dtype->is_string()) return false; + std::string delta_type = json::value_to(*dtype); + + if (delta_type == "text_delta") { + auto* text = dobj.if_contains("text"); + if (text && text->is_string()) { + token_out = json::value_to(*text); + return true; + } + } + } else if (type == "message_stop") { + token_out.clear(); + return true; // 流结束 + } + // 忽略: message_start, content_block_start, content_block_stop, ping, message_delta + } catch (...) { + // 解析失败忽略 + } + return false; +} + +// ============================================================================ +// configure +// ============================================================================ +static int my_configure(const char* provider, const char* base_url, + const char* api_key, const char* model, + int max_tokens, double temperature) +{ + if (provider) g_cfg.provider = provider; + if (base_url) g_cfg.base_url = base_url; + if (api_key) g_cfg.api_key = api_key; + if (model) g_cfg.model = model; + g_cfg.max_tokens = max_tokens; + g_cfg.temperature = temperature; + + if (g_host) { + g_host->log(DSTALK_LOG_INFO, + "[anthropic] configured: model=%s base_url=%s max_tokens=%d temperature=%.2f", + g_cfg.model.c_str(), g_cfg.base_url.c_str(), + g_cfg.max_tokens, g_cfg.temperature); + } + return 0; +} + +// ============================================================================ +// chat +// ============================================================================ +static dstalk_chat_result_t my_chat( + const dstalk_message_t* history, int history_len, + const char* user_input, + const char* /*tools_json*/) +{ + dstalk_chat_result_t r = {}; + r.ok = 0; + + if (!g_http) { + r.error = g_host->strdup("http service not available"); + return r; + } + + std::string scheme, host, port, target; + extract_host_port(g_cfg.base_url, scheme, host, port, target); + std::string target_path = target + "/v1/messages"; + + std::string body = build_request_json(history, history_len, + user_input ? user_input : "", false); + + std::string headers_json = build_headers_json(); + + char* response_body = nullptr; + int status_code = 0; + + int ret = g_http->post_json( + host.c_str(), port.c_str(), target_path.c_str(), body.c_str(), + headers_json.c_str(), &response_body, &status_code); + + if (ret != 0) { + r.error = g_host->strdup("http request failed"); + return r; + } + + parse_response(response_body, status_code, r); + + if (response_body) { + g_host->free(response_body); + } + return r; +} + +// ============================================================================ +// chat_stream +// ============================================================================ + +struct StreamContext { + const dstalk_host_api_t* host; + dstalk_stream_cb user_cb; + void* userdata; + std::string accumulated; + bool saw_data_line = false; +}; + +// 行回调 +static int sse_line_callback(const char* line, void* userdata) +{ + auto* ctx = static_cast(userdata); + if (!line || !line[0]) return 1; // 空行,继续 + + std::string line_str(line); + + // SSE 格式: "data: " + if (line_str.rfind("data: ", 0) == 0) { + std::string data = line_str.substr(6); + std::string token; + if (parse_sse_data(data, token)) { + ctx->saw_data_line = true; + if (token.empty()) { + // message_stop + return 0; + } + ctx->accumulated += token; + if (ctx->user_cb) { + return ctx->user_cb(token.c_str(), ctx->userdata); + } + } + } + // "event: ..." 行和其他 -> 忽略 + return 1; +} + +static dstalk_chat_result_t my_chat_stream( + const dstalk_message_t* history, int history_len, + const char* user_input, + dstalk_stream_cb cb, void* userdata) +{ + dstalk_chat_result_t r = {}; + r.ok = 0; + + if (!g_http) { + r.error = g_host->strdup("http service not available"); + return r; + } + + std::string scheme, host, port, target; + extract_host_port(g_cfg.base_url, scheme, host, port, target); + std::string target_path = target + "/v1/messages"; + + std::string body = build_request_json(history, history_len, + user_input ? user_input : "", true); + + std::string headers_json = build_headers_json(); + + StreamContext ctx; + ctx.host = g_host; + ctx.user_cb = cb; + ctx.userdata = userdata; + ctx.saw_data_line = false; + + char* response_body = nullptr; + int status_code = 0; + + int ret = g_http->post_stream( + host.c_str(), port.c_str(), target_path.c_str(), body.c_str(), + headers_json.c_str(), + sse_line_callback, &ctx, + &response_body, &status_code); + + r.http_status = status_code; + + // 检查错误状态 + if (status_code < 200 || status_code >= 300) { + r.ok = 0; + if (response_body && response_body[0]) { + try { + auto jv = json::parse(response_body); + auto obj = jv.as_object(); + if (obj.contains("error")) { + auto err = obj["error"].as_object(); + r.error = g_host->strdup( + json::value_to(err["message"]).c_str()); + } + } catch (...) {} + } + if (!r.error) { + if (status_code <= 0) + r.error = g_host->strdup("transport error"); + else + r.error = g_host->strdup( + ("HTTP " + std::to_string(status_code)).c_str()); + } + if (response_body) g_host->free(response_body); + r.content = nullptr; + r.tool_calls_json = nullptr; + return r; + } + + if (response_body) g_host->free(response_body); + + if (ctx.accumulated.empty() && !ctx.saw_data_line) { + r.ok = 0; + r.error = g_host->strdup("no content received"); + r.content = nullptr; + r.tool_calls_json = nullptr; + } else { + r.ok = 1; + r.error = nullptr; + r.content = g_host->strdup(ctx.accumulated.c_str()); + r.tool_calls_json = nullptr; + } + return r; +} + +// ============================================================================ +// free_result +// ============================================================================ +static void my_free_result(dstalk_chat_result_t* result) +{ + if (!result || !g_host) return; + if (result->content) { g_host->free((void*)result->content); result->content = nullptr; } + if (result->error) { g_host->free((void*)result->error); result->error = nullptr; } + if (result->tool_calls_json) { g_host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; } +} + +// ============================================================================ +// 服务 vtable +// ============================================================================ +static dstalk_ai_service_t g_service = { + &my_configure, + &my_chat, + &my_chat_stream, + &my_free_result, +}; + +// ============================================================================ +// 生命周期 +// ============================================================================ +static int on_init(const dstalk_host_api_t* host) +{ + g_host = host; + g_http = (dstalk_http_service_t*)host->query_service("http", 1); + g_config = (dstalk_config_service_t*)host->query_service("config", 1); + + if (!g_http) { + if (g_host) g_host->log(DSTALK_LOG_ERROR, "[anthropic] http service not found"); + return -1; + } + + if (g_host) g_host->log(DSTALK_LOG_INFO, "[anthropic] initializing Anthropic AI plugin"); + + return host->register_service("ai.anthropic", 1, &g_service); +} + +static void on_shutdown() +{ + if (g_host) g_host->log(DSTALK_LOG_INFO, "[anthropic] shutdown"); + g_http = nullptr; + g_config = nullptr; + g_host = nullptr; +} + +// ============================================================================ +// 插件描述符 +// ============================================================================ +static dstalk_plugin_info_t g_info = { + /* .name = */ "anthropic-ai", + /* .version = */ "1.0.0", + /* .description = */ "Anthropic Claude AI provider (Messages API)", + /* .api_version = */ DSTALK_API_VERSION, + /* .dependencies = */ { "http", "config", NULL }, + /* .on_init = */ on_init, + /* .on_shutdown = */ on_shutdown, + /* .on_event = */ nullptr, +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) +{ + return &g_info; +} diff --git a/plugins/config/CMakeLists.txt b/plugins/config/CMakeLists.txt new file mode 100644 index 0000000..ddf16a3 --- /dev/null +++ b/plugins/config/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library(plugin-config SHARED src/config_plugin.cpp) + +target_include_directories(plugin-config PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-config PRIVATE dstalk) + +set_target_properties(plugin-config PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/config/src/config_plugin.cpp b/plugins/config/src/config_plugin.cpp new file mode 100644 index 0000000..ee712b4 --- /dev/null +++ b/plugins/config/src/config_plugin.cpp @@ -0,0 +1,146 @@ +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" + +#include +#include +#include +#include +#include +#include + +// ============================================================ +// ConfigStore - independent TOML key-value store +// ============================================================ +namespace { + +class ConfigStore { +public: + int load_file(const char* path) { + if (!path) return -1; + + std::ifstream file(path); + if (!file.is_open()) return -1; + + std::stringstream ss; + ss << file.rdbuf(); + std::string data = ss.str(); + + std::string current_section; + size_t pos = 0; + while (pos < data.size()) { + while (pos < data.size() && (data[pos] == ' ' || data[pos] == '\t')) + pos++; + if (pos >= data.size()) break; + + size_t nl = data.find('\n', pos); + std::string line = (nl != std::string::npos) + ? data.substr(pos, nl - pos) : data.substr(pos); + pos = (nl != std::string::npos) ? nl + 1 : data.size(); + + while (!line.empty() && (line.back() == '\r' || line.back() == ' ')) + line.pop_back(); + + if (line.empty() || line[0] == '#') continue; + + if (line[0] == '[' && line.back() == ']') { + current_section = line.substr(1, line.size() - 2); + continue; + } + + size_t eq = line.find('='); + if (eq == std::string::npos) continue; + + std::string key = line.substr(0, eq); + while (!key.empty() && key.back() == ' ') key.pop_back(); + if (key.empty()) continue; + + std::string val = line.substr(eq + 1); + while (!val.empty() && (val.front() == ' ' || val.front() == '\t')) + val.erase(0, 1); + if (val.size() >= 2 && val.front() == '"' && val.back() == '"') + val = val.substr(1, val.size() - 2); + + std::lock_guard lock(mutex_); + std::string full_key = current_section.empty() + ? key : current_section + "." + key; + data_[full_key] = val; + } + + return 0; + } + + const char* get(const char* key) const { + if (!key) return nullptr; + std::lock_guard lock(mutex_); + auto it = data_.find(key); + if (it == data_.end()) return nullptr; + return it->second.c_str(); + } + + int set(const char* key, const char* value) { + if (!key || !value) return -1; + std::lock_guard lock(mutex_); + data_[key] = value; + return 0; + } + +private: + mutable std::mutex mutex_; + std::unordered_map data_; +}; + +} // anonymous namespace + +// ============================================================ +// Global state +// ============================================================ +static const dstalk_host_api_t* g_host = nullptr; +static ConfigStore g_config; + +// ============================================================ +// Service implementations +// ============================================================ +static const char* config_get(const char* key) { + return g_config.get(key); +} + +static int config_set(const char* key, const char* value) { + return g_config.set(key, value); +} + +static int config_load_file(const char* path) { + return g_config.load_file(path); +} + +static dstalk_config_service_t g_service = { + config_get, + config_set, + config_load_file +}; + +// ============================================================ +// Plugin lifecycle +// ============================================================ +static int on_init(const dstalk_host_api_t* host) { + g_host = host; + return host->register_service("config", 1, &g_service); +} + +static void on_shutdown() { + // nothing to clean up +} + +static dstalk_plugin_info_t g_info = { + "config", // name + "1.0.0", // version + "Configuration service with TOML file support", // description + DSTALK_API_VERSION, // api_version + {nullptr}, // dependencies (none) + on_init, // on_init + on_shutdown, // on_shutdown + nullptr // on_event +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { + return &g_info; +} diff --git a/plugins/context/CMakeLists.txt b/plugins/context/CMakeLists.txt new file mode 100644 index 0000000..06efd39 --- /dev/null +++ b/plugins/context/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library(plugin-context SHARED src/context_plugin.cpp) + +target_include_directories(plugin-context PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-context PRIVATE dstalk) + +set_target_properties(plugin-context PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/context/src/context_plugin.cpp b/plugins/context/src/context_plugin.cpp new file mode 100644 index 0000000..f230430 --- /dev/null +++ b/plugins/context/src/context_plugin.cpp @@ -0,0 +1,289 @@ +// plugin-context: 上下文管理服务插件 +// 提供 dstalk_context_service_t vtable 实现 +// 依赖: session (获取历史消息做 token 计数) +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_types.h" +#include "dstalk/dstalk_services.h" + +#include +#include +#include +#include +#include +#include + +// ============================================================ +// 全局状态 +// ============================================================ + +static const dstalk_host_api_t* g_host = nullptr; +static const dstalk_session_service_t* g_session = nullptr; +static size_t g_max_tokens = 4096; + +// ============================================================ +// 内部 C++ 辅助:token 计数 +// ============================================================ + +static bool cjk_is_ascii(unsigned char c) { return c < 0x80; } + +static bool cjk_starts_cjk(unsigned char c) { + // U+4E00-U+9FFF 在 UTF-8 中编码为 0xE4-0xE9 开头的三字节 + return c >= 0xE4 && c <= 0xE9; +} + +static size_t count_tokens_one_message(const dstalk_message_t& msg) { + const char* text = msg.content; + if (!text) return 4; // 只有 overhead + + size_t ascii_chars = 0; + size_t chinese_chars = 0; + size_t other_chars = 0; + + size_t i = 0; + while (text[i] != '\0') { + unsigned char c = static_cast(text[i]); + + if (cjk_is_ascii(c)) { + ascii_chars++; + i += 1; + } else if (cjk_starts_cjk(c)) { + chinese_chars++; + i += 3; + } else if (c >= 0xC0 && c < 0xE0) { + other_chars++; + i += 2; + } else if (c >= 0xE0 && c < 0xF0) { + other_chars++; + i += 3; + } else if (c >= 0xF0 && c < 0xF8) { + other_chars++; + i += 4; + } else { + other_chars++; + i += 1; + } + } + + size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3); + return content_tokens + 4; // +4 条消息开销 (role + separators) +} + +static size_t count_tokens_all(const dstalk_message_t* msgs, int count) { + size_t total = 0; + for (int i = 0; i < count; ++i) { + total += count_tokens_one_message(msgs[i]); + } + return total; +} + +// ============================================================ +// 内部 trim 逻辑 +// ============================================================ + +// 为 trim 操作将 C 消息数组复制到内部 struct +struct TrimMessage { + std::string role; + std::string content; + std::string tool_call_id; + std::string tool_calls_json; +}; + +static size_t count_tokens_trim(const TrimMessage& msg) { + if (msg.content.empty()) return 4; + const std::string& text = msg.content; + size_t ascii_chars = 0, chinese_chars = 0, other_chars = 0; + size_t i = 0; + while (i < text.size()) { + unsigned char c = static_cast(text[i]); + if (cjk_is_ascii(c)) { ascii_chars++; i += 1; } + else if (cjk_starts_cjk(c)) { chinese_chars++; i += 3; } + else if (c >= 0xC0 && c < 0xE0) { other_chars++; i += 2; } + else if (c >= 0xE0 && c < 0xF0) { other_chars++; i += 3; } + else if (c >= 0xF0 && c < 0xF8) { other_chars++; i += 4; } + else { other_chars++; i += 1; } + } + return (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3) + 4; +} + +static size_t count_tokens_trim_vec(const std::vector& msgs) { + size_t total = 0; + for (const auto& m : msgs) total += count_tokens_trim(m); + return total; +} + +static int trim_impl(const dstalk_message_t* in, int in_count, + dstalk_message_t** out, int* out_count, + size_t max_tokens) { + if (!in || in_count <= 0 || !out || !out_count) return -1; + + // 将 C 数组转换为内部 vector + std::vector messages; + messages.reserve(in_count); + for (int i = 0; i < in_count; ++i) { + TrimMessage tm; + if (in[i].role) tm.role = in[i].role; + if (in[i].content) tm.content = in[i].content; + if (in[i].tool_call_id) tm.tool_call_id = in[i].tool_call_id; + if (in[i].tool_calls_json) tm.tool_calls_json = in[i].tool_calls_json; + messages.push_back(std::move(tm)); + } + + // 如果已在限制内,直接返回完整副本 + size_t current = count_tokens_trim_vec(messages); + if (current <= max_tokens) { + *out_count = in_count; + *out = static_cast(g_host->alloc(sizeof(dstalk_message_t) * in_count)); + if (!*out) return -1; + for (int i = 0; i < in_count; ++i) { + (*out)[i].role = messages[i].role.empty() ? nullptr : g_host->strdup(messages[i].role.c_str()); + (*out)[i].content = messages[i].content.empty() ? nullptr : g_host->strdup(messages[i].content.c_str()); + (*out)[i].tool_call_id = messages[i].tool_call_id.empty() ? nullptr : g_host->strdup(messages[i].tool_call_id.c_str()); + (*out)[i].tool_calls_json = messages[i].tool_calls_json.empty() ? nullptr : g_host->strdup(messages[i].tool_calls_json.c_str()); + } + return 0; + } + + // 分离 system 消息和非 system 消息 + std::vector system_msgs; + std::vector non_system_msgs; + for (const auto& msg : messages) { + if (msg.role == "system") { + system_msgs.push_back(msg); + } else { + non_system_msgs.push_back(msg); + } + } + + size_t system_tokens = count_tokens_trim_vec(system_msgs); + if (system_tokens > max_tokens) { + std::fprintf(stderr, "[context] WARNING: system messages alone " + "(%zu tokens) exceed max_context_tokens (%zu)\n", + system_tokens, max_tokens); + } + + // 检查是否有单条消息超过限制 + for (const auto& msg : non_system_msgs) { + size_t msg_tokens = count_tokens_trim(msg); + if (msg_tokens > max_tokens) { + std::fprintf(stderr, "[context] WARNING: single message " + "(%s, %zu tokens) exceeds max_context_tokens (%zu). " + "Returning empty list.\n", + msg.role.c_str(), msg_tokens, max_tokens); + *out = nullptr; + *out_count = 0; + return -1; + } + } + + // 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 + while (!non_system_msgs.empty()) { + current = system_tokens + count_tokens_trim_vec(non_system_msgs); + if (current <= max_tokens) break; + + // 找第一个 "user" 消息 + auto user_it = non_system_msgs.begin(); + while (user_it != non_system_msgs.end() && user_it->role != "user") { + ++user_it; + } + if (user_it == non_system_msgs.end()) break; + + // 找下一个 "assistant" + auto assistant_it = user_it + 1; + while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") { + ++assistant_it; + } + + if (assistant_it == non_system_msgs.end()) { + non_system_msgs.erase(user_it); + } else { + // 先删 assistant 再删 user 避免迭代器失效 + non_system_msgs.erase(assistant_it); + user_it = non_system_msgs.begin(); + while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it; + if (user_it != non_system_msgs.end()) non_system_msgs.erase(user_it); + } + } + + // 组装结果 + std::vector result; + result.reserve(system_msgs.size() + non_system_msgs.size()); + result.insert(result.end(), system_msgs.begin(), system_msgs.end()); + result.insert(result.end(), non_system_msgs.begin(), non_system_msgs.end()); + + int result_count = static_cast(result.size()); + *out_count = result_count; + *out = static_cast(g_host->alloc(sizeof(dstalk_message_t) * result_count)); + if (!*out) return -1; + + for (int i = 0; i < result_count; ++i) { + (*out)[i].role = result[i].role.empty() ? nullptr : g_host->strdup(result[i].role.c_str()); + (*out)[i].content = result[i].content.empty() ? nullptr : g_host->strdup(result[i].content.c_str()); + (*out)[i].tool_call_id = result[i].tool_call_id.empty() ? nullptr : g_host->strdup(result[i].tool_call_id.c_str()); + (*out)[i].tool_calls_json = result[i].tool_calls_json.empty() ? nullptr : g_host->strdup(result[i].tool_calls_json.c_str()); + } + + return 0; +} + +// ============================================================ +// Context 服务 vtable 实现 +// ============================================================ + +static size_t context_count_tokens(const dstalk_message_t* msgs, int count) { + if (!msgs || count <= 0) return 0; + return count_tokens_all(msgs, count); +} + +static int context_trim(const dstalk_message_t* in, int in_count, + dstalk_message_t** out, int* out_count, + size_t max_tokens) { + return trim_impl(in, in_count, out, out_count, max_tokens); +} + +static void context_set_max_tokens(size_t max) { + g_max_tokens = max; +} + +static dstalk_context_service_t g_context_service = { + context_count_tokens, + context_trim, + context_set_max_tokens +}; + +// ============================================================ +// 插件生命周期 +// ============================================================ + +static int on_init(const dstalk_host_api_t* host) { + g_host = host; + + // 查询依赖服务: session + void* raw = host->query_service("session", 1); + if (!raw) { + host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found"); + return -1; + } + g_session = static_cast(raw); + + return host->register_service("context", 1, &g_context_service); +} + +static void on_shutdown() { + g_session = nullptr; + g_host = nullptr; +} + +static dstalk_plugin_info_t g_info = { + "context", + "1.0.0", + "Context management plugin with token counting and trim support", + DSTALK_API_VERSION, + {"session", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, + on_init, + on_shutdown, + nullptr +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { + return &g_info; +} diff --git a/plugins/deepseek/CMakeLists.txt b/plugins/deepseek/CMakeLists.txt new file mode 100644 index 0000000..2c218dd --- /dev/null +++ b/plugins/deepseek/CMakeLists.txt @@ -0,0 +1,32 @@ +cmake_minimum_required(VERSION 3.21) + +# ============================================================ +# plugin-deepseek — DeepSeek AI 服务 (OpenAI 兼容) +# 依赖: http 服务 (查询), config 服务 (查询) +# ============================================================ + +add_library(plugin-deepseek SHARED + src/deepseek_plugin.cpp +) + +target_include_directories(plugin-deepseek PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-deepseek PRIVATE dstalk) + +# Boost.JSON 用于构建/解析请求和响应 +find_package(Boost REQUIRED CONFIG) +target_link_libraries(plugin-deepseek PRIVATE boost::boost) + +target_compile_definitions(plugin-deepseek PRIVATE + BOOST_ALL_NO_LIB + BOOST_ERROR_CODE_HEADER_ONLY + BOOST_JSON_HEADER_ONLY +) + +set_target_properties(plugin-deepseek PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/deepseek/src/deepseek_plugin.cpp b/plugins/deepseek/src/deepseek_plugin.cpp new file mode 100644 index 0000000..9f62aa2 --- /dev/null +++ b/plugins/deepseek/src/deepseek_plugin.cpp @@ -0,0 +1,475 @@ +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" + +#include +#include +#include + +namespace json = boost::json; + +// ============================================================================ +// 全局指针:从 on_init 获取 +// ============================================================================ +static const dstalk_host_api_t* g_host = nullptr; +static dstalk_http_service_t* g_http = nullptr; +static dstalk_config_service_t* g_config = nullptr; + +// ============================================================================ +// 配置数据(由 configure() 设置) +// ============================================================================ +struct PluginConfig { + std::string provider; + std::string base_url; + std::string api_key; + std::string model; + int max_tokens = 4096; + double temperature = 0.7; +}; +static PluginConfig g_cfg; + +// ============================================================================ +// 辅助:从 base_url 提取 host 和 target +// ============================================================================ +static bool extract_host_port(const std::string& url, + std::string& scheme_out, std::string& host_out, + std::string& port_out, std::string& target_out) +{ + size_t scheme_end = url.find("://"); + if (scheme_end == std::string::npos) return false; + scheme_out = url.substr(0, scheme_end); + std::string rest = url.substr(scheme_end + 3); + size_t slash = rest.find('/'); + std::string authority = (slash != std::string::npos) ? rest.substr(0, slash) : rest; + target_out = (slash != std::string::npos) ? rest.substr(slash) : "/"; + size_t colon = authority.rfind(':'); + if (colon != std::string::npos) { + host_out = authority.substr(0, colon); + port_out = authority.substr(colon + 1); + } else { + host_out = authority; + port_out = (scheme_out == "https") ? "443" : "80"; + } + return true; +} + +// ============================================================================ +// 辅助:构建 headers JSON 字符串 +// ============================================================================ +static std::string build_headers_json(const std::string& auth_header_value) +{ + json::object h; + h["Authorization"] = "Bearer " + auth_header_value; + return json::serialize(h); +} + +// ============================================================================ +// 辅助:dstalk_message_t[] -> boost::json::array +// ============================================================================ +static void append_history(json::array& msgs, + const dstalk_message_t* history, int history_len) +{ + for (int i = 0; i < history_len; ++i) { + const auto& m = history[i]; + json::object obj; + obj["role"] = m.role ? m.role : ""; + + if (m.role && std::strcmp(m.role, "tool") == 0) { + obj["tool_call_id"] = m.tool_call_id ? m.tool_call_id : ""; + obj["content"] = m.content ? m.content : ""; + } else if (m.role && std::strcmp(m.role, "assistant") == 0 && + m.tool_calls_json && m.tool_calls_json[0] != '\0') { + obj["content"] = m.content ? m.content : ""; + obj["tool_calls"] = json::parse(m.tool_calls_json); + } else { + obj["content"] = m.content ? m.content : ""; + } + msgs.push_back(obj); + } +} + +// ============================================================================ +// 构建 DeepSeek JSON 请求体 +// ============================================================================ +static std::string build_request_json( + const dstalk_message_t* history, int history_len, + const std::string& user_input, + const std::string& tools_json, + bool stream) +{ + json::object root; + root["model"] = g_cfg.model; + root["max_tokens"] = g_cfg.max_tokens; + root["temperature"] = g_cfg.temperature; + root["stream"] = stream; + + json::array msgs; + append_history(msgs, history, history_len); + + // 追加当前用户输入 + if (!user_input.empty()) { + json::object obj; + obj["role"] = "user"; + obj["content"] = user_input; + msgs.push_back(obj); + } + + root["messages"] = msgs; + + // tools 定义 + if (!tools_json.empty()) { + root["tools"] = json::parse(tools_json); + } + + return json::serialize(root); +} + +// ============================================================================ +// 解析非流式 JSON 响应 +// ============================================================================ +static void parse_response(const char* body, int http_status, + dstalk_chat_result_t& r) +{ + r.http_status = http_status; + + if (http_status < 200 || http_status >= 300) { + r.ok = 0; + try { + auto jv = json::parse(body ? body : "{}"); + auto obj = jv.as_object(); + if (obj.contains("error")) { + auto err = obj["error"].as_object(); + r.error = g_host->strdup( + json::value_to(err["message"]).c_str()); + } + } catch (...) { + std::string msg = "HTTP " + std::to_string(http_status); + r.error = g_host->strdup(msg.c_str()); + } + if (!r.error) { + std::string msg = "HTTP " + std::to_string(http_status); + r.error = g_host->strdup(msg.c_str()); + } + r.content = nullptr; + r.tool_calls_json = nullptr; + return; + } + + try { + auto jv = json::parse(body ? body : "{}"); + auto obj = jv.as_object(); + auto choices = obj["choices"].as_array(); + if (!choices.empty()) { + auto msg = choices[0].as_object()["message"].as_object(); + + std::string content = json::value_to(msg["content"]); + r.content = g_host->strdup(content.c_str()); + + if (msg.contains("tool_calls")) { + std::string tc = json::serialize(msg["tool_calls"]); + r.tool_calls_json = g_host->strdup(tc.c_str()); + } else { + r.tool_calls_json = nullptr; + } + + r.ok = 1; + r.error = nullptr; + } else { + r.ok = 0; + r.error = g_host->strdup("empty response"); + r.content = nullptr; + r.tool_calls_json = nullptr; + } + } catch (std::exception& e) { + r.ok = 0; + std::string msg = std::string("json parse: ") + e.what(); + r.error = g_host->strdup(msg.c_str()); + r.content = nullptr; + r.tool_calls_json = nullptr; + } catch (...) { + r.ok = 0; + r.error = g_host->strdup("json parse error"); + r.content = nullptr; + r.tool_calls_json = nullptr; + } +} + +// ============================================================================ +// SSE 行解析(OpenAI 兼容格式) +// ============================================================================ +static bool parse_sse_line(const std::string& line, std::string& token_out) +{ + if (line.rfind("data: ", 0) != 0) return false; + + std::string data = line.substr(6); + if (data == "[DONE]") { + token_out.clear(); + return true; // 流结束信号 + } + + try { + auto jv = json::parse(data); + auto obj = jv.as_object(); + auto choices = obj["choices"].as_array(); + if (!choices.empty()) { + auto delta = choices[0].as_object()["delta"].as_object(); + if (delta.contains("content")) { + token_out = json::value_to(delta["content"]); + return true; + } + } + } catch (...) { + // 忽略解析失败 + } + return false; +} + +// ============================================================================ +// configure 实现 +// ============================================================================ +static int my_configure(const char* provider, const char* base_url, + const char* api_key, const char* model, + int max_tokens, double temperature) +{ + if (provider) g_cfg.provider = provider; + if (base_url) g_cfg.base_url = base_url; + if (api_key) g_cfg.api_key = api_key; + if (model) g_cfg.model = model; + g_cfg.max_tokens = max_tokens; + g_cfg.temperature = temperature; + + if (g_host) { + g_host->log(DSTALK_LOG_INFO, + "[deepseek] configured: model=%s base_url=%s max_tokens=%d temperature=%.2f", + g_cfg.model.c_str(), g_cfg.base_url.c_str(), + g_cfg.max_tokens, g_cfg.temperature); + } + return 0; +} + +// ============================================================================ +// chat 实现 +// ============================================================================ +static dstalk_chat_result_t my_chat( + const dstalk_message_t* history, int history_len, + const char* user_input, + const char* tools_json) +{ + dstalk_chat_result_t r = {}; + r.ok = 0; + + if (!g_http) { + r.error = g_host->strdup("http service not available"); + return r; + } + + std::string scheme, host, port, target; + extract_host_port(g_cfg.base_url, scheme, host, port, target); + std::string target_path = target + "/chat/completions"; + + std::string body = build_request_json(history, history_len, + user_input ? user_input : "", tools_json ? tools_json : "", false); + + std::string headers_json = build_headers_json(g_cfg.api_key); + + char* response_body = nullptr; + int status_code = 0; + + int ret = g_http->post_json( + host.c_str(), port.c_str(), target_path.c_str(), body.c_str(), + headers_json.c_str(), &response_body, &status_code); + + if (ret != 0) { + r.error = g_host->strdup("http request failed"); + return r; + } + + parse_response(response_body, status_code, r); + + if (response_body) { + g_host->free(response_body); + } + return r; +} + +// ============================================================================ +// chat_stream 实现 +// ============================================================================ + +// 回调上下文:在流式传输中收集累积内容和最终状态 +struct StreamContext { + const dstalk_host_api_t* host; + dstalk_stream_cb user_cb; + void* userdata; + std::string accumulated; + bool streaming_ok = true; +}; + +// 行回调:解析 SSE line,将 token 传递给用户回调 +static int sse_line_callback(const char* line, void* userdata) +{ + auto* ctx = static_cast(userdata); + if (!line || !line[0]) return 1; // 空行,继续 + + std::string line_str(line); + std::string token; + + if (!parse_sse_line(line_str, token)) return 1; // 非 data 行,继续 + + if (token.empty()) return 0; // [DONE],停止 + + ctx->accumulated += token; + + if (ctx->user_cb) { + return ctx->user_cb(token.c_str(), ctx->userdata); + } + return 1; // 继续 +} + +static dstalk_chat_result_t my_chat_stream( + const dstalk_message_t* history, int history_len, + const char* user_input, + dstalk_stream_cb cb, void* userdata) +{ + dstalk_chat_result_t r = {}; + r.ok = 0; + + if (!g_http) { + r.error = g_host->strdup("http service not available"); + return r; + } + + std::string scheme, host, port, target; + extract_host_port(g_cfg.base_url, scheme, host, port, target); + std::string target_path = target + "/chat/completions"; + + std::string body = build_request_json(history, history_len, + user_input ? user_input : "", "", true); // stream=true, no tools + + std::string headers_json = build_headers_json(g_cfg.api_key); + + StreamContext ctx; + ctx.host = g_host; + ctx.user_cb = cb; + ctx.userdata = userdata; + + char* response_body = nullptr; + int status_code = 0; + + int ret = g_http->post_stream( + host.c_str(), port.c_str(), target_path.c_str(), body.c_str(), + headers_json.c_str(), + sse_line_callback, &ctx, + &response_body, &status_code); + + r.http_status = status_code; + + // 检查传输层错误或非 2xx 状态 + if (status_code < 200 || status_code >= 300) { + r.ok = 0; + // 尝试从响应体提取错误信息 + if (response_body && response_body[0]) { + try { + auto jv = json::parse(response_body); + auto obj = jv.as_object(); + if (obj.contains("error")) { + auto err = obj["error"].as_object(); + r.error = g_host->strdup( + json::value_to(err["message"]).c_str()); + } + } catch (...) {} + } + if (!r.error) { + if (status_code <= 0) + r.error = g_host->strdup("transport error"); + else + r.error = g_host->strdup( + ("HTTP " + std::to_string(status_code)).c_str()); + } + if (response_body) g_host->free(response_body); + r.content = nullptr; + r.tool_calls_json = nullptr; + return r; + } + + if (response_body) g_host->free(response_body); + + if (ctx.accumulated.empty()) { + r.ok = 0; + r.error = g_host->strdup("no content received"); + r.content = nullptr; + r.tool_calls_json = nullptr; + } else { + r.ok = 1; + r.error = nullptr; + r.content = g_host->strdup(ctx.accumulated.c_str()); + r.tool_calls_json = nullptr; + } + return r; +} + +// ============================================================================ +// free_result 实现 +// ============================================================================ +static void my_free_result(dstalk_chat_result_t* result) +{ + if (!result || !g_host) return; + if (result->content) { g_host->free((void*)result->content); result->content = nullptr; } + if (result->error) { g_host->free((void*)result->error); result->error = nullptr; } + if (result->tool_calls_json) { g_host->free((void*)result->tool_calls_json); result->tool_calls_json = nullptr; } +} + +// ============================================================================ +// 服务 vtable +// ============================================================================ +static dstalk_ai_service_t g_service = { + &my_configure, + &my_chat, + &my_chat_stream, + &my_free_result, +}; + +// ============================================================================ +// 生命周期 +// ============================================================================ +static int on_init(const dstalk_host_api_t* host) +{ + g_host = host; + g_http = (dstalk_http_service_t*)host->query_service("http", 1); + g_config = (dstalk_config_service_t*)host->query_service("config", 1); + + if (!g_http) { + if (g_host) g_host->log(DSTALK_LOG_ERROR, "[deepseek] http service not found"); + return -1; + } + + if (g_host) g_host->log(DSTALK_LOG_INFO, "[deepseek] initializing DeepSeek AI plugin"); + + return host->register_service("ai.deepseek", 1, &g_service); +} + +static void on_shutdown() +{ + if (g_host) g_host->log(DSTALK_LOG_INFO, "[deepseek] shutdown"); + g_http = nullptr; + g_config = nullptr; + g_host = nullptr; +} + +// ============================================================================ +// 插件描述符 +// ============================================================================ +static dstalk_plugin_info_t g_info = { + /* .name = */ "deepseek-ai", + /* .version = */ "1.0.0", + /* .description = */ "DeepSeek AI provider (OpenAI-compatible API)", + /* .api_version = */ DSTALK_API_VERSION, + /* .dependencies = */ { "http", "config", NULL }, + /* .on_init = */ on_init, + /* .on_shutdown = */ on_shutdown, + /* .on_event = */ nullptr, +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) +{ + return &g_info; +} diff --git a/plugins/file-io/CMakeLists.txt b/plugins/file-io/CMakeLists.txt new file mode 100644 index 0000000..aa70175 --- /dev/null +++ b/plugins/file-io/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library(plugin-file-io SHARED src/file_io_plugin.cpp) + +target_include_directories(plugin-file-io PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-file-io PRIVATE dstalk) + +set_target_properties(plugin-file-io PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/file-io/src/file_io_plugin.cpp b/plugins/file-io/src/file_io_plugin.cpp new file mode 100644 index 0000000..6c4af67 --- /dev/null +++ b/plugins/file-io/src/file_io_plugin.cpp @@ -0,0 +1,95 @@ +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" + +#include +#include +#include + +// ============================================================ +// Global state +// ============================================================ +static const dstalk_host_api_t* g_host = nullptr; + +// ============================================================ +// Service implementations +// ============================================================ +static int file_read(const char* path, char** content) { + if (!path || !content) return -1; + + FILE* fp = fopen(path, "rb"); + if (!fp) return -1; + + // Get file size + fseek(fp, 0, SEEK_END); + long fsize = ftell(fp); + fseek(fp, 0, SEEK_SET); + + if (fsize < 0) { + fclose(fp); + return -1; + } + + // Allocate buffer (+1 for null terminator) + char* buf = (char*)malloc((size_t)fsize + 1); + if (!buf) { + fclose(fp); + return -1; + } + + size_t read_bytes = fread(buf, 1, (size_t)fsize, fp); + fclose(fp); + + if (read_bytes != (size_t)fsize) { + free(buf); + return -1; + } + + buf[read_bytes] = '\0'; + *content = buf; + return 0; +} + +static int file_write(const char* path, const char* content) { + if (!path || !content) return -1; + + FILE* fp = fopen(path, "wb"); + if (!fp) return -1; + + size_t len = strlen(content); + size_t written = fwrite(content, 1, len, fp); + fclose(fp); + + return (written == len) ? 0 : -1; +} + +static dstalk_file_io_service_t g_service = { + file_read, + file_write +}; + +// ============================================================ +// Plugin lifecycle +// ============================================================ +static int on_init(const dstalk_host_api_t* host) { + g_host = host; + return host->register_service("file_io", 1, &g_service); +} + +static void on_shutdown() { + // nothing to clean up +} + +static dstalk_plugin_info_t g_info = { + "file-io", // name + "1.0.0", // version + "Basic file I/O service", // description + DSTALK_API_VERSION, // api_version + {nullptr}, // dependencies (none) + on_init, // on_init + on_shutdown, // on_shutdown + nullptr // on_event +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { + return &g_info; +} diff --git a/plugins/lsp/CMakeLists.txt b/plugins/lsp/CMakeLists.txt new file mode 100644 index 0000000..6095d68 --- /dev/null +++ b/plugins/lsp/CMakeLists.txt @@ -0,0 +1,38 @@ +cmake_minimum_required(VERSION 3.21) + +# ============================================================ +# plugin-lsp — LSP (Language Server Protocol) 服务 +# 自行管理子进程,无外部服务依赖 +# ============================================================ + +add_library(plugin-lsp SHARED + src/lsp_plugin.cpp +) + +target_include_directories(plugin-lsp PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-lsp PRIVATE dstalk) + +# Boost.JSON 用于 JSON-RPC 消息构建/解析 +find_package(Boost REQUIRED CONFIG) +target_link_libraries(plugin-lsp PRIVATE boost::boost) + +target_compile_definitions(plugin-lsp PRIVATE + BOOST_ALL_NO_LIB + BOOST_ERROR_CODE_HEADER_ONLY + BOOST_JSON_HEADER_ONLY +) + +# POSIX 平台需要 pthread (用于 std::thread) +if(NOT WIN32) + find_package(Threads REQUIRED) + target_link_libraries(plugin-lsp PRIVATE Threads::Threads) +endif() + +set_target_properties(plugin-lsp PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/lsp/src/lsp_plugin.cpp b/plugins/lsp/src/lsp_plugin.cpp new file mode 100644 index 0000000..1899135 --- /dev/null +++ b/plugins/lsp/src/lsp_plugin.cpp @@ -0,0 +1,733 @@ +/* + * plugin-lsp — LSP (Language Server Protocol) 服务 + * + * 自行管理语言服务器子进程,使用 JSON-RPC 2.0 over stdio 通信。 + * 无外部服务依赖(不依赖 http/config 等其他插件)。 + */ + +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// ============================================================================ +// 平台相关 — 子进程管理 (内嵌 subprocess::Process) +// ============================================================================ + +#ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#else +#include +#include +#include +#include +#include +#include +#endif + +namespace json = boost::json; + +// ============================================================================ +// 全局指针 +// ============================================================================ +static const dstalk_host_api_t* g_host = nullptr; + +// ============================================================================ +// 子进程封装 (内嵌 subprocess.hpp) +// ============================================================================ +struct Process { +#ifdef _WIN32 + HANDLE hProcess = INVALID_HANDLE_VALUE; + HANDLE hThread = INVALID_HANDLE_VALUE; + HANDLE hStdIn = INVALID_HANDLE_VALUE; + HANDLE hStdOut = INVALID_HANDLE_VALUE; +#else + pid_t pid = -1; + int stdin_fd = -1; + int stdout_fd = -1; +#endif + + bool start(const char* cmd) { + if (!cmd || !cmd[0]) return false; + stop(); + +#ifdef _WIN32 + SECURITY_ATTRIBUTES sa = {}; + sa.nLength = sizeof(SECURITY_ATTRIBUTES); + sa.bInheritHandle = TRUE; + + HANDLE child_stdin_read = INVALID_HANDLE_VALUE; + HANDLE child_stdout_write = INVALID_HANDLE_VALUE; + + if (!CreatePipe(&child_stdin_read, &hStdIn, &sa, 0)) goto win32_fail; + if (!SetHandleInformation(hStdIn, HANDLE_FLAG_INHERIT, 0)) goto win32_fail; + if (!CreatePipe(&hStdOut, &child_stdout_write, &sa, 0)) goto win32_fail; + if (!SetHandleInformation(hStdOut, HANDLE_FLAG_INHERIT, 0)) goto win32_fail; + + { + STARTUPINFOW si = {}; + si.cb = sizeof(STARTUPINFOW); + si.dwFlags = STARTF_USESTDHANDLES | STARTF_USESHOWWINDOW; + si.wShowWindow = SW_HIDE; + si.hStdInput = child_stdin_read; + si.hStdOutput = child_stdout_write; + si.hStdError = child_stdout_write; + + PROCESS_INFORMATION pi = {}; + std::string cmd_copy(cmd); + wchar_t wcmd[4096] = {}; + if (MultiByteToWideChar(CP_UTF8, 0, cmd_copy.c_str(), -1, wcmd, 4096) == 0) + goto win32_fail; + + if (!CreateProcessW(nullptr, wcmd, nullptr, nullptr, TRUE, + CREATE_NO_WINDOW, nullptr, nullptr, &si, &pi)) + goto win32_fail; + + hProcess = pi.hProcess; + hThread = pi.hThread; + } + + CloseHandle(child_stdin_read); + CloseHandle(child_stdout_write); + return true; + + win32_fail: + if (child_stdin_read != INVALID_HANDLE_VALUE) CloseHandle(child_stdin_read); + if (child_stdout_write != INVALID_HANDLE_VALUE) CloseHandle(child_stdout_write); + if (hStdIn != INVALID_HANDLE_VALUE) { CloseHandle(hStdIn); hStdIn = INVALID_HANDLE_VALUE; } + if (hStdOut != INVALID_HANDLE_VALUE) { CloseHandle(hStdOut); hStdOut = INVALID_HANDLE_VALUE; } + if (hProcess != INVALID_HANDLE_VALUE) { CloseHandle(hProcess); hProcess = INVALID_HANDLE_VALUE; } + if (hThread != INVALID_HANDLE_VALUE) { CloseHandle(hThread); hThread = INVALID_HANDLE_VALUE; } + return false; + +#else + int pin[2] = {-1, -1}; + int pout[2] = {-1, -1}; + + if (pipe(pin) != 0) goto posix_fail; + if (pipe(pout) != 0) goto posix_fail; + + pid = fork(); + if (pid < 0) goto posix_fail; + + if (pid == 0) { + dup2(pin[0], STDIN_FILENO); + close(pin[0]); close(pin[1]); + dup2(pout[1], STDOUT_FILENO); + close(pout[0]); close(pout[1]); + + int max_fd = static_cast(sysconf(_SC_OPEN_MAX)); + if (max_fd > 3) { + for (int i = 3; i < max_fd; ++i) close(i); + } + + char* argv[64] = {}; + int argc = 0; + char* cmd_copy = strdup(cmd); + char* token = strtok(cmd_copy, " \t"); + while (token && argc < 63) { + argv[argc++] = token; + token = strtok(nullptr, " \t"); + } + argv[argc] = nullptr; + execvp(argv[0], argv); + _exit(127); + } + + close(pin[0]); + close(pout[1]); + stdin_fd = pin[1]; + stdout_fd = pout[0]; + return true; + + posix_fail: + if (pin[0] != -1) close(pin[0]); + if (pin[1] != -1) close(pin[1]); + if (pout[0] != -1) close(pout[0]); + if (pout[1] != -1) close(pout[1]); + stdin_fd = -1; + stdout_fd = -1; + pid = -1; + return false; +#endif + } + + void stop() { +#ifdef _WIN32 + if (hProcess != INVALID_HANDLE_VALUE) { + WaitForSingleObject(hProcess, 2000); + TerminateProcess(hProcess, 1); + CloseHandle(hProcess); hProcess = INVALID_HANDLE_VALUE; + } + if (hThread != INVALID_HANDLE_VALUE) { CloseHandle(hThread); hThread = INVALID_HANDLE_VALUE; } + if (hStdIn != INVALID_HANDLE_VALUE) { CloseHandle(hStdIn); hStdIn = INVALID_HANDLE_VALUE; } + if (hStdOut != INVALID_HANDLE_VALUE) { CloseHandle(hStdOut); hStdOut = INVALID_HANDLE_VALUE; } +#else + if (pid > 0) { + kill(pid, SIGTERM); + int status = 0; + for (int i = 0; i < 20; ++i) { + if (waitpid(pid, &status, WNOHANG) > 0) break; + usleep(100000); + } + if (waitpid(pid, &status, WNOHANG) == 0) { + kill(pid, SIGKILL); + waitpid(pid, &status, 0); + } + pid = -1; + } + if (stdin_fd != -1) { close(stdin_fd); stdin_fd = -1; } + if (stdout_fd != -1) { close(stdout_fd); stdout_fd = -1; } +#endif + } + + bool write(const std::string& data) { + if (data.empty()) return true; +#ifdef _WIN32 + if (hStdIn == INVALID_HANDLE_VALUE) return false; + DWORD written = 0; + return WriteFile(hStdIn, data.c_str(), static_cast(data.size()), &written, nullptr) + && written == data.size(); +#else + if (stdin_fd < 0) return false; + size_t total = 0; + const char* buf = data.c_str(); + size_t len = data.size(); + while (total < len) { + ssize_t n = ::write(stdin_fd, buf + total, len - total); + if (n <= 0) return false; + total += static_cast(n); + } + return true; +#endif + } + + bool read_line(std::string& line) { + line.clear(); +#ifdef _WIN32 + if (hStdOut == INVALID_HANDLE_VALUE) return false; + char ch; DWORD nread = 0; + while (true) { + if (!ReadFile(hStdOut, &ch, 1, &nread, nullptr)) return false; + if (nread == 0) return false; + line += ch; + if (ch == '\n') return true; + } +#else + if (stdout_fd < 0) return false; + char ch; + while (true) { + ssize_t n = ::read(stdout_fd, &ch, 1); + if (n <= 0) return false; + line += ch; + if (ch == '\n') return true; + } +#endif + } + + bool read_bytes(std::string& buf, int count) { + if (count <= 0) { buf.clear(); return true; } +#ifdef _WIN32 + if (hStdOut == INVALID_HANDLE_VALUE) return false; + buf.resize(static_cast(count) + 1); + DWORD total = 0, nread = 0; + while (total < static_cast(count)) { + if (!ReadFile(hStdOut, const_cast(buf.data()) + total, + static_cast(count) - total, &nread, nullptr)) + return false; + if (nread == 0) return false; + total += nread; + } + buf[count] = '\0'; + buf.resize(count); + return true; +#else + if (stdout_fd < 0) return false; + buf.resize(count); + size_t total = 0; + while (total < static_cast(count)) { + ssize_t n = ::read(stdout_fd, const_cast(buf.data()) + total, + static_cast(count) - total); + if (n <= 0) return false; + total += static_cast(n); + } + return true; +#endif + } +}; + +// ============================================================================ +// LSP 状态(静态单例) +// ============================================================================ +struct LspState { + Process proc; + std::atomic running{false}; + std::string language; + + std::atomic next_id{1}; + + // 响应用于同步等待 + std::mutex mutex; + std::condition_variable cv; + std::unordered_map pending_responses; + + // 诊断缓存: URI -> JSON 字符串 + std::unordered_map diagnostics; + + // 读取线程 + std::thread reader_thread; +}; +static LspState g_lsp; + +// ============================================================================ +// 辅助函数 +// ============================================================================ + +static std::string_view trim(std::string_view sv) { + while (!sv.empty() && (sv.front() == ' ' || sv.front() == '\t' || + sv.front() == '\r' || sv.front() == '\n')) + sv.remove_prefix(1); + while (!sv.empty() && (sv.back() == ' ' || sv.back() == '\t' || + sv.back() == '\r' || sv.back() == '\n')) + sv.remove_suffix(1); + return sv; +} + +static std::string frame_message(const std::string& body) { + std::string frame; + frame.reserve(64 + body.size()); + frame += "Content-Length: "; + frame += std::to_string(body.size()); + frame += "\r\n\r\n"; + frame += body; + return frame; +} + +static int parse_content_length(const std::string& line) { + auto sv = trim(std::string_view(line)); + const char prefix[] = "Content-Length:"; + const size_t prefix_len = sizeof(prefix) - 1; + + if (sv.size() <= prefix_len) return -1; + for (size_t i = 0; i < prefix_len; ++i) { + if (std::tolower(static_cast(sv[i])) != + std::tolower(static_cast(prefix[i]))) + return -1; + } + + std::string_view num_sv = sv.substr(prefix_len); + while (!num_sv.empty() && (num_sv.front() == ' ' || num_sv.front() == '\t')) + num_sv.remove_prefix(1); + + try { return std::stoi(std::string(num_sv)); } + catch (...) { return -1; } +} + +// ============================================================================ +// JSON-RPC 消息发送 +// ============================================================================ + +static int send_request(const std::string& method, const json::object& params) { + int id = g_lsp.next_id.fetch_add(1); + + json::object msg; + msg["jsonrpc"] = "2.0"; + msg["id"] = id; + msg["method"] = method; + msg["params"] = params; + + std::string body = json::serialize(msg); + g_lsp.proc.write(frame_message(body)); + return id; +} + +static void send_notification(const std::string& method, const json::object& params) { + json::object msg; + msg["jsonrpc"] = "2.0"; + msg["method"] = method; + msg["params"] = params; + + std::string body = json::serialize(msg); + g_lsp.proc.write(frame_message(body)); +} + +// ============================================================================ +// 消息处理 +// ============================================================================ + +static void handle_message(const std::string& body) { + json::value val; + try { val = json::parse(body); } + catch (...) { return; } + + json::object msg; + try { msg = val.as_object(); } + catch (...) { return; } + + if (msg.contains("id") && !msg.contains("method")) { + // 响应 (有 id, 无 method) + int id = static_cast(msg["id"].as_int64()); + std::lock_guard lock(g_lsp.mutex); + g_lsp.pending_responses[id] = body; + g_lsp.cv.notify_all(); + + } else if (msg.contains("method") && !msg.contains("id")) { + // 通知 (有 method, 无 id) + std::string method; + try { method = json::value_to(msg["method"]); } + catch (...) { return; } + + if (method == "textDocument/publishDiagnostics") { + if (!msg.contains("params")) return; + auto params = msg["params"].as_object(); + if (!params.contains("uri")) return; + + std::string uri = json::value_to(params["uri"]); + std::string diag_json; + if (params.contains("diagnostics")) + diag_json = json::serialize(params["diagnostics"]); + else + diag_json = "[]"; + + std::lock_guard lock(g_lsp.mutex); + g_lsp.diagnostics[uri] = diag_json; + } + } +} + +// ============================================================================ +// 读取线程主循环 +// ============================================================================ + +static void reader_loop() { + while (g_lsp.running) { + std::string header_line; + if (!g_lsp.proc.read_line(header_line)) break; + + int content_length = parse_content_length(header_line); + if (content_length < 0) continue; + + // 跳过后续头直到空行 (\r\n 换行被视为非空行,只检查空行) + while (true) { + std::string line; + if (!g_lsp.proc.read_line(line)) break; + auto sv = trim(std::string_view(line)); + if (sv.empty()) break; + } + + std::string body; + if (!g_lsp.proc.read_bytes(body, content_length)) break; + + handle_message(body); + } + + std::lock_guard lock(g_lsp.mutex); + g_lsp.running = false; + g_lsp.cv.notify_all(); +} + +// ============================================================================ +// LSP 服务 vtable 实现 (定义在 vtable 变量之前) +// ============================================================================ + +static int g_lsp_impl_stop(); + +static int g_lsp_impl_start(const char* server_cmd, const char* language) { + if (!server_cmd || !server_cmd[0]) return -1; + + // 如果已在运行, 先停止 + if (g_lsp.running) { + g_lsp_impl_stop(); + } + + g_lsp.language = language ? language : ""; + + // 启动进程 + if (!g_lsp.proc.start(server_cmd)) { + if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] failed to start: %s", server_cmd); + return -1; + } + + // 重置 ID 计数器 + g_lsp.next_id = 1; + + // 启动读取线程 + g_lsp.running = true; + g_lsp.reader_thread = std::thread(reader_loop); + + // 构建 initialize 参数 + json::object text_doc_caps; + { + json::object hover; + hover["dynamicRegistration"] = false; + text_doc_caps["hover"] = hover; + + json::object completion; + completion["dynamicRegistration"] = false; + text_doc_caps["completion"] = completion; + + json::object diagnostic; + diagnostic["dynamicRegistration"] = false; + text_doc_caps["diagnostic"] = diagnostic; + } + + json::object capabilities; + capabilities["textDocument"] = text_doc_caps; + + json::object init_params; + init_params["processId"] = nullptr; + init_params["rootUri"] = nullptr; + init_params["capabilities"] = capabilities; + + // 发送 initialize 请求 + int init_id = send_request("initialize", init_params); + + // 等待 initialize 响应 (最多 10 秒) + { + std::unique_lock lock(g_lsp.mutex); + bool got = g_lsp.cv.wait_for(lock, std::chrono::seconds(10), [init_id]() { + return !g_lsp.running || g_lsp.pending_responses.count(init_id) > 0; + }); + + if (!got || !g_lsp.running) { + if (g_host) g_host->log(DSTALK_LOG_ERROR, "[lsp] initialize timed out"); + g_lsp_impl_stop(); + return -1; + } + g_lsp.pending_responses.erase(init_id); + } + + // 发送 initialized 通知 + send_notification("initialized", json::object{}); + + if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] server started: %s", server_cmd); + return 0; +} + +static void g_lsp_impl_stop() { + if (!g_lsp.running) return; + + // 发送 shutdown 请求 + int shutdown_id = send_request("shutdown", json::object{}); + + // 等待 shutdown 响应 (最多 2 秒) + { + std::unique_lock lock(g_lsp.mutex); + g_lsp.cv.wait_for(lock, std::chrono::seconds(2), [shutdown_id]() { + return !g_lsp.running || g_lsp.pending_responses.count(shutdown_id) > 0; + }); + g_lsp.pending_responses.clear(); + } + + // 发送 exit 通知 + send_notification("exit", json::object{}); + + // 停止读取线程 + g_lsp.running = false; + g_lsp.proc.stop(); + + if (g_lsp.reader_thread.joinable()) + g_lsp.reader_thread.join(); + + g_lsp.diagnostics.clear(); + if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] server stopped"); +} + +static int g_lsp_impl_open_document(const char* uri, const char* content, + const char* lang_id) { + if (!g_lsp.running) return -1; + if (!uri || !content || !lang_id) return -1; + + json::object text_doc; + text_doc["uri"] = uri; + text_doc["languageId"] = lang_id; + text_doc["version"] = 1; + text_doc["text"] = content; + + json::object params; + params["textDocument"] = text_doc; + + send_notification("textDocument/didOpen", params); + return 0; +} + +static int g_lsp_impl_close_document(const char* uri) { + if (!g_lsp.running) return -1; + if (!uri) return -1; + + json::object text_doc; + text_doc["uri"] = uri; + + json::object params; + params["textDocument"] = text_doc; + + send_notification("textDocument/didClose", params); + return 0; +} + +static int g_lsp_impl_get_diagnostics(const char* uri, char** json_out) { + if (!g_lsp.running) return -1; + if (!uri || !json_out) return -1; + + std::lock_guard lock(g_lsp.mutex); + auto it = g_lsp.diagnostics.find(uri); + if (it == g_lsp.diagnostics.end()) { + *json_out = g_host->strdup("[]"); + } else { + *json_out = g_host->strdup(it->second.c_str()); + } + return 0; +} + +static int g_lsp_impl_get_hover(const char* uri, int line, int col, char** json_out) { + if (!g_lsp.running) return -1; + if (!uri || !json_out) return -1; + + json::object position; + position["line"] = line; + position["character"] = col; + + json::object text_doc; + text_doc["uri"] = uri; + + json::object params; + params["textDocument"] = text_doc; + params["position"] = position; + + int req_id = send_request("textDocument/hover", params); + + std::unique_lock lock(g_lsp.mutex); + bool got = g_lsp.cv.wait_for(lock, std::chrono::seconds(10), [req_id]() { + return !g_lsp.running || g_lsp.pending_responses.count(req_id) > 0; + }); + + if (!got || !g_lsp.running || g_lsp.pending_responses.count(req_id) == 0) { + return -1; + } + + std::string response_body = g_lsp.pending_responses[req_id]; + g_lsp.pending_responses.erase(req_id); + + json::value val; + try { val = json::parse(response_body); } + catch (...) { return -1; } + + json::object resp; + try { resp = val.as_object(); } + catch (...) { return -1; } + + if (!resp.contains("result")) return -1; + + *json_out = g_host->strdup(json::serialize(resp["result"]).c_str()); + return 0; +} + +static int g_lsp_impl_get_completion(const char* uri, int line, int col, char** json_out) { + if (!g_lsp.running) return -1; + if (!uri || !json_out) return -1; + + json::object position; + position["line"] = line; + position["character"] = col; + + json::object text_doc; + text_doc["uri"] = uri; + + json::object params; + params["textDocument"] = text_doc; + params["position"] = position; + + int req_id = send_request("textDocument/completion", params); + + std::unique_lock lock(g_lsp.mutex); + bool got = g_lsp.cv.wait_for(lock, std::chrono::seconds(10), [req_id]() { + return !g_lsp.running || g_lsp.pending_responses.count(req_id) > 0; + }); + + if (!got || !g_lsp.running || g_lsp.pending_responses.count(req_id) == 0) { + return -1; + } + + std::string response_body = g_lsp.pending_responses[req_id]; + g_lsp.pending_responses.erase(req_id); + + json::value val; + try { val = json::parse(response_body); } + catch (...) { return -1; } + + json::object resp; + try { resp = val.as_object(); } + catch (...) { return -1; } + + if (!resp.contains("result")) return -1; + + *json_out = g_host->strdup(json::serialize(resp["result"]).c_str()); + return 0; +} + +// ============================================================================ +// 服务 vtable +// ============================================================================ + +static dstalk_lsp_service_t g_service_vtable = { + &g_lsp_impl_start, + &g_lsp_impl_stop, + &g_lsp_impl_open_document, + &g_lsp_impl_close_document, + &g_lsp_impl_get_diagnostics, + &g_lsp_impl_get_hover, + &g_lsp_impl_get_completion, +}; + +// ============================================================================ +// 生命周期回调 +// ============================================================================ + +static int on_init(const dstalk_host_api_t* host) { + g_host = host; + if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] initializing LSP service plugin"); + return host->register_service("lsp", 1, &g_service_vtable); +} + +static void on_shutdown() { + if (g_lsp.running) { + g_lsp_impl_stop(); + } + if (g_host) g_host->log(DSTALK_LOG_INFO, "[lsp] shutdown"); + g_host = nullptr; +} + +// ============================================================================ +// 插件描述符 +// ============================================================================ + +static dstalk_plugin_info_t g_info = { + /* .name = */ "lsp", + /* .version = */ "1.0.0", + /* .description = */ "Language Server Protocol client (subprocess manager)", + /* .api_version = */ DSTALK_API_VERSION, + /* .dependencies = */ { NULL }, // 无依赖,自行管理子进程 + /* .on_init = */ on_init, + /* .on_shutdown = */ on_shutdown, + /* .on_event = */ nullptr, +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { + return &g_info; +} diff --git a/plugins/network/CMakeLists.txt b/plugins/network/CMakeLists.txt new file mode 100644 index 0000000..a04a6ec --- /dev/null +++ b/plugins/network/CMakeLists.txt @@ -0,0 +1,20 @@ +find_package(Boost REQUIRED CONFIG) +find_package(OpenSSL REQUIRED CONFIG) + +add_library(plugin-network SHARED src/network_plugin.cpp) + +target_include_directories(plugin-network PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-network PRIVATE + dstalk + boost::boost + openssl::openssl +) + +set_target_properties(plugin-network PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/network/src/network_plugin.cpp b/plugins/network/src/network_plugin.cpp new file mode 100644 index 0000000..7f64d5b --- /dev/null +++ b/plugins/network/src/network_plugin.cpp @@ -0,0 +1,322 @@ +// MSVC 14.16 (VS 2017) doesn't provide std::to_address (C++20) +#define BOOST_ASIO_DISABLE_STD_TO_ADDRESS + +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace beast = boost::beast; +namespace http = beast::http; +namespace asio = boost::asio; +namespace ssl = boost::asio::ssl; +using tcp = asio::ip::tcp; + +// ============================================================ +// Global state +// ============================================================ +static const dstalk_host_api_t* g_host = nullptr; +static dstalk_config_service_t* g_config_svc = nullptr; + +// ============================================================ +// Minimal JSON header parser +// Parses {"key1":"value1","key2":"value2"} into unordered_map +// ============================================================ +static std::unordered_map parse_headers_json(const char* json) { + std::unordered_map headers; + if (!json || !*json) return headers; + + std::string s(json); + // Very simple state-machine parser for flat string-key/value objects + enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE; + std::string current_key; + std::string current_value; + + for (size_t i = 0; i < s.size(); ++i) { + char c = s[i]; + switch (state) { + case OUTSIDE: + if (c == '"') { state = IN_KEY; current_key.clear(); } + break; + case IN_KEY: + if (c == '"') { state = AFTER_KEY; } + else if (c == '\\' && i + 1 < s.size()) { current_key += s[++i]; } + else { current_key += c; } + break; + case AFTER_KEY: + if (c == ':') { state = IN_VALUE; current_value.clear(); } + break; + case IN_VALUE: + if (c == '"') { + // Read until closing quote + ++i; + while (i < s.size() && s[i] != '"') { + if (s[i] == '\\' && i + 1 < s.size()) { current_value += s[++i]; } + else { current_value += s[i]; } + ++i; + } + headers[current_key] = current_value; + state = OUTSIDE; + } + break; + } + } + return headers; +} + +// ============================================================ +// HTTP Client implementation (adapted from dstalk-core HttpClient) +// ============================================================ +struct HttpClientCtx { + asio::io_context ioc; + ssl::context ssl_ctx{ssl::context::tlsv12_client}; + int connect_timeout = 30; + int request_timeout = 120; + + HttpClientCtx() { + ssl_ctx.set_default_verify_paths(); + } +}; + +static int do_post_stream( + const char* host, + const char* port, + const char* target, + const char* body, + const char* headers_json, + dstalk_stream_cb cb, + void* userdata, + char** response_body, + int* status_code) +{ + if (!host || !port || !target || !body || !response_body || !status_code) { + if (response_body) *response_body = nullptr; + if (status_code) *status_code = -1; + return -1; + } + + // Initialize output + *response_body = nullptr; + *status_code = -1; + + // Build C++ lambda from C callback + std::function on_line; + if (cb) { + on_line = [cb, userdata](const std::string& line) -> bool { + return cb(line.c_str(), userdata) == 0; + }; + } + + HttpClientCtx ctx; + + // Read timeouts from config if available + if (g_config_svc) { + const char* ct = g_config_svc->get("http.connect_timeout"); + const char* rt = g_config_svc->get("http.request_timeout"); + if (ct) ctx.connect_timeout = std::atoi(ct); + if (rt) ctx.request_timeout = std::atoi(rt); + if (ctx.connect_timeout <= 0) ctx.connect_timeout = 30; + if (ctx.request_timeout <= 0) ctx.request_timeout = 120; + } + + std::string result_body; + int result_code = -1; + + try { + tcp::resolver resolver(ctx.ioc); + auto endpoints = resolver.resolve(host, port); + + beast::ssl_stream stream(ctx.ioc, ctx.ssl_ctx); + beast::flat_buffer buffer; + + // SNI hostname + if (!SSL_set_tlsext_host_name(stream.native_handle(), host)) { + result_body = "SNI hostname set failed"; + goto done; + } + + // Connect + beast::get_lowest_layer(stream).expires_after( + std::chrono::seconds(ctx.connect_timeout)); + beast::get_lowest_layer(stream).connect(endpoints); + beast::get_lowest_layer(stream).expires_never(); + + // SSL handshake + beast::get_lowest_layer(stream).expires_after( + std::chrono::seconds(ctx.connect_timeout)); + stream.handshake(ssl::stream_base::client); + beast::get_lowest_layer(stream).expires_never(); + + // Build HTTP POST request + http::request req{http::verb::post, target, 11}; + req.set(http::field::host, host); + req.set(http::field::user_agent, "dstalk/0.1"); + req.set(http::field::content_type, "application/json"); + req.body() = body; + req.prepare_payload(); + + // Add extra headers from JSON + auto extra_headers = parse_headers_json(headers_json); + for (const auto& h : extra_headers) { + req.set(h.first, h.second); + } + + // Send + beast::get_lowest_layer(stream).expires_after( + std::chrono::seconds(ctx.request_timeout)); + http::write(stream, req); + beast::get_lowest_layer(stream).expires_never(); + + // Read response + http::response_parser parser; + parser.body_limit(16 * 1024 * 1024); + beast::get_lowest_layer(stream).expires_after( + std::chrono::seconds(ctx.request_timeout)); + http::read_header(stream, buffer, parser); + beast::get_lowest_layer(stream).expires_never(); + + result_code = parser.get().result_int(); + + beast::error_code ec; + + if (on_line) { + std::string fragment = parser.get().body(); + auto emit_lines = [&]() -> bool { + size_t pos = 0; + while (pos < fragment.size()) { + size_t nl = fragment.find('\n', pos); + if (nl == std::string::npos) break; + std::string line = fragment.substr(pos, nl - pos); + if (!line.empty() && line.back() == '\r') + line.pop_back(); + if (!on_line(line)) return false; + pos = nl + 1; + } + if (pos > 0) + fragment = fragment.substr(pos); + return true; + }; + if (!emit_lines()) goto done; + + size_t processed = parser.get().body().size(); + while (!parser.is_done()) { + beast::get_lowest_layer(stream).expires_after( + std::chrono::seconds(ctx.request_timeout)); + http::read_some(stream, buffer, parser, ec); + if (ec) break; + + const std::string& full_body = parser.get().body(); + if (full_body.size() > processed) { + std::string_view new_data(full_body.data() + processed, + full_body.size() - processed); + processed = full_body.size(); + + fragment.append(new_data.data(), new_data.size()); + if (!emit_lines()) goto done; + } + } + if (!fragment.empty()) { + if (fragment.back() == '\r') + fragment.pop_back(); + if (!fragment.empty()) + on_line(fragment); + } + } else { + while (!parser.is_done()) { + beast::get_lowest_layer(stream).expires_after( + std::chrono::seconds(ctx.request_timeout)); + http::read_some(stream, buffer, parser, ec); + if (ec) break; + } + } + + result_body = parser.get().body(); + beast::get_lowest_layer(stream).cancel(); + stream.shutdown(ec); + } catch (std::exception& e) { + result_code = -1; + result_body = e.what(); + } + +done: + *status_code = result_code; + if (!result_body.empty()) { + *response_body = g_host->strdup(result_body.c_str()); + } + return (result_code >= 200 && result_code < 300) ? 0 : -1; +} + +// ============================================================ +// Service implementations +// ============================================================ +static int http_post_json( + const char* host, const char* port, + const char* target, const char* body, + const char* headers_json, + char** response_body, int* status_code) +{ + return do_post_stream(host, port, target, body, headers_json, + nullptr, nullptr, response_body, status_code); +} + +static int http_post_stream( + const char* host, const char* port, + const char* target, const char* body, + const char* headers_json, + dstalk_stream_cb cb, void* userdata, + char** response_body, int* status_code) +{ + return do_post_stream(host, port, target, body, headers_json, + cb, userdata, response_body, status_code); +} + +static dstalk_http_service_t g_service = { + http_post_json, + http_post_stream +}; + +// ============================================================ +// Plugin lifecycle +// ============================================================ +static int on_init(const dstalk_host_api_t* host) { + g_host = host; + + // Query config service (declared dependency) + g_config_svc = (dstalk_config_service_t*)host->query_service("config", 1); + + return host->register_service("http", 1, &g_service); +} + +static void on_shutdown() { + // nothing to clean up +} + +static dstalk_plugin_info_t g_info = { + "http", // name + "1.0.0", // version + "HTTP/HTTPS client service using Boost.Beast + OpenSSL", // description + DSTALK_API_VERSION, // api_version + {"config", nullptr}, // dependencies + on_init, // on_init + on_shutdown, // on_shutdown + nullptr // on_event +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { + return &g_info; +} diff --git a/plugins/session/CMakeLists.txt b/plugins/session/CMakeLists.txt new file mode 100644 index 0000000..bc3eb70 --- /dev/null +++ b/plugins/session/CMakeLists.txt @@ -0,0 +1,18 @@ +add_library(plugin-session SHARED src/session_plugin.cpp) + +target_include_directories(plugin-session PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-session PRIVATE dstalk) + +find_package(Boost REQUIRED CONFIG) +target_link_libraries(plugin-session PRIVATE boost::boost) +target_compile_definitions(plugin-session PRIVATE + BOOST_ALL_NO_LIB BOOST_ERROR_CODE_HEADER_ONLY BOOST_JSON_HEADER_ONLY) + +set_target_properties(plugin-session PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/session/src/session_plugin.cpp b/plugins/session/src/session_plugin.cpp new file mode 100644 index 0000000..3f2e56a --- /dev/null +++ b/plugins/session/src/session_plugin.cpp @@ -0,0 +1,263 @@ +// plugin-session: 会话管理服务插件 +// 提供 dstalk_session_service_t vtable 实现 +// 依赖: file_io (save/load 需要文件操作) +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_types.h" +#include "dstalk/dstalk_services.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace json = boost::json; + +// ============================================================ +// 内部 C++ 数据结构 +// ============================================================ + +static const dstalk_host_api_t* g_host = nullptr; + +// 缓存 file_io 服务指针 +static const dstalk_file_io_service_t* g_file_io = nullptr; + +// 内部消息结构(C++ 易用,外部暴露 C struct) +struct InternalMessage { + std::string role; + std::string content; + std::string tool_call_id; + std::string tool_calls_json; +}; + +// 会话历史 +static std::vector g_history; + +// history() 返回的 C 数组缓存(生命周期到下次 history() 或 shutdown) +static std::vector g_cached_history; + +// ============================================================ +// Token 计数工具(内联,避免硬依赖 context 头文件) +// ============================================================ + +static bool is_ascii(unsigned char c) { return c < 0x80; } + +static bool starts_cjk(unsigned char c) { + return c >= 0xE4 && c <= 0xE9; +} + +static size_t count_tokens_one(const std::string& text) { + size_t ascii_chars = 0; + size_t chinese_chars = 0; + size_t other_chars = 0; + + size_t i = 0; + while (i < text.size()) { + unsigned char c = static_cast(text[i]); + + if (is_ascii(c)) { + ascii_chars++; + i += 1; + } else if (starts_cjk(c)) { + chinese_chars++; + i += 3; + } else if (c >= 0xC0 && c < 0xE0) { + other_chars++; + i += 2; + } else if (c >= 0xE0 && c < 0xF0) { + other_chars++; + i += 3; + } else if (c >= 0xF0 && c < 0xF8) { + other_chars++; + i += 4; + } else { + other_chars++; + i += 1; + } + } + + size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3); + return content_tokens + 4; // +4 per message overhead +} + +static size_t count_tokens_all(const std::vector& msgs) { + size_t total = 0; + for (const auto& m : msgs) { + total += count_tokens_one(m.content); + } + return total; +} + +// ============================================================ +// 辅助:刷新 C 缓存数组 +// ============================================================ + +static void rebuild_cached_history() { + // 释放旧的字符串 + for (auto& m : g_cached_history) { + if (m.role) { g_host->free(const_cast(m.role)); } + if (m.content) { g_host->free(const_cast(m.content)); } + if (m.tool_call_id) { g_host->free(const_cast(m.tool_call_id)); } + if (m.tool_calls_json){ g_host->free(const_cast(m.tool_calls_json)); } + } + g_cached_history.clear(); + + // 重建 + g_cached_history.reserve(g_history.size()); + for (const auto& im : g_history) { + dstalk_message_t cm; + cm.role = im.role.empty() ? nullptr : g_host->strdup(im.role.c_str()); + cm.content = im.content.empty() ? nullptr : g_host->strdup(im.content.c_str()); + cm.tool_call_id = im.tool_call_id.empty() ? nullptr : g_host->strdup(im.tool_call_id.c_str()); + cm.tool_calls_json = im.tool_calls_json.empty() ? nullptr : g_host->strdup(im.tool_calls_json.c_str()); + g_cached_history.push_back(cm); + } +} + +// ============================================================ +// Session 服务 vtable 实现 +// ============================================================ + +static void session_add(const dstalk_message_t* msg) { + if (!msg) return; + InternalMessage im; + if (msg->role) im.role = msg->role; + if (msg->content) im.content = msg->content; + if (msg->tool_call_id) im.tool_call_id = msg->tool_call_id; + if (msg->tool_calls_json) im.tool_calls_json = msg->tool_calls_json; + g_history.push_back(std::move(im)); +} + +static void session_clear() { + g_history.clear(); +} + +static int session_save(const char* path) { + if (!path || !g_file_io) return -1; + + std::string data; + for (const auto& m : g_history) { + json::object entry; + entry["role"] = m.role; + entry["content"] = m.content; + if (!m.tool_call_id.empty()) + entry["tool_call_id"] = m.tool_call_id; + if (!m.tool_calls_json.empty()) + entry["tool_calls_json"] = m.tool_calls_json; + data += json::serialize(entry); + data += '\n'; + } + return g_file_io->write(path, data.c_str()); +} + +static int session_load(const char* path) { + if (!path || !g_file_io) return -1; + + char* content = nullptr; + int ret = g_file_io->read(path, &content); + if (ret != 0 || !content) return -1; + + std::string data(content); + std::free(content); + + std::vector parsed; + size_t pos = 0; + while (pos < data.size()) { + size_t nl = data.find('\n', pos); + std::string line = (nl != std::string::npos) + ? data.substr(pos, nl - pos) : data.substr(pos); + pos = (nl != std::string::npos) ? nl + 1 : data.size(); + if (line.empty()) continue; + + try { + auto obj = json::parse(line).as_object(); + auto* role_j = obj.if_contains("role"); + auto* content_j = obj.if_contains("content"); + if (role_j && content_j && role_j->is_string() && content_j->is_string()) { + InternalMessage im; + im.role = json::value_to(*role_j); + im.content = json::value_to(*content_j); + auto* tci = obj.if_contains("tool_call_id"); + if (tci && tci->is_string()) + im.tool_call_id = json::value_to(*tci); + auto* tcj = obj.if_contains("tool_calls_json"); + if (tcj && tcj->is_string()) + im.tool_calls_json = json::value_to(*tcj); + parsed.push_back(std::move(im)); + } + } catch (const std::exception&) { + return -1; + } + } + + if (parsed.empty()) return -1; + g_history = std::move(parsed); + return 0; +} + +static const dstalk_message_t* session_history(int* out_count) { + rebuild_cached_history(); + if (out_count) *out_count = static_cast(g_cached_history.size()); + return g_cached_history.empty() ? nullptr : g_cached_history.data(); +} + +static int session_token_count() { + return static_cast(count_tokens_all(g_history)); +} + +static dstalk_session_service_t g_session_service = { + session_add, + session_clear, + session_save, + session_load, + session_history, + session_token_count +}; + +// ============================================================ +// 插件生命周期 +// ============================================================ + +static int on_init(const dstalk_host_api_t* host) { + g_host = host; + + // 查询依赖服务: file_io + void* raw = host->query_service("file_io", 1); + if (!raw) { + host->log(DSTALK_LOG_ERROR, "[plugin-session] required service 'file_io' not found"); + return -1; + } + g_file_io = static_cast(raw); + + // 注册自身服务 + return host->register_service("session", 1, &g_session_service); +} + +static void on_shutdown() { + // 释放缓存 + rebuild_cached_history(); // 这会先清理旧字符串再清空 + g_cached_history.clear(); // 确保空 + g_history.clear(); + g_file_io = nullptr; + g_host = nullptr; +} + +static dstalk_plugin_info_t g_info = { + "session", + "1.0.0", + "Session management plugin with save/load support", + DSTALK_API_VERSION, + {"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, + on_init, + on_shutdown, + nullptr +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { + return &g_info; +} diff --git a/plugins/tools/CMakeLists.txt b/plugins/tools/CMakeLists.txt new file mode 100644 index 0000000..e4f3d56 --- /dev/null +++ b/plugins/tools/CMakeLists.txt @@ -0,0 +1,18 @@ +add_library(plugin-tools SHARED src/tools_plugin.cpp) + +target_include_directories(plugin-tools PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk-core/include +) + +target_link_libraries(plugin-tools PRIVATE dstalk) + +find_package(Boost REQUIRED CONFIG) +target_link_libraries(plugin-tools PRIVATE boost::boost) +target_compile_definitions(plugin-tools PRIVATE + BOOST_ALL_NO_LIB BOOST_ERROR_CODE_HEADER_ONLY BOOST_JSON_HEADER_ONLY) + +set_target_properties(plugin-tools PROPERTIES + PREFIX "" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins" +) diff --git a/plugins/tools/src/tools_plugin.cpp b/plugins/tools/src/tools_plugin.cpp new file mode 100644 index 0000000..2096526 --- /dev/null +++ b/plugins/tools/src/tools_plugin.cpp @@ -0,0 +1,248 @@ +// plugin-tools: 工具注册服务插件 +// 提供 dstalk_tools_service_t vtable 实现 +// 依赖: file_io (内置 file_read / file_write 工具) +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_types.h" +#include "dstalk/dstalk_services.h" + +#include + +#include +#include +#include +#include +#include + +namespace json = boost::json; + +// ============================================================ +// 内部数据结构 +// ============================================================ + +static const dstalk_host_api_t* g_host = nullptr; +static const dstalk_file_io_service_t* g_file_io = nullptr; + +struct ToolDef { + std::string name; + std::string description; + std::string parameters_schema; + dstalk_tool_handler_fn handler; +}; + +static std::vector g_tools; + +// ============================================================ +// 内置工具: file_read, file_write +// ============================================================ + +static char* builtin_file_read(const char* args_json) { + if (!g_file_io) { + return g_host->strdup("{\"error\":\"file_io service not available\"}"); + } + + try { + auto args = json::parse(args_json).as_object(); + auto* path_j = args.if_contains("path"); + if (!path_j || !path_j->is_string()) { + return g_host->strdup("{\"error\":\"missing 'path' argument\"}"); + } + std::string path = json::value_to(*path_j); + + char* content = nullptr; + int ret = g_file_io->read(path.c_str(), &content); + if (ret != 0 || !content) { + return g_host->strdup("{\"error\":\"failed to read file\"}"); + } + + std::string escaped_content = json::serialize(json::string(content)); + std::free(content); + + std::string result = "{\"content\":" + escaped_content + "}"; + return g_host->strdup(result.c_str()); + } catch (const std::exception& e) { + std::string err = "{\"error\":\"file_read error: " + std::string(e.what()) + "\"}"; + return g_host->strdup(err.c_str()); + } +} + +static char* builtin_file_write(const char* args_json) { + if (!g_file_io) { + return g_host->strdup("{\"error\":\"file_io service not available\"}"); + } + + try { + auto args = json::parse(args_json).as_object(); + auto* path_j = args.if_contains("path"); + auto* content_j = args.if_contains("content"); + if (!path_j || !path_j->is_string()) { + return g_host->strdup("{\"error\":\"missing 'path' argument\"}"); + } + if (!content_j || !content_j->is_string()) { + return g_host->strdup("{\"error\":\"missing 'content' argument\"}"); + } + + std::string path = json::value_to(*path_j); + std::string content = json::value_to(*content_j); + + int ret = g_file_io->write(path.c_str(), content.c_str()); + if (ret != 0) { + return g_host->strdup("{\"error\":\"failed to write file\"}"); + } + + return g_host->strdup("{\"success\":true}"); + } catch (const std::exception& e) { + std::string err = "{\"error\":\"file_write error: " + std::string(e.what()) + "\"}"; + return g_host->strdup(err.c_str()); + } +} + +// ============================================================ +// Tools 服务 vtable 实现 +// ============================================================ + +static int tools_register_tool(const char* name, const char* desc, + const char* params_schema, + dstalk_tool_handler_fn handler) { + if (!name || !handler) return -1; + + // 如果已存在同名工具,先注销 + tools_unregister_tool(name); + + ToolDef td; + td.name = name; + td.description = desc ? desc : ""; + td.parameters_schema = params_schema ? params_schema : ""; + td.handler = handler; + g_tools.push_back(std::move(td)); + return 0; +} + +static void tools_unregister_tool(const char* name) { + if (!name) return; + std::string n(name); + g_tools.erase( + std::remove_if(g_tools.begin(), g_tools.end(), + [&n](const ToolDef& t) { return t.name == n; }), + g_tools.end()); +} + +static char* tools_get_tools_json() { + json::array tools_arr; + + for (const auto& t : g_tools) { + json::object tool_obj; + tool_obj["type"] = "function"; + + json::object func_obj; + func_obj["name"] = t.name; + func_obj["description"] = t.description; + + if (!t.parameters_schema.empty()) { + func_obj["parameters"] = json::parse(t.parameters_schema); + } else { + json::object empty_params; + empty_params["type"] = "object"; + empty_params["properties"] = json::object{}; + func_obj["parameters"] = empty_params; + } + + tool_obj["function"] = func_obj; + tools_arr.push_back(tool_obj); + } + + std::string result = json::serialize(tools_arr); + return g_host->strdup(result.c_str()); +} + +static char* tools_execute(const char* name, const char* args_json) { + if (!name) { + return g_host->strdup("{\"error\":\"tool name is null\"}"); + } + + std::string n(name); + ToolDef* found = nullptr; + for (auto& t : g_tools) { + if (t.name == n) { + found = &t; + break; + } + } + + if (!found) { + json::object err_obj; + err_obj["error"] = "unknown tool: " + n; + return g_host->strdup(json::serialize(err_obj).c_str()); + } + + try { + const char* args = args_json ? args_json : "{}"; + return found->handler(args); + } catch (const std::exception& e) { + json::object err_obj; + err_obj["error"] = std::string("tool execution failed: ") + e.what(); + return g_host->strdup(json::serialize(err_obj).c_str()); + } catch (...) { + return g_host->strdup("{\"error\":\"tool execution failed: unknown error\"}"); + } +} + +static dstalk_tools_service_t g_tools_service = { + tools_register_tool, + tools_unregister_tool, + tools_get_tools_json, + tools_execute +}; + +// ============================================================ +// 插件生命周期 +// ============================================================ + +static int on_init(const dstalk_host_api_t* host) { + g_host = host; + + // 查询依赖服务: file_io + void* raw = host->query_service("file_io", 1); + if (!raw) { + host->log(DSTALK_LOG_ERROR, "[plugin-tools] required service 'file_io' not found"); + return -1; + } + g_file_io = static_cast(raw); + + // 向自身注册内置工具 + tools_register_tool( + "file_read", + "Read the contents of a file at the given path", + "{\"type\":\"object\",\"properties\":{\"path\":{\"type\":\"string\",\"description\":\"Path to the file to read\"}},\"required\":[\"path\"]}", + builtin_file_read + ); + + tools_register_tool( + "file_write", + "Write content to a file at the given path", + "{\"type\":\"object\",\"properties\":{\"path\":{\"type\":\"string\",\"description\":\"Path to the file to write\"},\"content\":{\"type\":\"string\",\"description\":\"Content to write to the file\"}},\"required\":[\"path\",\"content\"]}", + builtin_file_write + ); + + return host->register_service("tools", 1, &g_tools_service); +} + +static void on_shutdown() { + g_tools.clear(); + g_file_io = nullptr; + g_host = nullptr; +} + +static dstalk_plugin_info_t g_info = { + "tools", + "1.0.0", + "Tool registration and execution plugin with built-in file tools", + DSTALK_API_VERSION, + {"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, + on_init, + on_shutdown, + nullptr +}; + +extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { + return &g_info; +} diff --git a/scripts/ci-build.bat b/scripts/ci-build.bat new file mode 100644 index 0000000..b61daf4 --- /dev/null +++ b/scripts/ci-build.bat @@ -0,0 +1,25 @@ +@echo off +setlocal + +set PROJECT_DIR=%~dp0.. +set BUILD_DIR=%PROJECT_DIR%\build + +echo === dstalk CI Build === +echo Project: %PROJECT_DIR% + +if not exist "%BUILD_DIR%" mkdir "%BUILD_DIR%" +cd /d "%BUILD_DIR%" + +echo --- CMake Configure --- +cmake "%PROJECT_DIR%" -G Ninja -DCMAKE_BUILD_TYPE=Release -DDSTALK_BUILD_TESTS=ON -DDSTALK_BUILD_GUI=OFF +if errorlevel 1 exit /b 1 + +echo --- Build --- +cmake --build . --parallel +if errorlevel 1 exit /b 1 + +echo --- Test --- +ctest --output-on-failure --parallel 4 +if errorlevel 1 exit /b 1 + +echo === CI Build PASSED === diff --git a/scripts/ci-build.sh b/scripts/ci-build.sh new file mode 100644 index 0000000..7ebf803 --- /dev/null +++ b/scripts/ci-build.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +set -euo pipefail + +PROJECT_DIR="$(cd "$(dirname "$0")/.." && pwd)" +BUILD_DIR="${PROJECT_DIR}/build" + +echo "=== dstalk CI Build ===" +echo "Project: ${PROJECT_DIR}" + +# 创建构建目录 +mkdir -p "${BUILD_DIR}" +cd "${BUILD_DIR}" + +# CMake 配置 +echo "--- CMake Configure ---" +cmake "${PROJECT_DIR}" -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DDSTALK_BUILD_TESTS=ON \ + -DDSTALK_BUILD_GUI=OFF + +# 编译 +echo "--- Build ---" +cmake --build . --parallel + +# 运行测试 +echo "--- Test ---" +ctest --output-on-failure --parallel 4 + +echo "=== CI Build PASSED ===" diff --git a/tests/smoke_test.cpp b/tests/smoke_test.cpp index cf7bab4..1ae1f65 100644 --- a/tests/smoke_test.cpp +++ b/tests/smoke_test.cpp @@ -1,4 +1,10 @@ -#include "dstalk/dstalk_api.h" +// ============================================================================ +// smoke_test.cpp — 插件化架构烟雾测试 +// ============================================================================ +// 测试: 核心初始化、插件加载、服务查询、file_io、session 功能 +// ============================================================================ + +#include "dstalk/dstalk_host.h" #include #include @@ -11,6 +17,7 @@ int main() const auto dir = std::filesystem::temp_directory_path() / "dstalk-smoke-test"; std::filesystem::create_directories(dir); + // 写一个配置文件用于初始化 const auto config_path = dir / "config.toml"; { std::ofstream config(config_path); @@ -18,71 +25,406 @@ int main() << "provider = \"deepseek\"\n" << "base_url = \"https://api.deepseek.com/v1\"\n" << "api_key = \"test-key\"\n" - << "model = \"deepseek-chat\"\n"; + << "model = \"deepseek-v4-pro\"\n"; } + // 初始化主机(会自动扫描 plugins/ 加载插件) if (dstalk_init(config_path.string().c_str()) != 0) { std::cerr << "dstalk_init failed\n"; return 1; } + std::cout << "[OK] dstalk_init succeeded\n"; - const auto file_path = dir / "sample.txt"; - constexpr const char* sample_content = "hello dstalk\nquote=\"yes\" tab=\t slash=\\"; - if (dstalk_file_write(file_path.string().c_str(), sample_content) != 0) { - std::cerr << "dstalk_file_write failed\n"; - dstalk_destroy(); - return 1; + // 验证插件列表 + { + char* list_json = nullptr; + int ret = dstalk_plugin_list(&list_json); + if (ret == 0 && list_json) { + std::cout << "[OK] plugins loaded: " << list_json << "\n"; + dstalk_free(list_json); + } else { + std::cerr << "[WARN] dstalk_plugin_list returned: " << ret << "\n"; + } } - char* content = nullptr; - if (dstalk_file_read(file_path.string().c_str(), &content) != 0 || !content) { - std::cerr << "dstalk_file_read failed\n"; - dstalk_destroy(); - return 1; + // 测试服务查询: file_io + auto* file_io = static_cast( + dstalk_service_query("file_io", 1)); + if (file_io) { + std::cout << "[OK] file_io service found\n"; + + // 测试写入 + const auto file_path = dir / "sample.txt"; + constexpr const char* sample_content = "hello dstalk\nquote=\"yes\" tab=\t slash=\\"; + if (file_io->write(file_path.string().c_str(), sample_content) == 0) { + std::cout << "[OK] file_io->write succeeded\n"; + } else { + std::cerr << "[FAIL] file_io->write failed\n"; + dstalk_shutdown(); + return 1; + } + + // 测试读取 + char* content = nullptr; + if (file_io->read(file_path.string().c_str(), &content) == 0 && content) { + bool ok = std::strcmp(content, sample_content) == 0; + std::free(content); + if (ok) { + std::cout << "[OK] file_io->read content matches\n"; + } else { + std::cerr << "[FAIL] file_io->read content mismatch\n"; + dstalk_shutdown(); + return 1; + } + } else { + std::cerr << "[FAIL] file_io->read failed\n"; + dstalk_shutdown(); + return 1; + } + } else { + std::cerr << "[WARN] file_io service not found (plugin may not be in plugins/ dir)\n"; } - const bool ok = std::strcmp(content, sample_content) == 0; - dstalk_free_string(content); - if (!ok) { - std::cerr << "unexpected file content\n"; - dstalk_destroy(); - return 1; + // 测试服务查询: session + auto* session = static_cast( + dstalk_service_query("session", 1)); + if (session) { + std::cout << "[OK] session service found\n"; + + // 测试 session save/load + const auto session_path = dir / "session.jsonl"; + const auto saved_path = dir / "session-saved.jsonl"; + constexpr const char* session_content = + "{\"role\":\"user\",\"content\":\"line\\n\\\"quote\\\"\\\\slash\"}\n" + "{\"role\":\"assistant\",\"content\":\"ok\\tready\"}\n"; + + if (file_io) { + file_io->write(session_path.string().c_str(), session_content); + } + + if (session->load(session_path.string().c_str()) == 0) { + std::cout << "[OK] session->load succeeded\n"; + } else { + std::cerr << "[FAIL] session->load failed\n"; + dstalk_shutdown(); + return 1; + } + + if (session->save(saved_path.string().c_str()) == 0) { + std::cout << "[OK] session->save succeeded\n"; + } else { + std::cerr << "[FAIL] session->save failed\n"; + dstalk_shutdown(); + return 1; + } + + // 验证保存的内容 + if (file_io) { + char* saved = nullptr; + if (file_io->read(saved_path.string().c_str(), &saved) == 0 && saved) { + bool session_ok = std::strcmp(saved, session_content) == 0; + std::free(saved); + if (session_ok) { + std::cout << "[OK] session content matches after save/load\n"; + } else { + std::cerr << "[FAIL] session content mismatch after save/load\n"; + dstalk_shutdown(); + return 1; + } + } + } + + // 测试 token 计数 + int tokens = session->token_count(); + std::cout << "[OK] session->token_count: " << tokens << "\n"; + + // 测试 history + int count = 0; + session->history(&count); + std::cout << "[OK] session->history count: " << count << "\n"; + + // 测试 clear + session->clear(); + session->history(&count); + if (count == 0) { + std::cout << "[OK] session->clear succeeded\n"; + } + } else { + std::cerr << "[WARN] session service not found\n"; } - const auto session_path = dir / "session.jsonl"; - const auto saved_session_path = dir / "session-saved.jsonl"; - constexpr const char* session_content = - "{\"role\":\"user\",\"content\":\"line\\n\\\"quote\\\"\\\\slash\"}\n" - "{\"role\":\"assistant\",\"content\":\"ok\\tready\"}\n"; - if (dstalk_file_write(session_path.string().c_str(), session_content) != 0) { - std::cerr << "session fixture write failed\n"; - dstalk_destroy(); - return 1; - } - if (dstalk_session_load(session_path.string().c_str()) != 0) { - std::cerr << "dstalk_session_load failed\n"; - dstalk_destroy(); - return 1; - } - if (dstalk_session_save(saved_session_path.string().c_str()) != 0) { - std::cerr << "dstalk_session_save failed\n"; - dstalk_destroy(); - return 1; + // 测试服务查询: ai(可能因为没有真实 API key 而失败,但服务应存在) + const char* ai_provider = dstalk_config_get("ai.provider"); + if (!ai_provider) ai_provider = "ai.deepseek"; + auto* ai = static_cast( + dstalk_service_query(ai_provider, 1)); + if (ai) { + std::cout << "[OK] ai service found\n"; + } else { + std::cerr << "[WARN] ai service not found\n"; } - char* saved_session = nullptr; - if (dstalk_file_read(saved_session_path.string().c_str(), &saved_session) != 0 || !saved_session) { - std::cerr << "saved session read failed\n"; - dstalk_destroy(); - return 1; + // 测试服务查询: config + auto* config_svc = static_cast( + dstalk_service_query("config", 1)); + if (config_svc) { + std::cout << "[OK] config service found\n"; + const char* val = config_svc->get("api.model"); + if (val) { + std::cout << "[OK] config->get(\"api.model\"): " << val << "\n"; + } + } else { + std::cerr << "[WARN] config service not found\n"; } - const bool session_ok = std::strcmp(saved_session, session_content) == 0; - dstalk_free_string(saved_session); - dstalk_destroy(); - if (!session_ok) { - std::cerr << "unexpected saved session content\n"; - return 1; + // 测试 dstalk_config_get(主机级配置 API) + const char* model = dstalk_config_get("api.model"); + if (model) { + std::cout << "[OK] dstalk_config_get(\"api.model\"): " << model << "\n"; } + + // 测试 dstalk_log + dstalk_log(DSTALK_LOG_INFO, "Smoke test completed successfully"); + + // ======================================================================== + // 扩展测试块 C2: null-safety / 转义边界 / tools 调用链 / session 健壮性 + // ======================================================================== + std::cout << "\n--- Extended Smoke Tests (C2) ---\n"; + + // 提前查询 tools 服务,供后续测试块使用 + auto* tools = static_cast( + dstalk_service_query("tools", 1)); + + // ---- 1. Null-safety 测试 ---- + // 对所有服务 API 传 null 参数,验证不崩溃且返回错误 + std::cout << "\n[Block] Null-safety tests\n"; + + if (file_io) { + char* dummy = nullptr; + int ret = file_io->read(nullptr, &dummy); + if (ret != 0) { + std::cout << "[OK] file_io->read(nullptr, ...) returned error (" << ret << ")\n"; + } else { + std::cerr << "[FAIL] file_io->read(nullptr, ...) should return error\n"; + } + + ret = file_io->write(nullptr, "test_content"); + if (ret != 0) { + std::cout << "[OK] file_io->write(nullptr, ...) returned error (" << ret << ")\n"; + } else { + std::cerr << "[FAIL] file_io->write(nullptr, ...) should return error\n"; + } + + // read 的 content 参数也为 null + ret = file_io->read("dummy_path", nullptr); + if (ret != 0) { + std::cout << "[OK] file_io->read(path, nullptr) returned error (" << ret << ")\n"; + } else { + std::cerr << "[FAIL] file_io->read(path, nullptr) should return error\n"; + } + + // write 的 content 参数为 null + ret = file_io->write("dummy_path", nullptr); + if (ret != 0) { + std::cout << "[OK] file_io->write(path, nullptr) returned error (" << ret << ")\n"; + } else { + std::cerr << "[FAIL] file_io->write(path, nullptr) should return error\n"; + } + } else { + std::cerr << "[WARN] file_io service not available for null-safety tests\n"; + } + + if (session) { + session->add(nullptr); + std::cout << "[OK] session->add(nullptr) did not crash\n"; + + int ret = session->save(nullptr); + if (ret != 0) { + std::cout << "[OK] session->save(nullptr) returned error (" << ret << ")\n"; + } else { + std::cerr << "[FAIL] session->save(nullptr) should return error\n"; + } + + ret = session->load(nullptr); + if (ret != 0) { + std::cout << "[OK] session->load(nullptr) returned error (" << ret << ")\n"; + } else { + std::cerr << "[FAIL] session->load(nullptr) should return error\n"; + } + } else { + std::cerr << "[WARN] session service not available for null-safety tests\n"; + } + + if (tools) { + char* result = tools->execute(nullptr, nullptr); + if (result) { + // 实现返回了错误字符串(如 {"error":"tool name is null"}),未崩溃 + std::cout << "[OK] tools->execute(nullptr, nullptr) did not crash" + << " (returned: " << result << ")\n"; + dstalk_free(result); + } else { + std::cout << "[OK] tools->execute(nullptr, nullptr) returned null without crash\n"; + } + } else { + std::cerr << "[WARN] tools service not available for null-safety tests\n"; + } + + if (config_svc) { + const char* val = config_svc->get(nullptr); + if (val == nullptr) { + std::cout << "[OK] config->get(nullptr) returned nullptr\n"; + } else { + std::cerr << "[FAIL] config->get(nullptr) should return nullptr\n"; + } + + int ret = config_svc->set(nullptr, nullptr); + if (ret != 0) { + std::cout << "[OK] config->set(nullptr, nullptr) returned error (" << ret << ")\n"; + } else { + std::cerr << "[FAIL] config->set(nullptr, nullptr) should return error\n"; + } + + // set 的 value 为 null + ret = config_svc->set("some.key", nullptr); + if (ret != 0) { + std::cout << "[OK] config->set(key, nullptr) returned error (" << ret << ")\n"; + } else { + std::cerr << "[FAIL] config->set(key, nullptr) should return error\n"; + } + } else { + std::cerr << "[WARN] config service not available for null-safety tests\n"; + } + + // ---- 2. 转义边界测试 ---- + // 写入含特殊字符的内容,读回后验证内容一致 + std::cout << "\n[Block] Escape boundary tests\n"; + + if (file_io) { + // 构造包含各种特殊字节的内容: + // - 实际换行符 (0x0A) + // - 实际双引号 (0x22) + // - 实际反斜杠 (0x5C) + // - 实际制表符 (0x09) + // - 以及字面上的 \n \" \\ \t 转义序列文本 + constexpr const char* escape_content = + "line1\nline2\n" + "quote=\"yes\"\n" + "backslash=\\path\n" + "tab=\there\n" + "literal-escapes: newline=\\n quote=\\\" backslash=\\\\ tab=\\t\n" + "endswithbackslash\\\\\n" + "mixed\\t\\\"quoted\\\"\\\\path\n"; + + const auto escape_path = dir / "escape_test.txt"; + + if (file_io->write(escape_path.string().c_str(), escape_content) == 0) { + std::cout << "[OK] escape content write succeeded\n"; + + char* read_back = nullptr; + if (file_io->read(escape_path.string().c_str(), &read_back) == 0 && read_back) { + bool match = (std::strcmp(read_back, escape_content) == 0); + if (match) { + std::cout << "[OK] escape content round-trip matches" + << " (length=" << std::strlen(escape_content) << ")\n"; + } else { + std::cerr << "[FAIL] escape content round-trip mismatch\n" + << " expected length: " << std::strlen(escape_content) << "\n" + << " got length: " << std::strlen(read_back) << "\n"; + } + std::free(read_back); + } else { + std::cerr << "[FAIL] escape content read-back failed\n"; + } + } else { + std::cerr << "[FAIL] escape content write failed\n"; + } + } else { + std::cerr << "[WARN] file_io service not available for escape tests\n"; + } + + // ---- 3. Tools 调用链测试 ---- + // 通过 tools->execute("file_read", ...) 验证内置工具可正确调用 file_io + std::cout << "\n[Block] Tools call chain tests\n"; + + if (tools && file_io) { + // 准备测试文件 + const auto chain_path = dir / "tool_chain_test.txt"; + constexpr const char* chain_content = "tools-chain-ok\n"; + file_io->write(chain_path.string().c_str(), chain_content); + + // 用 generic_string() 获取正斜杠路径,避免 JSON 中反斜杠转义问题 + std::string generic_path = chain_path.generic_string(); + std::string args_json = "{\"path\":\"" + generic_path + "\"}"; + + char* result = tools->execute("file_read", args_json.c_str()); + if (result) { + std::cout << "[OK] tools->execute(\"file_read\", ...) returned result\n"; + // 验证返回的 JSON 中包含原始文件内容 + if (std::strstr(result, "tools-chain-ok")) { + std::cout << "[OK] tools->execute chain correctly called file_io\n"; + } else { + std::cout << "[WARN] tools->execute result does not contain expected content: " + << result << "\n"; + } + dstalk_free(result); + } else { + std::cout << "[WARN] tools->execute(\"file_read\", ...) returned null" + << " (tool may not be registered)\n"; + } + + // 额外测试:查询 tools 返回的工具列表 + char* tools_json = tools->get_tools_json(); + if (tools_json) { + std::cout << "[OK] tools->get_tools_json() returned: " << tools_json << "\n"; + dstalk_free(tools_json); + } else { + std::cout << "[WARN] tools->get_tools_json() returned null\n"; + } + } else { + std::cerr << "[WARN] tools or file_io service not available for chain tests\n"; + } + + // ---- 4. Session 健壮性测试 ---- + // session->add(nullptr) 后验证 history 不变 + // session->clear 后验证 token_count 为 0 + std::cout << "\n[Block] Session robustness tests\n"; + + if (session) { + // 记录 add(nullptr) 前的 history 计数 + int count_before = 0; + session->history(&count_before); + + // 传 null 不应改变 history + session->add(nullptr); + + int count_after = 0; + session->history(&count_after); + + if (count_before == count_after) { + std::cout << "[OK] session->add(nullptr) did not change history count" + << " (before=" << count_before << ", after=" << count_after << ")\n"; + } else { + std::cerr << "[FAIL] session->add(nullptr) changed history count: " + << count_before << " -> " << count_after << "\n"; + } + + // clear 后 token_count 应为 0 + session->clear(); + int tokens = session->token_count(); + if (tokens == 0) { + std::cout << "[OK] session->token_count() == 0 after clear\n"; + } else { + std::cerr << "[FAIL] session->token_count() == " << tokens + << " after clear, expected 0\n"; + } + } else { + std::cerr << "[WARN] session service not available for robustness tests\n"; + } + + // 清理 + dstalk_shutdown(); + std::cout << "[OK] dstalk_shutdown succeeded\n"; + + std::cout << "\n=== All smoke tests passed ===\n"; return 0; } diff --git a/说明.txt b/说明.txt index 168bfc6..56b5ce9 100644 --- a/说明.txt +++ b/说明.txt @@ -1,6 +1,11 @@ + + 软件名称: dstalk +改下功能架构,dstalk的核心做成插件化可分离的架构,支持dll的动态加载和引入,支持功能的动态注册和更新,支持版本管理和接口,所有功能插件化,分为无依赖的基础插件和依赖别的插件的插件,dstalk只作为插件注册平台和调度管理中心和底层应用基础,所有插件只需引用dstalk即可实现,dstalk支持什么平台插件就支持什么平台,插件需要编译成dll + + 网址: dstalk.top @@ -59,6 +64,5 @@ anthropic api:https://api.deepseek.com/anthropic 测试用模型: deepseek-v4-pro -deepseek-v4-flash 密钥请通过本地 config.toml 配置,不要提交到仓库。