feat: add MCP
This commit is contained in:
@@ -0,0 +1,82 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.schemas.mcp import MCPServer, MCPServerCreate, MCPServerUpdate
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def get_mcp_servers_file() -> Path:
|
||||
return get_data_root() / "mcp_servers.json"
|
||||
|
||||
def read_mcp_servers() -> List[dict]:
|
||||
file_path = get_mcp_servers_file()
|
||||
if not file_path.exists():
|
||||
return []
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
def write_mcp_servers(servers: List[dict]) -> None:
|
||||
file_path = get_mcp_servers_file()
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(servers, f, indent=2, ensure_ascii=False)
|
||||
|
||||
@router.get("/mcp", response_model=List[MCPServer])
|
||||
def list_mcp_servers(project_id: Optional[int] = None):
|
||||
servers = read_mcp_servers()
|
||||
if project_id is not None:
|
||||
servers = [s for s in servers if s.get("project_id") == project_id]
|
||||
return servers
|
||||
|
||||
@router.post("/mcp", response_model=MCPServer)
|
||||
def create_mcp_server(server_in: MCPServerCreate):
|
||||
servers = read_mcp_servers()
|
||||
|
||||
server_data = server_in.dict()
|
||||
server_data["id"] = str(uuid.uuid4())
|
||||
if "status" not in server_data or not server_data["status"]:
|
||||
server_data["status"] = "disconnected"
|
||||
|
||||
servers.append(server_data)
|
||||
write_mcp_servers(servers)
|
||||
return server_data
|
||||
|
||||
@router.get("/mcp/{server_id}", response_model=MCPServer)
|
||||
def get_mcp_server(server_id: str):
|
||||
servers = read_mcp_servers()
|
||||
for server in servers:
|
||||
if server.get("id") == server_id:
|
||||
return server
|
||||
raise HTTPException(status_code=404, detail="MCP Server not found")
|
||||
|
||||
@router.put("/mcp/{server_id}", response_model=MCPServer)
|
||||
def update_mcp_server(server_id: str, server_in: MCPServerUpdate):
|
||||
servers = read_mcp_servers()
|
||||
for i, server in enumerate(servers):
|
||||
if server.get("id") == server_id:
|
||||
update_data = server_in.dict(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
server[key] = value
|
||||
servers[i] = server
|
||||
write_mcp_servers(servers)
|
||||
return server
|
||||
raise HTTPException(status_code=404, detail="MCP Server not found")
|
||||
|
||||
@router.delete("/mcp/{server_id}")
|
||||
def delete_mcp_server(server_id: str):
|
||||
servers = read_mcp_servers()
|
||||
filtered_servers = [s for s in servers if s.get("id") != server_id]
|
||||
|
||||
if len(servers) == len(filtered_servers):
|
||||
raise HTTPException(status_code=404, detail="MCP Server not found")
|
||||
|
||||
write_mcp_servers(filtered_servers)
|
||||
return {"status": "success"}
|
||||
+55
-37
@@ -44,7 +44,7 @@ class NanobotIntegration:
|
||||
self.cron: CronService | None = None
|
||||
self.config: Config | None = None
|
||||
self._started = False
|
||||
self._model_agent_cache: Dict[str, AgentLoop] = {}
|
||||
self._model_agent_cache: Dict[tuple[str | None, int | None], AgentLoop] = {}
|
||||
self._model_agent_lock = asyncio.Lock()
|
||||
|
||||
def initialize(self):
|
||||
@@ -189,7 +189,7 @@ class NanobotIntegration:
|
||||
self.cron.stop()
|
||||
self._started = False
|
||||
|
||||
def _build_agent_for_provider(self, provider: Any) -> AgentLoop:
|
||||
def _build_agent_for_provider(self, provider: Any, mcp_servers: dict | None = None) -> AgentLoop:
|
||||
return AgentLoop(
|
||||
bus=self.bus,
|
||||
provider=provider,
|
||||
@@ -205,26 +205,48 @@ class NanobotIntegration:
|
||||
exec_config=self.config.tools.exec,
|
||||
cron_service=self.cron,
|
||||
restrict_to_workspace=self.config.tools.restrict_to_workspace,
|
||||
session_manager=self.agent.sessions,
|
||||
mcp_servers=self.config.tools.mcp_servers,
|
||||
session_manager=self.agent.sessions if self.agent else None,
|
||||
mcp_servers=mcp_servers if mcp_servers is not None else self.config.tools.mcp_servers,
|
||||
channels_config=self.config.channels,
|
||||
)
|
||||
|
||||
async def _get_or_create_model_agent(self, model_id: str, target_config: Dict[str, Any]) -> AgentLoop:
|
||||
async def _get_or_create_model_agent(self, model_id: str | None, target_config: Dict[str, Any] | None, project_id: int | None = None) -> AgentLoop:
|
||||
cache_key = (model_id, project_id)
|
||||
async with self._model_agent_lock:
|
||||
cached = self._model_agent_cache.get(model_id)
|
||||
cached = self._model_agent_cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
provider = StreamingLiteLLMProvider(
|
||||
api_key=target_config.get("api_key"),
|
||||
api_base=target_config.get("api_base"),
|
||||
default_model=target_config.get("model"),
|
||||
extra_headers=target_config.get("extra_headers"),
|
||||
provider_name=target_config.get("provider"),
|
||||
)
|
||||
agent = self._build_agent_for_provider(provider)
|
||||
|
||||
if target_config:
|
||||
provider = StreamingLiteLLMProvider(
|
||||
api_key=target_config.get("api_key"),
|
||||
api_base=target_config.get("api_base"),
|
||||
default_model=target_config.get("model"),
|
||||
extra_headers=target_config.get("extra_headers"),
|
||||
provider_name=target_config.get("provider"),
|
||||
)
|
||||
else:
|
||||
provider = self._make_provider(self.config)
|
||||
|
||||
mcp_servers_dict = dict(self.config.tools.mcp_servers) if self.config.tools.mcp_servers else {}
|
||||
if project_id is not None:
|
||||
from app.api.mcp import list_mcp_servers
|
||||
from nanobot.config.schema import MCPServerConfig
|
||||
servers = list_mcp_servers(project_id=project_id)
|
||||
for s in servers:
|
||||
cfg = MCPServerConfig(
|
||||
type=s.get("type"),
|
||||
command=s.get("command") or "",
|
||||
args=s.get("args") or [],
|
||||
env=s.get("env") or {},
|
||||
url=s.get("url") or "",
|
||||
headers=s.get("headers") or {}
|
||||
)
|
||||
mcp_servers_dict[s["name"]] = cfg
|
||||
|
||||
agent = self._build_agent_for_provider(provider, mcp_servers=mcp_servers_dict)
|
||||
self._register_custom_tools(agent)
|
||||
self._model_agent_cache[model_id] = agent
|
||||
self._model_agent_cache[cache_key] = agent
|
||||
return agent
|
||||
|
||||
async def process_message(
|
||||
@@ -233,6 +255,7 @@ class NanobotIntegration:
|
||||
session_id: str = "api:default",
|
||||
skill_ids: List[str] | None = None,
|
||||
model_id: str | None = None,
|
||||
project_id: int | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
):
|
||||
if not self.agent:
|
||||
@@ -240,32 +263,27 @@ class NanobotIntegration:
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
# Handle dynamic model switching
|
||||
# If model_id is provided, we need to fetch its config and create a temporary provider
|
||||
# or update the current agent's provider context for this request.
|
||||
# Since AgentLoop is stateful and tied to a provider, and we want to avoid recreating the whole agent for every request if possible,
|
||||
# but changing the provider/model is a significant change.
|
||||
#
|
||||
# A simpler approach for this "stateless API" usage pattern:
|
||||
# We can instantiate a lightweight version of the agent or provider just for this request if the model differs.
|
||||
# OR, since we are using `process_direct`, we can check if `AgentLoop` supports overriding the model.
|
||||
# Looking at `nanobot/agent/loop.py` (assumed), it uses `self.provider.completion(...)`.
|
||||
|
||||
# Strategy:
|
||||
# 1. Load the model config from our JSON file using `model_id`.
|
||||
# 2. Construct a temporary provider instance for this model.
|
||||
# 3. Inject this provider into the agent for this request OR (cleaner) instantiate a temporary agent.
|
||||
# Instantiating a whole AgentLoop might be heavy due to MCP/Cron etc.
|
||||
# BUT `process_direct` is relatively isolated.
|
||||
#
|
||||
# Let's try to fetch the config first.
|
||||
if project_id is None:
|
||||
from app.core.session_alias_store import session_alias_store
|
||||
alias_info = session_alias_store.get_alias(session_id)
|
||||
if alias_info and alias_info.get("project_id"):
|
||||
project_id = alias_info.get("project_id")
|
||||
|
||||
agent_to_use = self.agent
|
||||
need_custom_agent = False
|
||||
target_config = None
|
||||
|
||||
if model_id:
|
||||
llm_configs = get_llm_configs()
|
||||
target_config = next((item for item in llm_configs if item.get("id") == model_id), None)
|
||||
if target_config:
|
||||
if target_config.get("model") != self.agent.model:
|
||||
agent_to_use = await self._get_or_create_model_agent(model_id, target_config)
|
||||
if target_config and target_config.get("model") != self.agent.model:
|
||||
need_custom_agent = True
|
||||
|
||||
if project_id is not None:
|
||||
need_custom_agent = True
|
||||
|
||||
if need_custom_agent:
|
||||
agent_to_use = await self._get_or_create_model_agent(model_id, target_config, project_id)
|
||||
|
||||
full_message = message
|
||||
# We no longer inject the full skill content into the user's message here,
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import List, Dict, Optional, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class MCPServerBase(BaseModel):
|
||||
name: str
|
||||
type: Literal["stdio", "sse", "streamableHttp"]
|
||||
command: Optional[str] = None
|
||||
args: Optional[List[str]] = Field(default_factory=list)
|
||||
env: Optional[Dict[str, str]] = Field(default_factory=dict)
|
||||
url: Optional[str] = None
|
||||
headers: Optional[Dict[str, str]] = Field(default_factory=dict)
|
||||
project_id: int
|
||||
status: str = "disconnected"
|
||||
|
||||
class MCPServerCreate(MCPServerBase):
|
||||
pass
|
||||
|
||||
class MCPServerUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
type: Optional[Literal["stdio", "sse", "streamableHttp"]] = None
|
||||
command: Optional[str] = None
|
||||
args: Optional[List[str]] = None
|
||||
env: Optional[Dict[str, str]] = None
|
||||
url: Optional[str] = None
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
project_id: Optional[int] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
class MCPServer(MCPServerBase):
|
||||
id: str
|
||||
+2
-1
@@ -16,7 +16,7 @@ import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from app.api import upload, llm, skills, users, datasources, projects, semantic
|
||||
from app.api import upload, llm, skills, users, datasources, projects, semantic, mcp
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.core.artifacts import extract_artifacts
|
||||
@@ -59,6 +59,7 @@ app.include_router(users.router, prefix="/api/v1")
|
||||
app.include_router(projects.router, prefix="/api/v1")
|
||||
app.include_router(datasources.router, prefix="/api/v1")
|
||||
app.include_router(semantic.router, prefix="/api/v1")
|
||||
app.include_router(mcp.router, prefix="/api/v1")
|
||||
|
||||
STREAM_DELTA_CHUNK_SIZE = 48
|
||||
PREVIEWABLE_TEXT_EXTENSIONS = {
|
||||
|
||||
Submodule
+1
Submodule backend/mcp-sse added at 64f9400214
Reference in New Issue
Block a user