fix:1.优化mcp插件功能,改用mcp sdk库

This commit is contained in:
xiamuceer
2025-11-08 12:32:32 +08:00
parent 88115a45c5
commit c7c1c1fdf3
9 changed files with 1278 additions and 660 deletions
+128 -326
View File
@@ -18,9 +18,9 @@ from app.schemas.mcp_plugin import (
import json
from app.user_manager import User
from app.mcp.registry import mcp_registry
from app.services.mcp_test_service import mcp_test_service
from app.services.mcp_tool_service import mcp_tool_service
from app.logger import get_logger
from app.services.ai_service import create_user_ai_service
from app.models.settings import Settings as UserSettings
logger = get_logger(__name__)
@@ -399,13 +399,8 @@ async def test_plugin(
"""
测试插件连接并调用工具验证功能
测试流程:
1. 测试MCP服务器连接
2. 获取工具列表
3. 自动选择一个工具进行实际调用测试
4. 返回完整测试结果
使用新的MCPTestService进行测试
"""
import time
result = await db.execute(
select(MCPPlugin).where(
@@ -426,356 +421,153 @@ async def test_plugin(
suggestions=["点击开关按钮启用插件"]
)
start_time = time.time()
# 使用新的测试服务
try:
# 1. 确保插件已加载
if not mcp_registry.get_client(user.user_id, plugin.plugin_name):
success = await mcp_registry.load_plugin(plugin)
if not success:
return MCPTestResult(
success=False,
message="插件加载失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
# 2. 测试连接并获取工具列表
test_result = await mcp_registry.test_plugin(user.user_id, plugin.plugin_name)
if not test_result["success"]:
plugin.status = "error"
plugin.last_error = test_result.get("error", "连接测试失败")
plugin.last_test_at = datetime.now()
await db.commit()
return MCPTestResult(**test_result)
tools = test_result.get("tools", [])
if not tools:
plugin.status = "error"
plugin.last_error = "插件没有提供任何工具"
plugin.last_test_at = datetime.now()
await db.commit()
return MCPTestResult(
success=False,
message="插件没有提供任何工具",
error="工具列表为空",
response_time_ms=test_result.get("response_time_ms"),
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
)
# 3. 使用AI智能选择工具并生成测试参数
logger.info(f"使用AI分析工具并生成测试计划...")
# 获取用户的AI设置
settings_result = await db.execute(
select(UserSettings).where(UserSettings.user_id == user.user_id)
)
user_settings = settings_result.scalar_one_or_none()
if not user_settings or not user_settings.api_key:
# 如果没有AI配置,回退到简单测试
logger.warning("用户未配置AI服务,使用简单连接测试")
# 更新插件状态
if test_result.success:
plugin.status = "active"
plugin.last_error = None
plugin.last_test_at = datetime.now()
plugin.tools = tools
await db.commit()
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
response_time_ms=test_result.get("response_time_ms"),
tools_count=len(tools),
suggestions=[
f"连接测试: 成功",
f"可用工具数: {len(tools)}",
"提示: 配置AI服务后可进行智能工具调用测试"
]
)
# 使用AI的标准Function Calling机制选择工具
ai_service = create_user_ai_service(
api_provider=user_settings.api_provider,
api_key=user_settings.api_key,
api_base_url=user_settings.api_base_url,
model_name=user_settings.llm_model,
temperature=0.3,
max_tokens=1000
)
# 将MCP工具格式转换为OpenAI Function Calling格式
openai_tools = []
for tool in tools:
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
}
}
# 将 inputSchema 转换为 parameters
if "inputSchema" in tool:
openai_tool["function"]["parameters"] = tool["inputSchema"]
openai_tools.append(openai_tool)
logger.info(f"转换了 {len(openai_tools)} 个MCP工具为OpenAI格式")
logger.info(f"工具列表: {[t['function']['name'] for t in openai_tools]}")
# 使用标准的Function Calling,将转换后的工具传递给AI
prompt = f"""你是MCP插件测试助手,需要测试插件 '{plugin.plugin_name}' 的功能。
请选择一个合适的工具进行测试,优先选择搜索、查询类工具。
生成真实有效的测试参数(例如搜索"人工智能最新进展"而不是"test")。
现在开始测试这个插件。"""
system_prompt = "你是专业的API测试工具。当给定工具列表时,选择一个工具并使用合适的参数调用它。"
# 调用AI的Function Calling
logger.info(f"📞 准备调用AI Function Calling")
logger.info(f" - Provider: {user_settings.api_provider}")
logger.info(f" - Model: {user_settings.llm_model}")
logger.info(f" - Tools count: {len(openai_tools)}")
logger.debug(f" - Tools: {json.dumps(openai_tools, ensure_ascii=False, indent=2)}")
ai_response = await ai_service.generate_text(
prompt=prompt,
system_prompt=system_prompt,
tools=openai_tools, # 传递转换后的OpenAI格式工具
tool_choice="required" # 要求AI必须选择一个工具
)
logger.info(f"📥 收到AI响应")
logger.info(f" - Response keys: {list(ai_response.keys())}")
logger.debug(f" - Full response: {json.dumps(ai_response, ensure_ascii=False, indent=2)}")
# 检查AI是否请求调用工具
if not ai_response.get("tool_calls"):
# AI未调用工具,记录详细信息
logger.error(f"❌ AI未返回工具调用")
logger.error(f" - Response: {ai_response}")
logger.error(f" - Content: {ai_response.get('content', 'N/A')}")
logger.error(f" - Finish reason: {ai_response.get('finish_reason', 'N/A')}")
plugin.status = "error"
plugin.last_error = "AI未返回工具调用请求"
plugin.last_test_at = datetime.now()
await db.commit()
return MCPTestResult(
success=False,
message="❌ AI Function Calling失败",
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
tools_count=len(tools),
suggestions=[
"请确认使用的AI模型支持Function Calling",
"OpenAI: 需要gpt-4, gpt-3.5-turbo等模型",
"Anthropic: 需要claude-3系列模型",
f"当前Provider: {user_settings.api_provider}",
f"当前模型: {user_settings.llm_model}",
f"AI返回内容: {ai_response.get('content', 'N/A')[:100]}"
]
)
# 获取第一个工具调用
tool_call = ai_response["tool_calls"][0]
function = tool_call["function"]
tool_name = function["name"]
test_arguments = function["arguments"]
# AI返回的arguments可能是JSON字符串,需要解析
if isinstance(test_arguments, str):
try:
test_arguments = json.loads(test_arguments)
logger.info(f"✅ 解析AI返回的JSON字符串参数")
except json.JSONDecodeError as e:
logger.error(f"❌ 解析AI参数失败: {e}")
return MCPTestResult(
success=False,
message="❌ AI返回的参数格式错误",
error=f"无法解析参数JSON: {str(e)}",
tools_count=len(tools),
suggestions=["AI返回的参数不是有效的JSON格式"]
)
logger.info(f"🤖 AI通过Function Calling选择的工具: {tool_name}")
logger.info(f"📝 AI生成的参数: {test_arguments}")
logger.info(f"📝 参数类型: {type(test_arguments).__name__}")
# 4. 使用AI选择的工具和参数调用MCP工具
call_start = time.time()
try:
tool_result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
tool_name,
test_arguments
)
call_end = time.time()
call_time = round((call_end - call_start) * 1000, 2)
total_time = round((call_end - start_time) * 1000, 2)
# 6. 测试成功,更新插件状态
plugin.status = "active"
plugin.last_error = None
plugin.last_test_at = datetime.now()
plugin.tools = tools # 缓存工具列表
await db.commit()
# 格式化工具结果用于显示
result_str = str(tool_result)
# 如果结果太长,截取前800字符
if len(result_str) > 800:
result_preview = result_str[:800] + "\n...(结果已截断,完整结果请查看日志)"
else:
result_preview = result_str
plugin.status = "error"
plugin.last_error = test_result.error
return MCPTestResult(
success=True,
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
response_time_ms=total_time,
tools_count=len(tools),
suggestions=[
f"🤖 AI (Function Calling) 选择: {tool_name}",
f"📝 AI生成的参数: {json.dumps(test_arguments, ensure_ascii=False)}",
f"⏱️ 调用耗时: {call_time}ms",
f"📊 返回结果:\n{result_preview}"
]
)
except Exception as call_error:
call_end = time.time()
total_time = round((call_end - start_time) * 1000, 2)
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
# 工具调用失败,但连接成功
plugin.status = "active" # 仍标记为active,因为连接是成功的
plugin.last_error = f"工具调用测试失败: {str(call_error)}"
plugin.last_test_at = datetime.now()
plugin.tools = tools
await db.commit()
return MCPTestResult(
success=True, # 连接成功就算测试通过
message=f"⚠️ 连接成功,但工具调用失败",
response_time_ms=total_time,
tools_count=len(tools),
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
suggestions=[
f"✅ 连接测试: 成功",
f"❌ 工具调用测试: 失败",
f"🤖 AI (Function Calling) 选择: {tool_name}",
f"📝 AI生成的参数: {json.dumps(test_arguments, ensure_ascii=False)}",
f"❌ 错误: {str(call_error)}",
"💡 可能原因: API Key无效、参数错误或服务限制"
]
)
return test_result
except Exception as e:
end_time = time.time()
total_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
plugin.status = "error"
plugin.last_error = str(e)
plugin.last_test_at = datetime.now()
await db.commit()
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=total_time,
error=str(e),
error_type=type(e).__name__,
suggestions=["请检查服务器是否在线", "请确认配置正确", "请检查API Key是否有效"]
)
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
def _build_test_arguments(tool_name: str, input_schema: dict, plugin_name: str) -> dict:
async def _ensure_plugin_loaded(
plugin: MCPPlugin,
user_id: str
) -> bool:
"""
根据工具schema智能构造测试参数
确保插件已加载(共享逻辑)
Args:
tool_name: 工具名称
input_schema: 输入schema
plugin_name: 插件名称
plugin: 插件对象
user_id: 用户ID
Returns:
测试参数字典
是否加载成功
Raises:
HTTPException: 加载失败
"""
# 针对常见MCP工具的默认测试参数
test_cases = {
# Exa搜索工具
"search": {
"query": "AI technology",
"num_results": 3
},
"search_and_contents": {
"query": "artificial intelligence",
"num_results": 2
},
# Brave搜索
"brave_web_search": {
"query": "AI news",
"count": 3
},
# Filesystem工具
"read_file": {
"path": "README.md"
},
"list_directory": {
"path": "."
},
if not mcp_registry.get_client(user_id, plugin.plugin_name):
logger.info(f"插件 {plugin.plugin_name} 未加载,自动加载中...")
success = await mcp_registry.load_plugin(plugin)
if not success:
raise HTTPException(
status_code=500,
detail=f"插件加载失败: {plugin.plugin_name}"
)
return True
@router.get("/metrics")
async def get_metrics(
tool_name: Optional[str] = Query(None, description="工具名称(可选,获取特定工具的指标)"),
user: User = Depends(require_login)
):
"""
获取MCP工具调用指标
Query参数:
- tool_name: 可选,指定工具名称获取特定工具的指标
Returns:
工具调用指标字典,包含:
- total_calls: 总调用次数
- success_calls: 成功调用次数
- failed_calls: 失败调用次数
- success_rate: 成功率
- avg_duration_ms: 平均耗时(毫秒)
- last_call_time: 最后调用时间
"""
metrics = mcp_tool_service.get_metrics(tool_name)
return {
"metrics": metrics,
"tool_name": tool_name,
"timestamp": datetime.now().isoformat()
}
# 如果有针对特定工具的测试用例,使用它
if tool_name in test_cases:
logger.info(f"使用预定义测试参数: {test_cases[tool_name]}")
return test_cases[tool_name]
# 否则根据schema自动构造
properties = input_schema.get("properties", {})
required = input_schema.get("required", [])
@router.get("/cache/stats")
async def get_cache_stats(
user: User = Depends(require_login)
):
"""
获取工具缓存统计信息
test_args = {}
Returns:
缓存统计信息,包含:
- total_entries: 缓存条目总数
- total_hits: 缓存总命中次数
- cache_ttl_minutes: 缓存TTL(分钟)
- entries: 各缓存条目详情
"""
stats = mcp_tool_service.get_cache_stats()
for prop_name, prop_schema in properties.items():
# 只填充必需的参数
if prop_name not in required:
continue
return {
"cache_stats": stats,
"timestamp": datetime.now().isoformat()
}
prop_type = prop_schema.get("type", "string")
# 根据参数名称和类型猜测合适的测试值
if "query" in prop_name.lower() or "search" in prop_name.lower():
test_args[prop_name] = "test query"
elif "url" in prop_name.lower():
test_args[prop_name] = "https://example.com"
elif "path" in prop_name.lower():
test_args[prop_name] = "."
elif "count" in prop_name.lower() or "limit" in prop_name.lower() or "num" in prop_name.lower():
test_args[prop_name] = 3
elif prop_type == "string":
test_args[prop_name] = "test"
elif prop_type == "number" or prop_type == "integer":
test_args[prop_name] = 1
elif prop_type == "boolean":
test_args[prop_name] = True
elif prop_type == "array":
test_args[prop_name] = []
elif prop_type == "object":
test_args[prop_name] = {}
@router.post("/cache/clear")
async def clear_cache(
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
plugin_name: Optional[str] = Query(None, description="插件名称(可选)"),
user: User = Depends(require_login)
):
"""
清理工具缓存
logger.info(f"自动构造测试参数: {test_args}")
return test_args
Query参数:
- user_id: 可选,清理特定用户的缓存
- plugin_name: 可选,清理特定插件的缓存
说明:
- 不提供任何参数:清理所有缓存
- 只提供user_id:清理该用户的所有缓存
- 提供user_id和plugin_name:清理特定插件的缓存
"""
# 非管理员只能清理自己的缓存
if user_id and user_id != user.user_id:
raise HTTPException(status_code=403, detail="无权清理其他用户的缓存")
# 如果没有指定user_id,使用当前用户
target_user_id = user_id or user.user_id
mcp_tool_service.clear_cache(target_user_id, plugin_name)
message = "已清理"
if plugin_name:
message += f"插件 {plugin_name} 的缓存"
elif target_user_id:
message += f"用户 {target_user_id} 的所有缓存"
else:
message += "所有缓存"
logger.info(f"用户 {user.user_id} {message}")
return {
"success": True,
"message": message,
"timestamp": datetime.now().isoformat()
}
@router.get("/{plugin_id}/tools")
@@ -802,6 +594,9 @@ async def get_plugin_tools(
raise HTTPException(status_code=400, detail="插件未启用")
try:
# 确保插件已加载
await _ensure_plugin_loaded(plugin, user.user_id)
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
# 更新缓存
@@ -813,6 +608,8 @@ async def get_plugin_tools(
"tools": tools,
"count": len(tools)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具列表失败: {plugin.plugin_name}, 错误: {e}")
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
@@ -843,6 +640,9 @@ async def call_mcp_tool(
raise HTTPException(status_code=400, detail="插件未启用")
try:
# 确保插件已加载
await _ensure_plugin_loaded(plugin, user.user_id)
# 调用工具
result = await mcp_registry.call_tool(
user.user_id,
@@ -857,6 +657,8 @@ async def call_mcp_tool(
"tool_name": data.tool_name,
"result": result
}
except HTTPException:
raise
except Exception as e:
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
+2 -3
View File
@@ -12,6 +12,7 @@ from app.database import close_db, _session_stats
from app.logger import setup_logging, get_logger
from app.middleware import RequestIDMiddleware
from app.middleware.auth_middleware import AuthMiddleware
from app.mcp.registry import mcp_registry
setup_logging(
level=config_settings.log_level,
@@ -27,9 +28,7 @@ logger = get_logger(__name__)
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
logger.info("应用启动,等待用户登录...")
# 导入MCP注册表
from app.mcp.registry import mcp_registry
logger.info("💡 MCP插件采用延迟加载策略,将在用户首次使用时自动加载")
yield
+42
View File
@@ -0,0 +1,42 @@
"""MCP模块配置常量"""
from dataclasses import dataclass
@dataclass(frozen=True)
class MCPConfig:
"""MCP模块配置常量(不可变)"""
# 连接池配置
MAX_CLIENTS: int = 1000 # 最大客户端数量
CLIENT_TTL_SECONDS: int = 3600 # 客户端过期时间(1小时)
IDLE_TIMEOUT_SECONDS: int = 1800 # 空闲超时(30分钟)
# 健康检查配置
HEALTH_CHECK_INTERVAL_SECONDS: int = 60 # 健康检查间隔
ERROR_RATE_CRITICAL: float = 0.7 # 严重错误率阈值
ERROR_RATE_WARNING: float = 0.4 # 警告错误率阈值
MIN_REQUESTS_FOR_HEALTH_CHECK: int = 10 # 进行健康检查的最小请求数
# 清理任务配置
CLEANUP_INTERVAL_SECONDS: int = 300 # 清理任务间隔(5分钟)
# 缓存配置
TOOL_CACHE_TTL_MINUTES: int = 10 # 工具定义缓存TTL
# 重试配置
MAX_RETRIES: int = 3 # 最大重试次数
BASE_RETRY_DELAY_SECONDS: float = 1.0 # 基础重试延迟
MAX_RETRY_DELAY_SECONDS: float = 10.0 # 最大重试延迟
# 超时配置
DEFAULT_TIMEOUT_SECONDS: float = 60.0 # 默认超时时间
TOOL_CALL_TIMEOUT_SECONDS: float = 60.0 # 工具调用超时时间
# 日志配置
LOG_TOOL_ARGUMENTS: bool = True # 是否记录工具参数
LOG_TOOL_RESULTS: bool = False # 是否记录工具结果(可能很大)
# 全局配置实例
mcp_config = MCPConfig()
+178 -190
View File
@@ -1,8 +1,13 @@
"""HTTP MCP客户端 - 实现JSON-RPC 2.0协议"""
import httpx
"""HTTP MCP客户端 - 使用官方 MCP Python SDK 实现"""
import asyncio
from typing import Dict, Any, List, Optional
from contextlib import asynccontextmanager
from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from pydantic import AnyUrl
from app.logger import get_logger
import time
logger = get_logger(__name__)
@@ -13,15 +18,14 @@ class MCPError(Exception):
class HTTPMCPClient:
"""HTTP模式MCP客户端(类似Cursor/Claude Code实现"""
"""HTTP模式MCP客户端(基于官方 MCP Python SDK"""
def __init__(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0,
http_client: Optional[httpx.AsyncClient] = None
timeout: float = 60.0
):
"""
初始化HTTP MCP客户端
@@ -31,162 +35,79 @@ class HTTPMCPClient:
headers: HTTP请求头
env: 环境变量(用于API Key等)
timeout: 超时时间(秒)
http_client: 可选的共享HTTP客户端(用于连接池复用)
"""
self.url = url.rstrip('/')
self.headers = headers or {}
self.env = env or {}
self.timeout = timeout
# 设置MCP必需的Accept头
# MCP服务器要求客户端必须接受 application/json 和 text/event-stream
if 'Accept' not in self.headers:
self.headers['Accept'] = 'application/json, text/event-stream'
# 设置Content-Type
if 'Content-Type' not in self.headers:
self.headers['Content-Type'] = 'application/json'
# 如果env中有API Key,添加到headers
if 'API_KEY' in self.env:
self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}'
# 使用共享客户端或创建新客户端
self._owns_client = http_client is None
if http_client:
self.client = http_client
else:
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(timeout),
headers=self.headers
)
self._request_id = 0
def _next_request_id(self) -> int:
"""获取下一个请求ID"""
self._request_id += 1
return self._request_id
async def _call_jsonrpc(
self,
method: str,
params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
调用JSON-RPC 2.0方法
Args:
method: 方法名
params: 参数
Returns:
响应结果
Raises:
MCPError: 调用失败时抛出
"""
request_id = self._next_request_id()
payload = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": params or {}
}
self._session: Optional[ClientSession] = None
self._context_stack = [] # 保存上下文管理器栈
self._initialized = False
self._lock = asyncio.Lock()
async def _ensure_connected(self):
"""确保连接已建立"""
async with self._lock:
if self._session is None:
try:
logger.debug(f"MCP请求: {method} -> {self.url}")
logger.info(f"🔗 连接到MCP服务器: {self.url}")
response = await self.client.post(
self.url,
json=payload,
headers=self.headers # 显式传递headers(对于共享客户端很重要)
)
# 使用官方 SDK 的 streamable_http_client
# 保存上下文管理器以便后续正确清理
stream_context = streamablehttp_client(self.url)
read_stream, write_stream, _ = await stream_context.__aenter__()
self._context_stack.append(('stream', stream_context))
response.raise_for_status()
# 创建客户端会话
self._session = ClientSession(read_stream, write_stream)
session_context = self._session
await session_context.__aenter__()
self._context_stack.append(('session', session_context))
# 获取响应内容
response_text = response.text
content_type = response.headers.get('content-type', '')
# 初始化会话
await self._session.initialize()
self._initialized = True
# 如果是空响应
if not response_text or response_text.strip() == '':
raise MCPError("服务器返回空响应")
logger.info(f"✅ MCP会话初始化成功")
# 处理SSE格式响应
if 'text/event-stream' in content_type or response_text.startswith('event:'):
logger.debug("检测到SSE格式响应,开始解析")
data = self._parse_sse_response(response_text)
else:
# 标准JSON响应
try:
data = response.json()
except ValueError as e:
logger.error(f"JSON解析失败,响应内容: {response_text[:500]}")
raise MCPError(f"无法解析JSON响应: {str(e)}")
# 检查JSON-RPC错误
if "error" in data:
error = data["error"]
error_msg = error.get("message", "Unknown error")
error_code = error.get("code", -1)
logger.error(f"MCP错误 [{error_code}]: {error_msg}")
raise MCPError(f"[{error_code}] {error_msg}")
if "result" not in data:
raise MCPError("响应中缺少result字段")
return data["result"]
except httpx.HTTPStatusError as e:
logger.error(f"HTTP错误 {e.response.status_code}: {e.response.text}")
raise MCPError(f"HTTP错误 {e.response.status_code}: {e.response.text}")
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
raise MCPError(f"请求错误: {str(e)}")
except MCPError:
raise
except Exception as e:
logger.error(f"未知错误: {str(e)}")
raise MCPError(f"未知错误: {str(e)}")
logger.error(f"❌ MCP连接失败: {e}")
await self._cleanup()
raise MCPError(f"连接MCP服务器失败: {str(e)}")
def _parse_sse_response(self, sse_text: str) -> Dict[str, Any]:
async def _cleanup(self):
"""清理连接资源(按照进入的相反顺序退出)"""
# 按照LIFO顺序清理上下文
while self._context_stack:
ctx_type, ctx = self._context_stack.pop()
try:
await ctx.__aexit__(None, None, None)
except RuntimeError as e:
# 忽略 anyio 的任务上下文错误(在关闭时可能发生)
if "cancel scope" in str(e).lower() or "different task" in str(e).lower():
logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}")
else:
logger.error(f"清理{ctx_type}上下文失败: {e}")
except Exception as e:
logger.error(f"清理{ctx_type}上下文失败: {e}")
self._session = None
self._initialized = False
async def initialize(self) -> Dict[str, Any]:
"""
解析SSE格式的响应
SSE格式示例:
event: message
data: {"result": {...}}
Args:
sse_text: SSE格式的文本
初始化MCP会话
Returns:
解析后的JSON数据
初始化响应
"""
import json
lines = sse_text.strip().split('\n')
data_lines = []
for line in lines:
line = line.strip()
if line.startswith('data:'):
# 提取data后面的内容
data_content = line[5:].strip()
data_lines.append(data_content)
if not data_lines:
raise MCPError("SSE响应中没有找到data字段")
# 合并所有data行(某些SSE可能分多行)
full_data = ''.join(data_lines)
try:
return json.loads(full_data)
except json.JSONDecodeError as e:
logger.error(f"解析SSE data失败: {full_data[:200]}")
raise MCPError(f"SSE data不是有效的JSON: {str(e)}")
await self._ensure_connected()
return {"status": "initialized"}
async def list_tools(self) -> List[Dict[str, Any]]:
"""
@@ -196,13 +117,26 @@ class HTTPMCPClient:
工具列表
"""
try:
result = await self._call_jsonrpc("tools/list")
tools = result.get("tools", [])
await self._ensure_connected()
result = await self._session.list_tools()
# 转换为字典格式
tools = []
for tool in result.tools:
tool_dict = {
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema
}
tools.append(tool_dict)
logger.info(f"获取到 {len(tools)} 个工具")
return tools
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
raise
raise MCPError(f"获取工具列表失败: {str(e)}")
async def call_tool(
self,
@@ -220,33 +154,38 @@ class HTTPMCPClient:
工具执行结果
"""
try:
await self._ensure_connected()
logger.info(f"调用工具: {tool_name}")
logger.debug(f"参数: {arguments}")
result = await self._call_jsonrpc(
"tools/call",
{
"name": tool_name,
"arguments": arguments
result = await self._session.call_tool(tool_name, arguments)
# 处理返回结果
# MCP SDK 返回 CallToolResult 对象
if result.content:
# 提取第一个content的文本
for content in result.content:
if isinstance(content, types.TextContent):
return content.text
elif isinstance(content, types.ImageContent):
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
)
# 如果没有文本内容,返回原始内容
return result.content[0] if result.content else None
# MCP返回的result通常包含content数组
if isinstance(result, dict) and "content" in result:
content = result["content"]
if isinstance(content, list) and len(content) > 0:
# 提取第一个content项的text
first_content = content[0]
if isinstance(first_content, dict) and "text" in first_content:
return first_content["text"]
return first_content
return content
# 如果有结构化内容(2025-06-18规范)
if hasattr(result, 'structuredContent') and result.structuredContent:
return result.structuredContent
return result
return None
except Exception as e:
logger.error(f"调用工具失败: {tool_name}, 错误: {e}")
raise
raise MCPError(f"调用工具失败: {str(e)}")
async def list_resources(self) -> List[Dict[str, Any]]:
"""
@@ -256,13 +195,27 @@ class HTTPMCPClient:
资源列表
"""
try:
result = await self._call_jsonrpc("resources/list")
resources = result.get("resources", [])
await self._ensure_connected()
result = await self._session.list_resources()
# 转换为字典格式
resources = []
for resource in result.resources:
resource_dict = {
"uri": str(resource.uri),
"name": resource.name,
"description": resource.description or "",
"mimeType": resource.mimeType or ""
}
resources.append(resource_dict)
logger.info(f"获取到 {len(resources)} 个资源")
return resources
except Exception as e:
logger.error(f"获取资源列表失败: {e}")
raise
raise MCPError(f"获取资源列表失败: {str(e)}")
async def read_resource(self, uri: str) -> Any:
"""
@@ -275,14 +228,33 @@ class HTTPMCPClient:
资源内容
"""
try:
result = await self._call_jsonrpc(
"resources/read",
{"uri": uri}
)
return result
await self._ensure_connected()
result = await self._session.read_resource(AnyUrl(uri))
# 提取资源内容
if result.contents:
content = result.contents[0]
if isinstance(content, types.TextContent):
return content.text
elif isinstance(content, types.ImageContent):
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
elif isinstance(content, types.BlobResourceContents):
return {
"type": "blob",
"blob": content.blob,
"mimeType": content.mimeType
}
return None
except Exception as e:
logger.error(f"读取资源失败: {uri}, 错误: {e}")
raise
raise MCPError(f"读取资源失败: {str(e)}")
async def test_connection(self) -> Dict[str, Any]:
"""
@@ -291,10 +263,12 @@ class HTTPMCPClient:
Returns:
测试结果
"""
import time
start_time = time.time()
try:
# 尝试列举工具来测试连接
# 尝试连接并列举工具
await self._ensure_connected()
tools = await self.list_tools()
end_time = time.time()
@@ -307,22 +281,7 @@ class HTTPMCPClient:
"tools_count": len(tools),
"tools": tools
}
except MCPError as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
return {
"success": False,
"message": "连接测试失败",
"response_time_ms": response_time,
"error": str(e),
"error_type": "MCPError",
"suggestions": [
"请检查服务器URL是否正确",
"请确认API Key是否有效",
"请检查网络连接"
]
}
except Exception as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
@@ -334,12 +293,41 @@ class HTTPMCPClient:
"error": str(e),
"error_type": type(e).__name__,
"suggestions": [
"请检查服务器是否在线",
"请确认配置是否正确"
"请检查服务器URL是否正确",
"请确认API Key是否有效",
"请检查网络连接",
"请确认MCP服务器是否在线"
]
}
async def close(self):
"""关闭客户端(仅在拥有客户端所有权时关闭)"""
if self._owns_client and self.client:
await self.client.aclose()
"""关闭客户端连接"""
logger.info(f"关闭MCP客户端: {self.url}")
await self._cleanup()
@asynccontextmanager
async def create_mcp_client(
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0
):
"""
创建MCP客户端的上下文管理器
Args:
url: MCP服务器URL
headers: HTTP请求头
env: 环境变量
timeout: 超时时间
Yields:
HTTPMCPClient实例
"""
client = HTTPMCPClient(url, headers, env, timeout)
try:
await client.initialize()
yield client
finally:
await client.close()
+245 -89
View File
@@ -1,93 +1,153 @@
"""MCP插件注册表 - 管理运行时插件实例"""
import asyncio
import time
import httpx
from typing import Dict, Optional, Any, List, Tuple
from collections import OrderedDict
from typing import Dict, Optional, Any, List
from dataclasses import dataclass
from datetime import datetime
from app.mcp.http_client import HTTPMCPClient, MCPError
from app.mcp.config import mcp_config
from app.models.mcp_plugin import MCPPlugin
from app.logger import get_logger
logger = get_logger(__name__)
class MCPPluginRegistry:
"""MCP插件注册表 - 管理运行时插件实例(多用户优化版)"""
@dataclass
class SessionInfo:
"""会话信息"""
client: HTTPMCPClient
created_at: float
last_access: float
request_count: int = 0
error_count: int = 0
status: str = "active" # active, degraded, error
def __init__(self, max_clients: int = 1000, client_ttl: int = 3600):
class MCPPluginRegistry:
"""MCP插件注册表 - 管理运行时插件实例(优化版)"""
def __init__(
self,
max_clients: Optional[int] = None,
client_ttl: Optional[int] = None
):
"""
初始化注册表
Args:
max_clients: 最大缓存客户端数量
client_ttl: 客户端过期时间(秒,默认1小时
max_clients: 最大缓存客户端数量(默认使用配置)
client_ttl: 客户端过期时间(秒,默认使用配置)
"""
# 存储格式: {plugin_id: (client, last_access_time)}
self._clients: OrderedDict[str, Tuple[HTTPMCPClient, float]] = OrderedDict()
# 存储格式: {plugin_id: SessionInfo}
self._sessions: Dict[str, SessionInfo] = {}
# 全局锁用于保护会话字典
self._sessions_lock = asyncio.Lock()
# 细粒度锁:每个用户一个锁
self._user_locks: Dict[str, asyncio.Lock] = {}
self._locks_lock = asyncio.Lock() # 保护locks字典本身
# 配置参数
self._max_clients = max_clients
self._client_ttl = client_ttl
# 共享HTTP客户端池(用于所有MCP HTTP请求)
self._shared_http_client = httpx.AsyncClient(
limits=httpx.Limits(
max_keepalive_connections=100,
max_connections=200,
keepalive_expiry=30.0
),
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=5.0),
headers={
"User-Agent": "MuMuAINovel-MCP-Client/1.0"
}
)
# 配置参数(使用配置常量)
self._max_clients = max_clients or mcp_config.MAX_CLIENTS
self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS
# 启动后台清理任务
self._cleanup_task = None
self._start_cleanup_task()
self._health_check_task = None
self._start_background_tasks()
def _start_cleanup_task(self):
"""启动后台清理任务"""
def _start_background_tasks(self):
"""启动后台任务"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("✅ MCP插件注册表后台清理任务已启动")
if self._health_check_task is None:
self._health_check_task = asyncio.create_task(self._health_check_loop())
logger.info("✅ MCP会话健康检查任务已启动")
async def _cleanup_loop(self):
"""后台清理过期客户端"""
while True:
try:
await asyncio.sleep(300) # 每5分钟清理一次
await self._cleanup_expired_clients()
await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS)
await self._cleanup_expired_sessions()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理任务异常: {e}")
async def _cleanup_expired_clients(self):
"""清理过期的客户端"""
async def _health_check_loop(self):
"""后台健康检查"""
while True:
try:
await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS)
await self._check_session_health()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"健康检查任务异常: {e}")
async def _cleanup_expired_sessions(self):
"""清理过期的会话"""
now = time.time()
expired_ids = []
async with self._sessions_lock:
# 收集过期的plugin_id
for plugin_id, (client, last_access) in list(self._clients.items()):
if now - last_access > self._client_ttl:
for plugin_id, session in list(self._sessions.items()):
if now - session.last_access > self._client_ttl:
expired_ids.append(plugin_id)
if expired_ids:
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP客户端")
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话")
for plugin_id in expired_ids:
# 提取user_id来获取对应的锁
user_id = plugin_id.split(':', 1)[0]
user_lock = await self._get_user_lock(user_id)
async with user_lock:
if plugin_id in self._clients:
async with self._sessions_lock:
if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
async def _check_session_health(self):
"""增强的会话健康检查"""
async with self._sessions_lock:
for plugin_id, session in list(self._sessions.items()):
# 计算错误率
if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK:
error_rate = session.error_count / session.request_count
# 动态调整状态(使用配置常量)
if error_rate > mcp_config.ERROR_RATE_CRITICAL:
if session.status != "error":
session.status = "error"
logger.error(
f"❌ 会话 {plugin_id} 错误率过高 "
f"({error_rate:.1%}), 标记为error"
)
elif error_rate > mcp_config.ERROR_RATE_WARNING:
if session.status == "active":
session.status = "degraded"
logger.warning(
f"⚠️ 会话 {plugin_id} 健康状况下降 "
f"(错误率: {error_rate:.1%})"
)
elif session.status == "degraded":
# 错误率降低,恢复正常
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
# 检查长时间无活动的会话
idle_time = time.time() - session.last_access
if idle_time > mcp_config.IDLE_TIMEOUT_SECONDS:
logger.info(
f"💤 会话 {plugin_id} 空闲 {idle_time/60:.1f} 分钟,"
f"准备清理"
)
async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
"""
获取用户专属的锁(细粒度锁)
@@ -103,25 +163,32 @@ class MCPPluginRegistry:
self._user_locks[user_id] = asyncio.Lock()
return self._user_locks[user_id]
def _touch_client(self, plugin_id: str):
def _touch_session(self, plugin_id: str):
"""
更新客户端的最后访问时间(LRU
更新会话的最后访问时间(需要在锁内调用
Args:
plugin_id: 插件ID
"""
if plugin_id in self._clients:
client, _ = self._clients[plugin_id]
self._clients[plugin_id] = (client, time.time())
# 移到末尾(LRU
self._clients.move_to_end(plugin_id)
if plugin_id in self._sessions:
session = self._sessions[plugin_id]
session.last_access = time.time()
session.request_count += 1
async def _evict_lru_client(self):
"""驱逐最久未使用的客户端(当达到max_clients限制时)"""
if len(self._clients) >= self._max_clients:
# 获取最旧的plugin_id
oldest_id = next(iter(self._clients))
logger.info(f"📤 达到最大客户端数量限制,驱逐: {oldest_id}")
async def _evict_lru_session(self):
"""驱逐最久未使用的会话(当达到max_clients限制时)"""
if len(self._sessions) >= self._max_clients:
# 找到最旧的会话
oldest_id = None
oldest_time = float('inf')
for plugin_id, session in self._sessions.items():
if session.last_access < oldest_time:
oldest_time = session.last_access
oldest_id = plugin_id
if oldest_id:
logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}")
await self._unload_plugin_unsafe(oldest_id)
async def load_plugin(self, plugin: MCPPlugin) -> bool:
@@ -141,11 +208,12 @@ class MCPPluginRegistry:
plugin_id = f"{plugin.user_id}:{plugin.plugin_name}"
# 如果已加载,先卸载
if plugin_id in self._clients:
async with self._sessions_lock:
if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
# 检查是否需要驱逐LRU客户端
await self._evict_lru_client()
# 检查是否需要驱逐LRU会话
await self._evict_lru_session()
# 目前只支持HTTP类型
if plugin.plugin_type == "http":
@@ -153,18 +221,30 @@ class MCPPluginRegistry:
logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}")
return False
# 使用共享HTTP连接池创建客户端
# 为每个插件创建独立的HTTP客户端
client = HTTPMCPClient(
url=plugin.server_url,
headers=plugin.headers or {},
env=plugin.env or {},
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0,
http_client=self._shared_http_client # 传入共享连接池
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
)
# 存储客户端和当前时间戳
self._clients[plugin_id] = (client, time.time())
logger.info(f"✅ 加载MCP插件: {plugin_id}")
# 创建会话信息
now = time.time()
session = SessionInfo(
client=client,
created_at=now,
last_access=now,
request_count=0,
error_count=0,
status="active"
)
# 存储会话
async with self._sessions_lock:
self._sessions[plugin_id] = session
logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)")
return True
else:
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
@@ -186,18 +266,19 @@ class MCPPluginRegistry:
user_lock = await self._get_user_lock(user_id)
async with user_lock:
plugin_id = f"{user_id}:{plugin_name}"
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
async def _unload_plugin_unsafe(self, plugin_id: str):
"""卸载插件(不加锁,内部使用)"""
if plugin_id in self._clients:
client, _ = self._clients[plugin_id] # 解包 (client, timestamp)
"""卸载插件(不加锁,内部使用,需要在sessions_lock内调用"""
if plugin_id in self._sessions:
session = self._sessions[plugin_id]
try:
await client.close()
await session.client.close()
except Exception as e:
logger.error(f"关闭插件客户端失败 {plugin_id}: {e}")
del self._clients[plugin_id]
del self._sessions[plugin_id]
logger.info(f"卸载MCP插件: {plugin_id}")
async def reload_plugin(self, plugin: MCPPlugin) -> bool:
@@ -215,7 +296,7 @@ class MCPPluginRegistry:
def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]:
"""
获取插件客户端(支持LRU访问时间更新)
获取插件客户端(线程安全,支持访问时间更新)
Args:
user_id: 用户ID
@@ -225,13 +306,68 @@ class MCPPluginRegistry:
客户端实例或None
"""
plugin_id = f"{user_id}:{plugin_name}"
entry = self._clients.get(plugin_id)
if entry:
# 更新访问时间(LRU
self._touch_client(plugin_id)
return entry[0] # 返回客户端对象
session = self._sessions.get(plugin_id)
if session:
# 检查会话状态
if session.status == "error":
logger.warning(
f"⚠️ 会话 {plugin_id} 处于错误状态,"
f"建议调用者重新加载插件"
)
# 不返回错误状态的客户端
return None
# ✅ 使用锁保护状态更新,避免并发问题
# 注意:这里使用原子操作更新简单字段,不需要异步锁
session.last_access = time.time()
session.request_count += 1
return session.client
return None
async def get_or_reconnect_client(
self,
user_id: str,
plugin_name: str,
plugin: MCPPlugin
) -> HTTPMCPClient:
"""
获取或重连客户端(自动处理错误状态)
Args:
user_id: 用户ID
plugin_name: 插件名称
plugin: 插件配置对象
Returns:
客户端实例
Raises:
ValueError: 插件加载失败
"""
plugin_id = f"{user_id}:{plugin_name}"
# 获取用户锁
user_lock = await self._get_user_lock(user_id)
async with user_lock:
session = self._sessions.get(plugin_id)
# 检查会话健康状态
if session and session.status == "error":
logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连")
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
session = None
# 如果没有会话,加载插件
if not session:
success = await self.load_plugin(plugin)
if not success:
raise ValueError(f"插件加载失败: {plugin_name}")
session = self._sessions[plugin_id]
return session.client
async def call_tool(
self,
user_id: str,
@@ -240,7 +376,7 @@ class MCPPluginRegistry:
arguments: Dict[str, Any]
) -> Any:
"""
调用插件工具
调用插件工具(带错误计数和状态管理)
Args:
user_id: 用户ID
@@ -255,18 +391,39 @@ class MCPPluginRegistry:
ValueError: 插件不存在或未启用
MCPError: 工具调用失败
"""
client = self.get_client(user_id, plugin_name)
plugin_id = f"{user_id}:{plugin_name}"
if not client:
# 获取会话
session = self._sessions.get(plugin_id)
if not session:
raise ValueError(f"插件未加载: {plugin_name}")
try:
result = await client.call_tool(tool_name, arguments)
result = await session.client.call_tool(tool_name, arguments)
logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}")
# logger.info(f"✅ 工具返回内容: {result}")
# 调用成功,重置状态(如果之前是degraded)
if session.status == "degraded":
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
return result
except Exception as e:
logger.error(f"❌ 工具调用失败: {plugin_name}.{tool_name}, 错误: {e}")
# 增加错误计数
session.error_count += 1
# 根据错误率更新状态
if session.request_count > 0:
error_rate = session.error_count / session.request_count
if error_rate > 0.5:
session.status = "error"
elif error_rate > 0.3:
session.status = "degraded"
logger.error(
f"❌ 工具调用失败: {plugin_name}.{tool_name}, "
f"错误: {e} (错误计数: {session.error_count}/{session.request_count})"
)
raise
async def get_plugin_tools(
@@ -320,7 +477,7 @@ class MCPPluginRegistry:
async def cleanup_all(self):
"""清理所有插件和资源"""
# 停止后台清理任务
# 停止后台任务
if self._cleanup_task:
self._cleanup_task.cancel()
try:
@@ -328,19 +485,18 @@ class MCPPluginRegistry:
except asyncio.CancelledError:
pass
# 清理所有客户端
plugin_ids = list(self._clients.keys())
for plugin_id in plugin_ids:
user_id = plugin_id.split(':', 1)[0]
user_lock = await self._get_user_lock(user_id)
async with user_lock:
await self._unload_plugin_unsafe(plugin_id)
# 关闭共享HTTP客户端
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._shared_http_client.aclose()
except Exception as e:
logger.error(f"关闭共享HTTP客户端失败: {e}")
await self._health_check_task
except asyncio.CancelledError:
pass
# 清理所有会话
async with self._sessions_lock:
plugin_ids = list(self._sessions.keys())
for plugin_id in plugin_ids:
await self._unload_plugin_unsafe(plugin_id)
logger.info("✅ 已清理所有MCP插件和资源")
+320
View File
@@ -0,0 +1,320 @@
"""MCP插件测试服务 - 专门处理插件测试逻辑"""
import time
import json
from typing import Dict, Any, Optional
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.mcp_plugin import MCPPlugin
from app.models.settings import Settings as UserSettings
from app.mcp.registry import mcp_registry
from app.services.ai_service import create_user_ai_service
from app.schemas.mcp_plugin import MCPTestResult
from app.logger import get_logger
from app.user_manager import User
logger = get_logger(__name__)
class MCPTestService:
"""MCP插件测试服务(分离的测试逻辑)"""
async def test_plugin_connection(
self,
plugin: MCPPlugin,
user_id: str
) -> MCPTestResult:
"""
简单连接测试
Args:
plugin: 插件配置
user_id: 用户ID
Returns:
测试结果
"""
start_time = time.time()
try:
# 确保插件已加载
if not mcp_registry.get_client(user_id, plugin.plugin_name):
success = await mcp_registry.load_plugin(plugin)
if not success:
return MCPTestResult(
success=False,
message="插件加载失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 测试连接并获取工具列表
test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
if test_result["success"]:
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功",
response_time_ms=response_time,
tools_count=test_result.get("tools_count", 0),
suggestions=[
f"响应时间: {response_time}ms",
f"可用工具数: {test_result.get('tools_count', 0)}"
]
)
else:
return MCPTestResult(**test_result)
except Exception as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=response_time,
error=str(e),
error_type=type(e).__name__,
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
async def test_plugin_with_ai(
self,
plugin: MCPPlugin,
user: User,
db_session: AsyncSession
) -> MCPTestResult:
"""
使用AI进行智能工具调用测试
Args:
plugin: 插件配置
user: 用户对象
db_session: 数据库会话
Returns:
测试结果
"""
start_time = time.time()
try:
# 1. 先进行连接测试
connection_result = await self.test_plugin_connection(plugin, user.user_id)
if not connection_result.success:
return connection_result
# 2. 获取工具列表
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
if not tools:
return MCPTestResult(
success=False,
message="插件没有提供任何工具",
error="工具列表为空",
response_time_ms=connection_result.response_time_ms,
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
)
# 3. 获取用户的AI设置
settings_result = await db_session.execute(
select(UserSettings).where(UserSettings.user_id == user.user_id)
)
user_settings = settings_result.scalar_one_or_none()
if not user_settings or not user_settings.api_key:
# 没有AI配置,返回简单测试结果
logger.warning("用户未配置AI服务,跳过智能测试")
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
response_time_ms=connection_result.response_time_ms,
tools_count=len(tools),
suggestions=[
f"连接测试: 成功",
f"可用工具数: {len(tools)}",
"提示: 配置AI服务后可进行智能工具调用测试"
]
)
# 4. 使用AI选择工具并生成测试参数
logger.info(f"使用AI分析工具并生成测试计划...")
ai_service = create_user_ai_service(
api_provider=user_settings.api_provider,
api_key=user_settings.api_key,
api_base_url=user_settings.api_base_url,
model_name=user_settings.llm_model,
temperature=0.3,
max_tokens=1000
)
# 转换为OpenAI Function Calling格式
openai_tools = self._convert_tools_to_openai_format(tools)
# 调用AI选择工具
prompt = f"""你是MCP插件测试助手,需要测试插件 '{plugin.plugin_name}' 的功能。
⚠️ 重要规则:生成参数时,必须严格使用工具 schema 中定义的原始参数名称,不要转换为 snake_case 或其他格式。
例如:如果 schema 中是 'nextThoughtNeeded',就必须使用 'nextThoughtNeeded',不能改成 'next_thought_needed'
请选择一个合适的工具进行测试,优先选择搜索、查询类工具。
生成真实有效的测试参数(例如搜索"人工智能最新进展"而不是"test")。
现在开始测试这个插件。"""
system_prompt = """你是专业的API测试工具。当给定工具列表时,选择一个工具并使用合适的参数调用它。
⚠️ 关键规则:调用工具时,必须严格使用 schema 中定义的原始参数名,不要自行转换命名风格。
- 如果参数名是 camelCase(如 nextThoughtNeeded),就使用 camelCase
- 如果参数名是 snake_case(如 next_thought),就使用 snake_case
- 保持与 schema 中定义的完全一致,包括大小写和命名风格"""
ai_response = await ai_service.generate_text(
prompt=prompt,
system_prompt=system_prompt,
tools=openai_tools,
tool_choice="required"
)
# 5. 检查AI是否返回工具调用
if not ai_response.get("tool_calls"):
logger.error(f"❌ AI未返回工具调用")
return MCPTestResult(
success=False,
message="❌ AI Function Calling失败",
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
tools_count=len(tools),
suggestions=[
"请确认使用的AI模型支持Function Calling",
f"当前Provider: {user_settings.api_provider}",
f"当前模型: {user_settings.llm_model}"
]
)
# 6. 解析工具调用
tool_call = ai_response["tool_calls"][0]
function = tool_call["function"]
tool_name = function["name"]
test_arguments = function["arguments"]
if isinstance(test_arguments, str):
try:
test_arguments = json.loads(test_arguments)
except json.JSONDecodeError as e:
logger.error(f"❌ 解析AI参数失败: {e}")
return MCPTestResult(
success=False,
message="❌ AI返回的参数格式错误",
error=f"无法解析参数JSON: {str(e)}",
tools_count=len(tools)
)
logger.info(f"🤖 AI选择的工具: {tool_name}")
logger.info(f"📝 AI生成的参数: {test_arguments}")
# 7. 调用MCP工具
call_start = time.time()
try:
tool_result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
tool_name,
test_arguments
)
call_end = time.time()
call_time = round((call_end - call_start) * 1000, 2)
total_time = round((call_end - start_time) * 1000, 2)
# 格式化结果
result_str = str(tool_result)
if len(result_str) > 800:
result_preview = result_str[:800] + "\n...(结果已截断)"
else:
result_preview = result_str
return MCPTestResult(
success=True,
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
response_time_ms=total_time,
tools_count=len(tools),
suggestions=[
f"🤖 AI选择: {tool_name}",
f"📝 参数: {json.dumps(test_arguments, ensure_ascii=False)}",
f"⏱️ 耗时: {call_time}ms",
f"📊 结果:\n{result_preview}"
]
)
except Exception as call_error:
call_end = time.time()
total_time = round((call_end - start_time) * 1000, 2)
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
return MCPTestResult(
success=True, # 连接成功就算测试通过
message=f"⚠️ 连接成功,但工具调用失败",
response_time_ms=total_time,
tools_count=len(tools),
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
suggestions=[
f"✅ 连接测试: 成功",
f"❌ 工具调用测试: 失败",
f"🤖 AI选择: {tool_name}",
f"❌ 错误: {str(call_error)}",
"💡 可能原因: API Key无效、参数错误或服务限制"
]
)
except Exception as e:
end_time = time.time()
total_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=total_time,
error=str(e),
error_type=type(e).__name__,
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
def _convert_tools_to_openai_format(self, tools: list) -> list:
"""将MCP工具格式转换为OpenAI Function Calling格式"""
openai_tools = []
for tool in tools:
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
}
}
if "inputSchema" in tool:
openai_tool["function"]["parameters"] = tool["inputSchema"]
openai_tools.append(openai_tool)
return openai_tools
# 全局单例
mcp_test_service = MCPTestService()
+327 -19
View File
@@ -5,26 +5,100 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import asyncio
import json
from datetime import datetime
import time
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from collections import defaultdict
from app.models.mcp_plugin import MCPPlugin
from app.mcp.registry import mcp_registry
from app.mcp.config import mcp_config
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ToolMetrics:
"""工具调用指标"""
total_calls: int = 0
success_calls: int = 0
failed_calls: int = 0
total_duration_ms: float = 0.0
avg_duration_ms: float = 0.0
last_call_time: Optional[datetime] = None
def update_success(self, duration_ms: float):
"""更新成功调用指标"""
self.total_calls += 1
self.success_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
def update_failure(self, duration_ms: float):
"""更新失败调用指标"""
self.total_calls += 1
self.failed_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
@property
def success_rate(self) -> float:
"""成功率"""
if self.total_calls == 0:
return 0.0
return self.success_calls / self.total_calls
@dataclass
class ToolCacheEntry:
"""工具缓存条目"""
tools: List[Dict[str, Any]]
expire_time: datetime
hit_count: int = 0
class MCPToolServiceError(Exception):
"""MCP工具服务异常"""
pass
class MCPToolService:
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
"""MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)"""
def __init__(self):
self._tool_cache = {} # 工具定义缓存
self._result_cache = {} # 工具结果缓存(可选)
def __init__(
self,
cache_ttl_minutes: Optional[int] = None,
max_retries: Optional[int] = None
):
"""
初始化MCP工具服务
Args:
cache_ttl_minutes: 工具缓存TTL(分钟,默认使用配置)
max_retries: 最大重试次数(默认使用配置)
"""
# 工具定义缓存: {cache_key: ToolCacheEntry}
self._tool_cache: Dict[str, ToolCacheEntry] = {}
self._cache_ttl = timedelta(
minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES
)
# 调用指标: {tool_key: ToolMetrics}
self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics)
# 重试配置(使用配置常量)
self._max_retries = max_retries or mcp_config.MAX_RETRIES
self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS
self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS
logger.info(
f"✅ MCPToolService初始化完成 "
f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, "
f"最大重试={self._max_retries}次)"
)
async def get_user_enabled_tools(
self,
@@ -61,7 +135,7 @@ class MCPToolService:
logger.info(f"用户 {user_id} 没有启用的MCP插件")
return []
# 2. 获取所有工具定义
# 2. 获取所有工具定义(使用缓存)
all_tools = []
for plugin in plugins:
try:
@@ -73,8 +147,8 @@ class MCPToolService:
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
continue
# 从registry获取该插件的工具列表
plugin_tools = await mcp_registry.get_plugin_tools(
# ✅ 使用缓存获取工具列表
plugin_tools = await self._get_plugin_tools_cached(
user_id=user_id,
plugin_name=plugin.plugin_name
)
@@ -82,7 +156,7 @@ class MCPToolService:
# 格式化为Function Calling格式
formatted_tools = self._format_tools_for_ai(
plugin_tools,
plugin.plugin_name # ✅ 修复:使用正确的属性名plugin_name
plugin.plugin_name
)
all_tools.extend(formatted_tools)
@@ -139,12 +213,85 @@ class MCPToolService:
return formatted_tools
async def _get_plugin_tools_cached(
self,
user_id: str,
plugin_name: str
) -> List[Dict[str, Any]]:
"""
带缓存的工具列表获取
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
工具列表
"""
cache_key = f"{user_id}:{plugin_name}"
now = datetime.now()
# 检查缓存
if cache_key in self._tool_cache:
entry = self._tool_cache[cache_key]
if now < entry.expire_time:
entry.hit_count += 1
logger.debug(
f"🎯 工具缓存命中: {cache_key} "
f"(命中次数: {entry.hit_count})"
)
return entry.tools
else:
logger.debug(f"⏰ 工具缓存过期: {cache_key}")
del self._tool_cache[cache_key]
# 缓存未命中,从MCP获取
logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}")
tools = await mcp_registry.get_plugin_tools(user_id, plugin_name)
# 更新缓存
self._tool_cache[cache_key] = ToolCacheEntry(
tools=tools,
expire_time=now + self._cache_ttl,
hit_count=0
)
return tools
def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None):
"""
清理缓存
Args:
user_id: 用户ID(可选,清理特定用户的缓存)
plugin_name: 插件名称(可选,清理特定插件的缓存)
"""
if user_id is None and plugin_name is None:
# 清理所有缓存
self._tool_cache.clear()
logger.info("🧹 已清理所有工具缓存")
elif user_id and plugin_name:
# 清理特定插件缓存
cache_key = f"{user_id}:{plugin_name}"
if cache_key in self._tool_cache:
del self._tool_cache[cache_key]
logger.info(f"🧹 已清理缓存: {cache_key}")
elif user_id:
# 清理用户所有缓存
keys_to_delete = [
key for key in self._tool_cache.keys()
if key.startswith(f"{user_id}:")
]
for key in keys_to_delete:
del self._tool_cache[key]
logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)")
async def execute_tool_calls(
self,
user_id: str,
tool_calls: List[Dict[str, Any]],
db_session: AsyncSession,
timeout: float = 60.0
timeout: Optional[float] = None
) -> List[Dict[str, Any]]:
"""
批量执行AI请求的工具调用(并行执行)
@@ -153,7 +300,7 @@ class MCPToolService:
user_id: 用户ID
tool_calls: AI返回的工具调用列表
db_session: 数据库会话
timeout: 单个工具调用的超时时间(秒,默认30秒
timeout: 单个工具调用的超时时间(秒,默认使用配置
Returns:
工具调用结果列表
@@ -161,7 +308,10 @@ class MCPToolService:
if not tool_calls:
return []
logger.info(f"开始执行 {len(tool_calls)} 个工具调用")
# 使用配置的默认超时
actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS
logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s)")
# 创建异步任务列表
tasks = [
@@ -169,7 +319,7 @@ class MCPToolService:
user_id=user_id,
tool_call=tool_call,
db_session=db_session,
timeout=timeout
timeout=actual_timeout
)
for tool_call in tool_calls
]
@@ -238,18 +388,28 @@ class MCPToolService:
f"参数: {arguments}"
)
# 设置超时
# ✅ 使用带重试的调用
tool_key = f"{plugin_name}.{tool_name}"
start_time = time.time()
try:
result = await asyncio.wait_for(
mcp_registry.call_tool(
result = await self._call_tool_with_retry(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments
),
arguments=arguments,
timeout=timeout
)
# 记录成功指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_success(duration_ms)
logger.info(
f"✅ 工具调用成功: {tool_key} "
f"(耗时: {duration_ms:.2f}ms)"
)
# 成功返回
return {
"tool_call_id": tool_call_id,
@@ -261,13 +421,21 @@ class MCPToolService:
}
except asyncio.TimeoutError:
# 记录失败指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
raise MCPToolServiceError(
f"工具调用超时(>{timeout}秒)"
)
except Exception as e:
# 记录失败指标
tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
logger.error(
f"工具 {function_name} 调用失败: {e}",
f"工具 {function_name} 调用失败: {e}",
exc_info=True
)
return {
@@ -279,6 +447,146 @@ class MCPToolService:
"error": str(e)
}
async def _call_tool_with_retry(
self,
user_id: str,
plugin_name: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: float
) -> Any:
"""
带指数退避重试的工具调用
Args:
user_id: 用户ID
plugin_name: 插件名称
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
工具执行结果
Raises:
MCPToolServiceError: 工具调用失败
asyncio.TimeoutError: 调用超时
"""
last_exception = None
for attempt in range(self._max_retries):
try:
# 尝试调用工具
result = await asyncio.wait_for(
mcp_registry.call_tool(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments
),
timeout=timeout
)
# 成功则返回
if attempt > 0:
logger.info(
f"✅ 重试成功: {plugin_name}.{tool_name} "
f"(第{attempt + 1}次尝试)"
)
return result
except asyncio.TimeoutError:
# 超时不重试,直接抛出
raise
except Exception as e:
last_exception = e
# 最后一次尝试失败
if attempt == self._max_retries - 1:
logger.error(
f"❌ 重试失败: {plugin_name}.{tool_name} "
f"(已尝试{self._max_retries}次): {e}"
)
raise MCPToolServiceError(
f"工具调用失败(已重试{self._max_retries}次): {str(e)}"
)
# 计算指数退避延迟
delay = min(
self._base_retry_delay * (2 ** attempt),
self._max_retry_delay
)
logger.warning(
f"⚠️ 工具调用失败,{delay:.1f}秒后重试 "
f"(第{attempt + 1}/{self._max_retries}次): "
f"{plugin_name}.{tool_name} - {e}"
)
await asyncio.sleep(delay)
# 理论上不会到这里,但为了类型安全
raise MCPToolServiceError(f"工具调用失败: {last_exception}")
def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
"""
获取工具调用指标
Args:
tool_name: 工具名称(可选,获取特定工具的指标)
Returns:
指标字典
"""
if tool_name:
if tool_name in self._metrics:
metric = self._metrics[tool_name]
return {
tool_name: {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": metric.success_rate,
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
}
return {}
# 返回所有工具的指标
result = {}
for tool_key, metric in self._metrics.items():
result[tool_key] = {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": round(metric.success_rate, 3),
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
return result
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
total_entries = len(self._tool_cache)
total_hits = sum(entry.hit_count for entry in self._tool_cache.values())
return {
"total_entries": total_entries,
"total_hits": total_hits,
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"key": key,
"tools_count": len(entry.tools),
"hit_count": entry.hit_count,
"expire_time": entry.expire_time.isoformat()
}
for key, entry in self._tool_cache.items()
]
}
async def build_tool_context(
self,
tool_results: List[Dict[str, Any]],
+2 -2
View File
@@ -226,8 +226,8 @@ class PlotAnalyzer:
)
# 🔍 添加调试日志:查看AI返回的原始内容
logger.info(f"🔍 AI返回类型: {type(response)}")
logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
# logger.info(f"🔍 AI返回类型: {type(response)}")
# logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
# 从返回的字典中提取content字段
if isinstance(response, dict):
+8 -5
View File
@@ -1,15 +1,15 @@
# Web框架
fastapi==0.109.0
uvicorn[standard]==0.27.0
python-multipart==0.0.6
fastapi==0.121.0
uvicorn[standard]==0.38.0
python-multipart==0.0.20
# 数据库
sqlalchemy==2.0.25
aiosqlite==0.19.0
# 数据验证
pydantic==2.5.3
pydantic-settings==2.1.0
pydantic==2.12.4
pydantic-settings==2.11.0
# AI服务
openai==2.7.0
@@ -18,6 +18,9 @@ anthropic==0.72.0
# 工具库
httpx==0.28.1
python-dotenv==1.0.0
# MCP官方库(Model Context Protocol Python SDK
mcp==1.21.0
# NumPy版本锁定(兼容性要求)
numpy==1.26.4