#include "ai/deepseek_api.hpp" #include "net/http_client.hpp" #include #include #include namespace json = boost::json; namespace dstalk { namespace ai { // ---- JSON 构造 ---- static std::string build_request_json( const ApiConfig& cfg, const std::vector& history, const std::string& user_input, bool stream) { json::object root; root["model"] = cfg.model; root["max_tokens"] = cfg.max_tokens; root["temperature"] = cfg.temperature; root["stream"] = stream; json::array msgs; for (const auto& m : history) { json::object obj; obj["role"] = m.role; obj["content"] = m.content; msgs.push_back(obj); } // 追加当前用户输入 { json::object obj; obj["role"] = "user"; obj["content"] = user_input; msgs.push_back(obj); } root["messages"] = msgs; return json::serialize(root); } // ---- JSON 响应解析 ---- static ChatResult parse_response(const std::string& body, int http_status) { ChatResult r; r.http_status = http_status; if (http_status < 200 || http_status >= 300) { r.ok = false; // 尝试提取错误信息 try { auto jv = json::parse(body); auto obj = jv.as_object(); if (obj.contains("error")) { auto err = obj["error"].as_object(); r.error = json::value_to(err["message"]); } } catch (...) { r.error = "HTTP " + std::to_string(http_status); } return r; } try { auto jv = json::parse(body); auto obj = jv.as_object(); auto choices = obj["choices"].as_array(); if (!choices.empty()) { auto msg = choices[0].as_object()["message"].as_object(); r.content = json::value_to(msg["content"]); r.ok = true; } else { r.ok = false; r.error = "empty response"; } } catch (std::exception& e) { r.ok = false; r.error = std::string("json parse: ") + e.what(); } return r; } // ---- SSE 行解析 ---- static bool parse_sse_line(const std::string& line, std::string& token_out) { // SSE 格式: "data: " 或 "data: [DONE]" if (line.rfind("data: ", 0) != 0) return false; std::string data = line.substr(6); if (data == "[DONE]") { token_out.clear(); return true; // 流结束信号 } try { auto jv = json::parse(data); auto obj = jv.as_object(); auto choices = obj["choices"].as_array(); if (!choices.empty()) { auto delta = choices[0].as_object()["delta"].as_object(); if (delta.contains("content")) { token_out = json::value_to(delta["content"]); return true; } } } catch (...) { // 忽略解析失败的行 } return false; } // ---- Impl ---- struct DeepSeekClient::Impl { net::HttpClient http; ApiConfig config; std::string extract_host_port(std::string& target) { // base_url 例如 "https://api.deepseek.com/v1" // 提取 host: "api.deepseek.com" // 提取 target 前缀: "/v1" std::string url = config.base_url; if (url.rfind("https://", 0) == 0) url = url.substr(8); else if (url.rfind("http://", 0) == 0) url = url.substr(7); size_t slash = url.find('/'); if (slash != std::string::npos) { target = url.substr(slash); return url.substr(0, slash); } target = "/"; return url; } }; DeepSeekClient::DeepSeekClient() : impl_(new Impl{}) {} DeepSeekClient::~DeepSeekClient() { delete impl_; } void DeepSeekClient::configure(const ApiConfig& config) { impl_->config = config; } ChatResult DeepSeekClient::chat( const std::vector& history, const std::string& user_input) { std::string target; std::string host = impl_->extract_host_port(target); std::string target_path = target + "/chat/completions"; std::string body = build_request_json( impl_->config, history, user_input, false); std::unordered_map headers; headers["Authorization"] = "Bearer " + impl_->config.api_key; auto resp = impl_->http.post_json(host, "443", target_path, body, headers); return parse_response(resp.body, resp.status_code); } ChatResult DeepSeekClient::chat_stream( const std::vector& history, const std::string& user_input, bool (*on_token)(const std::string& token, void* userdata), void* userdata) { std::string target; std::string host = impl_->extract_host_port(target); std::string target_path = target + "/chat/completions"; std::string body = build_request_json( impl_->config, history, user_input, true); std::unordered_map headers; headers["Authorization"] = "Bearer " + impl_->config.api_key; ChatResult result; auto resp = impl_->http.post_stream(host, "443", target_path, body, headers, [&](const std::string& line) -> bool { if (line.empty()) return true; std::string token; if (!parse_sse_line(line, token)) return true; if (token.empty()) return false; // [DONE] result.content += token; return on_token ? on_token(token, userdata) : true; }); result.http_status = resp.status_code; // 检查传输层错误或非 2xx 状态 if (resp.status_code < 200 || resp.status_code >= 300) { result.ok = false; // 尝试从响应 body 提取错误信息(与 parse_response 等同逻辑) try { auto jv = json::parse(resp.body); auto obj = jv.as_object(); if (obj.contains("error")) { auto err = obj["error"].as_object(); result.error = json::value_to(err["message"]); } } catch (...) { } if (result.error.empty()) { if (resp.status_code <= 0) { result.error = "transport error"; } else { result.error = "HTTP " + std::to_string(resp.status_code); } } return result; } if (result.content.empty()) { result.ok = false; result.error = "no content received"; } else { result.ok = true; } return result; } } // namespace ai } // namespace dstalk