update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
@@ -0,0 +1,141 @@
|
||||
"""Gemini 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import httpx
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Google Gemini API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
|
||||
self.config = config or default_config
|
||||
http_cfg = self.config.http
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_tools_to_gemini(self, tools: list) -> list:
|
||||
"""将 OpenAI 格式工具转换为 Gemini 格式"""
|
||||
gemini_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
|
||||
params.pop("$schema", None)
|
||||
params.pop("additionalProperties", None)
|
||||
if params and "type" not in params:
|
||||
params["type"] = "object"
|
||||
decl = {
|
||||
"name": func["name"],
|
||||
"description": func.get("description") or func["name"],
|
||||
}
|
||||
if params:
|
||||
decl["parameters"] = params
|
||||
gemini_tools.append(decl)
|
||||
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||
|
||||
response = await self.client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates or len(candidates) == 0:
|
||||
# 返回空内容而不是报错,保持流程继续
|
||||
return {
|
||||
"content": "",
|
||||
"tool_calls": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
text = ""
|
||||
tool_calls = []
|
||||
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append({
|
||||
"id": f"call_{fc['name']}",
|
||||
"type": "function",
|
||||
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
|
||||
})
|
||||
|
||||
return {
|
||||
"content": text,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
|
||||
async with self.client.stream("POST", url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
import json
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates and len(candidates) > 0:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
if parts and len(parts) > 0:
|
||||
text = parts[0].get("text", "")
|
||||
if text:
|
||||
yield text
|
||||
except:
|
||||
continue
|
||||
Reference in New Issue
Block a user