fix:1.优化mcp插件功能,改用mcp sdk库

This commit is contained in:
xiamuceer
2025-11-08 12:32:32 +08:00
parent 88115a45c5
commit c7c1c1fdf3
9 changed files with 1278 additions and 660 deletions
+42
View File
@@ -0,0 +1,42 @@
"""MCP模块配置常量"""
from dataclasses import dataclass
@dataclass(frozen=True)
class MCPConfig:
"""MCP模块配置常量(不可变)"""
# 连接池配置
MAX_CLIENTS: int = 1000 # 最大客户端数量
CLIENT_TTL_SECONDS: int = 3600 # 客户端过期时间(1小时)
IDLE_TIMEOUT_SECONDS: int = 1800 # 空闲超时(30分钟)
# 健康检查配置
HEALTH_CHECK_INTERVAL_SECONDS: int = 60 # 健康检查间隔
ERROR_RATE_CRITICAL: float = 0.7 # 严重错误率阈值
ERROR_RATE_WARNING: float = 0.4 # 警告错误率阈值
MIN_REQUESTS_FOR_HEALTH_CHECK: int = 10 # 进行健康检查的最小请求数
# 清理任务配置
CLEANUP_INTERVAL_SECONDS: int = 300 # 清理任务间隔(5分钟)
# 缓存配置
TOOL_CACHE_TTL_MINUTES: int = 10 # 工具定义缓存TTL
# 重试配置
MAX_RETRIES: int = 3 # 最大重试次数
BASE_RETRY_DELAY_SECONDS: float = 1.0 # 基础重试延迟
MAX_RETRY_DELAY_SECONDS: float = 10.0 # 最大重试延迟
# 超时配置
DEFAULT_TIMEOUT_SECONDS: float = 60.0 # 默认超时时间
TOOL_CALL_TIMEOUT_SECONDS: float = 60.0 # 工具调用超时时间
# 日志配置
LOG_TOOL_ARGUMENTS: bool = True # 是否记录工具参数
LOG_TOOL_RESULTS: bool = False # 是否记录工具结果(可能很大)
# 全局配置实例
mcp_config = MCPConfig()
+185 -197
View File
@@ -1,8 +1,13 @@
"""HTTP MCP客户端 - 实现JSON-RPC 2.0协议"""
import httpx
"""HTTP MCP客户端 - 使用官方 MCP Python SDK 实现"""
import asyncio
from typing import Dict, Any, List, Optional
from contextlib import asynccontextmanager
from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from pydantic import AnyUrl
from app.logger import get_logger
import time
logger = get_logger(__name__)
@@ -13,15 +18,14 @@ class MCPError(Exception):
class HTTPMCPClient:
"""HTTP模式MCP客户端(类似Cursor/Claude Code实现"""
"""HTTP模式MCP客户端(基于官方 MCP Python SDK"""
def __init__(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0,
http_client: Optional[httpx.AsyncClient] = None
timeout: float = 60.0
):
"""
初始化HTTP MCP客户端
@@ -31,162 +35,79 @@ class HTTPMCPClient:
headers: HTTP请求头
env: 环境变量(用于API Key等)
timeout: 超时时间(秒)
http_client: 可选的共享HTTP客户端(用于连接池复用)
"""
self.url = url.rstrip('/')
self.headers = headers or {}
self.env = env or {}
self.timeout = timeout
# 设置MCP必需的Accept头
# MCP服务器要求客户端必须接受 application/json 和 text/event-stream
if 'Accept' not in self.headers:
self.headers['Accept'] = 'application/json, text/event-stream'
# 设置Content-Type
if 'Content-Type' not in self.headers:
self.headers['Content-Type'] = 'application/json'
# 如果env中有API Key,添加到headers
if 'API_KEY' in self.env:
self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}'
# 使用共享客户端或创建新客户端
self._owns_client = http_client is None
if http_client:
self.client = http_client
else:
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(timeout),
headers=self.headers
)
self._request_id = 0
self._session: Optional[ClientSession] = None
self._context_stack = [] # 保存上下文管理器栈
self._initialized = False
self._lock = asyncio.Lock()
def _next_request_id(self) -> int:
"""获取下一个请求ID"""
self._request_id += 1
return self._request_id
async def _call_jsonrpc(
self,
method: str,
params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
调用JSON-RPC 2.0方法
Args:
method: 方法名
params: 参数
Returns:
响应结果
Raises:
MCPError: 调用失败时抛出
"""
request_id = self._next_request_id()
payload = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": params or {}
}
try:
logger.debug(f"MCP请求: {method} -> {self.url}")
response = await self.client.post(
self.url,
json=payload,
headers=self.headers # 显式传递headers(对于共享客户端很重要)
)
response.raise_for_status()
# 获取响应内容
response_text = response.text
content_type = response.headers.get('content-type', '')
# 如果是空响应
if not response_text or response_text.strip() == '':
raise MCPError("服务器返回空响应")
# 处理SSE格式响应
if 'text/event-stream' in content_type or response_text.startswith('event:'):
logger.debug("检测到SSE格式响应,开始解析")
data = self._parse_sse_response(response_text)
else:
# 标准JSON响应
async def _ensure_connected(self):
"""确保连接已建立"""
async with self._lock:
if self._session is None:
try:
data = response.json()
except ValueError as e:
logger.error(f"JSON解析失败,响应内容: {response_text[:500]}")
raise MCPError(f"无法解析JSON响应: {str(e)}")
# 检查JSON-RPC错误
if "error" in data:
error = data["error"]
error_msg = error.get("message", "Unknown error")
error_code = error.get("code", -1)
logger.error(f"MCP错误 [{error_code}]: {error_msg}")
raise MCPError(f"[{error_code}] {error_msg}")
if "result" not in data:
raise MCPError("响应中缺少result字段")
return data["result"]
except httpx.HTTPStatusError as e:
logger.error(f"HTTP错误 {e.response.status_code}: {e.response.text}")
raise MCPError(f"HTTP错误 {e.response.status_code}: {e.response.text}")
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
raise MCPError(f"请求错误: {str(e)}")
except MCPError:
raise
except Exception as e:
logger.error(f"未知错误: {str(e)}")
raise MCPError(f"未知错误: {str(e)}")
logger.info(f"🔗 连接到MCP服务器: {self.url}")
# 使用官方 SDK 的 streamable_http_client
# 保存上下文管理器以便后续正确清理
stream_context = streamablehttp_client(self.url)
read_stream, write_stream, _ = await stream_context.__aenter__()
self._context_stack.append(('stream', stream_context))
# 创建客户端会话
self._session = ClientSession(read_stream, write_stream)
session_context = self._session
await session_context.__aenter__()
self._context_stack.append(('session', session_context))
# 初始化会话
await self._session.initialize()
self._initialized = True
logger.info(f"✅ MCP会话初始化成功")
except Exception as e:
logger.error(f"❌ MCP连接失败: {e}")
await self._cleanup()
raise MCPError(f"连接MCP服务器失败: {str(e)}")
def _parse_sse_response(self, sse_text: str) -> Dict[str, Any]:
async def _cleanup(self):
"""清理连接资源(按照进入的相反顺序退出)"""
# 按照LIFO顺序清理上下文
while self._context_stack:
ctx_type, ctx = self._context_stack.pop()
try:
await ctx.__aexit__(None, None, None)
except RuntimeError as e:
# 忽略 anyio 的任务上下文错误(在关闭时可能发生)
if "cancel scope" in str(e).lower() or "different task" in str(e).lower():
logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}")
else:
logger.error(f"清理{ctx_type}上下文失败: {e}")
except Exception as e:
logger.error(f"清理{ctx_type}上下文失败: {e}")
self._session = None
self._initialized = False
async def initialize(self) -> Dict[str, Any]:
"""
解析SSE格式的响应
初始化MCP会话
SSE格式示例:
event: message
data: {"result": {...}}
Args:
sse_text: SSE格式的文本
Returns:
解析后的JSON数据
初始化响应
"""
import json
lines = sse_text.strip().split('\n')
data_lines = []
for line in lines:
line = line.strip()
if line.startswith('data:'):
# 提取data后面的内容
data_content = line[5:].strip()
data_lines.append(data_content)
if not data_lines:
raise MCPError("SSE响应中没有找到data字段")
# 合并所有data行(某些SSE可能分多行)
full_data = ''.join(data_lines)
try:
return json.loads(full_data)
except json.JSONDecodeError as e:
logger.error(f"解析SSE data失败: {full_data[:200]}")
raise MCPError(f"SSE data不是有效的JSON: {str(e)}")
await self._ensure_connected()
return {"status": "initialized"}
async def list_tools(self) -> List[Dict[str, Any]]:
"""
@@ -196,13 +117,26 @@ class HTTPMCPClient:
工具列表
"""
try:
result = await self._call_jsonrpc("tools/list")
tools = result.get("tools", [])
await self._ensure_connected()
result = await self._session.list_tools()
# 转换为字典格式
tools = []
for tool in result.tools:
tool_dict = {
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema
}
tools.append(tool_dict)
logger.info(f"获取到 {len(tools)} 个工具")
return tools
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
raise
raise MCPError(f"获取工具列表失败: {str(e)}")
async def call_tool(
self,
@@ -220,33 +154,38 @@ class HTTPMCPClient:
工具执行结果
"""
try:
await self._ensure_connected()
logger.info(f"调用工具: {tool_name}")
logger.debug(f"参数: {arguments}")
result = await self._call_jsonrpc(
"tools/call",
{
"name": tool_name,
"arguments": arguments
}
)
result = await self._session.call_tool(tool_name, arguments)
# MCP返回的result通常包含content数组
if isinstance(result, dict) and "content" in result:
content = result["content"]
if isinstance(content, list) and len(content) > 0:
# 提取第一个content项的text
first_content = content[0]
if isinstance(first_content, dict) and "text" in first_content:
return first_content["text"]
return first_content
return content
# 处理返回结果
# MCP SDK 返回 CallToolResult 对象
if result.content:
# 提取第一个content的文本
for content in result.content:
if isinstance(content, types.TextContent):
return content.text
elif isinstance(content, types.ImageContent):
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
# 如果没有文本内容,返回原始内容
return result.content[0] if result.content else None
return result
# 如果有结构化内容(2025-06-18规范)
if hasattr(result, 'structuredContent') and result.structuredContent:
return result.structuredContent
return None
except Exception as e:
logger.error(f"调用工具失败: {tool_name}, 错误: {e}")
raise
raise MCPError(f"调用工具失败: {str(e)}")
async def list_resources(self) -> List[Dict[str, Any]]:
"""
@@ -256,13 +195,27 @@ class HTTPMCPClient:
资源列表
"""
try:
result = await self._call_jsonrpc("resources/list")
resources = result.get("resources", [])
await self._ensure_connected()
result = await self._session.list_resources()
# 转换为字典格式
resources = []
for resource in result.resources:
resource_dict = {
"uri": str(resource.uri),
"name": resource.name,
"description": resource.description or "",
"mimeType": resource.mimeType or ""
}
resources.append(resource_dict)
logger.info(f"获取到 {len(resources)} 个资源")
return resources
except Exception as e:
logger.error(f"获取资源列表失败: {e}")
raise
raise MCPError(f"获取资源列表失败: {str(e)}")
async def read_resource(self, uri: str) -> Any:
"""
@@ -275,14 +228,33 @@ class HTTPMCPClient:
资源内容
"""
try:
result = await self._call_jsonrpc(
"resources/read",
{"uri": uri}
)
return result
await self._ensure_connected()
result = await self._session.read_resource(AnyUrl(uri))
# 提取资源内容
if result.contents:
content = result.contents[0]
if isinstance(content, types.TextContent):
return content.text
elif isinstance(content, types.ImageContent):
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
elif isinstance(content, types.BlobResourceContents):
return {
"type": "blob",
"blob": content.blob,
"mimeType": content.mimeType
}
return None
except Exception as e:
logger.error(f"读取资源失败: {uri}, 错误: {e}")
raise
raise MCPError(f"读取资源失败: {str(e)}")
async def test_connection(self) -> Dict[str, Any]:
"""
@@ -291,10 +263,12 @@ class HTTPMCPClient:
Returns:
测试结果
"""
import time
start_time = time.time()
try:
# 尝试列举工具来测试连接
# 尝试连接并列举工具
await self._ensure_connected()
tools = await self.list_tools()
end_time = time.time()
@@ -307,22 +281,7 @@ class HTTPMCPClient:
"tools_count": len(tools),
"tools": tools
}
except MCPError as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
return {
"success": False,
"message": "连接测试失败",
"response_time_ms": response_time,
"error": str(e),
"error_type": "MCPError",
"suggestions": [
"请检查服务器URL是否正确",
"请确认API Key是否有效",
"请检查网络连接"
]
}
except Exception as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
@@ -334,12 +293,41 @@ class HTTPMCPClient:
"error": str(e),
"error_type": type(e).__name__,
"suggestions": [
"请检查服务器是否在线",
"请确认配置是否正确"
"请检查服务器URL是否正确",
"请确认API Key是否有效",
"请检查网络连接",
"请确认MCP服务器是否在线"
]
}
async def close(self):
"""关闭客户端(仅在拥有客户端所有权时关闭)"""
if self._owns_client and self.client:
await self.client.aclose()
"""关闭客户端连接"""
logger.info(f"关闭MCP客户端: {self.url}")
await self._cleanup()
@asynccontextmanager
async def create_mcp_client(
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0
):
"""
创建MCP客户端的上下文管理器
Args:
url: MCP服务器URL
headers: HTTP请求头
env: 环境变量
timeout: 超时时间
Yields:
HTTPMCPClient实例
"""
client = HTTPMCPClient(url, headers, env, timeout)
try:
await client.initialize()
yield client
finally:
await client.close()
+251 -95
View File
@@ -1,92 +1,152 @@
"""MCP插件注册表 - 管理运行时插件实例"""
import asyncio
import time
import httpx
from typing import Dict, Optional, Any, List, Tuple
from collections import OrderedDict
from typing import Dict, Optional, Any, List
from dataclasses import dataclass
from datetime import datetime
from app.mcp.http_client import HTTPMCPClient, MCPError
from app.mcp.config import mcp_config
from app.models.mcp_plugin import MCPPlugin
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class SessionInfo:
"""会话信息"""
client: HTTPMCPClient
created_at: float
last_access: float
request_count: int = 0
error_count: int = 0
status: str = "active" # active, degraded, error
class MCPPluginRegistry:
"""MCP插件注册表 - 管理运行时插件实例(多用户优化版)"""
"""MCP插件注册表 - 管理运行时插件实例(优化版)"""
def __init__(self, max_clients: int = 1000, client_ttl: int = 3600):
def __init__(
self,
max_clients: Optional[int] = None,
client_ttl: Optional[int] = None
):
"""
初始化注册表
Args:
max_clients: 最大缓存客户端数量
client_ttl: 客户端过期时间(秒,默认1小时
max_clients: 最大缓存客户端数量(默认使用配置)
client_ttl: 客户端过期时间(秒,默认使用配置)
"""
# 存储格式: {plugin_id: (client, last_access_time)}
self._clients: OrderedDict[str, Tuple[HTTPMCPClient, float]] = OrderedDict()
# 存储格式: {plugin_id: SessionInfo}
self._sessions: Dict[str, SessionInfo] = {}
# 全局锁用于保护会话字典
self._sessions_lock = asyncio.Lock()
# 细粒度锁:每个用户一个锁
self._user_locks: Dict[str, asyncio.Lock] = {}
self._locks_lock = asyncio.Lock() # 保护locks字典本身
# 配置参数
self._max_clients = max_clients
self._client_ttl = client_ttl
# 共享HTTP客户端池(用于所有MCP HTTP请求)
self._shared_http_client = httpx.AsyncClient(
limits=httpx.Limits(
max_keepalive_connections=100,
max_connections=200,
keepalive_expiry=30.0
),
timeout=httpx.Timeout(connect=10.0, read=60.0, write=10.0, pool=5.0),
headers={
"User-Agent": "MuMuAINovel-MCP-Client/1.0"
}
)
# 配置参数(使用配置常量)
self._max_clients = max_clients or mcp_config.MAX_CLIENTS
self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS
# 启动后台清理任务
self._cleanup_task = None
self._start_cleanup_task()
self._health_check_task = None
self._start_background_tasks()
def _start_cleanup_task(self):
"""启动后台清理任务"""
def _start_background_tasks(self):
"""启动后台任务"""
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("✅ MCP插件注册表后台清理任务已启动")
if self._health_check_task is None:
self._health_check_task = asyncio.create_task(self._health_check_loop())
logger.info("✅ MCP会话健康检查任务已启动")
async def _cleanup_loop(self):
"""后台清理过期客户端"""
while True:
try:
await asyncio.sleep(300) # 每5分钟清理一次
await self._cleanup_expired_clients()
await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS)
await self._cleanup_expired_sessions()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理任务异常: {e}")
async def _cleanup_expired_clients(self):
"""清理过期的客户端"""
async def _health_check_loop(self):
"""后台健康检查"""
while True:
try:
await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS)
await self._check_session_health()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"健康检查任务异常: {e}")
async def _cleanup_expired_sessions(self):
"""清理过期的会话"""
now = time.time()
expired_ids = []
# 收集过期的plugin_id
for plugin_id, (client, last_access) in list(self._clients.items()):
if now - last_access > self._client_ttl:
expired_ids.append(plugin_id)
async with self._sessions_lock:
# 收集过期的plugin_id
for plugin_id, session in list(self._sessions.items()):
if now - session.last_access > self._client_ttl:
expired_ids.append(plugin_id)
if expired_ids:
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP客户端")
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话")
for plugin_id in expired_ids:
# 提取user_id来获取对应的锁
user_id = plugin_id.split(':', 1)[0]
user_lock = await self._get_user_lock(user_id)
async with user_lock:
if plugin_id in self._clients:
await self._unload_plugin_unsafe(plugin_id)
async with self._sessions_lock:
if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
async def _check_session_health(self):
"""增强的会话健康检查"""
async with self._sessions_lock:
for plugin_id, session in list(self._sessions.items()):
# 计算错误率
if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK:
error_rate = session.error_count / session.request_count
# 动态调整状态(使用配置常量)
if error_rate > mcp_config.ERROR_RATE_CRITICAL:
if session.status != "error":
session.status = "error"
logger.error(
f"❌ 会话 {plugin_id} 错误率过高 "
f"({error_rate:.1%}), 标记为error"
)
elif error_rate > mcp_config.ERROR_RATE_WARNING:
if session.status == "active":
session.status = "degraded"
logger.warning(
f"⚠️ 会话 {plugin_id} 健康状况下降 "
f"(错误率: {error_rate:.1%})"
)
elif session.status == "degraded":
# 错误率降低,恢复正常
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
# 检查长时间无活动的会话
idle_time = time.time() - session.last_access
if idle_time > mcp_config.IDLE_TIMEOUT_SECONDS:
logger.info(
f"💤 会话 {plugin_id} 空闲 {idle_time/60:.1f} 分钟,"
f"准备清理"
)
async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
"""
@@ -103,26 +163,33 @@ class MCPPluginRegistry:
self._user_locks[user_id] = asyncio.Lock()
return self._user_locks[user_id]
def _touch_client(self, plugin_id: str):
def _touch_session(self, plugin_id: str):
"""
更新客户端的最后访问时间(LRU
更新会话的最后访问时间(需要在锁内调用
Args:
plugin_id: 插件ID
"""
if plugin_id in self._clients:
client, _ = self._clients[plugin_id]
self._clients[plugin_id] = (client, time.time())
# 移到末尾(LRU
self._clients.move_to_end(plugin_id)
if plugin_id in self._sessions:
session = self._sessions[plugin_id]
session.last_access = time.time()
session.request_count += 1
async def _evict_lru_client(self):
"""驱逐最久未使用的客户端(当达到max_clients限制时)"""
if len(self._clients) >= self._max_clients:
# 获取最旧的plugin_id
oldest_id = next(iter(self._clients))
logger.info(f"📤 达到最大客户端数量限制,驱逐: {oldest_id}")
await self._unload_plugin_unsafe(oldest_id)
async def _evict_lru_session(self):
"""驱逐最久未使用的会话(当达到max_clients限制时)"""
if len(self._sessions) >= self._max_clients:
# 找到最旧的会话
oldest_id = None
oldest_time = float('inf')
for plugin_id, session in self._sessions.items():
if session.last_access < oldest_time:
oldest_time = session.last_access
oldest_id = plugin_id
if oldest_id:
logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}")
await self._unload_plugin_unsafe(oldest_id)
async def load_plugin(self, plugin: MCPPlugin) -> bool:
"""
@@ -141,11 +208,12 @@ class MCPPluginRegistry:
plugin_id = f"{plugin.user_id}:{plugin.plugin_name}"
# 如果已加载,先卸载
if plugin_id in self._clients:
await self._unload_plugin_unsafe(plugin_id)
# 检查是否需要驱逐LRU客户端
await self._evict_lru_client()
async with self._sessions_lock:
if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
# 检查是否需要驱逐LRU会话
await self._evict_lru_session()
# 目前只支持HTTP类型
if plugin.plugin_type == "http":
@@ -153,18 +221,30 @@ class MCPPluginRegistry:
logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}")
return False
# 使用共享HTTP连接池创建客户端
# 为每个插件创建独立的HTTP客户端
client = HTTPMCPClient(
url=plugin.server_url,
headers=plugin.headers or {},
env=plugin.env or {},
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0,
http_client=self._shared_http_client # 传入共享连接池
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
)
# 存储客户端和当前时间戳
self._clients[plugin_id] = (client, time.time())
logger.info(f"✅ 加载MCP插件: {plugin_id}")
# 创建会话信息
now = time.time()
session = SessionInfo(
client=client,
created_at=now,
last_access=now,
request_count=0,
error_count=0,
status="active"
)
# 存储会话
async with self._sessions_lock:
self._sessions[plugin_id] = session
logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)")
return True
else:
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
@@ -186,18 +266,19 @@ class MCPPluginRegistry:
user_lock = await self._get_user_lock(user_id)
async with user_lock:
plugin_id = f"{user_id}:{plugin_name}"
await self._unload_plugin_unsafe(plugin_id)
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
async def _unload_plugin_unsafe(self, plugin_id: str):
"""卸载插件(不加锁,内部使用)"""
if plugin_id in self._clients:
client, _ = self._clients[plugin_id] # 解包 (client, timestamp)
"""卸载插件(不加锁,内部使用,需要在sessions_lock内调用"""
if plugin_id in self._sessions:
session = self._sessions[plugin_id]
try:
await client.close()
await session.client.close()
except Exception as e:
logger.error(f"关闭插件客户端失败 {plugin_id}: {e}")
del self._clients[plugin_id]
del self._sessions[plugin_id]
logger.info(f"卸载MCP插件: {plugin_id}")
async def reload_plugin(self, plugin: MCPPlugin) -> bool:
@@ -215,7 +296,7 @@ class MCPPluginRegistry:
def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]:
"""
获取插件客户端(支持LRU访问时间更新)
获取插件客户端(线程安全,支持访问时间更新)
Args:
user_id: 用户ID
@@ -225,13 +306,68 @@ class MCPPluginRegistry:
客户端实例或None
"""
plugin_id = f"{user_id}:{plugin_name}"
entry = self._clients.get(plugin_id)
if entry:
# 更新访问时间(LRU
self._touch_client(plugin_id)
return entry[0] # 返回客户端对象
session = self._sessions.get(plugin_id)
if session:
# 检查会话状态
if session.status == "error":
logger.warning(
f"⚠️ 会话 {plugin_id} 处于错误状态,"
f"建议调用者重新加载插件"
)
# 不返回错误状态的客户端
return None
# ✅ 使用锁保护状态更新,避免并发问题
# 注意:这里使用原子操作更新简单字段,不需要异步锁
session.last_access = time.time()
session.request_count += 1
return session.client
return None
async def get_or_reconnect_client(
self,
user_id: str,
plugin_name: str,
plugin: MCPPlugin
) -> HTTPMCPClient:
"""
获取或重连客户端(自动处理错误状态)
Args:
user_id: 用户ID
plugin_name: 插件名称
plugin: 插件配置对象
Returns:
客户端实例
Raises:
ValueError: 插件加载失败
"""
plugin_id = f"{user_id}:{plugin_name}"
# 获取用户锁
user_lock = await self._get_user_lock(user_id)
async with user_lock:
session = self._sessions.get(plugin_id)
# 检查会话健康状态
if session and session.status == "error":
logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连")
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
session = None
# 如果没有会话,加载插件
if not session:
success = await self.load_plugin(plugin)
if not success:
raise ValueError(f"插件加载失败: {plugin_name}")
session = self._sessions[plugin_id]
return session.client
async def call_tool(
self,
user_id: str,
@@ -240,7 +376,7 @@ class MCPPluginRegistry:
arguments: Dict[str, Any]
) -> Any:
"""
调用插件工具
调用插件工具(带错误计数和状态管理)
Args:
user_id: 用户ID
@@ -255,18 +391,39 @@ class MCPPluginRegistry:
ValueError: 插件不存在或未启用
MCPError: 工具调用失败
"""
client = self.get_client(user_id, plugin_name)
plugin_id = f"{user_id}:{plugin_name}"
if not client:
# 获取会话
session = self._sessions.get(plugin_id)
if not session:
raise ValueError(f"插件未加载: {plugin_name}")
try:
result = await client.call_tool(tool_name, arguments)
result = await session.client.call_tool(tool_name, arguments)
logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}")
# logger.info(f"✅ 工具返回内容: {result}")
# 调用成功,重置状态(如果之前是degraded)
if session.status == "degraded":
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
return result
except Exception as e:
logger.error(f"❌ 工具调用失败: {plugin_name}.{tool_name}, 错误: {e}")
# 增加错误计数
session.error_count += 1
# 根据错误率更新状态
if session.request_count > 0:
error_rate = session.error_count / session.request_count
if error_rate > 0.5:
session.status = "error"
elif error_rate > 0.3:
session.status = "degraded"
logger.error(
f"❌ 工具调用失败: {plugin_name}.{tool_name}, "
f"错误: {e} (错误计数: {session.error_count}/{session.request_count})"
)
raise
async def get_plugin_tools(
@@ -320,7 +477,7 @@ class MCPPluginRegistry:
async def cleanup_all(self):
"""清理所有插件和资源"""
# 停止后台清理任务
# 停止后台任务
if self._cleanup_task:
self._cleanup_task.cancel()
try:
@@ -328,19 +485,18 @@ class MCPPluginRegistry:
except asyncio.CancelledError:
pass
# 清理所有客户端
plugin_ids = list(self._clients.keys())
for plugin_id in plugin_ids:
user_id = plugin_id.split(':', 1)[0]
user_lock = await self._get_user_lock(user_id)
async with user_lock:
await self._unload_plugin_unsafe(plugin_id)
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
# 关闭共享HTTP客户端
try:
await self._shared_http_client.aclose()
except Exception as e:
logger.error(f"关闭共享HTTP客户端失败: {e}")
# 清理所有会话
async with self._sessions_lock:
plugin_ids = list(self._sessions.keys())
for plugin_id in plugin_ids:
await self._unload_plugin_unsafe(plugin_id)
logger.info("✅ 已清理所有MCP插件和资源")