Files
2026-01-09 17:13:19 +08:00

235 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""MCP工具加载器 - 统一的工具获取入口
在AI请求之前,自动检查用户MCP配置并加载可用工具。
"""
from typing import Optional, List, Dict, Any
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.logger import get_logger
from app.models.mcp_plugin import MCPPlugin
from app.mcp import mcp_client
logger = get_logger(__name__)
@dataclass
class UserToolsCache:
"""用户工具缓存条目"""
tools: Optional[List[Dict[str, Any]]]
expire_time: datetime
hit_count: int = 0
class MCPToolsLoader:
"""
MCP工具加载器
负责:
1. 检查用户是否配置并启用了MCP插件
2. 从各个启用的插件加载工具列表
3. 将工具转换为OpenAI Function Calling格式
4. 缓存结果以提升性能
"""
_instance: Optional['MCPToolsLoader'] = None
def __new__(cls):
"""单例模式"""
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# 用户工具缓存: user_id -> UserToolsCache
self._cache: Dict[str, UserToolsCache] = {}
# 缓存TTL5分钟)
self._cache_ttl = timedelta(minutes=5)
self._initialized = True
logger.info("✅ MCPToolsLoader 初始化完成")
async def has_enabled_plugins(
self,
user_id: str,
db_session: AsyncSession
) -> bool:
"""
检查用户是否有启用的MCP插件
Args:
user_id: 用户ID
db_session: 数据库会话
Returns:
是否有启用的插件
"""
try:
query = select(MCPPlugin.id).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True,
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
).limit(1)
result = await db_session.execute(query)
return result.scalar() is not None
except Exception as e:
logger.warning(f"检查用户MCP插件失败: {e}")
return False
async def get_user_tools(
self,
user_id: str,
db_session: AsyncSession,
use_cache: bool = True,
force_refresh: bool = False
) -> Optional[List[Dict[str, Any]]]:
"""
获取用户的MCP工具列表(OpenAI格式)
Args:
user_id: 用户ID
db_session: 数据库会话
use_cache: 是否使用缓存
force_refresh: 是否强制刷新
Returns:
- None: 用户未配置或未启用任何MCP插件
- []: 有配置但没有可用工具
- List[Dict]: OpenAI Function Calling格式的工具列表
"""
now = datetime.now()
# 检查缓存
if use_cache and not force_refresh and user_id in self._cache:
cache_entry = self._cache[user_id]
if now < cache_entry.expire_time:
cache_entry.hit_count += 1
logger.debug(f"🎯 用户工具缓存命中: {user_id} (命中次数: {cache_entry.hit_count})")
return cache_entry.tools
else:
del self._cache[user_id]
logger.debug(f"⏰ 用户工具缓存过期: {user_id}")
# 从数据库加载
try:
tools = await self._load_user_tools(user_id, db_session)
# 更新缓存
self._cache[user_id] = UserToolsCache(
tools=tools,
expire_time=now + self._cache_ttl
)
if tools:
logger.info(f"🔧 用户 {user_id} 加载了 {len(tools)} 个MCP工具")
else:
logger.debug(f"📭 用户 {user_id} 没有可用的MCP工具")
return tools
except Exception as e:
logger.error(f"❌ 加载用户MCP工具失败: {e}")
return None
async def _load_user_tools(
self,
user_id: str,
db_session: AsyncSession
) -> Optional[List[Dict[str, Any]]]:
"""
从数据库加载用户启用的MCP插件并获取工具
"""
# 查询启用的插件
query = select(MCPPlugin).where(
MCPPlugin.user_id == user_id,
MCPPlugin.enabled == True,
MCPPlugin.plugin_type.in_(["http", "streamable_http", "sse"])
).order_by(MCPPlugin.sort_order)
result = await db_session.execute(query)
plugins = result.scalars().all()
if not plugins:
return None
all_tools = []
for plugin in plugins:
try:
# 确定插件类型
plugin_type = plugin.plugin_type
if plugin_type == "http":
plugin_type = "streamable_http" # 默认使用streamable_http
# 确保插件已注册到MCP客户端
await mcp_client.ensure_registered(
user_id=user_id,
plugin_name=plugin.plugin_name,
url=plugin.server_url,
plugin_type=plugin_type,
headers=plugin.headers
)
# 获取工具列表
plugin_tools = await mcp_client.get_tools(user_id, plugin.plugin_name)
# 转换为OpenAI格式
formatted = mcp_client.format_tools_for_openai(plugin_tools, plugin.plugin_name)
all_tools.extend(formatted)
logger.debug(f"✅ 从插件 {plugin.plugin_name} 加载了 {len(formatted)} 个工具")
except Exception as e:
logger.warning(f"⚠️ 加载插件 {plugin.plugin_name} 工具失败: {e}")
continue
return all_tools if all_tools else None
def invalidate_cache(self, user_id: Optional[str] = None):
"""
使缓存失效
Args:
user_id: 用户ID,为None时清空所有缓存
"""
if user_id:
if user_id in self._cache:
del self._cache[user_id]
logger.debug(f"🧹 清理用户工具缓存: {user_id}")
else:
count = len(self._cache)
self._cache.clear()
logger.info(f"🧹 清理所有用户工具缓存 ({count}个)")
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计"""
now = datetime.now()
return {
"total_entries": len(self._cache),
"total_hits": sum(e.hit_count for e in self._cache.values()),
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"user_id": uid,
"tools_count": len(e.tools) if e.tools else 0,
"hit_count": e.hit_count,
"expired": now >= e.expire_time,
"expire_time": e.expire_time.isoformat()
}
for uid, e in self._cache.items()
]
}
# 全局单例
mcp_tools_loader = MCPToolsLoader()