diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..492fb35 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,109 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Build & run + +The repo ships its own toolchain under `tools/` (clang-cl/clang, CMake, Ninja, Conan 2 in a venv). Build scripts prefer those over PATH. + +- **One-shot build (Windows)**: `build.bat` — checks/installs toolchain, runs Conan install against `deps/conanfile.txt`, configures CMake with clang-cl, builds with Ninja into `build/`. +- **One-shot build (Linux/macOS)**: `bash build.sh [Release|Debug]` — same flow with clang (or gcc fallback) and Ninja. +- **First-time setup**: `tools/setup.bat` or `bash tools/setup.sh` installs CMake, Ninja, and Conan 2 into `tools/` without touching system PATH. +- **CI-style build**: `scripts/ci-build.sh` / `scripts/ci-build.bat` — minimal Conan + CMake + Ninja flow used by CI. +- **Incremental rebuild after edits**: `cd build && -j` from inside `build/`. No need to re-run Conan unless `deps/conanfile.txt` changes. + +Build outputs: +- Executables → `build/bin/` (`dstalk_cli`, test binaries, optional `dstalk_gui`/`dstalk_web`) +- Plugin DLLs → `build/plugins/` (each `plugin_*` target writes here, prefix stripped) + +CMake options (root `CMakeLists.txt`): +- `DSTALK_BUILD_GUI=ON` — also builds the SDL3 GUI frontend (off by default; requires SDL3 from Conan/system) +- `DSTALK_BUILD_WEB=ON` — also builds the Boost.Beast web frontend +- `DSTALK_BUILD_TESTS=ON` — on by default; turn off to skip `tests/` subdir + +## Tests + +CTest, hand-rolled `CHECK`-macro tests (no GoogleTest — see "Network-restricted environments" below). + +- **Run all tests**: `cd build && ctest --output-on-failure` +- **Run one test**: `cd build && ctest -R dstalk_smoke_test --output-on-failure` (or any other target name from `tests/CMakeLists.txt`) +- **Run a test binary directly**: `build/bin/dstalk_smoke_test.exe` +- **Coverage**: `cmake --build build --target coverage` (requires gcovr + gcov/llvm-cov; produces `build/coverage/index.html`). Only meaningful when configured with `--coverage` flags. + +Test targets and what they cover (see `tests/CMakeLists.txt` for the authoritative list): +- `dstalk_smoke_test` — loads real plugins from `build/plugins/`, end-to-end integration; also carries regression cases (R1-R4, W21.5 tool calls) +- `dstalk_host_api_test`, `dstalk_event_bus_test`, `dstalk_service_registry_test` — core unit tests; compile core sources directly +- `dstalk_plugin_loader_test` — `PluginLoader` regression tests; sets `DSTALK_TEST_PLUGINS_DIR` to `build/plugins` +- `dstalk_context_plugin_test` — token/trim/UTF-8 boundaries +- `dstalk_anthropic_plugin_test`, `dstalk_openai_plugin_test` — currently `#include` the plugin `.cpp` to reach static functions; link `ai_common` and need its include dir +- `dstalk_network_plugin_test` — `#include`s `network_plugin.cpp`; needs OpenSSL +- `dstalk_lsp_plugin_test` — tests `lsp_trim` / `lsp_frame_message` / `lsp_parse_content_length` via `lsp_internal.hpp` + +> After changing the `dstalk_host_api_t` vtable layout, **clean rebuild is mandatory** — stale `.obj` files from incremental builds will surface as segfaults in `dstalk_smoke_test` / `dstalk_host_api_test` because the test binary and plugin DLLs disagree on struct layout. `rm -rf build && bash build.sh` (or delete the specific stale `.obj` files) before re-running ctest. + +## Architecture + +Plugin-host design. One process loads a host DLL and many plugin DLLs over a C ABI. + +``` +Frontend (dstalk_cli / dstalk_gui / dstalk_web) + │ links dstalk_frontend_common (shared bootstrap: config discovery, init, + │ service queries, FrontendServices struct) + ▼ +dstalk_core.dll ─ PluginLoader · ServiceRegistry · EventBus · Config · Logging · Memory + ▲ + │ C ABI (dstalk_host.h) + │ +Plugins (loaded from build/plugins/, each plugin_*.dll exports dstalk_plugin_init) + plugins_base/ config, file_io, lsp + plugins_middle/ network, session, tools (may depend on base) + plugins_upper/ context, openai, anthropic (may depend on base + middle) + ai_common (static lib, shared by openai + anthropic) +``` + +**Plugin tiers matter for build order and dependency direction**: `plugins_upper` may depend on `plugins_middle` may depend on `plugins_base`. Never let a lower tier depend on a higher one. + +**Service vtables are the only cross-DLL contract.** Plugins register service vtables (e.g. `dstalk_ai_service_t`, `dstalk_http_service_t`, `dstalk_session_service_t`) with the host's `ServiceRegistry`. Other plugins query them by `(name, min_version)`. The vtable shapes live in `dstalk_core/include/dstalk/dstalk_services.h`; the host API the loader passes into each plugin's `on_init(const dstalk_host_api_t* host)` lives in `dstalk_host.h`. + +**`ai_common`** (`plugins_upper/ai_common/`) is a static library, not a plugin. It holds shared types (`PluginConfig`, `ToolCallAccum`, `StreamContext`) and utilities (`secure_zero`, `extract_host_port`, `serialize_tool_calls`, `free_chat_result`) used by both `openai` and `anthropic` plugins, all under `namespace dstalk_ai`. Each plugin still compiles as its own DLL and links `ai_common` privately. + +### Hard rules — violating these is undefined behavior + +These come from `docs/reference/plugin-abi.md`. They are not style; they are correctness. + +1. **Cross-DLL heap discipline.** Plugins must NOT call `std::malloc`/`std::free`/`std::strdup`/`new`/`delete` on data that crosses the host↔plugin boundary. Use `host->alloc` / `host->free` / `host->strdup` (passed in via `on_init`). Windows /MD CRTs and even some Linux/libc configs give each DLL its own heap — mismatched alloc/free crashes. + +2. **API version is a hard match.** `dstalk_plugin_info_t.api_version` must equal `DSTALK_API_VERSION` (currently 1). The loader rejects mismatches. There is no backward compat — rebuild plugins against the new host. + +3. **String ownership.** `dstalk_chat_result_t.content` / `.error` / `.tool_calls_json` are allocated with `host->strdup` by the producing plugin; the **caller** frees them with `host->free`. Never return `std::string::c_str()` or stack buffers across the ABI. + +4. **C ABI + atomic globals.** Service vtable function pointers and cached service pointers stored in plugin global state must be `std::atomic` with `memory_order_acquire`/`release`. Raw pointers race during shutdown (the anthropic plugin's `g_config` was a real data race fixed in W17). Same applies to `g_host`, `g_http`, etc. + +5. **No new on plugin info strings.** `name`/`version`/`description` in `dstalk_plugin_info_t` only need to live for the duration of `dstalk_plugin_init()` — string literals are fine; the host copies them. + +6. **`api_key` lifecycle.** On `on_shutdown`, overwrite key bytes via `volatile char*` loop, then clear the string. Use `dstalk_ai::secure_zero()` from `ai_common`. + +### Build system facts that bite + +- **Conan provides only `boost::boost`** (the umbrella target) — granular `boost::json` / `boost::asio` / `boost::beast` targets do **not** exist in this Conan setup. Don't migrate to them. +- Every Boost-using target needs its own `find_package(Boost REQUIRED CONFIG)` call before `target_link_libraries(... boost::boost)`. Doing it once at the root is not enough; Conan-generated config exposes the target per-subdir. +- `dstalk_boost_config` (defined in `dstalk_core/CMakeLists.txt`) is an INTERFACE target carrying `BOOST_ALL_NO_LIB` / `BOOST_ERROR_CODE_HEADER_ONLY` / `BOOST_JSON_HEADER_ONLY`. Link it from any target using ``. Boost.JSON is header-only here, so each TU using it must `#include ` in exactly one file (see `dstalk_core/src/boost_json.cpp`). +- The network plugin links `openssl::openssl` directly and calls OpenSSL C APIs (`SSL_set_tlsext_host_name`, `SSL_set1_host`). Don't remove that link "because Boost.Asio already pulls SSL" — it doesn't, at least not for these symbols. +- Plugin DLLs are written to `build/plugins/` with `PREFIX ""` (so `plugin_openai.dll`, not `libplugin_openai.dll`). The smoke test and plugin_loader_test look there. + +### Network-restricted environments + +This repo has been built in environments where outbound network is locked down. Two consequences: + +- **Do not add `FetchContent` for GitHub-hosted deps.** `git clone` over smart-HTTP fails with "early EOF" through restrictive proxies (W17.3 lesson — GoogleTest dropped, tests use hand-rolled `CHECK` macros instead). Vendor or skip. +- Conan packages must be pre-cached in `~/.conan2/p/` or available through an allowed mirror; first-run `conan install` from a clean cache will fail offline. + +## Repo-specific conventions + +- **README is in Chinese (Simplified).** Comments and docs are routinely bilingual (Chinese first, English after a `/`). Match the surrounding style in any file you edit. +- **Plugin code style.** Each AI/network plugin has the same skeleton: `g_host` (atomic), `g_cfg`, `on_init`/`on_shutdown`, service vtable, then `dstalk_plugin_init()` returning a static `dstalk_plugin_info_t`. Keep that shape when adding a provider. +- **CRT.** `CMAKE_MSVC_RUNTIME_LIBRARY=MultiThreadedDLL` (/MD). All plugins and the host must agree. + +## The `agents/` directory + +`agents/` (README, WORKFLOW.md, STATUS.md, per-agent `profile.md` files, `groups/`) describes a multi-agent collaboration mode used by the project owner with Claude. It is **documentation, not code** — nothing in the build references it. Treat its contents as historical context; do not invent or extend the "16-person team / waves / 6-stage workflow" apparatus unless the user explicitly asks for it. If the user does invoke it, the rules are in `agents/WORKFLOW.md`. diff --git a/CMakeLists.txt b/CMakeLists.txt index 8c40c2b..550c642 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,6 +12,7 @@ option(DSTALK_BUILD_WEB "Build the web UI frontend" OFF) option(DSTALK_BUILD_TESTS "Build dstalk tests" ON) add_subdirectory(dstalk_core) +add_subdirectory(dstalk_frontend_common) add_subdirectory(dstalk_cli) # 插件按依赖层级分三个目录 / Plugins split into three directories by dependency tier add_subdirectory(plugins_base) diff --git a/agents/STATUS.md b/agents/STATUS.md index 1e4e0f1..e3b96c3 100644 --- a/agents/STATUS.md +++ b/agents/STATUS.md @@ -1,6 +1,6 @@ # dstalk 实时编制状态 -> **最后更新**: 2026-05-27 +> **最后更新**: 2026-05-31 > **数据来源**: 由 `scripts/refresh_status.py` 自动扫描全部 16 个 `agents/*/profile.md` + 5 个 `agents/groups/*.md` 生成。 ## 表 1:员工状态(16 人) @@ -40,7 +40,7 @@ ## Wave 进度 -**已完成高水位**: W16.5(基于 16 份 profile.md 的 performance_log 聚合) +**已完成高水位**: W17(由 CEO 直接执行,ai_common 模块提取) -**已发现 Wave 编号**: W1.1, W2.1, W2.2, W5.1, W6.1, W7, W9.3, W9.4, W9.6, W9.10, W10.1, W10.2, W10.3, W10.4, W11, W11.1, W11.2, W11.3, W11.6, W11.7, W12, W12.1, W12.2, W12.4, W12.5, W12.6, W13.1, W13.2, W13.3, W13.4, W13.5, W13.6, W14.1, W14.2, W14.3, W14.4, W14.5, W15.1, W15.2, W15.3, W15.4, W15.5, W15.6, W15.7, W15.8, W15.9, W16.5 +**已发现 Wave 编号**: W1.1, W2.1, W2.2, W5.1, W6.1, W7, W9.3, W9.4, W9.6, W9.10, W10.1, W10.2, W10.3, W10.4, W11, W11.1, W11.2, W11.3, W11.6, W11.7, W12, W12.1, W12.2, W12.4, W12.5, W12.6, W13.1, W13.2, W13.3, W13.4, W13.5, W13.6, W14.1, W14.2, W14.3, W14.4, W14.5, W15.1, W15.2, W15.3, W15.4, W15.5, W15.6, W15.7, W15.8, W15.9, W16.5, W17 diff --git a/agents/mailroom/README.md b/agents/mailroom/README.md new file mode 100644 index 0000000..c62471f --- /dev/null +++ b/agents/mailroom/README.md @@ -0,0 +1,249 @@ +# Mailroom — agents 信箱系统 + +> **版本**: 1.0 草案 (2026-05-31 设计) +> **状态**: 试运行;格式稳定后由 CEO 决定是否在 WORKFLOW.md §15 正式纳入流程 +> **设计目标**: 让无状态的子代理之间能异步传递消息,把进行中的协调状态从 CEO 的隐式上下文搬到磁盘 + +## 1. 设计动机 + +子代理是 stateless 的——每次 spawn 都是新会话,看不到其他子代理当前在做什么。现行所有协调依赖 CEO 在 prompt 里手动列禁忌、手动转交依赖、手动追踪谁回了什么。一波 6-9 路子代理并行时,CEO 的上下文窗口很快被占满。 + +信箱让这些消息**显式落盘**: + +- CEO 派活前可以扫一遍候选执行者的 inbox,看有没有未结的 handoff / blocker +- 子代理 spawn 时可以读到自己的待办邮件,了解上下文 +- 处理完毕的邮件归档到 `archive/W/`,不丢失追溯,但不打扰当前视图 + +## 2. 目录结构 + +``` +agents/mailroom/ +├── README.md # 本文件 +├── inbox/ +│ ├── ceo/ # CEO 的收件箱 +│ ├── architect-lin/ # 各 agent 的收件箱(按需创建) +│ ├── engineer-zhao/ +│ └── ... +└── archive/ + ├── W17/ # 按 Wave 归档已处理邮件 + ├── W18/ + └── ... +``` + +- `inbox//` —— 待处理邮件(status: pending / read / in_progress) +- `archive/W/` —— 处理完毕邮件(status: closed),按 Wave 归档保留 +- 收件箱目录按需创建,未创建表示该 agent 当前无邮件 + +## 3. 权限规则 + +文件系统不强制权限;权限靠子代理 prompt 中的禁忌条款约束。 + +| 角色 | 自己的 inbox | 别人的 inbox | archive/ | +|------|--------------|--------------|----------| +| 接收者本人 | 读 / 改 status / 移到 archive | 只能投递新邮件(write-only) | 只读 | +| CEO | 所有权限 | 所有权限(含重投递、强制保留) | 所有权限 | +| 其他子代理 | — | 只能投递新邮件 | 只读 | + +派子代理时 prompt 必须显式声明可访问的邮箱范围。新增防御性规则 **R-MAIL-SCOPE**:子代理不得读写超出自身 inbox 之外的他人邮箱内容,仅可投递新邮件。 + +## 4. 消息文件命名 + +`---to-.md` + +- timestamp: `YYYY-MM-DDTHHmm` 本地时间,分钟级即可区分 +- kind: 见 §5 +- from / to: agent-id(CEO 用 `ceo`,广播用 `all`) + +示例:`2026-05-31T1430-task-ceo-to-architect-lin.md` + +## 5. 五种邮件类型 + +| kind | 用途 | 投递方 | 接收方 | +|------|------|--------|--------| +| task | CEO 派活(prompt 副本,便于事后追溯) | CEO | 执行者 | +| report | 子代理回执(done / blocked / aborted) | 执行者 | CEO | +| handoff | 子代理之间工作交接(前序产出移交下一棒) | 子代理 A | 子代理 B | +| blocker | 阻塞通知(依赖未满足) | 任何 | CEO(必抄送) | +| notice | 广播(in-flight 文件锁定 / 流程变更 / 元数据自检失败) | CEO 或系统 | all | + +## 6. Frontmatter schema + +```yaml +--- +id: msg-W17.3-001 # 全局唯一,格式 msg--<序号> +from: ceo +to: architect-lin # 单人用 agent-id;广播用 all +wave: W17.3 +kind: task # task | report | handoff | blocker | notice +status: pending # pending | read | in_progress | closed +created: 2026-05-31T14:30 +closed: # 处理完毕时填写,YYYY-MM-DDTHHmm +related: # 关联消息 id 列表(如本报告对应的 task) + - msg-W17.3-000 +fixes: # 关联 findings-registry 的 finding id(如有) + - F-13.5-1 +--- + +邮件正文(markdown) +``` + +- 必填字段:`id` / `from` / `to` / `wave` / `kind` / `status` / `created` +- 选填字段:`closed` / `related` / `fixes` +- 字符串值不带引号(YAML 1.2 标准) + +## 7. 生命周期 + +``` +[投递] [归档] + │ │ + ▼ ▼ +inbox// pending ──→ read ──→ in_progress ──→ 处理完毕 archive/W/ + status: closed +``` + +操作动作: + +| 动作 | 谁执行 | 文件操作 | status 字段 | +|------|--------|----------|--------------| +| 投递 | 任意 agent | 在 `inbox//` 创建新文件 | pending | +| 阅读 | 接收者 | frontmatter 改 status | read | +| 开工 | 接收者 | frontmatter 改 status | in_progress | +| 完成 | 接收者 | 移动文件到 `archive/W/` + 改 status | closed | +| 重投递 | 仅 CEO | 复制 archive 中的文件到 inbox(保留原件) | pending | + +接收者**不可物理删除**邮件——只允许移到 archive;如需删除由 CEO 手动操作。这是与 R-NO-FORCE-PUSH 同等级的强约束(新增 **R-MAIL-NO-DELETE**)。 + +## 8. 与现有机制的边界 + +| 机制 | 职责 | 与 mailroom 的区别 | +|------|------|---------------------| +| `STATUS.md` | 实时编制状态快照 | 由 `scripts/refresh_status.py` 扫 profile.md 生成;mailroom 是消息流 | +| `agents//profile.md` performance_log | 永久任务履历 | 任务粒度,事后追加;mailroom 是消息粒度,进行中 | +| `agents/audits/findings-registry.md` | 审计发现追踪 | 跨 Wave 持续追踪;mailroom 单条消息生命周期短(一波内闭环) | +| `WORKFLOW.md` | 流程规范 | 是规则;mailroom 是规则产生的数据 | + +**重叠处理原则**:邮件正文 ≠ 履历正文。task 邮件中的 prompt 是 CEO 派活的原始材料,report 邮件中的回执是子代理的原始报告;最终摘要仍按现有流程写到 profile.md 的 performance_log 一条短记录。**邮件保留细节,profile.md 保留摘要,两者互补不重复**。 + +审计发现的归宿仍是 findings-registry。blocker 邮件如对应某个 finding,frontmatter 的 `fixes` 字段填写 finding id,CEO INSPECT 时按 §14.4 A3 检查关联。 + +## 9. CEO 派活前 5 步检查 + +扩展 [PROMPT_TEMPLATE.md](PROMPT_TEMPLATE.md) 的 CEO 派活前 4 步检查,新增第 5 步: + +| # | 检查项 | 操作 | 写入模板字段 | +|---|--------|------|--------------| +| 1 | 列 in-flight 工作区 | 查当前有哪些子代理在跑 | 禁忌 | +| 2 | 找前序 Wave 产出 | 从 performance_log 追溯 | 前序成果 | +| 3 | 设定任务范围三档 | 必做 / 可做 / 不做 | 任务范围 | +| 4 | 统一字数上限 | 固定 250 字 | 字数上限 | +| **5** | **扫候选执行者 inbox** | `ls agents/mailroom/inbox/<候选id>/` 看 pending 的 handoff / blocker | 禁忌 / 前序成果 | + +若候选执行者 inbox 有未结的 blocker(依赖未满足)→ 优先让其处理 blocker 再派新活,否则新活也会立即变 blocker。 + +## 10. 邮件示例 + +### 10.1 task 邮件 + +文件:`inbox/architect-lin/2026-05-31T1500-task-ceo-to-architect-lin.md` + +```markdown +--- +id: msg-W18.1-001 +from: ceo +to: architect-lin +wave: W18.1 +kind: task +status: pending +created: 2026-05-31T15:00 +--- + +# W18.1 任务: 评审 mailroom 设计 + +请读 agents/mailroom/README.md 并评估: + +- 是否值得正式纳入 WORKFLOW.md §15 +- 权限规则在 prompt 层面如何强约束 +- 与 findings-registry 的边界是否清晰 + +字数上限 250。完成后请发 report 邮件回 ceo。 +``` + +### 10.2 report 邮件 + +文件:`inbox/ceo/2026-05-31T1630-report-architect-lin-to-ceo.md` + +```markdown +--- +id: msg-W18.1-002 +from: architect-lin +to: ceo +wave: W18.1 +kind: report +status: pending +created: 2026-05-31T16:30 +related: + - msg-W18.1-001 +--- + +# W18.1 评审结论 + +赞成纳入 §15。建议补充三条边界规则: + +1. 接收者不可删除别人邮箱里的邮件 +2. CEO 重投递时必须复制而非移动,保留 archive 原件 +3. blocker 邮件必须抄送 CEO + +profile.md 已追加,rating: A-。 +``` + +### 10.3 handoff 邮件 + +文件:`inbox/qa-xu/2026-05-31T1700-handoff-engineer-sun-to-qa-xu.md` + +```markdown +--- +id: msg-W18.2-003 +from: engineer-sun +to: qa-xu +wave: W18.2 +kind: handoff +status: pending +created: 2026-05-31T17:00 +related: + - msg-W18.2-001 +--- + +# W18.2 LSP 死锁修复完成,请测 + +修复了 lsp_plugin.cpp:312 的死锁。需要你跑: + +- `ctest -R lsp_plugin_test` +- `ctest -R smoke` + +测试通过后请新建一封 report 邮件发给 CEO 并抄送我。 +``` + +## 11. 后续路线(v1.0 之外) + +格式稳定 + CEO 在 2-3 个 Wave 中实战使用后,再考虑: + +- 在 WORKFLOW.md §15 加一节正式纳入流程 +- 补充 `scripts/mailroom_summary.py` 自动汇总每个 agent 的未读邮件数与最老 pending 邮件年龄 +- 在 STATUS.md 增加 邮件待办数 列 +- 扩展 `scripts/check_agents_metadata.py` 校验邮件 frontmatter + +不在 v1.0 范围。先用 1-2 个 Wave 看实际效果再说。 + +## 12. 关联文档 + +- [WORKFLOW.md](../WORKFLOW.md) — 流程主文档(mailroom 将来可能并入 §15) +- [PROMPT_TEMPLATE.md](../PROMPT_TEMPLATE.md) — 派活模板(影响第 5 步检查) +- [STATUS.md](../STATUS.md) — 实时编制状态 +- [findings-registry.md](../audits/findings-registry.md) — 审计发现追踪 +- [POSTMORTEM.md](../POSTMORTEM.md) — 防御性规则(新增 R-MAIL-SCOPE / R-MAIL-NO-DELETE) + +## 13. 变更历史 + +| 日期 | 版本 | 变更 | +|------|------|------| +| 2026-05-31 | 1.0 草案 | 初始化。CEO 与用户讨论后落地,方案 2(归档保留,不真删) | diff --git a/dstalk_cli/CMakeLists.txt b/dstalk_cli/CMakeLists.txt index f9afad1..c184964 100644 --- a/dstalk_cli/CMakeLists.txt +++ b/dstalk_cli/CMakeLists.txt @@ -15,3 +15,9 @@ find_package(Boost REQUIRED CONFIG) target_link_libraries(dstalk_cli PRIVATE dstalk boost::boost dstalk_boost_config ) + +# POSIX 平台需要 pthread (用于 std::thread spinner) +if(NOT WIN32) + find_package(Threads REQUIRED) + target_link_libraries(dstalk_cli PRIVATE Threads::Threads) +endif() diff --git a/dstalk_cli/src/main.cpp b/dstalk_cli/src/main.cpp index e0c181b..8d1844a 100644 --- a/dstalk_cli/src/main.cpp +++ b/dstalk_cli/src/main.cpp @@ -7,12 +7,15 @@ #include #include +#include #include #include #include #include +#include #include #include +#include #include #include @@ -64,6 +67,8 @@ static const dstalk_tools_service_t* g_tools = nullptr; static std::string g_current_model; static std::atomic g_quit_requested{false}; static std::atomic g_quit_via_signal{false}; +static std::atomic g_spinning{false}; +static std::thread g_spinner_thread; // ---- Ctrl+C 信号处理 / Ctrl+C signal handlers ---- // Windows console event handler (CTRL_C_EVENT / CTRL_BREAK_EVENT). @@ -90,6 +95,138 @@ static void on_signal(int /*sig*/) // ---- 工具函数 / Utility functions ---- +// ---- 进度指示器 (spinner) / Progress indicator (spinner) ---- +// 在等待 AI 响应时在 stderr 显示旋转字符,通过 atomic flag 控制启停。 +// Displays a rotating character on stderr while waiting for AI responses, controlled via atomic flag. +static void spinner_run() +{ + const char chars[] = "|/-\\"; + int i = 0; + while (g_spinning.load(std::memory_order_relaxed)) { + std::fprintf(stderr, "\r%c", chars[i % 4]); + std::fflush(stderr); + i++; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + // 光标归位(不擦除,由下一条 stdout 输出覆盖) / Return cursor (don't erase, let next stdout output overwrite) + std::fprintf(stderr, "\r"); + std::fflush(stderr); +} + +static void spinner_start() +{ + if (g_spinner_thread.joinable()) { + g_spinner_thread.join(); + } + g_spinning = true; + g_spinner_thread = std::thread(spinner_run); +} + +static void spinner_stop() +{ + g_spinning = false; +} + +static void spinner_join() +{ + if (g_spinner_thread.joinable()) { + g_spinner_thread.join(); + } +} + +// ---- 错误分类与友好提示 / Error classification and user-friendly messages ---- +// 根据 HTTP 状态码和错误消息字符串匹配,将常见错误归类为认证/频率限制/网络/配额问题,并给出中文建议。 +// Classifies common errors into auth/rate-limit/network/quota categories based on HTTP status and string matching, with Chinese suggestions. +static void print_error(const char* error_msg, int http_status) +{ + std::string msg(error_msg ? error_msg : "unknown error"); + + const char* category = nullptr; + const char* suggestion = nullptr; + + // 先按 HTTP 状态码分类(最可靠) / First classify by HTTP status code (most reliable) + switch (http_status) { + case 401: + case 403: + category = "认证失败"; + suggestion = "请检查 API key 是否正确(输入 /status 查看当前配置)"; + break; + case 429: + category = "请求频率限制"; + suggestion = "API 调用太频繁,请稍后重试或降低请求频率"; + break; + case 400: + category = "请求参数错误"; + suggestion = "请求格式不正确,可能是模型名或参数有误(输入 /status 查看)"; + break; + case 502: + case 503: + case 504: + category = "服务器错误"; + suggestion = "API 服务器暂时不可用,请稍后重试"; + break; + default: + break; + } + + // http_status 未覆盖 → 字符串模式匹配 / HTTP status not covered → string pattern matching + if (!category) { + if (msg.find("401") != std::string::npos || + msg.find("403") != std::string::npos || + msg.find("Unauthorized") != std::string::npos || + msg.find("Forbidden") != std::string::npos || + msg.find("authentication") != std::string::npos || + msg.find("invalid api key") != std::string::npos || + msg.find("Incorrect API key") != std::string::npos) { + category = "认证失败"; + suggestion = "请检查 API key 是否正确(输入 /status 查看当前配置)"; + } else if (msg.find("429") != std::string::npos || + msg.find("rate") != std::string::npos || + msg.find("Rate limit") != std::string::npos || + msg.find("too many requests") != std::string::npos) { + category = "请求频率限制"; + suggestion = "API 调用太频繁,请稍后重试或降低请求频率"; + } else if (msg.find("connection refused") != std::string::npos || + msg.find("Connection refused") != std::string::npos || + msg.find("connection reset") != std::string::npos || + msg.find("Connection reset") != std::string::npos || + msg.find("timed out") != std::string::npos || + msg.find("Timeout") != std::string::npos || + msg.find("network") != std::string::npos || + msg.find("Network") != std::string::npos || + msg.find("resolve") != std::string::npos || + msg.find("Name or service not known") != std::string::npos || + msg.find("Couldn't resolve") != std::string::npos || + msg.find("Failed to connect") != std::string::npos || + msg.find("Could not connect") != std::string::npos || + msg.find("could not connect") != std::string::npos || + msg.find("connect error") != std::string::npos || + msg.find("Connect error") != std::string::npos || + msg.find("connect failed") != std::string::npos || + msg.find("Connect failed") != std::string::npos) { + category = "网络错误"; + suggestion = "无法连接到 API 服务器,请检查网络连接和 base_url(输入 /status 查看)"; + } else if (msg.find("400") != std::string::npos || + msg.find("Bad Request") != std::string::npos) { + category = "请求参数错误"; + suggestion = "请求格式不正确,可能是模型名或参数有误(输入 /status 查看)"; + } else if (msg.find("insufficient") != std::string::npos || + msg.find("quota") != std::string::npos || + msg.find("billing") != std::string::npos) { + category = "配额不足"; + suggestion = "API 配额已用完或账户余额不足,请检查账户状态"; + } + } + + if (category && suggestion) { + std::fprintf(stderr, CLR_RED "[ERROR] %s\n" CLR_RESET, category); + std::fprintf(stderr, CLR_YELLOW " -> %s\n" CLR_RESET, suggestion); + std::fprintf(stderr, CLR_DIM " (原始错误: %s)\n" CLR_RESET, msg.c_str()); + } else { + std::fprintf(stderr, CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET, msg.c_str()); + } +} + // 打印启动横幅 / Print the dstalk CLI banner with version, AI indicator, and quick command hints. static void print_banner() { @@ -391,12 +528,15 @@ static void handle_command(const char* line) } // ---- 流式回调 / Streaming callback ---- -// 流式输出回调:每收到一个 token 打印到 stdout 并刷新 / Callback invoked for each token during streaming chat; prints the token to stdout and flushes. +// 流式输出回调:每收到一个 token 打印到 stdout 并刷新。 +// 第一个 token 到达时停止 spinner 并用 \r 覆盖旋转字符。 +// Callback invoked for each token during streaming chat; stops spinner on first token and overwrites the spinner character with \r. static int on_stream_token(const char* token, void* userdata) { bool* first = static_cast(userdata); if (*first) { - std::printf(CLR_GREEN); + spinner_stop(); + std::printf("\r" CLR_GREEN); *first = false; } std::printf("%s", token); @@ -404,6 +544,18 @@ static int on_stream_token(const char* token, void* userdata) return 0; } +// ---- 管道 / --prompt 共用:从 stdin 读入全部内容 / Read all stdin content (shared by pipe and --prompt modes) ---- +static std::string read_all_stdin() +{ + std::string result; + std::string line; + while (std::getline(std::cin, line)) { + if (!result.empty()) result += '\n'; + result += line; + } + return result; +} + // ---- 主程序 / Main entry point ---- // 入口:初始化 dstalk host,查询插件服务,处理 batch/pipe/交互模式。 // Entry point: initializes dstalk host, queries plugin services, handles batch/pipe/interactive modes. @@ -520,11 +672,7 @@ int main(int argc, char* argv[]) // ---- B3: 管道输入模式 (非交互) / Pipe input mode (non-interactive) ---- if (pipe_mode) { - std::string input; - char buf[4096]; - while (std::fgets(buf, sizeof(buf), stdin)) { - input += buf; - } + std::string input = read_all_stdin(); if (input.empty()) { std::fprintf(stderr, "empty prompt\n"); dstalk_shutdown(); @@ -544,8 +692,7 @@ int main(int argc, char* argv[]) dstalk_shutdown(); return EXIT_OK; } else { - std::fprintf(stderr, CLR_RED "[ERROR] AI error: %s\n" CLR_RESET, - result.error ? result.error : "unknown"); + print_error(result.error, result.http_status); g_ai->free_result(&result); dstalk_shutdown(); return EXIT_FATAL; @@ -557,10 +704,7 @@ int main(int argc, char* argv[]) std::string prompt_text; if (std::strcmp(prompt_arg, "-") == 0) { // --prompt - or --prompt (no arg): read prompt from stdin / --prompt - 或 --prompt(无参数):从 stdin 读取提示 - char buf[4096]; - while (std::fgets(buf, sizeof(buf), stdin)) { - prompt_text += buf; - } + prompt_text = read_all_stdin(); if (prompt_text.empty()) { std::fprintf(stderr, "empty prompt\n"); dstalk_shutdown(); @@ -588,15 +732,15 @@ int main(int argc, char* argv[]) dstalk_shutdown(); return EXIT_OK; } else { - std::fprintf(stderr, CLR_RED "[ERROR] AI error: %s\n" CLR_RESET, - result.error ? result.error : "unknown"); + print_error(result.error, result.http_status); g_ai->free_result(&result); dstalk_shutdown(); return EXIT_FATAL; } } - char buffer[8192]; + // ---- 交互模式主循环 / Interactive mode main loop ---- + std::string line; while (true) { // B1: 检查退出标志 / Check quit flag if (g_quit_requested) { @@ -611,26 +755,17 @@ int main(int argc, char* argv[]) std::fflush(stdout); } - if (!std::fgets(buffer, sizeof(buffer), stdin)) break; + // 动态读取一行,无大小限制 / Read one line dynamically, no size limit + if (!std::getline(std::cin, line)) break; - // C3: fgets 截断检测 / fgets truncation detection - if (!std::strchr(buffer, '\n') && !feof(stdin)) { - std::fprintf(stderr, CLR_RED "[ERROR] 输入超过 8KB,已截断。建议用文件方式:dstalk --batch < file.txt\n" CLR_RESET); - int c; - while ((c = std::fgetc(stdin)) != '\n' && c != EOF) {} - } + // 去除末尾的 \r(Windows) / Strip trailing \r (Windows) + if (!line.empty() && line.back() == '\r') line.pop_back(); - // 去除末尾换行 / Strip trailing newline - size_t len = std::strlen(buffer); - while (len > 0 && (buffer[len-1] == '\n' || buffer[len-1] == '\r')) { - buffer[--len] = '\0'; - } - - if (len == 0) continue; + if (line.empty()) continue; // 命令处理 / Command dispatch - if (buffer[0] == '/') { - handle_command(buffer); + if (line[0] == '/') { + handle_command(line.c_str()); continue; } @@ -644,14 +779,19 @@ int main(int argc, char* argv[]) int history_count = 0; const dstalk_message_t* history = g_session->history(&history_count); + // 启动 spinner,等待 AI 响应 / Start spinner while waiting for AI response + spinner_start(); bool first = true; dstalk_chat_result_t result = g_ai->chat_stream( - history, history_count, buffer, on_stream_token, &first); + history, history_count, line.c_str(), on_stream_token, &first); + + // 确保 spinner 已停止(处理无流式输出的情况) / Ensure spinner is stopped (handles no-stream-output case) + spinner_stop(); if (result.ok) { std::printf(CLR_RESET "\n\n"); // 将用户消息和 AI 回复添加到会话 / Add user message and AI reply to session - dstalk_message_t user_msg = {"user", buffer, nullptr, nullptr}; + dstalk_message_t user_msg = {"user", line.c_str(), nullptr, nullptr}; g_session->add(&user_msg); dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json}; g_session->add(&ai_msg); @@ -727,8 +867,10 @@ int main(int argc, char* argv[]) history = g_session->history(&history_count); g_ai->free_result(&result); + spinner_start(); bool tool_stream_first = true; result = g_ai->chat_stream(history, history_count, nullptr, on_stream_token, &tool_stream_first); + spinner_stop(); if (result.ok) { std::printf(CLR_RESET "\n"); @@ -741,8 +883,7 @@ int main(int argc, char* argv[]) g_session->add(&ai_followup); has_tool_calls = (result.tool_calls_json && result.tool_calls_json[0] != '\0'); } else { - std::printf(CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET, - result.error ? result.error : "unknown error"); + print_error(result.error, result.http_status); break; } } @@ -751,14 +892,17 @@ int main(int argc, char* argv[]) std::fprintf(stderr, CLR_YELLOW "[WARN] 已达最大工具调用轮次(%d),停止\n" CLR_RESET, MAX_TOOL_ROUNDS); } } else { - // A3: error 路径下需 NULL 保护;当前只取 result.error,content 未涉及 / Error path needs NULL guard; currently only reads result.error, content not involved - std::printf(CLR_RESET "\n" CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET, - result.error ? result.error : "unknown error"); + // AI 调用失败:reset 颜色,输出分类错误信息 / AI call failed: reset color, output classified error info + std::printf(CLR_RESET "\n"); + print_error(result.error, result.http_status); } g_ai->free_result(&result); } // B2: 单一退出点,dstalk_shutdown 只在此调用(交互模式下) / Single exit point, dstalk_shutdown only called here (in interactive mode) + // 确保 spinner 线程已结束——先发信号停止,再 join 等待线程真正退出 / Ensure spinner thread has ended: signal stop first, then join to wait for thread exit + spinner_stop(); + spinner_join(); dstalk_shutdown(); return g_quit_via_signal ? EXIT_INTERRUPT : EXIT_OK; } diff --git a/dstalk_frontend_common/CMakeLists.txt b/dstalk_frontend_common/CMakeLists.txt new file mode 100644 index 0000000..694253c --- /dev/null +++ b/dstalk_frontend_common/CMakeLists.txt @@ -0,0 +1,17 @@ +# ============================================================ +# dstalk_frontend_common — 前端公共初始化静态库 +# ============================================================ + +add_library(dstalk_frontend_common STATIC + src/frontend_common.cpp +) + +target_include_directories(dstalk_frontend_common + PUBLIC include +) + +target_link_libraries(dstalk_frontend_common + PUBLIC dstalk +) + +target_compile_features(dstalk_frontend_common PUBLIC cxx_std_20) diff --git a/dstalk_frontend_common/include/dstalk_frontend_common.hpp b/dstalk_frontend_common/include/dstalk_frontend_common.hpp new file mode 100644 index 0000000..703af21 --- /dev/null +++ b/dstalk_frontend_common/include/dstalk_frontend_common.hpp @@ -0,0 +1,65 @@ +// ============================================================================ +// dstalk_frontend_common — 前端公共初始化模块 +// ============================================================================ +// 提供所有前端(CLI / GUI / Web)共享的启动逻辑: +// - 配置文件发现(argv / 默认路径 / 平台 fopen) +// - dstalk_init() 调用 +// - 常用服务查询(ai, session, file_io, tools, context) +// - AI 服务默认配置(从 config 读取,带 fallback) +// ============================================================================ + +#ifndef DSTALK_FRONTEND_COMMON_HPP +#define DSTALK_FRONTEND_COMMON_HPP + +#include + +#include "dstalk/dstalk_host.h" + +struct FrontendServices { + const dstalk_ai_service_t* ai = nullptr; + const dstalk_session_service_t* session = nullptr; + const dstalk_file_io_service_t* file_io = nullptr; + const dstalk_tools_service_t* tools = nullptr; + + std::string provider; // "ai.deepseek" / "ai.openai" / "ai.anthropic" + std::string model; // e.g. "deepseek-v4-pro" + std::string base_url; // e.g. "https://api.deepseek.com/v1" + std::string api_key; + + // 是否已成功初始化 dstalk 核心 + bool initialized = false; +}; + +// ---- 前端公共初始化 ---- +// +// 功能: +// 1. 发现配置文件:优先 argv[1](跳过已知标志),其次 default_config(如 "config.toml") +// 2. 调用 dstalk_init(config_path) +// 3. 查询常用插件服务(ai / session / file_io / tools) +// 4. 用 dstalk_config_get 读取 api.* 键并调用 ai->configure() 设置默认值 +// +// 参数: +// svc - [out] 填入查询到的服务指针和配置信息 +// argc/argv - 命令行参数(可为 0/nullptr,例如 GUI 没有命令行参数) +// default_cfg- 默认配置文件名(如 "config.toml"),当 argv 未提供时使用 +// skip_flags - 以 NULL 结尾的字符串数组,argv 扫描时跳过这些标志及其下一个参数 +// +// 返回值: +// 0 - 成功,svc.initialized == true,至少 ai + session 已就绪 +// 1 - dstalk_init 失败 +// 2 - AI 服务未找到 +// 3 - Session 服务未找到 +// +int dstalk_frontend_init(FrontendServices& svc, + int argc = 0, char* argv[] = nullptr, + const char* default_cfg = "config.toml", + const char* const* skip_flags = nullptr); + +// ---- 便捷辅助 ---- + +// 将 dstalk_message_t 数组的内容追加到 session 服务(一次一条)。 +// 常用于 Ctrl+O 加载会话后重建前端消息列表。 +// 返回实际追加的条数。 +int dstalk_frontend_replay_history(FrontendServices& svc); + +#endif // DSTALK_FRONTEND_COMMON_HPP diff --git a/dstalk_frontend_common/src/frontend_common.cpp b/dstalk_frontend_common/src/frontend_common.cpp new file mode 100644 index 0000000..1bb3c09 --- /dev/null +++ b/dstalk_frontend_common/src/frontend_common.cpp @@ -0,0 +1,124 @@ +// ============================================================================ +// dstalk_frontend_common — 实现 +// ============================================================================ + +#include "dstalk_frontend_common.hpp" + +#include +#include +#include + +// ---- 配置文件发现 ---- + +static const char* discover_config(int argc, char* argv[], + const char* default_cfg, + const char* const* skip_flags) +{ + // 1) 从 argv 中查找首个非标志参数 + if (argc >= 2 && argv) { + for (int i = 1; i < argc; ++i) { + bool is_skip = false; + if (skip_flags) { + for (int k = 0; skip_flags[k]; ++k) { + if (std::strcmp(argv[i], skip_flags[k]) == 0) { + is_skip = true; + // 若该标志有值参数则多跳一个 + if (i + 1 < argc && argv[i + 1][0] != '-') ++i; + break; + } + } + } + if (!is_skip) return argv[i]; + } + } + + // 2) 回退:尝试默认配置文件 + if (!default_cfg || default_cfg[0] == '\0') return nullptr; + + const char* candidates[] = { default_cfg, nullptr }; + for (int i = 0; candidates[i]; ++i) { + FILE* f = nullptr; +#ifdef _WIN32 + if (fopen_s(&f, candidates[i], "r") == 0 && f) { +#else + f = std::fopen(candidates[i], "r"); + if (f) { +#endif + std::fclose(f); + return candidates[i]; + } + } + return nullptr; +} + +// ---- 主初始化 ---- + +int dstalk_frontend_init(FrontendServices& svc, + int argc, char* argv[], + const char* default_cfg, + const char* const* skip_flags) +{ + // (1) 发现并加载配置 + const char* cfg = discover_config(argc, argv, default_cfg, skip_flags); + if (dstalk_init(cfg) != 0) { + std::fprintf(stderr, "[dstalk] 初始化失败\n"); + return 1; + } + + // (2) 查询 AI 服务 + const char* provider = dstalk_config_get("ai.provider"); + if (!provider || provider[0] == '\0') provider = "ai.deepseek"; + svc.provider = provider; + + svc.ai = static_cast( + dstalk_service_query(provider, 1)); + svc.session = static_cast( + dstalk_service_query("session", 1)); + svc.file_io = static_cast( + dstalk_service_query("file_io", 1)); + svc.tools = static_cast( + dstalk_service_query("tools", 1)); + + const dstalk_context_service_t* ctx_svc = + static_cast( + dstalk_service_query("context", 1)); + (void)ctx_svc; // 不强制使用,保留以备将来使用 + + if (!svc.ai) { + std::fprintf(stderr, "[dstalk] AI 服务未找到(请检查插件目录)\n"); + return 2; + } + if (!svc.session) { + std::fprintf(stderr, "[dstalk] Session 服务未找到\n"); + return 3; + } + + // (3) 配置 AI 服务的默认值 + const char* base_url = dstalk_config_get("api.base_url"); + const char* api_key = dstalk_config_get("api.api_key"); + const char* model = dstalk_config_get("api.model"); + + if (!base_url || base_url[0] == '\0') base_url = "https://api.deepseek.com/v1"; + if (!model || model[0] == '\0') model = "deepseek-v4-pro"; + + svc.base_url = base_url; + svc.api_key = api_key ? api_key : ""; + svc.model = model; + + svc.ai->configure(provider, base_url, + api_key ? api_key : "", + model, 4096, 0.7); + + svc.initialized = true; + return 0; +} + +// ---- 会话历史回放 ---- + +int dstalk_frontend_replay_history(FrontendServices& svc) +{ + if (!svc.session) return 0; + int count = 0; + const dstalk_message_t* msgs = svc.session->history(&count); + return count; // 调用方自行遍历并重建前端消息列表 +} diff --git a/dstalk_gui/CMakeLists.txt b/dstalk_gui/CMakeLists.txt index 013904e..3e3f49d 100644 --- a/dstalk_gui/CMakeLists.txt +++ b/dstalk_gui/CMakeLists.txt @@ -16,5 +16,6 @@ set_target_properties(dstalk_gui PROPERTIES target_link_libraries(dstalk_gui PRIVATE dstalk + dstalk_frontend_common SDL3::SDL3 ) diff --git a/dstalk_gui/src/main.cpp b/dstalk_gui/src/main.cpp index 27ed57c..dd2df1e 100644 --- a/dstalk_gui/src/main.cpp +++ b/dstalk_gui/src/main.cpp @@ -744,8 +744,23 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) { break; case SDLK_O: if (ctrl) { - // Ctrl+O:加载会话 / Ctrl+O: load session + // Ctrl+O:加载会话 if (g_session_svc && g_session_svc->load("session.json") == 0) { + // BUGFIX: 从 session 历史重建 GUI 消息列表 + int hcount = 0; + const dstalk_message_t* history = g_session_svc->history(&hcount); + gs.messages.clear(); + for (int i = 0; i < hcount; ++i) { + ChatMessage::Role r; + if (std::strcmp(history[i].role, "user") == 0) + r = ChatMessage::USER; + else if (std::strcmp(history[i].role, "assistant") == 0) + r = ChatMessage::ASSISTANT; + else + r = ChatMessage::SYSTEM; + gs.messages.push_back(ChatMessage(r, + history[i].content ? history[i].content : "")); + } gs.messages.push_back(ChatMessage( ChatMessage::SYSTEM, "Session loaded from session.json")); } else { @@ -895,8 +910,30 @@ int main(int argc, char* argv[]) { g_ai_svc->free_result(&result); } - // 流式传输完成(或被取消) / Streaming completed (or cancelled) - if (rc != 0) { + // 流式传输完成(或被取消) + if (rc == 0) { + // BUGFIX: 将用户消息和 AI 回复持久化到 session 服务 + if (g_session_svc) { + const std::string& userContent = + ctx.state.messages[ctx.state.messages.size() - 2].content; + const std::string& aiContent = + ctx.state.messages.back().content; + + dstalk_message_t user_msg = { + "user", + userContent.c_str(), + nullptr, nullptr + }; + g_session_svc->add(&user_msg); + + dstalk_message_t ai_msg = { + "assistant", + aiContent.c_str(), + nullptr, nullptr + }; + g_session_svc->add(&ai_msg); + } + } else { if (!ctx.state.messages.empty() && ctx.state.messages.back().role == ChatMessage::ASSISTANT) { if (ctx.state.messages.back().content.empty()) { diff --git a/plugins_base/lsp/src/lsp_internal.hpp b/plugins_base/lsp/src/lsp_internal.hpp new file mode 100644 index 0000000..d7b49c8 --- /dev/null +++ b/plugins_base/lsp/src/lsp_internal.hpp @@ -0,0 +1,22 @@ +// ============================================================================ +// lsp_internal.hpp — 内部声明:供单元测试访问的 LSP 工具函数 +// ============================================================================ +// 仅在 tests 中使用;非 plugin 公共 API +// ============================================================================ + +#ifndef LSP_INTERNAL_HPP +#define LSP_INTERNAL_HPP + +#include +#include + +// ---- 字符串 trim ---- +std::string_view lsp_trim(std::string_view sv); + +// ---- 构建 LSP frame (Content-Length header + body) ---- +std::string lsp_frame_message(const std::string& body); + +// ---- 解析 Content-Length header ---- +int lsp_parse_content_length(const std::string& line); + +#endif // LSP_INTERNAL_HPP diff --git a/plugins_base/lsp/src/lsp_plugin.cpp b/plugins_base/lsp/src/lsp_plugin.cpp index 46851a8..e70bf2f 100644 --- a/plugins_base/lsp/src/lsp_plugin.cpp +++ b/plugins_base/lsp/src/lsp_plugin.cpp @@ -12,6 +12,7 @@ #include "dstalk/dstalk_host.h" #include "dstalk/dstalk_services.h" +#include "lsp_internal.hpp" #include #include @@ -311,7 +312,7 @@ static LspState g_lsp; // ============================================================================ // 去除 string_view 首尾空白 / Trim leading and trailing whitespace from a string_view. -static std::string_view trim(std::string_view sv) { +std::string_view lsp_trim(std::string_view sv) { while (!sv.empty() && (sv.front() == ' ' || sv.front() == '\t' || sv.front() == '\r' || sv.front() == '\n')) sv.remove_prefix(1); @@ -322,7 +323,7 @@ static std::string_view trim(std::string_view sv) { } // 将 JSON-RPC 消息体包装在 LSP 头中 (Content-Length: ...\r\n\r\n) / Wrap a JSON-RPC message body in an LSP header (Content-Length: ...\r\n\r\n). -static std::string frame_message(const std::string& body) { +std::string lsp_frame_message(const std::string& body) { std::string frame; frame.reserve(64 + body.size()); frame += "Content-Length: "; @@ -333,8 +334,8 @@ static std::string frame_message(const std::string& body) { } // 从 LSP 头行中解析 Content-Length 值。解析失败返回 -1 / Parse the Content-Length value from an LSP header line. Returns -1 on parse failure. -static int parse_content_length(const std::string& line) { - auto sv = trim(std::string_view(line)); +int lsp_parse_content_length(const std::string& line) { + auto sv = lsp_trim(std::string_view(line)); const char prefix[] = "Content-Length:"; const size_t prefix_len = sizeof(prefix) - 1; @@ -368,7 +369,7 @@ static int send_request(const std::string& method, const json::object& params) { msg["params"] = params; std::string body = json::serialize(msg); - g_lsp.proc.write(frame_message(body)); + g_lsp.proc.write(lsp_frame_message(body)); return id; } @@ -380,7 +381,7 @@ static void send_notification(const std::string& method, const json::object& par msg["params"] = params; std::string body = json::serialize(msg); - g_lsp.proc.write(frame_message(body)); + g_lsp.proc.write(lsp_frame_message(body)); } // ============================================================================ @@ -457,11 +458,11 @@ static void reader_loop() { } // header 块以空行结束 / header block ends with empty line - auto sv = trim(std::string_view(line)); + auto sv = lsp_trim(std::string_view(line)); if (sv.empty()) break; // 累积 Content-Length;遇到其他 header 不丢弃,继续读取下一行 / Accumulate Content-Length; don't discard other headers, continue reading next line - int len = parse_content_length(line); + int len = lsp_parse_content_length(line); if (len >= 0) content_length = len; } diff --git a/plugins_middle/network/src/network_plugin.cpp b/plugins_middle/network/src/network_plugin.cpp index 631fec7..9a989fe 100644 --- a/plugins_middle/network/src/network_plugin.cpp +++ b/plugins_middle/network/src/network_plugin.cpp @@ -35,6 +35,54 @@ namespace asio = boost::asio; namespace ssl = boost::asio::ssl; using tcp = asio::ip::tcp; +// ============================================================ +// 安全常量和输入验证辅助函数 / Security constants and input-validation helpers +// ============================================================ +static constexpr size_t MAX_HEADER_KEY_LENGTH = 256; +static constexpr size_t MAX_HEADER_VALUE_LENGTH = 8192; + +/// 如果字符串包含任何控制字符(< 0x20 或 0x7F DEL),返回 true / Return true if the string contains any control character (< 0x20 or 0x7F DEL). +static bool contains_control_chars(const char* s) { + if (!s) return false; + for (const char* p = s; *p; ++p) { + unsigned char c = static_cast(*p); + if (c < 0x20u || c == 0x7Fu) return true; + } + return false; +} + +/// 基本端口/服务名验证。拒绝空值,对数字端口进行范围检查, +/// 对符号服务名要求字母数字加连字符(RFC 6335)/ Basic port / service-name validation. +/// Rejects empty, bounds-checks numeric ports, and requires alphanumeric+hyphen +/// for symbolic service names (RFC 6335). +static bool is_valid_port(const char* port) { + if (!port || !*port) return false; + bool all_digits = true; + for (const char* p = port; *p; ++p) { + if (static_cast(*p) < '0' || + static_cast(*p) > '9') { + all_digits = false; + break; + } + } + if (all_digits) { + if (std::strlen(port) > 5) return false; // > 65535 + long p = std::atol(port); + return p > 0 && p <= 65535; + } + // 服务名:字母数字加连字符(RFC 6335);最长15个字符 / Service name: alphanumeric plus hyphen (RFC 6335); max 15 chars + for (const char* p = port; *p; ++p) { + unsigned char c = static_cast(*p); + if (!((c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + c == '-')) { + return false; + } + } + return std::strlen(port) <= 15; +} + // ============================================================ // 全局状态 / Global state // ============================================================ @@ -42,8 +90,11 @@ static const dstalk_host_api_t* g_host = nullptr; static dstalk_config_service_t* g_config_svc = nullptr; // ============================================================ -// 极简 JSON 头解析器 / Minimal JSON header parser +// 极简 JSON 头解析器(含安全长度限制)/ Minimal JSON header parser (with security length limits) // 将 {"key1":"value1","key2":"value2"} 解析到 unordered_map / Parses {"key1":"value1","key2":"value2"} into unordered_map +// 强制 MAX_HEADER_KEY_LENGTH (256) 和 MAX_HEADER_VALUE_LENGTH (8192) 限制, +// 防止恶意输入导致资源耗尽 / Enforces MAX_HEADER_KEY_LENGTH (256) and MAX_HEADER_VALUE_LENGTH +// (8192) to prevent resource exhaustion from malicious input. // ============================================================ // 将扁平 JSON 对象中的字符串键值对解析到 unordered_map / Parse a flat JSON object of string key-value pairs into an unordered_map. static std::unordered_map parse_headers_json(const char* json) { @@ -55,31 +106,53 @@ static std::unordered_map parse_headers_json(const cha enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE; std::string current_key; std::string current_value; + bool key_too_long = false; for (size_t i = 0; i < s.size(); ++i) { char c = s[i]; switch (state) { case OUTSIDE: - if (c == '"') { state = IN_KEY; current_key.clear(); } + if (c == '"') { state = IN_KEY; current_key.clear(); key_too_long = false; } break; case IN_KEY: if (c == '"') { state = AFTER_KEY; } - else if (c == '\\' && i + 1 < s.size()) { current_key += s[++i]; } - else { current_key += c; } + else if (c == '\\' && i + 1 < s.size()) { + if (current_key.size() < MAX_HEADER_KEY_LENGTH) + current_key += s[++i]; + else { ++i; key_too_long = true; } + } + else { + if (current_key.size() < MAX_HEADER_KEY_LENGTH) + current_key += c; + else + key_too_long = true; + } break; case AFTER_KEY: if (c == ':') { state = IN_VALUE; current_value.clear(); } + // 跳过键和冒号之间的多余字符(例如 ',' 或 '}')/ Skip stray characters between key and colon (e.g. ',' or '}') + else if (c == '"' || c == ',' || c == '}') { /* 保持 AFTER_KEY 状态,忽略 / stay in AFTER_KEY, ignore */ } break; case IN_VALUE: if (c == '"') { // 读取到闭合引号 / Read until closing quote ++i; while (i < s.size() && s[i] != '"') { - if (s[i] == '\\' && i + 1 < s.size()) { current_value += s[++i]; } - else { current_value += s[i]; } + if (s[i] == '\\' && i + 1 < s.size()) { + if (current_value.size() < MAX_HEADER_VALUE_LENGTH) + current_value += s[++i]; + else + ++i; // 跳过转义字符,值已截断 / skip escaped char, value truncated + } + else { + if (current_value.size() < MAX_HEADER_VALUE_LENGTH) + current_value += s[i]; + } ++i; } - headers[current_key] = current_value; + if (!key_too_long) { + headers[current_key] = current_value; + } state = OUTSIDE; } break; @@ -93,19 +166,76 @@ static std::unordered_map parse_headers_json(const cha // ============================================================ struct HttpClientCtx { asio::io_context ioc; - ssl::context ssl_ctx{ssl::context::tlsv12_client}; + ssl::context ssl_ctx{ssl::context::tls_client}; int connect_timeout = 30; int request_timeout = 120; HttpClientCtx() { - ssl_ctx.set_default_verify_paths(); - // 启用对等证书验证 (CVSS 7.4 修复) / Enable peer certificate verification (CVSS 7.4 fix). - // set_default_verify_paths() 加载系统 CA 包;没有 verify_peer - // CA 存储不会被查询——任何证书(自签名/过期)都将被接受 / set_default_verify_paths() loads system CA bundle; without verify_peer - // the CA store is never consulted — any cert (self-signed/expired) is accepted. - // TODO: Windows: set_default_verify_paths() 可能无法定位系统 CA; - // 如果验证失败,设置 SSL_CERT_FILE 环境变量或捆绑 cacert.pem / Windows: set_default_verify_paths() may not locate system CAs; - // if verification fails, set SSL_CERT_FILE env or bundle a cacert.pem. + // TLS 1.2+ 协商(tls_client 允许 TLS 1.2 和 1.3)/ TLS 1.2+ negotiation (tls_client allows TLS 1.2 and 1.3). + // 启用针对系统 CA 存储的对等证书验证。在 Windows 上 + // set_default_verify_paths() 可能无法定位系统 CA; + // 检测到这种情况时尝试回退源 / Enable peer certificate verification against system CA store. + // On Windows set_default_verify_paths() may not locate system CAs; + // we detect that case and try fallback sources. + + boost::system::error_code ec; + ssl_ctx.set_default_verify_paths(ec); + if (ec) { + // 主路径失败——按顺序尝试回退源 / Primary path failed — try fallback sources in order + bool loaded = false; + + // 回退 1:SSL_CERT_FILE / SSL_CERT_DIR(OpenSSL 内部已查询, + // 但显式 load_verify_file 提供明确的错误码用于报告)/ Fallback 1: SSL_CERT_FILE / SSL_CERT_DIR (already consulted by + // OpenSSL internally, but an explicit load_verify_file gives us + // a clear error code to report). + const char* cert_file = std::getenv("SSL_CERT_FILE"); + if (cert_file && *cert_file) { + ssl_ctx.load_verify_file(cert_file, ec); + if (!ec) loaded = true; + } + if (!loaded) { + const char* cert_dir = std::getenv("SSL_CERT_DIR"); + if (cert_dir && *cert_dir) { + ssl_ctx.add_verify_path(cert_dir, ec); + if (!ec) loaded = true; + } + } + + // 回退 2:http.ca_cert_file 配置项 / Fallback 2: http.ca_cert_file config key + if (!loaded && g_config_svc) { + const char* cfg_cert = g_config_svc->get("http.ca_cert_file"); + if (cfg_cert && *cfg_cert) { + ssl_ctx.load_verify_file(cfg_cert, ec); + if (!ec) loaded = true; + } + } + + // 回退 3:捆绑的 cacert.pem,相对于常见安装路径 / Fallback 3: bundled cacert.pem relative to common install paths + if (!loaded) { + static const char* kBundlePaths[] = { + "cacert.pem", + "share/cacert.pem", + "../share/cacert.pem", + "certs/cacert.pem", + nullptr + }; + for (int pi = 0; kBundlePaths[pi]; ++pi) { + ssl_ctx.load_verify_file(kBundlePaths[pi], ec); + if (!ec) { loaded = true; break; } + } + } + + if (!loaded) { + if (g_host) g_host->log(DSTALK_LOG_WARN, + "TLS CA certificates not found. " + "set_default_verify_paths() failed: %s. " + "Set SSL_CERT_FILE=/path/to/cacert.pem or " + "http.ca_cert_file in config. " + "TLS verification will proceed but may fail at handshake.", + ec.message().c_str()); + } + } + ssl_ctx.set_verify_mode(ssl::verify_peer); } }; @@ -132,6 +262,33 @@ static int do_post_stream( return -1; } + // ---- 输入验证(安全加固)/ Input validation (security hardening) ---- + + // 拒绝 host 和 target 中的控制字符(CRLF 注入防护)/ Reject control characters in host and target (CRLF injection prevention) + if (contains_control_chars(host)) { + if (g_host) g_host->log(DSTALK_LOG_ERROR, + "do_post_stream: host contains control characters"); + *response_body = nullptr; + *status_code = -1; + return -1; + } + if (contains_control_chars(target)) { + if (g_host) g_host->log(DSTALK_LOG_ERROR, + "do_post_stream: target contains control characters"); + *response_body = nullptr; + *status_code = -1; + return -1; + } + + // 验证端口:必须是数字 1-65535 或有效的服务名 / Validate port: must be numeric 1-65535 or a valid service name + if (!is_valid_port(port)) { + if (g_host) g_host->log(DSTALK_LOG_ERROR, + "do_post_stream: invalid port '%s'", port); + *response_body = nullptr; + *status_code = -1; + return -1; + } + // 初始化输出 / Initialize output *response_body = nullptr; *status_code = -1; diff --git a/plugins_upper/anthropic/src/anthropic_internal.hpp b/plugins_upper/anthropic/src/anthropic_internal.hpp new file mode 100644 index 0000000..4bdae03 --- /dev/null +++ b/plugins_upper/anthropic/src/anthropic_internal.hpp @@ -0,0 +1,46 @@ +// ============================================================================ +// anthropic_internal.hpp — 内部声明:供单元测试访问的函数与数据结构 +// ============================================================================ +// 仅在 tests 中使用;非 plugin 公共 API +// ============================================================================ + +#ifndef ANTHROPIC_INTERNAL_HPP +#define ANTHROPIC_INTERNAL_HPP + +#include "dstalk/dstalk_host.h" +#include "dstalk/dstalk_services.h" + +#include "ai_common.hpp" + +#include +#include + +// ---- 从 dstalk_ai 命名空间导入共享类型与函数 ---- +using dstalk_ai::PluginConfig; +using dstalk_ai::ToolCallAccum; +using dstalk_ai::StreamContext; +using dstalk_ai::secure_zero; +using dstalk_ai::extract_host_port; + +// ---- 全局变量 ---- +extern PluginConfig g_cfg; +extern std::string g_tools_json; + +// ---- 测试用函数声明 ---- +bool parse_sse_data(const std::string& data, std::string& token_out, + StreamContext* ctx); + +std::string build_request_json(const dstalk_message_t* history, int history_len, + const std::string& user_input, + const std::string& tools_json, + bool stream); + +std::string build_headers_json(); + +void my_free_result(dstalk_chat_result_t* result); + +int my_configure(const char* provider, const char* base_url, + const char* api_key, const char* model, + int max_tokens, double temperature); + +#endif // ANTHROPIC_INTERNAL_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0dfc0d5..fbe0fa3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -212,6 +212,40 @@ target_link_libraries(dstalk_network_plugin_test add_test(NAME dstalk_network_plugin_test COMMAND dstalk_network_plugin_test) +# ============================================================ +# dstalk_lsp_plugin_test — LSP 插件单元测试 (GoogleTest, 新增) +# 覆盖: lsp_trim / lsp_frame_message / lsp_parse_content_length +# ============================================================ + +add_executable(dstalk_lsp_plugin_test + lsp_plugin_test.cpp + ${CMAKE_SOURCE_DIR}/plugins_base/lsp/src/lsp_plugin.cpp +) + +target_include_directories(dstalk_lsp_plugin_test + PRIVATE + ${CMAKE_SOURCE_DIR}/dstalk_core/include + ${CMAKE_SOURCE_DIR}/plugins_base/lsp/src +) + +target_compile_definitions(dstalk_lsp_plugin_test + PRIVATE + BOOST_JSON_HEADER_ONLY + BOOST_ALL_NO_LIB +) + +target_compile_features(dstalk_lsp_plugin_test + PRIVATE cxx_std_17 +) + +target_link_libraries(dstalk_lsp_plugin_test + PRIVATE + dstalk + boost::boost +) + +add_test(NAME dstalk_lsp_plugin_test COMMAND dstalk_lsp_plugin_test) + # ============================================================ # coverage — gcovr 覆盖率报告 (HTML + 终端摘要) # 用法: cmake --build --target coverage diff --git a/tests/lsp_plugin_test.cpp b/tests/lsp_plugin_test.cpp new file mode 100644 index 0000000..b7f4572 --- /dev/null +++ b/tests/lsp_plugin_test.cpp @@ -0,0 +1,153 @@ +// ============================================================================ +// lsp_plugin_test.cpp — LSP 插件单元测试 (轻量 CHECK 宏,离线环境无 GoogleTest) +// ============================================================================ +// 测试 LSP 插件的可独立验证功能: +// - lsp_trim: 字符串 trim 逻辑 +// - lsp_frame_message: Content-Length header 构建 +// - lsp_parse_content_length: Content-Length header 解析 +// ============================================================================ + +#include "lsp_internal.hpp" + +#include "dstalk/dstalk_host.h" + +#include +#include +#include +#include + +static int g_failures = 0; + +// Lightweight assertion macros (matches project pattern used by other tests) +#define CHECK(cond, msg) do { \ + if (!(cond)) { \ + std::cerr << "[FAIL] " << (msg) << "\n"; \ + ++g_failures; \ + } \ +} while (0) + +#define CHECK_EQ(actual, expected, msg) do { \ + auto _a = (actual); \ + auto _e = (expected); \ + if (!(_a == _e)) { \ + std::cerr << "[FAIL] " << (msg) \ + << " (got=" << _a << " expected=" << _e << ")\n"; \ + ++g_failures; \ + } \ +} while (0) + +// ---------------------------------------------------------------------------- +// lsp_trim +// ---------------------------------------------------------------------------- +static void test_lsp_trim() { + CHECK_EQ(lsp_trim(""), std::string(""), "trim empty string"); + CHECK_EQ(lsp_trim("hello"), std::string("hello"), "trim no whitespace"); + CHECK_EQ(lsp_trim(" hello"), std::string("hello"), "trim leading spaces"); + CHECK_EQ(lsp_trim("hello "), std::string("hello"), "trim trailing spaces"); + CHECK_EQ(lsp_trim(" hello "), std::string("hello"), "trim both sides"); + CHECK_EQ(lsp_trim("\t\n\rhello\t\n\r"), std::string("hello"), "trim tabs/newlines"); + CHECK_EQ(lsp_trim(" \t\n\r "), std::string(""), "trim only whitespace"); + CHECK_EQ(lsp_trim("A"), std::string("A"), "trim single char"); + CHECK_EQ(lsp_trim(" hello world "), std::string("hello world"), + "trim preserves internal whitespace"); +} + +// ---------------------------------------------------------------------------- +// lsp_frame_message +// ---------------------------------------------------------------------------- +static void test_lsp_frame_message() { + { + std::string frame = lsp_frame_message(""); + CHECK(frame.find("Content-Length: 0") != std::string::npos, + "frame empty body has Content-Length: 0"); + CHECK(frame.find("\r\n\r\n") != std::string::npos, + "frame empty body has header separator"); + CHECK(frame.find("\r\n\r\n") + 4 == frame.size(), + "frame empty body ends right after separator"); + } + { + std::string body = "{\"jsonrpc\":\"2.0\"}"; + std::string frame = lsp_frame_message(body); + CHECK(frame.find("Content-Length: " + std::to_string(body.size())) != std::string::npos, + "frame simple body Content-Length matches"); + CHECK(frame.find(body) != std::string::npos, "frame simple body present"); + } + { + std::string body = "line1\nline2\nline3"; + std::string frame = lsp_frame_message(body); + CHECK(frame.find("Content-Length: " + std::to_string(body.size())) != std::string::npos, + "frame multiline body Content-Length matches"); + CHECK(frame.find(body) != std::string::npos, "frame multiline body present"); + } + { + std::string body(std::string("\x00\x01\x02\x03\xFF", 5)); + std::string frame = lsp_frame_message(body); + CHECK(frame.find("Content-Length: 5") != std::string::npos, + "frame binary body Content-Length: 5"); + } +} + +// ---------------------------------------------------------------------------- +// lsp_parse_content_length +// ---------------------------------------------------------------------------- +static void test_lsp_parse_content_length() { + CHECK_EQ(lsp_parse_content_length("Content-Length: 1234"), 1234, "parse valid header"); + CHECK_EQ(lsp_parse_content_length(" Content-Length: 42"), 42, "parse leading spaces"); + CHECK_EQ(lsp_parse_content_length("content-length: 99"), 99, "parse lowercase"); + CHECK_EQ(lsp_parse_content_length("CONTENT-LENGTH: 77"), 77, "parse uppercase"); + CHECK_EQ(lsp_parse_content_length("Content-Length: 0"), 0, "parse zero length"); + CHECK_EQ(lsp_parse_content_length("Content-Length: 1048576"), 1048576, "parse large value"); + CHECK_EQ(lsp_parse_content_length(""), -1, "parse empty string"); + CHECK_EQ(lsp_parse_content_length("Content-Lengthh: 10"), -1, "parse misspelled (extra h)"); + CHECK_EQ(lsp_parse_content_length("ContentLength: 10"), -1, "parse misspelled (no hyphen)"); + CHECK_EQ(lsp_parse_content_length("Content-Length 10"), -1, "parse missing colon"); + CHECK_EQ(lsp_parse_content_length("Content-Length: abc"), -1, "parse non-numeric"); + CHECK_EQ(lsp_parse_content_length("Content-Length: -5"), -5, "parse negative"); + CHECK_EQ(lsp_parse_content_length("Content-Length: 999999999999"), -1, "parse overflow"); + CHECK_EQ(lsp_parse_content_length(std::string("\x00\x01\xFF", 3)), -1, "parse garbage input"); + CHECK_EQ(lsp_parse_content_length("Content-Length: 1234abc"), 1234, "parse trailing garbage"); + CHECK_EQ(lsp_parse_content_length("Content-Type: application/vscode-jsonrpc"), -1, + "parse other LSP header returns -1"); +} + +// ---------------------------------------------------------------------------- +// frame + parse round-trip +// ---------------------------------------------------------------------------- +static void test_round_trip() { + { + std::string body = "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\"}"; + std::string frame = lsp_frame_message(body); + size_t header_end = frame.find("\r\n\r\n"); + CHECK(header_end != std::string::npos, "round-trip simple finds header end"); + if (header_end != std::string::npos) { + std::string header_block = frame.substr(0, header_end); + CHECK_EQ(lsp_parse_content_length(header_block), static_cast(body.size()), + "round-trip simple Content-Length matches body size"); + } + } + { + std::string body; + std::string frame = lsp_frame_message(body); + size_t header_end = frame.find("\r\n\r\n"); + CHECK(header_end != std::string::npos, "round-trip empty finds header end"); + if (header_end != std::string::npos) { + std::string header_block = frame.substr(0, header_end); + CHECK_EQ(lsp_parse_content_length(header_block), 0, + "round-trip empty Content-Length is 0"); + } + } +} + +int main() { + test_lsp_trim(); + test_lsp_frame_message(); + test_lsp_parse_content_length(); + test_round_trip(); + + if (g_failures == 0) { + std::cout << "lsp_plugin_test: all checks passed\n"; + return 0; + } + std::cerr << "lsp_plugin_test: " << g_failures << " check(s) failed\n"; + return 1; +}