feat: add MCP health check

This commit is contained in:
qixinbo
2026-03-29 23:30:32 +08:00
parent aea97b6342
commit 72d25826a1
4 changed files with 181 additions and 15 deletions
+54 -1
View File
@@ -1,10 +1,15 @@
import json
import uuid
import asyncio
from typing import List, Optional
from pathlib import Path
from contextlib import AsyncExitStack
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from app.schemas.mcp import MCPServer, MCPServerCreate, MCPServerUpdate
from app.core.data_root import get_data_root
@@ -29,11 +34,59 @@ def write_mcp_servers(servers: List[dict]) -> None:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(servers, f, indent=2, ensure_ascii=False)
async def _check_single_mcp_health(server: dict) -> str:
try:
async with AsyncExitStack() as stack:
server_type = server.get("type")
if server_type == "stdio":
params = StdioServerParameters(
command=server.get("command", ""),
args=server.get("args", []),
env=server.get("env")
)
read, write = await stack.enter_async_context(stdio_client(params))
elif server_type in ["sse", "streamableHttp"]:
read, write = await stack.enter_async_context(sse_client(server.get("url", "")))
else:
return "error: unsupported type"
session = await stack.enter_async_context(ClientSession(read, write))
await asyncio.wait_for(session.initialize(), timeout=5.0)
return "connected"
except Exception as e:
err_msg = str(e)
if "unhandled errors in a TaskGroup" in err_msg:
return "error: connection refused"
return f"error: {err_msg or 'unknown'}"
@router.get("/mcp", response_model=List[MCPServer])
def list_mcp_servers(project_id: Optional[int] = None):
async def list_mcp_servers(project_id: Optional[int] = None):
servers = read_mcp_servers()
if project_id is not None:
servers = [s for s in servers if s.get("project_id") == project_id]
if not servers:
return []
tasks = [_check_single_mcp_health(s) for s in servers]
statuses = await asyncio.gather(*tasks, return_exceptions=True)
needs_update = False
for server, status in zip(servers, statuses):
new_status = status if isinstance(status, str) else f"error: {str(status)}"
if server.get("status") != new_status:
server["status"] = new_status
needs_update = True
if needs_update:
# Write back to persist the new statuses
all_servers = read_mcp_servers()
for s in all_servers:
for checked_s in servers:
if s.get("id") == checked_s.get("id"):
s["status"] = checked_s["status"]
write_mcp_servers(all_servers)
return servers
@router.post("/mcp", response_model=MCPServer)