feat: 重构MCP功能和AI服务提供者架构
This commit is contained in:
@@ -1,4 +1,36 @@
|
||||
"""MCP插件系统"""
|
||||
from .registry import mcp_registry
|
||||
"""MCP模块 - 统一的MCP客户端管理
|
||||
|
||||
__all__ = ["mcp_registry"]
|
||||
本模块提供MCP(Model Context Protocol)客户端的统一管理接口。
|
||||
|
||||
推荐使用方式:
|
||||
from app.mcp import mcp_client, MCPPluginConfig
|
||||
|
||||
# 注册插件
|
||||
await mcp_client.register(MCPPluginConfig(
|
||||
user_id="user123",
|
||||
plugin_name="exa-search",
|
||||
url="http://localhost:8000/mcp"
|
||||
))
|
||||
|
||||
# 获取工具
|
||||
tools = await mcp_client.get_tools("user123", "exa-search")
|
||||
|
||||
# 调用工具
|
||||
result = await mcp_client.call_tool("user123", "exa-search", "web_search", {"query": "..."})
|
||||
|
||||
# 注册状态变更回调
|
||||
from app.mcp.status_sync import register_status_sync
|
||||
register_status_sync()
|
||||
"""
|
||||
|
||||
from .facade import mcp_client, MCPClientFacade, MCPPluginConfig, MCPError, PluginStatus
|
||||
from .status_sync import register_status_sync
|
||||
|
||||
__all__ = [
|
||||
"mcp_client",
|
||||
"MCPClientFacade",
|
||||
"MCPPluginConfig",
|
||||
"MCPError",
|
||||
"PluginStatus",
|
||||
"register_status_sync",
|
||||
]
|
||||
@@ -1,14 +0,0 @@
|
||||
"""MCP适配器模块 - 支持多种AI API的工具调用方式"""
|
||||
|
||||
from .base import BaseMCPAdapter, AdapterType
|
||||
from .prompt_injection import PromptInjectionAdapter
|
||||
from .function_calling import FunctionCallingAdapter
|
||||
from .universal import UniversalMCPAdapter
|
||||
|
||||
__all__ = [
|
||||
"BaseMCPAdapter",
|
||||
"AdapterType",
|
||||
"PromptInjectionAdapter",
|
||||
"FunctionCallingAdapter",
|
||||
"UniversalMCPAdapter",
|
||||
]
|
||||
@@ -1,89 +0,0 @@
|
||||
"""MCP适配器基类"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class AdapterType(Enum):
|
||||
"""适配器类型"""
|
||||
FUNCTION_CALLING = "function_calling" # 标准Function Calling
|
||||
PROMPT_INJECTION = "prompt_injection" # 提示词注入
|
||||
REACT = "react" # ReAct模式
|
||||
XML = "xml" # XML标记
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallResult:
|
||||
"""工具调用结果"""
|
||||
tool_calls: List[Dict[str, Any]] # 解析出的工具调用
|
||||
raw_response: str # 原始AI响应
|
||||
has_tool_calls: bool # 是否包含工具调用
|
||||
needs_continuation: bool = False # 是否需要继续对话
|
||||
|
||||
|
||||
class BaseMCPAdapter(ABC):
|
||||
"""MCP适配器基类"""
|
||||
|
||||
def __init__(self):
|
||||
self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION
|
||||
|
||||
@abstractmethod
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
将工具列表格式化为提示词
|
||||
|
||||
Args:
|
||||
tools: MCP工具列表
|
||||
user_message: 用户消息
|
||||
|
||||
Returns:
|
||||
格式化后的提示词
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_tool_calls(self, ai_response: str) -> ToolCallResult:
|
||||
"""
|
||||
从AI响应中解析工具调用
|
||||
|
||||
Args:
|
||||
ai_response: AI的原始响应
|
||||
|
||||
Returns:
|
||||
解析结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建包含工具结果的继续对话提示词
|
||||
|
||||
Args:
|
||||
original_message: 原始用户消息
|
||||
ai_response: AI响应
|
||||
tool_results: 工具执行结果
|
||||
|
||||
Returns:
|
||||
继续对话的提示词
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_native_tools(self) -> bool:
|
||||
"""是否支持原生工具调用(如Function Calling)"""
|
||||
return False
|
||||
|
||||
def get_adapter_type(self) -> AdapterType:
|
||||
"""获取适配器类型"""
|
||||
return self.adapter_type
|
||||
@@ -1,171 +0,0 @@
|
||||
"""Function Calling适配器 - 支持原生Function Calling的API"""
|
||||
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class FunctionCallingAdapter(BaseMCPAdapter):
|
||||
"""Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI)"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter_type = AdapterType.FUNCTION_CALLING
|
||||
|
||||
def supports_native_tools(self) -> bool:
|
||||
"""支持原生工具调用"""
|
||||
return True
|
||||
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
Function Calling模式下,工具通过API参数传递,不需要修改提示词
|
||||
|
||||
Returns:
|
||||
原始用户消息
|
||||
"""
|
||||
return user_message
|
||||
|
||||
def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取适用于API的工具格式
|
||||
|
||||
Args:
|
||||
tools: MCP工具列表
|
||||
|
||||
Returns:
|
||||
适用于OpenAI Function Calling的工具格式
|
||||
"""
|
||||
return tools
|
||||
|
||||
def parse_tool_calls(self, ai_response: Any) -> ToolCallResult:
|
||||
"""
|
||||
从AI响应中解析工具调用(Function Calling格式)
|
||||
|
||||
Args:
|
||||
ai_response: AI响应对象(通常是OpenAI的ChatCompletion对象)
|
||||
|
||||
Returns:
|
||||
解析结果
|
||||
"""
|
||||
|
||||
try:
|
||||
# 处理不同类型的响应
|
||||
if isinstance(ai_response, dict):
|
||||
# 字典格式(OpenAI API响应)
|
||||
message = ai_response.get("choices", [{}])[0].get("message", {})
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
content = message.get("content", "")
|
||||
|
||||
elif hasattr(ai_response, "choices"):
|
||||
# 对象格式(OpenAI SDK响应)
|
||||
message = ai_response.choices[0].message
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
content = getattr(message, "content", "") or ""
|
||||
|
||||
# 转换为字典格式
|
||||
if tool_calls:
|
||||
tool_calls = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
else:
|
||||
# 字符串格式(降级为文本响应)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=str(ai_response),
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
has_tool_calls = len(tool_calls) > 0
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用")
|
||||
for tc in tool_calls:
|
||||
logger.info(f" - {tc['function']['name']}")
|
||||
|
||||
return ToolCallResult(
|
||||
tool_calls=tool_calls,
|
||||
raw_response=content or "",
|
||||
has_tool_calls=has_tool_calls,
|
||||
needs_continuation=has_tool_calls
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=str(ai_response),
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""
|
||||
构建包含工具结果的继续对话提示词
|
||||
|
||||
在Function Calling模式下,这通常不需要,因为工具结果会作为消息历史的一部分
|
||||
"""
|
||||
# Function Calling模式下通常通过消息历史传递工具结果
|
||||
# 这里提供一个降级方案
|
||||
results_text = "\n\n".join([
|
||||
f"工具 {r['name']} 的结果:\n{r['content']}"
|
||||
for r in tool_results
|
||||
])
|
||||
|
||||
return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。"
|
||||
|
||||
def build_messages_with_tool_results(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建包含工具结果的消息历史(Function Calling标准格式)
|
||||
|
||||
Args:
|
||||
messages: 原始消息历史
|
||||
tool_calls: AI的工具调用
|
||||
tool_results: 工具执行结果
|
||||
|
||||
Returns:
|
||||
更新后的消息历史
|
||||
"""
|
||||
|
||||
new_messages = messages.copy()
|
||||
|
||||
# 添加助手的工具调用消息
|
||||
new_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tool_calls
|
||||
})
|
||||
|
||||
# 添加工具结果消息
|
||||
for result in tool_results:
|
||||
new_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result.get("tool_call_id", ""),
|
||||
"name": result.get("name", ""),
|
||||
"content": result.get("content", "")
|
||||
})
|
||||
|
||||
return new_messages
|
||||
@@ -1,274 +0,0 @@
|
||||
"""提示词注入适配器 - 最通用的MCP工具调用方式"""
|
||||
|
||||
import re
|
||||
import json
|
||||
from typing import Dict, Any, List
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PromptInjectionAdapter(BaseMCPAdapter):
|
||||
"""提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter_type = AdapterType.PROMPT_INJECTION
|
||||
|
||||
def format_tools_for_prompt(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""将工具列表注入到提示词中"""
|
||||
|
||||
if not tools:
|
||||
return user_message
|
||||
|
||||
# 格式化工具描述
|
||||
tool_descriptions = self._format_tools_as_text(tools)
|
||||
|
||||
# 构建增强的提示词
|
||||
enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。
|
||||
|
||||
## 可用工具
|
||||
|
||||
{tool_descriptions}
|
||||
|
||||
## 工具使用说明
|
||||
|
||||
当你需要使用工具时,请按以下XML格式输出(可以一次调用多个工具):
|
||||
|
||||
<tool_calls>
|
||||
<tool_call>
|
||||
<tool_name>工具名称</tool_name>
|
||||
<arguments>
|
||||
{{
|
||||
"参数名1": "参数值1",
|
||||
"参数名2": "参数值2"
|
||||
}}
|
||||
</arguments>
|
||||
</tool_call>
|
||||
</tool_calls>
|
||||
|
||||
## 重要提示
|
||||
|
||||
1. 只有在确实需要使用工具时才调用工具
|
||||
2. 参数必须是有效的JSON格式
|
||||
3. 仔细检查参数是否符合工具的要求
|
||||
4. 可以在一个<tool_calls>标签内包含多个<tool_call>
|
||||
5. 调用工具后,你会收到工具的执行结果,然后需要基于结果继续回答
|
||||
|
||||
---
|
||||
|
||||
用户问题:{user_message}
|
||||
|
||||
请分析问题,判断是否需要使用工具。如果需要,先输出工具调用,然后等待结果。如果不需要,直接回答问题。"""
|
||||
|
||||
return enhanced_prompt
|
||||
|
||||
def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str:
|
||||
"""将工具格式化为可读的文本描述"""
|
||||
lines = []
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
func = tool.get("function", {})
|
||||
name = func.get("name", "unknown")
|
||||
description = func.get("description", "无描述")
|
||||
parameters = func.get("parameters", {})
|
||||
|
||||
lines.append(f"### {i}. {name}")
|
||||
lines.append(f"**描述**: {description}")
|
||||
lines.append("")
|
||||
|
||||
# 格式化参数信息
|
||||
if parameters and "properties" in parameters:
|
||||
lines.append("**参数**:")
|
||||
properties = parameters.get("properties", {})
|
||||
required = parameters.get("required", [])
|
||||
|
||||
for param_name, param_info in properties.items():
|
||||
param_type = param_info.get("type", "string")
|
||||
param_desc = param_info.get("description", "")
|
||||
is_required = "必填" if param_name in required else "可选"
|
||||
|
||||
lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}")
|
||||
lines.append("")
|
||||
|
||||
# 添加示例
|
||||
if "example" in func:
|
||||
lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def parse_tool_calls(self, ai_response) -> ToolCallResult:
|
||||
"""从AI响应中解析工具调用"""
|
||||
|
||||
tool_calls = []
|
||||
|
||||
try:
|
||||
# 处理不同类型的响应
|
||||
if isinstance(ai_response, dict):
|
||||
# 如果是字典,提取content字段
|
||||
ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
if not ai_response:
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response="",
|
||||
has_tool_calls=False
|
||||
)
|
||||
elif not isinstance(ai_response, str):
|
||||
# 转换为字符串
|
||||
ai_response = str(ai_response)
|
||||
|
||||
# 使用正则提取 <tool_calls> 标签内容
|
||||
tool_calls_match = re.search(
|
||||
r'<tool_calls>(.*?)</tool_calls>',
|
||||
ai_response,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if not tool_calls_match:
|
||||
# 没有找到工具调用
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
tool_calls_content = tool_calls_match.group(1)
|
||||
|
||||
# 提取所有 <tool_call> 标签
|
||||
tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
|
||||
tool_call_matches = re.findall(
|
||||
tool_call_pattern,
|
||||
tool_calls_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
for i, tool_call_content in enumerate(tool_call_matches):
|
||||
# 提取工具名称
|
||||
name_match = re.search(
|
||||
r'<tool_name>(.*?)</tool_name>',
|
||||
tool_call_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
# 提取参数
|
||||
args_match = re.search(
|
||||
r'<arguments>(.*?)</arguments>',
|
||||
tool_call_content,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if name_match and args_match:
|
||||
tool_name = name_match.group(1).strip()
|
||||
arguments_str = args_match.group(1).strip()
|
||||
|
||||
try:
|
||||
# 解析JSON参数
|
||||
arguments = json.loads(arguments_str)
|
||||
|
||||
# 构建标准格式的工具调用
|
||||
tool_calls.append({
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": json.dumps(arguments, ensure_ascii=False)
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(f"✅ 解析工具调用: {tool_name}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}")
|
||||
continue
|
||||
|
||||
has_tool_calls = len(tool_calls) > 0
|
||||
|
||||
if has_tool_calls:
|
||||
logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用")
|
||||
|
||||
return ToolCallResult(
|
||||
tool_calls=tool_calls,
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=has_tool_calls,
|
||||
needs_continuation=has_tool_calls
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True)
|
||||
return ToolCallResult(
|
||||
tool_calls=[],
|
||||
raw_response=ai_response,
|
||||
has_tool_calls=False
|
||||
)
|
||||
|
||||
def build_continuation_prompt(
|
||||
self,
|
||||
original_message: str,
|
||||
ai_response: str,
|
||||
tool_results: List[Dict[str, Any]]
|
||||
) -> str:
|
||||
"""构建包含工具结果的继续对话提示词"""
|
||||
|
||||
# 格式化工具结果
|
||||
results_text = self._format_tool_results(tool_results)
|
||||
|
||||
continuation = f"""你之前尝试使用工具来回答用户的问题。
|
||||
|
||||
原始问题:{original_message}
|
||||
|
||||
你的工具调用:
|
||||
{self._extract_tool_calls_text(ai_response)}
|
||||
|
||||
工具执行结果:
|
||||
{results_text}
|
||||
|
||||
现在,请基于这些工具的执行结果,给出完整、详细的回答。不要重复调用工具,直接使用已有的结果来回答用户的问题。"""
|
||||
|
||||
return continuation
|
||||
|
||||
def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str:
|
||||
"""格式化工具结果为可读文本"""
|
||||
lines = []
|
||||
|
||||
for i, result in enumerate(tool_results, 1):
|
||||
tool_name = result.get("name", "unknown")
|
||||
success = result.get("success", False)
|
||||
content = result.get("content", "")
|
||||
|
||||
status = "✅ 成功" if success else "❌ 失败"
|
||||
lines.append(f"{i}. {tool_name} - {status}")
|
||||
|
||||
if success:
|
||||
# 尝试美化JSON内容
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
content_obj = json.loads(content)
|
||||
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
|
||||
except:
|
||||
pass
|
||||
lines.append(f"```\n{content}\n```")
|
||||
else:
|
||||
error = result.get("error", "未知错误")
|
||||
lines.append(f"错误信息: {error}")
|
||||
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _extract_tool_calls_text(self, ai_response: str) -> str:
|
||||
"""从AI响应中提取工具调用部分的文本"""
|
||||
match = re.search(
|
||||
r'<tool_calls>(.*?)</tool_calls>',
|
||||
ai_response,
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
|
||||
if match:
|
||||
return match.group(0)
|
||||
return "(未找到工具调用)"
|
||||
@@ -1,353 +0,0 @@
|
||||
"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
|
||||
from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
|
||||
from app.mcp.adapters.function_calling import FunctionCallingAdapter
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class APICapability:
|
||||
"""API能力检测结果"""
|
||||
supports_function_calling: bool
|
||||
tested_at: datetime
|
||||
test_duration_ms: float
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class UniversalMCPAdapter:
|
||||
"""
|
||||
通用MCP适配器管理器
|
||||
|
||||
功能:
|
||||
1. 自动检测API是否支持Function Calling
|
||||
2. 缓存检测结果
|
||||
3. 自动降级策略:FC失败时切换到提示词注入
|
||||
4. 提供统一接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_ttl_hours: int = 24,
|
||||
enable_auto_fallback: bool = True
|
||||
):
|
||||
"""
|
||||
初始化通用适配器
|
||||
|
||||
Args:
|
||||
cache_ttl_hours: 能力检测缓存时长(小时)
|
||||
enable_auto_fallback: 是否启用自动降级
|
||||
"""
|
||||
# 适配器实例
|
||||
self.adapters = {
|
||||
AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
|
||||
AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
|
||||
}
|
||||
|
||||
# API能力缓存: {api_identifier: APICapability}
|
||||
self._capability_cache: Dict[str, APICapability] = {}
|
||||
self._cache_ttl = timedelta(hours=cache_ttl_hours)
|
||||
self._cache_lock = asyncio.Lock()
|
||||
|
||||
# 配置
|
||||
self._enable_auto_fallback = enable_auto_fallback
|
||||
|
||||
logger.info(
|
||||
f"✅ UniversalMCPAdapter初始化完成 "
|
||||
f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
|
||||
)
|
||||
|
||||
async def get_adapter(
|
||||
self,
|
||||
api_identifier: str,
|
||||
test_function: Optional[callable] = None
|
||||
) -> BaseMCPAdapter:
|
||||
"""
|
||||
获取适合当前API的适配器
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符(如"openai_official", "azure_openai"等)
|
||||
test_function: 可选的测试函数,用于检测API能力
|
||||
|
||||
Returns:
|
||||
最适合的适配器实例
|
||||
"""
|
||||
|
||||
# 检查缓存
|
||||
capability = await self._get_cached_capability(api_identifier)
|
||||
|
||||
if capability is None and test_function:
|
||||
# 缓存未命中,执行检测
|
||||
capability = await self._detect_capability(api_identifier, test_function)
|
||||
|
||||
# 选择适配器
|
||||
if capability and capability.supports_function_calling:
|
||||
logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
|
||||
return self.adapters[AdapterType.FUNCTION_CALLING]
|
||||
else:
|
||||
logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
|
||||
return self.adapters[AdapterType.PROMPT_INJECTION]
|
||||
|
||||
async def _get_cached_capability(
|
||||
self,
|
||||
api_identifier: str
|
||||
) -> Optional[APICapability]:
|
||||
"""获取缓存的能力检测结果"""
|
||||
|
||||
async with self._cache_lock:
|
||||
if api_identifier not in self._capability_cache:
|
||||
return None
|
||||
|
||||
capability = self._capability_cache[api_identifier]
|
||||
|
||||
# 检查是否过期
|
||||
if datetime.now() - capability.tested_at > self._cache_ttl:
|
||||
logger.info(f"⏰ API能力缓存过期: {api_identifier}")
|
||||
del self._capability_cache[api_identifier]
|
||||
return None
|
||||
|
||||
logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
|
||||
return capability
|
||||
|
||||
async def _detect_capability(
|
||||
self,
|
||||
api_identifier: str,
|
||||
test_function: callable
|
||||
) -> APICapability:
|
||||
"""
|
||||
检测API能力
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符
|
||||
test_function: 测试函数,应该尝试使用Function Calling
|
||||
|
||||
Returns:
|
||||
能力检测结果
|
||||
"""
|
||||
|
||||
logger.info(f"🔍 开始检测API能力: {api_identifier}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 调用测试函数
|
||||
result = await test_function()
|
||||
|
||||
# 判断是否成功
|
||||
supports_fc = self._is_function_calling_response(result)
|
||||
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
capability = APICapability(
|
||||
supports_function_calling=supports_fc,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=duration_ms
|
||||
)
|
||||
|
||||
# 缓存结果
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = capability
|
||||
|
||||
status = "✅ 支持" if supports_fc else "❌ 不支持"
|
||||
logger.info(
|
||||
f"{status} Function Calling: {api_identifier} "
|
||||
f"(耗时: {duration_ms:.2f}ms)"
|
||||
)
|
||||
|
||||
return capability
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
|
||||
logger.warning(
|
||||
f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
|
||||
f"将使用提示词注入模式"
|
||||
)
|
||||
|
||||
capability = APICapability(
|
||||
supports_function_calling=False,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=duration_ms,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# 缓存失败结果(避免重复测试)
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = capability
|
||||
|
||||
return capability
|
||||
|
||||
def _is_function_calling_response(self, response: Any) -> bool:
|
||||
"""
|
||||
判断响应是否是Function Calling格式
|
||||
|
||||
Args:
|
||||
response: API响应
|
||||
|
||||
Returns:
|
||||
是否支持Function Calling
|
||||
"""
|
||||
|
||||
try:
|
||||
# 检查字典格式
|
||||
if isinstance(response, dict):
|
||||
message = response.get("choices", [{}])[0].get("message", {})
|
||||
return "tool_calls" in message or "function_call" in message
|
||||
|
||||
# 检查对象格式(OpenAI SDK)
|
||||
if hasattr(response, "choices"):
|
||||
message = response.choices[0].message
|
||||
return hasattr(message, "tool_calls") or hasattr(message, "function_call")
|
||||
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def call_with_fallback(
|
||||
self,
|
||||
api_identifier: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
test_function: Optional[callable] = None
|
||||
) -> ToolCallResult:
|
||||
"""
|
||||
带降级策略的工具调用
|
||||
|
||||
Args:
|
||||
api_identifier: API标识符
|
||||
tools: MCP工具列表
|
||||
user_message: 用户消息
|
||||
call_function: 实际调用API的函数
|
||||
test_function: 可选的测试函数
|
||||
|
||||
Returns:
|
||||
工具调用结果
|
||||
"""
|
||||
|
||||
# 获取适配器
|
||||
adapter = await self.get_adapter(api_identifier, test_function)
|
||||
|
||||
# 首次尝试
|
||||
try:
|
||||
if adapter.supports_native_tools():
|
||||
# Function Calling模式
|
||||
logger.info("🚀 尝试使用Function Calling模式")
|
||||
result = await self._try_function_calling(
|
||||
tools, user_message, call_function, adapter
|
||||
)
|
||||
else:
|
||||
# 提示词注入模式
|
||||
logger.info("🚀 使用提示词注入模式")
|
||||
result = await self._try_prompt_injection(
|
||||
tools, user_message, call_function, adapter
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 工具调用失败: {e}")
|
||||
|
||||
# 自动降级
|
||||
if self._enable_auto_fallback and adapter.supports_native_tools():
|
||||
logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
|
||||
|
||||
# 更新缓存,标记为不支持
|
||||
async with self._cache_lock:
|
||||
self._capability_cache[api_identifier] = APICapability(
|
||||
supports_function_calling=False,
|
||||
tested_at=datetime.now(),
|
||||
test_duration_ms=0,
|
||||
error_message=str(e)
|
||||
)
|
||||
|
||||
# 使用提示词注入重试
|
||||
fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
|
||||
return await self._try_prompt_injection(
|
||||
tools, user_message, call_function, fallback_adapter
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def _try_function_calling(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
adapter: FunctionCallingAdapter
|
||||
) -> ToolCallResult:
|
||||
"""尝试Function Calling模式"""
|
||||
|
||||
# Function Calling不需要修改提示词
|
||||
response = await call_function(
|
||||
message=user_message,
|
||||
tools_param=tools,
|
||||
tool_choice_param="auto"
|
||||
)
|
||||
|
||||
return adapter.parse_tool_calls(response)
|
||||
|
||||
async def _try_prompt_injection(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
user_message: str,
|
||||
call_function: callable,
|
||||
adapter: PromptInjectionAdapter
|
||||
) -> ToolCallResult:
|
||||
"""尝试提示词注入模式"""
|
||||
|
||||
# 注入工具到提示词
|
||||
enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
|
||||
|
||||
# 调用API(不传tools参数)
|
||||
response = await call_function(
|
||||
message=enhanced_prompt,
|
||||
tools_param=None,
|
||||
tool_choice_param=None
|
||||
)
|
||||
|
||||
# 从文本响应中解析工具调用
|
||||
return adapter.parse_tool_calls(response)
|
||||
|
||||
def clear_cache(self, api_identifier: Optional[str] = None):
|
||||
"""
|
||||
清理能力缓存
|
||||
|
||||
Args:
|
||||
api_identifier: 可选,只清理特定API的缓存
|
||||
"""
|
||||
if api_identifier:
|
||||
if api_identifier in self._capability_cache:
|
||||
del self._capability_cache[api_identifier]
|
||||
logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
|
||||
else:
|
||||
self._capability_cache.clear()
|
||||
logger.info("🧹 已清理所有API能力缓存")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计信息"""
|
||||
return {
|
||||
"total_cached": len(self._capability_cache),
|
||||
"cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
|
||||
"cached_apis": [
|
||||
{
|
||||
"api_identifier": api_id,
|
||||
"supports_fc": cap.supports_function_calling,
|
||||
"tested_at": cap.tested_at.isoformat(),
|
||||
"test_duration_ms": cap.test_duration_ms
|
||||
}
|
||||
for api_id, cap in self._capability_cache.items()
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# 全局单例
|
||||
universal_mcp_adapter = UniversalMCPAdapter()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,385 +0,0 @@
|
||||
"""HTTP MCP客户端 - 使用官方 MCP Python SDK 实现"""
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from mcp import ClientSession, types
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from pydantic import AnyUrl
|
||||
from anyio import ClosedResourceError
|
||||
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""MCP错误"""
|
||||
pass
|
||||
|
||||
|
||||
class HTTPMCPClient:
|
||||
"""HTTP模式MCP客户端(基于官方 MCP Python SDK)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 60.0
|
||||
):
|
||||
"""
|
||||
初始化HTTP MCP客户端
|
||||
|
||||
Args:
|
||||
url: MCP服务器URL
|
||||
headers: HTTP请求头
|
||||
env: 环境变量(用于API Key等)
|
||||
timeout: 超时时间(秒)
|
||||
"""
|
||||
self.url = url.rstrip('/')
|
||||
self.headers = headers or {}
|
||||
self.env = env or {}
|
||||
self.timeout = timeout
|
||||
|
||||
# 如果env中有API Key,添加到headers
|
||||
if 'API_KEY' in self.env:
|
||||
self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}'
|
||||
|
||||
self._session: Optional[ClientSession] = None
|
||||
self._context_stack = [] # 保存上下文管理器栈
|
||||
self._initialized = False
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _ensure_connected(self):
|
||||
"""确保连接已建立"""
|
||||
async with self._lock:
|
||||
if self._session is None:
|
||||
try:
|
||||
logger.info(f"🔗 连接到MCP服务器: {self.url}")
|
||||
|
||||
# 使用官方 SDK 的 streamable_http_client
|
||||
# 保存上下文管理器以便后续正确清理
|
||||
stream_context = streamablehttp_client(self.url)
|
||||
read_stream, write_stream, _ = await stream_context.__aenter__()
|
||||
self._context_stack.append(('stream', stream_context))
|
||||
|
||||
# 创建客户端会话
|
||||
self._session = ClientSession(read_stream, write_stream)
|
||||
session_context = self._session
|
||||
await session_context.__aenter__()
|
||||
self._context_stack.append(('session', session_context))
|
||||
|
||||
# 初始化会话
|
||||
await self._session.initialize()
|
||||
self._initialized = True
|
||||
|
||||
logger.info(f"✅ MCP会话初始化成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MCP连接失败: {e}")
|
||||
await self._cleanup()
|
||||
raise MCPError(f"连接MCP服务器失败: {str(e)}")
|
||||
|
||||
async def _cleanup(self):
|
||||
"""清理连接资源(按照进入的相反顺序退出)"""
|
||||
# 按照LIFO顺序清理上下文
|
||||
while self._context_stack:
|
||||
ctx_type, ctx = self._context_stack.pop()
|
||||
try:
|
||||
await ctx.__aexit__(None, None, None)
|
||||
except RuntimeError as e:
|
||||
# 忽略 anyio 的任务上下文错误(在关闭时可能发生)
|
||||
if "cancel scope" in str(e).lower() or "different task" in str(e).lower():
|
||||
logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}")
|
||||
else:
|
||||
logger.error(f"清理{ctx_type}上下文失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理{ctx_type}上下文失败: {e}")
|
||||
|
||||
self._session = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> Dict[str, Any]:
|
||||
"""
|
||||
初始化MCP会话
|
||||
|
||||
Returns:
|
||||
初始化响应
|
||||
"""
|
||||
await self._ensure_connected()
|
||||
return {"status": "initialized"}
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列举可用工具
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools = []
|
||||
for tool in result.tools:
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"inputSchema": tool.inputSchema
|
||||
}
|
||||
tools.append(tool_dict)
|
||||
|
||||
logger.info(f"获取到 {len(tools)} 个工具")
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {e}")
|
||||
raise MCPError(f"获取工具列表失败: {str(e)}")
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
max_reconnect_attempts: int = 2
|
||||
) -> Any:
|
||||
"""
|
||||
调用工具(带自动重连)
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
max_reconnect_attempts: 最大重连尝试次数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
for attempt in range(max_reconnect_attempts + 1):
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
logger.info(f"调用工具: {tool_name}")
|
||||
logger.debug(f" 参数类型: {type(arguments)}")
|
||||
logger.debug(f" 参数内容: {arguments}")
|
||||
logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}")
|
||||
|
||||
result = await self._session.call_tool(tool_name, arguments)
|
||||
|
||||
logger.debug(f" 工具返回类型: {type(result)}")
|
||||
logger.debug(f" 返回内容: {result}")
|
||||
|
||||
# 处理返回结果
|
||||
# MCP SDK 返回 CallToolResult 对象
|
||||
if result.content:
|
||||
logger.debug(f" 返回content数量: {len(result.content)}")
|
||||
# 提取第一个content的文本
|
||||
for idx, content in enumerate(result.content):
|
||||
logger.debug(f" content[{idx}]类型: {type(content)}")
|
||||
if isinstance(content, types.TextContent):
|
||||
logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}")
|
||||
return content.text
|
||||
elif isinstance(content, types.ImageContent):
|
||||
logger.debug(f" ✅ 返回ImageContent")
|
||||
return {
|
||||
"type": "image",
|
||||
"data": content.data,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
# 如果没有文本内容,返回原始内容
|
||||
logger.debug(f" ⚠️ 返回原始content[0]")
|
||||
return result.content[0] if result.content else None
|
||||
|
||||
# 如果有结构化内容(2025-06-18规范)
|
||||
if hasattr(result, 'structuredContent') and result.structuredContent:
|
||||
logger.debug(f" ✅ 返回structuredContent")
|
||||
return result.structuredContent
|
||||
|
||||
logger.warning(f" ⚠️ 工具返回为None")
|
||||
return None
|
||||
|
||||
except ClosedResourceError as e:
|
||||
# 连接已关闭,尝试重连
|
||||
if attempt < max_reconnect_attempts:
|
||||
logger.warning(
|
||||
f"⚠️ MCP连接已关闭,尝试重新连接 "
|
||||
f"(第{attempt + 1}/{max_reconnect_attempts}次重连)"
|
||||
)
|
||||
await self._cleanup()
|
||||
await asyncio.sleep(0.5) # 短暂延迟后重连
|
||||
continue
|
||||
else:
|
||||
logger.error(f"❌ MCP连接重连失败,已达最大重试次数")
|
||||
error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)"
|
||||
raise MCPError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True)
|
||||
logger.error(f" 参数: {arguments}")
|
||||
logger.error(f" 错误类型: {type(e).__name__}")
|
||||
logger.error(f" 错误详情: {repr(e)}")
|
||||
logger.error(f" 错误字符串: '{str(e)}'")
|
||||
error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})"
|
||||
raise MCPError(f"调用工具失败: {error_msg}")
|
||||
|
||||
# 理论上不会到这里
|
||||
raise MCPError(f"工具调用失败: 未知错误")
|
||||
|
||||
async def list_resources(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列举可用资源
|
||||
|
||||
Returns:
|
||||
资源列表
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.list_resources()
|
||||
|
||||
# 转换为字典格式
|
||||
resources = []
|
||||
for resource in result.resources:
|
||||
resource_dict = {
|
||||
"uri": str(resource.uri),
|
||||
"name": resource.name,
|
||||
"description": resource.description or "",
|
||||
"mimeType": resource.mimeType or ""
|
||||
}
|
||||
resources.append(resource_dict)
|
||||
|
||||
logger.info(f"获取到 {len(resources)} 个资源")
|
||||
return resources
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取资源列表失败: {e}")
|
||||
raise MCPError(f"获取资源列表失败: {str(e)}")
|
||||
|
||||
async def read_resource(self, uri: str) -> Any:
|
||||
"""
|
||||
读取资源
|
||||
|
||||
Args:
|
||||
uri: 资源URI
|
||||
|
||||
Returns:
|
||||
资源内容
|
||||
"""
|
||||
try:
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.read_resource(AnyUrl(uri))
|
||||
|
||||
# 提取资源内容
|
||||
if result.contents:
|
||||
content = result.contents[0]
|
||||
if isinstance(content, types.TextContent):
|
||||
return content.text
|
||||
elif isinstance(content, types.ImageContent):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": content.data,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
elif isinstance(content, types.BlobResourceContents):
|
||||
return {
|
||||
"type": "blob",
|
||||
"blob": content.blob,
|
||||
"mimeType": content.mimeType
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"读取资源失败: {uri}, 错误: {e}")
|
||||
raise MCPError(f"读取资源失败: {str(e)}")
|
||||
|
||||
async def test_connection(self) -> Dict[str, Any]:
|
||||
"""
|
||||
测试连接
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 尝试连接并列举工具(直接调用SDK,避免重复日志)
|
||||
await self._ensure_connected()
|
||||
|
||||
result = await self._session.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools = []
|
||||
for tool in result.tools:
|
||||
tool_dict = {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"inputSchema": tool.inputSchema
|
||||
}
|
||||
tools.append(tool_dict)
|
||||
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
logger.info(f"✅ 连接测试成功,获取到 {len(tools)} 个工具")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接测试成功",
|
||||
"response_time_ms": response_time,
|
||||
"tools_count": len(tools),
|
||||
"tools": tools
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
response_time = round((end_time - start_time) * 1000, 2)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"message": "连接测试失败",
|
||||
"response_time_ms": response_time,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
"suggestions": [
|
||||
"请检查服务器URL是否正确",
|
||||
"请确认API Key是否有效",
|
||||
"请检查网络连接",
|
||||
"请确认MCP服务器是否在线"
|
||||
]
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端连接"""
|
||||
logger.info(f"关闭MCP客户端: {self.url}")
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_mcp_client(
|
||||
url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 60.0
|
||||
):
|
||||
"""
|
||||
创建MCP客户端的上下文管理器
|
||||
|
||||
Args:
|
||||
url: MCP服务器URL
|
||||
headers: HTTP请求头
|
||||
env: 环境变量
|
||||
timeout: 超时时间
|
||||
|
||||
Yields:
|
||||
HTTPMCPClient实例
|
||||
"""
|
||||
client = HTTPMCPClient(url, headers, env, timeout)
|
||||
try:
|
||||
await client.initialize()
|
||||
yield client
|
||||
finally:
|
||||
await client.close()
|
||||
@@ -1,527 +0,0 @@
|
||||
"""MCP插件注册表 - 管理运行时插件实例"""
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Any, List
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from app.mcp.http_client import HTTPMCPClient, MCPError
|
||||
from app.mcp.config import mcp_config
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""会话信息"""
|
||||
client: HTTPMCPClient
|
||||
created_at: float
|
||||
last_access: float
|
||||
request_count: int = 0
|
||||
error_count: int = 0
|
||||
status: str = "active" # active, degraded, error
|
||||
|
||||
|
||||
class MCPPluginRegistry:
|
||||
"""MCP插件注册表 - 管理运行时插件实例(优化版)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_clients: Optional[int] = None,
|
||||
client_ttl: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
初始化注册表
|
||||
|
||||
Args:
|
||||
max_clients: 最大缓存客户端数量(默认使用配置)
|
||||
client_ttl: 客户端过期时间(秒,默认使用配置)
|
||||
"""
|
||||
# 存储格式: {plugin_id: SessionInfo}
|
||||
self._sessions: Dict[str, SessionInfo] = {}
|
||||
|
||||
# 全局锁用于保护会话字典
|
||||
self._sessions_lock = asyncio.Lock()
|
||||
|
||||
# 细粒度锁:每个用户一个锁
|
||||
self._user_locks: Dict[str, asyncio.Lock] = {}
|
||||
self._locks_lock = asyncio.Lock() # 保护locks字典本身
|
||||
|
||||
# 配置参数(使用配置常量)
|
||||
self._max_clients = max_clients or mcp_config.MAX_CLIENTS
|
||||
self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS
|
||||
|
||||
# 启动后台清理任务
|
||||
self._cleanup_task = None
|
||||
self._health_check_task = None
|
||||
self._tasks_started = False
|
||||
|
||||
def _ensure_background_tasks(self):
|
||||
"""确保后台任务已启动(延迟初始化)"""
|
||||
if not self._tasks_started:
|
||||
try:
|
||||
# 检查是否有运行中的事件循环
|
||||
loop = asyncio.get_running_loop()
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("✅ MCP插件注册表后台清理任务已启动")
|
||||
|
||||
if self._health_check_task is None:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
logger.info("✅ MCP会话健康检查任务已启动")
|
||||
|
||||
self._tasks_started = True
|
||||
except RuntimeError:
|
||||
# 没有运行中的事件循环,稍后再试
|
||||
pass
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""后台清理过期客户端"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS)
|
||||
await self._cleanup_expired_sessions()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理任务异常: {e}")
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""后台健康检查"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS)
|
||||
await self._check_session_health()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查任务异常: {e}")
|
||||
|
||||
async def _cleanup_expired_sessions(self):
|
||||
"""清理过期的会话"""
|
||||
now = time.time()
|
||||
expired_ids = []
|
||||
|
||||
async with self._sessions_lock:
|
||||
# 收集过期的plugin_id
|
||||
for plugin_id, session in list(self._sessions.items()):
|
||||
if now - session.last_access > self._client_ttl:
|
||||
expired_ids.append(plugin_id)
|
||||
|
||||
if expired_ids:
|
||||
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话")
|
||||
for plugin_id in expired_ids:
|
||||
# 提取user_id来获取对应的锁
|
||||
user_id = plugin_id.split(':', 1)[0]
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
|
||||
async with user_lock:
|
||||
async with self._sessions_lock:
|
||||
if plugin_id in self._sessions:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
async def _check_session_health(self):
|
||||
"""增强的会话健康检查"""
|
||||
async with self._sessions_lock:
|
||||
for plugin_id, session in list(self._sessions.items()):
|
||||
# 计算错误率
|
||||
if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK:
|
||||
error_rate = session.error_count / session.request_count
|
||||
|
||||
# 动态调整状态(使用配置常量)
|
||||
if error_rate > mcp_config.ERROR_RATE_CRITICAL:
|
||||
if session.status != "error":
|
||||
session.status = "error"
|
||||
logger.error(
|
||||
f"❌ 会话 {plugin_id} 错误率过高 "
|
||||
f"({error_rate:.1%}), 标记为error"
|
||||
)
|
||||
elif error_rate > mcp_config.ERROR_RATE_WARNING:
|
||||
if session.status == "active":
|
||||
session.status = "degraded"
|
||||
logger.warning(
|
||||
f"⚠️ 会话 {plugin_id} 健康状况下降 "
|
||||
f"(错误率: {error_rate:.1%})"
|
||||
)
|
||||
elif session.status == "degraded":
|
||||
# 错误率降低,恢复正常
|
||||
session.status = "active"
|
||||
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
|
||||
|
||||
# 检查即将过期的会话(最后1分钟提醒)
|
||||
idle_time = time.time() - session.last_access
|
||||
time_until_expiry = self._client_ttl - idle_time
|
||||
|
||||
# 仅在最后1分钟(60秒)内提醒一次
|
||||
if 0 < time_until_expiry <= 60:
|
||||
# 使用会话属性避免重复提醒
|
||||
if not hasattr(session, '_expiry_warned') or not session._expiry_warned:
|
||||
logger.warning(
|
||||
f"⏰ 会话 {plugin_id} 即将过期 "
|
||||
f"(剩余 {time_until_expiry:.0f} 秒)"
|
||||
)
|
||||
session._expiry_warned = True
|
||||
elif time_until_expiry > 60:
|
||||
# 重置警告标志(如果会话被重新使用)
|
||||
if hasattr(session, '_expiry_warned'):
|
||||
session._expiry_warned = False
|
||||
|
||||
async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
|
||||
"""
|
||||
获取用户专属的锁(细粒度锁)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
该用户的锁对象
|
||||
"""
|
||||
async with self._locks_lock:
|
||||
if user_id not in self._user_locks:
|
||||
self._user_locks[user_id] = asyncio.Lock()
|
||||
return self._user_locks[user_id]
|
||||
|
||||
def _touch_session(self, plugin_id: str):
|
||||
"""
|
||||
更新会话的最后访问时间(需要在锁内调用)
|
||||
|
||||
Args:
|
||||
plugin_id: 插件ID
|
||||
"""
|
||||
if plugin_id in self._sessions:
|
||||
session = self._sessions[plugin_id]
|
||||
session.last_access = time.time()
|
||||
session.request_count += 1
|
||||
|
||||
async def _evict_lru_session(self):
|
||||
"""驱逐最久未使用的会话(当达到max_clients限制时)"""
|
||||
if len(self._sessions) >= self._max_clients:
|
||||
# 找到最旧的会话
|
||||
oldest_id = None
|
||||
oldest_time = float('inf')
|
||||
|
||||
for plugin_id, session in self._sessions.items():
|
||||
if session.last_access < oldest_time:
|
||||
oldest_time = session.last_access
|
||||
oldest_id = plugin_id
|
||||
|
||||
if oldest_id:
|
||||
logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}")
|
||||
await self._unload_plugin_unsafe(oldest_id)
|
||||
|
||||
async def load_plugin(self, plugin: MCPPlugin) -> bool:
|
||||
"""
|
||||
从配置加载插件
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
|
||||
Returns:
|
||||
是否加载成功
|
||||
"""
|
||||
# 确保后台任务已启动
|
||||
self._ensure_background_tasks()
|
||||
|
||||
# 使用细粒度锁(只锁定当前用户)
|
||||
user_lock = await self._get_user_lock(plugin.user_id)
|
||||
async with user_lock:
|
||||
try:
|
||||
plugin_id = f"{plugin.user_id}:{plugin.plugin_name}"
|
||||
|
||||
# 如果已加载,先卸载
|
||||
async with self._sessions_lock:
|
||||
if plugin_id in self._sessions:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
# 检查是否需要驱逐LRU会话
|
||||
await self._evict_lru_session()
|
||||
|
||||
# 目前只支持HTTP类型
|
||||
if plugin.plugin_type == "http":
|
||||
if not plugin.server_url:
|
||||
logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}")
|
||||
return False
|
||||
|
||||
# 为每个插件创建独立的HTTP客户端
|
||||
client = HTTPMCPClient(
|
||||
url=plugin.server_url,
|
||||
headers=plugin.headers or {},
|
||||
env=plugin.env or {},
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
)
|
||||
|
||||
# 创建会话信息
|
||||
now = time.time()
|
||||
session = SessionInfo(
|
||||
client=client,
|
||||
created_at=now,
|
||||
last_access=now,
|
||||
request_count=0,
|
||||
error_count=0,
|
||||
status="active"
|
||||
)
|
||||
|
||||
# 存储会话
|
||||
async with self._sessions_lock:
|
||||
self._sessions[plugin_id] = session
|
||||
|
||||
logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件失败 {plugin.plugin_name}: {e}")
|
||||
return False
|
||||
|
||||
async def unload_plugin(self, user_id: str, plugin_name: str):
|
||||
"""
|
||||
卸载插件
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
"""
|
||||
# 使用细粒度锁(只锁定当前用户)
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
async with user_lock:
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
async with self._sessions_lock:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
async def _unload_plugin_unsafe(self, plugin_id: str):
|
||||
"""卸载插件(不加锁,内部使用,需要在sessions_lock内调用)"""
|
||||
if plugin_id in self._sessions:
|
||||
session = self._sessions[plugin_id]
|
||||
try:
|
||||
await session.client.close()
|
||||
except Exception as e:
|
||||
logger.error(f"关闭插件客户端失败 {plugin_id}: {e}")
|
||||
|
||||
del self._sessions[plugin_id]
|
||||
logger.info(f"卸载MCP插件: {plugin_id}")
|
||||
|
||||
async def reload_plugin(self, plugin: MCPPlugin) -> bool:
|
||||
"""
|
||||
重新加载插件
|
||||
|
||||
Args:
|
||||
plugin: 插件配置
|
||||
|
||||
Returns:
|
||||
是否重载成功
|
||||
"""
|
||||
await self.unload_plugin(plugin.user_id, plugin.plugin_name)
|
||||
return await self.load_plugin(plugin)
|
||||
|
||||
def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]:
|
||||
"""
|
||||
获取插件客户端(线程安全,支持访问时间更新)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
客户端实例或None
|
||||
"""
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
|
||||
session = self._sessions.get(plugin_id)
|
||||
if session:
|
||||
# 检查会话状态
|
||||
if session.status == "error":
|
||||
logger.warning(
|
||||
f"⚠️ 会话 {plugin_id} 处于错误状态,"
|
||||
f"建议调用者重新加载插件"
|
||||
)
|
||||
# 不返回错误状态的客户端
|
||||
return None
|
||||
|
||||
# ✅ 使用锁保护状态更新,避免并发问题
|
||||
# 注意:这里使用原子操作更新简单字段,不需要异步锁
|
||||
session.last_access = time.time()
|
||||
session.request_count += 1
|
||||
return session.client
|
||||
return None
|
||||
|
||||
async def get_or_reconnect_client(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
plugin: MCPPlugin
|
||||
) -> HTTPMCPClient:
|
||||
"""
|
||||
获取或重连客户端(自动处理错误状态)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
plugin: 插件配置对象
|
||||
|
||||
Returns:
|
||||
客户端实例
|
||||
|
||||
Raises:
|
||||
ValueError: 插件加载失败
|
||||
"""
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
|
||||
# 获取用户锁
|
||||
user_lock = await self._get_user_lock(user_id)
|
||||
async with user_lock:
|
||||
session = self._sessions.get(plugin_id)
|
||||
|
||||
# 检查会话健康状态
|
||||
if session and session.status == "error":
|
||||
logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连")
|
||||
async with self._sessions_lock:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
session = None
|
||||
|
||||
# 如果没有会话,加载插件
|
||||
if not session:
|
||||
success = await self.load_plugin(plugin)
|
||||
if not success:
|
||||
raise ValueError(f"插件加载失败: {plugin_name}")
|
||||
session = self._sessions[plugin_id]
|
||||
|
||||
return session.client
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
调用插件工具(带错误计数和状态管理)
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
ValueError: 插件不存在或未启用
|
||||
MCPError: 工具调用失败
|
||||
"""
|
||||
plugin_id = f"{user_id}:{plugin_name}"
|
||||
|
||||
# 获取会话
|
||||
session = self._sessions.get(plugin_id)
|
||||
if not session:
|
||||
raise ValueError(f"插件未加载: {plugin_name}")
|
||||
|
||||
try:
|
||||
result = await session.client.call_tool(tool_name, arguments)
|
||||
logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}")
|
||||
|
||||
# 调用成功,重置状态(如果之前是degraded)
|
||||
if session.status == "degraded":
|
||||
session.status = "active"
|
||||
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# 增加错误计数
|
||||
session.error_count += 1
|
||||
|
||||
# 根据错误率更新状态
|
||||
if session.request_count > 0:
|
||||
error_rate = session.error_count / session.request_count
|
||||
if error_rate > 0.5:
|
||||
session.status = "error"
|
||||
elif error_rate > 0.3:
|
||||
session.status = "degraded"
|
||||
|
||||
logger.error(
|
||||
f"❌ 工具调用失败: {plugin_name}.{tool_name}, "
|
||||
f"错误: {e} (错误计数: {session.error_count}/{session.request_count})"
|
||||
)
|
||||
raise
|
||||
|
||||
async def get_plugin_tools(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取插件的工具列表
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
client = self.get_client(user_id, plugin_name)
|
||||
|
||||
if not client:
|
||||
raise ValueError(f"插件未加载: {plugin_name}")
|
||||
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
return tools
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具列表失败: {plugin_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def test_plugin(
|
||||
self,
|
||||
user_id: str,
|
||||
plugin_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
测试插件连接
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
测试结果
|
||||
"""
|
||||
client = self.get_client(user_id, plugin_name)
|
||||
|
||||
if not client:
|
||||
raise ValueError(f"插件未加载: {plugin_name}")
|
||||
|
||||
return await client.test_connection()
|
||||
|
||||
async def cleanup_all(self):
|
||||
"""清理所有插件和资源"""
|
||||
# 停止后台任务
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 清理所有会话
|
||||
async with self._sessions_lock:
|
||||
plugin_ids = list(self._sessions.keys())
|
||||
for plugin_id in plugin_ids:
|
||||
await self._unload_plugin_unsafe(plugin_id)
|
||||
|
||||
logger.info("✅ 已清理所有MCP插件和资源")
|
||||
|
||||
|
||||
# 全局注册表实例
|
||||
mcp_registry = MCPPluginRegistry()
|
||||
@@ -0,0 +1,50 @@
|
||||
"""MCP插件状态同步服务
|
||||
|
||||
将内存中的会话状态变更同步到数据库,确保状态一致性。
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.models.mcp_plugin import MCPPlugin
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def sync_status_to_db(event: Dict[str, Any]):
|
||||
"""
|
||||
状态变更回调 - 同步到数据库
|
||||
"""
|
||||
user_id = event["user_id"]
|
||||
plugin_name = event["plugin_name"]
|
||||
new_status = event["new_status"]
|
||||
reason = event.get("reason", "")
|
||||
|
||||
try:
|
||||
from app.database import get_engine
|
||||
|
||||
engine = await get_engine(user_id)
|
||||
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = (
|
||||
update(MCPPlugin)
|
||||
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
|
||||
.values(status=new_status, last_error=reason if new_status == "error" else None)
|
||||
)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"✅ 状态已同步到数据库: {plugin_name} -> {new_status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 状态同步失败: {plugin_name}, 错误: {e}")
|
||||
|
||||
|
||||
def register_status_sync():
|
||||
"""注册状态同步回调到MCP客户端"""
|
||||
from app.mcp import mcp_client
|
||||
mcp_client.register_status_callback(sync_status_to_db)
|
||||
logger.info("✅ MCP状态同步服务已注册")
|
||||
Reference in New Issue
Block a user