/* * @file session_plugin.cpp * @brief Session plugin: conversation message history management with save/load. * 会话插件:对话消息历史管理,支持保存/加载。 * Copyright (c) 2026 dstalk contributors. GPLv3. */ // plugin-session: 会话管理服务插件 / Session management service plugin // 提供 dstalk_session_service_t vtable 实现 / Provides dstalk_session_service_t vtable implementation // 依赖: file_io (save/load 需要文件操作) / Depends on: file_io (save/load needs file operations) #include "dstalk/dstalk_host.h" #include "dstalk/dstalk_types.h" #include "dstalk/dstalk_services.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace json = boost::json; // ============================================================ // 内部 C++ 数据结构 / Internal C++ data structures // ============================================================ // W14.3: g_host / g_file_io 使用 atomic 指针,写入 acquire/release,读取无锁 / g_host / g_file_io use atomic pointers, write with acquire/release, read lock-free static std::atomic g_host{nullptr}; static std::atomic g_file_io{nullptr}; // 内部消息结构(C++ 易用,外部暴露 C struct) / Internal message struct (C++ friendly, externally exposed as C struct) struct InternalMessage { std::string role; std::string content; std::string tool_call_id; std::string tool_calls_json; }; // 会话历史 + 缓存 —— W14.3: mutex 保护读写 / Session history + cache — W14.3: mutex protects read/write static std::vector g_history; static std::vector g_cached_history; static std::mutex g_session_mutex; // ============================================================ // Token 计数工具(内联,避免硬依赖 context 头文件) / Token counting utilities (inline, avoids hard dep on context headers) // ============================================================ // 如果字节是 ASCII (0x00–0x7F) 则返回 true / Returns true if the byte is ASCII (0x00–0x7F). static bool is_ascii(unsigned char c) { return c < 0x80; } // 启发式判断:如果字节起始一个 UTF-8 CJK 统一表意文字 (0xE4–0xE9) 则返回 true / Heuristic: returns true if the byte starts a CJK Unified Ideograph in UTF-8 (0xE4–0xE9). static bool starts_cjk(unsigned char c) { return c >= 0xE4 && c <= 0xE9; } // 使用启发式 UTF-8 字节计数估算单条消息的 token 数 / Estimate token count for a single message using heuristic UTF-8 byte counting. static size_t count_tokens_one(const std::string& text) { size_t ascii_chars = 0; size_t chinese_chars = 0; size_t other_chars = 0; size_t i = 0; while (i < text.size()) { unsigned char c = static_cast(text[i]); if (is_ascii(c)) { ascii_chars++; i += 1; } else if (starts_cjk(c)) { chinese_chars++; i += 3; } else if (c >= 0xC0 && c < 0xE0) { other_chars++; i += 2; } else if (c >= 0xE0 && c < 0xF0) { other_chars++; i += 3; } else if (c >= 0xF0 && c < 0xF8) { other_chars++; i += 4; } else { other_chars++; i += 1; } } size_t content_tokens = (ascii_chars / 4) + (chinese_chars / 2) + (other_chars / 3); return content_tokens + 4; // +4 每条消息开销 / +4 per message overhead } // 估算所有消息的总 token 数 / Estimate total token count across all messages. static size_t count_tokens_all(const std::vector& msgs) { size_t total = 0; for (const auto& m : msgs) { total += count_tokens_one(m.content); } return total; } // ============================================================ // 辅助:刷新 C 缓存数组(调用方需持有 g_session_mutex) / Helper: rebuild C cached array (caller must hold g_session_mutex) // ============================================================ // 从内部消息 vector 重建 C 兼容的缓存历史数组。调用方必须持有 g_session_mutex / Rebuild the C-compatible cached history array from the internal message vector. // Caller must hold g_session_mutex. static void rebuild_cached_history_locked() { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); // 释放旧的字符串 / Free old strings for (auto& m : g_cached_history) { if (m.role) { host->free(const_cast(m.role)); } if (m.content) { host->free(const_cast(m.content)); } if (m.tool_call_id) { host->free(const_cast(m.tool_call_id)); } if (m.tool_calls_json){ host->free(const_cast(m.tool_calls_json)); } } g_cached_history.clear(); // 重建 / Rebuild g_cached_history.reserve(g_history.size()); for (const auto& im : g_history) { dstalk_message_t cm; cm.role = im.role.empty() ? nullptr : host->strdup(im.role.c_str()); cm.content = im.content.empty() ? nullptr : host->strdup(im.content.c_str()); cm.tool_call_id = im.tool_call_id.empty() ? nullptr : host->strdup(im.tool_call_id.c_str()); cm.tool_calls_json = im.tool_calls_json.empty() ? nullptr : host->strdup(im.tool_calls_json.c_str()); g_cached_history.push_back(cm); } } // ============================================================ // Session 服务 vtable 实现 (W14.3: try/catch + mutex) / Session service vtable implementation (W14.3: try/catch + mutex) // ============================================================ // 向对话历史追加一条消息 / Append a message to the conversation history. static void session_add(const dstalk_message_t* msg) { try { if (!msg) return; InternalMessage im; if (msg->role) im.role = msg->role; if (msg->content) im.content = msg->content; if (msg->tool_call_id) im.tool_call_id = msg->tool_call_id; if (msg->tool_calls_json) im.tool_calls_json = msg->tool_calls_json; std::lock_guard lock(g_session_mutex); g_history.push_back(std::move(im)); } catch (const std::exception& e) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_add: %s", e.what()); } catch (...) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_add: unknown exception"); } } // 清空对话历史中的所有消息 / Clear all messages from the conversation history. static void session_clear() { std::lock_guard lock(g_session_mutex); g_history.clear(); } // 将当前对话历史序列化为 JSON 行文件并保存到 path / Serialize the current conversation history to a JSON lines file at `path`. static int session_save(const char* path) { try { if (!path) return -1; const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire); if (!fio) return -1; std::string data; { std::lock_guard lock(g_session_mutex); for (const auto& m : g_history) { json::object entry; entry["role"] = m.role; entry["content"] = m.content; if (!m.tool_call_id.empty()) entry["tool_call_id"] = m.tool_call_id; if (!m.tool_calls_json.empty()) entry["tool_calls_json"] = m.tool_calls_json; data += json::serialize(entry); data += '\n'; } } return fio->write(path, data.c_str()); } catch (const std::exception& e) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_save: %s", e.what()); return -1; } catch (...) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_save: unknown exception"); return -1; } } // 从 JSON 行文件中加载对话历史,替换当前历史 / Load conversation history from a JSON lines file at `path`, replacing current history. static int session_load(const char* path) { try { if (!path) return -1; const dstalk_file_io_service_t* fio = g_file_io.load(std::memory_order_acquire); if (!fio) return -1; char* content = nullptr; int ret = fio->read(path, &content); if (ret != 0 || !content) return -1; const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); std::string data(content); host->free(content); std::vector parsed; size_t pos = 0; while (pos < data.size()) { size_t nl = data.find('\n', pos); std::string line = (nl != std::string::npos) ? data.substr(pos, nl - pos) : data.substr(pos); pos = (nl != std::string::npos) ? nl + 1 : data.size(); if (line.empty()) continue; auto obj = json::parse(line).as_object(); auto* role_j = obj.if_contains("role"); auto* content_j = obj.if_contains("content"); if (role_j && content_j && role_j->is_string() && content_j->is_string()) { InternalMessage im; im.role = json::value_to(*role_j); im.content = json::value_to(*content_j); auto* tci = obj.if_contains("tool_call_id"); if (tci && tci->is_string()) im.tool_call_id = json::value_to(*tci); auto* tcj = obj.if_contains("tool_calls_json"); if (tcj && tcj->is_string()) im.tool_calls_json = json::value_to(*tcj); parsed.push_back(std::move(im)); } } if (parsed.empty()) return -1; { std::lock_guard lock(g_session_mutex); g_history = std::move(parsed); } return 0; } catch (const std::exception& e) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_load: %s", e.what()); return -1; } catch (...) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_load: unknown exception"); return -1; } } // 返回指向缓存 C 消息数组的指针,并将 *out_count 设置为数组大小 / Return a pointer to the cached C-array of messages and set *out_count to its size. static const dstalk_message_t* session_history(int* out_count) { try { std::lock_guard lock(g_session_mutex); rebuild_cached_history_locked(); if (out_count) *out_count = static_cast(g_cached_history.size()); return g_cached_history.empty() ? nullptr : g_cached_history.data(); } catch (const std::exception& e) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_history: %s", e.what()); if (out_count) *out_count = 0; return nullptr; } catch (...) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_history: unknown exception"); if (out_count) *out_count = 0; return nullptr; } } // 返回当前对话历史的估算 token 数 / Return the estimated token count for the current conversation history. static int session_token_count() { try { std::lock_guard lock(g_session_mutex); return static_cast(count_tokens_all(g_history)); } catch (const std::exception& e) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_token_count: %s", e.what()); return -1; } catch (...) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "session_token_count: unknown exception"); return -1; } } static dstalk_session_service_t g_session_service = { session_add, session_clear, session_save, session_load, session_history, session_token_count }; // ============================================================ // W20.6: 默认会话保存路径(平台标准目录) / Default session save path (platform standard directory) // ============================================================ // 返回平台特定的默认会话保存路径,根据需要创建目录 / Return the platform-specific default session save path, creating directories as needed. static std::string get_default_session_path() { // W22.5: static 缓存 + mkdir 保障 + 失败 fallback 到当前目录 / static cache + mkdir guarantee + fallback to current dir on failure static std::string cached_path = []() -> std::string { #ifdef _WIN32 char* buf = nullptr; size_t len = 0; _dupenv_s(&buf, &len, "APPDATA"); std::string dir = buf ? std::string(buf) + "/dstalk" : "dstalk"; free(buf); #else const char* home = std::getenv("HOME"); std::string dir = home ? std::string(home) + "/.dstalk" : "/tmp/dstalk"; #endif std::error_code ec; std::filesystem::create_directories(dir, ec); if (ec) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_WARN, "get_default_session_path: cannot mkdir '%s' (%s), fallback to .", dir.c_str(), ec.message().c_str()); return std::string("./session.json"); } return dir + "/session.json"; }(); return cached_path; } // ============================================================ // 插件生命周期 / Plugin lifecycle // ============================================================ // 插件初始化:保存主机指针,查询 file_io 依赖,注册 session 服务, // 并从默认路径自动加载已有会话 / Plugin init: store host pointer, query file_io dependency, register session service, // and auto-load any existing session from the default path. static int on_init(const dstalk_host_api_t* host) { try { g_host.store(host, std::memory_order_release); // 查询依赖服务: file_io / Query dependency service: file_io void* raw = host->query_service("file_io", 1); if (!raw) { host->log(DSTALK_LOG_ERROR, "[plugin-session] required service 'file_io' not found"); return -1; } g_file_io.store(static_cast(raw), std::memory_order_release); // 注册自身服务 / Register own service int ret = host->register_service("session", 1, &g_session_service); if (ret != 0) return ret; // W20.6: 从默认路径恢复会话(文件不存在则静默失败) / Restore session from default path (silent fail if file missing) session_load(get_default_session_path().c_str()); return 0; } catch (const std::exception& e) { const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire); if (h) h->log(DSTALK_LOG_ERROR, "on_init[session]: %s", e.what()); return -1; } catch (...) { const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire); if (h) h->log(DSTALK_LOG_ERROR, "on_init[session]: unknown exception"); return -1; } } // 插件关闭:自动保存会话到默认路径,失败时回退到当前目录, // 然后释放缓存历史和清空状态 / Plugin shutdown: auto-save session to default path, fallback to current dir on failure, // then release cached history and clear state. static void on_shutdown() { try { // W20.6: 清空前自动保存到默认路径 / Auto-save to default path before clearing // W21.4: 失败告警 + 当前目录 fallback / Failure warning + current dir fallback int ret = session_save(get_default_session_path().c_str()); if (ret != 0) { const dstalk_host_api_t* h = g_host.load(std::memory_order_acquire); if (h) h->log(DSTALK_LOG_WARN, "on_shutdown[session]: auto-save failed (ret=%d), trying fallback", ret); int fret = session_save("./dstalk_session_backup.json"); if (fret != 0) { if (h) h->log(DSTALK_LOG_ERROR, "on_shutdown[session]: fallback also failed (ret=%d), data may be lost", fret); } } std::lock_guard lock(g_session_mutex); rebuild_cached_history_locked(); g_cached_history.clear(); g_history.clear(); } catch (const std::exception& e) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "on_shutdown[session]: %s", e.what()); } catch (...) { const dstalk_host_api_t* host = g_host.load(std::memory_order_acquire); if (host) host->log(DSTALK_LOG_ERROR, "on_shutdown[session]: unknown exception"); } g_file_io.store(nullptr, std::memory_order_release); g_host.store(nullptr, std::memory_order_release); } static dstalk_plugin_info_t g_info = { "session", "1.0.0", "Session management plugin with save/load support / 支持保存/加载的会话管理插件", DSTALK_API_VERSION, {"file_io", nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}, on_init, on_shutdown, nullptr }; // 必须入口点:返回插件描述符给主机 / Mandatory entry point: returns the plugin descriptor to the host. extern "C" DSTALK_PLUGIN_EXPORT dstalk_plugin_info_t* dstalk_plugin_init(void) { return &g_info; }