W17: extract ai_common shared module + fix anthropic data race + brace bugs

- New plugins_upper/ai_common/ static library: shared PluginConfig, ToolCallAccum,
  StreamContext, secure_zero, extract_host_port, serialize_tool_calls, free_chat_result
- Refactored openai/anthropic plugins to use dstalk_ai:: namespace from ai_common
- Fixed anthropic g_config raw pointer → std::atomic (data race)
- Added SSE parse error counter with threshold abort (kMaxSseParseErrors=5)
- Fixed missing closing brace in both plugins' error-body catch block
- Updated test targets: ai_common include path + link, using namespace dstalk_ai
- plugin_loader_test: added stub_unreg + service_registry.cpp for unregister_service
- Includes pre-existing uncommitted changes from prior waves

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-05-31 16:58:25 +08:00
parent ba7382db2a
commit 8faa02c3d5
49 changed files with 1062 additions and 413 deletions

View File

@@ -1,8 +1,8 @@
add_library(plugin-context SHARED src/context_plugin.cpp)
add_library(plugin_context SHARED src/context_plugin.cpp)
target_link_libraries(plugin-context PRIVATE dstalk)
target_link_libraries(plugin_context PRIVATE dstalk)
set_target_properties(plugin-context PROPERTIES
set_target_properties(plugin_context PROPERTIES
PREFIX ""
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"

View File

@@ -5,7 +5,7 @@
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
// plugin-context: 上下文管理服务插件 / Context management service plugin
// plugin_context: 上下文管理服务插件 / Context management service plugin
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
#include "dstalk/dstalk_host.h"
@@ -263,14 +263,21 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
system_tokens, max_tokens);
}
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
// Precompute per-message token counts for non_system_msgs so that the
// trim pass below is O(N) instead of O(N*K) (no re-counting per iteration).
std::vector<size_t> ns_token_counts;
ns_token_counts.reserve(non_system_msgs.size());
for (const auto& msg : non_system_msgs) {
size_t msg_tokens = count_tokens_trim(msg);
if (msg_tokens > max_tokens) {
ns_token_counts.push_back(count_tokens_trim(msg));
}
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
for (size_t i = 0; i < non_system_msgs.size(); ++i) {
if (ns_token_counts[i] > max_tokens) {
std::fprintf(stderr, "[context] WARNING: single message "
"(%s, %zu tokens) exceeds max_context_tokens (%zu). "
"Returning empty list.\n",
msg.role.c_str(), msg_tokens, max_tokens);
non_system_msgs[i].role.c_str(), ns_token_counts[i], max_tokens);
*out = nullptr;
*out_count = 0;
return -1;
@@ -278,31 +285,53 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
}
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
while (!non_system_msgs.empty()) {
current = system_tokens + count_tokens_trim_vec(non_system_msgs);
if (current <= max_tokens) break;
// O(N): precompute token counts once, then mark removal candidates in a single forward pass
{
size_t ns_total = 0;
for (size_t t : ns_token_counts) ns_total += t;
current = system_tokens + ns_total;
// 找第一个 "user" 消息 / Find first "user" message
auto user_it = non_system_msgs.begin();
while (user_it != non_system_msgs.end() && user_it->role != "user") {
++user_it;
}
if (user_it == non_system_msgs.end()) break;
if (current > max_tokens) {
std::vector<bool> keep(non_system_msgs.size(), true);
size_t idx = 0;
while (idx < non_system_msgs.size() && current > max_tokens) {
// 找第一个 "user" 消息 / Find first "user" message
while (idx < non_system_msgs.size() && non_system_msgs[idx].role != "user") {
++idx;
}
if (idx >= non_system_msgs.size()) break;
// 找下一个 "assistant" / Find next "assistant"
auto assistant_it = user_it + 1;
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
++assistant_it;
}
size_t user_idx = idx;
++idx;
if (assistant_it == non_system_msgs.end()) {
non_system_msgs.erase(user_it);
} else {
// 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
non_system_msgs.erase(assistant_it);
user_it = non_system_msgs.begin();
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
if (user_it != non_system_msgs.end()) non_system_msgs.erase(user_it);
// 找下一个 "assistant" / Find next "assistant"
while (idx < non_system_msgs.size() && non_system_msgs[idx].role != "assistant") {
++idx;
}
if (idx >= non_system_msgs.size()) {
// 没有配对的 assistant只移除 user / No paired assistant, remove user only
keep[user_idx] = false;
current -= ns_token_counts[user_idx];
idx = user_idx + 1; // restart search after the removed message
} else {
// 移除 user + assistant 对 / Remove user + assistant pair
keep[user_idx] = false;
keep[idx] = false;
current -= ns_token_counts[user_idx] + ns_token_counts[idx];
++idx;
}
}
// Rebuild non_system_msgs with only kept messages (single O(N) pass)
std::vector<TrimMessage> kept;
kept.reserve(non_system_msgs.size());
for (size_t i = 0; i < non_system_msgs.size(); ++i) {
if (keep[i]) {
kept.push_back(std::move(non_system_msgs[i]));
}
}
non_system_msgs = std::move(kept);
}
}
@@ -310,8 +339,10 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
{
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
if (max_msg_count < 1) max_msg_count = 1;
while (non_system_msgs.size() > max_msg_count) {
non_system_msgs.erase(non_system_msgs.begin());
// O(N) single range-erase instead of O(N²) repeated erase(begin())
if (non_system_msgs.size() > max_msg_count) {
size_t to_remove = non_system_msgs.size() - max_msg_count;
non_system_msgs.erase(non_system_msgs.begin(), non_system_msgs.begin() + to_remove);
}
}
@@ -397,17 +428,17 @@ static int on_init(const dstalk_host_api_t* host) {
// 查询依赖服务: session / Query dependency service: session
void* raw = host->query_service("session", 1);
if (!raw) {
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
host->log(DSTALK_LOG_ERROR, "[plugin_context] required service 'session' not found");
return -1;
}
g_session = static_cast<const dstalk_session_service_t*>(raw);
return host->register_service("context", 1, &g_context_service);
} catch (const std::exception& e) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init exception: %s", e.what());
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_init exception: %s", e.what());
return -1;
} catch (...) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init unknown exception");
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_init unknown exception");
return -1;
}
}
@@ -419,11 +450,11 @@ static void on_shutdown() {
g_session = nullptr;
g_host = nullptr;
} catch (const std::exception& e) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: %s", e.what());
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_shutdown: %s", e.what());
g_session = nullptr;
g_host = nullptr;
} catch (...) {
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: unknown exception");
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_shutdown: unknown exception");
g_session = nullptr;
g_host = nullptr;
}