141 lines
5.3 KiB
Python
141 lines
5.3 KiB
Python
"""Gemini 客户端"""
|
|
from typing import Any, AsyncGenerator, Dict, List, Optional
|
|
import httpx
|
|
from app.services.ai_config import AIClientConfig, default_config
|
|
|
|
|
|
class GeminiClient:
|
|
"""Google Gemini API 客户端"""
|
|
|
|
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
|
self.api_key = api_key
|
|
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
|
|
self.config = config or default_config
|
|
http_cfg = self.config.http
|
|
self.client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(
|
|
connect=http_cfg.connect_timeout,
|
|
read=http_cfg.read_timeout,
|
|
write=http_cfg.write_timeout,
|
|
pool=http_cfg.pool_timeout
|
|
)
|
|
)
|
|
|
|
def _convert_tools_to_gemini(self, tools: list) -> list:
|
|
"""将 OpenAI 格式工具转换为 Gemini 格式"""
|
|
gemini_tools = []
|
|
for tool in tools:
|
|
if tool.get("type") == "function":
|
|
func = tool["function"]
|
|
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
|
|
params.pop("$schema", None)
|
|
params.pop("additionalProperties", None)
|
|
if params and "type" not in params:
|
|
params["type"] = "object"
|
|
decl = {
|
|
"name": func["name"],
|
|
"description": func.get("description") or func["name"],
|
|
}
|
|
if params:
|
|
decl["parameters"] = params
|
|
gemini_tools.append(decl)
|
|
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
|
|
|
|
async def chat_completion(
|
|
self,
|
|
messages: list,
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
system_prompt: Optional[str] = None,
|
|
tools: Optional[list] = None,
|
|
tool_choice: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
|
|
|
contents = []
|
|
for msg in messages:
|
|
role = "user" if msg["role"] == "user" else "model"
|
|
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
|
|
|
payload = {
|
|
"contents": contents,
|
|
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
|
}
|
|
if system_prompt:
|
|
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
|
if tools:
|
|
payload["tools"] = self._convert_tools_to_gemini(tools)
|
|
|
|
response = await self.client.post(url, json=payload)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
candidates = data.get("candidates", [])
|
|
if not candidates or len(candidates) == 0:
|
|
# 返回空内容而不是报错,保持流程继续
|
|
return {
|
|
"content": "",
|
|
"tool_calls": None,
|
|
"finish_reason": "stop"
|
|
}
|
|
|
|
parts = candidates[0].get("content", {}).get("parts", [])
|
|
text = ""
|
|
tool_calls = []
|
|
|
|
for part in parts:
|
|
if "text" in part:
|
|
text += part["text"]
|
|
elif "functionCall" in part:
|
|
fc = part["functionCall"]
|
|
tool_calls.append({
|
|
"id": f"call_{fc['name']}",
|
|
"type": "function",
|
|
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
|
|
})
|
|
|
|
return {
|
|
"content": text,
|
|
"tool_calls": tool_calls if tool_calls else None,
|
|
"finish_reason": "tool_calls" if tool_calls else "stop"
|
|
}
|
|
|
|
async def chat_completion_stream(
|
|
self,
|
|
messages: list,
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
system_prompt: Optional[str] = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
|
|
|
contents = []
|
|
for msg in messages:
|
|
role = "user" if msg["role"] == "user" else "model"
|
|
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
|
|
|
payload = {
|
|
"contents": contents,
|
|
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
|
}
|
|
if system_prompt:
|
|
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
|
|
|
async with self.client.stream("POST", url, json=payload) as response:
|
|
response.raise_for_status()
|
|
async for line in response.aiter_lines():
|
|
if line.startswith("data: "):
|
|
import json
|
|
try:
|
|
data = json.loads(line[6:])
|
|
candidates = data.get("candidates", [])
|
|
if candidates and len(candidates) > 0:
|
|
parts = candidates[0].get("content", {}).get("parts", [])
|
|
if parts and len(parts) > 0:
|
|
text = parts[0].get("text", "")
|
|
if text:
|
|
yield text
|
|
except:
|
|
continue |