Fix streaming and file IO edge cases

Repair streaming callback/error handling and make file/session handling safer so the core API behaves correctly under real usage.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-25 19:26:47 +08:00
parent c9fb924a1c
commit 16475ca3fe
7 changed files with 132 additions and 36 deletions

View File

@@ -29,7 +29,7 @@ DSTALK_API void dstalk_set_model(const char* model);
/* 同步对话: 发送 input返回完整 AI 回复 (调用方通过 dstalk_free_string 释放) */
DSTALK_API int dstalk_chat(const char* input, char** output);
/* 流式对话: 每收到一个 token 调用回调,回调返回 0 可提前取消 */
/* 流式对话: 每收到一个 token 调用回调,回调返回 0 继续,非 0 取消 */
typedef int (*dstalk_stream_cb)(const char* token, void* userdata);
DSTALK_API int dstalk_chat_stream(const char* input, dstalk_stream_cb cb, void* userdata);

View File

@@ -177,9 +177,8 @@ ChatResult DeepSeekClient::chat_stream(
headers["Authorization"] = "Bearer " + impl_->config.api_key;
ChatResult result;
result.ok = true;
impl_->http.post_stream(host, "443", target_path, body, headers,
auto resp = impl_->http.post_stream(host, "443", target_path, body, headers,
[&](const std::string& line) -> bool {
if (line.empty()) return true;
std::string token;
@@ -189,9 +188,36 @@ ChatResult DeepSeekClient::chat_stream(
return on_token ? on_token(token, userdata) : true;
});
result.http_status = resp.status_code;
// 检查传输层错误或非 2xx 状态
if (resp.status_code < 200 || resp.status_code >= 300) {
result.ok = false;
// 尝试从响应 body 提取错误信息(与 parse_response 等同逻辑)
try {
auto jv = json::parse(resp.body);
auto obj = jv.as_object();
if (obj.contains("error")) {
auto err = obj["error"].as_object();
result.error = json::value_to<std::string>(err["message"]);
}
} catch (...) {
}
if (result.error.empty()) {
if (resp.status_code <= 0) {
result.error = "transport error";
} else {
result.error = "HTTP " + std::to_string(resp.status_code);
}
}
return result;
}
if (result.content.empty()) {
result.ok = false;
result.error = "no content received";
} else {
result.ok = true;
}
return result;
}

View File

@@ -3,9 +3,11 @@
#include "file/file_io.hpp"
#include "net/http_client.hpp"
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <string>
#include <utility>
#include <vector>
// ---- 内部状态 ----
@@ -169,29 +171,40 @@ DSTALK_API int dstalk_chat(const char* input, char** output)
return 0;
}
// 流式回调上下文
struct StreamCtx {
std::string* buf;
dstalk_stream_cb cb;
void* ud;
bool cancelled;
};
static bool on_token_proxy(const std::string& token, void* userdata)
{
auto* ctx = static_cast<StreamCtx*>(userdata);
*ctx->buf += token;
int ret = ctx->cb(token.c_str(), ctx->ud);
if (ret == 0) return true;
ctx->cancelled = true;
return false;
}
DSTALK_API int dstalk_chat_stream(const char* input,
dstalk_stream_cb cb, void* userdata)
{
if (!g_initialized || !input || !cb) return -1;
std::string full_reply;
auto result = g_ai.chat_stream(g_history, input,
[](const std::string& token, void* ud) -> bool {
auto* buf = static_cast<std::string*>(ud);
*buf += token;
return true;
}, &full_reply);
StreamCtx ctx{&full_reply, cb, userdata, false};
auto result = g_ai.chat_stream(g_history, input, on_token_proxy, &ctx);
if (!result.ok) return -1;
if (!result.ok && !ctx.cancelled) return -1;
// 更新历史
g_history.push_back({"user", input});
g_history.push_back({"assistant", full_reply});
// 手动回调每个 token (简化实现:收集完后再回调)
// 真正的流式需要在 chat_stream 层回调
(void)cb;
(void)userdata;
return 0;
}
@@ -204,6 +217,7 @@ DSTALK_API void dstalk_free_string(char* str)
DSTALK_API void dstalk_session_clear(void)
{
if (!g_initialized) return;
g_history.clear();
}
@@ -213,14 +227,29 @@ DSTALK_API int dstalk_session_save(const char* path)
// 简单格式: 每行 JSON {"role":"...","content":"..."}
std::string data;
for (const auto& m : g_history) {
// 转义基本字符
// 转义 JSON 特殊字符和控制字符
auto escape = [](const std::string& s) -> std::string {
std::string out;
for (char c : s) {
if (c == '"') out += "\\\"";
else if (c == '\\') out += "\\\\";
else if (c == '\n') out += "\\n";
else out += c;
switch (c) {
case '"': out += "\\\""; break;
case '\\': out += "\\\\"; break;
case '\n': out += "\\n"; break;
case '\r': out += "\\r"; break;
case '\t': out += "\\t"; break;
case '\b': out += "\\b"; break;
case '\f': out += "\\f"; break;
default:
if (static_cast<unsigned char>(c) < 0x20) {
char buf[8];
std::snprintf(buf, sizeof(buf), "\\u%04x",
static_cast<unsigned char>(c));
out += buf;
} else {
out += c;
}
break;
}
}
return out;
};
@@ -237,10 +266,11 @@ DSTALK_API int dstalk_session_load(const char* path)
char* content = file_read_all(path, &len);
if (!content) return -1;
g_history.clear();
std::string data(content, len);
std::free(content);
std::vector<dstalk::ai::Message> parsed;
// 逐行解析简化的 JSON
size_t pos = 0;
while (pos < data.size()) {
@@ -267,9 +297,12 @@ DSTALK_API int dstalk_session_load(const char* path)
std::string role = extract("role");
std::string content_val = extract("content");
if (!role.empty() && !content_val.empty()) {
g_history.push_back({role, content_val});
parsed.push_back({role, content_val});
}
}
if (parsed.empty()) return -1;
g_history = std::move(parsed);
return 0;
}
@@ -277,6 +310,7 @@ DSTALK_API int dstalk_session_load(const char* path)
DSTALK_API int dstalk_file_read(const char* path, char** content)
{
if (!g_initialized || !path || !content) return -1;
size_t len = 0;
char* buf = file_read_all(path, &len);
if (!buf) return -1;

View File

@@ -29,12 +29,24 @@ char* file_read_all(const char* path, size_t* out_len)
fseek(f, 0, SEEK_END);
long sz = ftell(f);
fseek(f, 0, SEEK_SET);
if (sz <= 0) {
if (sz < 0) {
fclose(f);
*out_len = 0;
return nullptr;
}
if (sz == 0) {
fclose(f);
char* buf = (char*)std::malloc(1);
if (!buf) {
*out_len = 0;
return nullptr;
}
buf[0] = '\0';
*out_len = 0;
return buf;
}
char* buf = (char*)std::malloc(static_cast<size_t>(sz) + 1);
if (!buf) {
fclose(f);
@@ -55,9 +67,9 @@ int file_write_all(const char* path, const char* content)
FILE* f = nullptr;
#ifdef _WIN32
fopen_s(&f, path, "w");
fopen_s(&f, path, "wb");
#else
f = fopen(path, "w");
f = fopen(path, "wb");
#endif
if (!f) return -1;

View File

@@ -102,27 +102,52 @@ HttpResponse HttpClient::post_stream(
beast::error_code ec;
if (on_line) {
std::string fragment = result.body;
size_t pos = 0;
while (pos < fragment.size()) {
size_t nl = fragment.find('\n', pos);
if (nl == std::string::npos) break;
std::string line = fragment.substr(pos, nl - pos);
if (!line.empty() && line.back() == '\r')
line.pop_back();
if (!on_line(line)) goto done;
pos = nl + 1;
}
if (pos > 0)
fragment = fragment.substr(pos);
size_t processed = result.body.size();
while (!parser.is_done()) {
http::read_some(stream, buffer, parser, ec);
if (ec) break;
std::string chunk = parser.get().body();
if (!chunk.empty()) {
result.body += chunk;
size_t pos = 0;
while (pos < chunk.size()) {
size_t nl = chunk.find('\n', pos);
std::string line = (nl != std::string::npos)
? chunk.substr(pos, nl - pos)
: chunk.substr(pos);
const std::string& full_body = parser.get().body();
if (full_body.size() > processed) {
std::string new_data = full_body.substr(processed);
result.body += new_data;
processed = full_body.size();
fragment += new_data;
pos = 0;
while (pos < fragment.size()) {
size_t nl = fragment.find('\n', pos);
if (nl == std::string::npos) break;
std::string line = fragment.substr(pos, nl - pos);
if (!line.empty() && line.back() == '\r')
line.pop_back();
if (!on_line(line)) goto done;
if (nl == std::string::npos) break;
pos = nl + 1;
}
if (pos > 0)
fragment = fragment.substr(pos);
}
}
if (!fragment.empty()) {
if (fragment.back() == '\r')
fragment.pop_back();
if (!fragment.empty())
on_line(fragment);
}
} else {
while (!parser.is_done()) {
http::read_some(stream, buffer, parser, ec);