update: 修复基于长亭monkeycode扫描结果的12处安全漏洞
This commit is contained in:
@@ -24,11 +24,22 @@ from app.user_manager import User
|
||||
from app.mcp import mcp_client, MCPPluginConfig, PluginStatus
|
||||
from app.services.mcp_test_service import mcp_test_service
|
||||
from app.logger import get_logger
|
||||
from app.security import validate_public_http_url
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/mcp/plugins", tags=["MCP插件管理"])
|
||||
|
||||
HTTP_PLUGIN_TYPES = {"http", "streamable_http", "sse"}
|
||||
|
||||
|
||||
def _validate_mcp_server_url(plugin_type: str, server_url: Optional[str]) -> Optional[str]:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES:
|
||||
if not server_url:
|
||||
raise HTTPException(status_code=400, detail=f"{plugin_type}类型插件必须提供server_url")
|
||||
return validate_public_http_url(server_url)
|
||||
return server_url
|
||||
|
||||
|
||||
def require_login(request: Request) -> User:
|
||||
"""依赖:要求用户已登录"""
|
||||
@@ -53,7 +64,8 @@ async def _register_plugin_background(
|
||||
try:
|
||||
logger.info(f"后台注册MCP插件: {plugin_name}")
|
||||
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin_name,
|
||||
@@ -123,11 +135,12 @@ async def _register_plugin_to_facade(plugin: MCPPlugin, user_id: str) -> bool:
|
||||
Returns:
|
||||
是否注册成功
|
||||
"""
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
|
||||
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
|
||||
return await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
url=server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers,
|
||||
timeout=plugin.config.get('timeout', 60.0) if plugin.config else 60.0
|
||||
@@ -187,6 +200,10 @@ async def create_plugin(
|
||||
|
||||
# 创建插件数据
|
||||
plugin_data = data.model_dump()
|
||||
plugin_data["server_url"] = _validate_mcp_server_url(
|
||||
plugin_data.get("plugin_type", "http"),
|
||||
plugin_data.get("server_url")
|
||||
)
|
||||
|
||||
# 如果没有提供display_name,使用plugin_name作为默认值
|
||||
if not plugin_data.get("display_name"):
|
||||
@@ -278,12 +295,9 @@ async def create_plugin_simple(
|
||||
"sort_order": 0
|
||||
}
|
||||
|
||||
if server_type in ["http", "streamable_http", "sse"]:
|
||||
plugin_data["server_url"] = server_config.get("url")
|
||||
if server_type in HTTP_PLUGIN_TYPES:
|
||||
plugin_data["server_url"] = _validate_mcp_server_url(server_type, server_config.get("url"))
|
||||
plugin_data["headers"] = server_config.get("headers", {})
|
||||
|
||||
if not plugin_data["server_url"]:
|
||||
raise HTTPException(status_code=400, detail=f"{server_type}类型插件必须提供url字段")
|
||||
|
||||
elif server_type == "stdio":
|
||||
plugin_data["command"] = server_config.get("command")
|
||||
@@ -415,6 +429,12 @@ async def update_plugin(
|
||||
|
||||
# 更新字段
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
target_type = update_data.get("plugin_type", plugin.plugin_type)
|
||||
if "server_url" in update_data or target_type in HTTP_PLUGIN_TYPES:
|
||||
update_data["server_url"] = _validate_mcp_server_url(
|
||||
target_type,
|
||||
update_data.get("server_url", plugin.server_url)
|
||||
)
|
||||
for key, value in update_data.items():
|
||||
setattr(plugin, key, value)
|
||||
|
||||
@@ -501,7 +521,8 @@ async def toggle_plugin(
|
||||
if enabled:
|
||||
# 启用:注册到统一门面
|
||||
try:
|
||||
if plugin_type in ["http", "streamable_http", "sse"] and server_url:
|
||||
if plugin_type in HTTP_PLUGIN_TYPES and server_url:
|
||||
server_url = _validate_mcp_server_url(plugin_type, server_url)
|
||||
success = await mcp_client.register(MCPPluginConfig(
|
||||
user_id=user.user_id,
|
||||
plugin_name=plugin_name,
|
||||
@@ -647,11 +668,12 @@ async def _ensure_plugin_registered(
|
||||
"""
|
||||
try:
|
||||
# 使用ensure_registered方法,它会检查是否已注册
|
||||
if plugin.plugin_type in ["http", "streamable_http", "sse"] and plugin.server_url:
|
||||
if plugin.plugin_type in HTTP_PLUGIN_TYPES and plugin.server_url:
|
||||
server_url = _validate_mcp_server_url(plugin.plugin_type, plugin.server_url)
|
||||
return await mcp_client.ensure_registered(
|
||||
user_id=user_id,
|
||||
plugin_name=plugin.plugin_name,
|
||||
url=plugin.server_url,
|
||||
url=server_url,
|
||||
plugin_type=plugin.plugin_type,
|
||||
headers=plugin.headers
|
||||
)
|
||||
@@ -912,4 +934,4 @@ async def call_mcp_tool(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具失败: {plugin.plugin_name}.{data.tool_name}, 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"工具调用失败: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user