Files
MuMuAINovel/backend/app/mcp/adapters/universal.py
T
xiamuceer 69e3e46c96 feat: 优化MCP工具调用体验并集成通用适配器
- 静默检查MCP工具可用性,支持提示词注入调用mcp
- 集成UniversalMCPAdapter,支持自动API能力检测和智能降级
- 新增MCP适配器配置项,增强系统兼容性和健壮性
2025-11-24 11:30:27 +08:00

353 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
import time
import asyncio
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
from app.mcp.adapters.function_calling import FunctionCallingAdapter
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class APICapability:
"""API能力检测结果"""
supports_function_calling: bool
tested_at: datetime
test_duration_ms: float
error_message: Optional[str] = None
class UniversalMCPAdapter:
"""
通用MCP适配器管理器
功能:
1. 自动检测API是否支持Function Calling
2. 缓存检测结果
3. 自动降级策略:FC失败时切换到提示词注入
4. 提供统一接口
"""
def __init__(
self,
cache_ttl_hours: int = 24,
enable_auto_fallback: bool = True
):
"""
初始化通用适配器
Args:
cache_ttl_hours: 能力检测缓存时长(小时)
enable_auto_fallback: 是否启用自动降级
"""
# 适配器实例
self.adapters = {
AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
}
# API能力缓存: {api_identifier: APICapability}
self._capability_cache: Dict[str, APICapability] = {}
self._cache_ttl = timedelta(hours=cache_ttl_hours)
self._cache_lock = asyncio.Lock()
# 配置
self._enable_auto_fallback = enable_auto_fallback
logger.info(
f"✅ UniversalMCPAdapter初始化完成 "
f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
)
async def get_adapter(
self,
api_identifier: str,
test_function: Optional[callable] = None
) -> BaseMCPAdapter:
"""
获取适合当前API的适配器
Args:
api_identifier: API标识符(如"openai_official", "azure_openai"等)
test_function: 可选的测试函数,用于检测API能力
Returns:
最适合的适配器实例
"""
# 检查缓存
capability = await self._get_cached_capability(api_identifier)
if capability is None and test_function:
# 缓存未命中,执行检测
capability = await self._detect_capability(api_identifier, test_function)
# 选择适配器
if capability and capability.supports_function_calling:
logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
return self.adapters[AdapterType.FUNCTION_CALLING]
else:
logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
return self.adapters[AdapterType.PROMPT_INJECTION]
async def _get_cached_capability(
self,
api_identifier: str
) -> Optional[APICapability]:
"""获取缓存的能力检测结果"""
async with self._cache_lock:
if api_identifier not in self._capability_cache:
return None
capability = self._capability_cache[api_identifier]
# 检查是否过期
if datetime.now() - capability.tested_at > self._cache_ttl:
logger.info(f"⏰ API能力缓存过期: {api_identifier}")
del self._capability_cache[api_identifier]
return None
logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
return capability
async def _detect_capability(
self,
api_identifier: str,
test_function: callable
) -> APICapability:
"""
检测API能力
Args:
api_identifier: API标识符
test_function: 测试函数,应该尝试使用Function Calling
Returns:
能力检测结果
"""
logger.info(f"🔍 开始检测API能力: {api_identifier}")
start_time = time.time()
try:
# 调用测试函数
result = await test_function()
# 判断是否成功
supports_fc = self._is_function_calling_response(result)
duration_ms = (time.time() - start_time) * 1000
capability = APICapability(
supports_function_calling=supports_fc,
tested_at=datetime.now(),
test_duration_ms=duration_ms
)
# 缓存结果
async with self._cache_lock:
self._capability_cache[api_identifier] = capability
status = "✅ 支持" if supports_fc else "❌ 不支持"
logger.info(
f"{status} Function Calling: {api_identifier} "
f"(耗时: {duration_ms:.2f}ms)"
)
return capability
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.warning(
f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
f"将使用提示词注入模式"
)
capability = APICapability(
supports_function_calling=False,
tested_at=datetime.now(),
test_duration_ms=duration_ms,
error_message=str(e)
)
# 缓存失败结果(避免重复测试)
async with self._cache_lock:
self._capability_cache[api_identifier] = capability
return capability
def _is_function_calling_response(self, response: Any) -> bool:
"""
判断响应是否是Function Calling格式
Args:
response: API响应
Returns:
是否支持Function Calling
"""
try:
# 检查字典格式
if isinstance(response, dict):
message = response.get("choices", [{}])[0].get("message", {})
return "tool_calls" in message or "function_call" in message
# 检查对象格式(OpenAI SDK
if hasattr(response, "choices"):
message = response.choices[0].message
return hasattr(message, "tool_calls") or hasattr(message, "function_call")
return False
except Exception:
return False
async def call_with_fallback(
self,
api_identifier: str,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
test_function: Optional[callable] = None
) -> ToolCallResult:
"""
带降级策略的工具调用
Args:
api_identifier: API标识符
tools: MCP工具列表
user_message: 用户消息
call_function: 实际调用API的函数
test_function: 可选的测试函数
Returns:
工具调用结果
"""
# 获取适配器
adapter = await self.get_adapter(api_identifier, test_function)
# 首次尝试
try:
if adapter.supports_native_tools():
# Function Calling模式
logger.info("🚀 尝试使用Function Calling模式")
result = await self._try_function_calling(
tools, user_message, call_function, adapter
)
else:
# 提示词注入模式
logger.info("🚀 使用提示词注入模式")
result = await self._try_prompt_injection(
tools, user_message, call_function, adapter
)
return result
except Exception as e:
logger.error(f"❌ 工具调用失败: {e}")
# 自动降级
if self._enable_auto_fallback and adapter.supports_native_tools():
logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
# 更新缓存,标记为不支持
async with self._cache_lock:
self._capability_cache[api_identifier] = APICapability(
supports_function_calling=False,
tested_at=datetime.now(),
test_duration_ms=0,
error_message=str(e)
)
# 使用提示词注入重试
fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
return await self._try_prompt_injection(
tools, user_message, call_function, fallback_adapter
)
raise
async def _try_function_calling(
self,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
adapter: FunctionCallingAdapter
) -> ToolCallResult:
"""尝试Function Calling模式"""
# Function Calling不需要修改提示词
response = await call_function(
message=user_message,
tools_param=tools,
tool_choice_param="auto"
)
return adapter.parse_tool_calls(response)
async def _try_prompt_injection(
self,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
adapter: PromptInjectionAdapter
) -> ToolCallResult:
"""尝试提示词注入模式"""
# 注入工具到提示词
enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
# 调用API(不传tools参数)
response = await call_function(
message=enhanced_prompt,
tools_param=None,
tool_choice_param=None
)
# 从文本响应中解析工具调用
return adapter.parse_tool_calls(response)
def clear_cache(self, api_identifier: Optional[str] = None):
"""
清理能力缓存
Args:
api_identifier: 可选,只清理特定API的缓存
"""
if api_identifier:
if api_identifier in self._capability_cache:
del self._capability_cache[api_identifier]
logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
else:
self._capability_cache.clear()
logger.info("🧹 已清理所有API能力缓存")
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return {
"total_cached": len(self._capability_cache),
"cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
"cached_apis": [
{
"api_identifier": api_id,
"supports_fc": cap.supports_function_calling,
"tested_at": cap.tested_at.isoformat(),
"test_duration_ms": cap.test_duration_ms
}
for api_id, cap in self._capability_cache.items()
]
}
# 全局单例
universal_mcp_adapter = UniversalMCPAdapter()