W17: extract ai_common shared module + fix anthropic data race + brace bugs
- New plugins_upper/ai_common/ static library: shared PluginConfig, ToolCallAccum, StreamContext, secure_zero, extract_host_port, serialize_tool_calls, free_chat_result - Refactored openai/anthropic plugins to use dstalk_ai:: namespace from ai_common - Fixed anthropic g_config raw pointer → std::atomic (data race) - Added SSE parse error counter with threshold abort (kMaxSseParseErrors=5) - Fixed missing closing brace in both plugins' error-body catch block - Updated test targets: ai_common include path + link, using namespace dstalk_ai - plugin_loader_test: added stub_unreg + service_registry.cpp for unregister_service - Includes pre-existing uncommitted changes from prior waves Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
add_library(plugin-context SHARED src/context_plugin.cpp)
|
||||
add_library(plugin_context SHARED src/context_plugin.cpp)
|
||||
|
||||
target_link_libraries(plugin-context PRIVATE dstalk)
|
||||
target_link_libraries(plugin_context PRIVATE dstalk)
|
||||
|
||||
set_target_properties(plugin-context PROPERTIES
|
||||
set_target_properties(plugin_context PROPERTIES
|
||||
PREFIX ""
|
||||
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/plugins"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
* Copyright (c) 2026 dstalk contributors. GPLv3.
|
||||
*/
|
||||
|
||||
// plugin-context: 上下文管理服务插件 / Context management service plugin
|
||||
// plugin_context: 上下文管理服务插件 / Context management service plugin
|
||||
// 提供 dstalk_context_service_t vtable 实现 / Provides dstalk_context_service_t vtable implementation
|
||||
// 依赖: session (获取历史消息做 token 计数) / Depends on: session (get history messages for token counting)
|
||||
#include "dstalk/dstalk_host.h"
|
||||
@@ -263,14 +263,21 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
system_tokens, max_tokens);
|
||||
}
|
||||
|
||||
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
|
||||
// Precompute per-message token counts for non_system_msgs so that the
|
||||
// trim pass below is O(N) instead of O(N*K) (no re-counting per iteration).
|
||||
std::vector<size_t> ns_token_counts;
|
||||
ns_token_counts.reserve(non_system_msgs.size());
|
||||
for (const auto& msg : non_system_msgs) {
|
||||
size_t msg_tokens = count_tokens_trim(msg);
|
||||
if (msg_tokens > max_tokens) {
|
||||
ns_token_counts.push_back(count_tokens_trim(msg));
|
||||
}
|
||||
|
||||
// 检查是否有单条消息超过限制 / Check if any single message exceeds the limit
|
||||
for (size_t i = 0; i < non_system_msgs.size(); ++i) {
|
||||
if (ns_token_counts[i] > max_tokens) {
|
||||
std::fprintf(stderr, "[context] WARNING: single message "
|
||||
"(%s, %zu tokens) exceeds max_context_tokens (%zu). "
|
||||
"Returning empty list.\n",
|
||||
msg.role.c_str(), msg_tokens, max_tokens);
|
||||
non_system_msgs[i].role.c_str(), ns_token_counts[i], max_tokens);
|
||||
*out = nullptr;
|
||||
*out_count = 0;
|
||||
return -1;
|
||||
@@ -278,31 +285,53 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
}
|
||||
|
||||
// 从最早的非 system 消息开始裁剪,确保 user/assistant 成对移除 / Trim from earliest non-system messages, ensuring user/assistant pairs are removed together
|
||||
while (!non_system_msgs.empty()) {
|
||||
current = system_tokens + count_tokens_trim_vec(non_system_msgs);
|
||||
if (current <= max_tokens) break;
|
||||
// O(N): precompute token counts once, then mark removal candidates in a single forward pass
|
||||
{
|
||||
size_t ns_total = 0;
|
||||
for (size_t t : ns_token_counts) ns_total += t;
|
||||
current = system_tokens + ns_total;
|
||||
|
||||
// 找第一个 "user" 消息 / Find first "user" message
|
||||
auto user_it = non_system_msgs.begin();
|
||||
while (user_it != non_system_msgs.end() && user_it->role != "user") {
|
||||
++user_it;
|
||||
}
|
||||
if (user_it == non_system_msgs.end()) break;
|
||||
if (current > max_tokens) {
|
||||
std::vector<bool> keep(non_system_msgs.size(), true);
|
||||
size_t idx = 0;
|
||||
while (idx < non_system_msgs.size() && current > max_tokens) {
|
||||
// 找第一个 "user" 消息 / Find first "user" message
|
||||
while (idx < non_system_msgs.size() && non_system_msgs[idx].role != "user") {
|
||||
++idx;
|
||||
}
|
||||
if (idx >= non_system_msgs.size()) break;
|
||||
|
||||
// 找下一个 "assistant" / Find next "assistant"
|
||||
auto assistant_it = user_it + 1;
|
||||
while (assistant_it != non_system_msgs.end() && assistant_it->role != "assistant") {
|
||||
++assistant_it;
|
||||
}
|
||||
size_t user_idx = idx;
|
||||
++idx;
|
||||
|
||||
if (assistant_it == non_system_msgs.end()) {
|
||||
non_system_msgs.erase(user_it);
|
||||
} else {
|
||||
// 先删 assistant 再删 user 避免迭代器失效 / Delete assistant first then user to avoid iterator invalidation
|
||||
non_system_msgs.erase(assistant_it);
|
||||
user_it = non_system_msgs.begin();
|
||||
while (user_it != non_system_msgs.end() && user_it->role != "user") ++user_it;
|
||||
if (user_it != non_system_msgs.end()) non_system_msgs.erase(user_it);
|
||||
// 找下一个 "assistant" / Find next "assistant"
|
||||
while (idx < non_system_msgs.size() && non_system_msgs[idx].role != "assistant") {
|
||||
++idx;
|
||||
}
|
||||
|
||||
if (idx >= non_system_msgs.size()) {
|
||||
// 没有配对的 assistant,只移除 user / No paired assistant, remove user only
|
||||
keep[user_idx] = false;
|
||||
current -= ns_token_counts[user_idx];
|
||||
idx = user_idx + 1; // restart search after the removed message
|
||||
} else {
|
||||
// 移除 user + assistant 对 / Remove user + assistant pair
|
||||
keep[user_idx] = false;
|
||||
keep[idx] = false;
|
||||
current -= ns_token_counts[user_idx] + ns_token_counts[idx];
|
||||
++idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild non_system_msgs with only kept messages (single O(N) pass)
|
||||
std::vector<TrimMessage> kept;
|
||||
kept.reserve(non_system_msgs.size());
|
||||
for (size_t i = 0; i < non_system_msgs.size(); ++i) {
|
||||
if (keep[i]) {
|
||||
kept.push_back(std::move(non_system_msgs[i]));
|
||||
}
|
||||
}
|
||||
non_system_msgs = std::move(kept);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -310,8 +339,10 @@ static int trim_impl(const dstalk_message_t* in, int in_count,
|
||||
{
|
||||
size_t max_msg_count = (max_tokens + 99) / 100; // ceil(max_tokens / 100)
|
||||
if (max_msg_count < 1) max_msg_count = 1;
|
||||
while (non_system_msgs.size() > max_msg_count) {
|
||||
non_system_msgs.erase(non_system_msgs.begin());
|
||||
// O(N) single range-erase instead of O(N²) repeated erase(begin())
|
||||
if (non_system_msgs.size() > max_msg_count) {
|
||||
size_t to_remove = non_system_msgs.size() - max_msg_count;
|
||||
non_system_msgs.erase(non_system_msgs.begin(), non_system_msgs.begin() + to_remove);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -397,17 +428,17 @@ static int on_init(const dstalk_host_api_t* host) {
|
||||
// 查询依赖服务: session / Query dependency service: session
|
||||
void* raw = host->query_service("session", 1);
|
||||
if (!raw) {
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin-context] required service 'session' not found");
|
||||
host->log(DSTALK_LOG_ERROR, "[plugin_context] required service 'session' not found");
|
||||
return -1;
|
||||
}
|
||||
g_session = static_cast<const dstalk_session_service_t*>(raw);
|
||||
|
||||
return host->register_service("context", 1, &g_context_service);
|
||||
} catch (const std::exception& e) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init exception: %s", e.what());
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_init exception: %s", e.what());
|
||||
return -1;
|
||||
} catch (...) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_init unknown exception");
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_init unknown exception");
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -419,11 +450,11 @@ static void on_shutdown() {
|
||||
g_session = nullptr;
|
||||
g_host = nullptr;
|
||||
} catch (const std::exception& e) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: %s", e.what());
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_shutdown: %s", e.what());
|
||||
g_session = nullptr;
|
||||
g_host = nullptr;
|
||||
} catch (...) {
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin-context] on_shutdown: unknown exception");
|
||||
if (g_host) g_host->log(DSTALK_LOG_ERROR, "[plugin_context] on_shutdown: unknown exception");
|
||||
g_session = nullptr;
|
||||
g_host = nullptr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user