Files
MuMuAINovel/backend/app/services/ai_service.py
T
2026-03-17 17:31:08 +08:00

582 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""AI服务封装 - 统一的AI接口
重构后支持自动MCP工具加载:
- 所有AI方法在请求前自动检查用户MCP配置
- 如果有启用的MCP插件且有可用工具,自动发送tools
- 通过 auto_mcp 参数控制是否启用自动工具加载
"""
from typing import Optional, AsyncGenerator, List, Dict, Any, Union
from app.config import settings as app_settings
from app.logger import get_logger
from app.services.ai_config import AIClientConfig, default_config
from app.services.ai_clients.openai_client import OpenAIClient
from app.services.ai_clients.anthropic_client import AnthropicClient
from app.services.ai_clients.gemini_client import GeminiClient
from app.services.ai_clients.base_client import cleanup_all_clients
from app.services.ai_providers.openai_provider import OpenAIProvider
from app.services.ai_providers.anthropic_provider import AnthropicProvider
from app.services.ai_providers.gemini_provider import GeminiProvider
from app.services.ai_providers.base_provider import BaseAIProvider
from app.services.json_helper import clean_json_response, parse_json
# 导出清理函数
cleanup_http_clients = cleanup_all_clients
logger = get_logger(__name__)
def normalize_provider(provider: Optional[str]) -> Optional[str]:
"""标准化 provider 名称,兼容渠道别名。"""
if provider == "mumu":
return "openai"
return provider
class AIService:
"""
AI服务统一接口
MCP工具支持:
- 在创建服务时传入 user_id 和 db_session
- 根据用户MCP插件的enabled状态自动决定是否启用MCP
- 如果有任意一个MCP插件启用,则加载并使用工具
- 如果所有插件都关闭,则不使用任何MCP工具
- 通过 auto_mcp=False 可临时禁用自动工具加载
- 通过 mcp_max_rounds 控制工具调用轮数
- 通过 clear_mcp_cache() 可清理MCP工具缓存
MCP启用逻辑(backend/app/api/settings.py 中的 get_user_ai_service):
- 查询用户的所有MCP插件
- 如果有启用的插件 (enabled=True),则 enable_mcp=True
- 如果所有插件都关闭或没有插件,则 enable_mcp=False
使用示例:
# 创建支持MCP的AI服务(根据插件状态自动决定是否启用)
ai_service = create_user_ai_service_with_mcp(
api_provider="openai",
api_key="...",
user_id="user123",
db_session=db
)
# 自动加载MCP工具(如果有启用的插件)
result = await ai_service.generate_text(prompt="...")
# 临时禁用MCP工具
result = await ai_service.generate_text(prompt="...", auto_mcp=False)
# 自定义轮数
result = await ai_service.generate_text(prompt="...", mcp_max_rounds=3)
"""
def __init__(
self,
api_provider: Optional[str] = None,
api_key: Optional[str] = None,
api_base_url: Optional[str] = None,
default_model: Optional[str] = None,
default_temperature: Optional[float] = None,
default_max_tokens: Optional[int] = None,
default_system_prompt: Optional[str] = None,
config: Optional[AIClientConfig] = None,
# MCP支持参数
user_id: Optional[str] = None,
db_session: Optional[Any] = None,
enable_mcp: bool = True,
):
self.api_provider = normalize_provider(api_provider or app_settings.default_ai_provider)
self.default_model = default_model or app_settings.default_model
self.default_temperature = default_temperature or app_settings.default_temperature
self.default_max_tokens = default_max_tokens or app_settings.default_max_tokens
self.default_system_prompt = default_system_prompt
self.config = config or default_config
# MCP配置
self.user_id = user_id
self.db_session = db_session
self._enable_mcp = enable_mcp
self._cached_tools: Optional[List[Dict]] = None
self._tools_loaded = False
self._openai_provider: Optional[OpenAIProvider] = None
self._anthropic_provider: Optional[AnthropicProvider] = None
self._gemini_provider: Optional[GeminiProvider] = None
# 初始化 OpenAI
openai_key = api_key if self.api_provider == "openai" else app_settings.openai_api_key
if openai_key:
base_url = api_base_url if self.api_provider == "openai" else app_settings.openai_base_url
client = OpenAIClient(openai_key, base_url or "https://api.openai.com/v1", self.config)
self._openai_provider = OpenAIProvider(client)
# 初始化 Anthropic
anthropic_key = api_key if self.api_provider == "anthropic" else app_settings.anthropic_api_key
if anthropic_key:
base_url = api_base_url if self.api_provider == "anthropic" else app_settings.anthropic_base_url
client = AnthropicClient(anthropic_key, base_url, self.config)
self._anthropic_provider = AnthropicProvider(client)
# 初始化 Gemini
if self.api_provider == "gemini" and api_key:
client = GeminiClient(api_key, api_base_url, self.config)
self._gemini_provider = GeminiProvider(client)
@property
def enable_mcp(self) -> bool:
"""是否启用MCP工具"""
return self._enable_mcp
@enable_mcp.setter
def enable_mcp(self, value: bool):
"""设置MCP启用状态,如果禁用则清理缓存"""
if value is False and self._enable_mcp is True:
# 从启用变为禁用,清理缓存
self.clear_mcp_cache()
self._enable_mcp = value
def clear_mcp_cache(self):
"""
清理MCP工具缓存
当禁用MCP时调用此方法,确保后续AI调用不会使用缓存的工具。
同时更新 _tools_loaded 状态,使下次调用时重新检查。
"""
if self._cached_tools is not None:
logger.info(f"🔧 清理MCP工具缓存,移除 {len(self._cached_tools)} 个工具")
self._cached_tools = None
else:
logger.debug(f"🔧 MCP工具缓存已经是空,无需清理")
# 更新加载状态,确保下次调用会重新检查
self._tools_loaded = False
logger.debug(f"🔧 MCP工具状态已重置: enable_mcp={self._enable_mcp}, _tools_loaded=False")
def _get_provider(self, provider: Optional[str] = None) -> BaseAIProvider:
"""获取对应的 Provider"""
p = normalize_provider(provider or self.api_provider)
if p == "openai" and self._openai_provider:
return self._openai_provider
if p == "anthropic" and self._anthropic_provider:
return self._anthropic_provider
if p == "gemini" and self._gemini_provider:
return self._gemini_provider
raise ValueError(f"Provider {p} 未初始化")
async def _prepare_mcp_tools(self, auto_mcp: bool = True, force_refresh: bool = False) -> Optional[List[Dict]]:
"""
预处理MCP工具
检查用户MCP配置并加载可用工具。
结果会被缓存,避免重复加载。
Args:
auto_mcp: 是否自动加载MCP工具(来自调用方参数)
force_refresh: 是否强制刷新缓存
Returns:
- None: 无可用工具(未配置/未启用/加载失败)
- List[Dict]: OpenAI格式的工具列表
"""
# 前置条件检查
if not self._enable_mcp:
logger.debug(f"🔧 MCP工具未启用 (_enable_mcp=False)")
# 即使有缓存也清理掉,确保不使用
self._cached_tools = None
self._tools_loaded = False
return None
if not auto_mcp:
logger.debug(f"🔧 auto_mcp=False,跳过MCP工具加载")
# 即使有缓存也清理掉,确保不使用
self._cached_tools = None
self._tools_loaded = False
return None
if not self.user_id:
logger.debug(f"🔧 MCP工具加载跳过: user_id未设置")
return None
if not self.db_session:
logger.debug(f"🔧 MCP工具加载跳过: db_session未设置")
return None
# 使用缓存(只有 enable_mcp=True 时才使用缓存)
if self._tools_loaded and not force_refresh:
if self._cached_tools:
logger.debug(f"🔧 使用缓存的MCP工具 ({len(self._cached_tools)}个)")
return self._cached_tools
try:
from app.services.mcp_tools_loader import mcp_tools_loader
self._cached_tools = await mcp_tools_loader.get_user_tools(
user_id=self.user_id,
db_session=self.db_session,
use_cache=True,
force_refresh=force_refresh
)
self._tools_loaded = True
if self._cached_tools:
logger.info(f"🔧 已加载 {len(self._cached_tools)} 个MCP工具")
else:
logger.debug(f"📭 用户 {self.user_id} 没有可用的MCP工具")
return self._cached_tools
except Exception as e:
logger.warning(f"⚠️ 加载MCP工具失败: {e}")
self._tools_loaded = True
self._cached_tools = None
return None
async def _handle_tool_calls(
self,
original_prompt: str,
response: Dict[str, Any],
max_rounds: int = 2,
**kwargs
) -> Dict[str, Any]:
"""
处理AI返回的工具调用
Args:
original_prompt: 原始提示词
response: AI响应(包含tool_calls
max_rounds: 最大工具调用轮数
**kwargs: 传递给generate_text的其他参数
Returns:
最终的AI响应
"""
from app.mcp import mcp_client
tool_calls = response.get("tool_calls", [])
if not tool_calls or not self.user_id:
return response
result = {
"content": response.get("content", ""),
"tool_calls_made": 0,
"tools_used": [],
"finish_reason": response.get("finish_reason", ""),
"mcp_enhanced": True
}
prompt = original_prompt
for round_num in range(max_rounds):
logger.info(f"🔧 工具调用 - 第{round_num+1}/{max_rounds}轮,{len(tool_calls)}个工具")
try:
# 批量执行工具调用
tool_results = await mcp_client.batch_call_tools(
user_id=self.user_id,
tool_calls=tool_calls
)
# 记录使用的工具
for tc in tool_calls:
name = tc["function"]["name"]
if name not in result["tools_used"]:
result["tools_used"].append(name)
result["tool_calls_made"] += len(tool_calls)
# 构建工具上下文
tool_context = mcp_client.build_tool_context(tool_results, format="markdown")
# 更新提示词
if round_num == max_rounds - 1:
# 最后一轮,强制要求回答
prompt = f"{original_prompt}\n\n{tool_context}\n\n⚠️ 重要:请基于以上工具查询结果,给出完整详细的最终答案。不要再调用工具。"
tool_choice = "none"
else:
prompt = f"{original_prompt}\n\n{tool_context}\n\n请基于以上工具查询结果,继续完成任务。"
tool_choice = kwargs.get("tool_choice", "auto")
# 继续调用AI
prov = self._get_provider(kwargs.get("provider"))
next_response = await prov.generate(
prompt=prompt,
model=kwargs.get("model") or self.default_model,
temperature=kwargs.get("temperature") or self.default_temperature,
max_tokens=kwargs.get("max_tokens") or self.default_max_tokens,
system_prompt=kwargs.get("system_prompt") or self.default_system_prompt,
tools=None if tool_choice == "none" else self._cached_tools,
tool_choice=tool_choice,
)
tool_calls = next_response.get("tool_calls", [])
if not tool_calls:
# 没有更多工具调用,返回结果
result["content"] = next_response.get("content", "")
result["finish_reason"] = next_response.get("finish_reason", "stop")
break
except Exception as e:
logger.error(f"❌ 工具调用失败: {e}")
result["content"] = response.get("content", "")
result["finish_reason"] = "tool_error"
break
return result
async def generate_text(
self,
prompt: str,
provider: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
tools: Optional[List[Dict]] = None,
tool_choice: Optional[str] = None,
auto_mcp: bool = True,
handle_tool_calls: bool = True,
mcp_max_rounds: Optional[int] = None,
) -> Dict[str, Any]:
"""
生成文本(自动支持MCP工具)
Args:
prompt: 用户提示词
provider: AI提供商
model: 模型名称
temperature: 温度
max_tokens: 最大令牌数
system_prompt: 系统提示词
tools: 手动指定的工具列表(优先级高于自动加载)
tool_choice: 工具选择策略
auto_mcp: 是否自动加载MCP工具(默认True)
handle_tool_calls: 是否自动处理工具调用(默认True)
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
Returns:
包含生成内容的字典
"""
# 使用全局配置的MCP轮数(如果未指定)
if mcp_max_rounds is None:
mcp_max_rounds = app_settings.mcp_max_rounds
# 自动加载MCP工具
if auto_mcp and tools is None:
tools = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
prov = self._get_provider(provider)
response = await prov.generate(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
max_tokens=max_tokens or self.default_max_tokens,
system_prompt=system_prompt or self.default_system_prompt,
tools=tools,
tool_choice=tool_choice,
)
# 处理工具调用
if handle_tool_calls and response.get("tool_calls"):
return await self._handle_tool_calls(
original_prompt=prompt,
response=response,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
tool_choice=tool_choice,
max_rounds=mcp_max_rounds,
)
return response
async def generate_text_stream(
self,
prompt: str,
provider: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
tool_choice: Optional[str] = None,
auto_mcp: bool = True,
mcp_max_rounds: Optional[int] = None,
) -> AsyncGenerator[str, None]:
"""
流式生成文本(自动支持MCP工具)
工具调用在 Provider 层通过流式方式处理,支持真正的流式工具调用。
Args:
prompt: 用户提示词
provider: AI提供商
model: 模型名称
temperature: 温度
max_tokens: 最大令牌数
system_prompt: 系统提示词
tool_choice: 工具选择策略("auto"/"none"/"required"
auto_mcp: 是否自动加载MCP工具
mcp_max_rounds: 最大工具调用轮数(None使用默认值3)
Yields:
生成的文本块
"""
logger.debug(f"🔧 generate_text_stream: auto_mcp={auto_mcp}, tool_choice={tool_choice}")
tools_to_use = None
# 加载MCP工具
if auto_mcp:
tools_to_use = await self._prepare_mcp_tools(auto_mcp=auto_mcp)
if tools_to_use:
logger.info(f"🔧 已获取 {len(tools_to_use)} 个MCP工具")
# 流式生成(Provider 层处理工具调用)
prov = self._get_provider(provider)
logger.debug(f"🔧 开始流式生成,provider={provider or self.api_provider}, tools_count={len(tools_to_use) if tools_to_use else 0}")
async for chunk in prov.generate_stream(
prompt=prompt,
model=model or self.default_model,
temperature=temperature or self.default_temperature,
max_tokens=max_tokens or self.default_max_tokens,
system_prompt=system_prompt or self.default_system_prompt,
tools=tools_to_use,
tool_choice=tool_choice,
user_id=self.user_id,
):
yield chunk
async def call_with_json_retry(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_retries: int = 3,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
provider: Optional[str] = None,
model: Optional[str] = None,
expected_type: Optional[str] = None,
auto_mcp: bool = True,
) -> Union[Dict, List]:
"""
带重试的 JSON 调用(自动支持MCP工具)
Args:
prompt: 用户提示词
system_prompt: 系统提示词
max_retries: 最大重试次数
temperature: 温度
max_tokens: 最大令牌数
provider: AI提供商
model: 模型名称
expected_type: 期望的返回类型("object""array"
auto_mcp: 是否自动加载MCP工具
Returns:
解析后的JSON数据
"""
last_response = ""
for attempt in range(1, max_retries + 1):
current_prompt = prompt if attempt == 1 else self._add_json_hint(prompt, last_response, attempt)
result = await self.generate_text(
prompt=current_prompt,
provider=provider,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
auto_mcp=auto_mcp,
handle_tool_calls=True,
)
last_response = result.get("content", "")
try:
data = parse_json(last_response)
if expected_type == "object" and not isinstance(data, dict):
raise ValueError("期望对象")
if expected_type == "array" and not isinstance(data, list):
raise ValueError("期望数组")
return data
except Exception as e:
if attempt == max_retries:
raise ValueError(f"JSON 解析失败: {e}")
raise ValueError("JSON 调用失败")
@staticmethod
def _add_json_hint(prompt: str, failed: str, attempt: int) -> str:
return f"{prompt}\n\n⚠️ 第{attempt}次重试,请返回纯JSON,不要markdown包裹。上次错误: {failed[:200]}..."
@staticmethod
def _clean_json_response(text: str) -> str:
"""清洗 JSON 响应"""
return clean_json_response(text)
def create_user_ai_service(
api_provider: str,
api_key: str,
api_base_url: str,
model_name: str,
temperature: float,
max_tokens: int,
system_prompt: Optional[str] = None,
) -> AIService:
"""创建用户 AI 服务(不带MCP支持)"""
return AIService(
api_provider=api_provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=model_name,
default_temperature=temperature,
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
)
def create_user_ai_service_with_mcp(
api_provider: str,
api_key: str,
api_base_url: str,
model_name: str,
temperature: float,
max_tokens: int,
user_id: str,
db_session,
system_prompt: Optional[str] = None,
enable_mcp: bool = True,
) -> AIService:
"""
创建支持MCP的用户AI服务
Args:
api_provider: AI提供商
api_key: API密钥
api_base_url: API基础URL
model_name: 模型名称
temperature: 温度
max_tokens: 最大令牌数
user_id: 用户ID(用于加载MCP工具)
db_session: 数据库会话
system_prompt: 系统提示词
enable_mcp: 是否启用MCP工具
Returns:
配置好的AIService实例
"""
return AIService(
api_provider=api_provider,
api_key=api_key,
api_base_url=api_base_url,
default_model=model_name,
default_temperature=temperature,
default_max_tokens=max_tokens,
default_system_prompt=system_prompt,
user_id=user_id,
db_session=db_session,
enable_mcp=enable_mcp,
)