fix:1.优化mcp插件功能,改用mcp sdk库

This commit is contained in:
xiamuceer
2025-11-08 12:32:32 +08:00
parent 88115a45c5
commit c7c1c1fdf3
9 changed files with 1278 additions and 660 deletions
+320
View File
@@ -0,0 +1,320 @@
"""MCP插件测试服务 - 专门处理插件测试逻辑"""
import time
import json
from typing import Dict, Any, Optional
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.mcp_plugin import MCPPlugin
from app.models.settings import Settings as UserSettings
from app.mcp.registry import mcp_registry
from app.services.ai_service import create_user_ai_service
from app.schemas.mcp_plugin import MCPTestResult
from app.logger import get_logger
from app.user_manager import User
logger = get_logger(__name__)
class MCPTestService:
"""MCP插件测试服务(分离的测试逻辑)"""
async def test_plugin_connection(
self,
plugin: MCPPlugin,
user_id: str
) -> MCPTestResult:
"""
简单连接测试
Args:
plugin: 插件配置
user_id: 用户ID
Returns:
测试结果
"""
start_time = time.time()
try:
# 确保插件已加载
if not mcp_registry.get_client(user_id, plugin.plugin_name):
success = await mcp_registry.load_plugin(plugin)
if not success:
return MCPTestResult(
success=False,
message="插件加载失败",
error="无法创建MCP客户端",
suggestions=["请检查插件配置", "请确认服务器URL正确"]
)
# 测试连接并获取工具列表
test_result = await mcp_registry.test_plugin(user_id, plugin.plugin_name)
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
if test_result["success"]:
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功",
response_time_ms=response_time,
tools_count=test_result.get("tools_count", 0),
suggestions=[
f"响应时间: {response_time}ms",
f"可用工具数: {test_result.get('tools_count', 0)}"
]
)
else:
return MCPTestResult(**test_result)
except Exception as e:
end_time = time.time()
response_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=response_time,
error=str(e),
error_type=type(e).__name__,
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
async def test_plugin_with_ai(
self,
plugin: MCPPlugin,
user: User,
db_session: AsyncSession
) -> MCPTestResult:
"""
使用AI进行智能工具调用测试
Args:
plugin: 插件配置
user: 用户对象
db_session: 数据库会话
Returns:
测试结果
"""
start_time = time.time()
try:
# 1. 先进行连接测试
connection_result = await self.test_plugin_connection(plugin, user.user_id)
if not connection_result.success:
return connection_result
# 2. 获取工具列表
tools = await mcp_registry.get_plugin_tools(user.user_id, plugin.plugin_name)
if not tools:
return MCPTestResult(
success=False,
message="插件没有提供任何工具",
error="工具列表为空",
response_time_ms=connection_result.response_time_ms,
suggestions=["请检查插件配置", "请确认MCP服务器正常运行"]
)
# 3. 获取用户的AI设置
settings_result = await db_session.execute(
select(UserSettings).where(UserSettings.user_id == user.user_id)
)
user_settings = settings_result.scalar_one_or_none()
if not user_settings or not user_settings.api_key:
# 没有AI配置,返回简单测试结果
logger.warning("用户未配置AI服务,跳过智能测试")
return MCPTestResult(
success=True,
message=f"✅ 连接测试成功(未配置AI,跳过工具调用测试)",
response_time_ms=connection_result.response_time_ms,
tools_count=len(tools),
suggestions=[
f"连接测试: 成功",
f"可用工具数: {len(tools)}",
"提示: 配置AI服务后可进行智能工具调用测试"
]
)
# 4. 使用AI选择工具并生成测试参数
logger.info(f"使用AI分析工具并生成测试计划...")
ai_service = create_user_ai_service(
api_provider=user_settings.api_provider,
api_key=user_settings.api_key,
api_base_url=user_settings.api_base_url,
model_name=user_settings.llm_model,
temperature=0.3,
max_tokens=1000
)
# 转换为OpenAI Function Calling格式
openai_tools = self._convert_tools_to_openai_format(tools)
# 调用AI选择工具
prompt = f"""你是MCP插件测试助手,需要测试插件 '{plugin.plugin_name}' 的功能。
⚠️ 重要规则:生成参数时,必须严格使用工具 schema 中定义的原始参数名称,不要转换为 snake_case 或其他格式。
例如:如果 schema 中是 'nextThoughtNeeded',就必须使用 'nextThoughtNeeded',不能改成 'next_thought_needed'
请选择一个合适的工具进行测试,优先选择搜索、查询类工具。
生成真实有效的测试参数(例如搜索"人工智能最新进展"而不是"test")。
现在开始测试这个插件。"""
system_prompt = """你是专业的API测试工具。当给定工具列表时,选择一个工具并使用合适的参数调用它。
⚠️ 关键规则:调用工具时,必须严格使用 schema 中定义的原始参数名,不要自行转换命名风格。
- 如果参数名是 camelCase(如 nextThoughtNeeded),就使用 camelCase
- 如果参数名是 snake_case(如 next_thought),就使用 snake_case
- 保持与 schema 中定义的完全一致,包括大小写和命名风格"""
ai_response = await ai_service.generate_text(
prompt=prompt,
system_prompt=system_prompt,
tools=openai_tools,
tool_choice="required"
)
# 5. 检查AI是否返回工具调用
if not ai_response.get("tool_calls"):
logger.error(f"❌ AI未返回工具调用")
return MCPTestResult(
success=False,
message="❌ AI Function Calling失败",
error=f"AI未返回工具调用请求。响应: {ai_response.get('content', 'N/A')[:200]}",
tools_count=len(tools),
suggestions=[
"请确认使用的AI模型支持Function Calling",
f"当前Provider: {user_settings.api_provider}",
f"当前模型: {user_settings.llm_model}"
]
)
# 6. 解析工具调用
tool_call = ai_response["tool_calls"][0]
function = tool_call["function"]
tool_name = function["name"]
test_arguments = function["arguments"]
if isinstance(test_arguments, str):
try:
test_arguments = json.loads(test_arguments)
except json.JSONDecodeError as e:
logger.error(f"❌ 解析AI参数失败: {e}")
return MCPTestResult(
success=False,
message="❌ AI返回的参数格式错误",
error=f"无法解析参数JSON: {str(e)}",
tools_count=len(tools)
)
logger.info(f"🤖 AI选择的工具: {tool_name}")
logger.info(f"📝 AI生成的参数: {test_arguments}")
# 7. 调用MCP工具
call_start = time.time()
try:
tool_result = await mcp_registry.call_tool(
user.user_id,
plugin.plugin_name,
tool_name,
test_arguments
)
call_end = time.time()
call_time = round((call_end - call_start) * 1000, 2)
total_time = round((call_end - start_time) * 1000, 2)
# 格式化结果
result_str = str(tool_result)
if len(result_str) > 800:
result_preview = result_str[:800] + "\n...(结果已截断)"
else:
result_preview = result_str
return MCPTestResult(
success=True,
message=f"✅ Function Calling测试成功!工具 '{tool_name}' 调用正常",
response_time_ms=total_time,
tools_count=len(tools),
suggestions=[
f"🤖 AI选择: {tool_name}",
f"📝 参数: {json.dumps(test_arguments, ensure_ascii=False)}",
f"⏱️ 耗时: {call_time}ms",
f"📊 结果:\n{result_preview}"
]
)
except Exception as call_error:
call_end = time.time()
total_time = round((call_end - start_time) * 1000, 2)
logger.warning(f"工具调用失败: {tool_name}, 错误: {call_error}")
return MCPTestResult(
success=True, # 连接成功就算测试通过
message=f"⚠️ 连接成功,但工具调用失败",
response_time_ms=total_time,
tools_count=len(tools),
error=f"工具 '{tool_name}' 调用失败: {str(call_error)}",
suggestions=[
f"✅ 连接测试: 成功",
f"❌ 工具调用测试: 失败",
f"🤖 AI选择: {tool_name}",
f"❌ 错误: {str(call_error)}",
"💡 可能原因: API Key无效、参数错误或服务限制"
]
)
except Exception as e:
end_time = time.time()
total_time = round((end_time - start_time) * 1000, 2)
logger.error(f"测试插件失败: {plugin.plugin_name}, 错误: {e}")
return MCPTestResult(
success=False,
message="❌ 测试失败",
response_time_ms=total_time,
error=str(e),
error_type=type(e).__name__,
suggestions=[
"请检查服务器是否在线",
"请确认配置正确",
"请检查API Key是否有效"
]
)
def _convert_tools_to_openai_format(self, tools: list) -> list:
"""将MCP工具格式转换为OpenAI Function Calling格式"""
openai_tools = []
for tool in tools:
openai_tool = {
"type": "function",
"function": {
"name": tool["name"],
"description": tool.get("description", ""),
}
}
if "inputSchema" in tool:
openai_tool["function"]["parameters"] = tool["inputSchema"]
openai_tools.append(openai_tool)
return openai_tools
# 全局单例
mcp_test_service = MCPTestService()
+330 -22
View File
@@ -5,26 +5,100 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import asyncio
import json
from datetime import datetime
import time
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from collections import defaultdict
from app.models.mcp_plugin import MCPPlugin
from app.mcp.registry import mcp_registry
from app.mcp.config import mcp_config
from app.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ToolMetrics:
"""工具调用指标"""
total_calls: int = 0
success_calls: int = 0
failed_calls: int = 0
total_duration_ms: float = 0.0
avg_duration_ms: float = 0.0
last_call_time: Optional[datetime] = None
def update_success(self, duration_ms: float):
"""更新成功调用指标"""
self.total_calls += 1
self.success_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
def update_failure(self, duration_ms: float):
"""更新失败调用指标"""
self.total_calls += 1
self.failed_calls += 1
self.total_duration_ms += duration_ms
self.avg_duration_ms = self.total_duration_ms / self.total_calls
self.last_call_time = datetime.now()
@property
def success_rate(self) -> float:
"""成功率"""
if self.total_calls == 0:
return 0.0
return self.success_calls / self.total_calls
@dataclass
class ToolCacheEntry:
"""工具缓存条目"""
tools: List[Dict[str, Any]]
expire_time: datetime
hit_count: int = 0
class MCPToolServiceError(Exception):
"""MCP工具服务异常"""
pass
class MCPToolService:
"""MCP工具服务 - 统一管理MCP工具的注入和执行"""
"""MCP工具服务 - 统一管理MCP工具的注入和执行(优化版)"""
def __init__(self):
self._tool_cache = {} # 工具定义缓存
self._result_cache = {} # 工具结果缓存(可选)
def __init__(
self,
cache_ttl_minutes: Optional[int] = None,
max_retries: Optional[int] = None
):
"""
初始化MCP工具服务
Args:
cache_ttl_minutes: 工具缓存TTL(分钟,默认使用配置)
max_retries: 最大重试次数(默认使用配置)
"""
# 工具定义缓存: {cache_key: ToolCacheEntry}
self._tool_cache: Dict[str, ToolCacheEntry] = {}
self._cache_ttl = timedelta(
minutes=cache_ttl_minutes or mcp_config.TOOL_CACHE_TTL_MINUTES
)
# 调用指标: {tool_key: ToolMetrics}
self._metrics: Dict[str, ToolMetrics] = defaultdict(ToolMetrics)
# 重试配置(使用配置常量)
self._max_retries = max_retries or mcp_config.MAX_RETRIES
self._base_retry_delay = mcp_config.BASE_RETRY_DELAY_SECONDS
self._max_retry_delay = mcp_config.MAX_RETRY_DELAY_SECONDS
logger.info(
f"✅ MCPToolService初始化完成 "
f"(缓存TTL={self._cache_ttl.total_seconds()/60:.1f}分钟, "
f"最大重试={self._max_retries}次)"
)
async def get_user_enabled_tools(
self,
@@ -61,7 +135,7 @@ class MCPToolService:
logger.info(f"用户 {user_id} 没有启用的MCP插件")
return []
# 2. 获取所有工具定义
# 2. 获取所有工具定义(使用缓存)
all_tools = []
for plugin in plugins:
try:
@@ -73,8 +147,8 @@ class MCPToolService:
logger.warning(f"插件 {plugin.plugin_name} 加载失败,跳过")
continue
# 从registry获取该插件的工具列表
plugin_tools = await mcp_registry.get_plugin_tools(
# ✅ 使用缓存获取工具列表
plugin_tools = await self._get_plugin_tools_cached(
user_id=user_id,
plugin_name=plugin.plugin_name
)
@@ -82,7 +156,7 @@ class MCPToolService:
# 格式化为Function Calling格式
formatted_tools = self._format_tools_for_ai(
plugin_tools,
plugin.plugin_name # ✅ 修复:使用正确的属性名plugin_name
plugin.plugin_name
)
all_tools.extend(formatted_tools)
@@ -139,12 +213,85 @@ class MCPToolService:
return formatted_tools
async def _get_plugin_tools_cached(
self,
user_id: str,
plugin_name: str
) -> List[Dict[str, Any]]:
"""
带缓存的工具列表获取
Args:
user_id: 用户ID
plugin_name: 插件名称
Returns:
工具列表
"""
cache_key = f"{user_id}:{plugin_name}"
now = datetime.now()
# 检查缓存
if cache_key in self._tool_cache:
entry = self._tool_cache[cache_key]
if now < entry.expire_time:
entry.hit_count += 1
logger.debug(
f"🎯 工具缓存命中: {cache_key} "
f"(命中次数: {entry.hit_count})"
)
return entry.tools
else:
logger.debug(f"⏰ 工具缓存过期: {cache_key}")
del self._tool_cache[cache_key]
# 缓存未命中,从MCP获取
logger.debug(f"🔍 工具缓存未命中,从MCP获取: {cache_key}")
tools = await mcp_registry.get_plugin_tools(user_id, plugin_name)
# 更新缓存
self._tool_cache[cache_key] = ToolCacheEntry(
tools=tools,
expire_time=now + self._cache_ttl,
hit_count=0
)
return tools
def clear_cache(self, user_id: Optional[str] = None, plugin_name: Optional[str] = None):
"""
清理缓存
Args:
user_id: 用户ID(可选,清理特定用户的缓存)
plugin_name: 插件名称(可选,清理特定插件的缓存)
"""
if user_id is None and plugin_name is None:
# 清理所有缓存
self._tool_cache.clear()
logger.info("🧹 已清理所有工具缓存")
elif user_id and plugin_name:
# 清理特定插件缓存
cache_key = f"{user_id}:{plugin_name}"
if cache_key in self._tool_cache:
del self._tool_cache[cache_key]
logger.info(f"🧹 已清理缓存: {cache_key}")
elif user_id:
# 清理用户所有缓存
keys_to_delete = [
key for key in self._tool_cache.keys()
if key.startswith(f"{user_id}:")
]
for key in keys_to_delete:
del self._tool_cache[key]
logger.info(f"🧹 已清理用户缓存: {user_id} ({len(keys_to_delete)}个)")
async def execute_tool_calls(
self,
user_id: str,
tool_calls: List[Dict[str, Any]],
db_session: AsyncSession,
timeout: float = 60.0
timeout: Optional[float] = None
) -> List[Dict[str, Any]]:
"""
批量执行AI请求的工具调用(并行执行)
@@ -153,7 +300,7 @@ class MCPToolService:
user_id: 用户ID
tool_calls: AI返回的工具调用列表
db_session: 数据库会话
timeout: 单个工具调用的超时时间(秒,默认30秒
timeout: 单个工具调用的超时时间(秒,默认使用配置
Returns:
工具调用结果列表
@@ -161,7 +308,10 @@ class MCPToolService:
if not tool_calls:
return []
logger.info(f"开始执行 {len(tool_calls)} 个工具调用")
# 使用配置的默认超时
actual_timeout = timeout or mcp_config.TOOL_CALL_TIMEOUT_SECONDS
logger.info(f"开始执行 {len(tool_calls)} 个工具调用 (超时={actual_timeout}s)")
# 创建异步任务列表
tasks = [
@@ -169,7 +319,7 @@ class MCPToolService:
user_id=user_id,
tool_call=tool_call,
db_session=db_session,
timeout=timeout
timeout=actual_timeout
)
for tool_call in tool_calls
]
@@ -238,18 +388,28 @@ class MCPToolService:
f"参数: {arguments}"
)
# 设置超时
# ✅ 使用带重试的调用
tool_key = f"{plugin_name}.{tool_name}"
start_time = time.time()
try:
result = await asyncio.wait_for(
mcp_registry.call_tool(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments
),
result = await self._call_tool_with_retry(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments,
timeout=timeout
)
# 记录成功指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_success(duration_ms)
logger.info(
f"✅ 工具调用成功: {tool_key} "
f"(耗时: {duration_ms:.2f}ms)"
)
# 成功返回
return {
"tool_call_id": tool_call_id,
@@ -261,13 +421,21 @@ class MCPToolService:
}
except asyncio.TimeoutError:
# 记录失败指标
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
raise MCPToolServiceError(
f"工具调用超时(>{timeout}秒)"
)
except Exception as e:
# 记录失败指标
tool_key = f"{plugin_name}.{tool_name}" if 'plugin_name' in locals() else function_name
duration_ms = (time.time() - start_time) * 1000
self._metrics[tool_key].update_failure(duration_ms)
logger.error(
f"工具 {function_name} 调用失败: {e}",
f"工具 {function_name} 调用失败: {e}",
exc_info=True
)
return {
@@ -279,6 +447,146 @@ class MCPToolService:
"error": str(e)
}
async def _call_tool_with_retry(
self,
user_id: str,
plugin_name: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: float
) -> Any:
"""
带指数退避重试的工具调用
Args:
user_id: 用户ID
plugin_name: 插件名称
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
工具执行结果
Raises:
MCPToolServiceError: 工具调用失败
asyncio.TimeoutError: 调用超时
"""
last_exception = None
for attempt in range(self._max_retries):
try:
# 尝试调用工具
result = await asyncio.wait_for(
mcp_registry.call_tool(
user_id=user_id,
plugin_name=plugin_name,
tool_name=tool_name,
arguments=arguments
),
timeout=timeout
)
# 成功则返回
if attempt > 0:
logger.info(
f"✅ 重试成功: {plugin_name}.{tool_name} "
f"(第{attempt + 1}次尝试)"
)
return result
except asyncio.TimeoutError:
# 超时不重试,直接抛出
raise
except Exception as e:
last_exception = e
# 最后一次尝试失败
if attempt == self._max_retries - 1:
logger.error(
f"❌ 重试失败: {plugin_name}.{tool_name} "
f"(已尝试{self._max_retries}次): {e}"
)
raise MCPToolServiceError(
f"工具调用失败(已重试{self._max_retries}次): {str(e)}"
)
# 计算指数退避延迟
delay = min(
self._base_retry_delay * (2 ** attempt),
self._max_retry_delay
)
logger.warning(
f"⚠️ 工具调用失败,{delay:.1f}秒后重试 "
f"(第{attempt + 1}/{self._max_retries}次): "
f"{plugin_name}.{tool_name} - {e}"
)
await asyncio.sleep(delay)
# 理论上不会到这里,但为了类型安全
raise MCPToolServiceError(f"工具调用失败: {last_exception}")
def get_metrics(self, tool_name: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
"""
获取工具调用指标
Args:
tool_name: 工具名称(可选,获取特定工具的指标)
Returns:
指标字典
"""
if tool_name:
if tool_name in self._metrics:
metric = self._metrics[tool_name]
return {
tool_name: {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": metric.success_rate,
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
}
return {}
# 返回所有工具的指标
result = {}
for tool_key, metric in self._metrics.items():
result[tool_key] = {
"total_calls": metric.total_calls,
"success_calls": metric.success_calls,
"failed_calls": metric.failed_calls,
"success_rate": round(metric.success_rate, 3),
"avg_duration_ms": round(metric.avg_duration_ms, 2),
"last_call_time": metric.last_call_time.isoformat() if metric.last_call_time else None
}
return result
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
total_entries = len(self._tool_cache)
total_hits = sum(entry.hit_count for entry in self._tool_cache.values())
return {
"total_entries": total_entries,
"total_hits": total_hits,
"cache_ttl_minutes": self._cache_ttl.total_seconds() / 60,
"entries": [
{
"key": key,
"tools_count": len(entry.tools),
"hit_count": entry.hit_count,
"expire_time": entry.expire_time.isoformat()
}
for key, entry in self._tool_cache.items()
]
}
async def build_tool_context(
self,
tool_results: List[Dict[str, Any]],
+2 -2
View File
@@ -226,8 +226,8 @@ class PlotAnalyzer:
)
# 🔍 添加调试日志:查看AI返回的原始内容
logger.info(f"🔍 AI返回类型: {type(response)}")
logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
# logger.info(f"🔍 AI返回类型: {type(response)}")
# logger.info(f"🔍 AI返回内容(前500字符): {str(response)}")
# 从返回的字典中提取content字段
if isinstance(response, dict):