update:1.优化 AI 流式生成和进度显示系统 2.新增写作风格系统提示词支持 3.灵感模式功能增强,支持灵感重写 4.设置页面功能扩展,新增Gemini适配器 5.提示词模板系统优化,调整灵感模式提示词
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
"""AI 客户端模块"""
|
||||
from .base_client import BaseAIClient
|
||||
from .openai_client import OpenAIClient
|
||||
from .anthropic_client import AnthropicClient
|
||||
|
||||
__all__ = ["BaseAIClient", "OpenAIClient", "AnthropicClient"]
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Anthropic 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient:
|
||||
"""Anthropic API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.config = config or default_config
|
||||
kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
self.client = AsyncAnthropic(**kwargs)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if tool_choice == "required":
|
||||
kwargs["tool_choice"] = {"type": "any"}
|
||||
elif tool_choice == "auto":
|
||||
kwargs["tool_choice"] = {"type": "auto"}
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
|
||||
tool_calls = []
|
||||
content = ""
|
||||
for block in response.content:
|
||||
if block.type == "tool_use":
|
||||
tool_calls.append({
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {"name": block.name, "arguments": block.input},
|
||||
})
|
||||
elif block.type == "text":
|
||||
content += block.text
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": response.stop_reason,
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
async with self.client.messages.stream(**kwargs) as stream:
|
||||
async for text in stream.text_stream:
|
||||
yield text
|
||||
@@ -0,0 +1,154 @@
|
||||
"""AI 客户端基类"""
|
||||
import asyncio
|
||||
import hashlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局 HTTP 客户端池
|
||||
_http_client_pool: Dict[str, httpx.AsyncClient] = {}
|
||||
_global_semaphore: Optional[asyncio.Semaphore] = None
|
||||
|
||||
|
||||
def _get_semaphore(max_concurrent: int) -> asyncio.Semaphore:
|
||||
"""获取全局信号量"""
|
||||
global _global_semaphore
|
||||
if _global_semaphore is None:
|
||||
_global_semaphore = asyncio.Semaphore(max_concurrent)
|
||||
return _global_semaphore
|
||||
|
||||
|
||||
class BaseAIClient(ABC):
|
||||
"""AI HTTP 客户端基类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
config: Optional[AIClientConfig] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.config = config or default_config
|
||||
self.http_client = self._get_or_create_client()
|
||||
|
||||
def _get_client_key(self) -> str:
|
||||
"""生成客户端唯一键"""
|
||||
key_hash = hashlib.md5(self.api_key.encode()).hexdigest()[:8]
|
||||
return f"{self.__class__.__name__}_{self.base_url}_{key_hash}"
|
||||
|
||||
def _get_or_create_client(self) -> httpx.AsyncClient:
|
||||
"""获取或创建 HTTP 客户端"""
|
||||
client_key = self._get_client_key()
|
||||
|
||||
if client_key in _http_client_pool:
|
||||
client = _http_client_pool[client_key]
|
||||
if not client.is_closed:
|
||||
return client
|
||||
del _http_client_pool[client_key]
|
||||
|
||||
http_cfg = self.config.http
|
||||
client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout,
|
||||
),
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=http_cfg.max_keepalive_connections,
|
||||
max_connections=http_cfg.max_connections,
|
||||
keepalive_expiry=http_cfg.keepalive_expiry,
|
||||
),
|
||||
)
|
||||
_http_client_pool[client_key] = client
|
||||
logger.info(f"✅ 创建 HTTP 客户端: {client_key}")
|
||||
return client
|
||||
|
||||
@abstractmethod
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
pass
|
||||
|
||||
async def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
payload: Dict[str, Any],
|
||||
stream: bool = False,
|
||||
) -> Any:
|
||||
"""带重试的 HTTP 请求"""
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
headers = self._build_headers()
|
||||
retry_cfg = self.config.retry
|
||||
rate_cfg = self.config.rate_limit
|
||||
|
||||
semaphore = _get_semaphore(rate_cfg.max_concurrent_requests)
|
||||
|
||||
async with semaphore:
|
||||
await asyncio.sleep(rate_cfg.request_delay)
|
||||
|
||||
for attempt in range(retry_cfg.max_retries):
|
||||
try:
|
||||
if attempt > 0:
|
||||
delay = min(
|
||||
retry_cfg.base_delay * (retry_cfg.exponential_base ** attempt),
|
||||
retry_cfg.max_delay,
|
||||
)
|
||||
logger.warning(f"⚠️ 重试 {attempt + 1}/{retry_cfg.max_retries},等待 {delay}s")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if stream:
|
||||
return self.http_client.stream(method, url, headers=headers, json=payload)
|
||||
|
||||
response = await self.http_client.request(method, url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code in retry_cfg.non_retryable_status_codes:
|
||||
raise
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
except (httpx.ConnectError, httpx.TimeoutException):
|
||||
if attempt == retry_cfg.max_retries - 1:
|
||||
raise
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天补全"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式聊天补全"""
|
||||
pass
|
||||
|
||||
|
||||
async def cleanup_all_clients():
|
||||
"""清理所有 HTTP 客户端"""
|
||||
for key, client in list(_http_client_pool.items()):
|
||||
if not client.is_closed:
|
||||
await client.aclose()
|
||||
_http_client_pool.clear()
|
||||
logger.info("✅ HTTP 客户端池已清理")
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Gemini 客户端"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import httpx
|
||||
from app.services.ai_config import AIClientConfig, default_config
|
||||
|
||||
|
||||
class GeminiClient:
|
||||
"""Google Gemini API 客户端"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None, config: Optional[AIClientConfig] = None):
|
||||
self.api_key = api_key
|
||||
self.base_url = (base_url or "https://generativelanguage.googleapis.com/v1beta").rstrip("/")
|
||||
self.config = config or default_config
|
||||
http_cfg = self.config.http
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(
|
||||
connect=http_cfg.connect_timeout,
|
||||
read=http_cfg.read_timeout,
|
||||
write=http_cfg.write_timeout,
|
||||
pool=http_cfg.pool_timeout
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_tools_to_gemini(self, tools: list) -> list:
|
||||
"""将 OpenAI 格式工具转换为 Gemini 格式"""
|
||||
gemini_tools = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function":
|
||||
func = tool["function"]
|
||||
params = func.get("parameters", {}).copy() if func.get("parameters") else {}
|
||||
params.pop("$schema", None)
|
||||
params.pop("additionalProperties", None)
|
||||
if params and "type" not in params:
|
||||
params["type"] = "object"
|
||||
decl = {
|
||||
"name": func["name"],
|
||||
"description": func.get("description") or func["name"],
|
||||
}
|
||||
if params:
|
||||
decl["parameters"] = params
|
||||
gemini_tools.append(decl)
|
||||
return [{"functionDeclarations": gemini_tools}] if gemini_tools else []
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.base_url}/models/{model}:generateContent?key={self.api_key}"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
if tools:
|
||||
payload["tools"] = self._convert_tools_to_gemini(tools)
|
||||
|
||||
response = await self.client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
candidates = data.get("candidates", [])
|
||||
if not candidates or len(candidates) == 0:
|
||||
# 返回空内容而不是报错,保持流程继续
|
||||
return {
|
||||
"content": "",
|
||||
"tool_calls": None,
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
text = ""
|
||||
tool_calls = []
|
||||
|
||||
for part in parts:
|
||||
if "text" in part:
|
||||
text += part["text"]
|
||||
elif "functionCall" in part:
|
||||
fc = part["functionCall"]
|
||||
tool_calls.append({
|
||||
"id": f"call_{fc['name']}",
|
||||
"type": "function",
|
||||
"function": {"name": fc["name"], "arguments": fc.get("args", {})}
|
||||
})
|
||||
|
||||
return {
|
||||
"content": text,
|
||||
"tool_calls": tool_calls if tool_calls else None,
|
||||
"finish_reason": "tool_calls" if tool_calls else "stop"
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
url = f"{self.base_url}/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
|
||||
|
||||
contents = []
|
||||
for msg in messages:
|
||||
role = "user" if msg["role"] == "user" else "model"
|
||||
contents.append({"role": role, "parts": [{"text": msg["content"]}]})
|
||||
|
||||
payload = {
|
||||
"contents": contents,
|
||||
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}
|
||||
}
|
||||
if system_prompt:
|
||||
payload["systemInstruction"] = {"parts": [{"text": system_prompt}]}
|
||||
|
||||
async with self.client.stream("POST", url, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
import json
|
||||
try:
|
||||
data = json.loads(line[6:])
|
||||
candidates = data.get("candidates", [])
|
||||
if candidates and len(candidates) > 0:
|
||||
parts = candidates[0].get("content", {}).get("parts", [])
|
||||
if parts and len(parts) > 0:
|
||||
text = parts[0].get("text", "")
|
||||
if text:
|
||||
yield text
|
||||
except:
|
||||
continue
|
||||
@@ -0,0 +1,101 @@
|
||||
"""OpenAI 客户端"""
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from app.logger import get_logger
|
||||
from .base_client import BaseAIClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient(BaseAIClient):
|
||||
"""OpenAI API 客户端"""
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _build_payload(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if stream:
|
||||
payload["stream"] = True
|
||||
if tools:
|
||||
# 清理 $schema 字段
|
||||
cleaned = []
|
||||
for t in tools:
|
||||
tc = t.copy()
|
||||
if "function" in tc and "parameters" in tc["function"]:
|
||||
tc["function"]["parameters"] = {
|
||||
k: v for k, v in tc["function"]["parameters"].items() if k != "$schema"
|
||||
}
|
||||
cleaned.append(tc)
|
||||
payload["tools"] = cleaned
|
||||
if tool_choice:
|
||||
payload["tool_choice"] = tool_choice
|
||||
return payload
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, tools, tool_choice)
|
||||
data = await self._request_with_retry("POST", "/chat/completions", payload)
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices or len(choices) == 0:
|
||||
raise ValueError("API 返回空 choices 或 choices 为空列表")
|
||||
|
||||
choice = choices[0]
|
||||
message = choice.get("message", {})
|
||||
return {
|
||||
"content": message.get("content", ""),
|
||||
"tool_calls": message.get("tool_calls"),
|
||||
"finish_reason": choice.get("finish_reason"),
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
payload = self._build_payload(messages, model, temperature, max_tokens, stream=True)
|
||||
|
||||
async with await self._request_with_retry("POST", "/chat/completions", payload, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if line.startswith("data: "):
|
||||
data_str = line[6:]
|
||||
if data_str.strip() == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
choices = data.get("choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
content = choices[0].get("delta", {}).get("content", "")
|
||||
if content:
|
||||
yield content
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
@@ -0,0 +1,44 @@
|
||||
"""AI 服务配置管理"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class HTTPClientConfig:
|
||||
"""HTTP 客户端配置"""
|
||||
connect_timeout: float = 90.0
|
||||
read_timeout: float = 300.0
|
||||
write_timeout: float = 90.0
|
||||
pool_timeout: float = 90.0
|
||||
max_keepalive_connections: int = 50
|
||||
max_connections: int = 100
|
||||
keepalive_expiry: float = 60.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetryConfig:
|
||||
"""重试配置"""
|
||||
max_retries: int = 3
|
||||
base_delay: float = 0.2
|
||||
max_delay: float = 10.0
|
||||
exponential_base: int = 2
|
||||
non_retryable_status_codes: tuple = field(default_factory=lambda: (401, 403, 404))
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitConfig:
|
||||
"""限流配置"""
|
||||
max_concurrent_requests: int = 5
|
||||
request_delay: float = 0.2
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIClientConfig:
|
||||
"""AI 客户端完整配置"""
|
||||
http: HTTPClientConfig = field(default_factory=HTTPClientConfig)
|
||||
retry: RetryConfig = field(default_factory=RetryConfig)
|
||||
rate_limit: RateLimitConfig = field(default_factory=RateLimitConfig)
|
||||
|
||||
|
||||
# 全局默认配置
|
||||
default_config = AIClientConfig()
|
||||
@@ -0,0 +1,6 @@
|
||||
"""AI Provider 模块"""
|
||||
from .base_provider import BaseAIProvider
|
||||
from .openai_provider import OpenAIProvider
|
||||
from .anthropic_provider import AnthropicProvider
|
||||
|
||||
__all__ = ["BaseAIProvider", "OpenAIProvider", "AnthropicProvider"]
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Anthropic Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.services.ai_clients.anthropic_client import AnthropicClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
|
||||
class AnthropicProvider(BaseAIProvider):
|
||||
"""Anthropic 提供商"""
|
||||
|
||||
def __init__(self, client: AnthropicClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,33 @@
|
||||
"""AI Provider 基类"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
|
||||
class BaseAIProvider(ABC):
|
||||
"""AI 提供商抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成文本"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""流式生成"""
|
||||
pass
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Gemini Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from app.services.ai_clients.gemini_client import GeminiClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
|
||||
class GeminiProvider(BaseAIProvider):
|
||||
def __init__(self, client: GeminiClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,57 @@
|
||||
"""OpenAI Provider"""
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.services.ai_clients.openai_client import OpenAIClient
|
||||
from .base_provider import BaseAIProvider
|
||||
|
||||
|
||||
class OpenAIProvider(BaseAIProvider):
|
||||
"""OpenAI 提供商"""
|
||||
|
||||
def __init__(self, client: OpenAIClient):
|
||||
self.client = client
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
return await self.client.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def generate_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
system_prompt: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
async for chunk in self.client.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
yield chunk
|
||||
+177
-1295
File diff suppressed because it is too large
Load Diff
@@ -263,7 +263,7 @@ class AutoCharacterService:
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
@@ -362,7 +362,7 @@ class AutoCharacterService:
|
||||
user_id=user_id,
|
||||
db_session=db,
|
||||
enable_mcp=True,
|
||||
max_tool_rounds=1
|
||||
max_tool_rounds=2
|
||||
)
|
||||
content = result.get("content", "")
|
||||
# 使用统一的JSON清洗方法
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"""JSON 处理工具类"""
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, List, Union
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def clean_json_response(text: str) -> str:
|
||||
"""清洗 AI 返回的 JSON(改进版 - 流式安全)"""
|
||||
try:
|
||||
if not text:
|
||||
logger.warning("⚠️ clean_json_response: 输入为空")
|
||||
return text
|
||||
|
||||
original_length = len(text)
|
||||
logger.debug(f"🔍 开始清洗JSON,原始长度: {original_length}")
|
||||
|
||||
# 去除 markdown 代码块
|
||||
text = re.sub(r'^```json\s*\n?', '', text, flags=re.MULTILINE | re.IGNORECASE)
|
||||
text = re.sub(r'^```\s*\n?', '', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'\n?```\s*$', '', text, flags=re.MULTILINE)
|
||||
text = text.strip()
|
||||
|
||||
if len(text) != original_length:
|
||||
logger.debug(f" 移除markdown后长度: {len(text)}")
|
||||
|
||||
# 尝试直接解析(快速路径)
|
||||
try:
|
||||
json.loads(text)
|
||||
logger.debug(f"✅ 直接解析成功,无需清洗")
|
||||
return text
|
||||
except:
|
||||
pass
|
||||
|
||||
# 找到第一个 { 或 [
|
||||
start = -1
|
||||
for i, c in enumerate(text):
|
||||
if c in ('{', '['):
|
||||
start = i
|
||||
break
|
||||
|
||||
if start == -1:
|
||||
logger.warning(f"⚠️ 未找到JSON起始符号 {{ 或 [")
|
||||
logger.debug(f" 文本预览: {text[:200]}")
|
||||
return text
|
||||
|
||||
if start > 0:
|
||||
logger.debug(f" 跳过前{start}个字符")
|
||||
text = text[start:]
|
||||
|
||||
# 改进的括号匹配算法(更严格的字符串处理)
|
||||
stack = []
|
||||
i = 0
|
||||
end = -1
|
||||
|
||||
while i < len(text):
|
||||
c = text[i]
|
||||
|
||||
# 处理字符串(关键:正确处理转义)
|
||||
if c == '"':
|
||||
# 计算前面有多少个连续的反斜杠
|
||||
num_backslashes = 0
|
||||
j = i - 1
|
||||
while j >= 0 and text[j] == '\\':
|
||||
num_backslashes += 1
|
||||
j -= 1
|
||||
|
||||
# 偶数个反斜杠(包括0)表示引号未被转义
|
||||
if num_backslashes % 2 == 0:
|
||||
# 这是字符串边界,跳过整个字符串
|
||||
i += 1
|
||||
while i < len(text):
|
||||
if text[i] == '"':
|
||||
# 再次检查转义
|
||||
num_backslashes = 0
|
||||
j = i - 1
|
||||
while j >= 0 and text[j] == '\\':
|
||||
num_backslashes += 1
|
||||
j -= 1
|
||||
if num_backslashes % 2 == 0:
|
||||
# 字符串结束
|
||||
break
|
||||
i += 1
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 处理括号(只有在字符串外部才有效)
|
||||
if c == '{' or c == '[':
|
||||
stack.append(c)
|
||||
elif c == '}':
|
||||
if len(stack) > 0 and stack[-1] == '{':
|
||||
stack.pop()
|
||||
if len(stack) == 0:
|
||||
end = i + 1
|
||||
logger.debug(f"✅ 找到JSON结束位置: {end}")
|
||||
break
|
||||
else:
|
||||
logger.warning(f"⚠️ 括号不匹配:遇到 }} 但栈顶是 {stack[-1] if stack else 'empty'}")
|
||||
elif c == ']':
|
||||
if len(stack) > 0 and stack[-1] == '[':
|
||||
stack.pop()
|
||||
if len(stack) == 0:
|
||||
end = i + 1
|
||||
logger.debug(f"✅ 找到JSON结束位置: {end}")
|
||||
break
|
||||
else:
|
||||
logger.warning(f"⚠️ 括号不匹配:遇到 ] 但栈顶是 {stack[-1] if stack else 'empty'}")
|
||||
|
||||
i += 1
|
||||
|
||||
# 提取结果
|
||||
if end > 0:
|
||||
result = text[:end]
|
||||
logger.debug(f"✅ JSON清洗完成,结果长度: {len(result)}")
|
||||
else:
|
||||
result = text
|
||||
logger.warning(f"⚠️ 未找到JSON结束位置,返回全部内容(长度: {len(result)})")
|
||||
logger.debug(f" 栈状态: {stack}")
|
||||
|
||||
# 验证清洗后的结果
|
||||
try:
|
||||
json.loads(result)
|
||||
logger.debug(f"✅ 清洗后JSON验证成功")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 清洗后JSON仍然无效: {e}")
|
||||
logger.debug(f" 结果预览: {result[:500]}")
|
||||
logger.debug(f" 结果结尾: ...{result[-200:]}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ clean_json_response 出错: {e}")
|
||||
logger.error(f" 文本长度: {len(text) if text else 0}")
|
||||
logger.error(f" 文本预览: {text[:200] if text else 'None'}")
|
||||
raise
|
||||
|
||||
|
||||
def parse_json(text: str) -> Union[Dict, List]:
|
||||
"""解析 JSON"""
|
||||
try:
|
||||
cleaned = clean_json_response(text)
|
||||
return json.loads(cleaned)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ parse_json 出错: {e}")
|
||||
logger.error(f" 原始文本长度: {len(text) if text else 0}")
|
||||
logger.error(f" 清洗后文本长度: {len(cleaned) if cleaned else 0}")
|
||||
raise
|
||||
@@ -175,20 +175,34 @@ class MCPTestService:
|
||||
db=db_session
|
||||
)
|
||||
|
||||
ai_response = await ai_service.generate_text(
|
||||
# 注意: generate_text_stream 返回的是异步生成器,但在 tool_choice="required" 模式下
|
||||
# AI服务会直接返回包含 tool_calls 的完整响应,而不是流式chunks
|
||||
# 因此这里需要特殊处理
|
||||
accumulated_text = ""
|
||||
tool_calls = None
|
||||
|
||||
async for chunk in ai_service.generate_text_stream(
|
||||
prompt=prompts["user"],
|
||||
system_prompt=prompts["system"],
|
||||
tools=openai_tools,
|
||||
tool_choice="required"
|
||||
)
|
||||
):
|
||||
# 在 function calling 模式下,chunk 可能是字典格式包含 tool_calls
|
||||
if isinstance(chunk, dict):
|
||||
if "tool_calls" in chunk:
|
||||
tool_calls = chunk["tool_calls"]
|
||||
if "content" in chunk:
|
||||
accumulated_text += chunk.get("content", "")
|
||||
else:
|
||||
accumulated_text += chunk
|
||||
|
||||
# 5. 检查AI是否返回工具调用
|
||||
if not ai_response.get("tool_calls"):
|
||||
if not tool_calls:
|
||||
logger.error(f"❌ AI未返回工具调用")
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI Function Calling失败",
|
||||
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
|
||||
error=f"AI未返回工具调用请求。响应: {accumulated_text[:200] if accumulated_text else 'N/A'}",
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
"请确认使用的AI模型支持Function Calling",
|
||||
@@ -198,7 +212,7 @@ class MCPTestService:
|
||||
)
|
||||
|
||||
# 6. 解析工具调用
|
||||
tool_call = ai_response["tool_calls"][0]
|
||||
tool_call = tool_calls[0]
|
||||
function = tool_call["function"]
|
||||
tool_name = function["name"]
|
||||
test_arguments = function["arguments"]
|
||||
|
||||
@@ -386,17 +386,30 @@ class MCPToolService:
|
||||
|
||||
try:
|
||||
# 解析插件名和工具名
|
||||
logger.debug(f"🔍 解析工具名称: {function_name}")
|
||||
if "_" in function_name:
|
||||
plugin_name, tool_name = function_name.split("_", 1)
|
||||
logger.debug(f" 插件: {plugin_name}, 工具: {tool_name}")
|
||||
else:
|
||||
raise ValueError(f"无效的工具名称格式: {function_name}")
|
||||
|
||||
# 解析参数
|
||||
arguments_str = tool_call["function"]["arguments"]
|
||||
logger.debug(f"🔍 解析参数:")
|
||||
logger.debug(f" 原始类型: {type(arguments_str)}")
|
||||
logger.debug(f" 原始内容: {arguments_str}")
|
||||
|
||||
if isinstance(arguments_str, str):
|
||||
arguments = json.loads(arguments_str)
|
||||
try:
|
||||
arguments = json.loads(arguments_str)
|
||||
logger.debug(f" ✅ JSON解析成功: {arguments}")
|
||||
except json.JSONDecodeError as je:
|
||||
logger.error(f" ❌ JSON解析失败: {je}")
|
||||
logger.error(f" 原始字符串: '{arguments_str}'")
|
||||
raise ValueError(f"参数JSON解析失败: {je}")
|
||||
else:
|
||||
arguments = arguments_str
|
||||
logger.debug(f" 直接使用dict类型参数")
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {plugin_name}.{tool_name}, "
|
||||
|
||||
@@ -71,24 +71,15 @@ class PlotAnalyzer:
|
||||
# 调用AI进行分析
|
||||
# 注意:不指定max_tokens,使用用户在设置中配置的值
|
||||
logger.info(f" 调用AI分析(内容长度: {len(analysis_content)}字)...")
|
||||
response = await self.ai_service.generate_text(
|
||||
accumulated_text = ""
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
temperature=0.3 # 降低温度以获得更稳定的JSON输出
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 🔍 添加调试日志:查看AI返回的原始内容
|
||||
# logger.info(f"🔍 AI返回类型: {type(response)}")
|
||||
# logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
|
||||
|
||||
# 从返回的字典中提取content字段
|
||||
if isinstance(response, dict):
|
||||
response_text = response.get('content', '')
|
||||
if not response_text:
|
||||
logger.error("❌ AI返回的字典中没有content字段或content为空")
|
||||
return None
|
||||
else:
|
||||
# 兼容旧的字符串返回格式
|
||||
response_text = response
|
||||
# 提取内容
|
||||
response_text = accumulated_text
|
||||
|
||||
# 解析JSON结果
|
||||
analysis_result = self._parse_analysis_response(response_text)
|
||||
|
||||
@@ -133,14 +133,16 @@ class PlotExpansionService:
|
||||
|
||||
# 调用AI生成章节规划
|
||||
logger.info(f"调用AI生成章节规划...")
|
||||
ai_response = await self.ai_service.generate_text(
|
||||
accumulated_text = ""
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 提取内容
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
|
||||
# 解析AI响应
|
||||
chapter_plans = self._parse_expansion_response(ai_content, outline.id)
|
||||
@@ -236,14 +238,16 @@ class PlotExpansionService:
|
||||
|
||||
# 调用AI生成当前批次
|
||||
logger.info(f"调用AI生成第{batch_num + 1}批...")
|
||||
ai_response = await self.ai_service.generate_text(
|
||||
accumulated_text = ""
|
||||
async for chunk in self.ai_service.generate_text_stream(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model
|
||||
)
|
||||
):
|
||||
accumulated_text += chunk
|
||||
|
||||
# 提取内容
|
||||
ai_content = ai_response.get("content", "") if isinstance(ai_response, dict) else ai_response
|
||||
ai_content = accumulated_text
|
||||
|
||||
# 解析AI响应
|
||||
batch_plans = self._parse_expansion_response(ai_content, outline.id)
|
||||
|
||||
@@ -6,142 +6,6 @@ import json
|
||||
class WritingStyleManager:
|
||||
"""写作风格管理器"""
|
||||
|
||||
# 预设风格配置
|
||||
PRESET_STYLES = {
|
||||
"natural": {
|
||||
"name": "自然沉浸 (Natural & Immersive)",
|
||||
"description": "祛除翻译腔,强调生活质感,像呼吸一样自然的叙事",
|
||||
"prompt_content": """
|
||||
### 核心指令:自然沉浸风格
|
||||
请模拟人类作家在放松状态下的写作,通过以下规则消除“AI味”:
|
||||
|
||||
1. **拒绝翻译腔与书面化**:
|
||||
- 严禁使用“一种...的感觉”、“随着...”、“与此同时”等连接词。
|
||||
- 多用短句和“流水句”,模拟人类视线的移动和思维的跳跃。
|
||||
- 口语化叙述,但不要滥用语气词,而是通过句子的长短节奏来体现语气。
|
||||
|
||||
2. **生活化的颗粒度**:
|
||||
- 描写不要宏大,要聚焦在具体的、微小的生活细节(如:杯子上的水渍、衣服的褶皱)。
|
||||
- 允许逻辑上的适度“松散”,不要让每句话都像说明书一样严丝合缝。
|
||||
|
||||
3. **具体的“展示”**:
|
||||
- 不要写“他很生气”,要写他“把烟头按灭在还没吃完的米饭里”。
|
||||
- 避免使用抽象的形容词(如:巨大的、美丽的、悲伤的),必须用名词和动词来承载画面。
|
||||
"""
|
||||
},
|
||||
"classical": {
|
||||
"name": "古典雅致 (Classical & Elegant)",
|
||||
"description": "白话文与古典韵味的结合,强调留白与炼字",
|
||||
"prompt_content": """
|
||||
### 核心指令:古典雅致风格
|
||||
请模仿民国时期或古典白话小说的笔触,构建端庄且富有余味的叙事:
|
||||
|
||||
1. **炼字与韵律**:
|
||||
- 尽量使用双音节词或四字短语,但严禁堆砌辞藻。
|
||||
- 注重句子的声调韵律,读起来要有金石之声或流水之韵。
|
||||
- 适当使用倒装句或定语后置,增加古雅感。
|
||||
|
||||
2. **克制的修辞**:
|
||||
- 少用现代的比喻(如“像机器一样”),多用取自自然的比喻(如“如风过林”)。
|
||||
- **意在言外**:不要把话说透,留三分余地。写景即是写情,不要将情感直接剖白。
|
||||
|
||||
3. **禁忌**:
|
||||
- 严禁使用现代科技词汇(除非题材需要)、网络用语或过于西化的句式(如长定语从句)。
|
||||
- 避免滥用“之乎者也”,追求的是“神似”而非生硬的半文半白。
|
||||
"""
|
||||
},
|
||||
"modern": {
|
||||
"name": "冷硬现代 (Modern & Hard-boiled)",
|
||||
"description": "海明威式的冰山理论,节奏极快,零度情感",
|
||||
"prompt_content": """
|
||||
### 核心指令:冷硬现代风格
|
||||
请采用“极简主义”和“零度写作”手法,去除所有矫饰:
|
||||
|
||||
1. **冰山理论**:
|
||||
- **只写动作和对话,完全剔除心理描写和形容词堆砌。**
|
||||
- 不要告诉读者角色感觉如何,通过角色的反应和环境的冷峻反馈来体现。
|
||||
|
||||
2. **电影蒙太奇节奏**:
|
||||
- 句子要短、脆、硬。像手术刀一样切开场景。
|
||||
- 段落之间快速切换,不要用过渡句连接,直接跳切。
|
||||
|
||||
3. **高信息密度**:
|
||||
- 删除所有废话。如果一个词删掉不影响理解,就删掉它。
|
||||
- 多用名词和强动词(Strong Verbs),少用副词(Adverbs)。例如:不要写“他重重地关上门”,写“他摔上了门”。
|
||||
"""
|
||||
},
|
||||
"poetic": {
|
||||
"name": "意识流 (Stream of Consciousness)",
|
||||
"description": "注重感官通感与内心独白,打破现实与幻想的边界",
|
||||
"prompt_content": """
|
||||
### 核心指令:意识流/诗意风格
|
||||
请侧重于主观感受的流动,而非客观事实的记录:
|
||||
|
||||
1. **通感与陌生化**:
|
||||
- 打通五感(如:听到了颜色的声音,闻到了悲伤的气味)。
|
||||
- 使用“陌生化”的语言,把熟悉的事物写得陌生,迫使读者重新审视。
|
||||
|
||||
2. **情绪的具象化**:
|
||||
- **绝对禁止**直接出现“开心”、“痛苦”等抽象词汇。
|
||||
- 必须寻找“客观对应物”(Objective Correlative),将情绪投射到具体的景物上(如:生锈的铁轨、发霉的橘子)。
|
||||
|
||||
3. **流动的句式**:
|
||||
- 句子可以很长,包含多重意象的叠加。
|
||||
- 允许思维的非线性跳跃,模拟梦境或深层潜意识的逻辑。
|
||||
"""
|
||||
},
|
||||
"concise": {
|
||||
"name": "白描速写 (Sketch & Concise)",
|
||||
"description": "只有骨架的叙事,强调绝对的精准和功能性",
|
||||
"prompt_content": """
|
||||
### 核心指令:白描速写风格
|
||||
请像速写画家一样,只勾勒线条,不涂抹色彩:
|
||||
|
||||
1. **功能性第一**:
|
||||
- 每一句话必须推动情节,或者揭示关键信息。
|
||||
- 如果一句话只是为了渲染气氛,删掉它。
|
||||
|
||||
2. **主谓宾结构**:
|
||||
- 尽量使用简单的主谓宾结构,减少修饰语。
|
||||
- 避免复杂的从句和嵌套结构。
|
||||
|
||||
3. **直击核心**:
|
||||
- 对话直接进入主题,去除寒暄和废话。
|
||||
- 环境描写仅限于对情节有物理影响的物体(如:挡路的石头、藏在桌下的枪)。
|
||||
"""
|
||||
},
|
||||
"vivid": {
|
||||
"name": "感官特写 (Sensory & Vivid)",
|
||||
"description": "高分辨率的描写,强调材质、光影和微观细节",
|
||||
"prompt_content": """
|
||||
### 核心指令:感官特写风格
|
||||
请将镜头推到特写级别(Macro Lens),捕捉常人忽略的细节:
|
||||
|
||||
1. **反套路细节**:
|
||||
- 不要写大众化的细节(如:蓝天白云),要写具有**独特性**的细节(如:云层边缘那抹像淤青一样的灰紫色)。
|
||||
- 关注物体的**质感(Texture)**:粗糙的、粘稠的、冰凉的、颗粒感的。
|
||||
|
||||
2. **动态捕捉**:
|
||||
- 不要写静止的画面,要写光影的流变、灰尘的飞舞、肌肉的抽动。
|
||||
- 让读者产生生理性的反应(如:痛感、饥饿感、窒息感)。
|
||||
|
||||
3. **禁用词汇**:
|
||||
- 禁止使用“映入眼帘”、“宛如画卷”等陈词滥调。
|
||||
- 必须用具体的动词带动感官描写。
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_preset_style(cls, preset_id: str) -> Optional[Dict[str, str]]:
|
||||
"""获取预设风格配置"""
|
||||
return cls.PRESET_STYLES.get(preset_id)
|
||||
|
||||
@classmethod
|
||||
def get_all_presets(cls) -> Dict[str, Dict[str, str]]:
|
||||
"""获取所有预设风格"""
|
||||
return cls.PRESET_STYLES
|
||||
|
||||
@staticmethod
|
||||
def apply_style_to_prompt(base_prompt: str, style_content: str) -> str:
|
||||
"""
|
||||
@@ -692,9 +556,8 @@ class PromptService:
|
||||
|
||||
6. **承上启下**:
|
||||
- 开头自然衔接上一章结尾(但不重复上一章内容)
|
||||
- 结尾为下一章做好铺垫
|
||||
|
||||
6. **记忆系统使用指南**:
|
||||
7. **记忆系统使用指南**:
|
||||
- **最近章节记忆**:保持情节连贯,注意角色状态和剧情发展
|
||||
- **语义相关记忆**:参考相似情节的处理方式
|
||||
- **未完结伏笔**:适当时机可以回收伏笔,制造呼应效果
|
||||
@@ -1308,16 +1171,15 @@ class PromptService:
|
||||
- 如果参数名是 snake_case(如 next_thought),就使用 snake_case
|
||||
- 保持与 schema 中定义的完全一致,包括大小写和命名风格"""
|
||||
|
||||
# 灵感模式提示词字典
|
||||
INSPIRATION_PROMPTS = {
|
||||
"title": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
# 灵感模式 - 书名生成(系统提示词)
|
||||
INSPIRATION_TITLE_SYSTEM = """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
|
||||
请根据用户的想法,生成6个吸引人的书名建议,要求:
|
||||
1. 紧扣用户的原始想法和核心故事构思
|
||||
2. 富有创意和吸引力
|
||||
3. 涵盖不同的风格倾向
|
||||
4. 书名中不要带有"《》"符号
|
||||
|
||||
返回JSON格式:
|
||||
{{
|
||||
@@ -1325,11 +1187,13 @@ class PromptService:
|
||||
"options": ["书名1", "书名2", "书名3", "书名4", "书名5", "书名6"]
|
||||
}}
|
||||
|
||||
只返回纯JSON,不要有其他文字。""",
|
||||
"user": "用户的想法:{initial_idea}\n请生成6个书名建议"
|
||||
},
|
||||
"description": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
只返回纯JSON,不要有其他文字。"""
|
||||
|
||||
# 灵感模式 - 书名生成(用户提示词)
|
||||
INSPIRATION_TITLE_USER = "用户的想法:{initial_idea}\n请生成6个书名建议"
|
||||
|
||||
# 灵感模式 - 简介生成(系统提示词)
|
||||
INSPIRATION_DESCRIPTION_SYSTEM = """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
已确定的书名:{title}
|
||||
|
||||
@@ -1343,11 +1207,13 @@ class PromptService:
|
||||
返回JSON格式:
|
||||
{{"prompt":"选择一个简介:","options":["简介1","简介2","简介3","简介4","简介5","简介6"]}}
|
||||
|
||||
只返回纯JSON,不要有其他文字,不要换行。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
|
||||
},
|
||||
"theme": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
只返回纯JSON,不要有其他文字,不要换行。"""
|
||||
|
||||
# 灵感模式 - 简介生成(用户提示词)
|
||||
INSPIRATION_DESCRIPTION_USER = "原始想法:{initial_idea}\n书名:{title}\n请生成6个简介选项"
|
||||
|
||||
# 灵感模式 - 主题生成(系统提示词)
|
||||
INSPIRATION_THEME_SYSTEM = """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
小说信息:
|
||||
- 书名:{title}
|
||||
@@ -1363,11 +1229,13 @@ class PromptService:
|
||||
返回JSON格式:
|
||||
{{"prompt":"这本书的核心主题是什么?","options":["主题1","主题2","主题3","主题4","主题5","主题6"]}}
|
||||
|
||||
只返回纯JSON,不要有其他文字,不要换行。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
|
||||
},
|
||||
"genre": {
|
||||
"system": """你是一位专业的小说创作顾问。
|
||||
只返回纯JSON,不要有其他文字,不要换行。"""
|
||||
|
||||
# 灵感模式 - 主题生成(用户提示词)
|
||||
INSPIRATION_THEME_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n请生成6个主题选项"
|
||||
|
||||
# 灵感模式 - 类型生成(系统提示词)
|
||||
INSPIRATION_GENRE_SYSTEM = """你是一位专业的小说创作顾问。
|
||||
用户的原始想法:{initial_idea}
|
||||
小说信息:
|
||||
- 书名:{title}
|
||||
@@ -1384,10 +1252,10 @@ class PromptService:
|
||||
返回JSON格式:
|
||||
{{"prompt":"选择类型标签(可多选):","options":["类型1","类型2","类型3","类型4","类型5","类型6"]}}
|
||||
|
||||
只返回紧凑的纯JSON,不要换行,不要有其他文字。""",
|
||||
"user": "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
|
||||
}
|
||||
}
|
||||
只返回紧凑的纯JSON,不要换行,不要有其他文字。"""
|
||||
|
||||
# 灵感模式 - 类型生成(用户提示词)
|
||||
INSPIRATION_GENRE_USER = "原始想法:{initial_idea}\n书名:{title}\n简介:{description}\n主题:{theme}\n请生成6个类型标签"
|
||||
|
||||
# 灵感模式智能补全提示词
|
||||
INSPIRATION_QUICK_COMPLETE = """你是一位专业的小说创作顾问。用户提供了部分小说信息,请补全缺失的字段。
|
||||
@@ -1887,7 +1755,26 @@ class PromptService:
|
||||
@classmethod
|
||||
def get_inspiration_prompt(cls, step: str) -> Optional[Dict[str, str]]:
|
||||
"""获取灵感模式指定步骤的提示词"""
|
||||
return cls.INSPIRATION_PROMPTS.get(step)
|
||||
# 根据步骤名称返回对应的system和user提示词
|
||||
step_map = {
|
||||
"title": {
|
||||
"system": cls.INSPIRATION_TITLE_SYSTEM,
|
||||
"user": cls.INSPIRATION_TITLE_USER
|
||||
},
|
||||
"description": {
|
||||
"system": cls.INSPIRATION_DESCRIPTION_SYSTEM,
|
||||
"user": cls.INSPIRATION_DESCRIPTION_USER
|
||||
},
|
||||
"theme": {
|
||||
"system": cls.INSPIRATION_THEME_SYSTEM,
|
||||
"user": cls.INSPIRATION_THEME_USER
|
||||
},
|
||||
"genre": {
|
||||
"system": cls.INSPIRATION_GENRE_SYSTEM,
|
||||
"user": cls.INSPIRATION_GENRE_USER
|
||||
}
|
||||
}
|
||||
return step_map.get(step)
|
||||
|
||||
@classmethod
|
||||
def get_inspiration_quick_complete_prompt(cls, existing: str) -> Dict[str, str]:
|
||||
@@ -1997,17 +1884,12 @@ class PromptService:
|
||||
# 2. 降级到系统默认模板
|
||||
logger.info(f"⚪ 使用系统默认提示词: user_id={user_id}, template_key={template_key} (未找到自定义模板)")
|
||||
|
||||
# 特殊处理灵感模式的提示词(存储在INSPIRATION_PROMPTS字典中)
|
||||
# 特殊处理灵感模式的提示词(直接从类属性获取)
|
||||
if template_key.startswith("INSPIRATION_"):
|
||||
# 提取步骤名称(如 INSPIRATION_TITLE -> title)
|
||||
step = template_key.replace("INSPIRATION_", "").lower()
|
||||
inspiration_prompt = cls.INSPIRATION_PROMPTS.get(step)
|
||||
if inspiration_prompt:
|
||||
# 返回JSON格式的提示词
|
||||
return json.dumps(inspiration_prompt, ensure_ascii=False)
|
||||
# 如果是INSPIRATION_QUICK_COMPLETE
|
||||
if template_key == "INSPIRATION_QUICK_COMPLETE":
|
||||
return cls.INSPIRATION_QUICK_COMPLETE
|
||||
# 直接从类属性获取
|
||||
template_content = getattr(cls, template_key, None)
|
||||
if template_content:
|
||||
return template_content
|
||||
|
||||
# 其他模板直接从类属性获取
|
||||
template_content = getattr(cls, template_key, None)
|
||||
@@ -2182,6 +2064,60 @@ class PromptService:
|
||||
"category": "世界构建",
|
||||
"description": "根据世界观自动生成完整的职业体系,包括主职业和副职业",
|
||||
"parameters": ["title", "genre", "theme", "time_period", "location", "atmosphere", "rules"]
|
||||
},
|
||||
"INSPIRATION_TITLE_SYSTEM": {
|
||||
"name": "灵感模式-书名生成(系统提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户的原始想法生成6个书名建议的系统提示词",
|
||||
"parameters": ["initial_idea"]
|
||||
},
|
||||
"INSPIRATION_TITLE_USER": {
|
||||
"name": "灵感模式-书名生成(用户提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户的原始想法生成6个书名建议的用户提示词",
|
||||
"parameters": ["initial_idea"]
|
||||
},
|
||||
"INSPIRATION_DESCRIPTION_SYSTEM": {
|
||||
"name": "灵感模式-简介生成(系统提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户想法和书名生成6个简介选项的系统提示词",
|
||||
"parameters": ["initial_idea", "title"]
|
||||
},
|
||||
"INSPIRATION_DESCRIPTION_USER": {
|
||||
"name": "灵感模式-简介生成(用户提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户想法和书名生成6个简介选项的用户提示词",
|
||||
"parameters": ["initial_idea", "title"]
|
||||
},
|
||||
"INSPIRATION_THEME_SYSTEM": {
|
||||
"name": "灵感模式-主题生成(系统提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据书名和简介生成6个深刻的主题选项的系统提示词",
|
||||
"parameters": ["initial_idea", "title", "description"]
|
||||
},
|
||||
"INSPIRATION_THEME_USER": {
|
||||
"name": "灵感模式-主题生成(用户提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据书名和简介生成6个深刻的主题选项的用户提示词",
|
||||
"parameters": ["initial_idea", "title", "description"]
|
||||
},
|
||||
"INSPIRATION_GENRE_SYSTEM": {
|
||||
"name": "灵感模式-类型生成(系统提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据小说信息生成6个合适的类型标签的系统提示词",
|
||||
"parameters": ["initial_idea", "title", "description", "theme"]
|
||||
},
|
||||
"INSPIRATION_GENRE_USER": {
|
||||
"name": "灵感模式-类型生成(用户提示词)",
|
||||
"category": "灵感模式",
|
||||
"description": "根据小说信息生成6个合适的类型标签的用户提示词",
|
||||
"parameters": ["initial_idea", "title", "description", "theme"]
|
||||
},
|
||||
"INSPIRATION_QUICK_COMPLETE": {
|
||||
"name": "灵感模式-智能补全",
|
||||
"category": "灵感模式",
|
||||
"description": "根据用户提供的部分信息智能补全完整的小说方案",
|
||||
"parameters": ["existing"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user