update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
@@ -0,0 +1,154 @@
|
||||
"""AI 客户端基类"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局 HTTP 客户端池
|
||||
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
|
||||
_global_semaphore: Optional[asyncio.Semaphore] = None
|
||||
|
||||
|
||||
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
|
||||
"""获取全局信号量"""
|
||||
global _global_semaphore
|
||||
if _global_semaphore is None:
|
||||
_global_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
return _global_semaphore
|
||||
|
||||
|
||||
class BaseAIClient(ABC):
|
||||
"""AI HTTP 客户端基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
config: Optional[AIClientConfig] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.config = config or default_config
|
||||
self.http_client = self._get_or_create_client()
|
||||
|
||||
def _get_client_key(self) -> str:
|
||||
"""生成客户端唯一键"""
|
||||
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
|
||||
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
|
||||
|
||||
def _get_or_create_client(self) -> httpx.AsyncClient:
|
||||
"""获取或创建 HTTP 客户端"""
|
||||
client_key = self._get_client_key()
|
||||
|
||||
if client_key in _http_client_pool:
|
||||
client = _http_client_pool[client_key]
|
||||
if not client.is_closed:
|
||||
return client
|
||||
del _http_client_pool[client_key]
|
||||
|
||||
http_cfg = self.config.http
|
||||
client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=http_cfg.max_keepalive_connections,
|
||||
max_connections=http_cfg.max_connections,
|
||||
keepalive_expiry=http_cfg.keepalive_expiry,
|
||||
),
|
||||
)
|
||||
_http_client_pool[client_key] = client
|
||||
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
|
||||
return client
|
||||
|
||||
@abstractmethod
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
pass
|
||||
|
||||
async def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
stream: bool = False,
|
||||
) -> Any:
|
||||
"""带重试的 HTTP 请求"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = self._build_headers()
|
||||
retry_cfg = self.config.retry
|
||||
rate_cfg = self.config.rate_limit
|
||||
|
||||
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
|
||||
|
||||
async with semaphore:
|
||||
await asyncio.sleep(rate_cfg.request_delay)
|
||||
|
||||
for attempt in range(retry_cfg.max_retries):
|
||||
try:
|
||||
if attempt > 0:
|
||||
delay = min(
|
||||
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
|
||||
retry_cfg.max_delay,
|
||||
)
|
||||
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if stream:
|
||||
return self.http_client.stream(method, url, headers=headers, json=payload)
|
||||
|
||||
response = await self.http_client.request(method, url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code in retry_cfg.non_retryable_status_codes:
|
||||
raise
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天补全"""
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup_all_clients():
|
||||
"""清理所有 HTTP 客户端"""
|
||||
for key, client in list(_http_client_pool.items()):
|
||||
if not client.is_closed:
|
||||
await client.aclose()
|
||||
_http_client_pool.clear()
|
||||
logger.info("✅ HTTP 客户端池已清理")
|
||||
Reference in New Issue
Block a user