Files
MuMuAINovel/backend/app/services/ai_clients/gemini_client.py
T

141 lines
5.3 KiB
Python

"""Gemini 客户端"""
from typing import Any, AsyncGenerator, Dict, List, Optional
import httpx
from app.services.ai_config import AIClientConfig, default_config
class GeminiClient:
"""Google Gemini API 客户端"""
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
self.api_key = api_key
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
self.config = config or default_config
http_cfg = self.config.http
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=http_cfg.connect_timeout,
read=http_cfg.read_timeout,
write=http_cfg.write_timeout,
pool=http_cfg.pool_timeout
)
)
def _convert_tools_to_gemini(self, tools: list) -> list:
"""将 OpenAI 格式工具转换为 Gemini 格式"""
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
params.pop("$schema", None)
params.pop("additionalProperties", None)
if params and "type" not in params:
params["type"] = "object"
decl = {
"name": func["name"],
"description": func.get("description") or func["name"],
}
if params:
decl["parameters"] = params
gemini_tools.append(decl)
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
contents = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": contents,
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
if tools:
payload["tools"] = self._convert_tools_to_gemini(tools)
response = await self.client.post(url, json=payload)
response.raise_for_status()
data = response.json()
candidates = data.get("candidates", [])
if not candidates or len(candidates) == 0:
# 返回空内容而不是报错,保持流程继续
return {
"content": "",
"tool_calls": None,
"finish_reason": "stop"
}
parts = candidates[0].get("content", {}).get("parts", [])
text = ""
tool_calls = []
for part in parts:
if "text" in part:
text += part["text"]
elif "functionCall" in part:
fc = part["functionCall"]
tool_calls.append({
"id": f"call_{fc['name']}",
"type": "function",
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
})
return {
"content": text,
"tool_calls": tool_calls if tool_calls else None,
"finish_reason": "tool_calls" if tool_calls else "stop"
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
contents = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": contents,
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
async with self.client.stream("POST", url, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
import json
try:
data = json.loads(line[6:])
candidates = data.get("candidates", [])
if candidates and len(candidates) > 0:
parts = candidates[0].get("content", {}).get("parts", [])
if parts and len(parts) > 0:
text = parts[0].get("text", "")
if text:
yield text
except:
continue