feat: add MCP

This commit is contained in:
qixinbo
2026-03-27 22:06:00 +08:00
parent 5d013231bc
commit b24aff956a
8 changed files with 600 additions and 69 deletions
+82
View File
@@ -0,0 +1,82 @@
import json
import uuid
from typing import List, Optional
from pathlib import Path
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from app.schemas.mcp import MCPServer, MCPServerCreate, MCPServerUpdate
from app.core.data_root import get_data_root
router = APIRouter()
def get_mcp_servers_file() -> Path:
return get_data_root() / "mcp_servers.json"
def read_mcp_servers() -> List[dict]:
file_path = get_mcp_servers_file()
if not file_path.exists():
return []
try:
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError:
return []
def write_mcp_servers(servers: List[dict]) -> None:
file_path = get_mcp_servers_file()
with open(file_path, "w", encoding="utf-8") as f:
json.dump(servers, f, indent=2, ensure_ascii=False)
@router.get("/mcp", response_model=List[MCPServer])
def list_mcp_servers(project_id: Optional[int] = None):
servers = read_mcp_servers()
if project_id is not None:
servers = [s for s in servers if s.get("project_id") == project_id]
return servers
@router.post("/mcp", response_model=MCPServer)
def create_mcp_server(server_in: MCPServerCreate):
servers = read_mcp_servers()
server_data = server_in.dict()
server_data["id"] = str(uuid.uuid4())
if "status" not in server_data or not server_data["status"]:
server_data["status"] = "disconnected"
servers.append(server_data)
write_mcp_servers(servers)
return server_data
@router.get("/mcp/{server_id}", response_model=MCPServer)
def get_mcp_server(server_id: str):
servers = read_mcp_servers()
for server in servers:
if server.get("id") == server_id:
return server
raise HTTPException(status_code=404, detail="MCP Server not found")
@router.put("/mcp/{server_id}", response_model=MCPServer)
def update_mcp_server(server_id: str, server_in: MCPServerUpdate):
servers = read_mcp_servers()
for i, server in enumerate(servers):
if server.get("id") == server_id:
update_data = server_in.dict(exclude_unset=True)
for key, value in update_data.items():
server[key] = value
servers[i] = server
write_mcp_servers(servers)
return server
raise HTTPException(status_code=404, detail="MCP Server not found")
@router.delete("/mcp/{server_id}")
def delete_mcp_server(server_id: str):
servers = read_mcp_servers()
filtered_servers = [s for s in servers if s.get("id") != server_id]
if len(servers) == len(filtered_servers):
raise HTTPException(status_code=404, detail="MCP Server not found")
write_mcp_servers(filtered_servers)
return {"status": "success"}
+55 -37
View File
@@ -44,7 +44,7 @@ class NanobotIntegration:
self.cron: CronService | None = None
self.config: Config | None = None
self._started = False
self._model_agent_cache: Dict[str, AgentLoop] = {}
self._model_agent_cache: Dict[tuple[str | None, int | None], AgentLoop] = {}
self._model_agent_lock = asyncio.Lock()
def initialize(self):
@@ -189,7 +189,7 @@ class NanobotIntegration:
self.cron.stop()
self._started = False
def _build_agent_for_provider(self, provider: Any) -> AgentLoop:
def _build_agent_for_provider(self, provider: Any, mcp_servers: dict | None = None) -> AgentLoop:
return AgentLoop(
bus=self.bus,
provider=provider,
@@ -205,26 +205,48 @@ class NanobotIntegration:
exec_config=self.config.tools.exec,
cron_service=self.cron,
restrict_to_workspace=self.config.tools.restrict_to_workspace,
session_manager=self.agent.sessions,
mcp_servers=self.config.tools.mcp_servers,
session_manager=self.agent.sessions if self.agent else None,
mcp_servers=mcp_servers if mcp_servers is not None else self.config.tools.mcp_servers,
channels_config=self.config.channels,
)
async def _get_or_create_model_agent(self, model_id: str, target_config: Dict[str, Any]) -> AgentLoop:
async def _get_or_create_model_agent(self, model_id: str | None, target_config: Dict[str, Any] | None, project_id: int | None = None) -> AgentLoop:
cache_key = (model_id, project_id)
async with self._model_agent_lock:
cached = self._model_agent_cache.get(model_id)
cached = self._model_agent_cache.get(cache_key)
if cached:
return cached
provider = StreamingLiteLLMProvider(
api_key=target_config.get("api_key"),
api_base=target_config.get("api_base"),
default_model=target_config.get("model"),
extra_headers=target_config.get("extra_headers"),
provider_name=target_config.get("provider"),
)
agent = self._build_agent_for_provider(provider)
if target_config:
provider = StreamingLiteLLMProvider(
api_key=target_config.get("api_key"),
api_base=target_config.get("api_base"),
default_model=target_config.get("model"),
extra_headers=target_config.get("extra_headers"),
provider_name=target_config.get("provider"),
)
else:
provider = self._make_provider(self.config)
mcp_servers_dict = dict(self.config.tools.mcp_servers) if self.config.tools.mcp_servers else {}
if project_id is not None:
from app.api.mcp import list_mcp_servers
from nanobot.config.schema import MCPServerConfig
servers = list_mcp_servers(project_id=project_id)
for s in servers:
cfg = MCPServerConfig(
type=s.get("type"),
command=s.get("command") or "",
args=s.get("args") or [],
env=s.get("env") or {},
url=s.get("url") or "",
headers=s.get("headers") or {}
)
mcp_servers_dict[s["name"]] = cfg
agent = self._build_agent_for_provider(provider, mcp_servers=mcp_servers_dict)
self._register_custom_tools(agent)
self._model_agent_cache[model_id] = agent
self._model_agent_cache[cache_key] = agent
return agent
async def process_message(
@@ -233,6 +255,7 @@ class NanobotIntegration:
session_id: str = "api:default",
skill_ids: List[str] | None = None,
model_id: str | None = None,
project_id: int | None = None,
on_progress: Callable[[str], Awaitable[None]] | None = None,
):
if not self.agent:
@@ -240,32 +263,27 @@ class NanobotIntegration:
if not self._started:
await self.start()
# Handle dynamic model switching
# If model_id is provided, we need to fetch its config and create a temporary provider
# or update the current agent's provider context for this request.
# Since AgentLoop is stateful and tied to a provider, and we want to avoid recreating the whole agent for every request if possible,
# but changing the provider/model is a significant change.
#
# A simpler approach for this "stateless API" usage pattern:
# We can instantiate a lightweight version of the agent or provider just for this request if the model differs.
# OR, since we are using `process_direct`, we can check if `AgentLoop` supports overriding the model.
# Looking at `nanobot/agent/loop.py` (assumed), it uses `self.provider.completion(...)`.
# Strategy:
# 1. Load the model config from our JSON file using `model_id`.
# 2. Construct a temporary provider instance for this model.
# 3. Inject this provider into the agent for this request OR (cleaner) instantiate a temporary agent.
# Instantiating a whole AgentLoop might be heavy due to MCP/Cron etc.
# BUT `process_direct` is relatively isolated.
#
# Let's try to fetch the config first.
if project_id is None:
from app.core.session_alias_store import session_alias_store
alias_info = session_alias_store.get_alias(session_id)
if alias_info and alias_info.get("project_id"):
project_id = alias_info.get("project_id")
agent_to_use = self.agent
need_custom_agent = False
target_config = None
if model_id:
llm_configs = get_llm_configs()
target_config = next((item for item in llm_configs if item.get("id") == model_id), None)
if target_config:
if target_config.get("model") != self.agent.model:
agent_to_use = await self._get_or_create_model_agent(model_id, target_config)
if target_config and target_config.get("model") != self.agent.model:
need_custom_agent = True
if project_id is not None:
need_custom_agent = True
if need_custom_agent:
agent_to_use = await self._get_or_create_model_agent(model_id, target_config, project_id)
full_message = message
# We no longer inject the full skill content into the user's message here,
+30
View File
@@ -0,0 +1,30 @@
from typing import List, Dict, Optional, Literal
from pydantic import BaseModel, Field
class MCPServerBase(BaseModel):
name: str
type: Literal["stdio", "sse", "streamableHttp"]
command: Optional[str] = None
args: Optional[List[str]] = Field(default_factory=list)
env: Optional[Dict[str, str]] = Field(default_factory=dict)
url: Optional[str] = None
headers: Optional[Dict[str, str]] = Field(default_factory=dict)
project_id: int
status: str = "disconnected"
class MCPServerCreate(MCPServerBase):
pass
class MCPServerUpdate(BaseModel):
name: Optional[str] = None
type: Optional[Literal["stdio", "sse", "streamableHttp"]] = None
command: Optional[str] = None
args: Optional[List[str]] = None
env: Optional[Dict[str, str]] = None
url: Optional[str] = None
headers: Optional[Dict[str, str]] = None
project_id: Optional[int] = None
status: Optional[str] = None
class MCPServer(MCPServerBase):
id: str