"""Gemini 客户端""" from typing import Any, AsyncGenerator, Dict, List, Optional import httpx from app.services.ai_config import AIClientConfig, default_config class GeminiClient: """Google Gemini API 客户端""" def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None): self.api_key = api_key self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/") self.config = config or default_config http_cfg = self.config.http self.client = httpx.AsyncClient( timeout=httpx.Timeout( connect=http_cfg.connect_timeout, read=http_cfg.read_timeout, write=http_cfg.write_timeout, pool=http_cfg.pool_timeout ) ) def _convert_tools_to_gemini(self, tools: list) -> list: """将 OpenAI 格式工具转换为 Gemini 格式""" gemini_tools = [] for tool in tools: if tool.get("type") == "function": func = tool["function"] params = func.get("parameters", {}).copy() if func.get("parameters") else {} params.pop("$schema", None) params.pop("additionalProperties", None) if params and "type" not in params: params["type"] = "object" decl = { "name": func["name"], "description": func.get("description") or func["name"], } if params: decl["parameters"] = params gemini_tools.append(decl) return [{"functionDeclarations": gemini_tools}] if gemini_tools else [] async def chat_completion( self, messages: list, model: str, temperature: float, max_tokens: int, system_prompt: Optional[str] = None, tools: Optional[list] = None, tool_choice: Optional[str] = None, ) -> Dict[str, Any]: url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}" contents = [] for msg in messages: role = "user" if msg["role"] == "user" else "model" contents.append({"role": role, "parts": [{"text": msg["content"]}]}) payload = { "contents": contents, "generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens} } if system_prompt: payload["systemInstruction"] = {"parts": [{"text": system_prompt}]} if tools: payload["tools"] = self._convert_tools_to_gemini(tools) response = await self.client.post(url, json=payload) response.raise_for_status() data = response.json() candidates = data.get("candidates", []) if not candidates or len(candidates) == 0: # 返回空内容而不是报错,保持流程继续 return { "content": "", "tool_calls": None, "finish_reason": "stop" } parts = candidates[0].get("content", {}).get("parts", []) text = "" tool_calls = [] for part in parts: if "text" in part: text += part["text"] elif "functionCall" in part: fc = part["functionCall"] tool_calls.append({ "id": f"call_{fc['name']}", "type": "function", "function": {"name": fc["name"], "arguments": fc.get("args", {})} }) return { "content": text, "tool_calls": tool_calls if tool_calls else None, "finish_reason": "tool_calls" if tool_calls else "stop" } async def chat_completion_stream( self, messages: list, model: str, temperature: float, max_tokens: int, system_prompt: Optional[str] = None, ) -> AsyncGenerator[str, None]: url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse" contents = [] for msg in messages: role = "user" if msg["role"] == "user" else "model" contents.append({"role": role, "parts": [{"text": msg["content"]}]}) payload = { "contents": contents, "generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens} } if system_prompt: payload["systemInstruction"] = {"parts": [{"text": system_prompt}]} async with self.client.stream("POST", url, json=payload) as response: response.raise_for_status() async for line in response.aiter_lines(): if line.startswith("data: "): import json try: data = json.loads(line[6:]) candidates = data.get("candidates", []) if candidates and len(candidates) > 0: parts = candidates[0].get("content", {}).get("parts", []) if parts and len(parts) > 0: text = parts[0].get("text", "") if text: yield text except: continue