update:1.更新mcp插件功能,目前只支持remote调用
This commit is contained in:
@@ -0,0 +1,355 @@
|
||||
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPToolServiceError(Exception):
|
||||
"""MCP工具服务异常"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPToolService:
|
||||
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
|
||||
|
||||
def __init__(self):
|
||||
self._tool_cache = {} # 工具定义缓存
|
||||
self._result_cache = {} # 工具结果缓存(可选)
|
||||
|
||||
async def get_user_enabled_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
db_session: AsyncSession,
|
||||
category: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户启用的MCP工具列表
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
db_session: 数据库会话
|
||||
category: 工具类别筛选(search/analysis/filesystem等)
|
||||
|
||||
Returns:
|
||||
工具定义列表,格式符合OpenAI Function Calling规范
|
||||
"""
|
||||
try:
|
||||
# 1. 查询用户启用的插件(enabled=True即可,不强制要求status=active)
|
||||
# 因为新启用的插件status可能还是inactive,需要给它机会被调用
|
||||
query = select(MCPPlugin).where(
|
||||
MCPPlugin.user_id == user_id,
|
||||
MCPPlugin.enabled == True
|
||||
)
|
||||
|
||||
if category:
|
||||
query = query.where(MCPPlugin.category == category)
|
||||
|
||||
result = await db_session.execute(query)
|
||||
plugins = result.scalars().all()
|
||||
|
||||
if not plugins:
|
||||
logger.info(f"用户 {user_id} 没有启用的MCP插件")
|
||||
return []
|
||||
|
||||
# 2. 获取所有工具定义
|
||||
all_tools = []
|
||||
for plugin in plugins:
|
||||
try:
|
||||
# 确保插件已加载到注册表
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
logger.info(f"插件 {plugin.plugin_name} 未加载,尝试加载...")
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
|
||||
continue
|
||||
|
||||
# 从registry获取该插件的工具列表
|
||||
plugin_tools = await mcp_registry.get_plugin_tools(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name
|
||||
)
|
||||
|
||||
# 格式化为Function Calling格式
|
||||
formatted_tools = self._format_tools_for_ai(
|
||||
plugin_tools,
|
||||
plugin.plugin_name # ✅ 修复:使用正确的属性名plugin_name
|
||||
)
|
||||
all_tools.extend(formatted_tools)
|
||||
|
||||
logger.info(
|
||||
f"从插件 {plugin.plugin_name} 加载了 "
|
||||
f"{len(formatted_tools)} 个工具"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"获取插件 {plugin.plugin_name} 的工具失败: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {user_id} 共加载 {len(all_tools)} 个MCP工具")
|
||||
return all_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户MCP工具失败: {e}", exc_info=True)
|
||||
raise MCPToolServiceError(f"获取MCP工具失败: {str(e)}")
|
||||
|
||||
def _format_tools_for_ai(
|
||||
self,
|
||||
plugin_tools: List[Dict[str, Any]],
|
||||
plugin_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
将MCP工具定义格式化为AI Function Calling格式
|
||||
|
||||
Args:
|
||||
plugin_tools: MCP插件的工具列表
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
格式化后的工具列表
|
||||
"""
|
||||
formatted_tools = []
|
||||
|
||||
for tool in plugin_tools:
|
||||
formatted_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f"{plugin_name}_{tool['name']}", # 加插件前缀避免冲突
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get("inputSchema", {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
})
|
||||
}
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
db_session: AsyncSession,
|
||||
timeout: float = 60.0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量执行AI请求的工具调用(并行执行)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
tool_calls: AI返回的工具调用列表
|
||||
db_session: 数据库会话
|
||||
timeout: 单个工具调用的超时时间(秒,默认30秒)
|
||||
|
||||
Returns:
|
||||
工具调用结果列表
|
||||
"""
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.info(f"开始执行 {len(tool_calls)} 个工具调用")
|
||||
|
||||
# 创建异步任务列表
|
||||
tasks = [
|
||||
self._execute_single_tool(
|
||||
user_id=user_id,
|
||||
tool_call=tool_call,
|
||||
db_session=db_session,
|
||||
timeout=timeout
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
# 并行执行所有工具调用
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
formatted_results = []
|
||||
for i, result in enumerate(results):
|
||||
tool_call = tool_calls[i]
|
||||
|
||||
if isinstance(result, Exception):
|
||||
# 工具调用异常
|
||||
formatted_results.append({
|
||||
"tool_call_id": tool_call.get("id", f"call_{i}"),
|
||||
"role": "tool",
|
||||
"name": tool_call["function"]["name"],
|
||||
"content": f"工具调用失败: {str(result)}",
|
||||
"success": False,
|
||||
"error": str(result)
|
||||
})
|
||||
else:
|
||||
formatted_results.append(result)
|
||||
|
||||
return formatted_results
|
||||
|
||||
async def _execute_single_tool(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_call: Dict[str, Any],
|
||||
db_session: AsyncSession,
|
||||
timeout: float
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行单个工具调用
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
tool_call: 工具调用信息
|
||||
db_session: 数据库会话
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
工具调用结果
|
||||
"""
|
||||
tool_call_id = tool_call.get("id", "unknown")
|
||||
function_name = tool_call["function"]["name"]
|
||||
|
||||
try:
|
||||
# 解析插件名和工具名
|
||||
if "_" in function_name:
|
||||
plugin_name, tool_name = function_name.split("_", 1)
|
||||
else:
|
||||
raise ValueError(f"无效的工具名称格式: {function_name}")
|
||||
|
||||
# 解析参数
|
||||
arguments_str = tool_call["function"]["arguments"]
|
||||
if isinstance(arguments_str, str):
|
||||
arguments = json.loads(arguments_str)
|
||||
else:
|
||||
arguments = arguments_str
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {plugin_name}.{tool_name}, "
|
||||
f"参数: {arguments}"
|
||||
)
|
||||
|
||||
# 设置超时
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
mcp_registry.call_tool(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 成功返回
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": json.dumps(result, ensure_ascii=False),
|
||||
"success": True,
|
||||
"error": None
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPToolServiceError(
|
||||
f"工具调用超时(>{timeout}秒)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"工具 {function_name} 调用失败: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": f"工具调用失败: {str(e)}",
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def build_tool_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]],
|
||||
format: str = "markdown"
|
||||
) -> str:
|
||||
"""
|
||||
将工具调用结果格式化为上下文文本
|
||||
|
||||
Args:
|
||||
tool_results: 工具调用结果列表
|
||||
format: 输出格式(markdown/json/plain)
|
||||
|
||||
Returns:
|
||||
格式化的上下文字符串
|
||||
"""
|
||||
if not tool_results:
|
||||
return ""
|
||||
|
||||
if format == "markdown":
|
||||
return self._build_markdown_context(tool_results)
|
||||
elif format == "json":
|
||||
return json.dumps(tool_results, ensure_ascii=False, indent=2)
|
||||
else: # plain
|
||||
return self._build_plain_context(tool_results)
|
||||
|
||||
def _build_markdown_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建Markdown格式的工具上下文"""
|
||||
lines = ["## 🔧 工具调用结果\n"]
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status_emoji = "✅" if success else "❌"
|
||||
lines.append(f"### {status_emoji} {i}. {tool_name}\n")
|
||||
|
||||
if success:
|
||||
# 尝试美化JSON内容
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
|
||||
except:
|
||||
pass
|
||||
lines.append(f"```json\n{content}\n```\n")
|
||||
else:
|
||||
lines.append(f"**错误**: {content}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_plain_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建纯文本格式的工具上下文"""
|
||||
lines = ["=== 工具调用结果 ===\n"]
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status = "成功" if success else "失败"
|
||||
lines.append(f"{i}. {tool_name} - {status}")
|
||||
lines.append(f" 结果: {content}\n")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_tool_service = MCPToolService()
|
||||
Reference in New Issue
Block a user