Files
dstalk/dstalk-core/src/ai/deepseek_api.cpp
XiuChengWu 16475ca3fe Fix streaming and file IO edge cases
Repair streaming callback/error handling and make file/session handling safer so the core API behaves correctly under real usage.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-25 19:26:47 +08:00

227 lines
6.5 KiB
C++

#include "ai/deepseek_api.hpp"
#include "net/http_client.hpp"
#include <boost/json.hpp>
#include <sstream>
#include <cstring>
namespace json = boost::json;
namespace dstalk {
namespace ai {
// ---- JSON 构造 ----
static std::string build_request_json(
const ApiConfig& cfg,
const std::vector<Message>& history,
const std::string& user_input,
bool stream)
{
json::object root;
root["model"] = cfg.model;
root["max_tokens"] = cfg.max_tokens;
root["temperature"] = cfg.temperature;
root["stream"] = stream;
json::array msgs;
for (const auto& m : history) {
json::object obj;
obj["role"] = m.role;
obj["content"] = m.content;
msgs.push_back(obj);
}
// 追加当前用户输入
{
json::object obj;
obj["role"] = "user";
obj["content"] = user_input;
msgs.push_back(obj);
}
root["messages"] = msgs;
return json::serialize(root);
}
// ---- JSON 响应解析 ----
static ChatResult parse_response(const std::string& body, int http_status)
{
ChatResult r;
r.http_status = http_status;
if (http_status < 200 || http_status >= 300) {
r.ok = false;
// 尝试提取错误信息
try {
auto jv = json::parse(body);
auto obj = jv.as_object();
if (obj.contains("error")) {
auto err = obj["error"].as_object();
r.error = json::value_to<std::string>(err["message"]);
}
} catch (...) {
r.error = "HTTP " + std::to_string(http_status);
}
return r;
}
try {
auto jv = json::parse(body);
auto obj = jv.as_object();
auto choices = obj["choices"].as_array();
if (!choices.empty()) {
auto msg = choices[0].as_object()["message"].as_object();
r.content = json::value_to<std::string>(msg["content"]);
r.ok = true;
} else {
r.ok = false;
r.error = "empty response";
}
} catch (std::exception& e) {
r.ok = false;
r.error = std::string("json parse: ") + e.what();
}
return r;
}
// ---- SSE 行解析 ----
static bool parse_sse_line(const std::string& line, std::string& token_out)
{
// SSE 格式: "data: <json>" 或 "data: [DONE]"
if (line.rfind("data: ", 0) != 0) return false;
std::string data = line.substr(6);
if (data == "[DONE]") {
token_out.clear();
return true; // 流结束信号
}
try {
auto jv = json::parse(data);
auto obj = jv.as_object();
auto choices = obj["choices"].as_array();
if (!choices.empty()) {
auto delta = choices[0].as_object()["delta"].as_object();
if (delta.contains("content")) {
token_out = json::value_to<std::string>(delta["content"]);
return true;
}
}
} catch (...) {
// 忽略解析失败的行
}
return false;
}
// ---- Impl ----
struct DeepSeekClient::Impl {
net::HttpClient http;
ApiConfig config;
std::string extract_host_port(std::string& target) {
// base_url 例如 "https://api.deepseek.com/v1"
// 提取 host: "api.deepseek.com"
// 提取 target 前缀: "/v1"
std::string url = config.base_url;
if (url.rfind("https://", 0) == 0) url = url.substr(8);
else if (url.rfind("http://", 0) == 0) url = url.substr(7);
size_t slash = url.find('/');
if (slash != std::string::npos) {
target = url.substr(slash);
return url.substr(0, slash);
}
target = "/";
return url;
}
};
DeepSeekClient::DeepSeekClient() : impl_(new Impl{}) {}
DeepSeekClient::~DeepSeekClient() { delete impl_; }
void DeepSeekClient::configure(const ApiConfig& config)
{
impl_->config = config;
}
ChatResult DeepSeekClient::chat(
const std::vector<Message>& history,
const std::string& user_input)
{
std::string target;
std::string host = impl_->extract_host_port(target);
std::string target_path = target + "/chat/completions";
std::string body = build_request_json(
impl_->config, history, user_input, false);
std::unordered_map<std::string, std::string> headers;
headers["Authorization"] = "Bearer " + impl_->config.api_key;
auto resp = impl_->http.post_json(host, "443", target_path, body, headers);
return parse_response(resp.body, resp.status_code);
}
ChatResult DeepSeekClient::chat_stream(
const std::vector<Message>& history,
const std::string& user_input,
bool (*on_token)(const std::string& token, void* userdata),
void* userdata)
{
std::string target;
std::string host = impl_->extract_host_port(target);
std::string target_path = target + "/chat/completions";
std::string body = build_request_json(
impl_->config, history, user_input, true);
std::unordered_map<std::string, std::string> headers;
headers["Authorization"] = "Bearer " + impl_->config.api_key;
ChatResult result;
auto resp = impl_->http.post_stream(host, "443", target_path, body, headers,
[&](const std::string& line) -> bool {
if (line.empty()) return true;
std::string token;
if (!parse_sse_line(line, token)) return true;
if (token.empty()) return false; // [DONE]
result.content += token;
return on_token ? on_token(token, userdata) : true;
});
result.http_status = resp.status_code;
// 检查传输层错误或非 2xx 状态
if (resp.status_code < 200 || resp.status_code >= 300) {
result.ok = false;
// 尝试从响应 body 提取错误信息(与 parse_response 等同逻辑)
try {
auto jv = json::parse(resp.body);
auto obj = jv.as_object();
if (obj.contains("error")) {
auto err = obj["error"].as_object();
result.error = json::value_to<std::string>(err["message"]);
}
} catch (...) {
}
if (result.error.empty()) {
if (resp.status_code <= 0) {
result.error = "transport error";
} else {
result.error = "HTTP " + std::to_string(resp.status_code);
}
}
return result;
}
if (result.content.empty()) {
result.ok = false;
result.error = "no content received";
} else {
result.ok = true;
}
return result;
}
} // namespace ai
} // namespace dstalk