update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词

This commit is contained in:
xiamuceer
2025-12-28 19:35:23 +08:00
parent f32e51b594
commit 89848e2258
40 changed files with 2752 additions and 1824 deletions
@@ -0,0 +1,6 @@
"""AI 客户端模块"""
from .base_client import BaseAIClient
from .openai_client import OpenAIClient
from .anthropic_client import AnthropicClient
__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"]
@@ -0,0 +1,86 @@
"""Anthropic 客户端"""
from typing import Any, AsyncGenerator, Dict, Optional
from anthropic import AsyncAnthropic
from app.logger import get_logger
from app.services.ai_config import AIClientConfig, default_config
logger = get_logger(__name__)
class AnthropicClient:
"""Anthropic API 客户端"""
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
self.config = config or default_config
kwargs = {"api_key": api_key}
if base_url:
kwargs["base_url"] = base_url
self.client = AsyncAnthropic(**kwargs)
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
kwargs = {
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"messages": messages,
}
if system_prompt:
kwargs["system"] = system_prompt
if tools:
kwargs["tools"] = tools
if tool_choice == "required":
kwargs["tool_choice"] = {"type": "any"}
elif tool_choice == "auto":
kwargs["tool_choice"] = {"type": "auto"}
response = await self.client.messages.create(**kwargs)
tool_calls = []
content = ""
for block in response.content:
if block.type == "tool_use":
tool_calls.append({
"id": block.id,
"type": "function",
"function": {"name": block.name, "arguments": block.input},
})
elif block.type == "text":
content += block.text
return {
"content": content,
"tool_calls": tool_calls if tool_calls else None,
"finish_reason": response.stop_reason,
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
kwargs = {
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
"messages": messages,
}
if system_prompt:
kwargs["system"] = system_prompt
async with self.client.messages.stream(**kwargs) as stream:
async for text in stream.text_stream:
yield text
@@ -0,0 +1,154 @@
"""AI 客户端基类"""
import asyncio
import hashlib
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, Optional
import httpx
from app.logger import get_logger
from app.services.ai_config import AIClientConfig, default_config
logger = get_logger(__name__)
# 全局 HTTP 客户端池
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
_global_semaphore: Optional[asyncio.Semaphore] = None
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
"""获取全局信号量"""
global _global_semaphore
if _global_semaphore is None:
_global_semaphore = asyncio.Semaphore(max_concurrent)
return _global_semaphore
class BaseAIClient(ABC):
"""AI HTTP 客户端基类"""
def __init__(
self,
api_key: str,
base_url: str,
config: Optional[AIClientConfig] = None,
):
self.api_key = api_key
self.base_url = base_url.rstrip("/")
self.config = config or default_config
self.http_client = self._get_or_create_client()
def _get_client_key(self) -> str:
"""生成客户端唯一键"""
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
def _get_or_create_client(self) -> httpx.AsyncClient:
"""获取或创建 HTTP 客户端"""
client_key = self._get_client_key()
if client_key in _http_client_pool:
client = _http_client_pool[client_key]
if not client.is_closed:
return client
del _http_client_pool[client_key]
http_cfg = self.config.http
client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=http_cfg.connect_timeout,
read=http_cfg.read_timeout,
write=http_cfg.write_timeout,
pool=http_cfg.pool_timeout,
),
limits=httpx.Limits(
max_keepalive_connections=http_cfg.max_keepalive_connections,
max_connections=http_cfg.max_connections,
keepalive_expiry=http_cfg.keepalive_expiry,
),
)
_http_client_pool[client_key] = client
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
return client
@abstractmethod
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
pass
async def _request_with_retry(
self,
method: str,
endpoint: str,
payload: Dict[str, Any],
stream: bool = False,
) -> Any:
"""带重试的 HTTP 请求"""
url = f"{self.base_url}{endpoint}"
headers = self._build_headers()
retry_cfg = self.config.retry
rate_cfg = self.config.rate_limit
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
async with semaphore:
await asyncio.sleep(rate_cfg.request_delay)
for attempt in range(retry_cfg.max_retries):
try:
if attempt > 0:
delay = min(
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
retry_cfg.max_delay,
)
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
await asyncio.sleep(delay)
if stream:
return self.http_client.stream(method, url, headers=headers, json=payload)
response = await self.http_client.request(method, url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
if e.response.status_code in retry_cfg.non_retryable_status_codes:
raise
if attempt == retry_cfg.max_retries - 1:
raise
except (httpx.ConnectError, httpx.TimeoutException):
if attempt == retry_cfg.max_retries - 1:
raise
@abstractmethod
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天补全"""
pass
@abstractmethod
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
) -> AsyncGenerator[str, None]:
"""流式聊天补全"""
pass
async def cleanup_all_clients():
"""清理所有 HTTP 客户端"""
for key, client in list(_http_client_pool.items()):
if not client.is_closed:
await client.aclose()
_http_client_pool.clear()
logger.info("✅ HTTP 客户端池已清理")
@@ -0,0 +1,141 @@
"""Gemini 客户端"""
from typing import Any, AsyncGenerator, Dict, List, Optional
import httpx
from app.services.ai_config import AIClientConfig, default_config
class GeminiClient:
"""Google Gemini API 客户端"""
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
self.api_key = api_key
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
self.config = config or default_config
http_cfg = self.config.http
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(
connect=http_cfg.connect_timeout,
read=http_cfg.read_timeout,
write=http_cfg.write_timeout,
pool=http_cfg.pool_timeout
)
)
def _convert_tools_to_gemini(self, tools: list) -> list:
"""将 OpenAI 格式工具转换为 Gemini 格式"""
gemini_tools = []
for tool in tools:
if tool.get("type") == "function":
func = tool["function"]
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
params.pop("$schema", None)
params.pop("additionalProperties", None)
if params and "type" not in params:
params["type"] = "object"
decl = {
"name": func["name"],
"description": func.get("description") or func["name"],
}
if params:
decl["parameters"] = params
gemini_tools.append(decl)
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
contents = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": contents,
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
if tools:
payload["tools"] = self._convert_tools_to_gemini(tools)
response = await self.client.post(url, json=payload)
response.raise_for_status()
data = response.json()
candidates = data.get("candidates", [])
if not candidates or len(candidates) == 0:
# 返回空内容而不是报错,保持流程继续
return {
"content": "",
"tool_calls": None,
"finish_reason": "stop"
}
parts = candidates[0].get("content", {}).get("parts", [])
text = ""
tool_calls = []
for part in parts:
if "text" in part:
text += part["text"]
elif "functionCall" in part:
fc = part["functionCall"]
tool_calls.append({
"id": f"call_{fc['name']}",
"type": "function",
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
})
return {
"content": text,
"tool_calls": tool_calls if tool_calls else None,
"finish_reason": "tool_calls" if tool_calls else "stop"
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
contents = []
for msg in messages:
role = "user" if msg["role"] == "user" else "model"
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
payload = {
"contents": contents,
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
}
if system_prompt:
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
async with self.client.stream("POST", url, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
import json
try:
data = json.loads(line[6:])
candidates = data.get("candidates", [])
if candidates and len(candidates) > 0:
parts = candidates[0].get("content", {}).get("parts", [])
if parts and len(parts) > 0:
text = parts[0].get("text", "")
if text:
yield text
except:
continue
@@ -0,0 +1,101 @@
"""OpenAI 客户端"""
import json
from typing import Any, AsyncGenerator, Dict, Optional
from app.logger import get_logger
from .base_client import BaseAIClient
logger = get_logger(__name__)
class OpenAIClient(BaseAIClient):
"""OpenAI API 客户端"""
def _build_headers(self) -> Dict[str, str]:
return {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def _build_payload(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
stream: bool = False,
) -> Dict[str, Any]:
payload = {
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if stream:
payload["stream"] = True
if tools:
# 清理 $schema 字段
cleaned = []
for t in tools:
tc = t.copy()
if "function" in tc and "parameters" in tc["function"]:
tc["function"]["parameters"] = {
k: v for k, v in tc["function"]["parameters"].items() if k != "$schema"
}
cleaned.append(tc)
payload["tools"] = cleaned
if tool_choice:
payload["tool_choice"] = tool_choice
return payload
async def chat_completion(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
tools: Optional[list] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice)
data = await self._request_with_retry("POST", "/chat/completions", payload)
choices = data.get("choices", [])
if not choices or len(choices) == 0:
raise ValueError("API 返回空 choices 或 choices 为空列表")
choice = choices[0]
message = choice.get("message", {})
return {
"content": message.get("content", ""),
"tool_calls": message.get("tool_calls"),
"finish_reason": choice.get("finish_reason"),
}
async def chat_completion_stream(
self,
messages: list,
model: str,
temperature: float,
max_tokens: int,
) -> AsyncGenerator[str, None]:
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if line.startswith("data: "):
data_str = line[6:]
if data_str.strip() == "[DONE]":
break
try:
data = json.loads(data_str)
choices = data.get("choices", [])
if choices and len(choices) > 0:
content = choices[0].get("delta", {}).get("content", "")
if content:
yield content
except json.JSONDecodeError:
continue