fix:1.优化mcp插件功能,改用mcp sdk库

This commit is contained in:
xiamuceer
2025-11-08 12:32:32 +08:00
parent 88115a45c5
commit c7c1c1fdf3
9 changed files with 1278 additions and 660 deletions
+137 -335
View File
@@ -18,9 +18,9 @@ from app.schemas.mcp_plugin import (
import json import json
from app.user_manager import User from app.user_manager import User
from app.mcp.registry import mcp_registry from app.mcp.registry import mcp_registry
from app.services.mcp_test_service import mcp_test_service
from app.services.mcp_tool_service import mcp_tool_service
from app.logger import get_logger from app.logger import get_logger
from app.services.ai_service import create_user_ai_service
from app.models.settings import Settings as UserSettings
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -399,13 +399,8 @@ async def test_plugin(
""" """
测试插件连接并调用工具验证功能 测试插件连接并调用工具验证功能
测试流程: 使用新的MCPTestService进行测试
1. 测试MCP服务器连接
2. 获取工具列表
3. 自动选择一个工具进行实际调用测试
4. 返回完整测试结果
""" """
import time
result = await db.execute( result = await db.execute(
select(MCPPlugin).where( select(MCPPlugin).where(
@@ -426,356 +421,153 @@ async def test_plugin(
suggestions=["点击开关按钮启用插件"] suggestions=["点击开关按钮启用插件"]
) )
start_time = time.time() # 使用新的测试服务
try: try:
# 1. 确保插件已加载 test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
if not mcp_registry.get_client(user.user_id, plugin.plugin_name):
success = await mcp_registry.load_plugin(plugin)
if not success:
return MCPTestResult(
success=False,
message="插件加载失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 2. 测试连接并获取工具列表 # 更新插件状态
test_result = await mcp_registry.test_plugin(user.user_id, plugin.plugin_name) if test_result.success:
if not test_result["success"]:
plugin.status = "error"
plugin.last_error = test_result.get("error", "连接测试失败")
plugin.last_test_at = datetime.now()
await db.commit()
return MCPTestResult(**test_result)
tools = test_result.get("tools", [])
if not tools:
plugin.status = "error"
plugin.last_error = "插件没有提供任何工具"
plugin.last_test_at = datetime.now()
await db.commit()
return MCPTestResult(
success=False,
message="插件没有提供任何工具",
error="工具列表为空",
response_time_ms=test_result.get("response_time_ms"),
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
)
# 3. 使用AI智能选择工具并生成测试参数
logger.info(f"使用AI分析工具并生成测试计划...")
# 获取用户的AI设置
settings_result = await db.execute(
select(UserSettings).where(UserSettings.user_id == user.user_id)
)
user_settings = settings_result.scalar_one_or_none()
if not user_settings or not user_settings.api_key:
# 如果没有AI配置,回退到简单测试
logger.warning("用户未配置AI服务,使用简单连接测试")
plugin.status = "active" plugin.status = "active"
plugin.last_error = None plugin.last_error = None
plugin.last_test_at = datetime.now() else:
plugin.tools = tools
await db.commit()
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
response_time_ms=test_result.get("response_time_ms"),
tools_count=len(tools),
suggestions=[
f"连接测试: 成功",
f"可用工具数: {len(tools)}",
"提示: 配置AI服务后可进行智能工具调用测试"
]
)
# 使用AI的标准Function Calling机制选择工具
ai_service = create_user_ai_service(
api_provider=user_settings.api_provider,
api_key=user_settings.api_key,
api_base_url=user_settings.api_base_url,
model_name=user_settings.llm_model,
temperature=0.3,
max_tokens=1000
)
# 将MCP工具格式转换为OpenAI Function Calling格式
openai_tools = []
for tool in tools:
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
}
}
# 将 inputSchema 转换为 parameters
if "inputSchema" in tool:
openai_tool["function"]["parameters"] = tool["inputSchema"]
openai_tools.append(openai_tool)
logger.info(f"转换了 {len(openai_tools)} 个MCP工具为OpenAI格式")
logger.info(f"工具列表: {[t['function']['name'] for t in openai_tools]}")
# 使用标准的Function Calling,将转换后的工具传递给AI
prompt = f"""你是MCP插件测试助手,需要测试插件 '{plugin.plugin_name}' 的功能。
请选择一个合适的工具进行测试,优先选择搜索、查询类工具。
生成真实有效的测试参数(例如搜索"人工智能最新进展"而不是"test")。
现在开始测试这个插件。"""
system_prompt = "你是专业的API测试工具。当给定工具列表时,选择一个工具并使用合适的参数调用它。"
# 调用AI的Function Calling
logger.info(f"📞 准备调用AI Function Calling")
logger.info(f" - Provider: {user_settings.api_provider}")
logger.info(f" - Model: {user_settings.llm_model}")
logger.info(f" - Tools count: {len(openai_tools)}")
logger.debug(f" - Tools: {json.dumps(openai_tools, ensure_ascii=False, indent=2)}")
ai_response = await ai_service.generate_text(
prompt=prompt,
system_prompt=system_prompt,
tools=openai_tools, # 传递转换后的OpenAI格式工具
tool_choice="required" # 要求AI必须选择一个工具
)
logger.info(f"📥 收到AI响应")
logger.info(f" - Response keys: {list(ai_response.keys())}")
logger.debug(f" - Full response: {json.dumps(ai_response, ensure_ascii=False, indent=2)}")
# 检查AI是否请求调用工具
if not ai_response.get("tool_calls"):
# AI未调用工具,记录详细信息
logger.error(f"❌ AI未返回工具调用")
logger.error(f" - Response: {ai_response}")
logger.error(f" - Content: {ai_response.get('content', 'N/A')}")
logger.error(f" - Finish reason: {ai_response.get('finish_reason', 'N/A')}")
plugin.status = "error" plugin.status = "error"
plugin.last_error = "AI未返回工具调用请求" plugin.last_error = test_result.error
plugin.last_test_at = datetime.now()
await db.commit()
return MCPTestResult(
success=False,
message="❌ AI Function Calling失败",
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
tools_count=len(tools),
suggestions=[
"请确认使用的AI模型支持Function Calling",
"OpenAI: 需要gpt-4, gpt-3.5-turbo等模型",
"Anthropic: 需要claude-3系列模型",
f"当前Provider: {user_settings.api_provider}",
f"当前模型: {user_settings.llm_model}",
f"AI返回内容: {ai_response.get('content', 'N/A')[:100]}"
]
)
# 获取第一个工具调用 plugin.last_test_at = datetime.now()
tool_call = ai_response["tool_calls"][0] await db.commit()
function = tool_call["function"]
tool_name = function["name"]
test_arguments = function["arguments"]
# AI返回的arguments可能是JSON字符串,需要解析 return test_result
if isinstance(test_arguments, str):
try:
test_arguments = json.loads(test_arguments)
logger.info(f"✅ 解析AI返回的JSON字符串参数")
except json.JSONDecodeError as e:
logger.error(f"❌ 解析AI参数失败: {e}")
return MCPTestResult(
success=False,
message="❌ AI返回的参数格式错误",
error=f"无法解析参数JSON: {str(e)}",
tools_count=len(tools),
suggestions=["AI返回的参数不是有效的JSON格式"]
)
logger.info(f"🤖 AI通过Function Calling选择的工具: {tool_name}")
logger.info(f"📝 AI生成的参数: {test_arguments}")
logger.info(f"📝 参数类型: {type(test_arguments).__name__}")
# 4. 使用AI选择的工具和参数调用MCP工具
call_start = time.time()
try:
tool_result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
tool_name,
test_arguments
)
call_end = time.time()
call_time = round((call_end - call_start) * 1000, 2)
total_time = round((call_end - start_time) * 1000, 2)
# 6. 测试成功,更新插件状态
plugin.status = "active"
plugin.last_error = None
plugin.last_test_at = datetime.now()
plugin.tools = tools # 缓存工具列表
await db.commit()
# 格式化工具结果用于显示
result_str = str(tool_result)
# 如果结果太长,截取前800字符
if len(result_str) > 800:
result_preview = result_str[:800] + "\n...(结果已截断,完整结果请查看日志)"
else:
result_preview = result_str
return MCPTestResult(
success=True,
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
response_time_ms=total_time,
tools_count=len(tools),
suggestions=[
f"🤖 AI (Function Calling) 选择: {tool_name}",
f"📝 AI生成的参数: {json.dumps(test_arguments, ensure_ascii=False)}",
f"⏱️ 调用耗时: {call_time}ms",
f"📊 返回结果:\n{result_preview}"
]
)
except Exception as call_error:
call_end = time.time()
total_time = round((call_end - start_time) * 1000, 2)
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
# 工具调用失败,但连接成功
plugin.status = "active" # 仍标记为active,因为连接是成功的
plugin.last_error = f"工具调用测试失败: {str(call_error)}"
plugin.last_test_at = datetime.now()
plugin.tools = tools
await db.commit()
return MCPTestResult(
success=True, # 连接成功就算测试通过
message=f"⚠️ 连接成功,但工具调用失败",
response_time_ms=total_time,
tools_count=len(tools),
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
suggestions=[
f"✅ 连接测试: 成功",
f"❌ 工具调用测试: 失败",
f"🤖 AI (Function Calling) 选择: {tool_name}",
f"📝 AI生成的参数: {json.dumps(test_arguments, ensure_ascii=False)}",
f"❌ 错误: {str(call_error)}",
"💡 可能原因: API Key无效、参数错误或服务限制"
]
)
except Exception as e: except Exception as e:
end_time = time.time()
total_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}") logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
plugin.status = "error" plugin.status = "error"
plugin.last_error = str(e) plugin.last_error = str(e)
plugin.last_test_at = datetime.now() plugin.last_test_at = datetime.now()
await db.commit() await db.commit()
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=total_time,
error=str(e),
error_type=type(e).__name__,
suggestions=["请检查服务器是否在线", "请确认配置正确", "请检查API Key是否有效"]
)
def _build_test_arguments(tool_name: str, input_schema: dict, plugin_name: str) -> dict: async def _ensure_plugin_loaded(
plugin: MCPPlugin,
user_id: str
) -> bool:
""" """
根据工具schema智能构造测试参数 确保插件已加载(共享逻辑)
Args: Args:
tool_name: 工具名称 plugin: 插件对象
input_schema: 输入schema user_id: 用户ID
plugin_name: 插件名称
Returns: Returns:
测试参数字典 是否加载成功
"""
# 针对常见MCP工具的默认测试参数
test_cases = {
# Exa搜索工具
"search": {
"query": "AI technology",
"num_results": 3
},
"search_and_contents": {
"query": "artificial intelligence",
"num_results": 2
},
# Brave搜索
"brave_web_search": {
"query": "AI news",
"count": 3
},
# Filesystem工具
"read_file": {
"path": "README.md"
},
"list_directory": {
"path": "."
},
}
# 如果有针对特定工具的测试用例,使用它
if tool_name in test_cases:
logger.info(f"使用预定义测试参数: {test_cases[tool_name]}")
return test_cases[tool_name]
# 否则根据schema自动构造
properties = input_schema.get("properties", {})
required = input_schema.get("required", [])
test_args = {}
for prop_name, prop_schema in properties.items():
# 只填充必需的参数
if prop_name not in required:
continue
prop_type = prop_schema.get("type", "string")
# 根据参数名称和类型猜测合适的测试值 Raises:
if "query" in prop_name.lower() or "search" in prop_name.lower(): HTTPException: 加载失败
test_args[prop_name] = "test query" """
elif "url" in prop_name.lower(): if not mcp_registry.get_client(user_id, plugin.plugin_name):
test_args[prop_name] = "https://example.com" logger.info(f"插件 {plugin.plugin_name} 未加载,自动加载中...")
elif "path" in prop_name.lower(): success = await mcp_registry.load_plugin(plugin)
test_args[prop_name] = "." if not success:
elif "count" in prop_name.lower() or "limit" in prop_name.lower() or "num" in prop_name.lower(): raise HTTPException(
test_args[prop_name] = 3 status_code=500,
elif prop_type == "string": detail=f"插件加载失败: {plugin.plugin_name}"
test_args[prop_name] = "test" )
elif prop_type == "number" or prop_type == "integer": return True
test_args[prop_name] = 1
elif prop_type == "boolean":
test_args[prop_name] = True @router.get("/metrics")
elif prop_type == "array": async def get_metrics(
test_args[prop_name] = [] tool_name: Optional[str] = Query(None, description="工具名称(可选,获取特定工具的指标)"),
elif prop_type == "object": user: User = Depends(require_login)
test_args[prop_name] = {} ):
"""
获取MCP工具调用指标
logger.info(f"自动构造测试参数: {test_args}") Query参数:
return test_args - tool_name: 可选,指定工具名称获取特定工具的指标
Returns:
工具调用指标字典,包含:
- total_calls: 总调用次数
- success_calls: 成功调用次数
- failed_calls: 失败调用次数
- success_rate: 成功率
- avg_duration_ms: 平均耗时(毫秒)
- last_call_time: 最后调用时间
"""
metrics = mcp_tool_service.get_metrics(tool_name)
return {
"metrics": metrics,
"tool_name": tool_name,
"timestamp": datetime.now().isoformat()
}
@router.get("/cache/stats")
async def get_cache_stats(
user: User = Depends(require_login)
):
"""
获取工具缓存统计信息
Returns:
缓存统计信息,包含:
- total_entries: 缓存条目总数
- total_hits: 缓存总命中次数
- cache_ttl_minutes: 缓存TTL(分钟)
- entries: 各缓存条目详情
"""
stats = mcp_tool_service.get_cache_stats()
return {
"cache_stats": stats,
"timestamp": datetime.now().isoformat()
}
@router.post("/cache/clear")
async def clear_cache(
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
plugin_name: Optional[str] = Query(None, description="插件名称(可选)"),
user: User = Depends(require_login)
):
"""
清理工具缓存
Query参数:
- user_id: 可选,清理特定用户的缓存
- plugin_name: 可选,清理特定插件的缓存
说明:
- 不提供任何参数:清理所有缓存
- 只提供user_id:清理该用户的所有缓存
- 提供user_id和plugin_name:清理特定插件的缓存
"""
# 非管理员只能清理自己的缓存
if user_id and user_id != user.user_id:
raise HTTPException(status_code=403, detail="无权清理其他用户的缓存")
# 如果没有指定user_id,使用当前用户
target_user_id = user_id or user.user_id
mcp_tool_service.clear_cache(target_user_id, plugin_name)
message = "已清理"
if plugin_name:
message += f"插件 {plugin_name} 的缓存"
elif target_user_id:
message += f"用户 {target_user_id} 的所有缓存"
else:
message += "所有缓存"
logger.info(f"用户 {user.user_id} {message}")
return {
"success": True,
"message": message,
"timestamp": datetime.now().isoformat()
}
@router.get("/{plugin_id}/tools") @router.get("/{plugin_id}/tools")
@@ -802,6 +594,9 @@ async def get_plugin_tools(
raise HTTPException(status_code=400, detail="插件未启用") raise HTTPException(status_code=400, detail="插件未启用")
try: try:
# 确保插件已加载
await _ensure_plugin_loaded(plugin, user.user_id)
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name) tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
# 更新缓存 # 更新缓存
@@ -813,6 +608,8 @@ async def get_plugin_tools(
"tools": tools, "tools": tools,
"count": len(tools) "count": len(tools)
} }
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"获取工具列表失败: {plugin.plugin_name}, 错误: {e}") logger.error(f"获取工具列表失败: {plugin.plugin_name}, 错误: {e}")
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
@@ -843,6 +640,9 @@ async def call_mcp_tool(
raise HTTPException(status_code=400, detail="插件未启用") raise HTTPException(status_code=400, detail="插件未启用")
try: try:
# 确保插件已加载
await _ensure_plugin_loaded(plugin, user.user_id)
# 调用工具 # 调用工具
result = await mcp_registry.call_tool( result = await mcp_registry.call_tool(
user.user_id, user.user_id,
@@ -857,6 +657,8 @@ async def call_mcp_tool(
"tool_name": data.tool_name, "tool_name": data.tool_name,
"result": result "result": result
} }
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}") logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}") raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
+2 -3
View File
@@ -12,6 +12,7 @@ from app.database import close_db, _session_stats
from app.logger import setup_logging, get_logger from app.logger import setup_logging, get_logger
from app.middleware import RequestIDMiddleware from app.middleware import RequestIDMiddleware
from app.middleware.auth_middleware import AuthMiddleware from app.middleware.auth_middleware import AuthMiddleware
from app.mcp.registry import mcp_registry
setup_logging( setup_logging(
level=config_settings.log_level, level=config_settings.log_level,
@@ -27,9 +28,7 @@ logger = get_logger(__name__)
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""应用生命周期管理""" """应用生命周期管理"""
logger.info("应用启动,等待用户登录...") logger.info("应用启动,等待用户登录...")
logger.info("💡 MCP插件采用延迟加载策略,将在用户首次使用时自动加载")
# 导入MCP注册表
from app.mcp.registry import mcp_registry
yield yield
+42
View File
@@ -0,0 +1,42 @@
"""MCP模块配置常量"""
from dataclasses import dataclass
@dataclass(frozen=True)
class MCPConfig:
"""MCP模块配置常量(不可变)"""
# 连接池配置
MAX_CLIENTS: int = 1000 # 最大客户端数量
CLIENT_TTL_SECONDS: int = 3600 # 客户端过期时间(1小时)
IDLE_TIMEOUT_SECONDS: int = 1800 # 空闲超时(30分钟)
# 健康检查配置
HEALTH_CHECK_INTERVAL_SECONDS: int = 60 # 健康检查间隔
ERROR_RATE_CRITICAL: float = 0.7 # 严重错误率阈值
ERROR_RATE_WARNING: float = 0.4 # 警告错误率阈值
MIN_REQUESTS_FOR_HEALTH_CHECK: int = 10 # 进行健康检查的最小请求数
# 清理任务配置
CLEANUP_INTERVAL_SECONDS: int = 300 # 清理任务间隔(5分钟)
# 缓存配置
TOOL_CACHE_TTL_MINUTES: int = 10 # 工具定义缓存TTL
# 重试配置
MAX_RETRIES: int = 3 # 最大重试次数
BASE_RETRY_DELAY_SECONDS: float = 1.0 # 基础重试延迟
MAX_RETRY_DELAY_SECONDS: float = 10.0 # 最大重试延迟
# 超时配置
DEFAULT_TIMEOUT_SECONDS: float = 60.0 # 默认超时时间
TOOL_CALL_TIMEOUT_SECONDS: float = 60.0 # 工具调用超时时间
# 日志配置
LOG_TOOL_ARGUMENTS: bool = True # 是否记录工具参数
LOG_TOOL_RESULTS: bool = False # 是否记录工具结果(可能很大)
# 全局配置实例
mcp_config = MCPConfig()
+185 -197
View File
@@ -1,8 +1,13 @@
"""HTTP MCP客户端 - 实现JSON-RPC 2.0协议""" """HTTP MCP客户端 - 使用官方 MCP Python SDK 实现"""
import httpx import asyncio
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from contextlib import asynccontextmanager
from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from pydantic import AnyUrl
from app.logger import get_logger from app.logger import get_logger
import time
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -13,15 +18,14 @@ class MCPError(Exception):
class HTTPMCPClient: class HTTPMCPClient:
"""HTTP模式MCP客户端(类似Cursor/Claude Code实现""" """HTTP模式MCP客户端(基于官方 MCP Python SDK"""
def __init__( def __init__(
self, self,
url: str, url: str,
headers: Optional[Dict[str, str]] = None, headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None, env: Optional[Dict[str, str]] = None,
timeout: float = 60.0, timeout: float = 60.0
http_client: Optional[httpx.AsyncClient] = None
): ):
""" """
初始化HTTP MCP客户端 初始化HTTP MCP客户端
@@ -31,162 +35,79 @@ class HTTPMCPClient:
headers: HTTP请求头 headers: HTTP请求头
env: 环境变量(用于API Key等) env: 环境变量(用于API Key等)
timeout: 超时时间(秒) timeout: 超时时间(秒)
http_client: 可选的共享HTTP客户端(用于连接池复用)
""" """
self.url = url.rstrip('/') self.url = url.rstrip('/')
self.headers = headers or {} self.headers = headers or {}
self.env = env or {} self.env = env or {}
self.timeout = timeout self.timeout = timeout
# 设置MCP必需的Accept头
# MCP服务器要求客户端必须接受 application/json 和 text/event-stream
if 'Accept' not in self.headers:
self.headers['Accept'] = 'application/json, text/event-stream'
# 设置Content-Type
if 'Content-Type' not in self.headers:
self.headers['Content-Type'] = 'application/json'
# 如果env中有API Key,添加到headers # 如果env中有API Key,添加到headers
if 'API_KEY' in self.env: if 'API_KEY' in self.env:
self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}' self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}'
# 使用共享客户端或创建新客户端 self._session: Optional[ClientSession] = None
self._owns_client = http_client is None self._context_stack = [] # 保存上下文管理器栈
if http_client: self._initialized = False
self.client = http_client self._lock = asyncio.Lock()
else:
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(timeout),
headers=self.headers
)
self._request_id = 0
def _next_request_id(self) -> int: async def _ensure_connected(self):
"""获取下一个请求ID""" """确保连接已建立"""
self._request_id += 1 async with self._lock:
return self._request_id if self._session is None:
async def _call_jsonrpc(
self,
method: str,
params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
调用JSON-RPC 2.0方法
Args:
method: 方法名
params: 参数
Returns:
响应结果
Raises:
MCPError: 调用失败时抛出
"""
request_id = self._next_request_id()
payload = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": params or {}
}
try:
logger.debug(f"MCP请求: {method} -> {self.url}")
response = await self.client.post(
self.url,
json=payload,
headers=self.headers # 显式传递headers(对于共享客户端很重要)
)
response.raise_for_status()
# 获取响应内容
response_text = response.text
content_type = response.headers.get('content-type', '')
# 如果是空响应
if not response_text or response_text.strip() == '':
raise MCPError("服务器返回空响应")
# 处理SSE格式响应
if 'text/event-stream' in content_type or response_text.startswith('event:'):
logger.debug("检测到SSE格式响应,开始解析")
data = self._parse_sse_response(response_text)
else:
# 标准JSON响应
try: try:
data = response.json() logger.info(f"🔗 连接到MCP服务器: {self.url}")
except ValueError as e:
logger.error(f"JSON解析失败,响应内容: {response_text[:500]}") # 使用官方 SDK 的 streamable_http_client
raise MCPError(f"无法解析JSON响应: {str(e)}") # 保存上下文管理器以便后续正确清理
stream_context = streamablehttp_client(self.url)
# 检查JSON-RPC错误 read_stream, write_stream, _ = await stream_context.__aenter__()
if "error" in data: self._context_stack.append(('stream', stream_context))
error = data["error"]
error_msg = error.get("message", "Unknown error") # 创建客户端会话
error_code = error.get("code", -1) self._session = ClientSession(read_stream, write_stream)
logger.error(f"MCP错误 [{error_code}]: {error_msg}") session_context = self._session
raise MCPError(f"[{error_code}] {error_msg}") await session_context.__aenter__()
self._context_stack.append(('session', session_context))
if "result" not in data:
raise MCPError("响应中缺少result字段") # 初始化会话
await self._session.initialize()
return data["result"] self._initialized = True
except httpx.HTTPStatusError as e: logger.info(f"✅ MCP会话初始化成功")
logger.error(f"HTTP错误 {e.response.status_code}: {e.response.text}")
raise MCPError(f"HTTP错误 {e.response.status_code}: {e.response.text}") except Exception as e:
except httpx.RequestError as e: logger.error(f"❌ MCP连接失败: {e}")
logger.error(f"请求错误: {str(e)}") await self._cleanup()
raise MCPError(f"请求错误: {str(e)}") raise MCPError(f"连接MCP服务器失败: {str(e)}")
except MCPError:
raise
except Exception as e:
logger.error(f"未知错误: {str(e)}")
raise MCPError(f"未知错误: {str(e)}")
def _parse_sse_response(self, sse_text: str) -> Dict[str, Any]: async def _cleanup(self):
"""清理连接资源(按照进入的相反顺序退出)"""
# 按照LIFO顺序清理上下文
while self._context_stack:
ctx_type, ctx = self._context_stack.pop()
try:
await ctx.__aexit__(None, None, None)
except RuntimeError as e:
# 忽略 anyio 的任务上下文错误(在关闭时可能发生)
if "cancel scope" in str(e).lower() or "different task" in str(e).lower():
logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}")
else:
logger.error(f"清理{ctx_type}上下文失败: {e}")
except Exception as e:
logger.error(f"清理{ctx_type}上下文失败: {e}")
self._session = None
self._initialized = False
async def initialize(self) -> Dict[str, Any]:
""" """
解析SSE格式的响应 初始化MCP会话
SSE格式示例:
event: message
data: {"result": {...}}
Args:
sse_text: SSE格式的文本
Returns: Returns:
解析后的JSON数据 初始化响应
""" """
import json await self._ensure_connected()
return {"status": "initialized"}
lines = sse_text.strip().split('\n')
data_lines = []
for line in lines:
line = line.strip()
if line.startswith('data:'):
# 提取data后面的内容
data_content = line[5:].strip()
data_lines.append(data_content)
if not data_lines:
raise MCPError("SSE响应中没有找到data字段")
# 合并所有data行(某些SSE可能分多行)
full_data = ''.join(data_lines)
try:
return json.loads(full_data)
except json.JSONDecodeError as e:
logger.error(f"解析SSE data失败: {full_data[:200]}")
raise MCPError(f"SSE data不是有效的JSON: {str(e)}")
async def list_tools(self) -> List[Dict[str, Any]]: async def list_tools(self) -> List[Dict[str, Any]]:
""" """
@@ -196,13 +117,26 @@ class HTTPMCPClient:
工具列表 工具列表
""" """
try: try:
result = await self._call_jsonrpc("tools/list") await self._ensure_connected()
tools = result.get("tools", [])
result = await self._session.list_tools()
# 转换为字典格式
tools = []
for tool in result.tools:
tool_dict = {
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema
}
tools.append(tool_dict)
logger.info(f"获取到 {len(tools)} 个工具") logger.info(f"获取到 {len(tools)} 个工具")
return tools return tools
except Exception as e: except Exception as e:
logger.error(f"获取工具列表失败: {e}") logger.error(f"获取工具列表失败: {e}")
raise raise MCPError(f"获取工具列表失败: {str(e)}")
async def call_tool( async def call_tool(
self, self,
@@ -220,33 +154,38 @@ class HTTPMCPClient:
工具执行结果 工具执行结果
""" """
try: try:
await self._ensure_connected()
logger.info(f"调用工具: {tool_name}") logger.info(f"调用工具: {tool_name}")
logger.debug(f"参数: {arguments}") logger.debug(f"参数: {arguments}")
result = await self._call_jsonrpc( result = await self._session.call_tool(tool_name, arguments)
"tools/call",
{
"name": tool_name,
"arguments": arguments
}
)
# MCP返回的result通常包含content数组 # 处理返回结果
if isinstance(result, dict) and "content" in result: # MCP SDK 返回 CallToolResult 对象
content = result["content"] if result.content:
if isinstance(content, list) and len(content) > 0: # 提取第一个content的文本
# 提取第一个content项的text for content in result.content:
first_content = content[0] if isinstance(content, types.TextContent):
if isinstance(first_content, dict) and "text" in first_content: return content.text
return first_content["text"] elif isinstance(content, types.ImageContent):
return first_content return {
return content "type": "image",
"data": content.data,
"mimeType": content.mimeType
}
# 如果没有文本内容,返回原始内容
return result.content[0] if result.content else None
return result # 如果有结构化内容(2025-06-18规范)
if hasattr(result, 'structuredContent') and result.structuredContent:
return result.structuredContent
return None
except Exception as e: except Exception as e:
logger.error(f"调用工具失败: {tool_name}, 错误: {e}") logger.error(f"调用工具失败: {tool_name}, 错误: {e}")
raise raise MCPError(f"调用工具失败: {str(e)}")
async def list_resources(self) -> List[Dict[str, Any]]: async def list_resources(self) -> List[Dict[str, Any]]:
""" """
@@ -256,13 +195,27 @@ class HTTPMCPClient:
资源列表 资源列表
""" """
try: try:
result = await self._call_jsonrpc("resources/list") await self._ensure_connected()
resources = result.get("resources", [])
result = await self._session.list_resources()
# 转换为字典格式
resources = []
for resource in result.resources:
resource_dict = {
"uri": str(resource.uri),
"name": resource.name,
"description": resource.description or "",
"mimeType": resource.mimeType or ""
}
resources.append(resource_dict)
logger.info(f"获取到 {len(resources)} 个资源") logger.info(f"获取到 {len(resources)} 个资源")
return resources return resources
except Exception as e: except Exception as e:
logger.error(f"获取资源列表失败: {e}") logger.error(f"获取资源列表失败: {e}")
raise raise MCPError(f"获取资源列表失败: {str(e)}")
async def read_resource(self, uri: str) -> Any: async def read_resource(self, uri: str) -> Any:
""" """
@@ -275,14 +228,33 @@ class HTTPMCPClient:
资源内容 资源内容
""" """
try: try:
result = await self._call_jsonrpc( await self._ensure_connected()
"resources/read",
{"uri": uri} result = await self._session.read_resource(AnyUrl(uri))
)
return result # 提取资源内容
if result.contents:
content = result.contents[0]
if isinstance(content, types.TextContent):
return content.text
elif isinstance(content, types.ImageContent):
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
elif isinstance(content, types.BlobResourceContents):
return {
"type": "blob",
"blob": content.blob,
"mimeType": content.mimeType
}
return None
except Exception as e: except Exception as e:
logger.error(f"读取资源失败: {uri}, 错误: {e}") logger.error(f"读取资源失败: {uri}, 错误: {e}")
raise raise MCPError(f"读取资源失败: {str(e)}")
async def test_connection(self) -> Dict[str, Any]: async def test_connection(self) -> Dict[str, Any]:
""" """
@@ -291,10 +263,12 @@ class HTTPMCPClient:
Returns: Returns:
测试结果 测试结果
""" """
import time
start_time = time.time() start_time = time.time()
try: try:
# 尝试列举工具来测试连接 # 尝试连接并列举工具
await self._ensure_connected()
tools = await self.list_tools() tools = await self.list_tools()
end_time = time.time() end_time = time.time()
@@ -307,22 +281,7 @@ class HTTPMCPClient:
"tools_count": len(tools), "tools_count": len(tools),
"tools": tools "tools": tools
} }
except MCPError as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
return {
"success": False,
"message": "连接测试失败",
"response_time_ms": response_time,
"error": str(e),
"error_type": "MCPError",
"suggestions": [
"请检查服务器URL是否正确",
"请确认API Key是否有效",
"请检查网络连接"
]
}
except Exception as e: except Exception as e:
end_time = time.time() end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2) response_time = round((end_time - start_time) * 1000, 2)
@@ -334,12 +293,41 @@ class HTTPMCPClient:
"error": str(e), "error": str(e),
"error_type": type(e).__name__, "error_type": type(e).__name__,
"suggestions": [ "suggestions": [
"请检查服务器是否在线", "请检查服务器URL是否正确",
"请确认配置是否正确" "请确认API Key是否有效",
"请检查网络连接",
"请确认MCP服务器是否在线"
] ]
} }
async def close(self): async def close(self):
"""关闭客户端(仅在拥有客户端所有权时关闭)""" """关闭客户端连接"""
if self._owns_client and self.client: logger.info(f"关闭MCP客户端: {self.url}")
await self.client.aclose() await self._cleanup()
@asynccontextmanager
async def create_mcp_client(
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0
):
"""
创建MCP客户端的上下文管理器
Args:
url: MCP服务器URL
headers: HTTP请求头
env: 环境变量
timeout: 超时时间
Yields:
HTTPMCPClient实例
"""
client = HTTPMCPClient(url, headers, env, timeout)
try:
await client.initialize()
yield client
finally:
await client.close()
+251 -95
View File
@@ -1,92 +1,152 @@
"""MCP插件注册表 - 管理运行时插件实例""" """MCP插件注册表 - 管理运行时插件实例"""
import asyncio import asyncio
import time import time
import httpx from typing import Dict, Optional, Any, List
from typing import Dict, Optional, Any, List, Tuple from dataclasses import dataclass
from collections import OrderedDict from datetime import datetime
from app.mcp.http_client import HTTPMCPClient, MCPError from app.mcp.http_client import HTTPMCPClient, MCPError
from app.mcp.config import mcp_config
from app.models.mcp_plugin import MCPPlugin from app.models.mcp_plugin import MCPPlugin
from app.logger import get_logger from app.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass
class SessionInfo:
"""会话信息"""
client: HTTPMCPClient
created_at: float
last_access: float
request_count: int = 0
error_count: int = 0
status: str = "active" # active, degraded, error
class MCPPluginRegistry: class MCPPluginRegistry:
"""MCP插件注册表 - 管理运行时插件实例(多用户优化版)""" """MCP插件注册表 - 管理运行时插件实例(优化版)"""
def __init__(self, max_clients: int = 1000, client_ttl: int = 3600): def __init__(
self,
max_clients: Optional[int] = None,
client_ttl: Optional[int] = None
):
""" """
初始化注册表 初始化注册表
Args: Args:
max_clients: 最大缓存客户端数量 max_clients: 最大缓存客户端数量(默认使用配置)
client_ttl: 客户端过期时间(秒,默认1小时 client_ttl: 客户端过期时间(秒,默认使用配置)
""" """
# 存储格式: {plugin_id: (client, last_access_time)} # 存储格式: {plugin_id: SessionInfo}
self._clients: OrderedDict[str, Tuple[HTTPMCPClient, float]] = OrderedDict() self._sessions: Dict[str, SessionInfo] = {}
# 全局锁用于保护会话字典
self._sessions_lock = asyncio.Lock()
# 细粒度锁:每个用户一个锁 # 细粒度锁:每个用户一个锁
self._user_locks: Dict[str, asyncio.Lock] = {} self._user_locks: Dict[str, asyncio.Lock] = {}
self._locks_lock = asyncio.Lock() # 保护locks字典本身 self._locks_lock = asyncio.Lock() # 保护locks字典本身
# 配置参数 # 配置参数(使用配置常量)
self._max_clients = max_clients self._max_clients = max_clients or mcp_config.MAX_CLIENTS
self._client_ttl = client_ttl self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS
# 共享HTTP客户端池(用于所有MCP HTTP请求)
self._shared_http_client = httpx.AsyncClient(
limits=httpx.Limits(
max_keepalive_connections=100,
max_connections=200,
keepalive_expiry=30.0
),
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=5.0),
headers={
"User-Agent": "MuMuAINovel-MCP-Client/1.0"
}
)
# 启动后台清理任务 # 启动后台清理任务
self._cleanup_task = None self._cleanup_task = None
self._start_cleanup_task() self._health_check_task = None
self._start_background_tasks()
def _start_cleanup_task(self): def _start_background_tasks(self):
"""启动后台清理任务""" """启动后台任务"""
if self._cleanup_task is None: if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop()) self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("✅ MCP插件注册表后台清理任务已启动") logger.info("✅ MCP插件注册表后台清理任务已启动")
if self._health_check_task is None:
self._health_check_task = asyncio.create_task(self._health_check_loop())
logger.info("✅ MCP会话健康检查任务已启动")
async def _cleanup_loop(self): async def _cleanup_loop(self):
"""后台清理过期客户端""" """后台清理过期客户端"""
while True: while True:
try: try:
await asyncio.sleep(300) # 每5分钟清理一次 await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS)
await self._cleanup_expired_clients() await self._cleanup_expired_sessions()
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"清理任务异常: {e}") logger.error(f"清理任务异常: {e}")
async def _cleanup_expired_clients(self): async def _health_check_loop(self):
"""清理过期的客户端""" """后台健康检查"""
while True:
try:
await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS)
await self._check_session_health()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"健康检查任务异常: {e}")
async def _cleanup_expired_sessions(self):
"""清理过期的会话"""
now = time.time() now = time.time()
expired_ids = [] expired_ids = []
# 收集过期的plugin_id async with self._sessions_lock:
for plugin_id, (client, last_access) in list(self._clients.items()): # 收集过期的plugin_id
if now - last_access > self._client_ttl: for plugin_id, session in list(self._sessions.items()):
expired_ids.append(plugin_id) if now - session.last_access > self._client_ttl:
expired_ids.append(plugin_id)
if expired_ids: if expired_ids:
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP客户端") logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话")
for plugin_id in expired_ids: for plugin_id in expired_ids:
# 提取user_id来获取对应的锁 # 提取user_id来获取对应的锁
user_id = plugin_id.split(':', 1)[0] user_id = plugin_id.split(':', 1)[0]
user_lock = await self._get_user_lock(user_id) user_lock = await self._get_user_lock(user_id)
async with user_lock: async with user_lock:
if plugin_id in self._clients: async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id) if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
async def _check_session_health(self):
"""增强的会话健康检查"""
async with self._sessions_lock:
for plugin_id, session in list(self._sessions.items()):
# 计算错误率
if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK:
error_rate = session.error_count / session.request_count
# 动态调整状态(使用配置常量)
if error_rate > mcp_config.ERROR_RATE_CRITICAL:
if session.status != "error":
session.status = "error"
logger.error(
f"❌ 会话 {plugin_id} 错误率过高 "
f"({error_rate:.1%}), 标记为error"
)
elif error_rate > mcp_config.ERROR_RATE_WARNING:
if session.status == "active":
session.status = "degraded"
logger.warning(
f"⚠️ 会话 {plugin_id} 健康状况下降 "
f"(错误率: {error_rate:.1%})"
)
elif session.status == "degraded":
# 错误率降低,恢复正常
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
# 检查长时间无活动的会话
idle_time = time.time() - session.last_access
if idle_time > mcp_config.IDLE_TIMEOUT_SECONDS:
logger.info(
f"💤 会话 {plugin_id} 空闲 {idle_time/60:.1f} 分钟,"
f"准备清理"
)
async def _get_user_lock(self, user_id: str) -> asyncio.Lock: async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
""" """
@@ -103,26 +163,33 @@ class MCPPluginRegistry:
self._user_locks[user_id] = asyncio.Lock() self._user_locks[user_id] = asyncio.Lock()
return self._user_locks[user_id] return self._user_locks[user_id]
def _touch_client(self, plugin_id: str): def _touch_session(self, plugin_id: str):
""" """
更新客户端的最后访问时间(LRU 更新会话的最后访问时间(需要在锁内调用
Args: Args:
plugin_id: 插件ID plugin_id: 插件ID
""" """
if plugin_id in self._clients: if plugin_id in self._sessions:
client, _ = self._clients[plugin_id] session = self._sessions[plugin_id]
self._clients[plugin_id] = (client, time.time()) session.last_access = time.time()
# 移到末尾(LRU session.request_count += 1
self._clients.move_to_end(plugin_id)
async def _evict_lru_client(self): async def _evict_lru_session(self):
"""驱逐最久未使用的客户端(当达到max_clients限制时)""" """驱逐最久未使用的会话(当达到max_clients限制时)"""
if len(self._clients) >= self._max_clients: if len(self._sessions) >= self._max_clients:
# 获取最旧的plugin_id # 找到最旧的会话
oldest_id = next(iter(self._clients)) oldest_id = None
logger.info(f"📤 达到最大客户端数量限制,驱逐: {oldest_id}") oldest_time = float('inf')
await self._unload_plugin_unsafe(oldest_id)
for plugin_id, session in self._sessions.items():
if session.last_access < oldest_time:
oldest_time = session.last_access
oldest_id = plugin_id
if oldest_id:
logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}")
await self._unload_plugin_unsafe(oldest_id)
async def load_plugin(self, plugin: MCPPlugin) -> bool: async def load_plugin(self, plugin: MCPPlugin) -> bool:
""" """
@@ -141,11 +208,12 @@ class MCPPluginRegistry:
plugin_id = f"{plugin.user_id}:{plugin.plugin_name}" plugin_id = f"{plugin.user_id}:{plugin.plugin_name}"
# 如果已加载,先卸载 # 如果已加载,先卸载
if plugin_id in self._clients: async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id) if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
# 检查是否需要驱逐LRU客户端
await self._evict_lru_client() # 检查是否需要驱逐LRU会话
await self._evict_lru_session()
# 目前只支持HTTP类型 # 目前只支持HTTP类型
if plugin.plugin_type == "http": if plugin.plugin_type == "http":
@@ -153,18 +221,30 @@ class MCPPluginRegistry:
logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}") logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}")
return False return False
# 使用共享HTTP连接池创建客户端 # 为每个插件创建独立的HTTP客户端
client = HTTPMCPClient( client = HTTPMCPClient(
url=plugin.server_url, url=plugin.server_url,
headers=plugin.headers or {}, headers=plugin.headers or {},
env=plugin.env or {}, env=plugin.env or {},
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0, timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
http_client=self._shared_http_client # 传入共享连接池
) )
# 存储客户端和当前时间戳 # 创建会话信息
self._clients[plugin_id] = (client, time.time()) now = time.time()
logger.info(f"✅ 加载MCP插件: {plugin_id}") session = SessionInfo(
client=client,
created_at=now,
last_access=now,
request_count=0,
error_count=0,
status="active"
)
# 存储会话
async with self._sessions_lock:
self._sessions[plugin_id] = session
logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)")
return True return True
else: else:
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}") logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
@@ -186,18 +266,19 @@ class MCPPluginRegistry:
user_lock = await self._get_user_lock(user_id) user_lock = await self._get_user_lock(user_id)
async with user_lock: async with user_lock:
plugin_id = f"{user_id}:{plugin_name}" plugin_id = f"{user_id}:{plugin_name}"
await self._unload_plugin_unsafe(plugin_id) async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
async def _unload_plugin_unsafe(self, plugin_id: str): async def _unload_plugin_unsafe(self, plugin_id: str):
"""卸载插件(不加锁,内部使用)""" """卸载插件(不加锁,内部使用,需要在sessions_lock内调用"""
if plugin_id in self._clients: if plugin_id in self._sessions:
client, _ = self._clients[plugin_id] # 解包 (client, timestamp) session = self._sessions[plugin_id]
try: try:
await client.close() await session.client.close()
except Exception as e: except Exception as e:
logger.error(f"关闭插件客户端失败 {plugin_id}: {e}") logger.error(f"关闭插件客户端失败 {plugin_id}: {e}")
del self._clients[plugin_id] del self._sessions[plugin_id]
logger.info(f"卸载MCP插件: {plugin_id}") logger.info(f"卸载MCP插件: {plugin_id}")
async def reload_plugin(self, plugin: MCPPlugin) -> bool: async def reload_plugin(self, plugin: MCPPlugin) -> bool:
@@ -215,7 +296,7 @@ class MCPPluginRegistry:
def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]: def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]:
""" """
获取插件客户端(支持LRU访问时间更新) 获取插件客户端(线程安全,支持访问时间更新)
Args: Args:
user_id: 用户ID user_id: 用户ID
@@ -225,13 +306,68 @@ class MCPPluginRegistry:
客户端实例或None 客户端实例或None
""" """
plugin_id = f"{user_id}:{plugin_name}" plugin_id = f"{user_id}:{plugin_name}"
entry = self._clients.get(plugin_id)
if entry: session = self._sessions.get(plugin_id)
# 更新访问时间(LRU if session:
self._touch_client(plugin_id) # 检查会话状态
return entry[0] # 返回客户端对象 if session.status == "error":
logger.warning(
f"⚠️ 会话 {plugin_id} 处于错误状态,"
f"建议调用者重新加载插件"
)
# 不返回错误状态的客户端
return None
# ✅ 使用锁保护状态更新,避免并发问题
# 注意:这里使用原子操作更新简单字段,不需要异步锁
session.last_access = time.time()
session.request_count += 1
return session.client
return None return None
async def get_or_reconnect_client(
self,
user_id: str,
plugin_name: str,
plugin: MCPPlugin
) -> HTTPMCPClient:
"""
获取或重连客户端(自动处理错误状态)
Args:
user_id: 用户ID
plugin_name: 插件名称
plugin: 插件配置对象
Returns:
客户端实例
Raises:
ValueError: 插件加载失败
"""
plugin_id = f"{user_id}:{plugin_name}"
# 获取用户锁
user_lock = await self._get_user_lock(user_id)
async with user_lock:
session = self._sessions.get(plugin_id)
# 检查会话健康状态
if session and session.status == "error":
logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连")
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
session = None
# 如果没有会话,加载插件
if not session:
success = await self.load_plugin(plugin)
if not success:
raise ValueError(f"插件加载失败: {plugin_name}")
session = self._sessions[plugin_id]
return session.client
async def call_tool( async def call_tool(
self, self,
user_id: str, user_id: str,
@@ -240,7 +376,7 @@ class MCPPluginRegistry:
arguments: Dict[str, Any] arguments: Dict[str, Any]
) -> Any: ) -> Any:
""" """
调用插件工具 调用插件工具(带错误计数和状态管理)
Args: Args:
user_id: 用户ID user_id: 用户ID
@@ -255,18 +391,39 @@ class MCPPluginRegistry:
ValueError: 插件不存在或未启用 ValueError: 插件不存在或未启用
MCPError: 工具调用失败 MCPError: 工具调用失败
""" """
client = self.get_client(user_id, plugin_name) plugin_id = f"{user_id}:{plugin_name}"
if not client: # 获取会话
session = self._sessions.get(plugin_id)
if not session:
raise ValueError(f"插件未加载: {plugin_name}") raise ValueError(f"插件未加载: {plugin_name}")
try: try:
result = await client.call_tool(tool_name, arguments) result = await session.client.call_tool(tool_name, arguments)
logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}") logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}")
# logger.info(f"✅ 工具返回内容: {result}")
# 调用成功,重置状态(如果之前是degraded)
if session.status == "degraded":
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
return result return result
except Exception as e: except Exception as e:
logger.error(f"❌ 工具调用失败: {plugin_name}.{tool_name}, 错误: {e}") # 增加错误计数
session.error_count += 1
# 根据错误率更新状态
if session.request_count > 0:
error_rate = session.error_count / session.request_count
if error_rate > 0.5:
session.status = "error"
elif error_rate > 0.3:
session.status = "degraded"
logger.error(
f"❌ 工具调用失败: {plugin_name}.{tool_name}, "
f"错误: {e} (错误计数: {session.error_count}/{session.request_count})"
)
raise raise
async def get_plugin_tools( async def get_plugin_tools(
@@ -320,7 +477,7 @@ class MCPPluginRegistry:
async def cleanup_all(self): async def cleanup_all(self):
"""清理所有插件和资源""" """清理所有插件和资源"""
# 停止后台清理任务 # 停止后台任务
if self._cleanup_task: if self._cleanup_task:
self._cleanup_task.cancel() self._cleanup_task.cancel()
try: try:
@@ -328,19 +485,18 @@ class MCPPluginRegistry:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# 清理所有客户端 if self._health_check_task:
plugin_ids = list(self._clients.keys()) self._health_check_task.cancel()
for plugin_id in plugin_ids: try:
user_id = plugin_id.split(':', 1)[0] await self._health_check_task
user_lock = await self._get_user_lock(user_id) except asyncio.CancelledError:
async with user_lock: pass
await self._unload_plugin_unsafe(plugin_id)
# 关闭共享HTTP客户端 # 清理所有会话
try: async with self._sessions_lock:
await self._shared_http_client.aclose() plugin_ids = list(self._sessions.keys())
except Exception as e: for plugin_id in plugin_ids:
logger.error(f"关闭共享HTTP客户端失败: {e}") await self._unload_plugin_unsafe(plugin_id)
logger.info("✅ 已清理所有MCP插件和资源") logger.info("✅ 已清理所有MCP插件和资源")
+320
View File
@@ -0,0 +1,320 @@
"""MCP插件测试服务 - 专门处理插件测试逻辑"""
import time
import json
from typing import Dict, Any, Optional
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.mcp_plugin import MCPPlugin
from app.models.settings import Settings as UserSettings
from app.mcp.registry import mcp_registry
from app.services.ai_service import create_user_ai_service
from app.schemas.mcp_plugin import MCPTestResult
from app.logger import get_logger
from app.user_manager import User
logger = get_logger(__name__)
class MCPTestService:
"""MCP插件测试服务(分离的测试逻辑)"""
async def test_plugin_connection(
self,
plugin: MCPPlugin,
user_id: str
) -> MCPTestResult:
"""
简单连接测试
Args:
plugin: 插件配置
user_id: 用户ID
Returns:
测试结果
"""
start_time = time.time()
try:
# 确保插件已加载
if not mcp_registry.get_client(user_id, plugin.plugin_name):
success = await mcp_registry.load_plugin(plugin)
if not success:
return MCPTestResult(
success=False,
message="插件加载失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 测试连接并获取工具列表
test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
if test_result["success"]:
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功",
response_time_ms=response_time,
tools_count=test_result.get("tools_count", 0),
suggestions=[
f"响应时间: {response_time}ms",
f"可用工具数: {test_result.get('tools_count', 0)}"
]
)
else:
return MCPTestResult(**test_result)
except Exception as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=response_time,
error=str(e),
error_type=type(e).__name__,
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
async def test_plugin_with_ai(
self,
plugin: MCPPlugin,
user: User,
db_session: AsyncSession
) -> MCPTestResult:
"""
使用AI进行智能工具调用测试
Args:
plugin: 插件配置
user: 用户对象
db_session: 数据库会话
Returns:
测试结果
"""
start_time = time.time()
try:
# 1. 先进行连接测试
connection_result = await self.test_plugin_connection(plugin, user.user_id)
if not connection_result.success:
return connection_result
# 2. 获取工具列表
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
if not tools:
return MCPTestResult(
success=False,
message="插件没有提供任何工具",
error="工具列表为空",
response_time_ms=connection_result.response_time_ms,
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
)
# 3. 获取用户的AI设置
settings_result = await db_session.execute(
select(UserSettings).where(UserSettings.user_id == user.user_id)
)
user_settings = settings_result.scalar_one_or_none()
if not user_settings or not user_settings.api_key:
# 没有AI配置,返回简单测试结果
logger.warning("用户未配置AI服务,跳过智能测试")
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
response_time_ms=connection_result.response_time_ms,
tools_count=len(tools),
suggestions=[
f"连接测试: 成功",
f"可用工具数: {len(tools)}",
"提示: 配置AI服务后可进行智能工具调用测试"
]
)
# 4. 使用AI选择工具并生成测试参数
logger.info(f"使用AI分析工具并生成测试计划...")
ai_service = create_user_ai_service(
api_provider=user_settings.api_provider,
api_key=user_settings.api_key,
api_base_url=user_settings.api_base_url,
model_name=user_settings.llm_model,
temperature=0.3,
max_tokens=1000
)
# 转换为OpenAI Function Calling格式
openai_tools = self._convert_tools_to_openai_format(tools)
# 调用AI选择工具
prompt = f"""你是MCP插件测试助手,需要测试插件 '{plugin.plugin_name}' 的功能。
⚠️ 重要规则:生成参数时,必须严格使用工具 schema 中定义的原始参数名称,不要转换为 snake_case 或其他格式。
例如:如果 schema 中是 'nextThoughtNeeded',就必须使用 'nextThoughtNeeded',不能改成 'next_thought_needed'
请选择一个合适的工具进行测试,优先选择搜索、查询类工具。
生成真实有效的测试参数(例如搜索"人工智能最新进展"而不是"test")。
现在开始测试这个插件。"""
system_prompt = """你是专业的API测试工具。当给定工具列表时,选择一个工具并使用合适的参数调用它。
⚠️ 关键规则:调用工具时,必须严格使用 schema 中定义的原始参数名,不要自行转换命名风格。
- 如果参数名是 camelCase(如 nextThoughtNeeded),就使用 camelCase
- 如果参数名是 snake_case(如 next_thought),就使用 snake_case
- 保持与 schema 中定义的完全一致,包括大小写和命名风格"""
ai_response = await ai_service.generate_text(
prompt=prompt,
system_prompt=system_prompt,
tools=openai_tools,
tool_choice="required"
)
# 5. 检查AI是否返回工具调用
if not ai_response.get("tool_calls"):
logger.error(f"❌ AI未返回工具调用")
return MCPTestResult(
success=False,
message="❌ AI Function Calling失败",
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
tools_count=len(tools),
suggestions=[
"请确认使用的AI模型支持Function Calling",
f"当前Provider: {user_settings.api_provider}",
f"当前模型: {user_settings.llm_model}"
]
)
# 6. 解析工具调用
tool_call = ai_response["tool_calls"][0]
function = tool_call["function"]
tool_name = function["name"]
test_arguments = function["arguments"]
if isinstance(test_arguments, str):
try:
test_arguments = json.loads(test_arguments)
except json.JSONDecodeError as e:
logger.error(f"❌ 解析AI参数失败: {e}")
return MCPTestResult(
success=False,
message="❌ AI返回的参数格式错误",
error=f"无法解析参数JSON: {str(e)}",
tools_count=len(tools)
)
logger.info(f"🤖 AI选择的工具: {tool_name}")
logger.info(f"📝 AI生成的参数: {test_arguments}")
# 7. 调用MCP工具
call_start = time.time()
try:
tool_result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
tool_name,
test_arguments
)
call_end = time.time()
call_time = round((call_end - call_start) * 1000, 2)
total_time = round((call_end - start_time) * 1000, 2)
# 格式化结果
result_str = str(tool_result)
if len(result_str) > 800:
result_preview = result_str[:800] + "\n...(结果已截断)"
else:
result_preview = result_str
return MCPTestResult(
success=True,
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
response_time_ms=total_time,
tools_count=len(tools),
suggestions=[
f"🤖 AI选择: {tool_name}",
f"📝 参数: {json.dumps(test_arguments, ensure_ascii=False)}",
f"⏱️ 耗时: {call_time}ms",
f"📊 结果:\n{result_preview}"
]
)
except Exception as call_error:
call_end = time.time()
total_time = round((call_end - start_time) * 1000, 2)
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
return MCPTestResult(
success=True, # 连接成功就算测试通过
message=f"⚠️ 连接成功,但工具调用失败",
response_time_ms=total_time,
tools_count=len(tools),
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
suggestions=[
f"✅ 连接测试: 成功",
f"❌ 工具调用测试: 失败",
f"🤖 AI选择: {tool_name}",
f"❌ 错误: {str(call_error)}",
"💡 可能原因: API Key无效、参数错误或服务限制"
]
)
except Exception as e:
end_time = time.time()
total_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=total_time,
error=str(e),
error_type=type(e).__name__,
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
def _convert_tools_to_openai_format(self, tools: list) -> list:
"""将MCP工具格式转换为OpenAI Function Calling格式"""
openai_tools = []
for tool in tools:
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
}
}
if "inputSchema" in tool:
openai_tool["function"]["parameters"] = tool["inputSchema"]
openai_tools.append(openai_tool)
return openai_tools
# 全局单例
mcp_test_service = MCPTestService()
+330 -22
View File
@@ -5,26 +5,100 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import select
import asyncio import asyncio
import json import json
from datetime import datetime import time
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from collections import defaultdict
from app.models.mcp_plugin import MCPPlugin from app.models.mcp_plugin import MCPPlugin
from app.mcp.registry import mcp_registry from app.mcp.registry import mcp_registry
from app.mcp.config import mcp_config
from app.logger import get_logger from app.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
@dataclass
class ToolMetrics:
"""工具调用指标"""
total_calls: int = 0
success_calls: int = 0
failed_calls: int = 0
total_duration_ms: float = 0.0
avg_duration_ms: float = 0.0
last_call_time: Optional[datetime] = None
def update_success(self, duration_ms: float):
"""更新成功调用指标"""
self.total_calls += 1
self.success_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
def update_failure(self, duration_ms: float):
"""更新失败调用指标"""
self.total_calls += 1
self.failed_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
@property
def success_rate(self) -> float:
"""成功率"""
if self.total_calls == 0:
return 0.0
return self.success_calls / self.total_calls
@dataclass
class ToolCacheEntry:
"""工具缓存条目"""
tools: List[Dict[str, Any]]
expire_time: datetime
hit_count: int = 0
class MCPToolServiceError(Exception): class MCPToolServiceError(Exception):
"""MCP工具服务异常""" """MCP工具服务异常"""
pass pass
class MCPToolService: class MCPToolService:
"""MCP工具服务 - 统一管理MCP工具的注入和执行""" """MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)"""
def __init__(self): def __init__(
self._tool_cache = {} # 工具定义缓存 self,
self._result_cache = {} # 工具结果缓存(可选) cache_ttl_minutes: Optional[int] = None,
max_retries: Optional[int] = None
):
"""
初始化MCP工具服务
Args:
cache_ttl_minutes: 工具缓存TTL(分钟,默认使用配置)
max_retries: 最大重试次数(默认使用配置)
"""
# 工具定义缓存: {cache_key: ToolCacheEntry}
self._tool_cache: Dict[str, ToolCacheEntry] = {}
self._cache_ttl = timedelta(
minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES
)
# 调用指标: {tool_key: ToolMetrics}
self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics)
# 重试配置(使用配置常量)
self._max_retries = max_retries or mcp_config.MAX_RETRIES
self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS
self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS
logger.info(
f"✅ MCPToolService初始化完成 "
f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, "
f"最大重试={self._max_retries}次)"
)
async def get_user_enabled_tools( async def get_user_enabled_tools(
self, self,
@@ -61,7 +135,7 @@ class MCPToolService:
logger.info(f"用户 {user_id} 没有启用的MCP插件") logger.info(f"用户 {user_id} 没有启用的MCP插件")
return [] return []
# 2. 获取所有工具定义 # 2. 获取所有工具定义(使用缓存)
all_tools = [] all_tools = []
for plugin in plugins: for plugin in plugins:
try: try:
@@ -73,8 +147,8 @@ class MCPToolService:
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过") logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
continue continue
# 从registry获取该插件的工具列表 # ✅ 使用缓存获取工具列表
plugin_tools = await mcp_registry.get_plugin_tools( plugin_tools = await self._get_plugin_tools_cached(
user_id=user_id, user_id=user_id,
plugin_name=plugin.plugin_name plugin_name=plugin.plugin_name
) )
@@ -82,7 +156,7 @@ class MCPToolService:
# 格式化为Function Calling格式 # 格式化为Function Calling格式
formatted_tools = self._format_tools_for_ai( formatted_tools = self._format_tools_for_ai(
plugin_tools, plugin_tools,
plugin.plugin_name # ✅ 修复:使用正确的属性名plugin_name plugin.plugin_name
) )
all_tools.extend(formatted_tools) all_tools.extend(formatted_tools)
@@ -139,12 +213,85 @@ class MCPToolService:
return formatted_tools return formatted_tools
async def _get_plugin_tools_cached(
self,
user_id: str,
plugin_name: str
) -> List[Dict[str, Any]]:
"""
带缓存的工具列表获取
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
工具列表
"""
cache_key = f"{user_id}:{plugin_name}"
now = datetime.now()
# 检查缓存
if cache_key in self._tool_cache:
entry = self._tool_cache[cache_key]
if now < entry.expire_time:
entry.hit_count += 1
logger.debug(
f"🎯 工具缓存命中: {cache_key} "
f"(命中次数: {entry.hit_count})"
)
return entry.tools
else:
logger.debug(f"⏰ 工具缓存过期: {cache_key}")
del self._tool_cache[cache_key]
# 缓存未命中,从MCP获取
logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}")
tools = await mcp_registry.get_plugin_tools(user_id, plugin_name)
# 更新缓存
self._tool_cache[cache_key] = ToolCacheEntry(
tools=tools,
expire_time=now + self._cache_ttl,
hit_count=0
)
return tools
def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None):
"""
清理缓存
Args:
user_id: 用户ID(可选,清理特定用户的缓存)
plugin_name: 插件名称(可选,清理特定插件的缓存)
"""
if user_id is None and plugin_name is None:
# 清理所有缓存
self._tool_cache.clear()
logger.info("🧹 已清理所有工具缓存")
elif user_id and plugin_name:
# 清理特定插件缓存
cache_key = f"{user_id}:{plugin_name}"
if cache_key in self._tool_cache:
del self._tool_cache[cache_key]
logger.info(f"🧹 已清理缓存: {cache_key}")
elif user_id:
# 清理用户所有缓存
keys_to_delete = [
key for key in self._tool_cache.keys()
if key.startswith(f"{user_id}:")
]
for key in keys_to_delete:
del self._tool_cache[key]
logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)")
async def execute_tool_calls( async def execute_tool_calls(
self, self,
user_id: str, user_id: str,
tool_calls: List[Dict[str, Any]], tool_calls: List[Dict[str, Any]],
db_session: AsyncSession, db_session: AsyncSession,
timeout: float = 60.0 timeout: Optional[float] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
批量执行AI请求的工具调用(并行执行) 批量执行AI请求的工具调用(并行执行)
@@ -153,7 +300,7 @@ class MCPToolService:
user_id: 用户ID user_id: 用户ID
tool_calls: AI返回的工具调用列表 tool_calls: AI返回的工具调用列表
db_session: 数据库会话 db_session: 数据库会话
timeout: 单个工具调用的超时时间(秒,默认30秒 timeout: 单个工具调用的超时时间(秒,默认使用配置
Returns: Returns:
工具调用结果列表 工具调用结果列表
@@ -161,7 +308,10 @@ class MCPToolService:
if not tool_calls: if not tool_calls:
return [] return []
logger.info(f"开始执行 {len(tool_calls)} 个工具调用") # 使用配置的默认超时
actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS
logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s)")
# 创建异步任务列表 # 创建异步任务列表
tasks = [ tasks = [
@@ -169,7 +319,7 @@ class MCPToolService:
user_id=user_id, user_id=user_id,
tool_call=tool_call, tool_call=tool_call,
db_session=db_session, db_session=db_session,
timeout=timeout timeout=actual_timeout
) )
for tool_call in tool_calls for tool_call in tool_calls
] ]
@@ -238,18 +388,28 @@ class MCPToolService:
f"参数: {arguments}" f"参数: {arguments}"
) )
# 设置超时 # ✅ 使用带重试的调用
tool_key = f"{plugin_name}.{tool_name}"
start_time = time.time()
try: try:
result = await asyncio.wait_for( result = await self._call_tool_with_retry(
mcp_registry.call_tool( user_id=user_id,
user_id=user_id, plugin_name=plugin_name,
plugin_name=plugin_name, tool_name=tool_name,
tool_name=tool_name, arguments=arguments,
arguments=arguments
),
timeout=timeout timeout=timeout
) )
# 记录成功指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_success(duration_ms)
logger.info(
f"✅ 工具调用成功: {tool_key} "
f"(耗时: {duration_ms:.2f}ms)"
)
# 成功返回 # 成功返回
return { return {
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
@@ -261,13 +421,21 @@ class MCPToolService:
} }
except asyncio.TimeoutError: except asyncio.TimeoutError:
# 记录失败指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
raise MCPToolServiceError( raise MCPToolServiceError(
f"工具调用超时(>{timeout}秒)" f"工具调用超时(>{timeout}秒)"
) )
except Exception as e: except Exception as e:
# 记录失败指标
tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
logger.error( logger.error(
f"工具 {function_name} 调用失败: {e}", f"工具 {function_name} 调用失败: {e}",
exc_info=True exc_info=True
) )
return { return {
@@ -279,6 +447,146 @@ class MCPToolService:
"error": str(e) "error": str(e)
} }
async def _call_tool_with_retry(
self,
user_id: str,
plugin_name: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: float
) -> Any:
"""
带指数退避重试的工具调用
Args:
user_id: 用户ID
plugin_name: 插件名称
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
工具执行结果
Raises:
MCPToolServiceError: 工具调用失败
asyncio.TimeoutError: 调用超时
"""
last_exception = None
for attempt in range(self._max_retries):
try:
# 尝试调用工具
result = await asyncio.wait_for(
mcp_registry.call_tool(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments
),
timeout=timeout
)
# 成功则返回
if attempt > 0:
logger.info(
f"✅ 重试成功: {plugin_name}.{tool_name} "
f"(第{attempt + 1}次尝试)"
)
return result
except asyncio.TimeoutError:
# 超时不重试,直接抛出
raise
except Exception as e:
last_exception = e
# 最后一次尝试失败
if attempt == self._max_retries - 1:
logger.error(
f"❌ 重试失败: {plugin_name}.{tool_name} "
f"(已尝试{self._max_retries}次): {e}"
)
raise MCPToolServiceError(
f"工具调用失败(已重试{self._max_retries}次): {str(e)}"
)
# 计算指数退避延迟
delay = min(
self._base_retry_delay * (2 ** attempt),
self._max_retry_delay
)
logger.warning(
f"⚠️ 工具调用失败,{delay:.1f}秒后重试 "
f"(第{attempt + 1}/{self._max_retries}次): "
f"{plugin_name}.{tool_name} - {e}"
)
await asyncio.sleep(delay)
# 理论上不会到这里,但为了类型安全
raise MCPToolServiceError(f"工具调用失败: {last_exception}")
def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
"""
获取工具调用指标
Args:
tool_name: 工具名称(可选,获取特定工具的指标)
Returns:
指标字典
"""
if tool_name:
if tool_name in self._metrics:
metric = self._metrics[tool_name]
return {
tool_name: {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": metric.success_rate,
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
}
return {}
# 返回所有工具的指标
result = {}
for tool_key, metric in self._metrics.items():
result[tool_key] = {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": round(metric.success_rate, 3),
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
return result
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
total_entries = len(self._tool_cache)
total_hits = sum(entry.hit_count for entry in self._tool_cache.values())
return {
"total_entries": total_entries,
"total_hits": total_hits,
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"key": key,
"tools_count": len(entry.tools),
"hit_count": entry.hit_count,
"expire_time": entry.expire_time.isoformat()
}
for key, entry in self._tool_cache.items()
]
}
async def build_tool_context( async def build_tool_context(
self, self,
tool_results: List[Dict[str, Any]], tool_results: List[Dict[str, Any]],
+2 -2
View File
@@ -226,8 +226,8 @@ class PlotAnalyzer:
) )
# 🔍 添加调试日志:查看AI返回的原始内容 # 🔍 添加调试日志:查看AI返回的原始内容
logger.info(f"🔍 AI返回类型: {type(response)}") # logger.info(f"🔍 AI返回类型: {type(response)}")
logger.info(f"🔍 AI返回内容(前500字符): {str(response)}") # logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
# 从返回的字典中提取content字段 # 从返回的字典中提取content字段
if isinstance(response, dict): if isinstance(response, dict):
+9 -6
View File
@@ -1,15 +1,15 @@
# Web框架 # Web框架
fastapi==0.109.0 fastapi==0.121.0
uvicorn[standard]==0.27.0 uvicorn[standard]==0.38.0
python-multipart==0.0.6 python-multipart==0.0.20
# 数据库 # 数据库
sqlalchemy==2.0.25 sqlalchemy==2.0.25
aiosqlite==0.19.0 aiosqlite==0.19.0
# 数据验证 # 数据验证
pydantic==2.5.3 pydantic==2.12.4
pydantic-settings==2.1.0 pydantic-settings==2.11.0
# AI服务 # AI服务
openai==2.7.0 openai==2.7.0
@@ -18,6 +18,9 @@ anthropic==0.72.0
# 工具库 # 工具库
httpx==0.28.1 httpx==0.28.1
python-dotenv==1.0.0 python-dotenv==1.0.0
# MCP官方库(Model Context Protocol Python SDK
mcp==1.21.0
# NumPy版本锁定(兼容性要求) # NumPy版本锁定(兼容性要求)
numpy==1.26.4 numpy==1.26.4
@@ -30,4 +33,4 @@ chromadb==1.3.2
transformers==4.35.2 transformers==4.35.2
# Sentence Transformers(基于PyTorch的文本embedding库) # Sentence Transformers(基于PyTorch的文本embedding库)
sentence-transformers==2.3.1 sentence-transformers==2.3.1