69e3e46c96
- 静默检查MCP工具可用性,支持提示词注入调用mcp - 集成UniversalMCPAdapter,支持自动API能力检测和智能降级 - 新增MCP适配器配置项,增强系统兼容性和健壮性
353 lines
12 KiB
Python
353 lines
12 KiB
Python
"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
|
||
|
||
import time
|
||
import asyncio
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime, timedelta
|
||
from dataclasses import dataclass
|
||
|
||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||
from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
|
||
from app.mcp.adapters.function_calling import FunctionCallingAdapter
|
||
from app.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class APICapability:
|
||
"""API能力检测结果"""
|
||
supports_function_calling: bool
|
||
tested_at: datetime
|
||
test_duration_ms: float
|
||
error_message: Optional[str] = None
|
||
|
||
|
||
class UniversalMCPAdapter:
|
||
"""
|
||
通用MCP适配器管理器
|
||
|
||
功能:
|
||
1. 自动检测API是否支持Function Calling
|
||
2. 缓存检测结果
|
||
3. 自动降级策略:FC失败时切换到提示词注入
|
||
4. 提供统一接口
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
cache_ttl_hours: int = 24,
|
||
enable_auto_fallback: bool = True
|
||
):
|
||
"""
|
||
初始化通用适配器
|
||
|
||
Args:
|
||
cache_ttl_hours: 能力检测缓存时长(小时)
|
||
enable_auto_fallback: 是否启用自动降级
|
||
"""
|
||
# 适配器实例
|
||
self.adapters = {
|
||
AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
|
||
AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
|
||
}
|
||
|
||
# API能力缓存: {api_identifier: APICapability}
|
||
self._capability_cache: Dict[str, APICapability] = {}
|
||
self._cache_ttl = timedelta(hours=cache_ttl_hours)
|
||
self._cache_lock = asyncio.Lock()
|
||
|
||
# 配置
|
||
self._enable_auto_fallback = enable_auto_fallback
|
||
|
||
logger.info(
|
||
f"✅ UniversalMCPAdapter初始化完成 "
|
||
f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
|
||
)
|
||
|
||
async def get_adapter(
|
||
self,
|
||
api_identifier: str,
|
||
test_function: Optional[callable] = None
|
||
) -> BaseMCPAdapter:
|
||
"""
|
||
获取适合当前API的适配器
|
||
|
||
Args:
|
||
api_identifier: API标识符(如"openai_official", "azure_openai"等)
|
||
test_function: 可选的测试函数,用于检测API能力
|
||
|
||
Returns:
|
||
最适合的适配器实例
|
||
"""
|
||
|
||
# 检查缓存
|
||
capability = await self._get_cached_capability(api_identifier)
|
||
|
||
if capability is None and test_function:
|
||
# 缓存未命中,执行检测
|
||
capability = await self._detect_capability(api_identifier, test_function)
|
||
|
||
# 选择适配器
|
||
if capability and capability.supports_function_calling:
|
||
logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
|
||
return self.adapters[AdapterType.FUNCTION_CALLING]
|
||
else:
|
||
logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
|
||
return self.adapters[AdapterType.PROMPT_INJECTION]
|
||
|
||
async def _get_cached_capability(
|
||
self,
|
||
api_identifier: str
|
||
) -> Optional[APICapability]:
|
||
"""获取缓存的能力检测结果"""
|
||
|
||
async with self._cache_lock:
|
||
if api_identifier not in self._capability_cache:
|
||
return None
|
||
|
||
capability = self._capability_cache[api_identifier]
|
||
|
||
# 检查是否过期
|
||
if datetime.now() - capability.tested_at > self._cache_ttl:
|
||
logger.info(f"⏰ API能力缓存过期: {api_identifier}")
|
||
del self._capability_cache[api_identifier]
|
||
return None
|
||
|
||
logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
|
||
return capability
|
||
|
||
async def _detect_capability(
|
||
self,
|
||
api_identifier: str,
|
||
test_function: callable
|
||
) -> APICapability:
|
||
"""
|
||
检测API能力
|
||
|
||
Args:
|
||
api_identifier: API标识符
|
||
test_function: 测试函数,应该尝试使用Function Calling
|
||
|
||
Returns:
|
||
能力检测结果
|
||
"""
|
||
|
||
logger.info(f"🔍 开始检测API能力: {api_identifier}")
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 调用测试函数
|
||
result = await test_function()
|
||
|
||
# 判断是否成功
|
||
supports_fc = self._is_function_calling_response(result)
|
||
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
|
||
capability = APICapability(
|
||
supports_function_calling=supports_fc,
|
||
tested_at=datetime.now(),
|
||
test_duration_ms=duration_ms
|
||
)
|
||
|
||
# 缓存结果
|
||
async with self._cache_lock:
|
||
self._capability_cache[api_identifier] = capability
|
||
|
||
status = "✅ 支持" if supports_fc else "❌ 不支持"
|
||
logger.info(
|
||
f"{status} Function Calling: {api_identifier} "
|
||
f"(耗时: {duration_ms:.2f}ms)"
|
||
)
|
||
|
||
return capability
|
||
|
||
except Exception as e:
|
||
duration_ms = (time.time() - start_time) * 1000
|
||
|
||
logger.warning(
|
||
f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
|
||
f"将使用提示词注入模式"
|
||
)
|
||
|
||
capability = APICapability(
|
||
supports_function_calling=False,
|
||
tested_at=datetime.now(),
|
||
test_duration_ms=duration_ms,
|
||
error_message=str(e)
|
||
)
|
||
|
||
# 缓存失败结果(避免重复测试)
|
||
async with self._cache_lock:
|
||
self._capability_cache[api_identifier] = capability
|
||
|
||
return capability
|
||
|
||
def _is_function_calling_response(self, response: Any) -> bool:
|
||
"""
|
||
判断响应是否是Function Calling格式
|
||
|
||
Args:
|
||
response: API响应
|
||
|
||
Returns:
|
||
是否支持Function Calling
|
||
"""
|
||
|
||
try:
|
||
# 检查字典格式
|
||
if isinstance(response, dict):
|
||
message = response.get("choices", [{}])[0].get("message", {})
|
||
return "tool_calls" in message or "function_call" in message
|
||
|
||
# 检查对象格式(OpenAI SDK)
|
||
if hasattr(response, "choices"):
|
||
message = response.choices[0].message
|
||
return hasattr(message, "tool_calls") or hasattr(message, "function_call")
|
||
|
||
return False
|
||
|
||
except Exception:
|
||
return False
|
||
|
||
async def call_with_fallback(
|
||
self,
|
||
api_identifier: str,
|
||
tools: List[Dict[str, Any]],
|
||
user_message: str,
|
||
call_function: callable,
|
||
test_function: Optional[callable] = None
|
||
) -> ToolCallResult:
|
||
"""
|
||
带降级策略的工具调用
|
||
|
||
Args:
|
||
api_identifier: API标识符
|
||
tools: MCP工具列表
|
||
user_message: 用户消息
|
||
call_function: 实际调用API的函数
|
||
test_function: 可选的测试函数
|
||
|
||
Returns:
|
||
工具调用结果
|
||
"""
|
||
|
||
# 获取适配器
|
||
adapter = await self.get_adapter(api_identifier, test_function)
|
||
|
||
# 首次尝试
|
||
try:
|
||
if adapter.supports_native_tools():
|
||
# Function Calling模式
|
||
logger.info("🚀 尝试使用Function Calling模式")
|
||
result = await self._try_function_calling(
|
||
tools, user_message, call_function, adapter
|
||
)
|
||
else:
|
||
# 提示词注入模式
|
||
logger.info("🚀 使用提示词注入模式")
|
||
result = await self._try_prompt_injection(
|
||
tools, user_message, call_function, adapter
|
||
)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 工具调用失败: {e}")
|
||
|
||
# 自动降级
|
||
if self._enable_auto_fallback and adapter.supports_native_tools():
|
||
logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
|
||
|
||
# 更新缓存,标记为不支持
|
||
async with self._cache_lock:
|
||
self._capability_cache[api_identifier] = APICapability(
|
||
supports_function_calling=False,
|
||
tested_at=datetime.now(),
|
||
test_duration_ms=0,
|
||
error_message=str(e)
|
||
)
|
||
|
||
# 使用提示词注入重试
|
||
fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
|
||
return await self._try_prompt_injection(
|
||
tools, user_message, call_function, fallback_adapter
|
||
)
|
||
|
||
raise
|
||
|
||
async def _try_function_calling(
|
||
self,
|
||
tools: List[Dict[str, Any]],
|
||
user_message: str,
|
||
call_function: callable,
|
||
adapter: FunctionCallingAdapter
|
||
) -> ToolCallResult:
|
||
"""尝试Function Calling模式"""
|
||
|
||
# Function Calling不需要修改提示词
|
||
response = await call_function(
|
||
message=user_message,
|
||
tools_param=tools,
|
||
tool_choice_param="auto"
|
||
)
|
||
|
||
return adapter.parse_tool_calls(response)
|
||
|
||
async def _try_prompt_injection(
|
||
self,
|
||
tools: List[Dict[str, Any]],
|
||
user_message: str,
|
||
call_function: callable,
|
||
adapter: PromptInjectionAdapter
|
||
) -> ToolCallResult:
|
||
"""尝试提示词注入模式"""
|
||
|
||
# 注入工具到提示词
|
||
enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
|
||
|
||
# 调用API(不传tools参数)
|
||
response = await call_function(
|
||
message=enhanced_prompt,
|
||
tools_param=None,
|
||
tool_choice_param=None
|
||
)
|
||
|
||
# 从文本响应中解析工具调用
|
||
return adapter.parse_tool_calls(response)
|
||
|
||
def clear_cache(self, api_identifier: Optional[str] = None):
|
||
"""
|
||
清理能力缓存
|
||
|
||
Args:
|
||
api_identifier: 可选,只清理特定API的缓存
|
||
"""
|
||
if api_identifier:
|
||
if api_identifier in self._capability_cache:
|
||
del self._capability_cache[api_identifier]
|
||
logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
|
||
else:
|
||
self._capability_cache.clear()
|
||
logger.info("🧹 已清理所有API能力缓存")
|
||
|
||
def get_cache_stats(self) -> Dict[str, Any]:
|
||
"""获取缓存统计信息"""
|
||
return {
|
||
"total_cached": len(self._capability_cache),
|
||
"cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
|
||
"cached_apis": [
|
||
{
|
||
"api_identifier": api_id,
|
||
"supports_fc": cap.supports_function_calling,
|
||
"tested_at": cap.tested_at.isoformat(),
|
||
"test_duration_ms": cap.test_duration_ms
|
||
}
|
||
for api_id, cap in self._capability_cache.items()
|
||
]
|
||
}
|
||
|
||
|
||
# 全局单例
|
||
universal_mcp_adapter = UniversalMCPAdapter() |