"""OpenAI 客户端""" import json from typing import Any, AsyncGenerator, Dict, Optional from app.logger import get_logger from .base_client import BaseAIClient logger = get_logger(__name__) class OpenAIClient(BaseAIClient): """OpenAI API 客户端""" def _build_headers(self) -> Dict[str, str]: return { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } def _build_payload( self, messages: list, model: str, temperature: float, max_tokens: int, tools: Optional[list] = None, tool_choice: Optional[str] = None, stream: bool = False, ) -> Dict[str, Any]: payload = { "model": model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } if stream: payload["stream"] = True if tools: # 清理 $schema 字段 cleaned = [] for t in tools: tc = t.copy() if "function" in tc and "parameters" in tc["function"]: tc["function"]["parameters"] = { k: v for k, v in tc["function"]["parameters"].items() if k != "$schema" } cleaned.append(tc) payload["tools"] = cleaned if tool_choice: payload["tool_choice"] = tool_choice return payload async def chat_completion( self, messages: list, model: str, temperature: float, max_tokens: int, tools: Optional[list] = None, tool_choice: Optional[str] = None, ) -> Dict[str, Any]: payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice) logger.debug(f"📤 OpenAI 请求 payload: {json.dumps(payload, ensure_ascii=False, indent=2)}") data = await self._request_with_retry("POST", "/chat/completions", payload) # 调试日志:输出原始响应 logger.debug(f"📥 OpenAI 原始响应: {json.dumps(data, ensure_ascii=False, indent=2)}") choices = data.get("choices", []) if not choices or len(choices) == 0: raise ValueError("API 返回空 choices 或 choices 为空列表") choice = choices[0] message = choice.get("message", {}) return { "content": message.get("content", ""), "tool_calls": message.get("tool_calls"), "finish_reason": choice.get("finish_reason"), } async def chat_completion_stream( self, messages: list, model: str, temperature: float, max_tokens: int, tools: Optional[list] = None, tool_choice: Optional[str] = None, ) -> AsyncGenerator[Dict[str, Any], None]: """ 流式生成,支持工具调用 Yields: Dict with keys: - content: str - 文本内容块 - tool_calls: list - 工具调用列表(如果有) - done: bool - 是否结束 """ payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice, stream=True) tool_calls_buffer = {} # 收集工具调用块 try: async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response: response.raise_for_status() try: async for line in response.aiter_lines(): if line.startswith("data: "): data_str = line[6:] if data_str.strip() == "[DONE]": # 流结束,检查是否有工具调用需要处理 if tool_calls_buffer: yield {"tool_calls": list(tool_calls_buffer.values()), "done": True} yield {"done": True} break try: data = json.loads(data_str) choices = data.get("choices", []) if choices and len(choices) > 0: delta = choices[0].get("delta", {}) content = delta.get("content", "") # 检查工具调用 tc_list = delta.get("tool_calls") if tc_list: for tc in tc_list: index = tc.get("index", 0) if index not in tool_calls_buffer: tool_calls_buffer[index] = tc else: existing = tool_calls_buffer[index] # 合并 function.arguments if "function" in tc and "function" in existing: if tc["function"].get("arguments"): existing["function"]["arguments"] = ( existing["function"].get("arguments", "") + tc["function"]["arguments"] ) if content: yield {"content": content} except json.JSONDecodeError: continue except GeneratorExit: # 生成器被关闭,这是正常的清理过程 logger.debug("流式响应生成器被关闭(GeneratorExit)") raise except Exception as iter_error: logger.error(f"流式响应迭代出错: {str(iter_error)}") raise except GeneratorExit: # 重新抛出GeneratorExit,让调用方处理 raise except Exception as e: logger.error(f"流式请求出错: {str(e)}") raise