feat: Add LSP plugin unit tests and frontend common initialization library
Some checks failed
CI / Determine matrix (push) Has been cancelled
CI / Sanitizer (ASan+UBSan) / ubuntu-24.04 (push) Has been cancelled
CI / Coverage (gcovr) / ubuntu-24.04 (push) Has been cancelled
CI / ${{ matrix.os }} / ${{ matrix.build_type }} (push) Has been cancelled

- Introduced `dstalk_lsp_plugin_test` for testing LSP plugin functionalities including `lsp_trim`, `lsp_frame_message`, and `lsp_parse_content_length`.
- Created `dstalk_frontend_common` static library to encapsulate shared initialization logic for frontend components (CLI, GUI, Web).
- Implemented configuration file discovery and service querying in `dstalk_frontend_init`.
- Added internal headers for LSP and Anthropic plugins to facilitate unit testing.
- Established a mailroom system for asynchronous message passing between stateless agents, enhancing coordination and context management.
This commit is contained in:
2026-06-01 08:51:40 +08:00
parent 8faa02c3d5
commit c0af9c65c7
17 changed files with 1235 additions and 69 deletions

109
CLAUDE.md Normal file
View File

@@ -0,0 +1,109 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Build & run
The repo ships its own toolchain under `tools/` (clang-cl/clang, CMake, Ninja, Conan 2 in a venv). Build scripts prefer those over PATH.
- **One-shot build (Windows)**: `build.bat` — checks/installs toolchain, runs Conan install against `deps/conanfile.txt`, configures CMake with clang-cl, builds with Ninja into `build/`.
- **One-shot build (Linux/macOS)**: `bash build.sh [Release|Debug]` — same flow with clang (or gcc fallback) and Ninja.
- **First-time setup**: `tools/setup.bat` or `bash tools/setup.sh` installs CMake, Ninja, and Conan 2 into `tools/` without touching system PATH.
- **CI-style build**: `scripts/ci-build.sh` / `scripts/ci-build.bat` — minimal Conan + CMake + Ninja flow used by CI.
- **Incremental rebuild after edits**: `cd build && <tools/ninja or system ninja> -j<N>` from inside `build/`. No need to re-run Conan unless `deps/conanfile.txt` changes.
Build outputs:
- Executables → `build/bin/` (`dstalk_cli`, test binaries, optional `dstalk_gui`/`dstalk_web`)
- Plugin DLLs → `build/plugins/` (each `plugin_*` target writes here, prefix stripped)
CMake options (root `CMakeLists.txt`):
- `DSTALK_BUILD_GUI=ON` — also builds the SDL3 GUI frontend (off by default; requires SDL3 from Conan/system)
- `DSTALK_BUILD_WEB=ON` — also builds the Boost.Beast web frontend
- `DSTALK_BUILD_TESTS=ON` — on by default; turn off to skip `tests/` subdir
## Tests
CTest, hand-rolled `CHECK`-macro tests (no GoogleTest — see "Network-restricted environments" below).
- **Run all tests**: `cd build && ctest --output-on-failure`
- **Run one test**: `cd build && ctest -R dstalk_smoke_test --output-on-failure` (or any other target name from `tests/CMakeLists.txt`)
- **Run a test binary directly**: `build/bin/dstalk_smoke_test.exe`
- **Coverage**: `cmake --build build --target coverage` (requires gcovr + gcov/llvm-cov; produces `build/coverage/index.html`). Only meaningful when configured with `--coverage` flags.
Test targets and what they cover (see `tests/CMakeLists.txt` for the authoritative list):
- `dstalk_smoke_test` — loads real plugins from `build/plugins/`, end-to-end integration; also carries regression cases (R1-R4, W21.5 tool calls)
- `dstalk_host_api_test`, `dstalk_event_bus_test`, `dstalk_service_registry_test` — core unit tests; compile core sources directly
- `dstalk_plugin_loader_test``PluginLoader` regression tests; sets `DSTALK_TEST_PLUGINS_DIR` to `build/plugins`
- `dstalk_context_plugin_test` — token/trim/UTF-8 boundaries
- `dstalk_anthropic_plugin_test`, `dstalk_openai_plugin_test` — currently `#include` the plugin `.cpp` to reach static functions; link `ai_common` and need its include dir
- `dstalk_network_plugin_test``#include`s `network_plugin.cpp`; needs OpenSSL
- `dstalk_lsp_plugin_test` — tests `lsp_trim` / `lsp_frame_message` / `lsp_parse_content_length` via `lsp_internal.hpp`
> After changing the `dstalk_host_api_t` vtable layout, **clean rebuild is mandatory** — stale `.obj` files from incremental builds will surface as segfaults in `dstalk_smoke_test` / `dstalk_host_api_test` because the test binary and plugin DLLs disagree on struct layout. `rm -rf build && bash build.sh` (or delete the specific stale `.obj` files) before re-running ctest.
## Architecture
Plugin-host design. One process loads a host DLL and many plugin DLLs over a C ABI.
```
Frontend (dstalk_cli / dstalk_gui / dstalk_web)
│ links dstalk_frontend_common (shared bootstrap: config discovery, init,
│ service queries, FrontendServices struct)
dstalk_core.dll ─ PluginLoader · ServiceRegistry · EventBus · Config · Logging · Memory
│ C ABI (dstalk_host.h)
Plugins (loaded from build/plugins/, each plugin_*.dll exports dstalk_plugin_init)
plugins_base/ config, file_io, lsp
plugins_middle/ network, session, tools (may depend on base)
plugins_upper/ context, openai, anthropic (may depend on base + middle)
ai_common (static lib, shared by openai + anthropic)
```
**Plugin tiers matter for build order and dependency direction**: `plugins_upper` may depend on `plugins_middle` may depend on `plugins_base`. Never let a lower tier depend on a higher one.
**Service vtables are the only cross-DLL contract.** Plugins register service vtables (e.g. `dstalk_ai_service_t`, `dstalk_http_service_t`, `dstalk_session_service_t`) with the host's `ServiceRegistry`. Other plugins query them by `(name, min_version)`. The vtable shapes live in `dstalk_core/include/dstalk/dstalk_services.h`; the host API the loader passes into each plugin's `on_init(const dstalk_host_api_t* host)` lives in `dstalk_host.h`.
**`ai_common`** (`plugins_upper/ai_common/`) is a static library, not a plugin. It holds shared types (`PluginConfig`, `ToolCallAccum`, `StreamContext`) and utilities (`secure_zero`, `extract_host_port`, `serialize_tool_calls`, `free_chat_result`) used by both `openai` and `anthropic` plugins, all under `namespace dstalk_ai`. Each plugin still compiles as its own DLL and links `ai_common` privately.
### Hard rules — violating these is undefined behavior
These come from `docs/reference/plugin-abi.md`. They are not style; they are correctness.
1. **Cross-DLL heap discipline.** Plugins must NOT call `std::malloc`/`std::free`/`std::strdup`/`new`/`delete` on data that crosses the host↔plugin boundary. Use `host->alloc` / `host->free` / `host->strdup` (passed in via `on_init`). Windows /MD CRTs and even some Linux/libc configs give each DLL its own heap — mismatched alloc/free crashes.
2. **API version is a hard match.** `dstalk_plugin_info_t.api_version` must equal `DSTALK_API_VERSION` (currently 1). The loader rejects mismatches. There is no backward compat — rebuild plugins against the new host.
3. **String ownership.** `dstalk_chat_result_t.content` / `.error` / `.tool_calls_json` are allocated with `host->strdup` by the producing plugin; the **caller** frees them with `host->free`. Never return `std::string::c_str()` or stack buffers across the ABI.
4. **C ABI + atomic globals.** Service vtable function pointers and cached service pointers stored in plugin global state must be `std::atomic<T*>` with `memory_order_acquire`/`release`. Raw pointers race during shutdown (the anthropic plugin's `g_config` was a real data race fixed in W17). Same applies to `g_host`, `g_http`, etc.
5. **No new on plugin info strings.** `name`/`version`/`description` in `dstalk_plugin_info_t` only need to live for the duration of `dstalk_plugin_init()` — string literals are fine; the host copies them.
6. **`api_key` lifecycle.** On `on_shutdown`, overwrite key bytes via `volatile char*` loop, then clear the string. Use `dstalk_ai::secure_zero()` from `ai_common`.
### Build system facts that bite
- **Conan provides only `boost::boost`** (the umbrella target) — granular `boost::json` / `boost::asio` / `boost::beast` targets do **not** exist in this Conan setup. Don't migrate to them.
- Every Boost-using target needs its own `find_package(Boost REQUIRED CONFIG)` call before `target_link_libraries(... boost::boost)`. Doing it once at the root is not enough; Conan-generated config exposes the target per-subdir.
- `dstalk_boost_config` (defined in `dstalk_core/CMakeLists.txt`) is an INTERFACE target carrying `BOOST_ALL_NO_LIB` / `BOOST_ERROR_CODE_HEADER_ONLY` / `BOOST_JSON_HEADER_ONLY`. Link it from any target using `<boost/json.hpp>`. Boost.JSON is header-only here, so each TU using it must `#include <boost/json/src.hpp>` in exactly one file (see `dstalk_core/src/boost_json.cpp`).
- The network plugin links `openssl::openssl` directly and calls OpenSSL C APIs (`SSL_set_tlsext_host_name`, `SSL_set1_host`). Don't remove that link "because Boost.Asio already pulls SSL" — it doesn't, at least not for these symbols.
- Plugin DLLs are written to `build/plugins/` with `PREFIX ""` (so `plugin_openai.dll`, not `libplugin_openai.dll`). The smoke test and plugin_loader_test look there.
### Network-restricted environments
This repo has been built in environments where outbound network is locked down. Two consequences:
- **Do not add `FetchContent` for GitHub-hosted deps.** `git clone` over smart-HTTP fails with "early EOF" through restrictive proxies (W17.3 lesson — GoogleTest dropped, tests use hand-rolled `CHECK` macros instead). Vendor or skip.
- Conan packages must be pre-cached in `~/.conan2/p/` or available through an allowed mirror; first-run `conan install` from a clean cache will fail offline.
## Repo-specific conventions
- **README is in Chinese (Simplified).** Comments and docs are routinely bilingual (Chinese first, English after a `/`). Match the surrounding style in any file you edit.
- **Plugin code style.** Each AI/network plugin has the same skeleton: `g_host` (atomic), `g_cfg`, `on_init`/`on_shutdown`, service vtable, then `dstalk_plugin_init()` returning a static `dstalk_plugin_info_t`. Keep that shape when adding a provider.
- **CRT.** `CMAKE_MSVC_RUNTIME_LIBRARY=MultiThreadedDLL` (/MD). All plugins and the host must agree.
## The `agents/` directory
`agents/` (README, WORKFLOW.md, STATUS.md, per-agent `profile.md` files, `groups/`) describes a multi-agent collaboration mode used by the project owner with Claude. It is **documentation, not code** — nothing in the build references it. Treat its contents as historical context; do not invent or extend the "16-person team / waves / 6-stage workflow" apparatus unless the user explicitly asks for it. If the user does invoke it, the rules are in `agents/WORKFLOW.md`.

View File

@@ -12,6 +12,7 @@ option(DSTALK_BUILD_WEB "Build the web UI frontend" OFF)
option(DSTALK_BUILD_TESTS "Build dstalk tests" ON) option(DSTALK_BUILD_TESTS "Build dstalk tests" ON)
add_subdirectory(dstalk_core) add_subdirectory(dstalk_core)
add_subdirectory(dstalk_frontend_common)
add_subdirectory(dstalk_cli) add_subdirectory(dstalk_cli)
# 插件按依赖层级分三个目录 / Plugins split into three directories by dependency tier # 插件按依赖层级分三个目录 / Plugins split into three directories by dependency tier
add_subdirectory(plugins_base) add_subdirectory(plugins_base)

View File

@@ -1,6 +1,6 @@
# dstalk 实时编制状态 # dstalk 实时编制状态
> **最后更新**: 2026-05-27 > **最后更新**: 2026-05-31
> **数据来源**: 由 `scripts/refresh_status.py` 自动扫描全部 16 个 `agents/*/profile.md` + 5 个 `agents/groups/*.md` 生成。 > **数据来源**: 由 `scripts/refresh_status.py` 自动扫描全部 16 个 `agents/*/profile.md` + 5 个 `agents/groups/*.md` 生成。
## 表 1员工状态16 人) ## 表 1员工状态16 人)
@@ -40,7 +40,7 @@
## Wave 进度 ## Wave 进度
**已完成高水位**: W16.5(基于 16 份 profile.md 的 performance_log 聚合 **已完成高水位**: W17由 CEO 直接执行ai_common 模块提取
**已发现 Wave 编号**: W1.1, W2.1, W2.2, W5.1, W6.1, W7, W9.3, W9.4, W9.6, W9.10, W10.1, W10.2, W10.3, W10.4, W11, W11.1, W11.2, W11.3, W11.6, W11.7, W12, W12.1, W12.2, W12.4, W12.5, W12.6, W13.1, W13.2, W13.3, W13.4, W13.5, W13.6, W14.1, W14.2, W14.3, W14.4, W14.5, W15.1, W15.2, W15.3, W15.4, W15.5, W15.6, W15.7, W15.8, W15.9, W16.5 **已发现 Wave 编号**: W1.1, W2.1, W2.2, W5.1, W6.1, W7, W9.3, W9.4, W9.6, W9.10, W10.1, W10.2, W10.3, W10.4, W11, W11.1, W11.2, W11.3, W11.6, W11.7, W12, W12.1, W12.2, W12.4, W12.5, W12.6, W13.1, W13.2, W13.3, W13.4, W13.5, W13.6, W14.1, W14.2, W14.3, W14.4, W14.5, W15.1, W15.2, W15.3, W15.4, W15.5, W15.6, W15.7, W15.8, W15.9, W16.5, W17

249
agents/mailroom/README.md Normal file
View File

@@ -0,0 +1,249 @@
# Mailroom — agents 信箱系统
> **版本**: 1.0 草案 (2026-05-31 设计)
> **状态**: 试运行;格式稳定后由 CEO 决定是否在 WORKFLOW.md §15 正式纳入流程
> **设计目标**: 让无状态的子代理之间能异步传递消息,把进行中的协调状态从 CEO 的隐式上下文搬到磁盘
## 1. 设计动机
子代理是 stateless 的——每次 spawn 都是新会话,看不到其他子代理当前在做什么。现行所有协调依赖 CEO 在 prompt 里手动列禁忌、手动转交依赖、手动追踪谁回了什么。一波 6-9 路子代理并行时CEO 的上下文窗口很快被占满。
信箱让这些消息**显式落盘**
- CEO 派活前可以扫一遍候选执行者的 inbox看有没有未结的 handoff / blocker
- 子代理 spawn 时可以读到自己的待办邮件,了解上下文
- 处理完毕的邮件归档到 `archive/W<n>/`,不丢失追溯,但不打扰当前视图
## 2. 目录结构
```
agents/mailroom/
├── README.md # 本文件
├── inbox/
│ ├── ceo/ # CEO 的收件箱
│ ├── architect-lin/ # 各 agent 的收件箱(按需创建)
│ ├── engineer-zhao/
│ └── ...
└── archive/
├── W17/ # 按 Wave 归档已处理邮件
├── W18/
└── ...
```
- `inbox/<agent-id>/` —— 待处理邮件status: pending / read / in_progress
- `archive/W<n>/` —— 处理完毕邮件status: closed按 Wave 归档保留
- 收件箱目录按需创建,未创建表示该 agent 当前无邮件
## 3. 权限规则
文件系统不强制权限;权限靠子代理 prompt 中的禁忌条款约束。
| 角色 | 自己的 inbox | 别人的 inbox | archive/ |
|------|--------------|--------------|----------|
| 接收者本人 | 读 / 改 status / 移到 archive | 只能投递新邮件write-only | 只读 |
| CEO | 所有权限 | 所有权限(含重投递、强制保留) | 所有权限 |
| 其他子代理 | — | 只能投递新邮件 | 只读 |
派子代理时 prompt 必须显式声明可访问的邮箱范围。新增防御性规则 **R-MAIL-SCOPE**:子代理不得读写超出自身 inbox 之外的他人邮箱内容,仅可投递新邮件。
## 4. 消息文件命名
`<timestamp>-<kind>-<from>-to-<to>.md`
- timestamp: `YYYY-MM-DDTHHmm` 本地时间,分钟级即可区分
- kind: 见 §5
- from / to: agent-idCEO 用 `ceo`,广播用 `all`
示例:`2026-05-31T1430-task-ceo-to-architect-lin.md`
## 5. 五种邮件类型
| kind | 用途 | 投递方 | 接收方 |
|------|------|--------|--------|
| task | CEO 派活prompt 副本,便于事后追溯) | CEO | 执行者 |
| report | 子代理回执done / blocked / aborted | 执行者 | CEO |
| handoff | 子代理之间工作交接(前序产出移交下一棒) | 子代理 A | 子代理 B |
| blocker | 阻塞通知(依赖未满足) | 任何 | CEO必抄送 |
| notice | 广播in-flight 文件锁定 / 流程变更 / 元数据自检失败) | CEO 或系统 | all |
## 6. Frontmatter schema
```yaml
---
id: msg-W17.3-001 # 全局唯一,格式 msg-<Wave>-<序号>
from: ceo
to: architect-lin # 单人用 agent-id广播用 all
wave: W17.3
kind: task # task | report | handoff | blocker | notice
status: pending # pending | read | in_progress | closed
created: 2026-05-31T14:30
closed: # 处理完毕时填写YYYY-MM-DDTHHmm
related: # 关联消息 id 列表(如本报告对应的 task
- msg-W17.3-000
fixes: # 关联 findings-registry 的 finding id如有
- F-13.5-1
---
邮件正文markdown
```
- 必填字段:`id` / `from` / `to` / `wave` / `kind` / `status` / `created`
- 选填字段:`closed` / `related` / `fixes`
- 字符串值不带引号YAML 1.2 标准)
## 7. 生命周期
```
[投递] [归档]
│ │
▼ ▼
inbox/<to>/ pending ──→ read ──→ in_progress ──→ 处理完毕 archive/W<n>/
status: closed
```
操作动作:
| 动作 | 谁执行 | 文件操作 | status 字段 |
|------|--------|----------|--------------|
| 投递 | 任意 agent | 在 `inbox/<to>/` 创建新文件 | pending |
| 阅读 | 接收者 | frontmatter 改 status | read |
| 开工 | 接收者 | frontmatter 改 status | in_progress |
| 完成 | 接收者 | 移动文件到 `archive/W<n>/` + 改 status | closed |
| 重投递 | 仅 CEO | 复制 archive 中的文件到 inbox保留原件 | pending |
接收者**不可物理删除**邮件——只允许移到 archive如需删除由 CEO 手动操作。这是与 R-NO-FORCE-PUSH 同等级的强约束(新增 **R-MAIL-NO-DELETE**)。
## 8. 与现有机制的边界
| 机制 | 职责 | 与 mailroom 的区别 |
|------|------|---------------------|
| `STATUS.md` | 实时编制状态快照 | 由 `scripts/refresh_status.py` 扫 profile.md 生成mailroom 是消息流 |
| `agents/<id>/profile.md` performance_log | 永久任务履历 | 任务粒度事后追加mailroom 是消息粒度,进行中 |
| `agents/audits/findings-registry.md` | 审计发现追踪 | 跨 Wave 持续追踪mailroom 单条消息生命周期短(一波内闭环) |
| `WORKFLOW.md` | 流程规范 | 是规则mailroom 是规则产生的数据 |
**重叠处理原则**:邮件正文 ≠ 履历正文。task 邮件中的 prompt 是 CEO 派活的原始材料report 邮件中的回执是子代理的原始报告;最终摘要仍按现有流程写到 profile.md 的 performance_log 一条短记录。**邮件保留细节profile.md 保留摘要,两者互补不重复**。
审计发现的归宿仍是 findings-registry。blocker 邮件如对应某个 findingfrontmatter 的 `fixes` 字段填写 finding idCEO INSPECT 时按 §14.4 A3 检查关联。
## 9. CEO 派活前 5 步检查
扩展 [PROMPT_TEMPLATE.md](PROMPT_TEMPLATE.md) 的 CEO 派活前 4 步检查,新增第 5 步:
| # | 检查项 | 操作 | 写入模板字段 |
|---|--------|------|--------------|
| 1 | 列 in-flight 工作区 | 查当前有哪些子代理在跑 | 禁忌 |
| 2 | 找前序 Wave 产出 | 从 performance_log 追溯 | 前序成果 |
| 3 | 设定任务范围三档 | 必做 / 可做 / 不做 | 任务范围 |
| 4 | 统一字数上限 | 固定 250 字 | 字数上限 |
| **5** | **扫候选执行者 inbox** | `ls agents/mailroom/inbox/<候选id>/` 看 pending 的 handoff / blocker | 禁忌 / 前序成果 |
若候选执行者 inbox 有未结的 blocker依赖未满足→ 优先让其处理 blocker 再派新活,否则新活也会立即变 blocker。
## 10. 邮件示例
### 10.1 task 邮件
文件:`inbox/architect-lin/2026-05-31T1500-task-ceo-to-architect-lin.md`
```markdown
---
id: msg-W18.1-001
from: ceo
to: architect-lin
wave: W18.1
kind: task
status: pending
created: 2026-05-31T15:00
---
# W18.1 任务: 评审 mailroom 设计
请读 agents/mailroom/README.md 并评估:
- 是否值得正式纳入 WORKFLOW.md §15
- 权限规则在 prompt 层面如何强约束
- 与 findings-registry 的边界是否清晰
字数上限 250。完成后请发 report 邮件回 ceo。
```
### 10.2 report 邮件
文件:`inbox/ceo/2026-05-31T1630-report-architect-lin-to-ceo.md`
```markdown
---
id: msg-W18.1-002
from: architect-lin
to: ceo
wave: W18.1
kind: report
status: pending
created: 2026-05-31T16:30
related:
- msg-W18.1-001
---
# W18.1 评审结论
赞成纳入 §15。建议补充三条边界规则
1. 接收者不可删除别人邮箱里的邮件
2. CEO 重投递时必须复制而非移动,保留 archive 原件
3. blocker 邮件必须抄送 CEO
profile.md 已追加rating: A-。
```
### 10.3 handoff 邮件
文件:`inbox/qa-xu/2026-05-31T1700-handoff-engineer-sun-to-qa-xu.md`
```markdown
---
id: msg-W18.2-003
from: engineer-sun
to: qa-xu
wave: W18.2
kind: handoff
status: pending
created: 2026-05-31T17:00
related:
- msg-W18.2-001
---
# W18.2 LSP 死锁修复完成,请测
修复了 lsp_plugin.cpp:312 的死锁。需要你跑:
- `ctest -R lsp_plugin_test`
- `ctest -R smoke`
测试通过后请新建一封 report 邮件发给 CEO 并抄送我。
```
## 11. 后续路线v1.0 之外)
格式稳定 + CEO 在 2-3 个 Wave 中实战使用后,再考虑:
- 在 WORKFLOW.md §15 加一节正式纳入流程
- 补充 `scripts/mailroom_summary.py` 自动汇总每个 agent 的未读邮件数与最老 pending 邮件年龄
- 在 STATUS.md 增加 邮件待办数 列
- 扩展 `scripts/check_agents_metadata.py` 校验邮件 frontmatter
不在 v1.0 范围。先用 1-2 个 Wave 看实际效果再说。
## 12. 关联文档
- [WORKFLOW.md](../WORKFLOW.md) — 流程主文档mailroom 将来可能并入 §15
- [PROMPT_TEMPLATE.md](../PROMPT_TEMPLATE.md) — 派活模板(影响第 5 步检查)
- [STATUS.md](../STATUS.md) — 实时编制状态
- [findings-registry.md](../audits/findings-registry.md) — 审计发现追踪
- [POSTMORTEM.md](../POSTMORTEM.md) — 防御性规则(新增 R-MAIL-SCOPE / R-MAIL-NO-DELETE
## 13. 变更历史
| 日期 | 版本 | 变更 |
|------|------|------|
| 2026-05-31 | 1.0 草案 | 初始化。CEO 与用户讨论后落地,方案 2归档保留不真删 |

View File

@@ -15,3 +15,9 @@ find_package(Boost REQUIRED CONFIG)
target_link_libraries(dstalk_cli target_link_libraries(dstalk_cli
PRIVATE dstalk boost::boost dstalk_boost_config PRIVATE dstalk boost::boost dstalk_boost_config
) )
# POSIX 平台需要 pthread (用于 std::thread spinner)
if(NOT WIN32)
find_package(Threads REQUIRED)
target_link_libraries(dstalk_cli PRIVATE Threads::Threads)
endif()

View File

@@ -7,12 +7,15 @@
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <chrono>
#include <cstdio> #include <cstdio>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include <filesystem> #include <filesystem>
#include <iostream>
#include <string> #include <string>
#include <system_error> #include <system_error>
#include <thread>
#include <vector> #include <vector>
#include <boost/json.hpp> #include <boost/json.hpp>
@@ -64,6 +67,8 @@ static const dstalk_tools_service_t* g_tools = nullptr;
static std::string g_current_model; static std::string g_current_model;
static std::atomic<bool> g_quit_requested{false}; static std::atomic<bool> g_quit_requested{false};
static std::atomic<bool> g_quit_via_signal{false}; static std::atomic<bool> g_quit_via_signal{false};
static std::atomic<bool> g_spinning{false};
static std::thread g_spinner_thread;
// ---- Ctrl+C 信号处理 / Ctrl+C signal handlers ---- // ---- Ctrl+C 信号处理 / Ctrl+C signal handlers ----
// Windows console event handler (CTRL_C_EVENT / CTRL_BREAK_EVENT). // Windows console event handler (CTRL_C_EVENT / CTRL_BREAK_EVENT).
@@ -90,6 +95,138 @@ static void on_signal(int /*sig*/)
// ---- 工具函数 / Utility functions ---- // ---- 工具函数 / Utility functions ----
// ---- 进度指示器 (spinner) / Progress indicator (spinner) ----
// 在等待 AI 响应时在 stderr 显示旋转字符,通过 atomic flag 控制启停。
// Displays a rotating character on stderr while waiting for AI responses, controlled via atomic flag.
static void spinner_run()
{
const char chars[] = "|/-\\";
int i = 0;
while (g_spinning.load(std::memory_order_relaxed)) {
std::fprintf(stderr, "\r%c", chars[i % 4]);
std::fflush(stderr);
i++;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
// 光标归位(不擦除,由下一条 stdout 输出覆盖) / Return cursor (don't erase, let next stdout output overwrite)
std::fprintf(stderr, "\r");
std::fflush(stderr);
}
static void spinner_start()
{
if (g_spinner_thread.joinable()) {
g_spinner_thread.join();
}
g_spinning = true;
g_spinner_thread = std::thread(spinner_run);
}
static void spinner_stop()
{
g_spinning = false;
}
static void spinner_join()
{
if (g_spinner_thread.joinable()) {
g_spinner_thread.join();
}
}
// ---- 错误分类与友好提示 / Error classification and user-friendly messages ----
// 根据 HTTP 状态码和错误消息字符串匹配,将常见错误归类为认证/频率限制/网络/配额问题,并给出中文建议。
// Classifies common errors into auth/rate-limit/network/quota categories based on HTTP status and string matching, with Chinese suggestions.
static void print_error(const char* error_msg, int http_status)
{
std::string msg(error_msg ? error_msg : "unknown error");
const char* category = nullptr;
const char* suggestion = nullptr;
// 先按 HTTP 状态码分类(最可靠) / First classify by HTTP status code (most reliable)
switch (http_status) {
case 401:
case 403:
category = "认证失败";
suggestion = "请检查 API key 是否正确(输入 /status 查看当前配置)";
break;
case 429:
category = "请求频率限制";
suggestion = "API 调用太频繁,请稍后重试或降低请求频率";
break;
case 400:
category = "请求参数错误";
suggestion = "请求格式不正确,可能是模型名或参数有误(输入 /status 查看)";
break;
case 502:
case 503:
case 504:
category = "服务器错误";
suggestion = "API 服务器暂时不可用,请稍后重试";
break;
default:
break;
}
// http_status 未覆盖 → 字符串模式匹配 / HTTP status not covered → string pattern matching
if (!category) {
if (msg.find("401") != std::string::npos ||
msg.find("403") != std::string::npos ||
msg.find("Unauthorized") != std::string::npos ||
msg.find("Forbidden") != std::string::npos ||
msg.find("authentication") != std::string::npos ||
msg.find("invalid api key") != std::string::npos ||
msg.find("Incorrect API key") != std::string::npos) {
category = "认证失败";
suggestion = "请检查 API key 是否正确(输入 /status 查看当前配置)";
} else if (msg.find("429") != std::string::npos ||
msg.find("rate") != std::string::npos ||
msg.find("Rate limit") != std::string::npos ||
msg.find("too many requests") != std::string::npos) {
category = "请求频率限制";
suggestion = "API 调用太频繁,请稍后重试或降低请求频率";
} else if (msg.find("connection refused") != std::string::npos ||
msg.find("Connection refused") != std::string::npos ||
msg.find("connection reset") != std::string::npos ||
msg.find("Connection reset") != std::string::npos ||
msg.find("timed out") != std::string::npos ||
msg.find("Timeout") != std::string::npos ||
msg.find("network") != std::string::npos ||
msg.find("Network") != std::string::npos ||
msg.find("resolve") != std::string::npos ||
msg.find("Name or service not known") != std::string::npos ||
msg.find("Couldn't resolve") != std::string::npos ||
msg.find("Failed to connect") != std::string::npos ||
msg.find("Could not connect") != std::string::npos ||
msg.find("could not connect") != std::string::npos ||
msg.find("connect error") != std::string::npos ||
msg.find("Connect error") != std::string::npos ||
msg.find("connect failed") != std::string::npos ||
msg.find("Connect failed") != std::string::npos) {
category = "网络错误";
suggestion = "无法连接到 API 服务器,请检查网络连接和 base_url输入 /status 查看)";
} else if (msg.find("400") != std::string::npos ||
msg.find("Bad Request") != std::string::npos) {
category = "请求参数错误";
suggestion = "请求格式不正确,可能是模型名或参数有误(输入 /status 查看)";
} else if (msg.find("insufficient") != std::string::npos ||
msg.find("quota") != std::string::npos ||
msg.find("billing") != std::string::npos) {
category = "配额不足";
suggestion = "API 配额已用完或账户余额不足,请检查账户状态";
}
}
if (category && suggestion) {
std::fprintf(stderr, CLR_RED "[ERROR] %s\n" CLR_RESET, category);
std::fprintf(stderr, CLR_YELLOW " -> %s\n" CLR_RESET, suggestion);
std::fprintf(stderr, CLR_DIM " (原始错误: %s)\n" CLR_RESET, msg.c_str());
} else {
std::fprintf(stderr, CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET, msg.c_str());
}
}
// 打印启动横幅 / Print the dstalk CLI banner with version, AI indicator, and quick command hints. // 打印启动横幅 / Print the dstalk CLI banner with version, AI indicator, and quick command hints.
static void print_banner() static void print_banner()
{ {
@@ -391,12 +528,15 @@ static void handle_command(const char* line)
} }
// ---- 流式回调 / Streaming callback ---- // ---- 流式回调 / Streaming callback ----
// 流式输出回调:每收到一个 token 打印到 stdout 并刷新 / Callback invoked for each token during streaming chat; prints the token to stdout and flushes. // 流式输出回调:每收到一个 token 打印到 stdout 并刷新
// 第一个 token 到达时停止 spinner 并用 \r 覆盖旋转字符。
// Callback invoked for each token during streaming chat; stops spinner on first token and overwrites the spinner character with \r.
static int on_stream_token(const char* token, void* userdata) static int on_stream_token(const char* token, void* userdata)
{ {
bool* first = static_cast<bool*>(userdata); bool* first = static_cast<bool*>(userdata);
if (*first) { if (*first) {
std::printf(CLR_GREEN); spinner_stop();
std::printf("\r" CLR_GREEN);
*first = false; *first = false;
} }
std::printf("%s", token); std::printf("%s", token);
@@ -404,6 +544,18 @@ static int on_stream_token(const char* token, void* userdata)
return 0; return 0;
} }
// ---- 管道 / --prompt 共用:从 stdin 读入全部内容 / Read all stdin content (shared by pipe and --prompt modes) ----
static std::string read_all_stdin()
{
std::string result;
std::string line;
while (std::getline(std::cin, line)) {
if (!result.empty()) result += '\n';
result += line;
}
return result;
}
// ---- 主程序 / Main entry point ---- // ---- 主程序 / Main entry point ----
// 入口:初始化 dstalk host查询插件服务处理 batch/pipe/交互模式。 // 入口:初始化 dstalk host查询插件服务处理 batch/pipe/交互模式。
// Entry point: initializes dstalk host, queries plugin services, handles batch/pipe/interactive modes. // Entry point: initializes dstalk host, queries plugin services, handles batch/pipe/interactive modes.
@@ -520,11 +672,7 @@ int main(int argc, char* argv[])
// ---- B3: 管道输入模式 (非交互) / Pipe input mode (non-interactive) ---- // ---- B3: 管道输入模式 (非交互) / Pipe input mode (non-interactive) ----
if (pipe_mode) { if (pipe_mode) {
std::string input; std::string input = read_all_stdin();
char buf[4096];
while (std::fgets(buf, sizeof(buf), stdin)) {
input += buf;
}
if (input.empty()) { if (input.empty()) {
std::fprintf(stderr, "empty prompt\n"); std::fprintf(stderr, "empty prompt\n");
dstalk_shutdown(); dstalk_shutdown();
@@ -544,8 +692,7 @@ int main(int argc, char* argv[])
dstalk_shutdown(); dstalk_shutdown();
return EXIT_OK; return EXIT_OK;
} else { } else {
std::fprintf(stderr, CLR_RED "[ERROR] AI error: %s\n" CLR_RESET, print_error(result.error, result.http_status);
result.error ? result.error : "unknown");
g_ai->free_result(&result); g_ai->free_result(&result);
dstalk_shutdown(); dstalk_shutdown();
return EXIT_FATAL; return EXIT_FATAL;
@@ -557,10 +704,7 @@ int main(int argc, char* argv[])
std::string prompt_text; std::string prompt_text;
if (std::strcmp(prompt_arg, "-") == 0) { if (std::strcmp(prompt_arg, "-") == 0) {
// --prompt - or --prompt (no arg): read prompt from stdin / --prompt - 或 --prompt无参数从 stdin 读取提示 // --prompt - or --prompt (no arg): read prompt from stdin / --prompt - 或 --prompt无参数从 stdin 读取提示
char buf[4096]; prompt_text = read_all_stdin();
while (std::fgets(buf, sizeof(buf), stdin)) {
prompt_text += buf;
}
if (prompt_text.empty()) { if (prompt_text.empty()) {
std::fprintf(stderr, "empty prompt\n"); std::fprintf(stderr, "empty prompt\n");
dstalk_shutdown(); dstalk_shutdown();
@@ -588,15 +732,15 @@ int main(int argc, char* argv[])
dstalk_shutdown(); dstalk_shutdown();
return EXIT_OK; return EXIT_OK;
} else { } else {
std::fprintf(stderr, CLR_RED "[ERROR] AI error: %s\n" CLR_RESET, print_error(result.error, result.http_status);
result.error ? result.error : "unknown");
g_ai->free_result(&result); g_ai->free_result(&result);
dstalk_shutdown(); dstalk_shutdown();
return EXIT_FATAL; return EXIT_FATAL;
} }
} }
char buffer[8192]; // ---- 交互模式主循环 / Interactive mode main loop ----
std::string line;
while (true) { while (true) {
// B1: 检查退出标志 / Check quit flag // B1: 检查退出标志 / Check quit flag
if (g_quit_requested) { if (g_quit_requested) {
@@ -611,26 +755,17 @@ int main(int argc, char* argv[])
std::fflush(stdout); std::fflush(stdout);
} }
if (!std::fgets(buffer, sizeof(buffer), stdin)) break; // 动态读取一行,无大小限制 / Read one line dynamically, no size limit
if (!std::getline(std::cin, line)) break;
// C3: fgets 截断检测 / fgets truncation detection // 去除末尾的 \rWindows / Strip trailing \r (Windows)
if (!std::strchr(buffer, '\n') && !feof(stdin)) { if (!line.empty() && line.back() == '\r') line.pop_back();
std::fprintf(stderr, CLR_RED "[ERROR] 输入超过 8KB已截断。建议用文件方式dstalk --batch < file.txt\n" CLR_RESET);
int c;
while ((c = std::fgetc(stdin)) != '\n' && c != EOF) {}
}
// 去除末尾换行 / Strip trailing newline if (line.empty()) continue;
size_t len = std::strlen(buffer);
while (len > 0 && (buffer[len-1] == '\n' || buffer[len-1] == '\r')) {
buffer[--len] = '\0';
}
if (len == 0) continue;
// 命令处理 / Command dispatch // 命令处理 / Command dispatch
if (buffer[0] == '/') { if (line[0] == '/') {
handle_command(buffer); handle_command(line.c_str());
continue; continue;
} }
@@ -644,14 +779,19 @@ int main(int argc, char* argv[])
int history_count = 0; int history_count = 0;
const dstalk_message_t* history = g_session->history(&history_count); const dstalk_message_t* history = g_session->history(&history_count);
// 启动 spinner等待 AI 响应 / Start spinner while waiting for AI response
spinner_start();
bool first = true; bool first = true;
dstalk_chat_result_t result = g_ai->chat_stream( dstalk_chat_result_t result = g_ai->chat_stream(
history, history_count, buffer, on_stream_token, &first); history, history_count, line.c_str(), on_stream_token, &first);
// 确保 spinner 已停止(处理无流式输出的情况) / Ensure spinner is stopped (handles no-stream-output case)
spinner_stop();
if (result.ok) { if (result.ok) {
std::printf(CLR_RESET "\n\n"); std::printf(CLR_RESET "\n\n");
// 将用户消息和 AI 回复添加到会话 / Add user message and AI reply to session // 将用户消息和 AI 回复添加到会话 / Add user message and AI reply to session
dstalk_message_t user_msg = {"user", buffer, nullptr, nullptr}; dstalk_message_t user_msg = {"user", line.c_str(), nullptr, nullptr};
g_session->add(&user_msg); g_session->add(&user_msg);
dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json}; dstalk_message_t ai_msg = {"assistant", result.content, nullptr, result.tool_calls_json};
g_session->add(&ai_msg); g_session->add(&ai_msg);
@@ -727,8 +867,10 @@ int main(int argc, char* argv[])
history = g_session->history(&history_count); history = g_session->history(&history_count);
g_ai->free_result(&result); g_ai->free_result(&result);
spinner_start();
bool tool_stream_first = true; bool tool_stream_first = true;
result = g_ai->chat_stream(history, history_count, nullptr, on_stream_token, &tool_stream_first); result = g_ai->chat_stream(history, history_count, nullptr, on_stream_token, &tool_stream_first);
spinner_stop();
if (result.ok) { if (result.ok) {
std::printf(CLR_RESET "\n"); std::printf(CLR_RESET "\n");
@@ -741,8 +883,7 @@ int main(int argc, char* argv[])
g_session->add(&ai_followup); g_session->add(&ai_followup);
has_tool_calls = (result.tool_calls_json && result.tool_calls_json[0] != '\0'); has_tool_calls = (result.tool_calls_json && result.tool_calls_json[0] != '\0');
} else { } else {
std::printf(CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET, print_error(result.error, result.http_status);
result.error ? result.error : "unknown error");
break; break;
} }
} }
@@ -751,14 +892,17 @@ int main(int argc, char* argv[])
std::fprintf(stderr, CLR_YELLOW "[WARN] 已达最大工具调用轮次(%d),停止\n" CLR_RESET, MAX_TOOL_ROUNDS); std::fprintf(stderr, CLR_YELLOW "[WARN] 已达最大工具调用轮次(%d),停止\n" CLR_RESET, MAX_TOOL_ROUNDS);
} }
} else { } else {
// A3: error 路径下需 NULL 保护;当前只取 result.errorcontent 未涉及 / Error path needs NULL guard; currently only reads result.error, content not involved // AI 调用失败reset 颜色,输出分类错误信息 / AI call failed: reset color, output classified error info
std::printf(CLR_RESET "\n" CLR_RED "[ERROR] AI 调用失败: %s\n" CLR_RESET, std::printf(CLR_RESET "\n");
result.error ? result.error : "unknown error"); print_error(result.error, result.http_status);
} }
g_ai->free_result(&result); g_ai->free_result(&result);
} }
// B2: 单一退出点dstalk_shutdown 只在此调用(交互模式下) / Single exit point, dstalk_shutdown only called here (in interactive mode) // B2: 单一退出点dstalk_shutdown 只在此调用(交互模式下) / Single exit point, dstalk_shutdown only called here (in interactive mode)
// 确保 spinner 线程已结束——先发信号停止,再 join 等待线程真正退出 / Ensure spinner thread has ended: signal stop first, then join to wait for thread exit
spinner_stop();
spinner_join();
dstalk_shutdown(); dstalk_shutdown();
return g_quit_via_signal ? EXIT_INTERRUPT : EXIT_OK; return g_quit_via_signal ? EXIT_INTERRUPT : EXIT_OK;
} }

