Files
DataClaw/dataclaw-api/app/api/mcp.py
T

136 lines
4.8 KiB
Python
Raw Normal View History

2026-03-27 22:06:00 +08:00
import json
import uuid
2026-03-29 23:30:32 +08:00
import asyncio
2026-03-27 22:06:00 +08:00
from typing import List, Optional
from pathlib import Path
2026-03-29 23:30:32 +08:00
from contextlib import AsyncExitStack
2026-03-27 22:06:00 +08:00
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
2026-03-29 23:30:32 +08:00
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
2026-03-27 22:06:00 +08:00
from app.schemas.mcp import MCPServer, MCPServerCreate, MCPServerUpdate
from app.core.data_root import get_data_root
router = APIRouter()
def get_mcp_servers_file() -> Path:
return get_data_root() / "mcp_servers.json"
def read_mcp_servers() -> List[dict]:
file_path = get_mcp_servers_file()
if not file_path.exists():
return []
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
return []
def write_mcp_servers(servers: List[dict]) -> None:
file_path = get_mcp_servers_file()
with open(file_path, "w", encoding="utf-8") as f:
json.dump(servers, f, indent=2, ensure_ascii=False)
2026-03-29 23:30:32 +08:00
async def _check_single_mcp_health(server: dict) -> str:
try:
async with AsyncExitStack() as stack:
server_type = server.get("type")
if server_type == "stdio":
params = StdioServerParameters(
command=server.get("command", ""),
args=server.get("args", []),
env=server.get("env")
)
read, write = await stack.enter_async_context(stdio_client(params))
elif server_type in ["sse", "streamableHttp"]:
read, write = await stack.enter_async_context(sse_client(server.get("url", "")))
else:
return "error: unsupported type"
session = await stack.enter_async_context(ClientSession(read, write))
await asyncio.wait_for(session.initialize(), timeout=5.0)
return "connected"
except Exception as e:
err_msg = str(e)
if "unhandled errors in a TaskGroup" in err_msg:
return "error: connection refused"
return f"error: {err_msg or 'unknown'}"
2026-03-27 22:06:00 +08:00
@router.get("/mcp", response_model=List[MCPServer])
2026-03-29 23:30:32 +08:00
async def list_mcp_servers(project_id: Optional[int] = None):
2026-03-27 22:06:00 +08:00
servers = read_mcp_servers()
if project_id is not None:
servers = [s for s in servers if s.get("project_id") == project_id]
2026-03-29 23:30:32 +08:00
if not servers:
return []
tasks = [_check_single_mcp_health(s) for s in servers]
statuses = await asyncio.gather(*tasks, return_exceptions=True)
needs_update = False
for server, status in zip(servers, statuses):
new_status = status if isinstance(status, str) else f"error: {str(status)}"
if server.get("status") != new_status:
server["status"] = new_status
needs_update = True
if needs_update:
# Write back to persist the new statuses
all_servers = read_mcp_servers()
for s in all_servers:
for checked_s in servers:
if s.get("id") == checked_s.get("id"):
s["status"] = checked_s["status"]
write_mcp_servers(all_servers)
2026-03-27 22:06:00 +08:00
return servers
@router.post("/mcp", response_model=MCPServer)
def create_mcp_server(server_in: MCPServerCreate):
servers = read_mcp_servers()
server_data = server_in.dict()
server_data["id"] = str(uuid.uuid4())
if "status" not in server_data or not server_data["status"]:
server_data["status"] = "disconnected"
servers.append(server_data)
write_mcp_servers(servers)
return server_data
@router.get("/mcp/{server_id}", response_model=MCPServer)
def get_mcp_server(server_id: str):
servers = read_mcp_servers()
for server in servers:
if server.get("id") == server_id:
return server
raise HTTPException(status_code=404, detail="MCP Server not found")
@router.put("/mcp/{server_id}", response_model=MCPServer)
def update_mcp_server(server_id: str, server_in: MCPServerUpdate):
servers = read_mcp_servers()
for i, server in enumerate(servers):
if server.get("id") == server_id:
update_data = server_in.dict(exclude_unset=True)
for key, value in update_data.items():
server[key] = value
servers[i] = server
write_mcp_servers(servers)
return server
raise HTTPException(status_code=404, detail="MCP Server not found")
@router.delete("/mcp/{server_id}")
def delete_mcp_server(server_id: str):
servers = read_mcp_servers()
filtered_servers = [s for s in servers if s.get("id") != server_id]
if len(servers) == len(filtered_servers):
raise HTTPException(status_code=404, detail="MCP Server not found")
write_mcp_servers(filtered_servers)
return {"status": "success"}