Files
MuMuAINovel/backend/app/services/ai_service.py
T

296 lines
13 KiB
Python

"""AI服务封装 - 统一的AI接口"""
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
from app.config import settings as app_settings
from app.logger import get_logger
from app.services.ai_config import AIClientConfig, default_config
from app.services.ai_clients.openai_client import OpenAIClient
from app.services.ai_clients.anthropic_client import AnthropicClient
from app.services.ai_clients.gemini_client import GeminiClient
from app.services.ai_clients.base_client import cleanup_all_clients
from app.services.ai_providers.openai_provider import OpenAIProvider
from app.services.ai_providers.anthropic_provider import AnthropicProvider
from app.services.ai_providers.gemini_provider import GeminiProvider
from app.services.ai_providers.base_provider import BaseAIProvider
from app.services.json_helper import clean_json_response, parse_json
from app.mcp.adapters.universal import universal_mcp_adapter
# 导出清理函数
cleanup_http_clients = cleanup_all_clients
logger = get_logger(__name__)
class AIService:
"""AI服务统一接口"""
def __init__(
self,
api_provider: Optional[str] = None,
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
default_model: Optional[str] = None,
default_temperature: Optional[float] = None,
default_max_tokens: Optional[int] = None,
default_system_prompt: Optional[str] = None,
enable_mcp_adapter: bool = True,
config: Optional[AIClientConfig] = None,
):
self.api_provider = api_provider or app_settings.default_ai_provider
self.default_model = default_model or app_settings.default_model
self.default_temperature = default_temperature or app_settings.default_temperature
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
self.default_system_prompt = default_system_prompt
self.config = config or default_config
self.mcp_adapter = universal_mcp_adapter if enable_mcp_adapter else None
self._openai_provider: Optional[OpenAIProvider] = None
self._anthropic_provider: Optional[AnthropicProvider] = None
self._gemini_provider: Optional[GeminiProvider] = None
# 初始化 OpenAI
openai_key = api_key if api_provider == "openai" else app_settings.openai_api_key
if openai_key:
base_url = api_base_url if api_provider == "openai" else app_settings.openai_base_url
client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config)
self._openai_provider = OpenAIProvider(client)
# 初始化 Anthropic
anthropic_key = api_key if api_provider == "anthropic" else app_settings.anthropic_api_key
if anthropic_key:
base_url = api_base_url if api_provider == "anthropic" else app_settings.anthropic_base_url
client = AnthropicClient(anthropic_key, base_url, self.config)
self._anthropic_provider = AnthropicProvider(client)
# 初始化 Gemini
if api_provider == "gemini" and api_key:
client = GeminiClient(api_key, api_base_url, self.config)
self._gemini_provider = GeminiProvider(client)
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
"""获取对应的 Provider"""
p = provider or self.api_provider
if p == "openai" and self._openai_provider:
return self._openai_provider
if p == "anthropic" and self._anthropic_provider:
return self._anthropic_provider
if p == "gemini" and self._gemini_provider:
return self._gemini_provider
raise ValueError(f"Provider {p} 未初始化")
async def generate_text(
self,
prompt: str,
provider: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
) -> Dict[str, Any]:
"""生成文本"""
prov = self._get_provider(provider)
return await prov.generate(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
max_tokens=max_tokens or self.default_max_tokens,
system_prompt=system_prompt or self.default_system_prompt,
tools=tools,
tool_choice=tool_choice,
)
async def generate_text_stream(
self,
prompt: str,
provider: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""流式生成"""
prov = self._get_provider(provider)
async for chunk in prov.generate_stream(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
max_tokens=max_tokens or self.default_max_tokens,
system_prompt=system_prompt or self.default_system_prompt,
):
yield chunk
async def call_with_json_retry(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_retries: int = 3,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
expected_type: Optional[str] = None,
) -> Union[Dict, List]:
"""带重试的 JSON 调用"""
last_response = ""
for attempt in range(1, max_retries + 1):
current_prompt = prompt if attempt == 1 else self._add_json_hint(prompt, last_response, attempt)
result = await self.generate_text(
prompt=current_prompt,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
)
last_response = result.get("content", "")
try:
data = parse_json(last_response)
if expected_type == "object" and not isinstance(data, dict):
raise ValueError("期望对象")
if expected_type == "array" and not isinstance(data, list):
raise ValueError("期望数组")
return data
except Exception as e:
if attempt == max_retries:
raise ValueError(f"JSON 解析失败: {e}")
raise ValueError("JSON 调用失败")
@staticmethod
def _add_json_hint(prompt: str, failed: str, attempt: int) -> str:
return f"{prompt}\n\n⚠️ 第{attempt}次重试,请返回纯JSON,不要markdown包裹。上次错误: {failed[:200]}..."
@staticmethod
def _clean_json_response(text: str) -> str:
"""清洗 JSON 响应"""
return clean_json_response(text)
async def generate_text_with_mcp(
self,
prompt: str,
user_id: str,
db_session,
enable_mcp: bool = True,
max_tool_rounds: int = 3,
tool_choice: str = "auto",
**kwargs
) -> Dict[str, Any]:
"""支持MCP工具的AI文本生成"""
from app.services.mcp_tool_service import mcp_tool_service, MCPToolServiceError
result = {"content": "", "tool_calls_made": 0, "tools_used": [], "finish_reason": "", "mcp_enhanced": False}
tools = None
if enable_mcp:
try:
tools = await mcp_tool_service.get_user_enabled_tools(user_id=user_id, db_session=db_session)
if tools:
result["mcp_enhanced"] = True
except MCPToolServiceError:
tools = None
original_prompt = prompt # 保存原始提示词
for round_num in range(max_tool_rounds):
logger.debug(f"🔄 MCP工具调用 - 第{round_num+1}/{max_tool_rounds}")
logger.debug(f" prompt长度: {len(prompt)}, tools数量: {len(tools) if tools else 0}, tool_choice: {tool_choice}")
ai_response = await self.generate_text(prompt=prompt, tools=tools, tool_choice=tool_choice, **kwargs)
logger.debug(f" AI响应: finish_reason={ai_response.get('finish_reason')}, content长度={len(ai_response.get('content', ''))}")
tool_calls = ai_response.get("tool_calls", [])
if not tool_calls:
content = ai_response.get("content", "")
result["content"] = content
result["finish_reason"] = ai_response.get("finish_reason", "stop")
logger.debug(f" ✅ 无工具调用,返回内容长度: {len(content)}")
# 🔧 修复:如果内容为空且已经调用过工具,强制要求AI给出答案
if not content.strip() and result["tool_calls_made"] > 0:
logger.warning(f"⚠️ AI在工具调用后返回空内容,尝试强制要求回答(第{round_num+1}轮)")
prompt = f"{prompt}\n\n⚠️ 请注意:你必须基于以上工具查询结果,给出完整的回答。不要返回空内容。"
tools = None
tool_choice = "none" # 强制不使用工具
continue
break
logger.info(f"🔧 检测到 {len(tool_calls)} 个工具调用")
for idx, tc in enumerate(tool_calls):
logger.debug(f" 工具{idx+1}: {tc.get('function', {}).get('name')} - 参数: {tc.get('function', {}).get('arguments')}")
try:
logger.debug(f" 开始执行工具调用...")
tool_results = await mcp_tool_service.execute_tool_calls(user_id=user_id, tool_calls=tool_calls, db_session=db_session)
logger.debug(f" 工具执行完成,结果数量: {len(tool_results)}")
# 🔍 检查工具结果
for idx, tr in enumerate(tool_results):
success = tr.get("success", False)
content_preview = tr.get("content", "")[:200] if tr.get("content") else "None"
logger.debug(f" 工具结果[{idx}]: success={success}, content预览={content_preview}")
for tc in tool_calls:
name = tc["function"]["name"]
if name not in result["tools_used"]:
result["tools_used"].append(name)
result["tool_calls_made"] += len(tool_calls)
tool_context = await mcp_tool_service.build_tool_context(tool_results, format="markdown")
logger.debug(f" 工具上下文长度: {len(tool_context)}")
logger.debug(f" 工具上下文预览: {tool_context[:300] if len(tool_context) > 300 else tool_context}")
# 🔧 改进:在最后一轮时,明确要求AI给出完整答案
if round_num == max_tool_rounds - 1:
logger.info(f"⚠️ 最后一轮,强制要求AI给出最终答案")
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:这是最后一轮,请基于以上工具查询的参考资料,给出完整详细的最终答案。不要再调用工具。"
tool_choice = "none"
else:
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
logger.debug(f" 新prompt长度: {len(prompt)}")
tools = None # 工具调用后禁用工具列表,避免重复调用
logger.debug(f" ✅ 工具调用成功,准备下一轮")
except Exception as tool_error:
logger.error(f"❌ 工具调用执行失败: {tool_error}", exc_info=True)
logger.error(f" 错误类型: {type(tool_error).__name__}")
logger.error(f" AI响应内容: {ai_response.get('content', '')[:200]}")
result["content"] = ai_response.get("content", "")
result["finish_reason"] = "tool_error"
break
return result
# 全局实例
ai_service = AIService()
def create_user_ai_service(
api_provider: str,
api_key: str,
api_base_url: str,
model_name: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AIService:
"""创建用户 AI 服务"""
return AIService(
api_provider=api_provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=model_name,
default_temperature=temperature,
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
)