update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
"""AI 客户端模块"""
|
||||
from .base_client import BaseAIClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
|
||||
__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"]
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Anthropic 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient:
|
||||
"""Anthropic API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.config = config or default_config
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
self.client = AsyncAnthropic(**kwargs)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "auto":
|
||||
kwargs["tool_choice"] = {"type": "auto"}
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
|
||||
tool_calls = []
|
||||
content = ""
|
||||
for block in response.content:
|
||||
if block.type == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {"name": block.name, "arguments": block.input},
|
||||
})
|
||||
elif block.type == "text":
|
||||
content += block.text
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": response.stop_reason,
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
@@ -0,0 +1,154 @@
|
||||
"""AI 客户端基类"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局 HTTP 客户端池
|
||||
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
|
||||
_global_semaphore: Optional[asyncio.Semaphore] = None
|
||||
|
||||
|
||||
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
|
||||
"""获取全局信号量"""
|
||||
global _global_semaphore
|
||||
if _global_semaphore is None:
|
||||
_global_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
return _global_semaphore
|
||||
|
||||
|
||||
class BaseAIClient(ABC):
|
||||
"""AI HTTP 客户端基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
config: Optional[AIClientConfig] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.config = config or default_config
|
||||
self.http_client = self._get_or_create_client()
|
||||
|
||||
def _get_client_key(self) -> str:
|
||||
"""生成客户端唯一键"""
|
||||
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
|
||||
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
|
||||
|
||||
def _get_or_create_client(self) -> httpx.AsyncClient:
|
||||
"""获取或创建 HTTP 客户端"""
|
||||
client_key = self._get_client_key()
|
||||
|
||||
if client_key in _http_client_pool:
|
||||
client = _http_client_pool[client_key]
|
||||
if not client.is_closed:
|
||||
return client
|
||||
del _http_client_pool[client_key]
|
||||
|
||||
http_cfg = self.config.http
|
||||
client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=http_cfg.max_keepalive_connections,
|
||||
max_connections=http_cfg.max_connections,
|
||||
keepalive_expiry=http_cfg.keepalive_expiry,
|
||||
),
|
||||
)
|
||||
_http_client_pool[client_key] = client
|
||||
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
|
||||
return client
|
||||
|
||||
@abstractmethod
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
pass
|
||||
|
||||
async def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
stream: bool = False,
|
||||
) -> Any:
|
||||
"""带重试的 HTTP 请求"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = self._build_headers()
|
||||
retry_cfg = self.config.retry
|
||||
rate_cfg = self.config.rate_limit
|
||||
|
||||
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
|
||||
|
||||
async with semaphore:
|
||||
await asyncio.sleep(rate_cfg.request_delay)
|
||||
|
||||
for attempt in range(retry_cfg.max_retries):
|
||||
try:
|
||||
if attempt > 0:
|
||||
delay = min(
|
||||
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
|
||||
retry_cfg.max_delay,
|
||||
)
|
||||
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if stream:
|
||||
return self.http_client.stream(method, url, headers=headers, json=payload)
|
||||
|
||||
response = await self.http_client.request(method, url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code in retry_cfg.non_retryable_status_codes:
|
||||
raise
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天补全"""
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup_all_clients():
|
||||
"""清理所有 HTTP 客户端"""
|
||||
for key, client in list(_http_client_pool.items()):
|
||||
if not client.is_closed:
|
||||
await client.aclose()
|
||||
_http_client_pool.clear()
|
||||
logger.info("✅ HTTP 客户端池已清理")
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Gemini 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import httpx
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Google Gemini API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
|
||||
self.config = config or default_config
|
||||
http_cfg = self.config.http
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_tools_to_gemini(self, tools: list) -> list:
|
||||
"""将 OpenAI 格式工具转换为 Gemini 格式"""
|
||||
gemini_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
|
||||
params.pop("$schema", None)
|
||||
params.pop("additionalProperties", None)
|
||||
if params and "type" not in params:
|
||||
params["type"] = "object"
|
||||
decl = {
|
||||
"name": func["name"],
|
||||
"description": func.get("description") or func["name"],
|
||||
}
|
||||
if params:
|
||||
decl["parameters"] = params
|
||||
gemini_tools.append(decl)
|
||||
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||
|
||||
response = await self.client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates or len(candidates) == 0:
|
||||
# 返回空内容而不是报错,保持流程继续
|
||||
return {
|
||||
"content": "",
|
||||
"tool_calls": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
text = ""
|
||||
tool_calls = []
|
||||
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append({
|
||||
"id": f"call_{fc['name']}",
|
||||
"type": "function",
|
||||
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
|
||||
})
|
||||
|
||||
return {
|
||||
"content": text,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
|
||||
async with self.client.stream("POST", url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
import json
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates and len(candidates) > 0:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
if parts and len(parts) > 0:
|
||||
text = parts[0].get("text", "")
|
||||
if text:
|
||||
yield text
|
||||
except:
|
||||
continue
|
||||
@@ -0,0 +1,101 @@
|
||||
"""OpenAI 客户端"""
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from .base_client import BaseAIClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient(BaseAIClient):
|
||||
"""OpenAI API 客户端"""
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if stream:
|
||||
payload["stream"] = True
|
||||
if tools:
|
||||
# 清理 $schema 字段
|
||||
cleaned = []
|
||||
for t in tools:
|
||||
tc = t.copy()
|
||||
if "function" in tc and "parameters" in tc["function"]:
|
||||
tc["function"]["parameters"] = {
|
||||
k: v for k, v in tc["function"]["parameters"].items() if k != "$schema"
|
||||
}
|
||||
cleaned.append(tc)
|
||||
payload["tools"] = cleaned
|
||||
if tool_choice:
|
||||
payload["tool_choice"] = tool_choice
|
||||
return payload
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice)
|
||||
data = await self._request_with_retry("POST", "/chat/completions", payload)
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices or len(choices) == 0:
|
||||
raise ValueError("API 返回空 choices 或 choices 为空列表")
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
return {
|
||||
"content": message.get("content", ""),
|
||||
"tool_calls": message.get("tool_calls"),
|
||||
"finish_reason": choice.get("finish_reason"),
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
|
||||
|
||||
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
content = choices[0].get("delta", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
Reference in New Issue
Block a user