235 lines
7.5 KiB
Python
235 lines
7.5 KiB
Python
|
|
"""MCP工具加载器 - 统一的工具获取入口
|
|||
|
|
|
|||
|
|
在AI请求之前,自动检查用户MCP配置并加载可用工具。
|
|||
|
|
"""
|
|||
|
|
from typing import Optional, List, Dict, Any
|
|||
|
|
from datetime import datetime, timedelta
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
|
|||
|
|
from sqlalchemy import select
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
|
|
|||
|
|
from app.logger import get_logger
|
|||
|
|
from app.models.mcp_plugin import MCPPlugin
|
|||
|
|
from app.mcp import mcp_client
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class UserToolsCache:
|
|||
|
|
"""用户工具缓存条目"""
|
|||
|
|
tools: Optional[List[Dict[str, Any]]]
|
|||
|
|
expire_time: datetime
|
|||
|
|
hit_count: int = 0
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MCPToolsLoader:
|
|||
|
|
"""
|
|||
|
|
MCP工具加载器
|
|||
|
|
|
|||
|
|
负责:
|
|||
|
|
1. 检查用户是否配置并启用了MCP插件
|
|||
|
|
2. 从各个启用的插件加载工具列表
|
|||
|
|
3. 将工具转换为OpenAI Function Calling格式
|
|||
|
|
4. 缓存结果以提升性能
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
_instance: Optional['MCPToolsLoader'] = None
|
|||
|
|
|
|||
|
|
def __new__(cls):
|
|||
|
|
"""单例模式"""
|
|||
|
|
if cls._instance is None:
|
|||
|
|
cls._instance = super().__new__(cls)
|
|||
|
|
cls._instance._initialized = False
|
|||
|
|
return cls._instance
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
if self._initialized:
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 用户工具缓存: user_id -> UserToolsCache
|
|||
|
|
self._cache: Dict[str, UserToolsCache] = {}
|
|||
|
|
|
|||
|
|
# 缓存TTL(5分钟)
|
|||
|
|
self._cache_ttl = timedelta(minutes=5)
|
|||
|
|
|
|||
|
|
self._initialized = True
|
|||
|
|
logger.info("✅ MCPToolsLoader 初始化完成")
|
|||
|
|
|
|||
|
|
async def has_enabled_plugins(
|
|||
|
|
self,
|
|||
|
|
user_id: str,
|
|||
|
|
db_session: AsyncSession
|
|||
|
|
) -> bool:
|
|||
|
|
"""
|
|||
|
|
检查用户是否有启用的MCP插件
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_id: 用户ID
|
|||
|
|
db_session: 数据库会话
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
是否有启用的插件
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
query = select(MCPPlugin.id).where(
|
|||
|
|
MCPPlugin.user_id == user_id,
|
|||
|
|
MCPPlugin.enabled == True,
|
|||
|
|
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
|||
|
|
).limit(1)
|
|||
|
|
|
|||
|
|
result = await db_session.execute(query)
|
|||
|
|
return result.scalar() is not None
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"检查用户MCP插件失败: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
async def get_user_tools(
|
|||
|
|
self,
|
|||
|
|
user_id: str,
|
|||
|
|
db_session: AsyncSession,
|
|||
|
|
use_cache: bool = True,
|
|||
|
|
force_refresh: bool = False
|
|||
|
|
) -> Optional[List[Dict[str, Any]]]:
|
|||
|
|
"""
|
|||
|
|
获取用户的MCP工具列表(OpenAI格式)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_id: 用户ID
|
|||
|
|
db_session: 数据库会话
|
|||
|
|
use_cache: 是否使用缓存
|
|||
|
|
force_refresh: 是否强制刷新
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
- None: 用户未配置或未启用任何MCP插件
|
|||
|
|
- []: 有配置但没有可用工具
|
|||
|
|
- List[Dict]: OpenAI Function Calling格式的工具列表
|
|||
|
|
"""
|
|||
|
|
now = datetime.now()
|
|||
|
|
|
|||
|
|
# 检查缓存
|
|||
|
|
if use_cache and not force_refresh and user_id in self._cache:
|
|||
|
|
cache_entry = self._cache[user_id]
|
|||
|
|
if now < cache_entry.expire_time:
|
|||
|
|
cache_entry.hit_count += 1
|
|||
|
|
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
|
|||
|
|
return cache_entry.tools
|
|||
|
|
else:
|
|||
|
|
del self._cache[user_id]
|
|||
|
|
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
|
|||
|
|
|
|||
|
|
# 从数据库加载
|
|||
|
|
try:
|
|||
|
|
tools = await self._load_user_tools(user_id, db_session)
|
|||
|
|
|
|||
|
|
# 更新缓存
|
|||
|
|
self._cache[user_id] = UserToolsCache(
|
|||
|
|
tools=tools,
|
|||
|
|
expire_time=now + self._cache_ttl
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if tools:
|
|||
|
|
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
|
|||
|
|
else:
|
|||
|
|
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
|
|||
|
|
|
|||
|
|
return tools
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"❌ 加载用户MCP工具失败: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
async def _load_user_tools(
|
|||
|
|
self,
|
|||
|
|
user_id: str,
|
|||
|
|
db_session: AsyncSession
|
|||
|
|
) -> Optional[List[Dict[str, Any]]]:
|
|||
|
|
"""
|
|||
|
|
从数据库加载用户启用的MCP插件并获取工具
|
|||
|
|
"""
|
|||
|
|
# 查询启用的插件
|
|||
|
|
query = select(MCPPlugin).where(
|
|||
|
|
MCPPlugin.user_id == user_id,
|
|||
|
|
MCPPlugin.enabled == True,
|
|||
|
|
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
|||
|
|
).order_by(MCPPlugin.sort_order)
|
|||
|
|
|
|||
|
|
result = await db_session.execute(query)
|
|||
|
|
plugins = result.scalars().all()
|
|||
|
|
|
|||
|
|
if not plugins:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
all_tools = []
|
|||
|
|
|
|||
|
|
for plugin in plugins:
|
|||
|
|
try:
|
|||
|
|
# 确定插件类型
|
|||
|
|
plugin_type = plugin.plugin_type
|
|||
|
|
if plugin_type == "http":
|
|||
|
|
plugin_type = "streamable_http" # 默认使用streamable_http
|
|||
|
|
|
|||
|
|
# 确保插件已注册到MCP客户端
|
|||
|
|
await mcp_client.ensure_registered(
|
|||
|
|
user_id=user_id,
|
|||
|
|
plugin_name=plugin.plugin_name,
|
|||
|
|
url=plugin.server_url,
|
|||
|
|
plugin_type=plugin_type,
|
|||
|
|
headers=plugin.headers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取工具列表
|
|||
|
|
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
|
|||
|
|
|
|||
|
|
# 转换为OpenAI格式
|
|||
|
|
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
|
|||
|
|
all_tools.extend(formatted)
|
|||
|
|
|
|||
|
|
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
return all_tools if all_tools else None
|
|||
|
|
|
|||
|
|
def invalidate_cache(self, user_id: Optional[str] = None):
|
|||
|
|
"""
|
|||
|
|
使缓存失效
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
user_id: 用户ID,为None时清空所有缓存
|
|||
|
|
"""
|
|||
|
|
if user_id:
|
|||
|
|
if user_id in self._cache:
|
|||
|
|
del self._cache[user_id]
|
|||
|
|
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
|
|||
|
|
else:
|
|||
|
|
count = len(self._cache)
|
|||
|
|
self._cache.clear()
|
|||
|
|
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
|
|||
|
|
|
|||
|
|
def get_cache_stats(self) -> Dict[str, Any]:
|
|||
|
|
"""获取缓存统计"""
|
|||
|
|
now = datetime.now()
|
|||
|
|
return {
|
|||
|
|
"total_entries": len(self._cache),
|
|||
|
|
"total_hits": sum(e.hit_count for e in self._cache.values()),
|
|||
|
|
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
|
|||
|
|
"entries": [
|
|||
|
|
{
|
|||
|
|
"user_id": uid,
|
|||
|
|
"tools_count": len(e.tools) if e.tools else 0,
|
|||
|
|
"hit_count": e.hit_count,
|
|||
|
|
"expired": now >= e.expire_time,
|
|||
|
|
"expire_time": e.expire_time.isoformat()
|
|||
|
|
}
|
|||
|
|
for uid, e in self._cache.items()
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全局单例
|
|||
|
|
mcp_tools_loader = MCPToolsLoader()
|