View File

@@ -0,0 +1,17 @@
# ============================================================
# dstalk_frontend_common — 前端公共初始化静态库
# ============================================================
add_library(dstalk_frontend_common STATIC
src/frontend_common.cpp
)
target_include_directories(dstalk_frontend_common
PUBLIC include
)
target_link_libraries(dstalk_frontend_common
PUBLIC dstalk
)
target_compile_features(dstalk_frontend_common PUBLIC cxx_std_20)

View File

@@ -0,0 +1,65 @@
// ============================================================================
// dstalk_frontend_common — 前端公共初始化模块
// ============================================================================
// 提供所有前端CLI / GUI / Web共享的启动逻辑
// - 配置文件发现argv / 默认路径 / 平台 fopen
// - dstalk_init() 调用
// - 常用服务查询ai, session, file_io, tools, context
// - AI 服务默认配置(从 config 读取,带 fallback
// ============================================================================
#ifndef DSTALK_FRONTEND_COMMON_HPP
#define DSTALK_FRONTEND_COMMON_HPP
#include <string>
#include "dstalk/dstalk_host.h"
struct FrontendServices {
const dstalk_ai_service_t* ai = nullptr;
const dstalk_session_service_t* session = nullptr;
const dstalk_file_io_service_t* file_io = nullptr;
const dstalk_tools_service_t* tools = nullptr;
std::string provider; // "ai.deepseek" / "ai.openai" / "ai.anthropic"
std::string model; // e.g. "deepseek-v4-pro"
std::string base_url; // e.g. "https://api.deepseek.com/v1"
std::string api_key;
// 是否已成功初始化 dstalk 核心
bool initialized = false;
};
// ---- 前端公共初始化 ----
//
// 功能:
// 1. 发现配置文件:优先 argv[1](跳过已知标志),其次 default_config如 "config.toml"
// 2. 调用 dstalk_init(config_path)
// 3. 查询常用插件服务ai / session / file_io / tools
// 4. 用 dstalk_config_get 读取 api.* 键并调用 ai->configure() 设置默认值
//
// 参数:
// svc - [out] 填入查询到的服务指针和配置信息
// argc/argv - 命令行参数(可为 0/nullptr例如 GUI 没有命令行参数)
// default_cfg- 默认配置文件名(如 "config.toml"),当 argv 未提供时使用
// skip_flags - 以 NULL 结尾的字符串数组argv 扫描时跳过这些标志及其下一个参数
//
// 返回值:
// 0 - 成功svc.initialized == true至少 ai + session 已就绪
// 1 - dstalk_init 失败
// 2 - AI 服务未找到
// 3 - Session 服务未找到
//
int dstalk_frontend_init(FrontendServices& svc,
int argc = 0, char* argv[] = nullptr,
const char* default_cfg = "config.toml",
const char* const* skip_flags = nullptr);
// ---- 便捷辅助 ----
// 将 dstalk_message_t 数组的内容追加到 session 服务(一次一条)。
// 常用于 Ctrl+O 加载会话后重建前端消息列表。
// 返回实际追加的条数。
int dstalk_frontend_replay_history(FrontendServices& svc);
#endif // DSTALK_FRONTEND_COMMON_HPP

