feat: 重构MCP功能和AI服务提供者架构

This commit is contained in:
xiamuceer-j
2026-01-09 17:13:19 +08:00
parent f3c224261d
commit 77c5489ff8
49 changed files with 4763 additions and 4307 deletions
+35 -3
View File
@@ -1,4 +1,36 @@
"""MCP插件系统"""
from .registry import mcp_registry
"""MCP模块 - 统一的MCP客户端管理
__all__ = ["mcp_registry"]
本模块提供MCPModel Context Protocol)客户端的统一管理接口。
推荐使用方式:
from app.mcp import mcp_client, MCPPluginConfig
# 注册插件
await mcp_client.register(MCPPluginConfig(
user_id="user123",
plugin_name="exa-search",
url="http://localhost:8000/mcp"
))
# 获取工具
tools = await mcp_client.get_tools("user123", "exa-search")
# 调用工具
result = await mcp_client.call_tool("user123", "exa-search", "web_search", {"query": "..."})
# 注册状态变更回调
from app.mcp.status_sync import register_status_sync
register_status_sync()
"""
from .facade import mcp_client, MCPClientFacade, MCPPluginConfig, MCPError, PluginStatus
from .status_sync import register_status_sync
__all__ = [
"mcp_client",
"MCPClientFacade",
"MCPPluginConfig",
"MCPError",
"PluginStatus",
"register_status_sync",
]
-14
View File
@@ -1,14 +0,0 @@
"""MCP适配器模块 - 支持多种AI API的工具调用方式"""
from .base import BaseMCPAdapter, AdapterType
from .prompt_injection import PromptInjectionAdapter
from .function_calling import FunctionCallingAdapter
from .universal import UniversalMCPAdapter
__all__ = [
"BaseMCPAdapter",
"AdapterType",
"PromptInjectionAdapter",
"FunctionCallingAdapter",
"UniversalMCPAdapter",
]
-89
View File
@@ -1,89 +0,0 @@
"""MCP适配器基类"""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
class AdapterType(Enum):
"""适配器类型"""
FUNCTION_CALLING = "function_calling" # 标准Function Calling
PROMPT_INJECTION = "prompt_injection" # 提示词注入
REACT = "react" # ReAct模式
XML = "xml" # XML标记
@dataclass
class ToolCallResult:
"""工具调用结果"""
tool_calls: List[Dict[str, Any]] # 解析出的工具调用
raw_response: str # 原始AI响应
has_tool_calls: bool # 是否包含工具调用
needs_continuation: bool = False # 是否需要继续对话
class BaseMCPAdapter(ABC):
"""MCP适配器基类"""
def __init__(self):
self.adapter_type: AdapterType = AdapterType.PROMPT_INJECTION
@abstractmethod
def format_tools_for_prompt(
self,
tools: List[Dict[str, Any]],
user_message: str
) -> str:
"""
将工具列表格式化为提示词
Args:
tools: MCP工具列表
user_message: 用户消息
Returns:
格式化后的提示词
"""
pass
@abstractmethod
def parse_tool_calls(self, ai_response: str) -> ToolCallResult:
"""
从AI响应中解析工具调用
Args:
ai_response: AI的原始响应
Returns:
解析结果
"""
pass
@abstractmethod
def build_continuation_prompt(
self,
original_message: str,
ai_response: str,
tool_results: List[Dict[str, Any]]
) -> str:
"""
构建包含工具结果的继续对话提示词
Args:
original_message: 原始用户消息
ai_response: AI响应
tool_results: 工具执行结果
Returns:
继续对话的提示词
"""
pass
def supports_native_tools(self) -> bool:
"""是否支持原生工具调用(如Function Calling"""
return False
def get_adapter_type(self) -> AdapterType:
"""获取适配器类型"""
return self.adapter_type
@@ -1,171 +0,0 @@
"""Function Calling适配器 - 支持原生Function Calling的API"""
import json
from typing import Dict, Any, List
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
from app.logger import get_logger
logger = get_logger(__name__)
class FunctionCallingAdapter(BaseMCPAdapter):
"""Function Calling适配器 - 用于支持原生工具调用的AI API(如OpenAI"""
def __init__(self):
super().__init__()
self.adapter_type = AdapterType.FUNCTION_CALLING
def supports_native_tools(self) -> bool:
"""支持原生工具调用"""
return True
def format_tools_for_prompt(
self,
tools: List[Dict[str, Any]],
user_message: str
) -> str:
"""
Function Calling模式下,工具通过API参数传递,不需要修改提示词
Returns:
原始用户消息
"""
return user_message
def get_tools_for_api(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
获取适用于API的工具格式
Args:
tools: MCP工具列表
Returns:
适用于OpenAI Function Calling的工具格式
"""
return tools
def parse_tool_calls(self, ai_response: Any) -> ToolCallResult:
"""
从AI响应中解析工具调用(Function Calling格式)
Args:
ai_response: AI响应对象(通常是OpenAI的ChatCompletion对象)
Returns:
解析结果
"""
try:
# 处理不同类型的响应
if isinstance(ai_response, dict):
# 字典格式(OpenAI API响应)
message = ai_response.get("choices", [{}])[0].get("message", {})
tool_calls = message.get("tool_calls", [])
content = message.get("content", "")
elif hasattr(ai_response, "choices"):
# 对象格式(OpenAI SDK响应)
message = ai_response.choices[0].message
tool_calls = getattr(message, "tool_calls", None) or []
content = getattr(message, "content", "") or ""
# 转换为字典格式
if tool_calls:
tool_calls = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments
}
}
for tc in tool_calls
]
else:
# 字符串格式(降级为文本响应)
return ToolCallResult(
tool_calls=[],
raw_response=str(ai_response),
has_tool_calls=False
)
has_tool_calls = len(tool_calls) > 0
if has_tool_calls:
logger.info(f"✅ Function Calling模式解析出 {len(tool_calls)} 个工具调用")
for tc in tool_calls:
logger.info(f" - {tc['function']['name']}")
return ToolCallResult(
tool_calls=tool_calls,
raw_response=content or "",
has_tool_calls=has_tool_calls,
needs_continuation=has_tool_calls
)
except Exception as e:
logger.error(f"❌ 解析Function Calling响应失败: {e}", exc_info=True)
return ToolCallResult(
tool_calls=[],
raw_response=str(ai_response),
has_tool_calls=False
)
def build_continuation_prompt(
self,
original_message: str,
ai_response: str,
tool_results: List[Dict[str, Any]]
) -> str:
"""
构建包含工具结果的继续对话提示词
在Function Calling模式下,这通常不需要,因为工具结果会作为消息历史的一部分
"""
# Function Calling模式下通常通过消息历史传递工具结果
# 这里提供一个降级方案
results_text = "\n\n".join([
f"工具 {r['name']} 的结果:\n{r['content']}"
for r in tool_results
])
return f"{original_message}\n\n工具执行结果:\n{results_text}\n\n请基于以上工具结果回答用户的问题。"
def build_messages_with_tool_results(
self,
messages: List[Dict[str, Any]],
tool_calls: List[Dict[str, Any]],
tool_results: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
构建包含工具结果的消息历史(Function Calling标准格式)
Args:
messages: 原始消息历史
tool_calls: AI的工具调用
tool_results: 工具执行结果
Returns:
更新后的消息历史
"""
new_messages = messages.copy()
# 添加助手的工具调用消息
new_messages.append({
"role": "assistant",
"content": None,
"tool_calls": tool_calls
})
# 添加工具结果消息
for result in tool_results:
new_messages.append({
"role": "tool",
"tool_call_id": result.get("tool_call_id", ""),
"name": result.get("name", ""),
"content": result.get("content", "")
})
return new_messages
@@ -1,274 +0,0 @@
"""提示词注入适配器 - 最通用的MCP工具调用方式"""
import re
import json
from typing import Dict, Any, List
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
from app.logger import get_logger
logger = get_logger(__name__)
class PromptInjectionAdapter(BaseMCPAdapter):
"""提示词注入适配器 - 将工具转换为文本描述,通过提示词引导AI调用"""
def __init__(self):
super().__init__()
self.adapter_type = AdapterType.PROMPT_INJECTION
def format_tools_for_prompt(
self,
tools: List[Dict[str, Any]],
user_message: str
) -> str:
"""将工具列表注入到提示词中"""
if not tools:
return user_message
# 格式化工具描述
tool_descriptions = self._format_tools_as_text(tools)
# 构建增强的提示词
enhanced_prompt = f"""你现在可以使用以下工具来帮助回答用户的问题。
## 可用工具
{tool_descriptions}
## 工具使用说明
当你需要使用工具时,请按以下XML格式输出(可以一次调用多个工具):
<tool_calls>
<tool_call>
<tool_name>工具名称</tool_name>
<arguments>
{{
"参数名1": "参数值1",
"参数名2": "参数值2"
}}
</arguments>
</tool_call>
</tool_calls>
## 重要提示
1. 只有在确实需要使用工具时才调用工具
2. 参数必须是有效的JSON格式
3. 仔细检查参数是否符合工具的要求
4. 可以在一个<tool_calls>标签内包含多个<tool_call>
5. 调用工具后,你会收到工具的执行结果,然后需要基于结果继续回答
---
用户问题:{user_message}
请分析问题,判断是否需要使用工具。如果需要,先输出工具调用,然后等待结果。如果不需要,直接回答问题。"""
return enhanced_prompt
def _format_tools_as_text(self, tools: List[Dict[str, Any]]) -> str:
"""将工具格式化为可读的文本描述"""
lines = []
for i, tool in enumerate(tools, 1):
func = tool.get("function", {})
name = func.get("name", "unknown")
description = func.get("description", "无描述")
parameters = func.get("parameters", {})
lines.append(f"### {i}. {name}")
lines.append(f"**描述**: {description}")
lines.append("")
# 格式化参数信息
if parameters and "properties" in parameters:
lines.append("**参数**:")
properties = parameters.get("properties", {})
required = parameters.get("required", [])
for param_name, param_info in properties.items():
param_type = param_info.get("type", "string")
param_desc = param_info.get("description", "")
is_required = "必填" if param_name in required else "可选"
lines.append(f" - `{param_name}` ({param_type}, {is_required}): {param_desc}")
lines.append("")
# 添加示例
if "example" in func:
lines.append(f"**示例**: {json.dumps(func['example'], ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
def parse_tool_calls(self, ai_response) -> ToolCallResult:
"""从AI响应中解析工具调用"""
tool_calls = []
try:
# 处理不同类型的响应
if isinstance(ai_response, dict):
# 如果是字典,提取content字段
ai_response = ai_response.get("choices", [{}])[0].get("message", {}).get("content", "")
if not ai_response:
return ToolCallResult(
tool_calls=[],
raw_response="",
has_tool_calls=False
)
elif not isinstance(ai_response, str):
# 转换为字符串
ai_response = str(ai_response)
# 使用正则提取 <tool_calls> 标签内容
tool_calls_match = re.search(
r'<tool_calls>(.*?)</tool_calls>',
ai_response,
re.DOTALL | re.IGNORECASE
)
if not tool_calls_match:
# 没有找到工具调用
return ToolCallResult(
tool_calls=[],
raw_response=ai_response,
has_tool_calls=False
)
tool_calls_content = tool_calls_match.group(1)
# 提取所有 <tool_call> 标签
tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
tool_call_matches = re.findall(
tool_call_pattern,
tool_calls_content,
re.DOTALL | re.IGNORECASE
)
for i, tool_call_content in enumerate(tool_call_matches):
# 提取工具名称
name_match = re.search(
r'<tool_name>(.*?)</tool_name>',
tool_call_content,
re.DOTALL | re.IGNORECASE
)
# 提取参数
args_match = re.search(
r'<arguments>(.*?)</arguments>',
tool_call_content,
re.DOTALL | re.IGNORECASE
)
if name_match and args_match:
tool_name = name_match.group(1).strip()
arguments_str = args_match.group(1).strip()
try:
# 解析JSON参数
arguments = json.loads(arguments_str)
# 构建标准格式的工具调用
tool_calls.append({
"id": f"call_{i}",
"type": "function",
"function": {
"name": tool_name,
"arguments": json.dumps(arguments, ensure_ascii=False)
}
})
logger.info(f"✅ 解析工具调用: {tool_name}")
except json.JSONDecodeError as e:
logger.error(f"❌ 解析工具参数失败: {arguments_str}, 错误: {e}")
continue
has_tool_calls = len(tool_calls) > 0
if has_tool_calls:
logger.info(f"✅ 从响应中解析出 {len(tool_calls)} 个工具调用")
return ToolCallResult(
tool_calls=tool_calls,
raw_response=ai_response,
has_tool_calls=has_tool_calls,
needs_continuation=has_tool_calls
)
except Exception as e:
logger.error(f"❌ 解析工具调用失败: {e}", exc_info=True)
return ToolCallResult(
tool_calls=[],
raw_response=ai_response,
has_tool_calls=False
)
def build_continuation_prompt(
self,
original_message: str,
ai_response: str,
tool_results: List[Dict[str, Any]]
) -> str:
"""构建包含工具结果的继续对话提示词"""
# 格式化工具结果
results_text = self._format_tool_results(tool_results)
continuation = f"""你之前尝试使用工具来回答用户的问题。
原始问题:{original_message}
你的工具调用:
{self._extract_tool_calls_text(ai_response)}
工具执行结果:
{results_text}
现在,请基于这些工具的执行结果,给出完整、详细的回答。不要重复调用工具,直接使用已有的结果来回答用户的问题。"""
return continuation
def _format_tool_results(self, tool_results: List[Dict[str, Any]]) -> str:
"""格式化工具结果为可读文本"""
lines = []
for i, result in enumerate(tool_results, 1):
tool_name = result.get("name", "unknown")
success = result.get("success", False)
content = result.get("content", "")
status = "✅ 成功" if success else "❌ 失败"
lines.append(f"{i}. {tool_name} - {status}")
if success:
# 尝试美化JSON内容
try:
if isinstance(content, str):
content_obj = json.loads(content)
content = json.dumps(content_obj, ensure_ascii=False, indent=2)
except:
pass
lines.append(f"```\n{content}\n```")
else:
error = result.get("error", "未知错误")
lines.append(f"错误信息: {error}")
lines.append("")
return "\n".join(lines)
def _extract_tool_calls_text(self, ai_response: str) -> str:
"""从AI响应中提取工具调用部分的文本"""
match = re.search(
r'<tool_calls>(.*?)</tool_calls>',
ai_response,
re.DOTALL | re.IGNORECASE
)
if match:
return match.group(0)
return "(未找到工具调用)"
-353
View File
@@ -1,353 +0,0 @@
"""通用MCP适配器 - 自动检测API能力并选择最佳适配器"""
import time
import asyncio
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass
from app.mcp.adapters.base import BaseMCPAdapter, AdapterType, ToolCallResult
from app.mcp.adapters.prompt_injection import PromptInjectionAdapter
from app.mcp.adapters.function_calling import FunctionCallingAdapter
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class APICapability:
"""API能力检测结果"""
supports_function_calling: bool
tested_at: datetime
test_duration_ms: float
error_message: Optional[str] = None
class UniversalMCPAdapter:
"""
通用MCP适配器管理器
功能:
1. 自动检测API是否支持Function Calling
2. 缓存检测结果
3. 自动降级策略:FC失败时切换到提示词注入
4. 提供统一接口
"""
def __init__(
self,
cache_ttl_hours: int = 24,
enable_auto_fallback: bool = True
):
"""
初始化通用适配器
Args:
cache_ttl_hours: 能力检测缓存时长(小时)
enable_auto_fallback: 是否启用自动降级
"""
# 适配器实例
self.adapters = {
AdapterType.FUNCTION_CALLING: FunctionCallingAdapter(),
AdapterType.PROMPT_INJECTION: PromptInjectionAdapter()
}
# API能力缓存: {api_identifier: APICapability}
self._capability_cache: Dict[str, APICapability] = {}
self._cache_ttl = timedelta(hours=cache_ttl_hours)
self._cache_lock = asyncio.Lock()
# 配置
self._enable_auto_fallback = enable_auto_fallback
logger.info(
f"✅ UniversalMCPAdapter初始化完成 "
f"(缓存TTL={cache_ttl_hours}小时, 自动降级={'开启' if enable_auto_fallback else '关闭'})"
)
async def get_adapter(
self,
api_identifier: str,
test_function: Optional[callable] = None
) -> BaseMCPAdapter:
"""
获取适合当前API的适配器
Args:
api_identifier: API标识符(如"openai_official", "azure_openai"等)
test_function: 可选的测试函数,用于检测API能力
Returns:
最适合的适配器实例
"""
# 检查缓存
capability = await self._get_cached_capability(api_identifier)
if capability is None and test_function:
# 缓存未命中,执行检测
capability = await self._detect_capability(api_identifier, test_function)
# 选择适配器
if capability and capability.supports_function_calling:
logger.info(f"🎯 使用Function Calling适配器: {api_identifier}")
return self.adapters[AdapterType.FUNCTION_CALLING]
else:
logger.info(f"🎯 使用提示词注入适配器: {api_identifier}")
return self.adapters[AdapterType.PROMPT_INJECTION]
async def _get_cached_capability(
self,
api_identifier: str
) -> Optional[APICapability]:
"""获取缓存的能力检测结果"""
async with self._cache_lock:
if api_identifier not in self._capability_cache:
return None
capability = self._capability_cache[api_identifier]
# 检查是否过期
if datetime.now() - capability.tested_at > self._cache_ttl:
logger.info(f"⏰ API能力缓存过期: {api_identifier}")
del self._capability_cache[api_identifier]
return None
logger.debug(f"🎯 API能力缓存命中: {api_identifier}")
return capability
async def _detect_capability(
self,
api_identifier: str,
test_function: callable
) -> APICapability:
"""
检测API能力
Args:
api_identifier: API标识符
test_function: 测试函数,应该尝试使用Function Calling
Returns:
能力检测结果
"""
logger.info(f"🔍 开始检测API能力: {api_identifier}")
start_time = time.time()
try:
# 调用测试函数
result = await test_function()
# 判断是否成功
supports_fc = self._is_function_calling_response(result)
duration_ms = (time.time() - start_time) * 1000
capability = APICapability(
supports_function_calling=supports_fc,
tested_at=datetime.now(),
test_duration_ms=duration_ms
)
# 缓存结果
async with self._cache_lock:
self._capability_cache[api_identifier] = capability
status = "✅ 支持" if supports_fc else "❌ 不支持"
logger.info(
f"{status} Function Calling: {api_identifier} "
f"(耗时: {duration_ms:.2f}ms)"
)
return capability
except Exception as e:
duration_ms = (time.time() - start_time) * 1000
logger.warning(
f"⚠️ API能力检测失败: {api_identifier}, 错误: {e}, "
f"将使用提示词注入模式"
)
capability = APICapability(
supports_function_calling=False,
tested_at=datetime.now(),
test_duration_ms=duration_ms,
error_message=str(e)
)
# 缓存失败结果(避免重复测试)
async with self._cache_lock:
self._capability_cache[api_identifier] = capability
return capability
def _is_function_calling_response(self, response: Any) -> bool:
"""
判断响应是否是Function Calling格式
Args:
response: API响应
Returns:
是否支持Function Calling
"""
try:
# 检查字典格式
if isinstance(response, dict):
message = response.get("choices", [{}])[0].get("message", {})
return "tool_calls" in message or "function_call" in message
# 检查对象格式(OpenAI SDK
if hasattr(response, "choices"):
message = response.choices[0].message
return hasattr(message, "tool_calls") or hasattr(message, "function_call")
return False
except Exception:
return False
async def call_with_fallback(
self,
api_identifier: str,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
test_function: Optional[callable] = None
) -> ToolCallResult:
"""
带降级策略的工具调用
Args:
api_identifier: API标识符
tools: MCP工具列表
user_message: 用户消息
call_function: 实际调用API的函数
test_function: 可选的测试函数
Returns:
工具调用结果
"""
# 获取适配器
adapter = await self.get_adapter(api_identifier, test_function)
# 首次尝试
try:
if adapter.supports_native_tools():
# Function Calling模式
logger.info("🚀 尝试使用Function Calling模式")
result = await self._try_function_calling(
tools, user_message, call_function, adapter
)
else:
# 提示词注入模式
logger.info("🚀 使用提示词注入模式")
result = await self._try_prompt_injection(
tools, user_message, call_function, adapter
)
return result
except Exception as e:
logger.error(f"❌ 工具调用失败: {e}")
# 自动降级
if self._enable_auto_fallback and adapter.supports_native_tools():
logger.warning("⚠️ Function Calling失败,降级到提示词注入模式")
# 更新缓存,标记为不支持
async with self._cache_lock:
self._capability_cache[api_identifier] = APICapability(
supports_function_calling=False,
tested_at=datetime.now(),
test_duration_ms=0,
error_message=str(e)
)
# 使用提示词注入重试
fallback_adapter = self.adapters[AdapterType.PROMPT_INJECTION]
return await self._try_prompt_injection(
tools, user_message, call_function, fallback_adapter
)
raise
async def _try_function_calling(
self,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
adapter: FunctionCallingAdapter
) -> ToolCallResult:
"""尝试Function Calling模式"""
# Function Calling不需要修改提示词
response = await call_function(
message=user_message,
tools_param=tools,
tool_choice_param="auto"
)
return adapter.parse_tool_calls(response)
async def _try_prompt_injection(
self,
tools: List[Dict[str, Any]],
user_message: str,
call_function: callable,
adapter: PromptInjectionAdapter
) -> ToolCallResult:
"""尝试提示词注入模式"""
# 注入工具到提示词
enhanced_prompt = adapter.format_tools_for_prompt(tools, user_message)
# 调用API(不传tools参数)
response = await call_function(
message=enhanced_prompt,
tools_param=None,
tool_choice_param=None
)
# 从文本响应中解析工具调用
return adapter.parse_tool_calls(response)
def clear_cache(self, api_identifier: Optional[str] = None):
"""
清理能力缓存
Args:
api_identifier: 可选,只清理特定API的缓存
"""
if api_identifier:
if api_identifier in self._capability_cache:
del self._capability_cache[api_identifier]
logger.info(f"🧹 已清理API能力缓存: {api_identifier}")
else:
self._capability_cache.clear()
logger.info("🧹 已清理所有API能力缓存")
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
return {
"total_cached": len(self._capability_cache),
"cache_ttl_hours": self._cache_ttl.total_seconds() / 3600,
"cached_apis": [
{
"api_identifier": api_id,
"supports_fc": cap.supports_function_calling,
"tested_at": cap.tested_at.isoformat(),
"test_duration_ms": cap.test_duration_ms
}
for api_id, cap in self._capability_cache.items()
]
}
# 全局单例
universal_mcp_adapter = UniversalMCPAdapter()
File diff suppressed because it is too large Load Diff
-385
View File
@@ -1,385 +0,0 @@
"""HTTP MCP客户端 - 使用官方 MCP Python SDK 实现"""
import asyncio
from typing import Dict, Any, List, Optional
from contextlib import asynccontextmanager
from mcp import ClientSession, types
from mcp.client.streamable_http import streamablehttp_client
from pydantic import AnyUrl
from anyio import ClosedResourceError
from app.logger import get_logger
logger = get_logger(__name__)
class MCPError(Exception):
"""MCP错误"""
pass
class HTTPMCPClient:
"""HTTP模式MCP客户端(基于官方 MCP Python SDK"""
def __init__(
self,
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0
):
"""
初始化HTTP MCP客户端
Args:
url: MCP服务器URL
headers: HTTP请求头
env: 环境变量(用于API Key等)
timeout: 超时时间(秒)
"""
self.url = url.rstrip('/')
self.headers = headers or {}
self.env = env or {}
self.timeout = timeout
# 如果env中有API Key,添加到headers
if 'API_KEY' in self.env:
self.headers['Authorization'] = f'Bearer {self.env["API_KEY"]}'
self._session: Optional[ClientSession] = None
self._context_stack = [] # 保存上下文管理器栈
self._initialized = False
self._lock = asyncio.Lock()
async def _ensure_connected(self):
"""确保连接已建立"""
async with self._lock:
if self._session is None:
try:
logger.info(f"🔗 连接到MCP服务器: {self.url}")
# 使用官方 SDK 的 streamable_http_client
# 保存上下文管理器以便后续正确清理
stream_context = streamablehttp_client(self.url)
read_stream, write_stream, _ = await stream_context.__aenter__()
self._context_stack.append(('stream', stream_context))
# 创建客户端会话
self._session = ClientSession(read_stream, write_stream)
session_context = self._session
await session_context.__aenter__()
self._context_stack.append(('session', session_context))
# 初始化会话
await self._session.initialize()
self._initialized = True
logger.info(f"✅ MCP会话初始化成功")
except Exception as e:
logger.error(f"❌ MCP连接失败: {e}")
await self._cleanup()
raise MCPError(f"连接MCP服务器失败: {str(e)}")
async def _cleanup(self):
"""清理连接资源(按照进入的相反顺序退出)"""
# 按照LIFO顺序清理上下文
while self._context_stack:
ctx_type, ctx = self._context_stack.pop()
try:
await ctx.__aexit__(None, None, None)
except RuntimeError as e:
# 忽略 anyio 的任务上下文错误(在关闭时可能发生)
if "cancel scope" in str(e).lower() or "different task" in str(e).lower():
logger.debug(f"忽略{ctx_type}上下文清理的任务切换警告: {e}")
else:
logger.error(f"清理{ctx_type}上下文失败: {e}")
except Exception as e:
logger.error(f"清理{ctx_type}上下文失败: {e}")
self._session = None
self._initialized = False
async def initialize(self) -> Dict[str, Any]:
"""
初始化MCP会话
Returns:
初始化响应
"""
await self._ensure_connected()
return {"status": "initialized"}
async def list_tools(self) -> List[Dict[str, Any]]:
"""
列举可用工具
Returns:
工具列表
"""
try:
await self._ensure_connected()
result = await self._session.list_tools()
# 转换为字典格式
tools = []
for tool in result.tools:
tool_dict = {
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema
}
tools.append(tool_dict)
logger.info(f"获取到 {len(tools)} 个工具")
return tools
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
raise MCPError(f"获取工具列表失败: {str(e)}")
async def call_tool(
self,
tool_name: str,
arguments: Dict[str, Any],
max_reconnect_attempts: int = 2
) -> Any:
"""
调用工具(带自动重连)
Args:
tool_name: 工具名称
arguments: 工具参数
max_reconnect_attempts: 最大重连尝试次数
Returns:
工具执行结果
"""
for attempt in range(max_reconnect_attempts + 1):
try:
await self._ensure_connected()
logger.info(f"调用工具: {tool_name}")
logger.debug(f" 参数类型: {type(arguments)}")
logger.debug(f" 参数内容: {arguments}")
logger.debug(f" 会话状态: initialized={self._initialized}, session={self._session is not None}")
result = await self._session.call_tool(tool_name, arguments)
logger.debug(f" 工具返回类型: {type(result)}")
logger.debug(f" 返回内容: {result}")
# 处理返回结果
# MCP SDK 返回 CallToolResult 对象
if result.content:
logger.debug(f" 返回content数量: {len(result.content)}")
# 提取第一个content的文本
for idx, content in enumerate(result.content):
logger.debug(f" content[{idx}]类型: {type(content)}")
if isinstance(content, types.TextContent):
logger.debug(f" ✅ 返回TextContent: {content.text[:100] if len(content.text) > 100 else content.text}")
return content.text
elif isinstance(content, types.ImageContent):
logger.debug(f" ✅ 返回ImageContent")
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
# 如果没有文本内容,返回原始内容
logger.debug(f" ⚠️ 返回原始content[0]")
return result.content[0] if result.content else None
# 如果有结构化内容(2025-06-18规范)
if hasattr(result, 'structuredContent') and result.structuredContent:
logger.debug(f" ✅ 返回structuredContent")
return result.structuredContent
logger.warning(f" ⚠️ 工具返回为None")
return None
except ClosedResourceError as e:
# 连接已关闭,尝试重连
if attempt < max_reconnect_attempts:
logger.warning(
f"⚠️ MCP连接已关闭,尝试重新连接 "
f"(第{attempt + 1}/{max_reconnect_attempts}次重连)"
)
await self._cleanup()
await asyncio.sleep(0.5) # 短暂延迟后重连
continue
else:
logger.error(f"❌ MCP连接重连失败,已达最大重试次数")
error_msg = f"连接已关闭且重连失败 (尝试了{max_reconnect_attempts}次)"
raise MCPError(error_msg)
except Exception as e:
logger.error(f"调用工具失败: {tool_name}, 错误: {e}", exc_info=True)
logger.error(f" 参数: {arguments}")
logger.error(f" 错误类型: {type(e).__name__}")
logger.error(f" 错误详情: {repr(e)}")
logger.error(f" 错误字符串: '{str(e)}'")
error_msg = str(e) or repr(e) or f"未知错误 ({type(e).__name__})"
raise MCPError(f"调用工具失败: {error_msg}")
# 理论上不会到这里
raise MCPError(f"工具调用失败: 未知错误")
async def list_resources(self) -> List[Dict[str, Any]]:
"""
列举可用资源
Returns:
资源列表
"""
try:
await self._ensure_connected()
result = await self._session.list_resources()
# 转换为字典格式
resources = []
for resource in result.resources:
resource_dict = {
"uri": str(resource.uri),
"name": resource.name,
"description": resource.description or "",
"mimeType": resource.mimeType or ""
}
resources.append(resource_dict)
logger.info(f"获取到 {len(resources)} 个资源")
return resources
except Exception as e:
logger.error(f"获取资源列表失败: {e}")
raise MCPError(f"获取资源列表失败: {str(e)}")
async def read_resource(self, uri: str) -> Any:
"""
读取资源
Args:
uri: 资源URI
Returns:
资源内容
"""
try:
await self._ensure_connected()
result = await self._session.read_resource(AnyUrl(uri))
# 提取资源内容
if result.contents:
content = result.contents[0]
if isinstance(content, types.TextContent):
return content.text
elif isinstance(content, types.ImageContent):
return {
"type": "image",
"data": content.data,
"mimeType": content.mimeType
}
elif isinstance(content, types.BlobResourceContents):
return {
"type": "blob",
"blob": content.blob,
"mimeType": content.mimeType
}
return None
except Exception as e:
logger.error(f"读取资源失败: {uri}, 错误: {e}")
raise MCPError(f"读取资源失败: {str(e)}")
async def test_connection(self) -> Dict[str, Any]:
"""
测试连接
Returns:
测试结果
"""
import time
start_time = time.time()
try:
# 尝试连接并列举工具(直接调用SDK,避免重复日志)
await self._ensure_connected()
result = await self._session.list_tools()
# 转换为字典格式
tools = []
for tool in result.tools:
tool_dict = {
"name": tool.name,
"description": tool.description or "",
"inputSchema": tool.inputSchema
}
tools.append(tool_dict)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
logger.info(f"✅ 连接测试成功,获取到 {len(tools)} 个工具")
return {
"success": True,
"message": "连接测试成功",
"response_time_ms": response_time,
"tools_count": len(tools),
"tools": tools
}
except Exception as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
return {
"success": False,
"message": "连接测试失败",
"response_time_ms": response_time,
"error": str(e),
"error_type": type(e).__name__,
"suggestions": [
"请检查服务器URL是否正确",
"请确认API Key是否有效",
"请检查网络连接",
"请确认MCP服务器是否在线"
]
}
async def close(self):
"""关闭客户端连接"""
logger.info(f"关闭MCP客户端: {self.url}")
await self._cleanup()
@asynccontextmanager
async def create_mcp_client(
url: str,
headers: Optional[Dict[str, str]] = None,
env: Optional[Dict[str, str]] = None,
timeout: float = 60.0
):
"""
创建MCP客户端的上下文管理器
Args:
url: MCP服务器URL
headers: HTTP请求头
env: 环境变量
timeout: 超时时间
Yields:
HTTPMCPClient实例
"""
client = HTTPMCPClient(url, headers, env, timeout)
try:
await client.initialize()
yield client
finally:
await client.close()
-527
View File
@@ -1,527 +0,0 @@
"""MCP插件注册表 - 管理运行时插件实例"""
import asyncio
import time
from typing import Dict, Optional, Any, List
from dataclasses import dataclass
from datetime import datetime
from app.mcp.http_client import HTTPMCPClient, MCPError
from app.mcp.config import mcp_config
from app.models.mcp_plugin import MCPPlugin
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class SessionInfo:
"""会话信息"""
client: HTTPMCPClient
created_at: float
last_access: float
request_count: int = 0
error_count: int = 0
status: str = "active" # active, degraded, error
class MCPPluginRegistry:
"""MCP插件注册表 - 管理运行时插件实例(优化版)"""
def __init__(
self,
max_clients: Optional[int] = None,
client_ttl: Optional[int] = None
):
"""
初始化注册表
Args:
max_clients: 最大缓存客户端数量(默认使用配置)
client_ttl: 客户端过期时间(秒,默认使用配置)
"""
# 存储格式: {plugin_id: SessionInfo}
self._sessions: Dict[str, SessionInfo] = {}
# 全局锁用于保护会话字典
self._sessions_lock = asyncio.Lock()
# 细粒度锁:每个用户一个锁
self._user_locks: Dict[str, asyncio.Lock] = {}
self._locks_lock = asyncio.Lock() # 保护locks字典本身
# 配置参数(使用配置常量)
self._max_clients = max_clients or mcp_config.MAX_CLIENTS
self._client_ttl = client_ttl or mcp_config.CLIENT_TTL_SECONDS
# 启动后台清理任务
self._cleanup_task = None
self._health_check_task = None
self._tasks_started = False
def _ensure_background_tasks(self):
"""确保后台任务已启动(延迟初始化)"""
if not self._tasks_started:
try:
# 检查是否有运行中的事件循环
loop = asyncio.get_running_loop()
if self._cleanup_task is None:
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
logger.info("✅ MCP插件注册表后台清理任务已启动")
if self._health_check_task is None:
self._health_check_task = asyncio.create_task(self._health_check_loop())
logger.info("✅ MCP会话健康检查任务已启动")
self._tasks_started = True
except RuntimeError:
# 没有运行中的事件循环,稍后再试
pass
async def _cleanup_loop(self):
"""后台清理过期客户端"""
while True:
try:
await asyncio.sleep(mcp_config.CLEANUP_INTERVAL_SECONDS)
await self._cleanup_expired_sessions()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"清理任务异常: {e}")
async def _health_check_loop(self):
"""后台健康检查"""
while True:
try:
await asyncio.sleep(mcp_config.HEALTH_CHECK_INTERVAL_SECONDS)
await self._check_session_health()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"健康检查任务异常: {e}")
async def _cleanup_expired_sessions(self):
"""清理过期的会话"""
now = time.time()
expired_ids = []
async with self._sessions_lock:
# 收集过期的plugin_id
for plugin_id, session in list(self._sessions.items()):
if now - session.last_access > self._client_ttl:
expired_ids.append(plugin_id)
if expired_ids:
logger.info(f"🧹 清理 {len(expired_ids)} 个过期的MCP会话")
for plugin_id in expired_ids:
# 提取user_id来获取对应的锁
user_id = plugin_id.split(':', 1)[0]
user_lock = await self._get_user_lock(user_id)
async with user_lock:
async with self._sessions_lock:
if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
async def _check_session_health(self):
"""增强的会话健康检查"""
async with self._sessions_lock:
for plugin_id, session in list(self._sessions.items()):
# 计算错误率
if session.request_count > mcp_config.MIN_REQUESTS_FOR_HEALTH_CHECK:
error_rate = session.error_count / session.request_count
# 动态调整状态(使用配置常量)
if error_rate > mcp_config.ERROR_RATE_CRITICAL:
if session.status != "error":
session.status = "error"
logger.error(
f"❌ 会话 {plugin_id} 错误率过高 "
f"({error_rate:.1%}), 标记为error"
)
elif error_rate > mcp_config.ERROR_RATE_WARNING:
if session.status == "active":
session.status = "degraded"
logger.warning(
f"⚠️ 会话 {plugin_id} 健康状况下降 "
f"(错误率: {error_rate:.1%})"
)
elif session.status == "degraded":
# 错误率降低,恢复正常
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
# 检查即将过期的会话(最后1分钟提醒)
idle_time = time.time() - session.last_access
time_until_expiry = self._client_ttl - idle_time
# 仅在最后1分钟(60秒)内提醒一次
if 0 < time_until_expiry <= 60:
# 使用会话属性避免重复提醒
if not hasattr(session, '_expiry_warned') or not session._expiry_warned:
logger.warning(
f"⏰ 会话 {plugin_id} 即将过期 "
f"(剩余 {time_until_expiry:.0f} 秒)"
)
session._expiry_warned = True
elif time_until_expiry > 60:
# 重置警告标志(如果会话被重新使用)
if hasattr(session, '_expiry_warned'):
session._expiry_warned = False
async def _get_user_lock(self, user_id: str) -> asyncio.Lock:
"""
获取用户专属的锁(细粒度锁)
Args:
user_id: 用户ID
Returns:
该用户的锁对象
"""
async with self._locks_lock:
if user_id not in self._user_locks:
self._user_locks[user_id] = asyncio.Lock()
return self._user_locks[user_id]
def _touch_session(self, plugin_id: str):
"""
更新会话的最后访问时间(需要在锁内调用)
Args:
plugin_id: 插件ID
"""
if plugin_id in self._sessions:
session = self._sessions[plugin_id]
session.last_access = time.time()
session.request_count += 1
async def _evict_lru_session(self):
"""驱逐最久未使用的会话(当达到max_clients限制时)"""
if len(self._sessions) >= self._max_clients:
# 找到最旧的会话
oldest_id = None
oldest_time = float('inf')
for plugin_id, session in self._sessions.items():
if session.last_access < oldest_time:
oldest_time = session.last_access
oldest_id = plugin_id
if oldest_id:
logger.info(f"📤 达到最大会话数量限制,驱逐: {oldest_id}")
await self._unload_plugin_unsafe(oldest_id)
async def load_plugin(self, plugin: MCPPlugin) -> bool:
"""
从配置加载插件
Args:
plugin: 插件配置
Returns:
是否加载成功
"""
# 确保后台任务已启动
self._ensure_background_tasks()
# 使用细粒度锁(只锁定当前用户)
user_lock = await self._get_user_lock(plugin.user_id)
async with user_lock:
try:
plugin_id = f"{plugin.user_id}:{plugin.plugin_name}"
# 如果已加载,先卸载
async with self._sessions_lock:
if plugin_id in self._sessions:
await self._unload_plugin_unsafe(plugin_id)
# 检查是否需要驱逐LRU会话
await self._evict_lru_session()
# 目前只支持HTTP类型
if plugin.plugin_type == "http":
if not plugin.server_url:
logger.error(f"HTTP插件缺少server_url: {plugin.plugin_name}")
return False
# 为每个插件创建独立的HTTP客户端
client = HTTPMCPClient(
url=plugin.server_url,
headers=plugin.headers or {},
env=plugin.env or {},
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
)
# 创建会话信息
now = time.time()
session = SessionInfo(
client=client,
created_at=now,
last_access=now,
request_count=0,
error_count=0,
status="active"
)
# 存储会话
async with self._sessions_lock:
self._sessions[plugin_id] = session
logger.info(f"✅ 加载MCP插件: {plugin_id} (独立会话)")
return True
else:
logger.warning(f"暂不支持的插件类型: {plugin.plugin_type}")
return False
except Exception as e:
logger.error(f"加载插件失败 {plugin.plugin_name}: {e}")
return False
async def unload_plugin(self, user_id: str, plugin_name: str):
"""
卸载插件
Args:
user_id: 用户ID
plugin_name: 插件名称
"""
# 使用细粒度锁(只锁定当前用户)
user_lock = await self._get_user_lock(user_id)
async with user_lock:
plugin_id = f"{user_id}:{plugin_name}"
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
async def _unload_plugin_unsafe(self, plugin_id: str):
"""卸载插件(不加锁,内部使用,需要在sessions_lock内调用)"""
if plugin_id in self._sessions:
session = self._sessions[plugin_id]
try:
await session.client.close()
except Exception as e:
logger.error(f"关闭插件客户端失败 {plugin_id}: {e}")
del self._sessions[plugin_id]
logger.info(f"卸载MCP插件: {plugin_id}")
async def reload_plugin(self, plugin: MCPPlugin) -> bool:
"""
重新加载插件
Args:
plugin: 插件配置
Returns:
是否重载成功
"""
await self.unload_plugin(plugin.user_id, plugin.plugin_name)
return await self.load_plugin(plugin)
def get_client(self, user_id: str, plugin_name: str) -> Optional[HTTPMCPClient]:
"""
获取插件客户端(线程安全,支持访问时间更新)
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
客户端实例或None
"""
plugin_id = f"{user_id}:{plugin_name}"
session = self._sessions.get(plugin_id)
if session:
# 检查会话状态
if session.status == "error":
logger.warning(
f"⚠️ 会话 {plugin_id} 处于错误状态,"
f"建议调用者重新加载插件"
)
# 不返回错误状态的客户端
return None
# ✅ 使用锁保护状态更新,避免并发问题
# 注意:这里使用原子操作更新简单字段,不需要异步锁
session.last_access = time.time()
session.request_count += 1
return session.client
return None
async def get_or_reconnect_client(
self,
user_id: str,
plugin_name: str,
plugin: MCPPlugin
) -> HTTPMCPClient:
"""
获取或重连客户端(自动处理错误状态)
Args:
user_id: 用户ID
plugin_name: 插件名称
plugin: 插件配置对象
Returns:
客户端实例
Raises:
ValueError: 插件加载失败
"""
plugin_id = f"{user_id}:{plugin_name}"
# 获取用户锁
user_lock = await self._get_user_lock(user_id)
async with user_lock:
session = self._sessions.get(plugin_id)
# 检查会话健康状态
if session and session.status == "error":
logger.warning(f"会话 {plugin_id} 处于错误状态,尝试重连")
async with self._sessions_lock:
await self._unload_plugin_unsafe(plugin_id)
session = None
# 如果没有会话,加载插件
if not session:
success = await self.load_plugin(plugin)
if not success:
raise ValueError(f"插件加载失败: {plugin_name}")
session = self._sessions[plugin_id]
return session.client
async def call_tool(
self,
user_id: str,
plugin_name: str,
tool_name: str,
arguments: Dict[str, Any]
) -> Any:
"""
调用插件工具(带错误计数和状态管理)
Args:
user_id: 用户ID
plugin_name: 插件名称
tool_name: 工具名称
arguments: 工具参数
Returns:
工具执行结果
Raises:
ValueError: 插件不存在或未启用
MCPError: 工具调用失败
"""
plugin_id = f"{user_id}:{plugin_name}"
# 获取会话
session = self._sessions.get(plugin_id)
if not session:
raise ValueError(f"插件未加载: {plugin_name}")
try:
result = await session.client.call_tool(tool_name, arguments)
logger.info(f"✅ 工具调用成功: {plugin_name}.{tool_name}")
# 调用成功,重置状态(如果之前是degraded)
if session.status == "degraded":
session.status = "active"
logger.info(f"✅ 会话 {plugin_id} 恢复正常")
return result
except Exception as e:
# 增加错误计数
session.error_count += 1
# 根据错误率更新状态
if session.request_count > 0:
error_rate = session.error_count / session.request_count
if error_rate > 0.5:
session.status = "error"
elif error_rate > 0.3:
session.status = "degraded"
logger.error(
f"❌ 工具调用失败: {plugin_name}.{tool_name}, "
f"错误: {e} (错误计数: {session.error_count}/{session.request_count})"
)
raise
async def get_plugin_tools(
self,
user_id: str,
plugin_name: str
) -> List[Dict[str, Any]]:
"""
获取插件的工具列表
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
工具列表
"""
client = self.get_client(user_id, plugin_name)
if not client:
raise ValueError(f"插件未加载: {plugin_name}")
try:
tools = await client.list_tools()
return tools
except Exception as e:
logger.error(f"获取工具列表失败: {plugin_name}, 错误: {e}")
raise
async def test_plugin(
self,
user_id: str,
plugin_name: str
) -> Dict[str, Any]:
"""
测试插件连接
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
测试结果
"""
client = self.get_client(user_id, plugin_name)
if not client:
raise ValueError(f"插件未加载: {plugin_name}")
return await client.test_connection()
async def cleanup_all(self):
"""清理所有插件和资源"""
# 停止后台任务
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
# 清理所有会话
async with self._sessions_lock:
plugin_ids = list(self._sessions.keys())
for plugin_id in plugin_ids:
await self._unload_plugin_unsafe(plugin_id)
logger.info("✅ 已清理所有MCP插件和资源")
# 全局注册表实例
mcp_registry = MCPPluginRegistry()
+50
View File
@@ -0,0 +1,50 @@
"""MCP插件状态同步服务
将内存中的会话状态变更同步到数据库,确保状态一致性。
"""
from typing import Dict, Any
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.models.mcp_plugin import MCPPlugin
from app.logger import get_logger
logger = get_logger(__name__)
async def sync_status_to_db(event: Dict[str, Any]):
"""
状态变更回调 - 同步到数据库
"""
user_id = event["user_id"]
plugin_name = event["plugin_name"]
new_status = event["new_status"]
reason = event.get("reason", "")
try:
from app.database import get_engine
engine = await get_engine(user_id)
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with AsyncSessionLocal() as db:
stmt = (
update(MCPPlugin)
.where(MCPPlugin.user_id == user_id, MCPPlugin.plugin_name == plugin_name)
.values(status=new_status, last_error=reason if new_status == "error" else None)
)
await db.execute(stmt)
await db.commit()
logger.debug(f"✅ 状态已同步到数据库: {plugin_name} -> {new_status}")
except Exception as e:
logger.error(f"❌ 状态同步失败: {plugin_name}, 错误: {e}")
def register_status_sync():
"""注册状态同步回调到MCP客户端"""
from app.mcp import mcp_client
mcp_client.register_status_callback(sync_status_to_db)
logger.info("✅ MCP状态同步服务已注册")