154 lines
4.9 KiB
Python
154 lines
4.9 KiB
Python
"""AI 客户端基类"""
|
|
import asyncio
|
|
import hashlib
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, AsyncGenerator, Dict, Optional
|
|
|
|
import httpx
|
|
|
|
from app.logger import get_logger
|
|
from app.services.ai_config import AIClientConfig, default_config
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# 全局 HTTP 客户端池
|
|
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
|
|
_global_semaphore: Optional[asyncio.Semaphore] = None
|
|
|
|
|
|
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
|
|
"""获取全局信号量"""
|
|
global _global_semaphore
|
|
if _global_semaphore is None:
|
|
_global_semaphore = asyncio.Semaphore(max_concurrent)
|
|
return _global_semaphore
|
|
|
|
|
|
class BaseAIClient(ABC):
|
|
"""AI HTTP 客户端基类"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
base_url: str,
|
|
config: Optional[AIClientConfig] = None,
|
|
):
|
|
self.api_key = api_key
|
|
self.base_url = base_url.rstrip("/")
|
|
self.config = config or default_config
|
|
self.http_client = self._get_or_create_client()
|
|
|
|
def _get_client_key(self) -> str:
|
|
"""生成客户端唯一键"""
|
|
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
|
|
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
|
|
|
|
def _get_or_create_client(self) -> httpx.AsyncClient:
|
|
"""获取或创建 HTTP 客户端"""
|
|
client_key = self._get_client_key()
|
|
|
|
if client_key in _http_client_pool:
|
|
client = _http_client_pool[client_key]
|
|
if not client.is_closed:
|
|
return client
|
|
del _http_client_pool[client_key]
|
|
|
|
http_cfg = self.config.http
|
|
client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(
|
|
connect=http_cfg.connect_timeout,
|
|
read=http_cfg.read_timeout,
|
|
write=http_cfg.write_timeout,
|
|
pool=http_cfg.pool_timeout,
|
|
),
|
|
limits=httpx.Limits(
|
|
max_keepalive_connections=http_cfg.max_keepalive_connections,
|
|
max_connections=http_cfg.max_connections,
|
|
keepalive_expiry=http_cfg.keepalive_expiry,
|
|
),
|
|
)
|
|
_http_client_pool[client_key] = client
|
|
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
|
|
return client
|
|
|
|
@abstractmethod
|
|
def _build_headers(self) -> Dict[str, str]:
|
|
"""构建请求头"""
|
|
pass
|
|
|
|
async def _request_with_retry(
|
|
self,
|
|
method: str,
|
|
endpoint: str,
|
|
payload: Dict[str, Any],
|
|
stream: bool = False,
|
|
) -> Any:
|
|
"""带重试的 HTTP 请求"""
|
|
url = f"{self.base_url}{endpoint}"
|
|
headers = self._build_headers()
|
|
retry_cfg = self.config.retry
|
|
rate_cfg = self.config.rate_limit
|
|
|
|
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
|
|
|
|
async with semaphore:
|
|
await asyncio.sleep(rate_cfg.request_delay)
|
|
|
|
for attempt in range(retry_cfg.max_retries):
|
|
try:
|
|
if attempt > 0:
|
|
delay = min(
|
|
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
|
|
retry_cfg.max_delay,
|
|
)
|
|
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
|
|
await asyncio.sleep(delay)
|
|
|
|
if stream:
|
|
return self.http_client.stream(method, url, headers=headers, json=payload)
|
|
|
|
response = await self.http_client.request(method, url, headers=headers, json=payload)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
if e.response.status_code in retry_cfg.non_retryable_status_codes:
|
|
raise
|
|
if attempt == retry_cfg.max_retries - 1:
|
|
raise
|
|
except (httpx.ConnectError, httpx.TimeoutException):
|
|
if attempt == retry_cfg.max_retries - 1:
|
|
raise
|
|
|
|
@abstractmethod
|
|
async def chat_completion(
|
|
self,
|
|
messages: list,
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
tools: Optional[list] = None,
|
|
tool_choice: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""聊天补全"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def chat_completion_stream(
|
|
self,
|
|
messages: list,
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""流式聊天补全"""
|
|
pass
|
|
|
|
|
|
async def cleanup_all_clients():
|
|
"""清理所有 HTTP 客户端"""
|
|
for key, client in list(_http_client_pool.items()):
|
|
if not client.is_closed:
|
|
await client.aclose()
|
|
_http_client_pool.clear()
|
|
logger.info("✅ HTTP 客户端池已清理") |