View File

@@ -0,0 +1,124 @@
// ============================================================================
// dstalk_frontend_common — 实现
// ============================================================================
#include "dstalk_frontend_common.hpp"
#include <cstdio>
#include <cstdlib>
#include <cstring>
// ---- 配置文件发现 ----
static const char* discover_config(int argc, char* argv[],
const char* default_cfg,
const char* const* skip_flags)
{
// 1) 从 argv 中查找首个非标志参数
if (argc >= 2 && argv) {
for (int i = 1; i < argc; ++i) {
bool is_skip = false;
if (skip_flags) {
for (int k = 0; skip_flags[k]; ++k) {
if (std::strcmp(argv[i], skip_flags[k]) == 0) {
is_skip = true;
// 若该标志有值参数则多跳一个
if (i + 1 < argc && argv[i + 1][0] != '-') ++i;
break;
}
}
}
if (!is_skip) return argv[i];
}
}
// 2) 回退:尝试默认配置文件
if (!default_cfg || default_cfg[0] == '\0') return nullptr;
const char* candidates[] = { default_cfg, nullptr };
for (int i = 0; candidates[i]; ++i) {
FILE* f = nullptr;
#ifdef _WIN32
if (fopen_s(&f, candidates[i], "r") == 0 && f) {
#else
f = std::fopen(candidates[i], "r");
if (f) {
#endif
std::fclose(f);
return candidates[i];
}
}
return nullptr;
}
// ---- 主初始化 ----
int dstalk_frontend_init(FrontendServices& svc,
int argc, char* argv[],
const char* default_cfg,
const char* const* skip_flags)
{
// (1) 发现并加载配置
const char* cfg = discover_config(argc, argv, default_cfg, skip_flags);
if (dstalk_init(cfg) != 0) {
std::fprintf(stderr, "[dstalk] 初始化失败\n");
return 1;
}
// (2) 查询 AI 服务
const char* provider = dstalk_config_get("ai.provider");
if (!provider || provider[0] == '\0') provider = "ai.deepseek";
svc.provider = provider;
svc.ai = static_cast<const dstalk_ai_service_t*>(
dstalk_service_query(provider, 1));
svc.session = static_cast<const dstalk_session_service_t*>(
dstalk_service_query("session", 1));
svc.file_io = static_cast<const dstalk_file_io_service_t*>(
dstalk_service_query("file_io", 1));
svc.tools = static_cast<const dstalk_tools_service_t*>(
dstalk_service_query("tools", 1));
const dstalk_context_service_t* ctx_svc =
static_cast<const dstalk_context_service_t*>(
dstalk_service_query("context", 1));
(void)ctx_svc; // 不强制使用,保留以备将来使用
if (!svc.ai) {
std::fprintf(stderr, "[dstalk] AI 服务未找到(请检查插件目录)\n");
return 2;
}
if (!svc.session) {
std::fprintf(stderr, "[dstalk] Session 服务未找到\n");
return 3;
}
// (3) 配置 AI 服务的默认值
const char* base_url = dstalk_config_get("api.base_url");
const char* api_key = dstalk_config_get("api.api_key");
const char* model = dstalk_config_get("api.model");
if (!base_url || base_url[0] == '\0') base_url = "https://api.deepseek.com/v1";
if (!model || model[0] == '\0') model = "deepseek-v4-pro";
svc.base_url = base_url;
svc.api_key = api_key ? api_key : "";
svc.model = model;
svc.ai->configure(provider, base_url,
api_key ? api_key : "",
model, 4096, 0.7);
svc.initialized = true;
return 0;
}
// ---- 会话历史回放 ----
int dstalk_frontend_replay_history(FrontendServices& svc)
{
if (!svc.session) return 0;
int count = 0;
const dstalk_message_t* msgs = svc.session->history(&count);
return count; // 调用方自行遍历并重建前端消息列表
}

View File

@@ -16,5 +16,6 @@ set_target_properties(dstalk_gui PROPERTIES
target_link_libraries(dstalk_gui target_link_libraries(dstalk_gui
PRIVATE PRIVATE
dstalk dstalk
dstalk_frontend_common
SDL3::SDL3 SDL3::SDL3
) )

View File

@@ -744,8 +744,23 @@ static void processEvent(AppContext& ctx, SDL_Event& ev) {
break; break;
case SDLK_O: case SDLK_O:
if (ctrl) { if (ctrl) {
// Ctrl+O加载会话 / Ctrl+O: load session // Ctrl+O加载会话
if (g_session_svc && g_session_svc->load("session.json") == 0) { if (g_session_svc && g_session_svc->load("session.json") == 0) {
// BUGFIX: 从 session 历史重建 GUI 消息列表
int hcount = 0;
const dstalk_message_t* history = g_session_svc->history(&hcount);
gs.messages.clear();
for (int i = 0; i < hcount; ++i) {
ChatMessage::Role r;
if (std::strcmp(history[i].role, "user") == 0)
r = ChatMessage::USER;
else if (std::strcmp(history[i].role, "assistant") == 0)
r = ChatMessage::ASSISTANT;
else
r = ChatMessage::SYSTEM;
gs.messages.push_back(ChatMessage(r,
history[i].content ? history[i].content : ""));
}
gs.messages.push_back(ChatMessage( gs.messages.push_back(ChatMessage(
ChatMessage::SYSTEM, "Session loaded from session.json")); ChatMessage::SYSTEM, "Session loaded from session.json"));
} else { } else {
@@ -895,8 +910,30 @@ int main(int argc, char* argv[]) {
g_ai_svc->free_result(&result); g_ai_svc->free_result(&result);
} }
// 流式传输完成(或被取消) / Streaming completed (or cancelled) // 流式传输完成(或被取消)
if (rc != 0) { if (rc == 0) {
// BUGFIX: 将用户消息和 AI 回复持久化到 session 服务
if (g_session_svc) {
const std::string& userContent =
ctx.state.messages[ctx.state.messages.size() - 2].content;
const std::string& aiContent =
ctx.state.messages.back().content;
dstalk_message_t user_msg = {
"user",
userContent.c_str(),
nullptr, nullptr
};
g_session_svc->add(&user_msg);
dstalk_message_t ai_msg = {
"assistant",
aiContent.c_str(),
nullptr, nullptr
};
g_session_svc->add(&ai_msg);
}
} else {
if (!ctx.state.messages.empty() && if (!ctx.state.messages.empty() &&
ctx.state.messages.back().role == ChatMessage::ASSISTANT) { ctx.state.messages.back().role == ChatMessage::ASSISTANT) {
if (ctx.state.messages.back().content.empty()) { if (ctx.state.messages.back().content.empty()) {

View File

@@ -0,0 +1,22 @@
// ============================================================================
// lsp_internal.hpp — 内部声明:供单元测试访问的 LSP 工具函数
// ============================================================================
// 仅在 tests 中使用;非 plugin 公共 API
// ============================================================================
#ifndef LSP_INTERNAL_HPP
#define LSP_INTERNAL_HPP
#include <string>
#include <string_view>
// ---- 字符串 trim ----
std::string_view lsp_trim(std::string_view sv);
// ---- 构建 LSP frame (Content-Length header + body) ----
std::string lsp_frame_message(const std::string& body);
// ---- 解析 Content-Length header ----
int lsp_parse_content_length(const std::string& line);
#endif // LSP_INTERNAL_HPP

View File

@@ -12,6 +12,7 @@
#include "dstalk/dstalk_host.h" #include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h" #include "dstalk/dstalk_services.h"
#include "lsp_internal.hpp"
#include <boost/json.hpp> #include <boost/json.hpp>
#include <boost/json/src.hpp> #include <boost/json/src.hpp>
@@ -311,7 +312,7 @@ static LspState g_lsp;
// ============================================================================ // ============================================================================
// 去除 string_view 首尾空白 / Trim leading and trailing whitespace from a string_view. // 去除 string_view 首尾空白 / Trim leading and trailing whitespace from a string_view.
static std::string_view trim(std::string_view sv) { std::string_view lsp_trim(std::string_view sv) {
while (!sv.empty() && (sv.front() == ' ' || sv.front() == '\t' || while (!sv.empty() && (sv.front() == ' ' || sv.front() == '\t' ||
sv.front() == '\r' || sv.front() == '\n')) sv.front() == '\r' || sv.front() == '\n'))
sv.remove_prefix(1); sv.remove_prefix(1);
@@ -322,7 +323,7 @@ static std::string_view trim(std::string_view sv) {
} }
// 将 JSON-RPC 消息体包装在 LSP 头中 (Content-Length: ...\r\n\r\n) / Wrap a JSON-RPC message body in an LSP header (Content-Length: ...\r\n\r\n). // 将 JSON-RPC 消息体包装在 LSP 头中 (Content-Length: ...\r\n\r\n) / Wrap a JSON-RPC message body in an LSP header (Content-Length: ...\r\n\r\n).
static std::string frame_message(const std::string& body) { std::string lsp_frame_message(const std::string& body) {
std::string frame; std::string frame;
frame.reserve(64 + body.size()); frame.reserve(64 + body.size());
frame += "Content-Length: "; frame += "Content-Length: ";
@@ -333,8 +334,8 @@ static std::string frame_message(const std::string& body) {
} }
// 从 LSP 头行中解析 Content-Length 值。解析失败返回 -1 / Parse the Content-Length value from an LSP header line. Returns -1 on parse failure. // 从 LSP 头行中解析 Content-Length 值。解析失败返回 -1 / Parse the Content-Length value from an LSP header line. Returns -1 on parse failure.
static int parse_content_length(const std::string& line) { int lsp_parse_content_length(const std::string& line) {
auto sv = trim(std::string_view(line)); auto sv = lsp_trim(std::string_view(line));
const char prefix[] = "Content-Length:"; const char prefix[] = "Content-Length:";
const size_t prefix_len = sizeof(prefix) - 1; const size_t prefix_len = sizeof(prefix) - 1;
@@ -368,7 +369,7 @@ static int send_request(const std::string& method, const json::object& params) {
msg["params"] = params; msg["params"] = params;
std::string body = json::serialize(msg); std::string body = json::serialize(msg);
g_lsp.proc.write(frame_message(body)); g_lsp.proc.write(lsp_frame_message(body));
return id; return id;
} }
@@ -380,7 +381,7 @@ static void send_notification(const std::string& method, const json::object& par
msg["params"] = params; msg["params"] = params;
std::string body = json::serialize(msg); std::string body = json::serialize(msg);
g_lsp.proc.write(frame_message(body)); g_lsp.proc.write(lsp_frame_message(body));
} }
// ============================================================================ // ============================================================================
@@ -457,11 +458,11 @@ static void reader_loop() {
} }
// header 块以空行结束 / header block ends with empty line // header 块以空行结束 / header block ends with empty line
auto sv = trim(std::string_view(line)); auto sv = lsp_trim(std::string_view(line));
if (sv.empty()) break; if (sv.empty()) break;
// 累积 Content-Length遇到其他 header 不丢弃,继续读取下一行 / Accumulate Content-Length; don't discard other headers, continue reading next line // 累积 Content-Length遇到其他 header 不丢弃,继续读取下一行 / Accumulate Content-Length; don't discard other headers, continue reading next line
int len = parse_content_length(line); int len = lsp_parse_content_length(line);
if (len >= 0) content_length = len; if (len >= 0) content_length = len;
} }

View File

@@ -35,6 +35,54 @@ namespace asio = boost::asio;
namespace ssl = boost::asio::ssl; namespace ssl = boost::asio::ssl;
using tcp = asio::ip::tcp; using tcp = asio::ip::tcp;
// ============================================================
// 安全常量和输入验证辅助函数 / Security constants and input-validation helpers
// ============================================================
static constexpr size_t MAX_HEADER_KEY_LENGTH = 256;
static constexpr size_t MAX_HEADER_VALUE_LENGTH = 8192;
/// 如果字符串包含任何控制字符(< 0x20 或 0x7F DEL返回 true / Return true if the string contains any control character (< 0x20 or 0x7F DEL).
static bool contains_control_chars(const char* s) {
if (!s) return false;
for (const char* p = s; *p; ++p) {
unsigned char c = static_cast<unsigned char>(*p);
if (c < 0x20u || c == 0x7Fu) return true;
}
return false;
}
/// 基本端口/服务名验证。拒绝空值,对数字端口进行范围检查,
/// 对符号服务名要求字母数字加连字符RFC 6335/ Basic port / service-name validation.
/// Rejects empty, bounds-checks numeric ports, and requires alphanumeric+hyphen
/// for symbolic service names (RFC 6335).
static bool is_valid_port(const char* port) {
if (!port || !*port) return false;
bool all_digits = true;
for (const char* p = port; *p; ++p) {
if (static_cast<unsigned char>(*p) < '0' ||
static_cast<unsigned char>(*p) > '9') {
all_digits = false;
break;
}
}
if (all_digits) {
if (std::strlen(port) > 5) return false; // > 65535
long p = std::atol(port);
return p > 0 && p <= 65535;
}
// 服务名字母数字加连字符RFC 6335最长15个字符 / Service name: alphanumeric plus hyphen (RFC 6335); max 15 chars
for (const char* p = port; *p; ++p) {
unsigned char c = static_cast<unsigned char>(*p);
if (!((c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '-')) {
return false;
}
}
return std::strlen(port) <= 15;
}
// ============================================================ // ============================================================
// 全局状态 / Global state // 全局状态 / Global state
// ============================================================ // ============================================================
@@ -42,8 +90,11 @@ static const dstalk_host_api_t* g_host = nullptr;
static dstalk_config_service_t* g_config_svc = nullptr; static dstalk_config_service_t* g_config_svc = nullptr;
// ============================================================ // ============================================================
// 极简 JSON 头解析器 / Minimal JSON header parser // 极简 JSON 头解析器(含安全长度限制)/ Minimal JSON header parser (with security length limits)
// 将 {"key1":"value1","key2":"value2"} 解析到 unordered_map / Parses {"key1":"value1","key2":"value2"} into unordered_map // 将 {"key1":"value1","key2":"value2"} 解析到 unordered_map / Parses {"key1":"value1","key2":"value2"} into unordered_map
// 强制 MAX_HEADER_KEY_LENGTH (256) 和 MAX_HEADER_VALUE_LENGTH (8192) 限制,
// 防止恶意输入导致资源耗尽 / Enforces MAX_HEADER_KEY_LENGTH (256) and MAX_HEADER_VALUE_LENGTH
// (8192) to prevent resource exhaustion from malicious input.
// ============================================================ // ============================================================
// 将扁平 JSON 对象中的字符串键值对解析到 unordered_map / Parse a flat JSON object of string key-value pairs into an unordered_map. // 将扁平 JSON 对象中的字符串键值对解析到 unordered_map / Parse a flat JSON object of string key-value pairs into an unordered_map.
static std::unordered_map<std::string, std::string> parse_headers_json(const char* json) { static std::unordered_map<std::string, std::string> parse_headers_json(const char* json) {
@@ -55,31 +106,53 @@ static std::unordered_map<std::string, std::string> parse_headers_json(const cha
enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE; enum { OUTSIDE, IN_KEY, AFTER_KEY, IN_VALUE } state = OUTSIDE;
std::string current_key; std::string current_key;
std::string current_value; std::string current_value;
bool key_too_long = false;
for (size_t i = 0; i < s.size(); ++i) { for (size_t i = 0; i < s.size(); ++i) {
char c = s[i]; char c = s[i];
switch (state) { switch (state) {
case OUTSIDE: case OUTSIDE:
if (c == '"') { state = IN_KEY; current_key.clear(); } if (c == '"') { state = IN_KEY; current_key.clear(); key_too_long = false; }
break; break;
case IN_KEY: case IN_KEY:
if (c == '"') { state = AFTER_KEY; } if (c == '"') { state = AFTER_KEY; }
else if (c == '\\' && i + 1 < s.size()) { current_key += s[++i]; } else if (c == '\\' && i + 1 < s.size()) {
else { current_key += c; } if (current_key.size() < MAX_HEADER_KEY_LENGTH)
current_key += s[++i];
else { ++i; key_too_long = true; }
}
else {
if (current_key.size() < MAX_HEADER_KEY_LENGTH)
current_key += c;
else
key_too_long = true;
}
break; break;
case AFTER_KEY: case AFTER_KEY:
if (c == ':') { state = IN_VALUE; current_value.clear(); } if (c == ':') { state = IN_VALUE; current_value.clear(); }
// 跳过键和冒号之间的多余字符(例如 ',' 或 '}'/ Skip stray characters between key and colon (e.g. ',' or '}')
else if (c == '"' || c == ',' || c == '}') { /* 保持 AFTER_KEY 状态,忽略 / stay in AFTER_KEY, ignore */ }
break; break;
case IN_VALUE: case IN_VALUE:
if (c == '"') { if (c == '"') {
// 读取到闭合引号 / Read until closing quote // 读取到闭合引号 / Read until closing quote
++i; ++i;
while (i < s.size() && s[i] != '"') { while (i < s.size() && s[i] != '"') {
if (s[i] == '\\' && i + 1 < s.size()) { current_value += s[++i]; } if (s[i] == '\\' && i + 1 < s.size()) {
else { current_value += s[i]; } if (current_value.size() < MAX_HEADER_VALUE_LENGTH)
current_value += s[++i];
else
++i; // 跳过转义字符,值已截断 / skip escaped char, value truncated
}
else {
if (current_value.size() < MAX_HEADER_VALUE_LENGTH)
current_value += s[i];
}
++i; ++i;
} }
if (!key_too_long) {
headers[current_key] = current_value; headers[current_key] = current_value;
}
state = OUTSIDE; state = OUTSIDE;
} }
break; break;
@@ -93,19 +166,76 @@ static std::unordered_map<std::string, std::string> parse_headers_json(const cha
// ============================================================ // ============================================================
struct HttpClientCtx { struct HttpClientCtx {
asio::io_context ioc; asio::io_context ioc;
ssl::context ssl_ctx{ssl::context::tlsv12_client}; ssl::context ssl_ctx{ssl::context::tls_client};
int connect_timeout = 30; int connect_timeout = 30;
int request_timeout = 120; int request_timeout = 120;
HttpClientCtx() { HttpClientCtx() {
ssl_ctx.set_default_verify_paths(); // TLS 1.2+ 协商tls_client 允许 TLS 1.2 和 1.3/ TLS 1.2+ negotiation (tls_client allows TLS 1.2 and 1.3).
// 启用对等证书验证 (CVSS 7.4 修复) / Enable peer certificate verification (CVSS 7.4 fix). // 启用针对系统 CA 存储的对等证书验证。在 Windows 上
// set_default_verify_paths() 加载系统 CA 包;没有 verify_peer // set_default_verify_paths() 可能无法定位系统 CA
// CA 存储不会被查询——任何证书(自签名/过期)都将被接受 / set_default_verify_paths() loads system CA bundle; without verify_peer // 检测到这种情况时尝试回退源 / Enable peer certificate verification against system CA store.
// the CA store is never consulted — any cert (self-signed/expired) is accepted. // On Windows set_default_verify_paths() may not locate system CAs;
// TODO: Windows: set_default_verify_paths() 可能无法定位系统 CA // we detect that case and try fallback sources.
// 如果验证失败,设置 SSL_CERT_FILE 环境变量或捆绑 cacert.pem / Windows: set_default_verify_paths() may not locate system CAs;
// if verification fails, set SSL_CERT_FILE env or bundle a cacert.pem. boost::system::error_code ec;
ssl_ctx.set_default_verify_paths(ec);
if (ec) {
// 主路径失败——按顺序尝试回退源 / Primary path failed — try fallback sources in order
bool loaded = false;
// 回退 1SSL_CERT_FILE / SSL_CERT_DIROpenSSL 内部已查询,
// 但显式 load_verify_file 提供明确的错误码用于报告)/ Fallback 1: SSL_CERT_FILE / SSL_CERT_DIR (already consulted by
// OpenSSL internally, but an explicit load_verify_file gives us
// a clear error code to report).
const char* cert_file = std::getenv("SSL_CERT_FILE");
if (cert_file && *cert_file) {
ssl_ctx.load_verify_file(cert_file, ec);
if (!ec) loaded = true;
}
if (!loaded) {
const char* cert_dir = std::getenv("SSL_CERT_DIR");
if (cert_dir && *cert_dir) {
ssl_ctx.add_verify_path(cert_dir, ec);
if (!ec) loaded = true;
}
}
// 回退 2http.ca_cert_file 配置项 / Fallback 2: http.ca_cert_file config key
if (!loaded && g_config_svc) {
const char* cfg_cert = g_config_svc->get("http.ca_cert_file");
if (cfg_cert && *cfg_cert) {
ssl_ctx.load_verify_file(cfg_cert, ec);
if (!ec) loaded = true;
}
}
// 回退 3捆绑的 cacert.pem相对于常见安装路径 / Fallback 3: bundled cacert.pem relative to common install paths
if (!loaded) {
static const char* kBundlePaths[] = {
"cacert.pem",
"share/cacert.pem",
"../share/cacert.pem",
"certs/cacert.pem",
nullptr
};
for (int pi = 0; kBundlePaths[pi]; ++pi) {
ssl_ctx.load_verify_file(kBundlePaths[pi], ec);
if (!ec) { loaded = true; break; }
}
}
if (!loaded) {
if (g_host) g_host->log(DSTALK_LOG_WARN,
"TLS CA certificates not found. "
"set_default_verify_paths() failed: %s. "
"Set SSL_CERT_FILE=/path/to/cacert.pem or "
"http.ca_cert_file in config. "
"TLS verification will proceed but may fail at handshake.",
ec.message().c_str());
}
}
ssl_ctx.set_verify_mode(ssl::verify_peer); ssl_ctx.set_verify_mode(ssl::verify_peer);
} }
}; };
@@ -132,6 +262,33 @@ static int do_post_stream(
return -1; return -1;
} }
// ---- 输入验证(安全加固)/ Input validation (security hardening) ----
// 拒绝 host 和 target 中的控制字符CRLF 注入防护)/ Reject control characters in host and target (CRLF injection prevention)
if (contains_control_chars(host)) {
if (g_host) g_host->log(DSTALK_LOG_ERROR,
"do_post_stream: host contains control characters");
*response_body = nullptr;
*status_code = -1;
return -1;
}
if (contains_control_chars(target)) {
if (g_host) g_host->log(DSTALK_LOG_ERROR,
"do_post_stream: target contains control characters");
*response_body = nullptr;
*status_code = -1;
return -1;
}
// 验证端口:必须是数字 1-65535 或有效的服务名 / Validate port: must be numeric 1-65535 or a valid service name
if (!is_valid_port(port)) {
if (g_host) g_host->log(DSTALK_LOG_ERROR,
"do_post_stream: invalid port '%s'", port);
*response_body = nullptr;
*status_code = -1;
return -1;
}
// 初始化输出 / Initialize output // 初始化输出 / Initialize output
*response_body = nullptr; *response_body = nullptr;
*status_code = -1; *status_code = -1;

View File

@@ -0,0 +1,46 @@
// ============================================================================
// anthropic_internal.hpp — 内部声明:供单元测试访问的函数与数据结构
// ============================================================================
// 仅在 tests 中使用;非 plugin 公共 API
// ============================================================================
#ifndef ANTHROPIC_INTERNAL_HPP
#define ANTHROPIC_INTERNAL_HPP
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_services.h"
#include "ai_common.hpp"
#include <string>
#include <vector>
// ---- 从 dstalk_ai 命名空间导入共享类型与函数 ----
using dstalk_ai::PluginConfig;
using dstalk_ai::ToolCallAccum;
using dstalk_ai::StreamContext;
using dstalk_ai::secure_zero;
using dstalk_ai::extract_host_port;
// ---- 全局变量 ----
extern PluginConfig g_cfg;
extern std::string g_tools_json;
// ---- 测试用函数声明 ----
bool parse_sse_data(const std::string& data, std::string& token_out,
StreamContext* ctx);
std::string build_request_json(const dstalk_message_t* history, int history_len,
const std::string& user_input,
const std::string& tools_json,
bool stream);
std::string build_headers_json();
void my_free_result(dstalk_chat_result_t* result);
int my_configure(const char* provider, const char* base_url,
const char* api_key, const char* model,
int max_tokens, double temperature);
#endif // ANTHROPIC_INTERNAL_HPP

View File

@@ -212,6 +212,40 @@ target_link_libraries(dstalk_network_plugin_test
add_test(NAME dstalk_network_plugin_test COMMAND dstalk_network_plugin_test) add_test(NAME dstalk_network_plugin_test COMMAND dstalk_network_plugin_test)
# ============================================================
# dstalk_lsp_plugin_test — LSP 插件单元测试 (GoogleTest, 新增)
# 覆盖: lsp_trim / lsp_frame_message / lsp_parse_content_length
# ============================================================
add_executable(dstalk_lsp_plugin_test
lsp_plugin_test.cpp
${CMAKE_SOURCE_DIR}/plugins_base/lsp/src/lsp_plugin.cpp
)
target_include_directories(dstalk_lsp_plugin_test
PRIVATE
${CMAKE_SOURCE_DIR}/dstalk_core/include
${CMAKE_SOURCE_DIR}/plugins_base/lsp/src
)
target_compile_definitions(dstalk_lsp_plugin_test
PRIVATE
BOOST_JSON_HEADER_ONLY
BOOST_ALL_NO_LIB
)
target_compile_features(dstalk_lsp_plugin_test
PRIVATE cxx_std_17
)
target_link_libraries(dstalk_lsp_plugin_test
PRIVATE
dstalk
boost::boost
)
add_test(NAME dstalk_lsp_plugin_test COMMAND dstalk_lsp_plugin_test)
# ============================================================ # ============================================================
# coverage — gcovr 覆盖率报告 (HTML + 终端摘要) # coverage — gcovr 覆盖率报告 (HTML + 终端摘要)
# 用法: cmake --build <dir> --target coverage # 用法: cmake --build <dir> --target coverage

153
tests/lsp_plugin_test.cpp Normal file
View File

@@ -0,0 +1,153 @@
// ============================================================================
// lsp_plugin_test.cpp — LSP 插件单元测试 (轻量 CHECK 宏,离线环境无 GoogleTest)
// ============================================================================
// 测试 LSP 插件的可独立验证功能:
// - lsp_trim: 字符串 trim 逻辑
// - lsp_frame_message: Content-Length header 构建
// - lsp_parse_content_length: Content-Length header 解析
// ============================================================================
#include "lsp_internal.hpp"
#include "dstalk/dstalk_host.h"
#include <cstdarg>
#include <cstring>
#include <iostream>
#include <string>
static int g_failures = 0;
// Lightweight assertion macros (matches project pattern used by other tests)
#define CHECK(cond, msg) do { \
if (!(cond)) { \
std::cerr << "[FAIL] " << (msg) << "\n"; \
++g_failures; \
} \
} while (0)
#define CHECK_EQ(actual, expected, msg) do { \
auto _a = (actual); \
auto _e = (expected); \
if (!(_a == _e)) { \
std::cerr << "[FAIL] " << (msg) \
<< " (got=" << _a << " expected=" << _e << ")\n"; \
++g_failures; \
} \
} while (0)
// ----------------------------------------------------------------------------
// lsp_trim
// ----------------------------------------------------------------------------
static void test_lsp_trim() {
CHECK_EQ(lsp_trim(""), std::string(""), "trim empty string");
CHECK_EQ(lsp_trim("hello"), std::string("hello"), "trim no whitespace");
CHECK_EQ(lsp_trim(" hello"), std::string("hello"), "trim leading spaces");
CHECK_EQ(lsp_trim("hello "), std::string("hello"), "trim trailing spaces");
CHECK_EQ(lsp_trim(" hello "), std::string("hello"), "trim both sides");
CHECK_EQ(lsp_trim("\t\n\rhello\t\n\r"), std::string("hello"), "trim tabs/newlines");
CHECK_EQ(lsp_trim(" \t\n\r "), std::string(""), "trim only whitespace");
CHECK_EQ(lsp_trim("A"), std::string("A"), "trim single char");
CHECK_EQ(lsp_trim(" hello world "), std::string("hello world"),
"trim preserves internal whitespace");
}
// ----------------------------------------------------------------------------
// lsp_frame_message
// ----------------------------------------------------------------------------
static void test_lsp_frame_message() {
{
std::string frame = lsp_frame_message("");
CHECK(frame.find("Content-Length: 0") != std::string::npos,
"frame empty body has Content-Length: 0");
CHECK(frame.find("\r\n\r\n") != std::string::npos,
"frame empty body has header separator");
CHECK(frame.find("\r\n\r\n") + 4 == frame.size(),
"frame empty body ends right after separator");
}
{
std::string body = "{\"jsonrpc\":\"2.0\"}";
std::string frame = lsp_frame_message(body);
CHECK(frame.find("Content-Length: " + std::to_string(body.size())) != std::string::npos,
"frame simple body Content-Length matches");
CHECK(frame.find(body) != std::string::npos, "frame simple body present");
}
{
std::string body = "line1\nline2\nline3";
std::string frame = lsp_frame_message(body);
CHECK(frame.find("Content-Length: " + std::to_string(body.size())) != std::string::npos,
"frame multiline body Content-Length matches");
CHECK(frame.find(body) != std::string::npos, "frame multiline body present");
}
{
std::string body(std::string("\x00\x01\x02\x03\xFF", 5));
std::string frame = lsp_frame_message(body);
CHECK(frame.find("Content-Length: 5") != std::string::npos,
"frame binary body Content-Length: 5");
}
}
// ----------------------------------------------------------------------------
// lsp_parse_content_length
// ----------------------------------------------------------------------------
static void test_lsp_parse_content_length() {
CHECK_EQ(lsp_parse_content_length("Content-Length: 1234"), 1234, "parse valid header");
CHECK_EQ(lsp_parse_content_length(" Content-Length: 42"), 42, "parse leading spaces");
CHECK_EQ(lsp_parse_content_length("content-length: 99"), 99, "parse lowercase");
CHECK_EQ(lsp_parse_content_length("CONTENT-LENGTH: 77"), 77, "parse uppercase");
CHECK_EQ(lsp_parse_content_length("Content-Length: 0"), 0, "parse zero length");
CHECK_EQ(lsp_parse_content_length("Content-Length: 1048576"), 1048576, "parse large value");
CHECK_EQ(lsp_parse_content_length(""), -1, "parse empty string");
CHECK_EQ(lsp_parse_content_length("Content-Lengthh: 10"), -1, "parse misspelled (extra h)");
CHECK_EQ(lsp_parse_content_length("ContentLength: 10"), -1, "parse misspelled (no hyphen)");
CHECK_EQ(lsp_parse_content_length("Content-Length 10"), -1, "parse missing colon");
CHECK_EQ(lsp_parse_content_length("Content-Length: abc"), -1, "parse non-numeric");
CHECK_EQ(lsp_parse_content_length("Content-Length: -5"), -5, "parse negative");
CHECK_EQ(lsp_parse_content_length("Content-Length: 999999999999"), -1, "parse overflow");
CHECK_EQ(lsp_parse_content_length(std::string("\x00\x01\xFF", 3)), -1, "parse garbage input");
CHECK_EQ(lsp_parse_content_length("Content-Length: 1234abc"), 1234, "parse trailing garbage");
CHECK_EQ(lsp_parse_content_length("Content-Type: application/vscode-jsonrpc"), -1,
"parse other LSP header returns -1");
}
// ----------------------------------------------------------------------------
// frame + parse round-trip
// ----------------------------------------------------------------------------
static void test_round_trip() {
{
std::string body = "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\"}";
std::string frame = lsp_frame_message(body);
size_t header_end = frame.find("\r\n\r\n");
CHECK(header_end != std::string::npos, "round-trip simple finds header end");
if (header_end != std::string::npos) {
std::string header_block = frame.substr(0, header_end);
CHECK_EQ(lsp_parse_content_length(header_block), static_cast<int>(body.size()),
"round-trip simple Content-Length matches body size");
}
}
{
std::string body;
std::string frame = lsp_frame_message(body);
size_t header_end = frame.find("\r\n\r\n");
CHECK(header_end != std::string::npos, "round-trip empty finds header end");
if (header_end != std::string::npos) {
std::string header_block = frame.substr(0, header_end);
CHECK_EQ(lsp_parse_content_length(header_block), 0,
"round-trip empty Content-Length is 0");
}
}
}
int main() {
test_lsp_trim();
test_lsp_frame_message();
test_lsp_parse_content_length();
test_round_trip();
if (g_failures == 0) {
std::cout << "lsp_plugin_test: all checks passed\n";
return 0;
}
std::cerr << "lsp_plugin_test: " << g_failures << " check(s) failed\n";
return 1;
}