fix:1.优化mcp插件功能,改用mcp sdk库
This commit is contained in:
+137
-335
@@ -18,9 +18,9 @@ from app.schemas.mcp_plugin import (
|
||||
import json
|
||||
from app.user_manager import User
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.services.mcp_test_service import mcp_test_service
|
||||
from app.services.mcp_tool_service import mcp_tool_service
|
||||
from app.logger import get_logger
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.models.settings import Settings as UserSettings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -399,13 +399,8 @@ async def test_plugin(
|
||||
"""
|
||||
测试插件连接并调用工具验证功能
|
||||
|
||||
测试流程:
|
||||
1. 测试MCP服务器连接
|
||||
2. 获取工具列表
|
||||
3. 自动选择一个工具进行实际调用测试
|
||||
4. 返回完整测试结果
|
||||
使用新的MCPTestService进行测试
|
||||
"""
|
||||
import time
|
||||
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
@@ -426,356 +421,153 @@ async def test_plugin(
|
||||
suggestions=["点击开关按钮启用插件"]
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 使用新的测试服务
|
||||
try:
|
||||
# 1. 确保插件已加载
|
||||
if not mcp_registry.get_client(user.user_id, plugin.plugin_name):
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件加载失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
)
|
||||
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
|
||||
|
||||
# 2. 测试连接并获取工具列表
|
||||
test_result = await mcp_registry.test_plugin(user.user_id, plugin.plugin_name)
|
||||
|
||||
if not test_result["success"]:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = test_result.get("error", "连接测试失败")
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
return MCPTestResult(**test_result)
|
||||
|
||||
tools = test_result.get("tools", [])
|
||||
|
||||
if not tools:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "插件没有提供任何工具"
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件没有提供任何工具",
|
||||
error="工具列表为空",
|
||||
response_time_ms=test_result.get("response_time_ms"),
|
||||
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
|
||||
)
|
||||
|
||||
# 3. 使用AI智能选择工具并生成测试参数
|
||||
logger.info(f"使用AI分析工具并生成测试计划...")
|
||||
|
||||
# 获取用户的AI设置
|
||||
settings_result = await db.execute(
|
||||
select(UserSettings).where(UserSettings.user_id == user.user_id)
|
||||
)
|
||||
user_settings = settings_result.scalar_one_or_none()
|
||||
|
||||
if not user_settings or not user_settings.api_key:
|
||||
# 如果没有AI配置,回退到简单测试
|
||||
logger.warning("用户未配置AI服务,使用简单连接测试")
|
||||
# 更新插件状态
|
||||
if test_result.success:
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
plugin.last_test_at = datetime.now()
|
||||
plugin.tools = tools
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
|
||||
response_time_ms=test_result.get("response_time_ms"),
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
f"连接测试: 成功",
|
||||
f"可用工具数: {len(tools)}",
|
||||
"提示: 配置AI服务后可进行智能工具调用测试"
|
||||
]
|
||||
)
|
||||
|
||||
# 使用AI的标准Function Calling机制选择工具
|
||||
ai_service = create_user_ai_service(
|
||||
api_provider=user_settings.api_provider,
|
||||
api_key=user_settings.api_key,
|
||||
api_base_url=user_settings.api_base_url,
|
||||
model_name=user_settings.llm_model,
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 将MCP工具格式转换为OpenAI Function Calling格式
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
}
|
||||
}
|
||||
# 将 inputSchema 转换为 parameters
|
||||
if "inputSchema" in tool:
|
||||
openai_tool["function"]["parameters"] = tool["inputSchema"]
|
||||
|
||||
openai_tools.append(openai_tool)
|
||||
|
||||
logger.info(f"转换了 {len(openai_tools)} 个MCP工具为OpenAI格式")
|
||||
logger.info(f"工具列表: {[t['function']['name'] for t in openai_tools]}")
|
||||
|
||||
# 使用标准的Function Calling,将转换后的工具传递给AI
|
||||
prompt = f"""你是MCP插件测试助手,需要测试插件 '{plugin.plugin_name}' 的功能。
|
||||
|
||||
请选择一个合适的工具进行测试,优先选择搜索、查询类工具。
|
||||
生成真实有效的测试参数(例如搜索"人工智能最新进展"而不是"test")。
|
||||
|
||||
现在开始测试这个插件。"""
|
||||
|
||||
system_prompt = "你是专业的API测试工具。当给定工具列表时,选择一个工具并使用合适的参数调用它。"
|
||||
|
||||
# 调用AI的Function Calling
|
||||
logger.info(f"📞 准备调用AI Function Calling")
|
||||
logger.info(f" - Provider: {user_settings.api_provider}")
|
||||
logger.info(f" - Model: {user_settings.llm_model}")
|
||||
logger.info(f" - Tools count: {len(openai_tools)}")
|
||||
logger.debug(f" - Tools: {json.dumps(openai_tools, ensure_ascii=False, indent=2)}")
|
||||
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
tools=openai_tools, # 传递转换后的OpenAI格式工具
|
||||
tool_choice="required" # 要求AI必须选择一个工具
|
||||
)
|
||||
|
||||
logger.info(f"📥 收到AI响应")
|
||||
logger.info(f" - Response keys: {list(ai_response.keys())}")
|
||||
logger.debug(f" - Full response: {json.dumps(ai_response, ensure_ascii=False, indent=2)}")
|
||||
|
||||
# 检查AI是否请求调用工具
|
||||
if not ai_response.get("tool_calls"):
|
||||
# AI未调用工具,记录详细信息
|
||||
logger.error(f"❌ AI未返回工具调用")
|
||||
logger.error(f" - Response: {ai_response}")
|
||||
logger.error(f" - Content: {ai_response.get('content', 'N/A')}")
|
||||
logger.error(f" - Finish reason: {ai_response.get('finish_reason', 'N/A')}")
|
||||
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = "AI未返回工具调用请求"
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI Function Calling失败",
|
||||
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
"请确认使用的AI模型支持Function Calling",
|
||||
"OpenAI: 需要gpt-4, gpt-3.5-turbo等模型",
|
||||
"Anthropic: 需要claude-3系列模型",
|
||||
f"当前Provider: {user_settings.api_provider}",
|
||||
f"当前模型: {user_settings.llm_model}",
|
||||
f"AI返回内容: {ai_response.get('content', 'N/A')[:100]}"
|
||||
]
|
||||
)
|
||||
plugin.last_error = test_result.error
|
||||
|
||||
# 获取第一个工具调用
|
||||
tool_call = ai_response["tool_calls"][0]
|
||||
function = tool_call["function"]
|
||||
tool_name = function["name"]
|
||||
test_arguments = function["arguments"]
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
# AI返回的arguments可能是JSON字符串,需要解析
|
||||
if isinstance(test_arguments, str):
|
||||
try:
|
||||
test_arguments = json.loads(test_arguments)
|
||||
logger.info(f"✅ 解析AI返回的JSON字符串参数")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析AI参数失败: {e}")
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI返回的参数格式错误",
|
||||
error=f"无法解析参数JSON: {str(e)}",
|
||||
tools_count=len(tools),
|
||||
suggestions=["AI返回的参数不是有效的JSON格式"]
|
||||
)
|
||||
|
||||
logger.info(f"🤖 AI通过Function Calling选择的工具: {tool_name}")
|
||||
logger.info(f"📝 AI生成的参数: {test_arguments}")
|
||||
logger.info(f"📝 参数类型: {type(test_arguments).__name__}")
|
||||
|
||||
# 4. 使用AI选择的工具和参数调用MCP工具
|
||||
call_start = time.time()
|
||||
try:
|
||||
tool_result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
plugin.plugin_name,
|
||||
tool_name,
|
||||
test_arguments
|
||||
)
|
||||
|
||||
call_end = time.time()
|
||||
call_time = round((call_end - call_start) * 1000, 2)
|
||||
total_time = round((call_end - start_time) * 1000, 2)
|
||||
|
||||
# 6. 测试成功,更新插件状态
|
||||
plugin.status = "active"
|
||||
plugin.last_error = None
|
||||
plugin.last_test_at = datetime.now()
|
||||
plugin.tools = tools # 缓存工具列表
|
||||
await db.commit()
|
||||
|
||||
# 格式化工具结果用于显示
|
||||
result_str = str(tool_result)
|
||||
|
||||
# 如果结果太长,截取前800字符
|
||||
if len(result_str) > 800:
|
||||
result_preview = result_str[:800] + "\n...(结果已截断,完整结果请查看日志)"
|
||||
else:
|
||||
result_preview = result_str
|
||||
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
|
||||
response_time_ms=total_time,
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
f"🤖 AI (Function Calling) 选择: {tool_name}",
|
||||
f"📝 AI生成的参数: {json.dumps(test_arguments, ensure_ascii=False)}",
|
||||
f"⏱️ 调用耗时: {call_time}ms",
|
||||
f"📊 返回结果:\n{result_preview}"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as call_error:
|
||||
call_end = time.time()
|
||||
total_time = round((call_end - start_time) * 1000, 2)
|
||||
|
||||
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
|
||||
|
||||
# 工具调用失败,但连接成功
|
||||
plugin.status = "active" # 仍标记为active,因为连接是成功的
|
||||
plugin.last_error = f"工具调用测试失败: {str(call_error)}"
|
||||
plugin.last_test_at = datetime.now()
|
||||
plugin.tools = tools
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=True, # 连接成功就算测试通过
|
||||
message=f"⚠️ 连接成功,但工具调用失败",
|
||||
response_time_ms=total_time,
|
||||
tools_count=len(tools),
|
||||
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
|
||||
suggestions=[
|
||||
f"✅ 连接测试: 成功",
|
||||
f"❌ 工具调用测试: 失败",
|
||||
f"🤖 AI (Function Calling) 选择: {tool_name}",
|
||||
f"📝 AI生成的参数: {json.dumps(test_arguments, ensure_ascii=False)}",
|
||||
f"❌ 错误: {str(call_error)}",
|
||||
"💡 可能原因: API Key无效、参数错误或服务限制"
|
||||
]
|
||||
)
|
||||
return test_result
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
total_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
|
||||
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 测试失败",
|
||||
response_time_ms=total_time,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
suggestions=["请检查服务器是否在线", "请确认配置正确", "请检查API Key是否有效"]
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"测试失败: {str(e)}")
|
||||
|
||||
|
||||
def _build_test_arguments(tool_name: str, input_schema: dict, plugin_name: str) -> dict:
|
||||
async def _ensure_plugin_loaded(
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
根据工具schema智能构造测试参数
|
||||
确保插件已加载(共享逻辑)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
input_schema: 输入schema
|
||||
plugin_name: 插件名称
|
||||
plugin: 插件对象
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
测试参数字典
|
||||
"""
|
||||
# 针对常见MCP工具的默认测试参数
|
||||
test_cases = {
|
||||
# Exa搜索工具
|
||||
"search": {
|
||||
"query": "AI technology",
|
||||
"num_results": 3
|
||||
},
|
||||
"search_and_contents": {
|
||||
"query": "artificial intelligence",
|
||||
"num_results": 2
|
||||
},
|
||||
# Brave搜索
|
||||
"brave_web_search": {
|
||||
"query": "AI news",
|
||||
"count": 3
|
||||
},
|
||||
# Filesystem工具
|
||||
"read_file": {
|
||||
"path": "README.md"
|
||||
},
|
||||
"list_directory": {
|
||||
"path": "."
|
||||
},
|
||||
}
|
||||
|
||||
# 如果有针对特定工具的测试用例,使用它
|
||||
if tool_name in test_cases:
|
||||
logger.info(f"使用预定义测试参数: {test_cases[tool_name]}")
|
||||
return test_cases[tool_name]
|
||||
|
||||
# 否则根据schema自动构造
|
||||
properties = input_schema.get("properties", {})
|
||||
required = input_schema.get("required", [])
|
||||
|
||||
test_args = {}
|
||||
|
||||
for prop_name, prop_schema in properties.items():
|
||||
# 只填充必需的参数
|
||||
if prop_name not in required:
|
||||
continue
|
||||
|
||||
prop_type = prop_schema.get("type", "string")
|
||||
是否加载成功
|
||||
|
||||
# 根据参数名称和类型猜测合适的测试值
|
||||
if "query" in prop_name.lower() or "search" in prop_name.lower():
|
||||
test_args[prop_name] = "test query"
|
||||
elif "url" in prop_name.lower():
|
||||
test_args[prop_name] = "https://example.com"
|
||||
elif "path" in prop_name.lower():
|
||||
test_args[prop_name] = "."
|
||||
elif "count" in prop_name.lower() or "limit" in prop_name.lower() or "num" in prop_name.lower():
|
||||
test_args[prop_name] = 3
|
||||
elif prop_type == "string":
|
||||
test_args[prop_name] = "test"
|
||||
elif prop_type == "number" or prop_type == "integer":
|
||||
test_args[prop_name] = 1
|
||||
elif prop_type == "boolean":
|
||||
test_args[prop_name] = True
|
||||
elif prop_type == "array":
|
||||
test_args[prop_name] = []
|
||||
elif prop_type == "object":
|
||||
test_args[prop_name] = {}
|
||||
Raises:
|
||||
HTTPException: 加载失败
|
||||
"""
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
logger.info(f"插件 {plugin.plugin_name} 未加载,自动加载中...")
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"插件加载失败: {plugin.plugin_name}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_metrics(
|
||||
tool_name: Optional[str] = Query(None, description="工具名称(可选,获取特定工具的指标)"),
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
获取MCP工具调用指标
|
||||
|
||||
logger.info(f"自动构造测试参数: {test_args}")
|
||||
return test_args
|
||||
Query参数:
|
||||
- tool_name: 可选,指定工具名称获取特定工具的指标
|
||||
|
||||
Returns:
|
||||
工具调用指标字典,包含:
|
||||
- total_calls: 总调用次数
|
||||
- success_calls: 成功调用次数
|
||||
- failed_calls: 失败调用次数
|
||||
- success_rate: 成功率
|
||||
- avg_duration_ms: 平均耗时(毫秒)
|
||||
- last_call_time: 最后调用时间
|
||||
"""
|
||||
metrics = mcp_tool_service.get_metrics(tool_name)
|
||||
|
||||
return {
|
||||
"metrics": metrics,
|
||||
"tool_name": tool_name,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/cache/stats")
|
||||
async def get_cache_stats(
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
获取工具缓存统计信息
|
||||
|
||||
Returns:
|
||||
缓存统计信息,包含:
|
||||
- total_entries: 缓存条目总数
|
||||
- total_hits: 缓存总命中次数
|
||||
- cache_ttl_minutes: 缓存TTL(分钟)
|
||||
- entries: 各缓存条目详情
|
||||
"""
|
||||
stats = mcp_tool_service.get_cache_stats()
|
||||
|
||||
return {
|
||||
"cache_stats": stats,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache(
|
||||
user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
plugin_name: Optional[str] = Query(None, description="插件名称(可选)"),
|
||||
user: User = Depends(require_login)
|
||||
):
|
||||
"""
|
||||
清理工具缓存
|
||||
|
||||
Query参数:
|
||||
- user_id: 可选,清理特定用户的缓存
|
||||
- plugin_name: 可选,清理特定插件的缓存
|
||||
|
||||
说明:
|
||||
- 不提供任何参数:清理所有缓存
|
||||
- 只提供user_id:清理该用户的所有缓存
|
||||
- 提供user_id和plugin_name:清理特定插件的缓存
|
||||
"""
|
||||
# 非管理员只能清理自己的缓存
|
||||
if user_id and user_id != user.user_id:
|
||||
raise HTTPException(status_code=403, detail="无权清理其他用户的缓存")
|
||||
|
||||
# 如果没有指定user_id,使用当前用户
|
||||
target_user_id = user_id or user.user_id
|
||||
|
||||
mcp_tool_service.clear_cache(target_user_id, plugin_name)
|
||||
|
||||
message = "已清理"
|
||||
if plugin_name:
|
||||
message += f"插件 {plugin_name} 的缓存"
|
||||
elif target_user_id:
|
||||
message += f"用户 {target_user_id} 的所有缓存"
|
||||
else:
|
||||
message += "所有缓存"
|
||||
|
||||
logger.info(f"用户 {user.user_id} {message}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": message,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{plugin_id}/tools")
|
||||
@@ -802,6 +594,9 @@ async def get_plugin_tools(
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
await _ensure_plugin_loaded(plugin, user.user_id)
|
||||
|
||||
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
# 更新缓存
|
||||
@@ -813,6 +608,8 @@ async def get_plugin_tools(
|
||||
"tools": tools,
|
||||
"count": len(tools)
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {plugin.plugin_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取工具列表失败: {str(e)}")
|
||||
@@ -843,6 +640,9 @@ async def call_mcp_tool(
|
||||
raise HTTPException(status_code=400, detail="插件未启用")
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
await _ensure_plugin_loaded(plugin, user.user_id)
|
||||
|
||||
# 调用工具
|
||||
result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
@@ -857,6 +657,8 @@ async def call_mcp_tool(
|
||||
"tool_name": data.tool_name,
|
||||
"result": result
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
+2
-3
@@ -12,6 +12,7 @@ from app.database import close_db, _session_stats
|
||||
from app.logger import setup_logging, get_logger
|
||||
from app.middleware import RequestIDMiddleware
|
||||
from app.middleware.auth_middleware import AuthMiddleware
|
||||
from app.mcp.registry import mcp_registry
|
||||
|
||||
setup_logging(
|
||||
level=config_settings.log_level,
|
||||
@@ -27,9 +28,7 @@ logger = get_logger(__name__)
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
logger.info("应用启动,等待用户登录...")
|
||||
|
||||
# 导入MCP注册表
|
||||
from app.mcp.registry import mcp_registry
|
||||
logger.info("💡 MCP插件采用延迟加载策略,将在用户首次使用时自动加载")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""MCP模块配置常量"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MCPConfig:
|
||||
"""MCP模块配置常量(不可变)"""
|
||||
|
||||
# 连接池配置
|
||||
MAX_CLIENTS: int = 1000 # 最大客户端数量
|
||||
CLIENT_TTL_SECONDS: int = 3600 # 客户端过期时间(1小时)
|
||||
IDLE_TIMEOUT_SECONDS: int = 1800 # 空闲超时(30分钟)
|
||||
|
||||
# 健康检查配置
|
||||
HEALTH_CHECK_INTERVAL_SECONDS: int = 60 # 健康检查间隔
|
||||
ERROR_RATE_CRITICAL: float = 0.7 # 严重错误率阈值
|
||||
ERROR_RATE_WARNING: float = 0.4 # 警告错误率阈值
|
||||
MIN_REQUESTS_FOR_HEALTH_CHECK: int = 10 # 进行健康检查的最小请求数
|
||||
|
||||
# 清理任务配置
|
||||
CLEANUP_INTERVAL_SECONDS: int = 300 # 清理任务间隔(5分钟)
|
||||
|
||||
# 缓存配置
|
||||
TOOL_CACHE_TTL_MINUTES: int = 10 # 工具定义缓存TTL
|
||||
|
||||
# 重试配置
|
||||
MAX_RETRIES: int = 3 # 最大重试次数
|
||||
BASE_RETRY_DELAY_SECONDS: float = 1.0 # 基础重试延迟
|
||||
MAX_RETRY_DELAY_SECONDS: float = 10.0 # 最大重试延迟
|
||||
|
||||
# 超时配置
|
||||
DEFAULT_TIMEOUT_SECONDS: float = 60.0 # 默认超时时间
|
||||
TOOL_CALL_TIMEOUT_SECONDS: float = 60.0 # 工具调用超时时间
|
||||
|
||||
# 日志配置
|
||||
LOG_TOOL_ARGUMENTS: bool = True # 是否记录工具参数
|
||||
LOG_TOOL_RESULTS: bool = False # 是否记录工具结果(可能很大)
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
mcp_config = MCPConfig()
|
||||
+185
-197
@@ -1,8 +1,13 @@
|
||||
"""HTTP MCP客户端 - 实现JSON-RPC 2.0协议"""
|
||||
import httpx
|
||||
"""HTTP MCP客户端 - 使用官方 MCP Python SDK 实现"""
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from mcp import ClientSession, types
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from app.logger import get_logger
|
||||
import time
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -13,15 +18,14 @@ class MCPError(Exception):
|
||||
|
||||
|
||||
class HTTPMCPClient:
|
||||
"""HTTP模式MCP客户端(类似Cursor/Claude Code实现)"""
|
||||
"""HTTP模式MCP客户端(基于官方 MCP Python SDK)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 60.0,
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
timeout: float = 60.0
|
||||
):
|
||||
"""
|
||||
初始化HTTP MCP客户端
|
||||
@@ -31,162 +35,79 @@ class HTTPMCPClient:
|
||||
headers: HTTP请求头
|
||||
env: 环境变量(用于API Key等)
|
||||
timeout: 超时时间(秒)
|
||||
http_client: 可选的共享HTTP客户端(用于连接池复用)
|
||||
"""
|
||||
self.url = url.rstrip('/')
|
||||
self.headers = headers or {}
|
||||
self.env = env or {}
|
||||
self.timeout = timeout
|
||||
|
||||
# 设置MCP必需的Accept头
|
||||
# MCP服务器要求客户端必须接受 application/json 和 text/event-stream
|
||||
if 'Accept' not in self.headers:
|
||||
self.headers['Accept'] = 'application/json, text/event-stream'
|
||||
|
||||
# 设置Content-Type
|
||||
if 'Content-Type' not in self.headers:
|
||||
self.headers['Content-Type'] = 'application/json'
|
||||
|
||||
# 如果env中有API Key,添加到headers
|
||||
if 'API_KEY' in self.env:
|
||||
self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}'
|
||||
|
||||
# 使用共享客户端或创建新客户端
|
||||
self._owns_client = http_client is None
|
||||
if http_client:
|
||||
self.client = http_client
|
||||
else:
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(timeout),
|
||||
headers=self.headers
|
||||
)
|
||||
self._request_id = 0
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._context_stack = [] # 保存上下文管理器栈
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
"""获取下一个请求ID"""
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
async def _call_jsonrpc(
|
||||
self,
|
||||
method: str,
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
调用JSON-RPC 2.0方法
|
||||
|
||||
Args:
|
||||
method: 方法名
|
||||
params: 参数
|
||||
|
||||
Returns:
|
||||
响应结果
|
||||
|
||||
Raises:
|
||||
MCPError: 调用失败时抛出
|
||||
"""
|
||||
request_id = self._next_request_id()
|
||||
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params or {}
|
||||
}
|
||||
|
||||
try:
|
||||
logger.debug(f"MCP请求: {method} -> {self.url}")
|
||||
|
||||
response = await self.client.post(
|
||||
self.url,
|
||||
json=payload,
|
||||
headers=self.headers # 显式传递headers(对于共享客户端很重要)
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# 获取响应内容
|
||||
response_text = response.text
|
||||
content_type = response.headers.get('content-type', '')
|
||||
|
||||
# 如果是空响应
|
||||
if not response_text or response_text.strip() == '':
|
||||
raise MCPError("服务器返回空响应")
|
||||
|
||||
# 处理SSE格式响应
|
||||
if 'text/event-stream' in content_type or response_text.startswith('event:'):
|
||||
logger.debug("检测到SSE格式响应,开始解析")
|
||||
data = self._parse_sse_response(response_text)
|
||||
else:
|
||||
# 标准JSON响应
|
||||
async def _ensure_connected(self):
|
||||
"""确保连接已建立"""
|
||||
async with self._lock:
|
||||
if self._session is None:
|
||||
try:
|
||||
data = response.json()
|
||||
except ValueError as e:
|
||||
logger.error(f"JSON解析失败,响应内容: {response_text[:500]}")
|
||||
raise MCPError(f"无法解析JSON响应: {str(e)}")
|
||||
|
||||
# 检查JSON-RPC错误
|
||||
if "error" in data:
|
||||
error = data["error"]
|
||||
error_msg = error.get("message", "Unknown error")
|
||||
error_code = error.get("code", -1)
|
||||
logger.error(f"MCP错误 [{error_code}]: {error_msg}")
|
||||
raise MCPError(f"[{error_code}] {error_msg}")
|
||||
|
||||
if "result" not in data:
|
||||
raise MCPError("响应中缺少result字段")
|
||||
|
||||
return data["result"]
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP错误 {e.response.status_code}: {e.response.text}")
|
||||
raise MCPError(f"HTTP错误 {e.response.status_code}: {e.response.text}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"请求错误: {str(e)}")
|
||||
raise MCPError(f"请求错误: {str(e)}")
|
||||
except MCPError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"未知错误: {str(e)}")
|
||||
raise MCPError(f"未知错误: {str(e)}")
|
||||
logger.info(f"🔗 连接到MCP服务器: {self.url}")
|
||||
|
||||
# 使用官方 SDK 的 streamable_http_client
|
||||
# 保存上下文管理器以便后续正确清理
|
||||
stream_context = streamablehttp_client(self.url)
|
||||
read_stream, write_stream, _ = await stream_context.__aenter__()
|
||||
self._context_stack.append(('stream', stream_context))
|
||||
|
||||
# 创建客户端会话
|
||||
self._session = ClientSession(read_stream, write_stream)
|
||||
session_context = self._session
|
||||
await session_context.__aenter__()
|
||||
self._context_stack.append(('session', session_context))
|
||||
|
||||
# 初始化会话
|
||||
await self._session.initialize()
|
||||
self._initialized = True
|
||||
|
||||
logger.info(f"✅ MCP会话初始化成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MCP连接失败: {e}")
|
||||
await self._cleanup()
|
||||
raise MCPError(f"连接MCP服务器失败: {str(e)}")
|
||||
|
||||
def _parse_sse_response(self, sse_text: str) -> Dict[str, Any]:
|
||||
async def _cleanup(self):
|
||||
"""清理连接资源(按照进入的相反顺序退出)"""
|
||||
# 按照LIFO顺序清理上下文
|
||||
while self._context_stack:
|
||||
ctx_type, ctx = self._context_stack.pop()
|
||||
try:
|
||||
await ctx.__aexit__(None, None, None)
|
||||
except RuntimeError as e:
|
||||
# 忽略 anyio 的任务上下文错误(在关闭时可能发生)
|
||||
if "cancel scope" in str(e).lower() or "different task" in str(e).lower():
|
||||
logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}")
|
||||
else:
|
||||
logger.error(f"清理{ctx_type}上下文失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理{ctx_type}上下文失败: {e}")
|
||||
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> Dict[str, Any]:
|
||||
"""
|
||||
解析SSE格式的响应
|
||||
初始化MCP会话
|
||||
|
||||
SSE格式示例:
|
||||
event: message
|
||||
data: {"result": {...}}
|
||||
|
||||
Args:
|
||||
sse_text: SSE格式的文本
|
||||
|
||||
Returns:
|
||||
解析后的JSON数据
|
||||
初始化响应
|
||||
"""
|
||||
import json
|
||||
|
||||
lines = sse_text.strip().split('\n')
|
||||
data_lines = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith('data:'):
|
||||
# 提取data后面的内容
|
||||
data_content = line[5:].strip()
|
||||
data_lines.append(data_content)
|
||||
|
||||
if not data_lines:
|
||||
raise MCPError("SSE响应中没有找到data字段")
|
||||
|
||||
# 合并所有data行(某些SSE可能分多行)
|
||||
full_data = ''.join(data_lines)
|
||||
|
||||
try:
|
||||
return json.loads(full_data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析SSE data失败: {full_data[:200]}")
|
||||
raise MCPError(f"SSE data不是有效的JSON: {str(e)}")
|
||||
await self._ensure_connected()
|
||||
return {"status": "initialized"}
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -196,13 +117,26 @@ class HTTPMCPClient:
|
||||
工具列表
|
||||
"""
|
||||
try:
|
||||
result = await self._call_jsonrpc("tools/list")
|
||||
tools = result.get("tools", [])
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools = []
|
||||
for tool in result.tools:
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"inputSchema": tool.inputSchema
|
||||
}
|
||||
tools.append(tool_dict)
|
||||
|
||||
logger.info(f"获取到 {len(tools)} 个工具")
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}")
|
||||
raise
|
||||
raise MCPError(f"获取工具列表失败: {str(e)}")
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
@@ -220,33 +154,38 @@ class HTTPMCPClient:
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
logger.info(f"调用工具: {tool_name}")
|
||||
logger.debug(f"参数: {arguments}")
|
||||
|
||||
result = await self._call_jsonrpc(
|
||||
"tools/call",
|
||||
{
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
)
|
||||
result = await self._session.call_tool(tool_name, arguments)
|
||||
|
||||
# MCP返回的result通常包含content数组
|
||||
if isinstance(result, dict) and "content" in result:
|
||||
content = result["content"]
|
||||
if isinstance(content, list) and len(content) > 0:
|
||||
# 提取第一个content项的text
|
||||
first_content = content[0]
|
||||
if isinstance(first_content, dict) and "text" in first_content:
|
||||
return first_content["text"]
|
||||
return first_content
|
||||
return content
|
||||
# 处理返回结果
|
||||
# MCP SDK 返回 CallToolResult 对象
|
||||
if result.content:
|
||||
# 提取第一个content的文本
|
||||
for content in result.content:
|
||||
if isinstance(content, types.TextContent):
|
||||
return content.text
|
||||
elif isinstance(content, types.ImageContent):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": content.data,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
# 如果没有文本内容,返回原始内容
|
||||
return result.content[0] if result.content else None
|
||||
|
||||
return result
|
||||
# 如果有结构化内容(2025-06-18规范)
|
||||
if hasattr(result, 'structuredContent') and result.structuredContent:
|
||||
return result.structuredContent
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {tool_name}, 错误: {e}")
|
||||
raise
|
||||
raise MCPError(f"调用工具失败: {str(e)}")
|
||||
|
||||
async def list_resources(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -256,13 +195,27 @@ class HTTPMCPClient:
|
||||
资源列表
|
||||
"""
|
||||
try:
|
||||
result = await self._call_jsonrpc("resources/list")
|
||||
resources = result.get("resources", [])
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.list_resources()
|
||||
|
||||
# 转换为字典格式
|
||||
resources = []
|
||||
for resource in result.resources:
|
||||
resource_dict = {
|
||||
"uri": str(resource.uri),
|
||||
"name": resource.name,
|
||||
"description": resource.description or "",
|
||||
"mimeType": resource.mimeType or ""
|
||||
}
|
||||
resources.append(resource_dict)
|
||||
|
||||
logger.info(f"获取到 {len(resources)} 个资源")
|
||||
return resources
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取资源列表失败: {e}")
|
||||
raise
|
||||
raise MCPError(f"获取资源列表失败: {str(e)}")
|
||||
|
||||
async def read_resource(self, uri: str) -> Any:
|
||||
"""
|
||||
@@ -275,14 +228,33 @@ class HTTPMCPClient:
|
||||
资源内容
|
||||
"""
|
||||
try:
|
||||
result = await self._call_jsonrpc(
|
||||
"resources/read",
|
||||
{"uri": uri}
|
||||
)
|
||||
return result
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.read_resource(AnyUrl(uri))
|
||||
|
||||
# 提取资源内容
|
||||
if result.contents:
|
||||
content = result.contents[0]
|
||||
if isinstance(content, types.TextContent):
|
||||
return content.text
|
||||
elif isinstance(content, types.ImageContent):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": content.data,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
elif isinstance(content, types.BlobResourceContents):
|
||||
return {
|
||||
"type": "blob",
|
||||
"blob": content.blob,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取资源失败: {uri}, 错误: {e}")
|
||||
raise
|
||||
raise MCPError(f"读取资源失败: {str(e)}")
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -291,10 +263,12 @@ class HTTPMCPClient:
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 尝试列举工具来测试连接
|
||||
# 尝试连接并列举工具
|
||||
await self._ensure_connected()
|
||||
tools = await self.list_tools()
|
||||
|
||||
end_time = time.time()
|
||||
@@ -307,22 +281,7 @@ class HTTPMCPClient:
|
||||
"tools_count": len(tools),
|
||||
"tools": tools
|
||||
}
|
||||
except MCPError as e:
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"message": "连接测试失败",
|
||||
"response_time_ms": response_time,
|
||||
"error": str(e),
|
||||
"error_type": "MCPError",
|
||||
"suggestions": [
|
||||
"请检查服务器URL是否正确",
|
||||
"请确认API Key是否有效",
|
||||
"请检查网络连接"
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
@@ -334,12 +293,41 @@ class HTTPMCPClient:
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"suggestions": [
|
||||
"请检查服务器是否在线",
|
||||
"请确认配置是否正确"
|
||||
"请检查服务器URL是否正确",
|
||||
"请确认API Key是否有效",
|
||||
"请检查网络连接",
|
||||
"请确认MCP服务器是否在线"
|
||||
]
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端(仅在拥有客户端所有权时关闭)"""
|
||||
if self._owns_client and self.client:
|
||||
await self.client.aclose()
|
||||
"""关闭客户端连接"""
|
||||
logger.info(f"关闭MCP客户端: {self.url}")
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_mcp_client(
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 60.0
|
||||
):
|
||||
"""
|
||||
创建MCP客户端的上下文管理器
|
||||
|
||||
Args:
|
||||
url: MCP服务器URL
|
||||
headers: HTTP请求头
|
||||
env: 环境变量
|
||||
timeout: 超时时间
|
||||
|
||||
Yields:
|
||||
HTTPMCPClient实例
|
||||
"""
|
||||
client = HTTPMCPClient(url, headers, env, timeout)
|
||||
try:
|
||||
await client.initialize()
|
||||
yield client
|
||||
finally:
|
||||
await client.close()
|
||||
+251
-95
@@ -1,92 +1,152 @@
|
||||
"""MCP插件注册表 - 管理运行时插件实例"""
|
||||
import asyncio
|
||||
import time
|
||||
import httpx
|
||||
from typing import Dict, Optional, Any, List, Tuple
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Optional, Any, List
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from app.mcp.http_client import HTTPMCPClient, MCPError
|
||||
from app.mcp.config import mcp_config
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""会话信息"""
|
||||
client: HTTPMCPClient
|
||||
created_at: float
|
||||
last_access: float
|
||||
request_count: int = 0
|
||||
error_count: int = 0
|
||||
status: str = "active" # active, degraded, error
|
||||
|
||||
|
||||
class MCPPluginRegistry:
|
||||
"""MCP插件注册表 - 管理运行时插件实例(多用户优化版)"""
|
||||
"""MCP插件注册表 - 管理运行时插件实例(优化版)"""
|
||||
|
||||
def __init__(self, max_clients: int = 1000, client_ttl: int = 3600):
|
||||
def __init__(
|
||||
self,
|
||||
max_clients: Optional[int] = None,
|
||||
client_ttl: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化注册表
|
||||
|
||||
Args:
|
||||
max_clients: 最大缓存客户端数量
|
||||
client_ttl: 客户端过期时间(秒),默认1小时
|
||||
max_clients: 最大缓存客户端数量(默认使用配置)
|
||||
client_ttl: 客户端过期时间(秒,默认使用配置)
|
||||
"""
|
||||
# 存储格式: {plugin_id: (client, last_access_time)}
|
||||
self._clients: OrderedDict[str, Tuple[HTTPMCPClient, float]] = OrderedDict()
|
||||
# 存储格式: {plugin_id: SessionInfo}
|
||||
self._sessions: Dict[str, SessionInfo] = {}
|
||||
|
||||
# 全局锁用于保护会话字典
|
||||
self._sessions_lock = asyncio.Lock()
|
||||
|
||||
# 细粒度锁:每个用户一个锁
|
||||
self._user_locks: Dict[str, asyncio.Lock] = {}
|
||||
self._locks_lock = asyncio.Lock() # 保护locks字典本身
|
||||
|
||||
# 配置参数
|
||||
self._max_clients = max_clients
|
||||
self._client_ttl = client_ttl
|
||||
|
||||
# 共享HTTP客户端池(用于所有MCP HTTP请求)
|
||||
self._shared_http_client = httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=100,
|
||||
max_connections=200,
|
||||
keepalive_expiry=30.0
|
||||
),
|
||||
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=5.0),
|
||||
headers={
|
||||
"User-Agent": "MuMuAINovel-MCP-Client/1.0"
|
||||
}
|
||||
)
|
||||
# 配置参数(使用配置常量)
|
||||
self._max_clients = max_clients or mcp_config.MAX_CLIENTS
|
||||
self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS
|
||||
|
||||
# 启动后台清理任务
|
||||
self._cleanup_task = None
|
||||
self._start_cleanup_task()
|
||||
self._health_check_task = None
|
||||
self._start_background_tasks()
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动后台清理任务"""
|
||||
def _start_background_tasks(self):
|
||||
"""启动后台任务"""
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("✅ MCP插件注册表后台清理任务已启动")
|
||||
|
||||
if self._health_check_task is None:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
logger.info("✅ MCP会话健康检查任务已启动")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""后台清理过期客户端"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # 每5分钟清理一次
|
||||
await self._cleanup_expired_clients()
|
||||
await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS)
|
||||
await self._cleanup_expired_sessions()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理任务异常: {e}")
|
||||
|
||||
async def _cleanup_expired_clients(self):
|
||||
"""清理过期的客户端"""
|
||||
async def _health_check_loop(self):
|
||||
"""后台健康检查"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS)
|
||||
await self._check_session_health()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查任务异常: {e}")
|
||||
|
||||
async def _cleanup_expired_sessions(self):
|
||||
"""清理过期的会话"""
|
||||
now = time.time()
|
||||
expired_ids = []
|
||||
|
||||
# 收集过期的plugin_id
|
||||
for plugin_id, (client, last_access) in list(self._clients.items()):
|
||||
if now - last_access > self._client_ttl:
|
||||
expired_ids.append(plugin_id)
|
||||
async with self._sessions_lock:
|
||||
# 收集过期的plugin_id
|
||||
for plugin_id, session in list(self._sessions.items()):
|
||||
if now - session.last_access > self._client_ttl:
|
||||
expired_ids.append(plugin_id)
|
||||
|
||||
if expired_ids:
|
||||
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP客户端")
|
||||
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话")
|
||||
for plugin_id in expired_ids:
|
||||
# 提取user_id来获取对应的锁
|
||||
user_id = plugin_id.split(':', 1)[0]
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
|
||||
async with user_lock:
|
||||
if plugin_id in self._clients:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
async with self._sessions_lock:
|
||||
if plugin_id in self._sessions:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
async def _check_session_health(self):
|
||||
"""增强的会话健康检查"""
|
||||
async with self._sessions_lock:
|
||||
for plugin_id, session in list(self._sessions.items()):
|
||||
# 计算错误率
|
||||
if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK:
|
||||
error_rate = session.error_count / session.request_count
|
||||
|
||||
# 动态调整状态(使用配置常量)
|
||||
if error_rate > mcp_config.ERROR_RATE_CRITICAL:
|
||||
if session.status != "error":
|
||||
session.status = "error"
|
||||
logger.error(
|
||||
f"❌ 会话 {plugin_id} 错误率过高 "
|
||||
f"({error_rate:.1%}), 标记为error"
|
||||
)
|
||||
elif error_rate > mcp_config.ERROR_RATE_WARNING:
|
||||
if session.status == "active":
|
||||
session.status = "degraded"
|
||||
logger.warning(
|
||||
f"⚠️ 会话 {plugin_id} 健康状况下降 "
|
||||
f"(错误率: {error_rate:.1%})"
|
||||
)
|
||||
elif session.status == "degraded":
|
||||
# 错误率降低,恢复正常
|
||||
session.status = "active"
|
||||
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
|
||||
|
||||
# 检查长时间无活动的会话
|
||||
idle_time = time.time() - session.last_access
|
||||
if idle_time > mcp_config.IDLE_TIMEOUT_SECONDS:
|
||||
logger.info(
|
||||
f"💤 会话 {plugin_id} 空闲 {idle_time/60:.1f} 分钟,"
|
||||
f"准备清理"
|
||||
)
|
||||
|
||||
async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
|
||||
"""
|
||||
@@ -103,26 +163,33 @@ class MCPPluginRegistry:
|
||||
self._user_locks[user_id] = asyncio.Lock()
|
||||
return self._user_locks[user_id]
|
||||
|
||||
def _touch_client(self, plugin_id: str):
|
||||
def _touch_session(self, plugin_id: str):
|
||||
"""
|
||||
更新客户端的最后访问时间(LRU)
|
||||
更新会话的最后访问时间(需要在锁内调用)
|
||||
|
||||
Args:
|
||||
plugin_id: 插件ID
|
||||
"""
|
||||
if plugin_id in self._clients:
|
||||
client, _ = self._clients[plugin_id]
|
||||
self._clients[plugin_id] = (client, time.time())
|
||||
# 移到末尾(LRU)
|
||||
self._clients.move_to_end(plugin_id)
|
||||
if plugin_id in self._sessions:
|
||||
session = self._sessions[plugin_id]
|
||||
session.last_access = time.time()
|
||||
session.request_count += 1
|
||||
|
||||
async def _evict_lru_client(self):
|
||||
"""驱逐最久未使用的客户端(当达到max_clients限制时)"""
|
||||
if len(self._clients) >= self._max_clients:
|
||||
# 获取最旧的plugin_id
|
||||
oldest_id = next(iter(self._clients))
|
||||
logger.info(f"📤 达到最大客户端数量限制,驱逐: {oldest_id}")
|
||||
await self._unload_plugin_unsafe(oldest_id)
|
||||
async def _evict_lru_session(self):
|
||||
"""驱逐最久未使用的会话(当达到max_clients限制时)"""
|
||||
if len(self._sessions) >= self._max_clients:
|
||||
# 找到最旧的会话
|
||||
oldest_id = None
|
||||
oldest_time = float('inf')
|
||||
|
||||
for plugin_id, session in self._sessions.items():
|
||||
if session.last_access < oldest_time:
|
||||
oldest_time = session.last_access
|
||||
oldest_id = plugin_id
|
||||
|
||||
if oldest_id:
|
||||
logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}")
|
||||
await self._unload_plugin_unsafe(oldest_id)
|
||||
|
||||
async def load_plugin(self, plugin: MCPPlugin) -> bool:
|
||||
"""
|
||||
@@ -141,11 +208,12 @@ class MCPPluginRegistry:
|
||||
plugin_id = f"{plugin.user_id}:{plugin.plugin_name}"
|
||||
|
||||
# 如果已加载,先卸载
|
||||
if plugin_id in self._clients:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
# 检查是否需要驱逐LRU客户端
|
||||
await self._evict_lru_client()
|
||||
async with self._sessions_lock:
|
||||
if plugin_id in self._sessions:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
# 检查是否需要驱逐LRU会话
|
||||
await self._evict_lru_session()
|
||||
|
||||
# 目前只支持HTTP类型
|
||||
if plugin.plugin_type == "http":
|
||||
@@ -153,18 +221,30 @@ class MCPPluginRegistry:
|
||||
logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}")
|
||||
return False
|
||||
|
||||
# 使用共享HTTP连接池创建客户端
|
||||
# 为每个插件创建独立的HTTP客户端
|
||||
client = HTTPMCPClient(
|
||||
url=plugin.server_url,
|
||||
headers=plugin.headers or {},
|
||||
env=plugin.env or {},
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0,
|
||||
http_client=self._shared_http_client # 传入共享连接池
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
)
|
||||
|
||||
# 存储客户端和当前时间戳
|
||||
self._clients[plugin_id] = (client, time.time())
|
||||
logger.info(f"✅ 加载MCP插件: {plugin_id}")
|
||||
# 创建会话信息
|
||||
now = time.time()
|
||||
session = SessionInfo(
|
||||
client=client,
|
||||
created_at=now,
|
||||
last_access=now,
|
||||
request_count=0,
|
||||
error_count=0,
|
||||
status="active"
|
||||
)
|
||||
|
||||
# 存储会话
|
||||
async with self._sessions_lock:
|
||||
self._sessions[plugin_id] = session
|
||||
|
||||
logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
|
||||
@@ -186,18 +266,19 @@ class MCPPluginRegistry:
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
async with user_lock:
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
async with self._sessions_lock:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
async def _unload_plugin_unsafe(self, plugin_id: str):
|
||||
"""卸载插件(不加锁,内部使用)"""
|
||||
if plugin_id in self._clients:
|
||||
client, _ = self._clients[plugin_id] # 解包 (client, timestamp)
|
||||
"""卸载插件(不加锁,内部使用,需要在sessions_lock内调用)"""
|
||||
if plugin_id in self._sessions:
|
||||
session = self._sessions[plugin_id]
|
||||
try:
|
||||
await client.close()
|
||||
await session.client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭插件客户端失败 {plugin_id}: {e}")
|
||||
|
||||
del self._clients[plugin_id]
|
||||
del self._sessions[plugin_id]
|
||||
logger.info(f"卸载MCP插件: {plugin_id}")
|
||||
|
||||
async def reload_plugin(self, plugin: MCPPlugin) -> bool:
|
||||
@@ -215,7 +296,7 @@ class MCPPluginRegistry:
|
||||
|
||||
def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]:
|
||||
"""
|
||||
获取插件客户端(支持LRU访问时间更新)
|
||||
获取插件客户端(线程安全,支持访问时间更新)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
@@ -225,13 +306,68 @@ class MCPPluginRegistry:
|
||||
客户端实例或None
|
||||
"""
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
entry = self._clients.get(plugin_id)
|
||||
if entry:
|
||||
# 更新访问时间(LRU)
|
||||
self._touch_client(plugin_id)
|
||||
return entry[0] # 返回客户端对象
|
||||
|
||||
session = self._sessions.get(plugin_id)
|
||||
if session:
|
||||
# 检查会话状态
|
||||
if session.status == "error":
|
||||
logger.warning(
|
||||
f"⚠️ 会话 {plugin_id} 处于错误状态,"
|
||||
f"建议调用者重新加载插件"
|
||||
)
|
||||
# 不返回错误状态的客户端
|
||||
return None
|
||||
|
||||
# ✅ 使用锁保护状态更新,避免并发问题
|
||||
# 注意:这里使用原子操作更新简单字段,不需要异步锁
|
||||
session.last_access = time.time()
|
||||
session.request_count += 1
|
||||
return session.client
|
||||
return None
|
||||
|
||||
async def get_or_reconnect_client(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
plugin: MCPPlugin
|
||||
) -> HTTPMCPClient:
|
||||
"""
|
||||
获取或重连客户端(自动处理错误状态)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
plugin: 插件配置对象
|
||||
|
||||
Returns:
|
||||
客户端实例
|
||||
|
||||
Raises:
|
||||
ValueError: 插件加载失败
|
||||
"""
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
|
||||
# 获取用户锁
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
async with user_lock:
|
||||
session = self._sessions.get(plugin_id)
|
||||
|
||||
# 检查会话健康状态
|
||||
if session and session.status == "error":
|
||||
logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连")
|
||||
async with self._sessions_lock:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
session = None
|
||||
|
||||
# 如果没有会话,加载插件
|
||||
if not session:
|
||||
success = await self.load_plugin(plugin)
|
||||
if not success:
|
||||
raise ValueError(f"插件加载失败: {plugin_name}")
|
||||
session = self._sessions[plugin_id]
|
||||
|
||||
return session.client
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -240,7 +376,7 @@ class MCPPluginRegistry:
|
||||
arguments: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
调用插件工具
|
||||
调用插件工具(带错误计数和状态管理)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
@@ -255,18 +391,39 @@ class MCPPluginRegistry:
|
||||
ValueError: 插件不存在或未启用
|
||||
MCPError: 工具调用失败
|
||||
"""
|
||||
client = self.get_client(user_id, plugin_name)
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
|
||||
if not client:
|
||||
# 获取会话
|
||||
session = self._sessions.get(plugin_id)
|
||||
if not session:
|
||||
raise ValueError(f"插件未加载: {plugin_name}")
|
||||
|
||||
try:
|
||||
result = await client.call_tool(tool_name, arguments)
|
||||
result = await session.client.call_tool(tool_name, arguments)
|
||||
logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}")
|
||||
# logger.info(f"✅ 工具返回内容: {result}")
|
||||
|
||||
# 调用成功,重置状态(如果之前是degraded)
|
||||
if session.status == "degraded":
|
||||
session.status = "active"
|
||||
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 工具调用失败: {plugin_name}.{tool_name}, 错误: {e}")
|
||||
# 增加错误计数
|
||||
session.error_count += 1
|
||||
|
||||
# 根据错误率更新状态
|
||||
if session.request_count > 0:
|
||||
error_rate = session.error_count / session.request_count
|
||||
if error_rate > 0.5:
|
||||
session.status = "error"
|
||||
elif error_rate > 0.3:
|
||||
session.status = "degraded"
|
||||
|
||||
logger.error(
|
||||
f"❌ 工具调用失败: {plugin_name}.{tool_name}, "
|
||||
f"错误: {e} (错误计数: {session.error_count}/{session.request_count})"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_plugin_tools(
|
||||
@@ -320,7 +477,7 @@ class MCPPluginRegistry:
|
||||
|
||||
async def cleanup_all(self):
|
||||
"""清理所有插件和资源"""
|
||||
# 停止后台清理任务
|
||||
# 停止后台任务
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
@@ -328,19 +485,18 @@ class MCPPluginRegistry:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 清理所有客户端
|
||||
plugin_ids = list(self._clients.keys())
|
||||
for plugin_id in plugin_ids:
|
||||
user_id = plugin_id.split(':', 1)[0]
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
async with user_lock:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 关闭共享HTTP客户端
|
||||
try:
|
||||
await self._shared_http_client.aclose()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭共享HTTP客户端失败: {e}")
|
||||
# 清理所有会话
|
||||
async with self._sessions_lock:
|
||||
plugin_ids = list(self._sessions.keys())
|
||||
for plugin_id in plugin_ids:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
logger.info("✅ 已清理所有MCP插件和资源")
|
||||
|
||||
|
||||
@@ -0,0 +1,320 @@
|
||||
"""MCP插件测试服务 - 专门处理插件测试逻辑"""
|
||||
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.models.settings import Settings as UserSettings
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.services.ai_service import create_user_ai_service
|
||||
from app.schemas.mcp_plugin import MCPTestResult
|
||||
from app.logger import get_logger
|
||||
from app.user_manager import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPTestService:
|
||||
"""MCP插件测试服务(分离的测试逻辑)"""
|
||||
|
||||
async def test_plugin_connection(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> MCPTestResult:
|
||||
"""
|
||||
简单连接测试
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保插件已加载
|
||||
if not mcp_registry.get_client(user_id, plugin.plugin_name):
|
||||
success = await mcp_registry.load_plugin(plugin)
|
||||
if not success:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件加载失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
)
|
||||
|
||||
# 测试连接并获取工具列表
|
||||
test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
if test_result["success"]:
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ 连接测试成功",
|
||||
response_time_ms=response_time,
|
||||
tools_count=test_result.get("tools_count", 0),
|
||||
suggestions=[
|
||||
f"响应时间: {response_time}ms",
|
||||
f"可用工具数: {test_result.get('tools_count', 0)}"
|
||||
]
|
||||
)
|
||||
else:
|
||||
return MCPTestResult(**test_result)
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 测试失败",
|
||||
response_time_ms=response_time,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
suggestions=[
|
||||
"请检查服务器是否在线",
|
||||
"请确认配置正确",
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
async def test_plugin_with_ai(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
user: User,
|
||||
db_session: AsyncSession
|
||||
) -> MCPTestResult:
|
||||
"""
|
||||
使用AI进行智能工具调用测试
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user: 用户对象
|
||||
db_session: 数据库会话
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. 先进行连接测试
|
||||
connection_result = await self.test_plugin_connection(plugin, user.user_id)
|
||||
|
||||
if not connection_result.success:
|
||||
return connection_result
|
||||
|
||||
# 2. 获取工具列表
|
||||
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
|
||||
|
||||
if not tools:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件没有提供任何工具",
|
||||
error="工具列表为空",
|
||||
response_time_ms=connection_result.response_time_ms,
|
||||
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
|
||||
)
|
||||
|
||||
# 3. 获取用户的AI设置
|
||||
settings_result = await db_session.execute(
|
||||
select(UserSettings).where(UserSettings.user_id == user.user_id)
|
||||
)
|
||||
user_settings = settings_result.scalar_one_or_none()
|
||||
|
||||
if not user_settings or not user_settings.api_key:
|
||||
# 没有AI配置,返回简单测试结果
|
||||
logger.warning("用户未配置AI服务,跳过智能测试")
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
|
||||
response_time_ms=connection_result.response_time_ms,
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
f"连接测试: 成功",
|
||||
f"可用工具数: {len(tools)}",
|
||||
"提示: 配置AI服务后可进行智能工具调用测试"
|
||||
]
|
||||
)
|
||||
|
||||
# 4. 使用AI选择工具并生成测试参数
|
||||
logger.info(f"使用AI分析工具并生成测试计划...")
|
||||
|
||||
ai_service = create_user_ai_service(
|
||||
api_provider=user_settings.api_provider,
|
||||
api_key=user_settings.api_key,
|
||||
api_base_url=user_settings.api_base_url,
|
||||
model_name=user_settings.llm_model,
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
# 转换为OpenAI Function Calling格式
|
||||
openai_tools = self._convert_tools_to_openai_format(tools)
|
||||
|
||||
# 调用AI选择工具
|
||||
prompt = f"""你是MCP插件测试助手,需要测试插件 '{plugin.plugin_name}' 的功能。
|
||||
|
||||
⚠️ 重要规则:生成参数时,必须严格使用工具 schema 中定义的原始参数名称,不要转换为 snake_case 或其他格式。
|
||||
例如:如果 schema 中是 'nextThoughtNeeded',就必须使用 'nextThoughtNeeded',不能改成 'next_thought_needed'。
|
||||
|
||||
请选择一个合适的工具进行测试,优先选择搜索、查询类工具。
|
||||
生成真实有效的测试参数(例如搜索"人工智能最新进展"而不是"test")。
|
||||
|
||||
现在开始测试这个插件。"""
|
||||
|
||||
system_prompt = """你是专业的API测试工具。当给定工具列表时,选择一个工具并使用合适的参数调用它。
|
||||
|
||||
⚠️ 关键规则:调用工具时,必须严格使用 schema 中定义的原始参数名,不要自行转换命名风格。
|
||||
- 如果参数名是 camelCase(如 nextThoughtNeeded),就使用 camelCase
|
||||
- 如果参数名是 snake_case(如 next_thought),就使用 snake_case
|
||||
- 保持与 schema 中定义的完全一致,包括大小写和命名风格"""
|
||||
|
||||
ai_response = await ai_service.generate_text(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
tools=openai_tools,
|
||||
tool_choice="required"
|
||||
)
|
||||
|
||||
# 5. 检查AI是否返回工具调用
|
||||
if not ai_response.get("tool_calls"):
|
||||
logger.error(f"❌ AI未返回工具调用")
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI Function Calling失败",
|
||||
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
"请确认使用的AI模型支持Function Calling",
|
||||
f"当前Provider: {user_settings.api_provider}",
|
||||
f"当前模型: {user_settings.llm_model}"
|
||||
]
|
||||
)
|
||||
|
||||
# 6. 解析工具调用
|
||||
tool_call = ai_response["tool_calls"][0]
|
||||
function = tool_call["function"]
|
||||
tool_name = function["name"]
|
||||
test_arguments = function["arguments"]
|
||||
|
||||
if isinstance(test_arguments, str):
|
||||
try:
|
||||
test_arguments = json.loads(test_arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析AI参数失败: {e}")
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ AI返回的参数格式错误",
|
||||
error=f"无法解析参数JSON: {str(e)}",
|
||||
tools_count=len(tools)
|
||||
)
|
||||
|
||||
logger.info(f"🤖 AI选择的工具: {tool_name}")
|
||||
logger.info(f"📝 AI生成的参数: {test_arguments}")
|
||||
|
||||
# 7. 调用MCP工具
|
||||
call_start = time.time()
|
||||
try:
|
||||
tool_result = await mcp_registry.call_tool(
|
||||
user.user_id,
|
||||
plugin.plugin_name,
|
||||
tool_name,
|
||||
test_arguments
|
||||
)
|
||||
|
||||
call_end = time.time()
|
||||
call_time = round((call_end - call_start) * 1000, 2)
|
||||
total_time = round((call_end - start_time) * 1000, 2)
|
||||
|
||||
# 格式化结果
|
||||
result_str = str(tool_result)
|
||||
if len(result_str) > 800:
|
||||
result_preview = result_str[:800] + "\n...(结果已截断)"
|
||||
else:
|
||||
result_preview = result_str
|
||||
|
||||
return MCPTestResult(
|
||||
success=True,
|
||||
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
|
||||
response_time_ms=total_time,
|
||||
tools_count=len(tools),
|
||||
suggestions=[
|
||||
f"🤖 AI选择: {tool_name}",
|
||||
f"📝 参数: {json.dumps(test_arguments, ensure_ascii=False)}",
|
||||
f"⏱️ 耗时: {call_time}ms",
|
||||
f"📊 结果:\n{result_preview}"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as call_error:
|
||||
call_end = time.time()
|
||||
total_time = round((call_end - start_time) * 1000, 2)
|
||||
|
||||
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
|
||||
|
||||
return MCPTestResult(
|
||||
success=True, # 连接成功就算测试通过
|
||||
message=f"⚠️ 连接成功,但工具调用失败",
|
||||
response_time_ms=total_time,
|
||||
tools_count=len(tools),
|
||||
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
|
||||
suggestions=[
|
||||
f"✅ 连接测试: 成功",
|
||||
f"❌ 工具调用测试: 失败",
|
||||
f"🤖 AI选择: {tool_name}",
|
||||
f"❌ 错误: {str(call_error)}",
|
||||
"💡 可能原因: API Key无效、参数错误或服务限制"
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
total_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="❌ 测试失败",
|
||||
response_time_ms=total_time,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
suggestions=[
|
||||
"请检查服务器是否在线",
|
||||
"请确认配置正确",
|
||||
"请检查API Key是否有效"
|
||||
]
|
||||
)
|
||||
|
||||
def _convert_tools_to_openai_format(self, tools: list) -> list:
|
||||
"""将MCP工具格式转换为OpenAI Function Calling格式"""
|
||||
openai_tools = []
|
||||
for tool in tools:
|
||||
openai_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
}
|
||||
}
|
||||
if "inputSchema" in tool:
|
||||
openai_tool["function"]["parameters"] = tool["inputSchema"]
|
||||
openai_tools.append(openai_tool)
|
||||
return openai_tools
|
||||
|
||||
|
||||
# 全局单例
|
||||
mcp_test_service = MCPTestService()
|
||||
@@ -5,26 +5,100 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.mcp.registry import mcp_registry
|
||||
from app.mcp.config import mcp_config
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolMetrics:
|
||||
"""工具调用指标"""
|
||||
total_calls: int = 0
|
||||
success_calls: int = 0
|
||||
failed_calls: int = 0
|
||||
total_duration_ms: float = 0.0
|
||||
avg_duration_ms: float = 0.0
|
||||
last_call_time: Optional[datetime] = None
|
||||
|
||||
def update_success(self, duration_ms: float):
|
||||
"""更新成功调用指标"""
|
||||
self.total_calls += 1
|
||||
self.success_calls += 1
|
||||
self.total_duration_ms += duration_ms
|
||||
self.avg_duration_ms = self.total_duration_ms / self.total_calls
|
||||
self.last_call_time = datetime.now()
|
||||
|
||||
def update_failure(self, duration_ms: float):
|
||||
"""更新失败调用指标"""
|
||||
self.total_calls += 1
|
||||
self.failed_calls += 1
|
||||
self.total_duration_ms += duration_ms
|
||||
self.avg_duration_ms = self.total_duration_ms / self.total_calls
|
||||
self.last_call_time = datetime.now()
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""成功率"""
|
||||
if self.total_calls == 0:
|
||||
return 0.0
|
||||
return self.success_calls / self.total_calls
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCacheEntry:
|
||||
"""工具缓存条目"""
|
||||
tools: List[Dict[str, Any]]
|
||||
expire_time: datetime
|
||||
hit_count: int = 0
|
||||
|
||||
|
||||
class MCPToolServiceError(Exception):
|
||||
"""MCP工具服务异常"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPToolService:
|
||||
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
|
||||
"""MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)"""
|
||||
|
||||
def __init__(self):
|
||||
self._tool_cache = {} # 工具定义缓存
|
||||
self._result_cache = {} # 工具结果缓存(可选)
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl_minutes: Optional[int] = None,
|
||||
max_retries: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化MCP工具服务
|
||||
|
||||
Args:
|
||||
cache_ttl_minutes: 工具缓存TTL(分钟,默认使用配置)
|
||||
max_retries: 最大重试次数(默认使用配置)
|
||||
"""
|
||||
# 工具定义缓存: {cache_key: ToolCacheEntry}
|
||||
self._tool_cache: Dict[str, ToolCacheEntry] = {}
|
||||
self._cache_ttl = timedelta(
|
||||
minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES
|
||||
)
|
||||
|
||||
# 调用指标: {tool_key: ToolMetrics}
|
||||
self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics)
|
||||
|
||||
# 重试配置(使用配置常量)
|
||||
self._max_retries = max_retries or mcp_config.MAX_RETRIES
|
||||
self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS
|
||||
self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS
|
||||
|
||||
logger.info(
|
||||
f"✅ MCPToolService初始化完成 "
|
||||
f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, "
|
||||
f"最大重试={self._max_retries}次)"
|
||||
)
|
||||
|
||||
async def get_user_enabled_tools(
|
||||
self,
|
||||
@@ -61,7 +135,7 @@ class MCPToolService:
|
||||
logger.info(f"用户 {user_id} 没有启用的MCP插件")
|
||||
return []
|
||||
|
||||
# 2. 获取所有工具定义
|
||||
# 2. 获取所有工具定义(使用缓存)
|
||||
all_tools = []
|
||||
for plugin in plugins:
|
||||
try:
|
||||
@@ -73,8 +147,8 @@ class MCPToolService:
|
||||
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
|
||||
continue
|
||||
|
||||
# 从registry获取该插件的工具列表
|
||||
plugin_tools = await mcp_registry.get_plugin_tools(
|
||||
# ✅ 使用缓存获取工具列表
|
||||
plugin_tools = await self._get_plugin_tools_cached(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name
|
||||
)
|
||||
@@ -82,7 +156,7 @@ class MCPToolService:
|
||||
# 格式化为Function Calling格式
|
||||
formatted_tools = self._format_tools_for_ai(
|
||||
plugin_tools,
|
||||
plugin.plugin_name # ✅ 修复:使用正确的属性名plugin_name
|
||||
plugin.plugin_name
|
||||
)
|
||||
all_tools.extend(formatted_tools)
|
||||
|
||||
@@ -139,12 +213,85 @@ class MCPToolService:
|
||||
|
||||
return formatted_tools
|
||||
|
||||
async def _get_plugin_tools_cached(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
带缓存的工具列表获取
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
cache_key = f"{user_id}:{plugin_name}"
|
||||
now = datetime.now()
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in self._tool_cache:
|
||||
entry = self._tool_cache[cache_key]
|
||||
if now < entry.expire_time:
|
||||
entry.hit_count += 1
|
||||
logger.debug(
|
||||
f"🎯 工具缓存命中: {cache_key} "
|
||||
f"(命中次数: {entry.hit_count})"
|
||||
)
|
||||
return entry.tools
|
||||
else:
|
||||
logger.debug(f"⏰ 工具缓存过期: {cache_key}")
|
||||
del self._tool_cache[cache_key]
|
||||
|
||||
# 缓存未命中,从MCP获取
|
||||
logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}")
|
||||
tools = await mcp_registry.get_plugin_tools(user_id, plugin_name)
|
||||
|
||||
# 更新缓存
|
||||
self._tool_cache[cache_key] = ToolCacheEntry(
|
||||
tools=tools,
|
||||
expire_time=now + self._cache_ttl,
|
||||
hit_count=0
|
||||
)
|
||||
|
||||
return tools
|
||||
|
||||
def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None):
|
||||
"""
|
||||
清理缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(可选,清理特定用户的缓存)
|
||||
plugin_name: 插件名称(可选,清理特定插件的缓存)
|
||||
"""
|
||||
if user_id is None and plugin_name is None:
|
||||
# 清理所有缓存
|
||||
self._tool_cache.clear()
|
||||
logger.info("🧹 已清理所有工具缓存")
|
||||
elif user_id and plugin_name:
|
||||
# 清理特定插件缓存
|
||||
cache_key = f"{user_id}:{plugin_name}"
|
||||
if cache_key in self._tool_cache:
|
||||
del self._tool_cache[cache_key]
|
||||
logger.info(f"🧹 已清理缓存: {cache_key}")
|
||||
elif user_id:
|
||||
# 清理用户所有缓存
|
||||
keys_to_delete = [
|
||||
key for key in self._tool_cache.keys()
|
||||
if key.startswith(f"{user_id}:")
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self._tool_cache[key]
|
||||
logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)")
|
||||
|
||||
async def execute_tool_calls(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
db_session: AsyncSession,
|
||||
timeout: float = 60.0
|
||||
timeout: Optional[float] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量执行AI请求的工具调用(并行执行)
|
||||
@@ -153,7 +300,7 @@ class MCPToolService:
|
||||
user_id: 用户ID
|
||||
tool_calls: AI返回的工具调用列表
|
||||
db_session: 数据库会话
|
||||
timeout: 单个工具调用的超时时间(秒,默认30秒)
|
||||
timeout: 单个工具调用的超时时间(秒,默认使用配置)
|
||||
|
||||
Returns:
|
||||
工具调用结果列表
|
||||
@@ -161,7 +308,10 @@ class MCPToolService:
|
||||
if not tool_calls:
|
||||
return []
|
||||
|
||||
logger.info(f"开始执行 {len(tool_calls)} 个工具调用")
|
||||
# 使用配置的默认超时
|
||||
actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS
|
||||
|
||||
logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s)")
|
||||
|
||||
# 创建异步任务列表
|
||||
tasks = [
|
||||
@@ -169,7 +319,7 @@ class MCPToolService:
|
||||
user_id=user_id,
|
||||
tool_call=tool_call,
|
||||
db_session=db_session,
|
||||
timeout=timeout
|
||||
timeout=actual_timeout
|
||||
)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
@@ -238,18 +388,28 @@ class MCPToolService:
|
||||
f"参数: {arguments}"
|
||||
)
|
||||
|
||||
# 设置超时
|
||||
# ✅ 使用带重试的调用
|
||||
tool_key = f"{plugin_name}.{tool_name}"
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
mcp_registry.call_tool(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments
|
||||
),
|
||||
result = await self._call_tool_with_retry(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 记录成功指标
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_success(duration_ms)
|
||||
|
||||
logger.info(
|
||||
f"✅ 工具调用成功: {tool_key} "
|
||||
f"(耗时: {duration_ms:.2f}ms)"
|
||||
)
|
||||
|
||||
# 成功返回
|
||||
return {
|
||||
"tool_call_id": tool_call_id,
|
||||
@@ -261,13 +421,21 @@ class MCPToolService:
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 记录失败指标
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_failure(duration_ms)
|
||||
raise MCPToolServiceError(
|
||||
f"工具调用超时(>{timeout}秒)"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 记录失败指标
|
||||
tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self._metrics[tool_key].update_failure(duration_ms)
|
||||
|
||||
logger.error(
|
||||
f"工具 {function_name} 调用失败: {e}",
|
||||
f"❌ 工具 {function_name} 调用失败: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
@@ -279,6 +447,146 @@ class MCPToolService:
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _call_tool_with_retry(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
timeout: float
|
||||
) -> Any:
|
||||
"""
|
||||
带指数退避重试的工具调用
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPToolServiceError: 工具调用失败
|
||||
asyncio.TimeoutError: 调用超时
|
||||
"""
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
# 尝试调用工具
|
||||
result = await asyncio.wait_for(
|
||||
mcp_registry.call_tool(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 成功则返回
|
||||
if attempt > 0:
|
||||
logger.info(
|
||||
f"✅ 重试成功: {plugin_name}.{tool_name} "
|
||||
f"(第{attempt + 1}次尝试)"
|
||||
)
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时不重试,直接抛出
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
# 最后一次尝试失败
|
||||
if attempt == self._max_retries - 1:
|
||||
logger.error(
|
||||
f"❌ 重试失败: {plugin_name}.{tool_name} "
|
||||
f"(已尝试{self._max_retries}次): {e}"
|
||||
)
|
||||
raise MCPToolServiceError(
|
||||
f"工具调用失败(已重试{self._max_retries}次): {str(e)}"
|
||||
)
|
||||
|
||||
# 计算指数退避延迟
|
||||
delay = min(
|
||||
self._base_retry_delay * (2 ** attempt),
|
||||
self._max_retry_delay
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"⚠️ 工具调用失败,{delay:.1f}秒后重试 "
|
||||
f"(第{attempt + 1}/{self._max_retries}次): "
|
||||
f"{plugin_name}.{tool_name} - {e}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# 理论上不会到这里,但为了类型安全
|
||||
raise MCPToolServiceError(f"工具调用失败: {last_exception}")
|
||||
|
||||
def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取工具调用指标
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称(可选,获取特定工具的指标)
|
||||
|
||||
Returns:
|
||||
指标字典
|
||||
"""
|
||||
if tool_name:
|
||||
if tool_name in self._metrics:
|
||||
metric = self._metrics[tool_name]
|
||||
return {
|
||||
tool_name: {
|
||||
"total_calls": metric.total_calls,
|
||||
"success_calls": metric.success_calls,
|
||||
"failed_calls": metric.failed_calls,
|
||||
"success_rate": metric.success_rate,
|
||||
"avg_duration_ms": round(metric.avg_duration_ms, 2),
|
||||
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
|
||||
}
|
||||
}
|
||||
return {}
|
||||
|
||||
# 返回所有工具的指标
|
||||
result = {}
|
||||
for tool_key, metric in self._metrics.items():
|
||||
result[tool_key] = {
|
||||
"total_calls": metric.total_calls,
|
||||
"success_calls": metric.success_calls,
|
||||
"failed_calls": metric.failed_calls,
|
||||
"success_rate": round(metric.success_rate, 3),
|
||||
"avg_duration_ms": round(metric.avg_duration_ms, 2),
|
||||
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
|
||||
}
|
||||
return result
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
total_entries = len(self._tool_cache)
|
||||
total_hits = sum(entry.hit_count for entry in self._tool_cache.values())
|
||||
|
||||
return {
|
||||
"total_entries": total_entries,
|
||||
"total_hits": total_hits,
|
||||
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
|
||||
"entries": [
|
||||
{
|
||||
"key": key,
|
||||
"tools_count": len(entry.tools),
|
||||
"hit_count": entry.hit_count,
|
||||
"expire_time": entry.expire_time.isoformat()
|
||||
}
|
||||
for key, entry in self._tool_cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
async def build_tool_context(
|
||||
self,
|
||||
tool_results: List[Dict[str, Any]],
|
||||
|
||||
@@ -226,8 +226,8 @@ class PlotAnalyzer:
|
||||
)
|
||||
|
||||
# 🔍 添加调试日志:查看AI返回的原始内容
|
||||
logger.info(f"🔍 AI返回类型: {type(response)}")
|
||||
logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
|
||||
# logger.info(f"🔍 AI返回类型: {type(response)}")
|
||||
# logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
|
||||
|
||||
# 从返回的字典中提取content字段
|
||||
if isinstance(response, dict):
|
||||
|
||||
Reference in New Issue
Block a user