Files
MuMuAINovel/backend/app/services/mcp_tool_service.py
T

355 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
from typing import List, Dict, Any, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import asyncio
import json
from datetime import datetime
from app.models.mcp_plugin import MCPPlugin
from app.mcp.registry import mcp_registry
from app.logger import get_logger
logger = get_logger(__name__)
class MCPToolServiceError(Exception):
"""MCP工具服务异常"""
pass
class MCPToolService:
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
def __init__(self):
self._tool_cache = {} # 工具定义缓存
self._result_cache = {} # 工具结果缓存(可选)
async def get_user_enabled_tools(
self,
user_id: str,
db_session: AsyncSession,
category: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
获取用户启用的MCP工具列表
Args:
user_id: 用户ID
db_session: 数据库会话
category: 工具类别筛选(search/analysis/filesystem等)
Returns:
工具定义列表,格式符合OpenAI Function Calling规范
"""
try:
# 1. 查询用户启用的插件(enabled=True即可,不强制要求status=active
# 因为新启用的插件status可能还是inactive,需要给它机会被调用
query = select(MCPPlugin).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True
)
if category:
query = query.where(MCPPlugin.category == category)
result = await db_session.execute(query)
plugins = result.scalars().all()
if not plugins:
logger.info(f"用户 {user_id} 没有启用的MCP插件")
return []
# 2. 获取所有工具定义
all_tools = []
for plugin in plugins:
try:
# 确保插件已加载到注册表
if not mcp_registry.get_client(user_id, plugin.plugin_name):
logger.info(f"插件 {plugin.plugin_name} 未加载,尝试加载...")
success = await mcp_registry.load_plugin(plugin)
if not success:
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
continue
# 从registry获取该插件的工具列表
plugin_tools = await mcp_registry.get_plugin_tools(
user_id=user_id,
plugin_name=plugin.plugin_name
)
# 格式化为Function Calling格式
formatted_tools = self._format_tools_for_ai(
plugin_tools,
plugin.plugin_name # ✅ 修复:使用正确的属性名plugin_name
)
all_tools.extend(formatted_tools)
logger.info(
f"从插件 {plugin.plugin_name} 加载了 "
f"{len(formatted_tools)} 个工具"
)
except Exception as e:
logger.error(
f"获取插件 {plugin.plugin_name} 的工具失败: {e}",
exc_info=True
)
continue
logger.info(f"用户 {user_id} 共加载 {len(all_tools)} 个MCP工具")
return all_tools
except Exception as e:
logger.error(f"获取用户MCP工具失败: {e}", exc_info=True)
raise MCPToolServiceError(f"获取MCP工具失败: {str(e)}")
def _format_tools_for_ai(
self,
plugin_tools: List[Dict[str, Any]],
plugin_name: str
) -> List[Dict[str, Any]]:
"""
将MCP工具定义格式化为AI Function Calling格式
Args:
plugin_tools: MCP插件的工具列表
plugin_name: 插件名称
Returns:
格式化后的工具列表
"""
formatted_tools = []
for tool in plugin_tools:
formatted_tool = {
"type": "function",
"function": {
"name": f"{plugin_name}_{tool['name']}", # 加插件前缀避免冲突
"description": tool.get("description", ""),
"parameters": tool.get("inputSchema", {
"type": "object",
"properties": {},
"required": []
})
}
}
formatted_tools.append(formatted_tool)
return formatted_tools
async def execute_tool_calls(
self,
user_id: str,
tool_calls: List[Dict[str, Any]],
db_session: AsyncSession,
timeout: float = 60.0
) -> List[Dict[str, Any]]:
"""
批量执行AI请求的工具调用(并行执行)
Args:
user_id: 用户ID
tool_calls: AI返回的工具调用列表
db_session: 数据库会话
timeout: 单个工具调用的超时时间(秒,默认30秒)
Returns:
工具调用结果列表
"""
if not tool_calls:
return []
logger.info(f"开始执行 {len(tool_calls)} 个工具调用")
# 创建异步任务列表
tasks = [
self._execute_single_tool(
user_id=user_id,
tool_call=tool_call,
db_session=db_session,
timeout=timeout
)
for tool_call in tool_calls
]
# 并行执行所有工具调用
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
formatted_results = []
for i, result in enumerate(results):
tool_call = tool_calls[i]
if isinstance(result, Exception):
# 工具调用异常
formatted_results.append({
"tool_call_id": tool_call.get("id", f"call_{i}"),
"role": "tool",
"name": tool_call["function"]["name"],
"content": f"工具调用失败: {str(result)}",
"success": False,
"error": str(result)
})
else:
formatted_results.append(result)
return formatted_results
async def _execute_single_tool(
self,
user_id: str,
tool_call: Dict[str, Any],
db_session: AsyncSession,
timeout: float
) -> Dict[str, Any]:
"""
执行单个工具调用
Args:
user_id: 用户ID
tool_call: 工具调用信息
db_session: 数据库会话
timeout: 超时时间
Returns:
工具调用结果
"""
tool_call_id = tool_call.get("id", "unknown")
function_name = tool_call["function"]["name"]
try:
# 解析插件名和工具名
if "_" in function_name:
plugin_name, tool_name = function_name.split("_", 1)
else:
raise ValueError(f"无效的工具名称格式: {function_name}")
# 解析参数
arguments_str = tool_call["function"]["arguments"]
if isinstance(arguments_str, str):
arguments = json.loads(arguments_str)
else:
arguments = arguments_str
logger.info(
f"执行工具: {plugin_name}.{tool_name}, "
f"参数: {arguments}"
)
# 设置超时
try:
result = await asyncio.wait_for(
mcp_registry.call_tool(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments
),
timeout=timeout
)
# 成功返回
return {
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": json.dumps(result, ensure_ascii=False),
"success": True,
"error": None
}
except asyncio.TimeoutError:
raise MCPToolServiceError(
f"工具调用超时(>{timeout}秒)"
)
except Exception as e:
logger.error(
f"工具 {function_name} 调用失败: {e}",
exc_info=True
)
return {
"tool_call_id": tool_call_id,
"role": "tool",
"name": function_name,
"content": f"工具调用失败: {str(e)}",
"success": False,
"error": str(e)
}
async def build_tool_context(
self,
tool_results: List[Dict[str, Any]],
format: str = "markdown"
) -> str:
"""
将工具调用结果格式化为上下文文本
Args:
tool_results: 工具调用结果列表
format: 输出格式(markdown/json/plain
Returns:
格式化的上下文字符串
"""
if not tool_results:
return ""
if format == "markdown":
return self._build_markdown_context(tool_results)
elif format == "json":
return json.dumps(tool_results, ensure_ascii=False, indent=2)
else: # plain
return self._build_plain_context(tool_results)
def _build_markdown_context(
self,
tool_results: List[Dict[str, Any]]
) -> str:
"""构建Markdown格式的工具上下文"""
lines = ["## 🔧 工具调用结果\n"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
content = result.get("content", "")
status_emoji = "" if success else ""
lines.append(f"### {status_emoji} {i}. {tool_name}\n")
if success:
# 尝试美化JSON内容
try:
content_obj = json.loads(content)
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
except:
pass
lines.append(f"```json\n{content}\n```\n")
else:
lines.append(f"**错误**: {content}\n")
return "\n".join(lines)
def _build_plain_context(
self,
tool_results: List[Dict[str, Any]]
) -> str:
"""构建纯文本格式的工具上下文"""
lines = ["=== 工具调用结果 ===\n"]
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
content = result.get("content", "")
status = "成功" if success else "失败"
lines.append(f"{i}. {tool_name} - {status}")
lines.append(f" 结果: {content}\n")
return "\n".join(lines)
# 全局单例
mcp_tool_service = MCPToolService()