Files
dstalk/plugins_middle/session/src/session_plugin.cpp
XiuChengWu ba7382db2a feat: add OpenAI-compatible AI provider plugin with SSE streaming support
- Implemented the OpenAI-compatible AI provider plugin, including configuration, chat, and chat_stream functionalities.
- Added support for SSE streaming and tool calls.
- Integrated Boost.JSON for JSON handling.
- Created CMake configuration for the plugin.
- Added error handling and logging throughout the plugin.
2026-05-31 05:37:04 +08:00

430 lines
18 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/*
* @file session_plugin.cpp
* @brief Session plugin: conversation message history management with save/load.
* 会话插件:对话消息历史管理,支持保存/加载。
* Copyright (c) 2026 dstalk contributors. GPLv3.
*/
// plugin-session: 会话管理服务插件 / Session management service plugin
// 提供 dstalk_session_service_t vtable 实现 / Provides dstalk_session_service_t vtable implementation
// 依赖: file_io (save/load 需要文件操作) / Depends on: file_io (save/load needs file operations)
#include "dstalk/dstalk_host.h"
#include "dstalk/dstalk_types.h"
#include "dstalk/dstalk_services.h"
#include <boost/json.hpp>
#include <boost/json/src.hpp>
#include <algorithm>
#include <atomic>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <exception>
#include <filesystem>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
namespace json = boost::json;
// ============================================================
// 内部 C++ 数据结构 / Internal C++ data structures
// ============================================================
// W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free
static std::atomic<const dstalk_host_api_t*> g_host{nullptr};
static std::atomic<const dstalk_file_io_service_t*> g_file_io{nullptr};
// 内部消息结构C++ 易用,外部暴露 C struct / Internal message struct (C++ friendly, externally exposed as C struct)
struct InternalMessage {
std::string role;
std::string content;
std::string tool_call_id;
std::string tool_calls_json;
};
// 会话历史 + 缓存 —— W14.3: mutex 保护读写 / Session history + cache — W14.3: mutex protects read/write
static std::vector<InternalMessage> g_history;
static std::vector<dstalk_message_t> g_cached_history;
static std::mutex g_session_mutex;
// ============================================================
// Token 计数工具(内联,避免硬依赖 context 头文件) / Token counting utilities (inline, avoids hard dep on context headers)
// ============================================================
// 如果字节是 ASCII (0x000x7F) 则返回 true / Returns true if the byte is ASCII (0x000x7F).
static bool is_ascii(unsigned char c) { return c < 0x80; }
// 启发式判断:如果字节起始一个 UTF-8 CJK 统一表意文字 (0xE40xE9) 则返回 true / Heuristic: returns true if the byte starts a CJK Unified Ideograph in UTF-8 (0xE40xE9).
static bool starts_cjk(unsigned char c) {
return c >= 0xE4 && c <= 0xE9;
}
// 使用启发式 UTF-8 字节计数估算单条消息的 token 数 / Estimate token count for a single message using heuristic UTF-8 byte counting.
static size_t count_tokens_one(const std::string& text) {
size_t ascii_chars = 0;
size_t chinese_chars = 0;
size_t other_chars = 0;
size_t i = 0;
while (i < text.size()) {
unsigned char c = static_cast<unsigned char>(text[i]);
if (is_ascii(c)) {
ascii_chars++;
i += 1;
} else if (starts_cjk(c)) {
chinese_chars++;
i += 3;
} else if (c >= 0xC0 && c < 0xE0) {
other_chars++;
i += 2;
} else if (c >= 0xE0 && c < 0xF0) {
other_chars++;
i += 3;
} else if (c >= 0xF0 && c < 0xF8) {
other_chars++;
i += 4;
} else {
other_chars++;
i += 1;
}
}
size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3);
return content_tokens + 4; // +4 每条消息开销 / +4 per message overhead
}
// 估算所有消息的总 token 数 / Estimate total token count across all messages.
static size_t count_tokens_all(const std::vector<InternalMessage>& msgs) {
size_t total = 0;
for (const auto& m : msgs) {
total += count_tokens_one(m.content);
}
return total;
}
// ============================================================
// 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex / Helper: rebuild C cached array (caller must hold g_session_mutex)
// ============================================================
// 从内部消息 vector 重建 C 兼容的缓存历史数组。调用方必须持有 g_session_mutex / Rebuild the C-compatible cached history array from the internal message vector.
// Caller must hold g_session_mutex.
static void rebuild_cached_history_locked() {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
// 释放旧的字符串 / Free old strings
for (auto& m : g_cached_history) {
if (m.role) { host->free(const_cast<char*>(m.role)); }
if (m.content) { host->free(const_cast<char*>(m.content)); }
if (m.tool_call_id) { host->free(const_cast<char*>(m.tool_call_id)); }
if (m.tool_calls_json){ host->free(const_cast<char*>(m.tool_calls_json)); }
}
g_cached_history.clear();
// 重建 / Rebuild
g_cached_history.reserve(g_history.size());
for (const auto& im : g_history) {
dstalk_message_t cm;
cm.role = im.role.empty() ? nullptr : host->strdup(im.role.c_str());
cm.content = im.content.empty() ? nullptr : host->strdup(im.content.c_str());
cm.tool_call_id = im.tool_call_id.empty() ? nullptr : host->strdup(im.tool_call_id.c_str());
cm.tool_calls_json = im.tool_calls_json.empty() ? nullptr : host->strdup(im.tool_calls_json.c_str());
g_cached_history.push_back(cm);
}
}
// ============================================================
// Session 服务 vtable 实现 (W14.3: try/catch + mutex) / Session service vtable implementation (W14.3: try/catch + mutex)
// ============================================================
// 向对话历史追加一条消息 / Append a message to the conversation history.
static void session_add(const dstalk_message_t* msg) {
try {
if (!msg) return;
InternalMessage im;
if (msg->role) im.role = msg->role;
if (msg->content) im.content = msg->content;
if (msg->tool_call_id) im.tool_call_id = msg->tool_call_id;
if (msg->tool_calls_json) im.tool_calls_json = msg->tool_calls_json;
std::lock_guard<std::mutex> lock(g_session_mutex);
g_history.push_back(std::move(im));
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_add: %s", e.what());
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_add: unknown exception");
}
}
// 清空对话历史中的所有消息 / Clear all messages from the conversation history.
static void session_clear() {
std::lock_guard<std::mutex> lock(g_session_mutex);
g_history.clear();
}
// 将当前对话历史序列化为 JSON 行文件并保存到 path / Serialize the current conversation history to a JSON lines file at `path`.
static int session_save(const char* path) {
try {
if (!path) return -1;
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
if (!fio) return -1;
std::string data;
{
std::lock_guard<std::mutex> lock(g_session_mutex);
for (const auto& m : g_history) {
json::object entry;
entry["role"] = m.role;
entry["content"] = m.content;
if (!m.tool_call_id.empty())
entry["tool_call_id"] = m.tool_call_id;
if (!m.tool_calls_json.empty())
entry["tool_calls_json"] = m.tool_calls_json;
data += json::serialize(entry);
data += '\n';
}
}
return fio->write(path, data.c_str());
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_save: %s", e.what());
return -1;
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_save: unknown exception");
return -1;
}
}
// 从 JSON 行文件中加载对话历史,替换当前历史 / Load conversation history from a JSON lines file at `path`, replacing current history.
static int session_load(const char* path) {
try {
if (!path) return -1;
const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire);
if (!fio) return -1;
char* content = nullptr;
int ret = fio->read(path, &content);
if (ret != 0 || !content) return -1;
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
std::string data(content);
host->free(content);
std::vector<InternalMessage> parsed;
size_t pos = 0;
while (pos < data.size()) {
size_t nl = data.find('\n', pos);
std::string line = (nl != std::string::npos)
? data.substr(pos, nl - pos) : data.substr(pos);
pos = (nl != std::string::npos) ? nl + 1 : data.size();
if (line.empty()) continue;
auto obj = json::parse(line).as_object();
auto* role_j = obj.if_contains("role");
auto* content_j = obj.if_contains("content");
if (role_j && content_j && role_j->is_string() && content_j->is_string()) {
InternalMessage im;
im.role = json::value_to<std::string>(*role_j);
im.content = json::value_to<std::string>(*content_j);
auto* tci = obj.if_contains("tool_call_id");
if (tci && tci->is_string())
im.tool_call_id = json::value_to<std::string>(*tci);
auto* tcj = obj.if_contains("tool_calls_json");
if (tcj && tcj->is_string())
im.tool_calls_json = json::value_to<std::string>(*tcj);
parsed.push_back(std::move(im));
}
}
if (parsed.empty()) return -1;
{
std::lock_guard<std::mutex> lock(g_session_mutex);
g_history = std::move(parsed);
}
return 0;
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_load: %s", e.what());
return -1;
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_load: unknown exception");
return -1;
}
}
// 返回指向缓存 C 消息数组的指针,并将 *out_count 设置为数组大小 / Return a pointer to the cached C-array of messages and set *out_count to its size.
static const dstalk_message_t* session_history(int* out_count) {
try {
std::lock_guard<std::mutex> lock(g_session_mutex);
rebuild_cached_history_locked();
if (out_count) *out_count = static_cast<int>(g_cached_history.size());
return g_cached_history.empty() ? nullptr : g_cached_history.data();
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_history: %s", e.what());
if (out_count) *out_count = 0;
return nullptr;
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_history: unknown exception");
if (out_count) *out_count = 0;
return nullptr;
}
}
// 返回当前对话历史的估算 token 数 / Return the estimated token count for the current conversation history.
static int session_token_count() {
try {
std::lock_guard<std::mutex> lock(g_session_mutex);
return static_cast<int>(count_tokens_all(g_history));
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_token_count: %s", e.what());
return -1;
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "session_token_count: unknown exception");
return -1;
}
}
static dstalk_session_service_t g_session_service = {
session_add,
session_clear,
session_save,
session_load,
session_history,
session_token_count
};
// ============================================================
// W20.6: 默认会话保存路径(平台标准目录) / Default session save path (platform standard directory)
// ============================================================
// 返回平台特定的默认会话保存路径,根据需要创建目录 / Return the platform-specific default session save path, creating directories as needed.
static std::string get_default_session_path() {
// W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录 / static cache + mkdir guarantee + fallback to current dir on failure
static std::string cached_path = []() -> std::string {
#ifdef _WIN32
char* buf = nullptr;
size_t len = 0;
_dupenv_s(&buf, &len, "APPDATA");
std::string dir = buf ? std::string(buf) + "/dstalk" : "dstalk";
free(buf);
#else
const char* home = std::getenv("HOME");
std::string dir = home ? std::string(home) + "/.dstalk" : "/tmp/dstalk";
#endif
std::error_code ec;
std::filesystem::create_directories(dir, ec);
if (ec) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_WARN,
"get_default_session_path: cannot mkdir '%s' (%s), fallback to .",
dir.c_str(), ec.message().c_str());
return std::string("./session.json");
}
return dir + "/session.json";
}();
return cached_path;
}
// ============================================================
// 插件生命周期 / Plugin lifecycle
// ============================================================
// 插件初始化:保存主机指针,查询 file_io 依赖,注册 session 服务,
// 并从默认路径自动加载已有会话 / Plugin init: store host pointer, query file_io dependency, register session service,
// and auto-load any existing session from the default path.
static int on_init(const dstalk_host_api_t* host) {
try {
g_host.store(host, std::memory_order_release);
// 查询依赖服务: file_io / Query dependency service: file_io
void* raw = host->query_service("file_io", 1);
if (!raw) {
host->log(DSTALK_LOG_ERROR, "[plugin-session] required service 'file_io' not found");
return -1;
}
g_file_io.store(static_cast<const dstalk_file_io_service_t*>(raw), std::memory_order_release);
// 注册自身服务 / Register own service
int ret = host->register_service("session", 1, &g_session_service);
if (ret != 0) return ret;
// W20.6: 从默认路径恢复会话(文件不存在则静默失败) / Restore session from default path (silent fail if file missing)
session_load(get_default_session_path().c_str());
return 0;
} catch (const std::exception& e) {
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
if (h) h->log(DSTALK_LOG_ERROR, "on_init[session]: %s", e.what());
return -1;
} catch (...) {
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
if (h) h->log(DSTALK_LOG_ERROR, "on_init[session]: unknown exception");
return -1;
}
}
// 插件关闭:自动保存会话到默认路径,失败时回退到当前目录,
// 然后释放缓存历史和清空状态 / Plugin shutdown: auto-save session to default path, fallback to current dir on failure,
// then release cached history and clear state.
static void on_shutdown() {
try {
// W20.6: 清空前自动保存到默认路径 / Auto-save to default path before clearing
// W21.4: 失败告警 + 当前目录 fallback / Failure warning + current dir fallback
int ret = session_save(get_default_session_path().c_str());
if (ret != 0) {
const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire);
if (h) h->log(DSTALK_LOG_WARN, "on_shutdown[session]: auto-save failed (ret=%d), trying fallback", ret);
int fret = session_save("./dstalk_session_backup.json");
if (fret != 0) {
if (h) h->log(DSTALK_LOG_ERROR, "on_shutdown[session]: fallback also failed (ret=%d), data may be lost", fret);
}
}
std::lock_guard<std::mutex> lock(g_session_mutex);
rebuild_cached_history_locked();
g_cached_history.clear();
g_history.clear();
} catch (const std::exception& e) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "on_shutdown[session]: %s", e.what());
} catch (...) {
const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire);
if (host) host->log(DSTALK_LOG_ERROR, "on_shutdown[session]: unknown exception");
}
g_file_io.store(nullptr, std::memory_order_release);
g_host.store(nullptr, std::memory_order_release);
}
static dstalk_plugin_info_t g_info = {
"session",
"1.0.0",
"Session management plugin with save/load support / 支持保存/加载的会话管理插件",
DSTALK_API_VERSION,
{"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr},
on_init,
on_shutdown,
nullptr
};
// 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host.
extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) {
return &g_info;
}