+162
-48
@@ -2,13 +2,14 @@
|
||||
|
||||
重构后使用统一的MCPClientFacade门面来管理所有MCP操作。
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import asyncio
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Request, BackgroundTasks
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import select, update
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.database import get_db
|
||||
from app.database import get_db, get_engine
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.schemas.mcp_plugin import (
|
||||
MCPPluginCreate,
|
||||
@@ -36,6 +37,81 @@ def require_login(request: Request) -> User:
|
||||
return request.state.user
|
||||
|
||||
|
||||
async def _register_plugin_background(
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
plugin_type: str,
|
||||
server_url: str,
|
||||
headers: Optional[dict],
|
||||
config: Optional[dict]
|
||||
):
|
||||
"""
|
||||
后台任务:注册MCP插件并更新数据库状态
|
||||
|
||||
在独立的任务中执行MCP连接,避免阻塞请求处理
|
||||
"""
|
||||
try:
|
||||
logger.info(f"后台注册MCP插件: {plugin_name}")
|
||||
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
url=server_url,
|
||||
plugin_type=plugin_type,
|
||||
headers=headers,
|
||||
timeout=config.get('timeout', 60.0) if config else 60.0
|
||||
))
|
||||
else:
|
||||
success = False
|
||||
|
||||
# 更新数据库状态
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
|
||||
.values(
|
||||
status="active" if success else "error",
|
||||
last_error=None if success else "连接失败"
|
||||
)
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
if success:
|
||||
logger.info(f"后台注册MCP插件成功: {plugin_name}")
|
||||
else:
|
||||
logger.warning(f"后台注册MCP插件失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"后台注册MCP插件异常: {plugin_name}, 错误: {e}")
|
||||
try:
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
|
||||
.values(status="error", last_error=str(e))
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
except Exception as db_error:
|
||||
logger.error(f"更新插件状态失败: {db_error}")
|
||||
|
||||
|
||||
async def _unregister_plugin_safe(user_id: str, plugin_name: str):
|
||||
"""安全地在后台注销MCP插件"""
|
||||
try:
|
||||
await mcp_client.unregister(user_id, plugin_name)
|
||||
logger.info(f"后台注销MCP插件成功: {plugin_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"后台注销MCP插件出错: {plugin_name}, 错误: {e}")
|
||||
|
||||
|
||||
async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
|
||||
"""
|
||||
将插件注册到统一门面
|
||||
@@ -228,30 +304,31 @@ async def create_plugin_simple(
|
||||
# 更新字段
|
||||
for key, value in plugin_data.items():
|
||||
setattr(existing, key, value)
|
||||
|
||||
|
||||
# 设置为pending状态,等待后台连接
|
||||
if plugin_data.get("enabled"):
|
||||
existing.status = "pending"
|
||||
|
||||
plugin = existing
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 数据库完成后进行MCP操作
|
||||
|
||||
# 后台执行MCP操作(不阻塞请求)
|
||||
if old_enabled:
|
||||
try:
|
||||
await mcp_client.unregister(user.user_id, old_plugin_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"注销旧插件出错: {e}")
|
||||
|
||||
# 注销旧插件(使用create_task在后台执行)
|
||||
asyncio.create_task(_unregister_plugin_safe(user.user_id, old_plugin_name))
|
||||
|
||||
if plugin.enabled:
|
||||
try:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
|
||||
# 后台注册新插件
|
||||
asyncio.create_task(_register_plugin_background(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
plugin_type=plugin.plugin_type,
|
||||
server_url=plugin.server_url,
|
||||
headers=plugin.headers,
|
||||
config=plugin.config
|
||||
))
|
||||
|
||||
logger.info(f"用户 {user.user_id} 更新插件: {plugin_name}")
|
||||
else:
|
||||
# 创建新插件
|
||||
@@ -259,24 +336,26 @@ async def create_plugin_simple(
|
||||
user_id=user.user_id,
|
||||
**plugin_data
|
||||
)
|
||||
|
||||
|
||||
# 设置为pending状态,等待后台连接
|
||||
if plugin_data.get("enabled"):
|
||||
plugin.status = "pending"
|
||||
|
||||
db.add(plugin)
|
||||
await db.commit()
|
||||
await db.refresh(plugin)
|
||||
|
||||
# 数据库完成后进行MCP操作
|
||||
|
||||
# 后台执行MCP注册(不阻塞请求)
|
||||
if plugin.enabled:
|
||||
try:
|
||||
success = await _register_plugin_to_facade(plugin, user.user_id)
|
||||
plugin.status = "active" if success else "error"
|
||||
plugin.last_error = None if success else "加载失败"
|
||||
await db.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件失败: {e}")
|
||||
plugin.status = "error"
|
||||
plugin.last_error = str(e)
|
||||
await db.commit()
|
||||
|
||||
asyncio.create_task(_register_plugin_background(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
plugin_type=plugin.plugin_type,
|
||||
server_url=plugin.server_url,
|
||||
headers=plugin.headers,
|
||||
config=plugin.config
|
||||
))
|
||||
|
||||
logger.info(f"用户 {user.user_id} 通过简化配置创建插件: {plugin_name}")
|
||||
|
||||
return plugin
|
||||
@@ -465,10 +544,11 @@ async def test_plugin(
|
||||
):
|
||||
"""
|
||||
测试插件连接并调用工具验证功能
|
||||
|
||||
使用MCPTestService进行测试
|
||||
|
||||
使用MCPTestService进行测试。
|
||||
如果插件会话尚未建立,会先在后台注册,需要再次调用测试。
|
||||
"""
|
||||
|
||||
|
||||
result = await db.execute(
|
||||
select(MCPPlugin).where(
|
||||
MCPPlugin.id == plugin_id,
|
||||
@@ -476,10 +556,10 @@ async def test_plugin(
|
||||
)
|
||||
)
|
||||
plugin = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if not plugin:
|
||||
raise HTTPException(status_code=404, detail="插件不存在")
|
||||
|
||||
|
||||
if not plugin.enabled:
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
@@ -487,11 +567,45 @@ async def test_plugin(
|
||||
error="请先启用插件",
|
||||
suggestions=["点击开关按钮启用插件"]
|
||||
)
|
||||
|
||||
# 使用测试服务
|
||||
|
||||
# 检查会话是否已注册
|
||||
is_registered = mcp_client.is_registered(user.user_id, plugin.plugin_name)
|
||||
session_status = mcp_client.get_session_status(user.user_id, plugin.plugin_name)
|
||||
|
||||
if not is_registered:
|
||||
# 会话不存在或状态异常,需要在后台注册
|
||||
logger.info(f"插件 {plugin.plugin_name} 会话不存在(状态: {session_status}),启动后台注册")
|
||||
|
||||
# 更新数据库状态为pending
|
||||
plugin.status = "pending"
|
||||
plugin.last_error = None
|
||||
await db.commit()
|
||||
|
||||
# 在后台注册插件
|
||||
asyncio.create_task(_register_plugin_background(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
plugin_type=plugin.plugin_type,
|
||||
server_url=plugin.server_url,
|
||||
headers=plugin.headers,
|
||||
config=plugin.config
|
||||
))
|
||||
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="正在建立连接...",
|
||||
error="插件会话正在初始化,请稍后重试",
|
||||
suggestions=[
|
||||
"插件正在连接MCP服务器",
|
||||
"请等待2-3秒后再次点击测试",
|
||||
"如果持续失败,请检查服务器地址是否正确"
|
||||
]
|
||||
)
|
||||
|
||||
# 会话已存在,直接执行测试
|
||||
try:
|
||||
test_result = await mcp_test_service.test_plugin_with_ai(plugin, user, db)
|
||||
|
||||
|
||||
# 更新插件状态
|
||||
if test_result.success:
|
||||
plugin.status = "active"
|
||||
@@ -499,12 +613,12 @@ async def test_plugin(
|
||||
else:
|
||||
plugin.status = "error"
|
||||
plugin.last_error = test_result.error
|
||||
|
||||
|
||||
plugin.last_test_at = datetime.now()
|
||||
await db.commit()
|
||||
|
||||
|
||||
return test_result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
|
||||
plugin.status = "error"
|
||||
|
||||
@@ -307,18 +307,18 @@ class MCPClientFacade:
|
||||
是否注册成功
|
||||
"""
|
||||
self._ensure_background_tasks()
|
||||
|
||||
|
||||
key = self._get_key(config.user_id, config.plugin_name)
|
||||
user_lock = await self._get_user_lock(config.user_id)
|
||||
|
||||
|
||||
async with user_lock:
|
||||
# 如果已存在,先关闭
|
||||
if key in self._sessions:
|
||||
await self._close_session_unsafe(key)
|
||||
|
||||
|
||||
try:
|
||||
logger.info(f"🔗 连接MCP服务器: {config.plugin_name} -> {config.url} (类型: {config.plugin_type})")
|
||||
|
||||
|
||||
# 根据类型选择客户端
|
||||
if config.plugin_type == "sse":
|
||||
# SSE 客户端 - 返回 2 个值
|
||||
@@ -357,9 +357,19 @@ class MCPClientFacade:
|
||||
logger.info(f"✅ MCP会话建立成功: {key}")
|
||||
await self._emit_status_change(config.user_id, config.plugin_name, "inactive", "active", "连接成功")
|
||||
return True
|
||||
|
||||
|
||||
except ExceptionGroup as eg:
|
||||
# 处理 TaskGroup 的异常组,提取详细错误信息
|
||||
error_details = []
|
||||
for exc in eg.exceptions:
|
||||
error_details.append(f"{type(exc).__name__}: {exc}")
|
||||
error_msg = "; ".join(error_details)
|
||||
logger.error(f"❌ MCP连接失败 {key}: TaskGroup异常 - {error_msg}")
|
||||
await self._emit_status_change(config.user_id, config.plugin_name, "inactive", "error", error_msg)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MCP连接失败 {key}: {e}")
|
||||
logger.error(f"❌ MCP连接失败 {key}: {type(e).__name__}: {e}")
|
||||
await self._emit_status_change(config.user_id, config.plugin_name, "inactive", "error", str(e))
|
||||
return False
|
||||
|
||||
@@ -428,7 +438,37 @@ class MCPClientFacade:
|
||||
info.last_access = time.time()
|
||||
info.request_count += 1
|
||||
return info.session
|
||||
|
||||
|
||||
def is_registered(self, user_id: str, plugin_name: str) -> bool:
|
||||
"""
|
||||
检查插件是否已注册(同步方法,仅检查内存状态)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
是否已注册且状态正常
|
||||
"""
|
||||
key = self._get_key(user_id, plugin_name)
|
||||
info = self._sessions.get(key)
|
||||
return info is not None and info.status != "error"
|
||||
|
||||
def get_session_status(self, user_id: str, plugin_name: str) -> Optional[str]:
|
||||
"""
|
||||
获取会话状态(同步方法)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
会话状态,如果不存在返回 None
|
||||
"""
|
||||
key = self._get_key(user_id, plugin_name)
|
||||
info = self._sessions.get(key)
|
||||
return info.status if info else None
|
||||
|
||||
async def ensure_registered(
|
||||
self,
|
||||
user_id: str,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
将内存中的会话状态变更同步到数据库,确保状态一致性。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
@@ -12,22 +13,42 @@ from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 状态同步队列
|
||||
_sync_queue: asyncio.Queue = None
|
||||
_sync_task: asyncio.Task = None
|
||||
|
||||
async def sync_status_to_db(event: Dict[str, Any]):
|
||||
"""
|
||||
状态变更回调 - 同步到数据库
|
||||
"""
|
||||
|
||||
async def _sync_worker():
|
||||
"""后台状态同步工作线程"""
|
||||
global _sync_queue
|
||||
|
||||
while True:
|
||||
try:
|
||||
event = await _sync_queue.get()
|
||||
if event is None: # 停止信号
|
||||
break
|
||||
|
||||
await _do_sync_status(event)
|
||||
_sync_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"状态同步工作线程异常: {e}")
|
||||
|
||||
|
||||
async def _do_sync_status(event: Dict[str, Any]):
|
||||
"""实际执行状态同步"""
|
||||
user_id = event["user_id"]
|
||||
plugin_name = event["plugin_name"]
|
||||
new_status = event["new_status"]
|
||||
reason = event.get("reason", "")
|
||||
|
||||
|
||||
try:
|
||||
from app.database import get_engine
|
||||
|
||||
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
@@ -36,15 +57,38 @@ async def sync_status_to_db(event: Dict[str, Any]):
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
|
||||
logger.debug(f"✅ 状态已同步到数据库: {plugin_name} -> {new_status}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 状态同步失败: {plugin_name}, 错误: {e}")
|
||||
|
||||
|
||||
async def sync_status_to_db(event: Dict[str, Any]):
|
||||
"""
|
||||
状态变更回调 - 将事件加入队列异步同步到数据库
|
||||
|
||||
使用队列异步处理,避免在请求处理过程中阻塞或产生数据库连接冲突
|
||||
"""
|
||||
global _sync_queue, _sync_task
|
||||
|
||||
# 延迟初始化队列和工作线程
|
||||
if _sync_queue is None:
|
||||
_sync_queue = asyncio.Queue()
|
||||
|
||||
if _sync_task is None or _sync_task.done():
|
||||
_sync_task = asyncio.create_task(_sync_worker())
|
||||
logger.info("✅ MCP状态同步工作线程已启动")
|
||||
|
||||
# 将事件加入队列(非阻塞)
|
||||
try:
|
||||
_sync_queue.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(f"状态同步队列已满,丢弃事件: {event['plugin_name']}")
|
||||
|
||||
|
||||
def register_status_sync():
|
||||
"""注册状态同步回调到MCP客户端"""
|
||||
from app.mcp import mcp_client
|
||||
mcp_client.register_status_callback(sync_status_to_db)
|
||||
logger.info("✅ MCP状态同步服务已注册")
|
||||
logger.info("✅ MCP状态同步服务已注册")
|
||||
|
||||
@@ -25,32 +25,20 @@ logger = get_logger(__name__)
|
||||
|
||||
class MCPTestService:
|
||||
"""MCP插件测试服务(使用统一门面重构)"""
|
||||
|
||||
async def _ensure_plugin_registered(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
user_id: str
|
||||
) -> bool:
|
||||
|
||||
def _check_plugin_registered(self, plugin: MCPPlugin, user_id: str) -> bool:
|
||||
"""
|
||||
确保插件已注册到统一门面
|
||||
|
||||
检查插件是否已注册(同步方法,不触发新的连接)
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user_id: 用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
是否已注册
|
||||
"""
|
||||
if plugin.plugin_type in ("http", "streamable_http", "sse") and plugin.server_url:
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
return False
|
||||
|
||||
return mcp_client.is_registered(user_id, plugin.plugin_name)
|
||||
|
||||
async def test_plugin_connection(
|
||||
self,
|
||||
plugin: MCPPlugin,
|
||||
@@ -58,27 +46,28 @@ class MCPTestService:
|
||||
) -> MCPTestResult:
|
||||
"""
|
||||
简单连接测试
|
||||
|
||||
|
||||
注意:调用此方法前,需要确保插件已通过后台任务注册。
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
user_id: 用户ID
|
||||
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# 确保插件已注册
|
||||
registered = await self._ensure_plugin_registered(plugin, user_id)
|
||||
if not registered:
|
||||
# 检查插件是否已注册(不触发新连接)
|
||||
if not self._check_plugin_registered(plugin, user_id):
|
||||
return MCPTestResult(
|
||||
success=False,
|
||||
message="插件注册失败",
|
||||
error="无法创建MCP客户端",
|
||||
suggestions=["请检查插件配置", "请确认服务器URL正确"]
|
||||
message="插件未注册",
|
||||
error="MCP会话不存在,请先启用插件",
|
||||
suggestions=["请先启用插件", "如果已启用,请稍等片刻后重试"]
|
||||
)
|
||||
|
||||
|
||||
# 使用统一门面测试连接
|
||||
test_result = await mcp_client.test_connection(user_id, plugin.plugin_name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user