235 lines
7.5 KiB
Python
235 lines
7.5 KiB
Python
"""MCP工具加载器 - 统一的工具获取入口
|
||
|
||
在AI请求之前,自动检查用户MCP配置并加载可用工具。
|
||
"""
|
||
from typing import Optional, List, Dict, Any
|
||
from datetime import datetime, timedelta
|
||
from dataclasses import dataclass, field
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.logger import get_logger
|
||
from app.models.mcp_plugin import MCPPlugin
|
||
from app.mcp import mcp_client
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class UserToolsCache:
|
||
"""用户工具缓存条目"""
|
||
tools: Optional[List[Dict[str, Any]]]
|
||
expire_time: datetime
|
||
hit_count: int = 0
|
||
|
||
|
||
class MCPToolsLoader:
|
||
"""
|
||
MCP工具加载器
|
||
|
||
负责:
|
||
1. 检查用户是否配置并启用了MCP插件
|
||
2. 从各个启用的插件加载工具列表
|
||
3. 将工具转换为OpenAI Function Calling格式
|
||
4. 缓存结果以提升性能
|
||
"""
|
||
|
||
_instance: Optional['MCPToolsLoader'] = None
|
||
|
||
def __new__(cls):
|
||
"""单例模式"""
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
cls._instance._initialized = False
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
# 用户工具缓存: user_id -> UserToolsCache
|
||
self._cache: Dict[str, UserToolsCache] = {}
|
||
|
||
# 缓存TTL(5分钟)
|
||
self._cache_ttl = timedelta(minutes=5)
|
||
|
||
self._initialized = True
|
||
logger.info("✅ MCPToolsLoader 初始化完成")
|
||
|
||
async def has_enabled_plugins(
|
||
self,
|
||
user_id: str,
|
||
db_session: AsyncSession
|
||
) -> bool:
|
||
"""
|
||
检查用户是否有启用的MCP插件
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
db_session: 数据库会话
|
||
|
||
Returns:
|
||
是否有启用的插件
|
||
"""
|
||
try:
|
||
query = select(MCPPlugin.id).where(
|
||
MCPPlugin.user_id == user_id,
|
||
MCPPlugin.enabled == True,
|
||
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
||
).limit(1)
|
||
|
||
result = await db_session.execute(query)
|
||
return result.scalar() is not None
|
||
|
||
except Exception as e:
|
||
logger.warning(f"检查用户MCP插件失败: {e}")
|
||
return False
|
||
|
||
async def get_user_tools(
|
||
self,
|
||
user_id: str,
|
||
db_session: AsyncSession,
|
||
use_cache: bool = True,
|
||
force_refresh: bool = False
|
||
) -> Optional[List[Dict[str, Any]]]:
|
||
"""
|
||
获取用户的MCP工具列表(OpenAI格式)
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
db_session: 数据库会话
|
||
use_cache: 是否使用缓存
|
||
force_refresh: 是否强制刷新
|
||
|
||
Returns:
|
||
- None: 用户未配置或未启用任何MCP插件
|
||
- []: 有配置但没有可用工具
|
||
- List[Dict]: OpenAI Function Calling格式的工具列表
|
||
"""
|
||
now = datetime.now()
|
||
|
||
# 检查缓存
|
||
if use_cache and not force_refresh and user_id in self._cache:
|
||
cache_entry = self._cache[user_id]
|
||
if now < cache_entry.expire_time:
|
||
cache_entry.hit_count += 1
|
||
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
|
||
return cache_entry.tools
|
||
else:
|
||
del self._cache[user_id]
|
||
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
|
||
|
||
# 从数据库加载
|
||
try:
|
||
tools = await self._load_user_tools(user_id, db_session)
|
||
|
||
# 更新缓存
|
||
self._cache[user_id] = UserToolsCache(
|
||
tools=tools,
|
||
expire_time=now + self._cache_ttl
|
||
)
|
||
|
||
if tools:
|
||
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
|
||
else:
|
||
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
|
||
|
||
return tools
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 加载用户MCP工具失败: {e}")
|
||
return None
|
||
|
||
async def _load_user_tools(
|
||
self,
|
||
user_id: str,
|
||
db_session: AsyncSession
|
||
) -> Optional[List[Dict[str, Any]]]:
|
||
"""
|
||
从数据库加载用户启用的MCP插件并获取工具
|
||
"""
|
||
# 查询启用的插件
|
||
query = select(MCPPlugin).where(
|
||
MCPPlugin.user_id == user_id,
|
||
MCPPlugin.enabled == True,
|
||
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
|
||
).order_by(MCPPlugin.sort_order)
|
||
|
||
result = await db_session.execute(query)
|
||
plugins = result.scalars().all()
|
||
|
||
if not plugins:
|
||
return None
|
||
|
||
all_tools = []
|
||
|
||
for plugin in plugins:
|
||
try:
|
||
# 确定插件类型
|
||
plugin_type = plugin.plugin_type
|
||
if plugin_type == "http":
|
||
plugin_type = "streamable_http" # 默认使用streamable_http
|
||
|
||
# 确保插件已注册到MCP客户端
|
||
await mcp_client.ensure_registered(
|
||
user_id=user_id,
|
||
plugin_name=plugin.plugin_name,
|
||
url=plugin.server_url,
|
||
plugin_type=plugin_type,
|
||
headers=plugin.headers
|
||
)
|
||
|
||
# 获取工具列表
|
||
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
|
||
|
||
# 转换为OpenAI格式
|
||
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
|
||
all_tools.extend(formatted)
|
||
|
||
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
|
||
continue
|
||
|
||
return all_tools if all_tools else None
|
||
|
||
def invalidate_cache(self, user_id: Optional[str] = None):
|
||
"""
|
||
使缓存失效
|
||
|
||
Args:
|
||
user_id: 用户ID,为None时清空所有缓存
|
||
"""
|
||
if user_id:
|
||
if user_id in self._cache:
|
||
del self._cache[user_id]
|
||
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
|
||
else:
|
||
count = len(self._cache)
|
||
self._cache.clear()
|
||
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
|
||
|
||
def get_cache_stats(self) -> Dict[str, Any]:
|
||
"""获取缓存统计"""
|
||
now = datetime.now()
|
||
return {
|
||
"total_entries": len(self._cache),
|
||
"total_hits": sum(e.hit_count for e in self._cache.values()),
|
||
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
|
||
"entries": [
|
||
{
|
||
"user_id": uid,
|
||
"tools_count": len(e.tools) if e.tools else 0,
|
||
"hit_count": e.hit_count,
|
||
"expired": now >= e.expire_time,
|
||
"expire_time": e.expire_time.isoformat()
|
||
}
|
||
for uid, e in self._cache.items()
|
||
]
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
mcp_tools_loader = MCPToolsLoader() |