First build
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
- [x] Directory structure is correct (`backend/`, `frontend/`, `nanobot/`).
|
||||
- [x] FastAPI server running.
|
||||
- [x] React frontend running.
|
||||
- [x] LLM Configuration UI works.
|
||||
- [x] Users can select configured LLM for agents.
|
||||
- [x] Data Sources (PG, CH, MinIO, CSV) connect successfully.
|
||||
- [x] NL2SQL generates correct SQL.
|
||||
- [x] "View SQL" button works.
|
||||
- [x] "Add to Dashboard" button works.
|
||||
- [x] Dashboard displays pinned charts.
|
||||
- [x] Dashboard panels are resizable and draggable (Grafana-like).
|
||||
- [x] Internal Skills CRUD works.
|
||||
- [x] Users can select specific skills for a chat session.
|
||||
- [x] Selected skills are used by the agent.
|
||||
- [x] Multi-agent workflow functions as expected.
|
||||
@@ -0,0 +1,89 @@
|
||||
# Data Analysis Platform Spec
|
||||
|
||||
## Why
|
||||
Currently, users need a unified platform to perform data analysis using natural language across multiple data sources (PostgreSQL, ClickHouse, MinIO, CSV). Existing solutions may lack the specific agentic capabilities, skill extensibility, or the desired user experience (WrenAI-like). Building this platform will democratize data access and analysis for non-technical users through LLM-powered SQL generation and agent-based workflows.
|
||||
|
||||
## What Changes
|
||||
- **Directory Structure**:
|
||||
- `frontend/`: React application.
|
||||
- `backend/`: FastAPI application.
|
||||
- `nanobot/`: Existing nanobot source code (to be integrated/referenced by backend).
|
||||
- **Backend Architecture**:
|
||||
- Implement a FastAPI server in `backend/` to handle API requests.
|
||||
- Integrate `nanobot` framework for agent management, session handling, and memory.
|
||||
- Implement connectors for PostgreSQL, ClickHouse, MinIO, and CSV file handling.
|
||||
- Implement NL2SQL logic using LLM.
|
||||
- Implement internal Skills management system.
|
||||
- Implement LLM Integration Module for custom model configuration.
|
||||
- **Frontend Architecture**:
|
||||
- Create a React application in `frontend/`.
|
||||
- Implement a layout inspired by WrenAI (Sidebar for threads, Main chat area, Visualization pane).
|
||||
- **New**: Implement a dynamic Dashboard view with resizable/draggable panels (Grafana-like).
|
||||
- Implement a Skills management interface.
|
||||
- Implement an LLM Configuration interface (Settings page).
|
||||
- **Agent System**:
|
||||
- Configure `nanobot` to support multiple agents.
|
||||
- Allow users to explicitly select skills for an agent to use.
|
||||
- Allow users to select which LLM model to use for the agent.
|
||||
|
||||
## Impact
|
||||
- **Affected specs**: N/A (New Project)
|
||||
- **Affected code**: `backend/`, `frontend/`, and integration with `nanobot`.
|
||||
|
||||
## ADDED Requirements
|
||||
|
||||
### Requirement: LLM Integration Module
|
||||
The system SHALL:
|
||||
- Allow users to configure multiple LLM providers (e.g., OpenAI, Anthropic, Custom/Local via OpenAI-compatible API).
|
||||
- Allow users to set API Keys, Base URLs, and Model Names.
|
||||
- Persist these configurations.
|
||||
- Allow users to select which model/provider to use for the agent/NL2SQL tasks.
|
||||
|
||||
### Requirement: Data Source Connectivity
|
||||
The system SHALL allow users to connect to:
|
||||
- PostgreSQL databases.
|
||||
- ClickHouse databases.
|
||||
- MinIO object storage.
|
||||
- Upload CSV files directly.
|
||||
|
||||
### Requirement: Natural Language to SQL (NL2SQL) & Visualization
|
||||
The system SHALL:
|
||||
- Accept natural language queries from users.
|
||||
- Use an LLM to convert these queries into executable SQL.
|
||||
- Execute the SQL and return results.
|
||||
- Visualize the results (charts/tables).
|
||||
- Provide a "View SQL" button for each chart/result to show the underlying SQL.
|
||||
- Provide an "Add to Dashboard" button for each chart to pin it to a global Dashboard view.
|
||||
|
||||
### Requirement: Interactive Dashboard (Grafana-like)
|
||||
The system SHALL:
|
||||
- Provide a Dashboard view where pinned charts are displayed as panels.
|
||||
- Allow users to resize and drag panels to customize the layout (Grid layout).
|
||||
- Persist the dashboard layout.
|
||||
|
||||
### Requirement: Internal Skills Management
|
||||
The system SHALL:
|
||||
- Allow users to define custom skills (name, description, instructions/code).
|
||||
- Store these skills in the system.
|
||||
- Allow users to view, edit, and delete skills.
|
||||
- Allow users to select specific skills to be active for a conversation or agent.
|
||||
|
||||
### Requirement: Multi-Agent Support
|
||||
The system SHALL:
|
||||
- Support multiple specialized agents (e.g., Data Analyst, SQL Generator).
|
||||
- Use `nanobot` for orchestrating these agents.
|
||||
|
||||
### Requirement: User Interface (WrenAI Style)
|
||||
The system SHALL:
|
||||
- Have a sidebar for managing chat threads/history.
|
||||
- Have a main chat interface.
|
||||
- Have a dedicated area for displaying data visualizations.
|
||||
- Have a "Dashboard" page/view with resizable panels.
|
||||
- Have a "Skills" page/view to manage custom skills.
|
||||
- Have a "Settings" page/view to manage LLM configurations.
|
||||
|
||||
## MODIFIED Requirements
|
||||
N/A
|
||||
|
||||
## REMOVED Requirements
|
||||
N/A
|
||||
@@ -0,0 +1,47 @@
|
||||
# Tasks
|
||||
|
||||
- [x] Task 1: Project Initialization & Structure
|
||||
- [x] SubTask 1.1: Verify/Create `backend` and `frontend` directories.
|
||||
- [x] SubTask 1.2: Ensure `nanobot` source code is correctly placed/linked.
|
||||
- [x] SubTask 1.3: Set up FastAPI in `backend/` and React in `frontend/`.
|
||||
|
||||
- [x] Task 2: Backend - Core & Data Sources
|
||||
- [x] SubTask 2.1: Configure `nanobot` integration within FastAPI.
|
||||
- [x] SubTask 2.2: Implement PostgreSQL connector.
|
||||
- [x] SubTask 2.3: Implement ClickHouse connector.
|
||||
- [x] SubTask 2.4: Implement MinIO connector.
|
||||
- [x] SubTask 2.5: Implement CSV upload handling.
|
||||
|
||||
- [x] Task 3: Backend - Agent, Skills & LLM
|
||||
- [x] SubTask 3.1: Implement LLM Configuration API (CRUD for providers/models).
|
||||
- [x] SubTask 3.2: Implement NL2SQL agent logic using `nanobot` (using configured LLM).
|
||||
- [x] SubTask 3.3: Implement Internal Skills CRUD API.
|
||||
- [x] SubTask 3.4: Implement Skill Selection logic.
|
||||
|
||||
- [x] Task 4: Frontend - Core & UI Components
|
||||
- [x] SubTask 4.1: Setup React project with Tailwind/Shadcn.
|
||||
- [x] SubTask 4.2: Implement Sidebar (Threads/History).
|
||||
- [x] SubTask 4.3: Implement Main Chat Interface.
|
||||
- [x] SubTask 4.4: Implement Visualization Component (Charts/Tables).
|
||||
- [x] SubTask 4.5: Implement "View SQL" button and modal/popover.
|
||||
|
||||
- [x] Task 5: Frontend - Dashboard & Management
|
||||
- [x] SubTask 5.1: Implement Dashboard Page with Grid Layout (using `react-grid-layout` or similar).
|
||||
- [x] SubTask 5.2: Implement "Add to Dashboard" functionality (persist chart config to dashboard state).
|
||||
- [x] SubTask 5.3: Implement Skills Management UI (List/Edit).
|
||||
- [x] SubTask 5.4: Implement LLM Settings UI (Configure providers/keys).
|
||||
- [x] SubTask 5.5: Implement Skill Selection selector in Chat interface.
|
||||
|
||||
- [x] Task 6: Integration & Polish
|
||||
- [x] SubTask 6.1: Connect Frontend to Backend.
|
||||
- [x] SubTask 6.2: Test LLM Configuration (add provider -> use in chat).
|
||||
- [x] SubTask 6.3: Test NL2SQL flow with SQL view and Dashboard pinning.
|
||||
- [x] SubTask 6.4: Test Dashboard interactivity (resize/drag panels).
|
||||
- [x] SubTask 6.5: Test Skill creation and usage.
|
||||
- [x] SubTask 6.6: Verify multi-agent coordination.
|
||||
|
||||
# Task Dependencies
|
||||
- Task 3 depends on Task 1 and Task 2.
|
||||
- Task 4 depends on Task 1.
|
||||
- Task 5 depends on Task 3 and Task 4.
|
||||
- Task 6 depends on all previous tasks.
|
||||
@@ -0,0 +1,27 @@
|
||||
# DataClaw
|
||||
|
||||
Data Analysis Platform.
|
||||
|
||||
## Structure
|
||||
|
||||
- `backend/`: FastAPI backend
|
||||
- `frontend/`: React frontend (Vite + TailwindCSS + Shadcn UI)
|
||||
- `nanobot/`: Core AI agent framework
|
||||
|
||||
## Setup
|
||||
|
||||
### Backend
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
pip install -r requirements.txt
|
||||
uvicorn main:app --reload
|
||||
```
|
||||
|
||||
### Frontend
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
@@ -0,0 +1 @@
|
||||
3.11
|
||||
@@ -0,0 +1,106 @@
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Add project root to sys.path to allow importing nanobot
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.append(str(PROJECT_ROOT))
|
||||
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
|
||||
class NL2SQLRequest(BaseModel):
|
||||
query: str = Field(..., description="User's natural language query")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse)")
|
||||
|
||||
class NL2SQLResponse(BaseModel):
|
||||
sql: str
|
||||
result: List[Dict[str, Any]]
|
||||
error: Optional[str] = None
|
||||
|
||||
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
# 1. Get the connector and schema
|
||||
connector = None
|
||||
if request.source == "postgres":
|
||||
connector = postgres_connector
|
||||
elif request.source == "clickhouse":
|
||||
connector = clickhouse_connector
|
||||
else:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||
|
||||
if not connector.test_connection():
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
|
||||
schema = connector.get_schema()
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
|
||||
# 2. Get the active LLM config
|
||||
llm_configs = load_llm_config()
|
||||
active_config = next((c for c in llm_configs if c.get("is_active")), None)
|
||||
|
||||
if not active_config:
|
||||
return NL2SQLResponse(sql="", result=[], error="No active LLM configuration found")
|
||||
|
||||
# 3. Initialize Provider
|
||||
try:
|
||||
provider = LiteLLMProvider(
|
||||
api_key=active_config.get("api_key"),
|
||||
api_base=active_config.get("api_base"),
|
||||
default_model=active_config.get("model"),
|
||||
extra_headers=active_config.get("extra_headers")
|
||||
)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to initialize LLM provider: {e}")
|
||||
|
||||
# 4. Construct Prompt
|
||||
prompt = f"""You are an expert SQL generator.
|
||||
Given the following database schema for a {request.source} database:
|
||||
{schema_str}
|
||||
|
||||
Write a SQL query to answer the following question:
|
||||
"{request.query}"
|
||||
|
||||
Return ONLY the SQL query. Do not include any markdown formatting, explanations, or code blocks. Just the raw SQL string.
|
||||
"""
|
||||
|
||||
# 5. Call LLM
|
||||
try:
|
||||
# provider.complete returns a string
|
||||
response = await provider.complete(prompt)
|
||||
sql_query = response.strip()
|
||||
# Remove potential markdown code blocks if the LLM ignores instructions
|
||||
if sql_query.startswith("```sql"):
|
||||
sql_query = sql_query[6:]
|
||||
if sql_query.startswith("```"):
|
||||
sql_query = sql_query[3:]
|
||||
if sql_query.endswith("```"):
|
||||
sql_query = sql_query[:-3]
|
||||
sql_query = sql_query.strip()
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
results = connector.execute_query(sql_query)
|
||||
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
||||
formatted_results = []
|
||||
if request.source == "postgres":
|
||||
formatted_results = results
|
||||
elif request.source == "clickhouse":
|
||||
# ClickHouse returns list of tuples, we need column names
|
||||
# But execute_query in ClickHouseConnector just returns raw results from client.execute
|
||||
# client.execute(query, with_column_types=True) might be better but let's stick to simple for now
|
||||
# Actually, without column names it's hard to format as dict.
|
||||
# Let's assume we can just return the raw tuples for now or try to fetch column names.
|
||||
# For now, let's just return as list of lists/tuples if it's not a dict
|
||||
formatted_results = [list(row) for row in results]
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Body
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
DATA_FILE = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "llm_config.json")
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
id: str = Field(..., description="Unique identifier for the LLM configuration")
|
||||
provider: str = Field(..., description="Provider name (e.g., openai, azure, anthropic)")
|
||||
model: str = Field(..., description="Model name (e.g., gpt-4, claude-3-opus)")
|
||||
api_key: Optional[str] = Field(None, description="API Key for the provider")
|
||||
api_base: Optional[str] = Field(None, description="Base URL for the API")
|
||||
extra_headers: Optional[Dict[str, str]] = Field(None, description="Extra headers for the request")
|
||||
is_active: bool = Field(True, description="Whether this configuration is active")
|
||||
|
||||
class LLMConfigCreate(BaseModel):
|
||||
id: str
|
||||
provider: str
|
||||
model: str
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
is_active: bool = True
|
||||
|
||||
class LLMConfigUpdate(BaseModel):
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_base: Optional[str] = None
|
||||
extra_headers: Optional[Dict[str, str]] = None
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
def _load_data() -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(DATA_FILE):
|
||||
return []
|
||||
try:
|
||||
with open(DATA_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
def _save_data(data: List[Dict[str, Any]]):
|
||||
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
|
||||
with open(DATA_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
@router.get("/llm", response_model=List[LLMConfig])
|
||||
def list_llm_configs():
|
||||
data = _load_data()
|
||||
return [LLMConfig(**item) for item in data]
|
||||
|
||||
@router.get("/llm/{config_id}", response_model=LLMConfig)
|
||||
def get_llm_config(config_id: str):
|
||||
data = _load_data()
|
||||
for item in data:
|
||||
if item["id"] == config_id:
|
||||
return LLMConfig(**item)
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
|
||||
@router.post("/llm", response_model=LLMConfig)
|
||||
def create_llm_config(config: LLMConfigCreate):
|
||||
data = _load_data()
|
||||
if any(item["id"] == config.id for item in data):
|
||||
raise HTTPException(status_code=400, detail="LLM configuration with this ID already exists")
|
||||
|
||||
new_config = config.dict()
|
||||
data.append(new_config)
|
||||
_save_data(data)
|
||||
return LLMConfig(**new_config)
|
||||
|
||||
@router.put("/llm/{config_id}", response_model=LLMConfig)
|
||||
def update_llm_config(config_id: str, config: LLMConfigUpdate):
|
||||
data = _load_data()
|
||||
for i, item in enumerate(data):
|
||||
if item["id"] == config_id:
|
||||
updated_item = item.copy()
|
||||
update_data = config.dict(exclude_unset=True)
|
||||
updated_item.update(update_data)
|
||||
data[i] = updated_item
|
||||
_save_data(data)
|
||||
return LLMConfig(**updated_item)
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
|
||||
@router.delete("/llm/{config_id}")
|
||||
def delete_llm_config(config_id: str):
|
||||
data = _load_data()
|
||||
initial_len = len(data)
|
||||
data = [item for item in data if item["id"] != config_id]
|
||||
if len(data) == initial_len:
|
||||
raise HTTPException(status_code=404, detail="LLM configuration not found")
|
||||
_save_data(data)
|
||||
return {"message": "LLM configuration deleted successfully"}
|
||||
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
DATA_FILE = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "data", "skills.json")
|
||||
|
||||
class Skill(BaseModel):
|
||||
id: str = Field(..., description="Unique identifier for the skill")
|
||||
name: str = Field(..., description="Name of the skill")
|
||||
description: Optional[str] = Field(None, description="Description of what the skill does")
|
||||
content: str = Field(..., description="The content/prompt/logic of the skill")
|
||||
type: str = Field("python", description="Type of the skill (python, sql, api)")
|
||||
|
||||
class SkillCreate(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
content: str
|
||||
type: str = "python"
|
||||
|
||||
class SkillUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
|
||||
def _load_data() -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(DATA_FILE):
|
||||
return []
|
||||
try:
|
||||
with open(DATA_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
return []
|
||||
|
||||
def _save_data(data: List[Dict[str, Any]]):
|
||||
os.makedirs(os.path.dirname(DATA_FILE), exist_ok=True)
|
||||
with open(DATA_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def load_skills() -> List[Dict[str, Any]]:
|
||||
return _load_data()
|
||||
|
||||
@router.get("/skills", response_model=List[Skill])
|
||||
def list_skills():
|
||||
data = load_skills()
|
||||
return [Skill(**item) for item in data]
|
||||
|
||||
@router.get("/skills/{skill_id}", response_model=Skill)
|
||||
def get_skill(skill_id: str):
|
||||
data = _load_data()
|
||||
for item in data:
|
||||
if item["id"] == skill_id:
|
||||
return Skill(**item)
|
||||
raise HTTPException(status_code=404, detail="Skill not found")
|
||||
|
||||
@router.post("/skills", response_model=Skill)
|
||||
def create_skill(skill: SkillCreate):
|
||||
data = _load_data()
|
||||
if any(item["id"] == skill.id for item in data):
|
||||
raise HTTPException(status_code=400, detail="Skill with this ID already exists")
|
||||
|
||||
new_skill = skill.dict()
|
||||
data.append(new_skill)
|
||||
_save_data(data)
|
||||
return Skill(**new_skill)
|
||||
|
||||
@router.put("/skills/{skill_id}", response_model=Skill)
|
||||
def update_skill(skill_id: str, skill: SkillUpdate):
|
||||
data = _load_data()
|
||||
for i, item in enumerate(data):
|
||||
if item["id"] == skill_id:
|
||||
updated_item = item.copy()
|
||||
update_data = skill.dict(exclude_unset=True)
|
||||
updated_item.update(update_data)
|
||||
data[i] = updated_item
|
||||
_save_data(data)
|
||||
return Skill(**updated_item)
|
||||
raise HTTPException(status_code=404, detail="Skill not found")
|
||||
|
||||
@router.delete("/skills/{skill_id}")
|
||||
def delete_skill(skill_id: str):
|
||||
data = _load_data()
|
||||
initial_len = len(data)
|
||||
data = [item for item in data if item["id"] != skill_id]
|
||||
if len(data) == initial_len:
|
||||
raise HTTPException(status_code=404, detail="Skill not found")
|
||||
_save_data(data)
|
||||
return {"message": "Skill deleted successfully"}
|
||||
@@ -0,0 +1,53 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, BackgroundTasks
|
||||
from app.connectors.minio import minio_connector
|
||||
import pandas as pd
|
||||
import duckdb
|
||||
import io
|
||||
import uuid
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/upload/csv")
|
||||
async def upload_csv(file: UploadFile = File(...), background_tasks: BackgroundTasks = None):
|
||||
if not file.filename.endswith('.csv'):
|
||||
raise HTTPException(status_code=400, detail="Invalid file type. Only CSV allowed.")
|
||||
|
||||
try:
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
file_obj = io.BytesIO(content)
|
||||
|
||||
# Generate a unique filename
|
||||
unique_filename = f"{uuid.uuid4()}-{file.filename}"
|
||||
|
||||
# Upload to MinIO
|
||||
minio_url = minio_connector.upload_file(unique_filename, file_obj, file_size, content_type="text/csv")
|
||||
|
||||
# Reset file pointer for analysis
|
||||
file_obj.seek(0)
|
||||
|
||||
# Load into DuckDB (in-memory) for quick analysis
|
||||
try:
|
||||
df = pd.read_csv(file_obj)
|
||||
duckdb_conn = duckdb.connect(database=':memory:')
|
||||
duckdb_conn.register('uploaded_csv', df)
|
||||
summary = duckdb_conn.execute("DESCRIBE uploaded_csv").fetchall()
|
||||
row_count = len(df)
|
||||
columns = list(df.columns)
|
||||
|
||||
return {
|
||||
"filename": unique_filename,
|
||||
"url": minio_url,
|
||||
"rows": row_count,
|
||||
"columns": columns,
|
||||
"summary": str(summary)
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"filename": unique_filename,
|
||||
"url": minio_url,
|
||||
"analysis_error": str(e)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,50 @@
|
||||
from clickhouse_driver import Client
|
||||
import os
|
||||
|
||||
class ClickHouseConnector:
|
||||
def __init__(self, host: str = None, port: int = 9000, user: str = 'default', password: str = '', database: str = 'default'):
|
||||
self.host = host or os.getenv("CLICKHOUSE_HOST", "localhost")
|
||||
self.port = port or int(os.getenv("CLICKHOUSE_PORT", 9000))
|
||||
self.user = user or os.getenv("CLICKHOUSE_USER", "default")
|
||||
self.password = password or os.getenv("CLICKHOUSE_PASSWORD", "")
|
||||
self.database = database or os.getenv("CLICKHOUSE_DB", "default")
|
||||
|
||||
self.client = Client(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
database=self.database
|
||||
)
|
||||
|
||||
def execute_query(self, query: str):
|
||||
try:
|
||||
return self.client.execute(query)
|
||||
except Exception as e:
|
||||
print(f"ClickHouse Query Error: {e}")
|
||||
raise e
|
||||
|
||||
def get_schema(self):
|
||||
query = "SELECT table, name, type FROM system.columns WHERE database = currentDatabase()"
|
||||
try:
|
||||
results = self.client.execute(query)
|
||||
schema = {}
|
||||
for row in results:
|
||||
table = row[0]
|
||||
if table not in schema:
|
||||
schema[table] = []
|
||||
schema[table].append(f"{row[1]} ({row[2]})")
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting schema: {e}")
|
||||
return {}
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
try:
|
||||
self.client.execute("SELECT 1")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"ClickHouse Connection Error: {e}")
|
||||
return False
|
||||
|
||||
clickhouse_connector = ClickHouseConnector()
|
||||
@@ -0,0 +1,51 @@
|
||||
from minio import Minio
|
||||
from minio.error import S3Error
|
||||
import os
|
||||
from typing import BinaryIO
|
||||
|
||||
class MinioConnector:
|
||||
def __init__(self):
|
||||
self.endpoint = os.getenv("MINIO_ENDPOINT", "localhost:9000")
|
||||
self.access_key = os.getenv("MINIO_ACCESS_KEY", "minioadmin")
|
||||
self.secret_key = os.getenv("MINIO_SECRET_KEY", "minioadmin")
|
||||
self.secure = os.getenv("MINIO_SECURE", "False").lower() == "true"
|
||||
self.bucket_name = os.getenv("MINIO_BUCKET", "dataclaw")
|
||||
|
||||
self.client = Minio(
|
||||
self.endpoint,
|
||||
access_key=self.access_key,
|
||||
secret_key=self.secret_key,
|
||||
secure=self.secure
|
||||
)
|
||||
self._ensure_bucket_exists()
|
||||
|
||||
def _ensure_bucket_exists(self):
|
||||
try:
|
||||
if not self.client.bucket_exists(self.bucket_name):
|
||||
self.client.make_bucket(self.bucket_name)
|
||||
except S3Error as e:
|
||||
print(f"MinIO Bucket Error: {e}")
|
||||
|
||||
def upload_file(self, object_name: str, file_data: BinaryIO, length: int, content_type: str = "application/octet-stream"):
|
||||
try:
|
||||
self.client.put_object(
|
||||
self.bucket_name,
|
||||
object_name,
|
||||
file_data,
|
||||
length,
|
||||
content_type=content_type
|
||||
)
|
||||
return f"http{'s' if self.secure else ''}://{self.endpoint}/{self.bucket_name}/{object_name}"
|
||||
except S3Error as e:
|
||||
print(f"MinIO Upload Error: {e}")
|
||||
raise e
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
try:
|
||||
self.client.list_buckets()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"MinIO Connection Error: {e}")
|
||||
return False
|
||||
|
||||
minio_connector = MinioConnector()
|
||||
@@ -0,0 +1,53 @@
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import Generator
|
||||
import os
|
||||
|
||||
class PostgresConnector:
|
||||
def __init__(self, db_url: str = None):
|
||||
self.db_url = db_url or os.getenv("POSTGRES_URL", "postgresql://user:password@localhost:5432/dbname")
|
||||
self.engine = create_engine(self.db_url)
|
||||
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
|
||||
|
||||
def get_db(self) -> Generator:
|
||||
db = self.SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def execute_query(self, query: str):
|
||||
with self.engine.connect() as connection:
|
||||
result = connection.execute(text(query))
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
||||
def get_schema(self):
|
||||
query = """
|
||||
SELECT table_name, column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
ORDER BY table_name, ordinal_position;
|
||||
"""
|
||||
try:
|
||||
results = self.execute_query(query)
|
||||
schema = {}
|
||||
for row in results:
|
||||
table = row['table_name']
|
||||
if table not in schema:
|
||||
schema[table] = []
|
||||
schema[table].append(f"{row['column_name']} ({row['data_type']})")
|
||||
return schema
|
||||
except Exception as e:
|
||||
print(f"Error getting schema: {e}")
|
||||
return {}
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
try:
|
||||
with self.engine.connect() as connection:
|
||||
connection.execute(text("SELECT 1"))
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"PostgreSQL Connection Error: {e}")
|
||||
return False
|
||||
|
||||
postgres_connector = PostgresConnector()
|
||||
@@ -0,0 +1,149 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Add project root to sys.path to allow importing nanobot
|
||||
# Assuming backend/app/core/nanobot.py -> backend/app/core -> backend/app -> backend -> root
|
||||
# This path calculation seems correct for backend/app/core/nanobot.py relative to backend/
|
||||
# BUT nanobot package is in ../nanobot relative to backend/
|
||||
# So we need to go up one more level to reach the parent of backend/
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(PROJECT_ROOT / "nanobot") not in sys.path:
|
||||
sys.path.append(str(PROJECT_ROOT / "nanobot"))
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.paths import get_cron_dir
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
from nanobot.session.manager import SessionManager
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
# Import skills loader
|
||||
# We use a lazy import inside the method to avoid potential circular dependencies if any arise,
|
||||
# or just import here if we are confident.
|
||||
# Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py.
|
||||
from app.api.skills import load_skills
|
||||
|
||||
class NanobotIntegration:
|
||||
def __init__(self):
|
||||
self.agent: AgentLoop | None = None
|
||||
self.bus: MessageBus | None = None
|
||||
self.cron: CronService | None = None
|
||||
self.config: Config | None = None
|
||||
|
||||
def initialize(self):
|
||||
self.config = load_config()
|
||||
self.bus = MessageBus()
|
||||
provider = self._make_provider(self.config)
|
||||
|
||||
cron_store_path = get_cron_dir() / "jobs.json"
|
||||
self.cron = CronService(cron_store_path)
|
||||
|
||||
session_manager = SessionManager(self.config.workspace_path)
|
||||
|
||||
self.agent = AgentLoop(
|
||||
bus=self.bus,
|
||||
provider=provider,
|
||||
workspace=self.config.workspace_path,
|
||||
model=self.config.agents.defaults.model,
|
||||
temperature=self.config.agents.defaults.temperature,
|
||||
max_tokens=self.config.agents.defaults.max_tokens,
|
||||
max_iterations=self.config.agents.defaults.max_tool_iterations,
|
||||
memory_window=self.config.agents.defaults.memory_window,
|
||||
reasoning_effort=self.config.agents.defaults.reasoning_effort,
|
||||
brave_api_key=self.config.tools.web.search.api_key or None,
|
||||
web_proxy=self.config.tools.web.proxy or None,
|
||||
exec_config=self.config.tools.exec,
|
||||
cron_service=self.cron,
|
||||
restrict_to_workspace=self.config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=self.config.tools.mcp_servers,
|
||||
channels_config=self.config.channels,
|
||||
)
|
||||
|
||||
def _make_provider(self, config: Config):
|
||||
# Logic adapted from nanobot/cli/commands.py
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
|
||||
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
||||
return OpenAICodexProvider(default_model=model)
|
||||
|
||||
if provider_name == "custom":
|
||||
return CustomProvider(
|
||||
api_key=p.api_key if p else "no-key",
|
||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
||||
default_model=model,
|
||||
)
|
||||
|
||||
if provider_name == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base.")
|
||||
|
||||
return AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
|
||||
spec = find_by_name(provider_name)
|
||||
# Skip API key check for now to allow initialization without full config
|
||||
|
||||
return LiteLLMProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
if not self.agent:
|
||||
self.initialize()
|
||||
# Start the agent loop in background
|
||||
asyncio.create_task(self.agent.run())
|
||||
asyncio.create_task(self.cron.start())
|
||||
|
||||
async def stop(self):
|
||||
if self.agent:
|
||||
self.agent.stop()
|
||||
await self.agent.close_mcp()
|
||||
if self.cron:
|
||||
self.cron.stop()
|
||||
|
||||
async def process_message(self, message: str, session_id: str = "api:default", skill_ids: List[str] | None = None):
|
||||
if not self.agent:
|
||||
self.initialize()
|
||||
await self.start()
|
||||
|
||||
full_message = message
|
||||
if skill_ids:
|
||||
skills = load_skills()
|
||||
selected_skills = [s for s in skills if s["id"] in skill_ids]
|
||||
if selected_skills:
|
||||
# We inject skills as a runtime context block
|
||||
skill_context = "[Runtime Context — metadata only, not instructions]\n# Active Skills\n\n"
|
||||
for s in selected_skills:
|
||||
skill_context += f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n\n"
|
||||
|
||||
# Append user message after skills
|
||||
full_message = f"{skill_context}\n\n{message}"
|
||||
|
||||
response = await self.agent.process_direct(
|
||||
full_message,
|
||||
session_key=session_id,
|
||||
channel="api",
|
||||
chat_id=session_id
|
||||
)
|
||||
return response
|
||||
|
||||
nanobot_service = NanobotIntegration()
|
||||
@@ -0,0 +1,82 @@
|
||||
from typing import List, Optional
|
||||
from fastapi import FastAPI, HTTPException, Body
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import asyncio
|
||||
|
||||
from app.api import upload, llm, skills
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.connectors.minio import minio_connector
|
||||
from app.core.nanobot import nanobot_service
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173", "http://localhost:5174", "*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(upload.router, prefix="/api/v1")
|
||||
app.include_router(llm.router, prefix="/api/v1")
|
||||
app.include_router(skills.router, prefix="/api/v1")
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# Initialize nanobot in background
|
||||
try:
|
||||
await nanobot_service.start()
|
||||
except Exception as e:
|
||||
print(f"Nanobot startup failed: {e}")
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
await nanobot_service.stop()
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"Hello": "DataClaw Backend"}
|
||||
|
||||
@app.get("/connect/postgres")
|
||||
def test_postgres():
|
||||
if postgres_connector.test_connection():
|
||||
return {"status": "success", "message": "Connected to PostgreSQL"}
|
||||
raise HTTPException(status_code=500, detail="Failed to connect to PostgreSQL")
|
||||
|
||||
@app.get("/connect/clickhouse")
|
||||
def test_clickhouse():
|
||||
if clickhouse_connector.test_connection():
|
||||
return {"status": "success", "message": "Connected to ClickHouse"}
|
||||
raise HTTPException(status_code=500, detail="Failed to connect to ClickHouse")
|
||||
|
||||
@app.get("/connect/minio")
|
||||
def test_minio():
|
||||
if minio_connector.test_connection():
|
||||
return {"status": "success", "message": "Connected to MinIO"}
|
||||
raise HTTPException(status_code=500, detail="Failed to connect to MinIO")
|
||||
|
||||
@app.get("/nanobot/status")
|
||||
def nanobot_status():
|
||||
if nanobot_service.agent:
|
||||
return {"status": "running", "model": nanobot_service.agent.model}
|
||||
return {"status": "stopped"}
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
skill_ids: Optional[List[str]] = None
|
||||
|
||||
@app.post("/nanobot/chat")
|
||||
async def nanobot_chat(request: ChatRequest):
|
||||
try:
|
||||
response = await nanobot_service.process_message(request.message, skill_ids=request.skill_ids)
|
||||
return {"response": response}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post("/api/v1/agent/nl2sql", response_model=NL2SQLResponse)
|
||||
async def run_nl2sql(request: NL2SQLRequest):
|
||||
return await process_nl2sql(request)
|
||||
@@ -0,0 +1,48 @@
|
||||
[project]
|
||||
name = "backend"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"chardet>=3.0.2,<6.0.0",
|
||||
"clickhouse-driver>=0.2.10",
|
||||
"croniter>=6.0.0,<7.0.0",
|
||||
"dingtalk-stream>=0.24.0,<1.0.0",
|
||||
"duckdb>=1.5.0",
|
||||
"fastapi>=0.135.1",
|
||||
"httpx>=0.28.0,<1.0.0",
|
||||
"json-repair>=0.57.0,<1.0.0",
|
||||
"lark-oapi>=1.5.0,<2.0.0",
|
||||
"litellm>=1.81.5,<2.0.0",
|
||||
"loguru>=0.7.3,<1.0.0",
|
||||
"mcp>=1.26.0,<2.0.0",
|
||||
"minio>=7.2.20",
|
||||
"msgpack>=1.1.0,<2.0.0",
|
||||
"nanobot-ai",
|
||||
"oauth-cli-kit>=0.1.3,<1.0.0",
|
||||
"openai>=2.8.0",
|
||||
"pandas>=3.0.1",
|
||||
"prompt-toolkit>=3.0.50,<4.0.0",
|
||||
"psycopg2-binary>=2.9.11",
|
||||
"pydantic>=2.12.0,<3.0.0",
|
||||
"pydantic-settings>=2.12.0,<3.0.0",
|
||||
"python-multipart>=0.0.22",
|
||||
"python-socketio>=5.16.0,<6.0.0",
|
||||
"python-socks[asyncio]>=2.8.0,<3.0.0",
|
||||
"python-telegram-bot[socks]>=22.6,<23.0",
|
||||
"qq-botpy>=1.2.0,<2.0.0",
|
||||
"readability-lxml>=0.8.4,<1.0.0",
|
||||
"rich>=14.0.0,<15.0.0",
|
||||
"slack-sdk>=3.39.0,<4.0.0",
|
||||
"slackify-markdown>=0.2.0,<1.0.0",
|
||||
"socksio>=1.0.0,<2.0.0",
|
||||
"sqlalchemy>=2.0.48",
|
||||
"typer>=0.20.0,<1.0.0",
|
||||
"uvicorn>=0.41.0",
|
||||
"websocket-client>=1.9.0,<2.0.0",
|
||||
"websockets>=16.0,<17.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
nanobot-ai = { path = "../nanobot" }
|
||||
Generated
+3307
File diff suppressed because it is too large
Load Diff
Submodule
+1
Submodule frontend added at f4a923cdbb
@@ -0,0 +1,13 @@
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.egg-info
|
||||
dist/
|
||||
build/
|
||||
.git
|
||||
.env
|
||||
.assets
|
||||
node_modules/
|
||||
bridge/dist/
|
||||
workspace/
|
||||
@@ -0,0 +1,23 @@
|
||||
.worktrees/
|
||||
.assets
|
||||
.env
|
||||
*.pyc
|
||||
dist/
|
||||
build/
|
||||
docs/
|
||||
*.egg-info/
|
||||
*.egg
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.pyw
|
||||
*.pyz
|
||||
*.pywz
|
||||
*.pyzz
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
poetry.lock
|
||||
.pytest_cache/
|
||||
botpy.log
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
We provide QR codes for joining the HKUDS discussion groups on **WeChat** and **Feishu**.
|
||||
|
||||
You can join by scanning the QR codes below:
|
||||
|
||||
<img src="https://github.com/HKUDS/.github/blob/main/profile/QR.png" alt="WeChat QR Code" width="400"/>
|
||||
@@ -0,0 +1,40 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
||||
|
||||
# Install Node.js 20 for the WhatsApp bridge
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git && \
|
||||
mkdir -p /etc/apt/keyrings && \
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends nodejs && \
|
||||
apt-get purge -y gnupg && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python dependencies first (cached layer)
|
||||
COPY pyproject.toml README.md LICENSE ./
|
||||
RUN mkdir -p nanobot bridge && touch nanobot/__init__.py && \
|
||||
uv pip install --system --no-cache . && \
|
||||
rm -rf nanobot bridge
|
||||
|
||||
# Copy the full source and install
|
||||
COPY nanobot/ nanobot/
|
||||
COPY bridge/ bridge/
|
||||
RUN uv pip install --system --no-cache .
|
||||
|
||||
# Build the WhatsApp bridge
|
||||
WORKDIR /app/bridge
|
||||
RUN npm install && npm run build
|
||||
WORKDIR /app
|
||||
|
||||
# Create config directory
|
||||
RUN mkdir -p /root/.nanobot
|
||||
|
||||
# Gateway default port
|
||||
EXPOSE 18790
|
||||
|
||||
ENTRYPOINT ["nanobot"]
|
||||
CMD ["status"]
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 nanobot contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
+1205
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,263 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you discover a security vulnerability in nanobot, please report it by:
|
||||
|
||||
1. **DO NOT** open a public GitHub issue
|
||||
2. Create a private security advisory on GitHub or contact the repository maintainers (xubinrencs@gmail.com)
|
||||
3. Include:
|
||||
- Description of the vulnerability
|
||||
- Steps to reproduce
|
||||
- Potential impact
|
||||
- Suggested fix (if any)
|
||||
|
||||
We aim to respond to security reports within 48 hours.
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
### 1. API Key Management
|
||||
|
||||
**CRITICAL**: Never commit API keys to version control.
|
||||
|
||||
```bash
|
||||
# ✅ Good: Store in config file with restricted permissions
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
|
||||
# ❌ Bad: Hardcoding keys in code or committing them
|
||||
```
|
||||
|
||||
**Recommendations:**
|
||||
- Store API keys in `~/.nanobot/config.json` with file permissions set to `0600`
|
||||
- Consider using environment variables for sensitive keys
|
||||
- Use OS keyring/credential manager for production deployments
|
||||
- Rotate API keys regularly
|
||||
- Use separate API keys for development and production
|
||||
|
||||
### 2. Channel Access Control
|
||||
|
||||
**IMPORTANT**: Always configure `allowFrom` lists for production use.
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["123456789", "987654321"]
|
||||
},
|
||||
"whatsapp": {
|
||||
"enabled": true,
|
||||
"allowFrom": ["+1234567890"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Security Notes:**
|
||||
- In `v0.1.4.post3` and earlier, an empty `allowFrom` allows all users. In newer versions (including source builds), **empty `allowFrom` denies all access** — set `["*"]` to explicitly allow everyone.
|
||||
- Get your Telegram user ID from `@userinfobot`
|
||||
- Use full phone numbers with country code for WhatsApp
|
||||
- Review access logs regularly for unauthorized access attempts
|
||||
|
||||
### 3. Shell Command Execution
|
||||
|
||||
The `exec` tool can execute shell commands. While dangerous command patterns are blocked, you should:
|
||||
|
||||
- ✅ Review all tool usage in agent logs
|
||||
- ✅ Understand what commands the agent is running
|
||||
- ✅ Use a dedicated user account with limited privileges
|
||||
- ✅ Never run nanobot as root
|
||||
- ❌ Don't disable security checks
|
||||
- ❌ Don't run on systems with sensitive data without careful review
|
||||
|
||||
**Blocked patterns:**
|
||||
- `rm -rf /` - Root filesystem deletion
|
||||
- Fork bombs
|
||||
- Filesystem formatting (`mkfs.*`)
|
||||
- Raw disk writes
|
||||
- Other destructive operations
|
||||
|
||||
### 4. File System Access
|
||||
|
||||
File operations have path traversal protection, but:
|
||||
|
||||
- ✅ Run nanobot with a dedicated user account
|
||||
- ✅ Use filesystem permissions to protect sensitive directories
|
||||
- ✅ Regularly audit file operations in logs
|
||||
- ❌ Don't give unrestricted access to sensitive files
|
||||
|
||||
### 5. Network Security
|
||||
|
||||
**API Calls:**
|
||||
- All external API calls use HTTPS by default
|
||||
- Timeouts are configured to prevent hanging requests
|
||||
- Consider using a firewall to restrict outbound connections if needed
|
||||
|
||||
**WhatsApp Bridge:**
|
||||
- The bridge binds to `127.0.0.1:3001` (localhost only, not accessible from external network)
|
||||
- Set `bridgeToken` in config to enable shared-secret authentication between Python and Node.js
|
||||
- Keep authentication data in `~/.nanobot/whatsapp-auth` secure (mode 0700)
|
||||
|
||||
### 6. Dependency Security
|
||||
|
||||
**Critical**: Keep dependencies updated!
|
||||
|
||||
```bash
|
||||
# Check for vulnerable dependencies
|
||||
pip install pip-audit
|
||||
pip-audit
|
||||
|
||||
# Update to latest secure versions
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
For Node.js dependencies (WhatsApp bridge):
|
||||
```bash
|
||||
cd bridge
|
||||
npm audit
|
||||
npm audit fix
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- Keep `litellm` updated to the latest version for security fixes
|
||||
- We've updated `ws` to `>=8.17.1` to fix DoS vulnerability
|
||||
- Run `pip-audit` or `npm audit` regularly
|
||||
- Subscribe to security advisories for nanobot and its dependencies
|
||||
|
||||
### 7. Production Deployment
|
||||
|
||||
For production use:
|
||||
|
||||
1. **Isolate the Environment**
|
||||
```bash
|
||||
# Run in a container or VM
|
||||
docker run --rm -it python:3.11
|
||||
pip install nanobot-ai
|
||||
```
|
||||
|
||||
2. **Use a Dedicated User**
|
||||
```bash
|
||||
sudo useradd -m -s /bin/bash nanobot
|
||||
sudo -u nanobot nanobot gateway
|
||||
```
|
||||
|
||||
3. **Set Proper Permissions**
|
||||
```bash
|
||||
chmod 700 ~/.nanobot
|
||||
chmod 600 ~/.nanobot/config.json
|
||||
chmod 700 ~/.nanobot/whatsapp-auth
|
||||
```
|
||||
|
||||
4. **Enable Logging**
|
||||
```bash
|
||||
# Configure log monitoring
|
||||
tail -f ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
|
||||
5. **Use Rate Limiting**
|
||||
- Configure rate limits on your API providers
|
||||
- Monitor usage for anomalies
|
||||
- Set spending limits on LLM APIs
|
||||
|
||||
6. **Regular Updates**
|
||||
```bash
|
||||
# Check for updates weekly
|
||||
pip install --upgrade nanobot-ai
|
||||
```
|
||||
|
||||
### 8. Development vs Production
|
||||
|
||||
**Development:**
|
||||
- Use separate API keys
|
||||
- Test with non-sensitive data
|
||||
- Enable verbose logging
|
||||
- Use a test Telegram bot
|
||||
|
||||
**Production:**
|
||||
- Use dedicated API keys with spending limits
|
||||
- Restrict file system access
|
||||
- Enable audit logging
|
||||
- Regular security reviews
|
||||
- Monitor for unusual activity
|
||||
|
||||
### 9. Data Privacy
|
||||
|
||||
- **Logs may contain sensitive information** - secure log files appropriately
|
||||
- **LLM providers see your prompts** - review their privacy policies
|
||||
- **Chat history is stored locally** - protect the `~/.nanobot` directory
|
||||
- **API keys are in plain text** - use OS keyring for production
|
||||
|
||||
### 10. Incident Response
|
||||
|
||||
If you suspect a security breach:
|
||||
|
||||
1. **Immediately revoke compromised API keys**
|
||||
2. **Review logs for unauthorized access**
|
||||
```bash
|
||||
grep "Access denied" ~/.nanobot/logs/nanobot.log
|
||||
```
|
||||
3. **Check for unexpected file modifications**
|
||||
4. **Rotate all credentials**
|
||||
5. **Update to latest version**
|
||||
6. **Report the incident** to maintainers
|
||||
|
||||
## Security Features
|
||||
|
||||
### Built-in Security Controls
|
||||
|
||||
✅ **Input Validation**
|
||||
- Path traversal protection on file operations
|
||||
- Dangerous command pattern detection
|
||||
- Input length limits on HTTP requests
|
||||
|
||||
✅ **Authentication**
|
||||
- Allow-list based access control — in `v0.1.4.post3` and earlier empty means allow all; in newer versions empty means deny all (`["*"]` to explicitly allow all)
|
||||
- Failed authentication attempt logging
|
||||
|
||||
✅ **Resource Protection**
|
||||
- Command execution timeouts (60s default)
|
||||
- Output truncation (10KB limit)
|
||||
- HTTP request timeouts (10-30s)
|
||||
|
||||
✅ **Secure Communication**
|
||||
- HTTPS for all external API calls
|
||||
- TLS for Telegram API
|
||||
- WhatsApp bridge: localhost-only binding + optional token auth
|
||||
|
||||
## Known Limitations
|
||||
|
||||
⚠️ **Current Security Limitations:**
|
||||
|
||||
1. **No Rate Limiting** - Users can send unlimited messages (add your own if needed)
|
||||
2. **Plain Text Config** - API keys stored in plain text (use keyring for production)
|
||||
3. **No Session Management** - No automatic session expiry
|
||||
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns
|
||||
5. **No Audit Trail** - Limited security event logging (enhance as needed)
|
||||
|
||||
## Security Checklist
|
||||
|
||||
Before deploying nanobot:
|
||||
|
||||
- [ ] API keys stored securely (not in code)
|
||||
- [ ] Config file permissions set to 0600
|
||||
- [ ] `allowFrom` lists configured for all channels
|
||||
- [ ] Running as non-root user
|
||||
- [ ] File system permissions properly restricted
|
||||
- [ ] Dependencies updated to latest secure versions
|
||||
- [ ] Logs monitored for security events
|
||||
- [ ] Rate limits configured on API providers
|
||||
- [ ] Backup and disaster recovery plan in place
|
||||
- [ ] Security review of custom skills/tools
|
||||
|
||||
## Updates
|
||||
|
||||
**Last Updated**: 2026-02-03
|
||||
|
||||
For the latest security updates and announcements, check:
|
||||
- GitHub Security Advisories: https://github.com/HKUDS/nanobot/security/advisories
|
||||
- Release Notes: https://github.com/HKUDS/nanobot/releases
|
||||
|
||||
## License
|
||||
|
||||
See LICENSE file for details.
|
||||
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"name": "nanobot-whatsapp-bridge",
|
||||
"version": "0.1.0",
|
||||
"description": "WhatsApp bridge for nanobot using Baileys",
|
||||
"type": "module",
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"build": "tsc",
|
||||
"start": "node dist/index.js",
|
||||
"dev": "tsc && node dist/index.js"
|
||||
},
|
||||
"dependencies": {
|
||||
"@whiskeysockets/baileys": "7.0.0-rc.9",
|
||||
"ws": "^8.17.1",
|
||||
"qrcode-terminal": "^0.12.0",
|
||||
"pino": "^9.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20.14.0",
|
||||
"@types/ws": "^8.5.10",
|
||||
"typescript": "^5.4.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env node
|
||||
/**
|
||||
* nanobot WhatsApp Bridge
|
||||
*
|
||||
* This bridge connects WhatsApp Web to nanobot's Python backend
|
||||
* via WebSocket. It handles authentication, message forwarding,
|
||||
* and reconnection logic.
|
||||
*
|
||||
* Usage:
|
||||
* npm run build && npm start
|
||||
*
|
||||
* Or with custom settings:
|
||||
* BRIDGE_PORT=3001 AUTH_DIR=~/.nanobot/whatsapp npm start
|
||||
*/
|
||||
|
||||
// Polyfill crypto for Baileys in ESM
|
||||
import { webcrypto } from 'crypto';
|
||||
if (!globalThis.crypto) {
|
||||
(globalThis as any).crypto = webcrypto;
|
||||
}
|
||||
|
||||
import { BridgeServer } from './server.js';
|
||||
import { homedir } from 'os';
|
||||
import { join } from 'path';
|
||||
|
||||
const PORT = parseInt(process.env.BRIDGE_PORT || '3001', 10);
|
||||
const AUTH_DIR = process.env.AUTH_DIR || join(homedir(), '.nanobot', 'whatsapp-auth');
|
||||
const TOKEN = process.env.BRIDGE_TOKEN || undefined;
|
||||
|
||||
console.log('🐈 nanobot WhatsApp Bridge');
|
||||
console.log('========================\n');
|
||||
|
||||
const server = new BridgeServer(PORT, AUTH_DIR, TOKEN);
|
||||
|
||||
// Handle graceful shutdown
|
||||
process.on('SIGINT', async () => {
|
||||
console.log('\n\nShutting down...');
|
||||
await server.stop();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
process.on('SIGTERM', async () => {
|
||||
await server.stop();
|
||||
process.exit(0);
|
||||
});
|
||||
|
||||
// Start the server
|
||||
server.start().catch((error) => {
|
||||
console.error('Failed to start bridge:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
@@ -0,0 +1,129 @@
|
||||
/**
|
||||
* WebSocket server for Python-Node.js bridge communication.
|
||||
* Security: binds to 127.0.0.1 only; optional BRIDGE_TOKEN auth.
|
||||
*/
|
||||
|
||||
import { WebSocketServer, WebSocket } from 'ws';
|
||||
import { WhatsAppClient, InboundMessage } from './whatsapp.js';
|
||||
|
||||
interface SendCommand {
|
||||
type: 'send';
|
||||
to: string;
|
||||
text: string;
|
||||
}
|
||||
|
||||
interface BridgeMessage {
|
||||
type: 'message' | 'status' | 'qr' | 'error';
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export class BridgeServer {
|
||||
private wss: WebSocketServer | null = null;
|
||||
private wa: WhatsAppClient | null = null;
|
||||
private clients: Set<WebSocket> = new Set();
|
||||
|
||||
constructor(private port: number, private authDir: string, private token?: string) {}
|
||||
|
||||
async start(): Promise<void> {
|
||||
// Bind to localhost only — never expose to external network
|
||||
this.wss = new WebSocketServer({ host: '127.0.0.1', port: this.port });
|
||||
console.log(`🌉 Bridge server listening on ws://127.0.0.1:${this.port}`);
|
||||
if (this.token) console.log('🔒 Token authentication enabled');
|
||||
|
||||
// Initialize WhatsApp client
|
||||
this.wa = new WhatsAppClient({
|
||||
authDir: this.authDir,
|
||||
onMessage: (msg) => this.broadcast({ type: 'message', ...msg }),
|
||||
onQR: (qr) => this.broadcast({ type: 'qr', qr }),
|
||||
onStatus: (status) => this.broadcast({ type: 'status', status }),
|
||||
});
|
||||
|
||||
// Handle WebSocket connections
|
||||
this.wss.on('connection', (ws) => {
|
||||
if (this.token) {
|
||||
// Require auth handshake as first message
|
||||
const timeout = setTimeout(() => ws.close(4001, 'Auth timeout'), 5000);
|
||||
ws.once('message', (data) => {
|
||||
clearTimeout(timeout);
|
||||
try {
|
||||
const msg = JSON.parse(data.toString());
|
||||
if (msg.type === 'auth' && msg.token === this.token) {
|
||||
console.log('🔗 Python client authenticated');
|
||||
this.setupClient(ws);
|
||||
} else {
|
||||
ws.close(4003, 'Invalid token');
|
||||
}
|
||||
} catch {
|
||||
ws.close(4003, 'Invalid auth message');
|
||||
}
|
||||
});
|
||||
} else {
|
||||
console.log('🔗 Python client connected');
|
||||
this.setupClient(ws);
|
||||
}
|
||||
});
|
||||
|
||||
// Connect to WhatsApp
|
||||
await this.wa.connect();
|
||||
}
|
||||
|
||||
private setupClient(ws: WebSocket): void {
|
||||
this.clients.add(ws);
|
||||
|
||||
ws.on('message', async (data) => {
|
||||
try {
|
||||
const cmd = JSON.parse(data.toString()) as SendCommand;
|
||||
await this.handleCommand(cmd);
|
||||
ws.send(JSON.stringify({ type: 'sent', to: cmd.to }));
|
||||
} catch (error) {
|
||||
console.error('Error handling command:', error);
|
||||
ws.send(JSON.stringify({ type: 'error', error: String(error) }));
|
||||
}
|
||||
});
|
||||
|
||||
ws.on('close', () => {
|
||||
console.log('🔌 Python client disconnected');
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
|
||||
ws.on('error', (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
this.clients.delete(ws);
|
||||
});
|
||||
}
|
||||
|
||||
private async handleCommand(cmd: SendCommand): Promise<void> {
|
||||
if (cmd.type === 'send' && this.wa) {
|
||||
await this.wa.sendMessage(cmd.to, cmd.text);
|
||||
}
|
||||
}
|
||||
|
||||
private broadcast(msg: BridgeMessage): void {
|
||||
const data = JSON.stringify(msg);
|
||||
for (const client of this.clients) {
|
||||
if (client.readyState === WebSocket.OPEN) {
|
||||
client.send(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
// Close all client connections
|
||||
for (const client of this.clients) {
|
||||
client.close();
|
||||
}
|
||||
this.clients.clear();
|
||||
|
||||
// Close WebSocket server
|
||||
if (this.wss) {
|
||||
this.wss.close();
|
||||
this.wss = null;
|
||||
}
|
||||
|
||||
// Disconnect WhatsApp
|
||||
if (this.wa) {
|
||||
await this.wa.disconnect();
|
||||
this.wa = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
Vendored
+3
@@ -0,0 +1,3 @@
|
||||
declare module 'qrcode-terminal' {
|
||||
export function generate(text: string, options?: { small?: boolean }): void;
|
||||
}
|
||||
@@ -0,0 +1,239 @@
|
||||
/**
|
||||
* WhatsApp client wrapper using Baileys.
|
||||
* Based on OpenClaw's working implementation.
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import makeWASocket, {
|
||||
DisconnectReason,
|
||||
useMultiFileAuthState,
|
||||
fetchLatestBaileysVersion,
|
||||
makeCacheableSignalKeyStore,
|
||||
downloadMediaMessage,
|
||||
extractMessageContent as baileysExtractMessageContent,
|
||||
} from '@whiskeysockets/baileys';
|
||||
|
||||
import { Boom } from '@hapi/boom';
|
||||
import qrcode from 'qrcode-terminal';
|
||||
import pino from 'pino';
|
||||
import { writeFile, mkdir } from 'fs/promises';
|
||||
import { join } from 'path';
|
||||
import { randomBytes } from 'crypto';
|
||||
|
||||
const VERSION = '0.1.0';
|
||||
|
||||
export interface InboundMessage {
|
||||
id: string;
|
||||
sender: string;
|
||||
pn: string;
|
||||
content: string;
|
||||
timestamp: number;
|
||||
isGroup: boolean;
|
||||
media?: string[];
|
||||
}
|
||||
|
||||
export interface WhatsAppClientOptions {
|
||||
authDir: string;
|
||||
onMessage: (msg: InboundMessage) => void;
|
||||
onQR: (qr: string) => void;
|
||||
onStatus: (status: string) => void;
|
||||
}
|
||||
|
||||
export class WhatsAppClient {
|
||||
private sock: any = null;
|
||||
private options: WhatsAppClientOptions;
|
||||
private reconnecting = false;
|
||||
|
||||
constructor(options: WhatsAppClientOptions) {
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
async connect(): Promise<void> {
|
||||
const logger = pino({ level: 'silent' });
|
||||
const { state, saveCreds } = await useMultiFileAuthState(this.options.authDir);
|
||||
const { version } = await fetchLatestBaileysVersion();
|
||||
|
||||
console.log(`Using Baileys version: ${version.join('.')}`);
|
||||
|
||||
// Create socket following OpenClaw's pattern
|
||||
this.sock = makeWASocket({
|
||||
auth: {
|
||||
creds: state.creds,
|
||||
keys: makeCacheableSignalKeyStore(state.keys, logger),
|
||||
},
|
||||
version,
|
||||
logger,
|
||||
printQRInTerminal: false,
|
||||
browser: ['nanobot', 'cli', VERSION],
|
||||
syncFullHistory: false,
|
||||
markOnlineOnConnect: false,
|
||||
});
|
||||
|
||||
// Handle WebSocket errors
|
||||
if (this.sock.ws && typeof this.sock.ws.on === 'function') {
|
||||
this.sock.ws.on('error', (err: Error) => {
|
||||
console.error('WebSocket error:', err.message);
|
||||
});
|
||||
}
|
||||
|
||||
// Handle connection updates
|
||||
this.sock.ev.on('connection.update', async (update: any) => {
|
||||
const { connection, lastDisconnect, qr } = update;
|
||||
|
||||
if (qr) {
|
||||
// Display QR code in terminal
|
||||
console.log('\n📱 Scan this QR code with WhatsApp (Linked Devices):\n');
|
||||
qrcode.generate(qr, { small: true });
|
||||
this.options.onQR(qr);
|
||||
}
|
||||
|
||||
if (connection === 'close') {
|
||||
const statusCode = (lastDisconnect?.error as Boom)?.output?.statusCode;
|
||||
const shouldReconnect = statusCode !== DisconnectReason.loggedOut;
|
||||
|
||||
console.log(`Connection closed. Status: ${statusCode}, Will reconnect: ${shouldReconnect}`);
|
||||
this.options.onStatus('disconnected');
|
||||
|
||||
if (shouldReconnect && !this.reconnecting) {
|
||||
this.reconnecting = true;
|
||||
console.log('Reconnecting in 5 seconds...');
|
||||
setTimeout(() => {
|
||||
this.reconnecting = false;
|
||||
this.connect();
|
||||
}, 5000);
|
||||
}
|
||||
} else if (connection === 'open') {
|
||||
console.log('✅ Connected to WhatsApp');
|
||||
this.options.onStatus('connected');
|
||||
}
|
||||
});
|
||||
|
||||
// Save credentials on update
|
||||
this.sock.ev.on('creds.update', saveCreds);
|
||||
|
||||
// Handle incoming messages
|
||||
this.sock.ev.on('messages.upsert', async ({ messages, type }: { messages: any[]; type: string }) => {
|
||||
if (type !== 'notify') return;
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.key.fromMe) continue;
|
||||
if (msg.key.remoteJid === 'status@broadcast') continue;
|
||||
|
||||
const unwrapped = baileysExtractMessageContent(msg.message);
|
||||
if (!unwrapped) continue;
|
||||
|
||||
const content = this.getTextContent(unwrapped);
|
||||
let fallbackContent: string | null = null;
|
||||
const mediaPaths: string[] = [];
|
||||
|
||||
if (unwrapped.imageMessage) {
|
||||
fallbackContent = '[Image]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.imageMessage.mimetype ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
} else if (unwrapped.documentMessage) {
|
||||
fallbackContent = '[Document]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.documentMessage.mimetype ?? undefined,
|
||||
unwrapped.documentMessage.fileName ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
} else if (unwrapped.videoMessage) {
|
||||
fallbackContent = '[Video]';
|
||||
const path = await this.downloadMedia(msg, unwrapped.videoMessage.mimetype ?? undefined);
|
||||
if (path) mediaPaths.push(path);
|
||||
}
|
||||
|
||||
const finalContent = content || (mediaPaths.length === 0 ? fallbackContent : '') || '';
|
||||
if (!finalContent && mediaPaths.length === 0) continue;
|
||||
|
||||
const isGroup = msg.key.remoteJid?.endsWith('@g.us') || false;
|
||||
|
||||
this.options.onMessage({
|
||||
id: msg.key.id || '',
|
||||
sender: msg.key.remoteJid || '',
|
||||
pn: msg.key.remoteJidAlt || '',
|
||||
content: finalContent,
|
||||
timestamp: msg.messageTimestamp as number,
|
||||
isGroup,
|
||||
...(mediaPaths.length > 0 ? { media: mediaPaths } : {}),
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private async downloadMedia(msg: any, mimetype?: string, fileName?: string): Promise<string | null> {
|
||||
try {
|
||||
const mediaDir = join(this.options.authDir, '..', 'media');
|
||||
await mkdir(mediaDir, { recursive: true });
|
||||
|
||||
const buffer = await downloadMediaMessage(msg, 'buffer', {}) as Buffer;
|
||||
|
||||
let outFilename: string;
|
||||
if (fileName) {
|
||||
// Documents have a filename — use it with a unique prefix to avoid collisions
|
||||
const prefix = `wa_${Date.now()}_${randomBytes(4).toString('hex')}_`;
|
||||
outFilename = prefix + fileName;
|
||||
} else {
|
||||
const mime = mimetype || 'application/octet-stream';
|
||||
// Derive extension from mimetype subtype (e.g. "image/png" → ".png", "application/pdf" → ".pdf")
|
||||
const ext = '.' + (mime.split('/').pop()?.split(';')[0] || 'bin');
|
||||
outFilename = `wa_${Date.now()}_${randomBytes(4).toString('hex')}${ext}`;
|
||||
}
|
||||
|
||||
const filepath = join(mediaDir, outFilename);
|
||||
await writeFile(filepath, buffer);
|
||||
|
||||
return filepath;
|
||||
} catch (err) {
|
||||
console.error('Failed to download media:', err);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private getTextContent(message: any): string | null {
|
||||
// Text message
|
||||
if (message.conversation) {
|
||||
return message.conversation;
|
||||
}
|
||||
|
||||
// Extended text (reply, link preview)
|
||||
if (message.extendedTextMessage?.text) {
|
||||
return message.extendedTextMessage.text;
|
||||
}
|
||||
|
||||
// Image with optional caption
|
||||
if (message.imageMessage) {
|
||||
return message.imageMessage.caption || '';
|
||||
}
|
||||
|
||||
// Video with optional caption
|
||||
if (message.videoMessage) {
|
||||
return message.videoMessage.caption || '';
|
||||
}
|
||||
|
||||
// Document with optional caption
|
||||
if (message.documentMessage) {
|
||||
return message.documentMessage.caption || '';
|
||||
}
|
||||
|
||||
// Voice/Audio message
|
||||
if (message.audioMessage) {
|
||||
return `[Voice Message]`;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
async sendMessage(to: string, text: string): Promise<void> {
|
||||
if (!this.sock) {
|
||||
throw new Error('Not connected');
|
||||
}
|
||||
|
||||
await this.sock.sendMessage(to, { text });
|
||||
}
|
||||
|
||||
async disconnect(): Promise<void> {
|
||||
if (this.sock) {
|
||||
this.sock.end(undefined);
|
||||
this.sock = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2022",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "node",
|
||||
"esModuleInterop": true,
|
||||
"strict": true,
|
||||
"skipLibCheck": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"declaration": true,
|
||||
"resolveJsonModule": true
|
||||
},
|
||||
"include": ["src/**/*"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 12 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 5.6 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 6.8 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 6.0 MiB |
Executable
+21
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# Count core agent lines (excluding channels/, cli/, providers/ adapters)
|
||||
cd "$(dirname "$0")" || exit 1
|
||||
|
||||
echo "nanobot core agent line count"
|
||||
echo "================================"
|
||||
echo ""
|
||||
|
||||
for dir in agent agent/tools bus config cron heartbeat session utils; do
|
||||
count=$(find "nanobot/$dir" -maxdepth 1 -name "*.py" -exec cat {} + | wc -l)
|
||||
printf " %-16s %5s lines\n" "$dir/" "$count"
|
||||
done
|
||||
|
||||
root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
|
||||
printf " %-16s %5s lines\n" "(root)" "$root"
|
||||
|
||||
echo ""
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/providers/*" | xargs cat | wc -l)
|
||||
echo " Core total: $total lines"
|
||||
echo ""
|
||||
echo " (excludes: channels/, cli/, providers/)"
|
||||
@@ -0,0 +1,31 @@
|
||||
x-common-config: &common-config
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ~/.nanobot:/root/.nanobot
|
||||
|
||||
services:
|
||||
nanobot-gateway:
|
||||
container_name: nanobot-gateway
|
||||
<<: *common-config
|
||||
command: ["gateway"]
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- 18790:18790
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1'
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
memory: 256M
|
||||
|
||||
nanobot-cli:
|
||||
<<: *common-config
|
||||
profiles:
|
||||
- cli
|
||||
command: ["status"]
|
||||
stdin_open: true
|
||||
tty: true
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
nanobot - A lightweight AI agent framework
|
||||
"""
|
||||
|
||||
__version__ = "0.1.4.post4"
|
||||
__logo__ = "🐈"
|
||||
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Entry point for running nanobot as a module: python -m nanobot
|
||||
"""
|
||||
|
||||
from nanobot.cli.commands import app
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Agent core module."""
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
|
||||
@@ -0,0 +1,193 @@
|
||||
"""Context builder for assembling agent prompts."""
|
||||
|
||||
import base64
|
||||
import mimetypes
|
||||
import platform
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.utils.helpers import detect_image_mime
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.memory = MemoryStore(workspace)
|
||||
self.skills = SkillsLoader(workspace)
|
||||
|
||||
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
||||
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
|
||||
parts = [self._get_identity()]
|
||||
|
||||
bootstrap = self._load_bootstrap_files()
|
||||
if bootstrap:
|
||||
parts.append(bootstrap)
|
||||
|
||||
memory = self.memory.get_memory_context()
|
||||
if memory:
|
||||
parts.append(f"# Memory\n\n{memory}")
|
||||
|
||||
always_skills = self.skills.get_always_skills()
|
||||
if always_skills:
|
||||
always_content = self.skills.load_skills_for_context(always_skills)
|
||||
if always_content:
|
||||
parts.append(f"# Active Skills\n\n{always_content}")
|
||||
|
||||
skills_summary = self.skills.build_skills_summary()
|
||||
if skills_summary:
|
||||
parts.append(f"""# Skills
|
||||
|
||||
The following skills extend your capabilities. To use a skill, read its SKILL.md file using the read_file tool.
|
||||
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew.
|
||||
|
||||
{skills_summary}""")
|
||||
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def _get_identity(self) -> str:
|
||||
"""Get the core identity section."""
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||
|
||||
platform_policy = ""
|
||||
if system == "Windows":
|
||||
platform_policy = """## Platform Policy (Windows)
|
||||
- You are running on Windows. Do not assume GNU tools like `grep`, `sed`, or `awk` exist.
|
||||
- Prefer Windows-native commands or file tools when they are more reliable.
|
||||
- If terminal output is garbled, retry with UTF-8 output enabled.
|
||||
"""
|
||||
else:
|
||||
platform_policy = """## Platform Policy (POSIX)
|
||||
- You are running on a POSIX system. Prefer UTF-8 and standard shell tools.
|
||||
- Use file tools when they are simpler or more reliable than shell commands.
|
||||
"""
|
||||
|
||||
return f"""# nanobot 🐈
|
||||
|
||||
You are nanobot, a helpful AI assistant.
|
||||
|
||||
## Runtime
|
||||
{runtime}
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
|
||||
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
|
||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||
|
||||
{platform_policy}
|
||||
|
||||
## nanobot Guidelines
|
||||
- State intent before tool calls, but NEVER predict or claim results before receiving them.
|
||||
- Before modifying a file, read it first. Do not assume files or directories exist.
|
||||
- After writing or editing a file, re-read it if accuracy matters.
|
||||
- If a tool call fails, analyze the error before retrying with a different approach.
|
||||
- Ask for clarification when the request is ambiguous.
|
||||
|
||||
Reply directly with text for conversations. Only use the 'message' tool to send to a specific chat channel."""
|
||||
|
||||
@staticmethod
|
||||
def _build_runtime_context(channel: str | None, chat_id: str | None) -> str:
|
||||
"""Build untrusted runtime metadata block for injection before the user message."""
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
||||
tz = time.strftime("%Z") or "UTC"
|
||||
lines = [f"Current Time: {now} ({tz})"]
|
||||
if channel and chat_id:
|
||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||
|
||||
def _load_bootstrap_files(self) -> str:
|
||||
"""Load all bootstrap files from workspace."""
|
||||
parts = []
|
||||
|
||||
for filename in self.BOOTSTRAP_FILES:
|
||||
file_path = self.workspace / filename
|
||||
if file_path.exists():
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
parts.append(f"## {filename}\n\n{content}")
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
current_message: str,
|
||||
skill_names: list[str] | None = None,
|
||||
media: list[str] | None = None,
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call."""
|
||||
runtime_ctx = self._build_runtime_context(channel, chat_id)
|
||||
user_content = self._build_user_content(current_message, media)
|
||||
|
||||
# Merge runtime context and user content into a single user message
|
||||
# to avoid consecutive same-role messages that some providers reject.
|
||||
if isinstance(user_content, str):
|
||||
merged = f"{runtime_ctx}\n\n{user_content}"
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
|
||||
return [
|
||||
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||
*history,
|
||||
{"role": "user", "content": merged},
|
||||
]
|
||||
|
||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||
"""Build user message content with optional base64-encoded images."""
|
||||
if not media:
|
||||
return text
|
||||
|
||||
images = []
|
||||
for path in media:
|
||||
p = Path(path)
|
||||
if not p.is_file():
|
||||
continue
|
||||
raw = p.read_bytes()
|
||||
# Detect real MIME type from magic bytes; fallback to filename guess
|
||||
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
||||
if not mime or not mime.startswith("image/"):
|
||||
continue
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
||||
|
||||
if not images:
|
||||
return text
|
||||
return images + [{"type": "text", "text": text}]
|
||||
|
||||
def add_tool_result(
|
||||
self, messages: list[dict[str, Any]],
|
||||
tool_call_id: str, tool_name: str, result: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add a tool result to the message list."""
|
||||
messages.append({"role": "tool", "tool_call_id": tool_call_id, "name": tool_name, "content": result})
|
||||
return messages
|
||||
|
||||
def add_assistant_message(
|
||||
self, messages: list[dict[str, Any]],
|
||||
content: str | None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
thinking_blocks: list[dict] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Add an assistant message to the message list."""
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
if thinking_blocks:
|
||||
msg["thinking_blocks"] = thinking_blocks
|
||||
messages.append(msg)
|
||||
return messages
|
||||
@@ -0,0 +1,509 @@
|
||||
"""Agent loop: the core processing engine."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import weakref
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.config.schema import ChannelsConfig, ExecToolConfig
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
|
||||
It:
|
||||
1. Receives messages from the bus
|
||||
2. Builds context with history, memory, skills
|
||||
3. Calls the LLM
|
||||
4. Executes tool calls
|
||||
5. Sends responses back
|
||||
"""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 500
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bus: MessageBus,
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
memory_window: int = 100,
|
||||
reasoning_effort: str | None = None,
|
||||
brave_api_key: str | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
cron_service: CronService | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
session_manager: SessionManager | None = None,
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
self.bus = bus
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.memory_window = memory_window
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.brave_api_key = brave_api_key
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.cron_service = cron_service
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
|
||||
self.context = ContextBuilder(workspace)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
self.subagents = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_effort=reasoning_effort,
|
||||
brave_api_key=brave_api_key,
|
||||
web_proxy=web_proxy,
|
||||
exec_config=self.exec_config,
|
||||
restrict_to_workspace=restrict_to_workspace,
|
||||
)
|
||||
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._processing_lock = asyncio.Lock()
|
||||
self._register_default_tools()
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
for cls in (ReadFileTool, WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
|
||||
self.tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
self.tools.register(SpawnTool(manager=self.subagents))
|
||||
if self.cron_service:
|
||||
self.tools.register(CronTool(self.cron_service))
|
||||
|
||||
async def _connect_mcp(self) -> None:
|
||||
"""Connect to configured MCP servers (one-time, lazy)."""
|
||||
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
|
||||
return
|
||||
self._mcp_connecting = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
try:
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||
self._mcp_connected = True
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Update context for all tools that need routing info."""
|
||||
for name in ("message", "spawn", "cron"):
|
||||
if tool := self.tools.get(name):
|
||||
if hasattr(tool, "set_context"):
|
||||
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
"""Remove <think>…</think> blocks that some models embed in content."""
|
||||
if not text:
|
||||
return None
|
||||
return re.sub(r"<think>[\s\S]*?</think>", "", text).strip() or None
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
||||
def _fmt(tc):
|
||||
args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if on_progress:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False)
|
||||
}
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
|
||||
for tool_call in response.tool_calls:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
clean = self._strip_think(response.content)
|
||||
# Don't persist error responses to session history — they can
|
||||
# poison the context and cause permanent 400 loops (#1303).
|
||||
if response.finish_reason == "error":
|
||||
logger.error("LLM returned error: {}", (clean or "")[:200])
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
break
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, clean, reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
final_content = clean
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
final_content = (
|
||||
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
self._running = True
|
||||
await self._connect_mcp()
|
||||
logger.info("Agent loop started")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if msg.content.strip().lower() == "/stop":
|
||||
await self._handle_stop(msg)
|
||||
else:
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
|
||||
async def _handle_stop(self, msg: InboundMessage) -> None:
|
||||
"""Cancel all active tasks and subagents for the session."""
|
||||
tasks = self._active_tasks.pop(msg.session_key, [])
|
||||
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||
for t in tasks:
|
||||
try:
|
||||
await t
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
sub_cancelled = await self.subagents.cancel_by_session(msg.session_key)
|
||||
total = cancelled + sub_cancelled
|
||||
content = f"⏹ Stopped {total} task(s)." if total else "No active task to stop."
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||
))
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message under the global lock."""
|
||||
async with self._processing_lock:
|
||||
try:
|
||||
response = await self._process_message(msg)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Close MCP connections."""
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the agent loop."""
|
||||
self._running = False
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_message(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||
else ("cli", msg.chat_id))
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
|
||||
# Slash commands
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
self._consolidating.add(session.key)
|
||||
try:
|
||||
async with lock:
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if snapshot:
|
||||
temp = Session(key=session.key)
|
||||
temp.messages = list(snapshot)
|
||||
if not await self._consolidate_memory(temp, archive_all=True):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("/new archival failed for {}", session.key)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
|
||||
session.clear()
|
||||
self.sessions.save(session)
|
||||
self.sessions.invalidate(session.key)
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="New session started.")
|
||||
if cmd == "/help":
|
||||
return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="🐈 nanobot commands:\n/new — Start a new conversation\n/stop — Stop the current task\n/help — Show available commands")
|
||||
|
||||
unconsolidated = len(session.messages) - session.last_consolidated
|
||||
if (unconsolidated >= self.memory_window and session.key not in self._consolidating):
|
||||
self._consolidating.add(session.key)
|
||||
lock = self._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
|
||||
async def _consolidate_and_unlock():
|
||||
try:
|
||||
async with lock:
|
||||
await self._consolidate_memory(session)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
_task = asyncio.current_task()
|
||||
if _task is not None:
|
||||
self._consolidation_tasks.discard(_task)
|
||||
|
||||
_task = asyncio.create_task(_consolidate_and_unlock())
|
||||
self._consolidation_tasks.add(_task)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=self.memory_window)
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
initial_messages, on_progress=on_progress or _bus_progress,
|
||||
)
|
||||
|
||||
if final_content is None:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
metadata=msg.metadata or {},
|
||||
)
|
||||
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
for m in messages[skip:]:
|
||||
entry = dict(m)
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||
continue # skip empty assistant messages — they poison session context
|
||||
if role == "tool" and isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
elif role == "user":
|
||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||
# Strip the runtime-context prefix, keep only the user text.
|
||||
parts = content.split("\n\n", 1)
|
||||
if len(parts) > 1 and parts[1].strip():
|
||||
entry["content"] = parts[1]
|
||||
else:
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = []
|
||||
for c in content:
|
||||
if c.get("type") == "text" and isinstance(c.get("text"), str) and c["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||
continue # Strip runtime context from multimodal messages
|
||||
if (c.get("type") == "image_url"
|
||||
and c.get("image_url", {}).get("url", "").startswith("data:image/")):
|
||||
filtered.append({"type": "text", "text": "[image]"})
|
||||
else:
|
||||
filtered.append(c)
|
||||
if not filtered:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
||||
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
||||
return await MemoryStore(self.workspace).consolidate(
|
||||
session, self.provider, self.model,
|
||||
archive_all=archive_all, memory_window=self.memory_window,
|
||||
)
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
session_key: str = "cli:direct",
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> str:
|
||||
"""Process a message directly (for CLI or cron usage)."""
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||
return response.content if response else ""
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Memory system for persistent agent memory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import ensure_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
_SAVE_MEMORY_TOOL = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "save_memory",
|
||||
"description": "Save the memory consolidation result to persistent storage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"history_entry": {
|
||||
"type": "string",
|
||||
"description": "A paragraph (2-5 sentences) summarizing key events/decisions/topics. "
|
||||
"Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.",
|
||||
},
|
||||
"memory_update": {
|
||||
"type": "string",
|
||||
"description": "Full updated long-term memory as markdown. Include all existing "
|
||||
"facts plus new ones. Return unchanged if nothing new.",
|
||||
},
|
||||
},
|
||||
"required": ["history_entry", "memory_update"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log)."""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.memory_dir = ensure_dir(workspace / "memory")
|
||||
self.memory_file = self.memory_dir / "MEMORY.md"
|
||||
self.history_file = self.memory_dir / "HISTORY.md"
|
||||
|
||||
def read_long_term(self) -> str:
|
||||
if self.memory_file.exists():
|
||||
return self.memory_file.read_text(encoding="utf-8")
|
||||
return ""
|
||||
|
||||
def write_long_term(self, content: str) -> None:
|
||||
self.memory_file.write_text(content, encoding="utf-8")
|
||||
|
||||
def append_history(self, entry: str) -> None:
|
||||
with open(self.history_file, "a", encoding="utf-8") as f:
|
||||
f.write(entry.rstrip() + "\n\n")
|
||||
|
||||
def get_memory_context(self) -> str:
|
||||
long_term = self.read_long_term()
|
||||
return f"## Long-term Memory\n{long_term}" if long_term else ""
|
||||
|
||||
async def consolidate(
|
||||
self,
|
||||
session: Session,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
*,
|
||||
archive_all: bool = False,
|
||||
memory_window: int = 50,
|
||||
) -> bool:
|
||||
"""Consolidate old messages into MEMORY.md + HISTORY.md via LLM tool call.
|
||||
|
||||
Returns True on success (including no-op), False on failure.
|
||||
"""
|
||||
if archive_all:
|
||||
old_messages = session.messages
|
||||
keep_count = 0
|
||||
logger.info("Memory consolidation (archive_all): {} messages", len(session.messages))
|
||||
else:
|
||||
keep_count = memory_window // 2
|
||||
if len(session.messages) <= keep_count:
|
||||
return True
|
||||
if len(session.messages) - session.last_consolidated <= 0:
|
||||
return True
|
||||
old_messages = session.messages[session.last_consolidated:-keep_count]
|
||||
if not old_messages:
|
||||
return True
|
||||
logger.info("Memory consolidation: {} to consolidate, {} keep", len(old_messages), keep_count)
|
||||
|
||||
lines = []
|
||||
for m in old_messages:
|
||||
if not m.get("content"):
|
||||
continue
|
||||
tools = f" [tools: {', '.join(m['tools_used'])}]" if m.get("tools_used") else ""
|
||||
lines.append(f"[{m.get('timestamp', '?')[:16]}] {m['role'].upper()}{tools}: {m['content']}")
|
||||
|
||||
current_memory = self.read_long_term()
|
||||
prompt = f"""Process this conversation and call the save_memory tool with your consolidation.
|
||||
|
||||
## Current Long-term Memory
|
||||
{current_memory or "(empty)"}
|
||||
|
||||
## Conversation to Process
|
||||
{chr(10).join(lines)}"""
|
||||
|
||||
try:
|
||||
response = await provider.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
tools=_SAVE_MEMORY_TOOL,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
logger.warning("Memory consolidation: LLM did not call save_memory, skipping")
|
||||
return False
|
||||
|
||||
args = response.tool_calls[0].arguments
|
||||
# Some providers return arguments as a JSON string instead of dict
|
||||
if isinstance(args, str):
|
||||
args = json.loads(args)
|
||||
# Some providers return arguments as a list (handle edge case)
|
||||
if isinstance(args, list):
|
||||
if args and isinstance(args[0], dict):
|
||||
args = args[0]
|
||||
else:
|
||||
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
||||
return False
|
||||
if not isinstance(args, dict):
|
||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||
return False
|
||||
|
||||
if entry := args.get("history_entry"):
|
||||
if not isinstance(entry, str):
|
||||
entry = json.dumps(entry, ensure_ascii=False)
|
||||
self.append_history(entry)
|
||||
if update := args.get("memory_update"):
|
||||
if not isinstance(update, str):
|
||||
update = json.dumps(update, ensure_ascii=False)
|
||||
if update != current_memory:
|
||||
self.write_long_term(update)
|
||||
|
||||
session.last_consolidated = 0 if archive_all else len(session.messages) - keep_count
|
||||
logger.info("Memory consolidation done: {} messages, last_consolidated={}", len(session.messages), session.last_consolidated)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Memory consolidation failed")
|
||||
return False
|
||||
@@ -0,0 +1,228 @@
|
||||
"""Skills loader for agent capabilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
# Default builtin skills directory (relative to this file)
|
||||
BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""
|
||||
Loader for agent skills.
|
||||
|
||||
Skills are markdown files (SKILL.md) that teach the agent how to use
|
||||
specific tools or perform certain tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
|
||||
self.workspace = workspace
|
||||
self.workspace_skills = workspace / "skills"
|
||||
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
||||
|
||||
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
|
||||
"""
|
||||
List all available skills.
|
||||
|
||||
Args:
|
||||
filter_unavailable: If True, filter out skills with unmet requirements.
|
||||
|
||||
Returns:
|
||||
List of skill info dicts with 'name', 'path', 'source'.
|
||||
"""
|
||||
skills = []
|
||||
|
||||
# Workspace skills (highest priority)
|
||||
if self.workspace_skills.exists():
|
||||
for skill_dir in self.workspace_skills.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists():
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
|
||||
|
||||
# Built-in skills
|
||||
if self.builtin_skills and self.builtin_skills.exists():
|
||||
for skill_dir in self.builtin_skills.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
|
||||
|
||||
# Filter by requirements
|
||||
if filter_unavailable:
|
||||
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
|
||||
return skills
|
||||
|
||||
def load_skill(self, name: str) -> str | None:
|
||||
"""
|
||||
Load a skill by name.
|
||||
|
||||
Args:
|
||||
name: Skill name (directory name).
|
||||
|
||||
Returns:
|
||||
Skill content or None if not found.
|
||||
"""
|
||||
# Check workspace first
|
||||
workspace_skill = self.workspace_skills / name / "SKILL.md"
|
||||
if workspace_skill.exists():
|
||||
return workspace_skill.read_text(encoding="utf-8")
|
||||
|
||||
# Check built-in
|
||||
if self.builtin_skills:
|
||||
builtin_skill = self.builtin_skills / name / "SKILL.md"
|
||||
if builtin_skill.exists():
|
||||
return builtin_skill.read_text(encoding="utf-8")
|
||||
|
||||
return None
|
||||
|
||||
def load_skills_for_context(self, skill_names: list[str]) -> str:
|
||||
"""
|
||||
Load specific skills for inclusion in agent context.
|
||||
|
||||
Args:
|
||||
skill_names: List of skill names to load.
|
||||
|
||||
Returns:
|
||||
Formatted skills content.
|
||||
"""
|
||||
parts = []
|
||||
for name in skill_names:
|
||||
content = self.load_skill(name)
|
||||
if content:
|
||||
content = self._strip_frontmatter(content)
|
||||
parts.append(f"### Skill: {name}\n\n{content}")
|
||||
|
||||
return "\n\n---\n\n".join(parts) if parts else ""
|
||||
|
||||
def build_skills_summary(self) -> str:
|
||||
"""
|
||||
Build a summary of all skills (name, description, path, availability).
|
||||
|
||||
This is used for progressive loading - the agent can read the full
|
||||
skill content using read_file when needed.
|
||||
|
||||
Returns:
|
||||
XML-formatted skills summary.
|
||||
"""
|
||||
all_skills = self.list_skills(filter_unavailable=False)
|
||||
if not all_skills:
|
||||
return ""
|
||||
|
||||
def escape_xml(s: str) -> str:
|
||||
return s.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
lines = ["<skills>"]
|
||||
for s in all_skills:
|
||||
name = escape_xml(s["name"])
|
||||
path = s["path"]
|
||||
desc = escape_xml(self._get_skill_description(s["name"]))
|
||||
skill_meta = self._get_skill_meta(s["name"])
|
||||
available = self._check_requirements(skill_meta)
|
||||
|
||||
lines.append(f" <skill available=\"{str(available).lower()}\">")
|
||||
lines.append(f" <name>{name}</name>")
|
||||
lines.append(f" <description>{desc}</description>")
|
||||
lines.append(f" <location>{path}</location>")
|
||||
|
||||
# Show missing requirements for unavailable skills
|
||||
if not available:
|
||||
missing = self._get_missing_requirements(skill_meta)
|
||||
if missing:
|
||||
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
||||
|
||||
lines.append(" </skill>")
|
||||
lines.append("</skills>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_missing_requirements(self, skill_meta: dict) -> str:
|
||||
"""Get a description of missing requirements."""
|
||||
missing = []
|
||||
requires = skill_meta.get("requires", {})
|
||||
for b in requires.get("bins", []):
|
||||
if not shutil.which(b):
|
||||
missing.append(f"CLI: {b}")
|
||||
for env in requires.get("env", []):
|
||||
if not os.environ.get(env):
|
||||
missing.append(f"ENV: {env}")
|
||||
return ", ".join(missing)
|
||||
|
||||
def _get_skill_description(self, name: str) -> str:
|
||||
"""Get the description of a skill from its frontmatter."""
|
||||
meta = self.get_skill_metadata(name)
|
||||
if meta and meta.get("description"):
|
||||
return meta["description"]
|
||||
return name # Fallback to skill name
|
||||
|
||||
def _strip_frontmatter(self, content: str) -> str:
|
||||
"""Remove YAML frontmatter from markdown content."""
|
||||
if content.startswith("---"):
|
||||
match = re.match(r"^---\n.*?\n---\n", content, re.DOTALL)
|
||||
if match:
|
||||
return content[match.end():].strip()
|
||||
return content
|
||||
|
||||
def _parse_nanobot_metadata(self, raw: str) -> dict:
|
||||
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
|
||||
def _check_requirements(self, skill_meta: dict) -> bool:
|
||||
"""Check if skill requirements are met (bins, env vars)."""
|
||||
requires = skill_meta.get("requires", {})
|
||||
for b in requires.get("bins", []):
|
||||
if not shutil.which(b):
|
||||
return False
|
||||
for env in requires.get("env", []):
|
||||
if not os.environ.get(env):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_skill_meta(self, name: str) -> dict:
|
||||
"""Get nanobot metadata for a skill (cached in frontmatter)."""
|
||||
meta = self.get_skill_metadata(name) or {}
|
||||
return self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
|
||||
def get_always_skills(self) -> list[str]:
|
||||
"""Get skills marked as always=true that meet requirements."""
|
||||
result = []
|
||||
for s in self.list_skills(filter_unavailable=True):
|
||||
meta = self.get_skill_metadata(s["name"]) or {}
|
||||
skill_meta = self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
if skill_meta.get("always") or meta.get("always"):
|
||||
result.append(s["name"])
|
||||
return result
|
||||
|
||||
def get_skill_metadata(self, name: str) -> dict | None:
|
||||
"""
|
||||
Get metadata from a skill's frontmatter.
|
||||
|
||||
Args:
|
||||
name: Skill name.
|
||||
|
||||
Returns:
|
||||
Metadata dict or None.
|
||||
"""
|
||||
content = self.load_skill(name)
|
||||
if not content:
|
||||
return None
|
||||
|
||||
if content.startswith("---"):
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if match:
|
||||
# Simple YAML parsing
|
||||
metadata = {}
|
||||
for line in match.group(1).split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip('"\'')
|
||||
return metadata
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,246 @@
|
||||
"""Subagent manager for background task execution."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
"""Manages background subagent execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
bus: MessageBus,
|
||||
model: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
reasoning_effort: str | None = None,
|
||||
brave_api_key: str | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: "ExecToolConfig | None" = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.bus = bus
|
||||
self.model = model or provider.get_default_model()
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.brave_api_key = brave_api_key
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
task: str,
|
||||
label: str | None = None,
|
||||
origin_channel: str = "cli",
|
||||
origin_chat_id: str = "direct",
|
||||
session_key: str | None = None,
|
||||
) -> str:
|
||||
"""Spawn a subagent to execute a task in the background."""
|
||||
task_id = str(uuid.uuid4())[:8]
|
||||
display_label = label or task[:30] + ("..." if len(task) > 30 else "")
|
||||
origin = {"channel": origin_channel, "chat_id": origin_chat_id}
|
||||
|
||||
bg_task = asyncio.create_task(
|
||||
self._run_subagent(task_id, task, display_label, origin)
|
||||
)
|
||||
self._running_tasks[task_id] = bg_task
|
||||
if session_key:
|
||||
self._session_tasks.setdefault(session_key, set()).add(task_id)
|
||||
|
||||
def _cleanup(_: asyncio.Task) -> None:
|
||||
self._running_tasks.pop(task_id, None)
|
||||
if session_key and (ids := self._session_tasks.get(session_key)):
|
||||
ids.discard(task_id)
|
||||
if not ids:
|
||||
del self._session_tasks[session_key]
|
||||
|
||||
bg_task.add_done_callback(_cleanup)
|
||||
|
||||
logger.info("Spawned subagent [{}]: {}", task_id, display_label)
|
||||
return f"Subagent [{display_label}] started (id: {task_id}). I'll notify you when it completes."
|
||||
|
||||
async def _run_subagent(
|
||||
self,
|
||||
task_id: str,
|
||||
task: str,
|
||||
label: str,
|
||||
origin: dict[str, str],
|
||||
) -> None:
|
||||
"""Execute the subagent task and announce the result."""
|
||||
logger.info("Subagent [{}] starting task: {}", task_id, label)
|
||||
|
||||
try:
|
||||
# Build subagent tools (no message tool, no spawn tool)
|
||||
tools = ToolRegistry()
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
|
||||
tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||
|
||||
system_prompt = self._build_subagent_prompt()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
|
||||
# Run agent loop (limited iterations)
|
||||
max_iterations = 15
|
||||
iteration = 0
|
||||
final_result: str | None = None
|
||||
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=tools.get_definitions(),
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_effort=self.reasoning_effort,
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
# Add assistant message with tool calls
|
||||
tool_call_dicts = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.name,
|
||||
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
|
||||
},
|
||||
}
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content or "",
|
||||
"tool_calls": tool_call_dicts,
|
||||
})
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
result = await tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"content": result,
|
||||
})
|
||||
else:
|
||||
final_result = response.content
|
||||
break
|
||||
|
||||
if final_result is None:
|
||||
final_result = "Task completed but no final response was generated."
|
||||
|
||||
logger.info("Subagent [{}] completed successfully", task_id)
|
||||
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error: {str(e)}"
|
||||
logger.error("Subagent [{}] failed: {}", task_id, e)
|
||||
await self._announce_result(task_id, label, task, error_msg, origin, "error")
|
||||
|
||||
async def _announce_result(
|
||||
self,
|
||||
task_id: str,
|
||||
label: str,
|
||||
task: str,
|
||||
result: str,
|
||||
origin: dict[str, str],
|
||||
status: str,
|
||||
) -> None:
|
||||
"""Announce the subagent result to the main agent via the message bus."""
|
||||
status_text = "completed successfully" if status == "ok" else "failed"
|
||||
|
||||
announce_content = f"""[Subagent '{label}' {status_text}]
|
||||
|
||||
Task: {task}
|
||||
|
||||
Result:
|
||||
{result}
|
||||
|
||||
Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not mention technical details like "subagent" or task IDs."""
|
||||
|
||||
# Inject as system message to trigger main agent
|
||||
msg = InboundMessage(
|
||||
channel="system",
|
||||
sender_id="subagent",
|
||||
chat_id=f"{origin['channel']}:{origin['chat_id']}",
|
||||
content=announce_content,
|
||||
)
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
||||
|
||||
def _build_subagent_prompt(self) -> str:
|
||||
"""Build a focused system prompt for the subagent."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
time_ctx = ContextBuilder._build_runtime_context(None, None)
|
||||
parts = [f"""# Subagent
|
||||
|
||||
{time_ctx}
|
||||
|
||||
You are a subagent spawned by the main agent to complete a specific task.
|
||||
Stay focused on the assigned task. Your final response will be reported back to the main agent.
|
||||
|
||||
## Workspace
|
||||
{self.workspace}"""]
|
||||
|
||||
skills_summary = SkillsLoader(self.workspace).build_skills_summary()
|
||||
if skills_summary:
|
||||
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
async def cancel_by_session(self, session_key: str) -> int:
|
||||
"""Cancel all subagents for the given session. Returns count cancelled."""
|
||||
tasks = [self._running_tasks[tid] for tid in self._session_tasks.get(session_key, [])
|
||||
if tid in self._running_tasks and not self._running_tasks[tid].done()]
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
return len(tasks)
|
||||
|
||||
def get_running_count(self) -> int:
|
||||
"""Return the number of currently running subagents."""
|
||||
return len(self._running_tasks)
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Agent tools module."""
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
__all__ = ["Tool", "ToolRegistry"]
|
||||
@@ -0,0 +1,181 @@
|
||||
"""Base class for agent tools."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
Abstract base class for agent tools.
|
||||
|
||||
Tools are capabilities that the agent can use to interact with
|
||||
the environment, such as reading files, executing commands, etc.
|
||||
"""
|
||||
|
||||
_TYPE_MAP = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Tool name used in function calls."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def description(self) -> str:
|
||||
"""Description of what the tool does."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""JSON Schema for tool parameters."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
"""
|
||||
Execute the tool with given parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Tool-specific parameters.
|
||||
|
||||
Returns:
|
||||
String result of the tool execution.
|
||||
"""
|
||||
pass
|
||||
|
||||
def cast_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply safe schema-driven casts before validation."""
|
||||
schema = self.parameters or {}
|
||||
if schema.get("type", "object") != "object":
|
||||
return params
|
||||
|
||||
return self._cast_object(params, schema)
|
||||
|
||||
def _cast_object(self, obj: Any, schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Cast an object (dict) according to schema."""
|
||||
if not isinstance(obj, dict):
|
||||
return obj
|
||||
|
||||
props = schema.get("properties", {})
|
||||
result = {}
|
||||
|
||||
for key, value in obj.items():
|
||||
if key in props:
|
||||
result[key] = self._cast_value(value, props[key])
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
def _cast_value(self, val: Any, schema: dict[str, Any]) -> Any:
|
||||
"""Cast a single value according to schema."""
|
||||
target_type = schema.get("type")
|
||||
|
||||
if target_type == "boolean" and isinstance(val, bool):
|
||||
return val
|
||||
if target_type == "integer" and isinstance(val, int) and not isinstance(val, bool):
|
||||
return val
|
||||
if target_type in self._TYPE_MAP and target_type not in ("boolean", "integer", "array", "object"):
|
||||
expected = self._TYPE_MAP[target_type]
|
||||
if isinstance(val, expected):
|
||||
return val
|
||||
|
||||
if target_type == "integer" and isinstance(val, str):
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
if target_type == "number" and isinstance(val, str):
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
if target_type == "string":
|
||||
return val if val is None else str(val)
|
||||
|
||||
if target_type == "boolean" and isinstance(val, str):
|
||||
val_lower = val.lower()
|
||||
if val_lower in ("true", "1", "yes"):
|
||||
return True
|
||||
if val_lower in ("false", "0", "no"):
|
||||
return False
|
||||
return val
|
||||
|
||||
if target_type == "array" and isinstance(val, list):
|
||||
item_schema = schema.get("items")
|
||||
return [self._cast_value(item, item_schema) for item in val] if item_schema else val
|
||||
|
||||
if target_type == "object" and isinstance(val, dict):
|
||||
return self._cast_object(val, schema)
|
||||
|
||||
return val
|
||||
|
||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||
if not isinstance(params, dict):
|
||||
return [f"parameters must be an object, got {type(params).__name__}"]
|
||||
schema = self.parameters or {}
|
||||
if schema.get("type", "object") != "object":
|
||||
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
|
||||
return self._validate(params, {**schema, "type": "object"}, "")
|
||||
|
||||
def _validate(self, val: Any, schema: dict[str, Any], path: str) -> list[str]:
|
||||
t, label = schema.get("type"), path or "parameter"
|
||||
if t == "integer" and (not isinstance(val, int) or isinstance(val, bool)):
|
||||
return [f"{label} should be integer"]
|
||||
if t == "number" and (
|
||||
not isinstance(val, self._TYPE_MAP[t]) or isinstance(val, bool)
|
||||
):
|
||||
return [f"{label} should be number"]
|
||||
if t in self._TYPE_MAP and t not in ("integer", "number") and not isinstance(val, self._TYPE_MAP[t]):
|
||||
return [f"{label} should be {t}"]
|
||||
|
||||
errors = []
|
||||
if "enum" in schema and val not in schema["enum"]:
|
||||
errors.append(f"{label} must be one of {schema['enum']}")
|
||||
if t in ("integer", "number"):
|
||||
if "minimum" in schema and val < schema["minimum"]:
|
||||
errors.append(f"{label} must be >= {schema['minimum']}")
|
||||
if "maximum" in schema and val > schema["maximum"]:
|
||||
errors.append(f"{label} must be <= {schema['maximum']}")
|
||||
if t == "string":
|
||||
if "minLength" in schema and len(val) < schema["minLength"]:
|
||||
errors.append(f"{label} must be at least {schema['minLength']} chars")
|
||||
if "maxLength" in schema and len(val) > schema["maxLength"]:
|
||||
errors.append(f"{label} must be at most {schema['maxLength']} chars")
|
||||
if t == "object":
|
||||
props = schema.get("properties", {})
|
||||
for k in schema.get("required", []):
|
||||
if k not in val:
|
||||
errors.append(f"missing required {path + '.' + k if path else k}")
|
||||
for k, v in val.items():
|
||||
if k in props:
|
||||
errors.extend(self._validate(v, props[k], path + "." + k if path else k))
|
||||
if t == "array" and "items" in schema:
|
||||
for i, item in enumerate(val):
|
||||
errors.extend(
|
||||
self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")
|
||||
)
|
||||
return errors
|
||||
|
||||
def to_schema(self) -> dict[str, Any]:
|
||||
"""Convert tool to OpenAI function schema format."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Cron tool for scheduling reminders and tasks."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronSchedule
|
||||
|
||||
|
||||
class CronTool(Tool):
|
||||
"""Tool to schedule reminders and recurring tasks."""
|
||||
|
||||
def __init__(self, cron_service: CronService):
|
||||
self._cron = cron_service
|
||||
self._channel = ""
|
||||
self._chat_id = ""
|
||||
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str) -> None:
|
||||
"""Set the current session context for delivery."""
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
|
||||
def set_cron_context(self, active: bool):
|
||||
"""Mark whether the tool is executing inside a cron job callback."""
|
||||
return self._in_cron_context.set(active)
|
||||
|
||||
def reset_cron_context(self, token) -> None:
|
||||
"""Restore previous cron context."""
|
||||
self._in_cron_context.reset(token)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "cron"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Schedule reminders and recurring tasks. Actions: add, list, remove."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["add", "list", "remove"],
|
||||
"description": "Action to perform",
|
||||
},
|
||||
"message": {"type": "string", "description": "Reminder message (for add)"},
|
||||
"every_seconds": {
|
||||
"type": "integer",
|
||||
"description": "Interval in seconds (for recurring tasks)",
|
||||
},
|
||||
"cron_expr": {
|
||||
"type": "string",
|
||||
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)",
|
||||
},
|
||||
"tz": {
|
||||
"type": "string",
|
||||
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')",
|
||||
},
|
||||
"at": {
|
||||
"type": "string",
|
||||
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')",
|
||||
},
|
||||
"job_id": {"type": "string", "description": "Job ID (for remove)"},
|
||||
},
|
||||
"required": ["action"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
message: str = "",
|
||||
every_seconds: int | None = None,
|
||||
cron_expr: str | None = None,
|
||||
tz: str | None = None,
|
||||
at: str | None = None,
|
||||
job_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if action == "add":
|
||||
if self._in_cron_context.get():
|
||||
return "Error: cannot schedule new jobs from within a cron job execution"
|
||||
return self._add_job(message, every_seconds, cron_expr, tz, at)
|
||||
elif action == "list":
|
||||
return self._list_jobs()
|
||||
elif action == "remove":
|
||||
return self._remove_job(job_id)
|
||||
return f"Unknown action: {action}"
|
||||
|
||||
def _add_job(
|
||||
self,
|
||||
message: str,
|
||||
every_seconds: int | None,
|
||||
cron_expr: str | None,
|
||||
tz: str | None,
|
||||
at: str | None,
|
||||
) -> str:
|
||||
if not message:
|
||||
return "Error: message is required for add"
|
||||
if not self._channel or not self._chat_id:
|
||||
return "Error: no session context (channel/chat_id)"
|
||||
if tz and not cron_expr:
|
||||
return "Error: tz can only be used with cron_expr"
|
||||
if tz:
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
try:
|
||||
ZoneInfo(tz)
|
||||
except (KeyError, Exception):
|
||||
return f"Error: unknown timezone '{tz}'"
|
||||
|
||||
# Build schedule
|
||||
delete_after = False
|
||||
if every_seconds:
|
||||
schedule = CronSchedule(kind="every", every_ms=every_seconds * 1000)
|
||||
elif cron_expr:
|
||||
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
||||
elif at:
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
dt = datetime.fromisoformat(at)
|
||||
except ValueError:
|
||||
return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS"
|
||||
at_ms = int(dt.timestamp() * 1000)
|
||||
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
||||
delete_after = True
|
||||
else:
|
||||
return "Error: either every_seconds, cron_expr, or at is required"
|
||||
|
||||
job = self._cron.add_job(
|
||||
name=message[:30],
|
||||
schedule=schedule,
|
||||
message=message,
|
||||
deliver=True,
|
||||
channel=self._channel,
|
||||
to=self._chat_id,
|
||||
delete_after_run=delete_after,
|
||||
)
|
||||
return f"Created job '{job.name}' (id: {job.id})"
|
||||
|
||||
def _list_jobs(self) -> str:
|
||||
jobs = self._cron.list_jobs()
|
||||
if not jobs:
|
||||
return "No scheduled jobs."
|
||||
lines = [f"- {j.name} (id: {j.id}, {j.schedule.kind})" for j in jobs]
|
||||
return "Scheduled jobs:\n" + "\n".join(lines)
|
||||
|
||||
def _remove_job(self, job_id: str | None) -> str:
|
||||
if not job_id:
|
||||
return "Error: job_id is required for remove"
|
||||
if self._cron.remove_job(job_id):
|
||||
return f"Removed job {job_id}"
|
||||
return f"Job {job_id} not found"
|
||||
@@ -0,0 +1,238 @@
|
||||
"""File system tools: read, write, edit."""
|
||||
|
||||
import difflib
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
def _resolve_path(
|
||||
path: str, workspace: Path | None = None, allowed_dir: Path | None = None
|
||||
) -> Path:
|
||||
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
||||
p = Path(path).expanduser()
|
||||
if not p.is_absolute() and workspace:
|
||||
p = workspace / p
|
||||
resolved = p.resolve()
|
||||
if allowed_dir:
|
||||
try:
|
||||
resolved.relative_to(allowed_dir.resolve())
|
||||
except ValueError:
|
||||
raise PermissionError(f"Path {path} is outside allowed directory {allowed_dir}")
|
||||
return resolved
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
"""Tool to read file contents."""
|
||||
|
||||
_MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Read the contents of a file at the given path."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "The file path to read"}},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if not file_path.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
size = file_path.stat().st_size
|
||||
if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes)
|
||||
return (
|
||||
f"Error: File too large ({size:,} bytes). "
|
||||
f"Use exec tool with head/tail/grep to read portions."
|
||||
)
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
if len(content) > self._MAX_CHARS:
|
||||
return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})"
|
||||
return content
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
|
||||
class WriteFileTool(Tool):
|
||||
"""Tool to write content to a file."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Write content to a file at the given path. Creates parent directories if needed."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to write to"},
|
||||
"content": {"type": "string", "description": "The content to write"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return f"Successfully wrote {len(content)} bytes to {file_path}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error writing file: {str(e)}"
|
||||
|
||||
|
||||
class EditFileTool(Tool):
|
||||
"""Tool to edit a file by replacing text."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Edit a file by replacing old_text with new_text. The old_text must exist exactly in the file."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "The file path to edit"},
|
||||
"old_text": {"type": "string", "description": "The exact text to find and replace"},
|
||||
"new_text": {"type": "string", "description": "The text to replace with"},
|
||||
},
|
||||
"required": ["path", "old_text", "new_text"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
file_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not file_path.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
if old_text not in content:
|
||||
return self._not_found_message(old_text, content, path)
|
||||
|
||||
# Count occurrences
|
||||
count = content.count(old_text)
|
||||
if count > 1:
|
||||
return f"Warning: old_text appears {count} times. Please provide more context to make it unique."
|
||||
|
||||
new_content = content.replace(old_text, new_text, 1)
|
||||
file_path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
return f"Successfully edited {file_path}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing file: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
def _not_found_message(old_text: str, content: str, path: str) -> str:
|
||||
"""Build a helpful error when old_text is not found."""
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = len(old_lines)
|
||||
|
||||
best_ratio, best_start = 0.0, 0
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
|
||||
if best_ratio > 0.5:
|
||||
diff = "\n".join(
|
||||
difflib.unified_diff(
|
||||
old_lines,
|
||||
lines[best_start : best_start + window],
|
||||
fromfile="old_text (provided)",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
)
|
||||
)
|
||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
return (
|
||||
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
)
|
||||
|
||||
|
||||
class ListDirTool(Tool):
|
||||
"""Tool to list directory contents."""
|
||||
|
||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||
self._workspace = workspace
|
||||
self._allowed_dir = allowed_dir
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_dir"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "List the contents of a directory."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {"path": {"type": "string", "description": "The directory path to list"}},
|
||||
"required": ["path"],
|
||||
}
|
||||
|
||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
dir_path = _resolve_path(path, self._workspace, self._allowed_dir)
|
||||
if not dir_path.exists():
|
||||
return f"Error: Directory not found: {path}"
|
||||
if not dir_path.is_dir():
|
||||
return f"Error: Not a directory: {path}"
|
||||
|
||||
items = []
|
||||
for item in sorted(dir_path.iterdir()):
|
||||
prefix = "📁 " if item.is_dir() else "📄 "
|
||||
items.append(f"{prefix}{item.name}")
|
||||
|
||||
if not items:
|
||||
return f"Directory {path} is empty"
|
||||
|
||||
return "\n".join(items)
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error listing directory: {str(e)}"
|
||||
@@ -0,0 +1,148 @@
|
||||
"""MCP client: connects to MCP servers and wraps their tools as native nanobot tools."""
|
||||
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
|
||||
class MCPToolWrapper(Tool):
|
||||
"""Wraps a single MCP server tool as a nanobot Tool."""
|
||||
|
||||
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
|
||||
self._session = session
|
||||
self._original_name = tool_def.name
|
||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||
self._description = tool_def.description or tool_def.name
|
||||
self._parameters = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||
self._tool_timeout = tool_timeout
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return self._parameters
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
from mcp import types
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._session.call_tool(self._original_name, arguments=kwargs),
|
||||
timeout=self._tool_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("MCP tool '{}' timed out after {}s", self._name, self._tool_timeout)
|
||||
return f"(MCP tool call timed out after {self._tool_timeout}s)"
|
||||
except asyncio.CancelledError:
|
||||
# MCP SDK's anyio cancel scopes can leak CancelledError on timeout/failure.
|
||||
# Re-raise only if our task was externally cancelled (e.g. /stop).
|
||||
task = asyncio.current_task()
|
||||
if task is not None and task.cancelling() > 0:
|
||||
raise
|
||||
logger.warning("MCP tool '{}' was cancelled by server/SDK", self._name)
|
||||
return "(MCP tool call was cancelled)"
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"MCP tool '{}' failed: {}: {}",
|
||||
self._name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return f"(MCP tool call failed: {type(exc).__name__})"
|
||||
|
||||
parts = []
|
||||
for block in result.content:
|
||||
if isinstance(block, types.TextContent):
|
||||
parts.append(block.text)
|
||||
else:
|
||||
parts.append(str(block))
|
||||
return "\n".join(parts) or "(no output)"
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
||||
) -> None:
|
||||
"""Connect to configured MCP servers and register their tools."""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
|
||||
for name, cfg in mcp_servers.items():
|
||||
try:
|
||||
transport_type = cfg.type
|
||||
if not transport_type:
|
||||
if cfg.command:
|
||||
transport_type = "stdio"
|
||||
elif cfg.url:
|
||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
||||
transport_type = (
|
||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
|
||||
if transport_type == "stdio":
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
elif transport_type == "sse":
|
||||
def httpx_client_factory(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
merged_headers = {**(cfg.headers or {}), **(headers or {})}
|
||||
return httpx.AsyncClient(
|
||||
headers=merged_headers or None,
|
||||
follow_redirects=True,
|
||||
timeout=timeout,
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
read, write = await stack.enter_async_context(
|
||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||
)
|
||||
elif transport_type == "streamableHttp":
|
||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||
http_client = await stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=cfg.headers or None,
|
||||
follow_redirects=True,
|
||||
timeout=None,
|
||||
)
|
||||
)
|
||||
read, write, _ = await stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||
continue
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
|
||||
tools = await session.list_tools()
|
||||
for tool_def in tools.tools:
|
||||
wrapper = MCPToolWrapper(session, name, tool_def, tool_timeout=cfg.tool_timeout)
|
||||
registry.register(wrapper)
|
||||
logger.debug("MCP: registered tool '{}' from server '{}'", wrapper.name, name)
|
||||
|
||||
logger.info("MCP server '{}': connected, {} tools registered", name, len(tools.tools))
|
||||
except Exception as e:
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Message tool for sending messages to users."""
|
||||
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
|
||||
class MessageTool(Tool):
|
||||
"""Tool to send messages to users on chat channels."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send_callback: Callable[[OutboundMessage], Awaitable[None]] | None = None,
|
||||
default_channel: str = "",
|
||||
default_chat_id: str = "",
|
||||
default_message_id: str | None = None,
|
||||
):
|
||||
self._send_callback = send_callback
|
||||
self._default_channel = default_channel
|
||||
self._default_chat_id = default_chat_id
|
||||
self._default_message_id = default_message_id
|
||||
self._sent_in_turn: bool = False
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Set the current message context."""
|
||||
self._default_channel = channel
|
||||
self._default_chat_id = chat_id
|
||||
self._default_message_id = message_id
|
||||
|
||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||
"""Set the callback for sending messages."""
|
||||
self._send_callback = callback
|
||||
|
||||
def start_turn(self) -> None:
|
||||
"""Reset per-turn send tracking."""
|
||||
self._sent_in_turn = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "message"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Send a message to the user. Use this when you want to communicate something."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content to send"
|
||||
},
|
||||
"channel": {
|
||||
"type": "string",
|
||||
"description": "Optional: target channel (telegram, discord, etc.)"
|
||||
},
|
||||
"chat_id": {
|
||||
"type": "string",
|
||||
"description": "Optional: target chat/user ID"
|
||||
},
|
||||
"media": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional: list of file paths to attach (images, audio, documents)"
|
||||
}
|
||||
},
|
||||
"required": ["content"]
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
content: str,
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
media: list[str] | None = None,
|
||||
**kwargs: Any
|
||||
) -> str:
|
||||
channel = channel or self._default_channel
|
||||
chat_id = chat_id or self._default_chat_id
|
||||
message_id = message_id or self._default_message_id
|
||||
|
||||
if not channel or not chat_id:
|
||||
return "Error: No target channel/chat specified"
|
||||
|
||||
if not self._send_callback:
|
||||
return "Error: Message sending not configured"
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
await self._send_callback(msg)
|
||||
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||
self._sent_in_turn = True
|
||||
media_info = f" with {len(media)} attachments" if media else ""
|
||||
return f"Message sent to {channel}:{chat_id}{media_info}"
|
||||
except Exception as e:
|
||||
return f"Error sending message: {str(e)}"
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Tool registry for dynamic tool management."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""
|
||||
Registry for agent tools.
|
||||
|
||||
Allows dynamic registration and execution of tools.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""Register a tool."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def unregister(self, name: str) -> None:
|
||||
"""Unregister a tool by name."""
|
||||
self._tools.pop(name, None)
|
||||
|
||||
def get(self, name: str) -> Tool | None:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
"""Check if a tool is registered."""
|
||||
return name in self._tools
|
||||
|
||||
def get_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Get all tool definitions in OpenAI format."""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
|
||||
async def execute(self, name: str, params: dict[str, Any]) -> str:
|
||||
"""Execute a tool by name with given parameters."""
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
|
||||
try:
|
||||
# Attempt to cast parameters to match schema types
|
||||
params = tool.cast_params(params)
|
||||
|
||||
# Validate parameters
|
||||
errors = tool.validate_params(params)
|
||||
if errors:
|
||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||
result = await tool.execute(**params)
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
return result + _HINT
|
||||
return result
|
||||
except Exception as e:
|
||||
return f"Error executing {name}: {str(e)}" + _HINT
|
||||
|
||||
@property
|
||||
def tool_names(self) -> list[str]:
|
||||
"""Get list of registered tool names."""
|
||||
return list(self._tools.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._tools)
|
||||
|
||||
def __contains__(self, name: str) -> bool:
|
||||
return name in self._tools
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Shell execution tool."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
|
||||
class ExecTool(Tool):
|
||||
"""Tool to execute shell commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 60,
|
||||
working_dir: str | None = None,
|
||||
deny_patterns: list[str] | None = None,
|
||||
allow_patterns: list[str] | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
path_append: str = "",
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
self.deny_patterns = deny_patterns or [
|
||||
r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
|
||||
r"\bdel\s+/[fq]\b", # del /f, del /q
|
||||
r"\brmdir\s+/s\b", # rmdir /s
|
||||
r"(?:^|[;&|]\s*)format\b", # format (as standalone command only)
|
||||
r"\b(mkfs|diskpart)\b", # disk operations
|
||||
r"\bdd\s+if=", # dd
|
||||
r">\s*/dev/sd", # write to disk
|
||||
r"\b(shutdown|reboot|poweroff)\b", # system power
|
||||
r":\(\)\s*\{.*\};\s*:", # fork bomb
|
||||
]
|
||||
self.allow_patterns = allow_patterns or []
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.path_append = path_append
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "exec"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The shell command to execute"
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "Optional working directory for the command"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
|
||||
async def execute(self, command: str, working_dir: str | None = None, **kwargs: Any) -> str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
guard_error = self._guard_command(command, cwd)
|
||||
if guard_error:
|
||||
return guard_error
|
||||
|
||||
env = os.environ.copy()
|
||||
if self.path_append:
|
||||
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
# Wait for the process to fully terminate so pipes are
|
||||
# drained and file descriptors are released.
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
return f"Error: Command timed out after {self.timeout} seconds"
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
if process.returncode != 0:
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
# Truncate very long output
|
||||
max_len = 10000
|
||||
if len(result) > max_len:
|
||||
result = result[:max_len] + f"\n... (truncated, {len(result) - max_len} more chars)"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||
"""Best-effort safety guard for potentially destructive commands."""
|
||||
cmd = command.strip()
|
||||
lower = cmd.lower()
|
||||
|
||||
for pattern in self.deny_patterns:
|
||||
if re.search(pattern, lower):
|
||||
return "Error: Command blocked by safety guard (dangerous pattern detected)"
|
||||
|
||||
if self.allow_patterns:
|
||||
if not any(re.search(p, lower) for p in self.allow_patterns):
|
||||
return "Error: Command blocked by safety guard (not in allowlist)"
|
||||
|
||||
if self.restrict_to_workspace:
|
||||
if "..\\" in cmd or "../" in cmd:
|
||||
return "Error: Command blocked by safety guard (path traversal detected)"
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
|
||||
for raw in self._extract_absolute_paths(cmd):
|
||||
try:
|
||||
p = Path(raw.strip()).resolve()
|
||||
except Exception:
|
||||
continue
|
||||
if p.is_absolute() and cwd_path not in p.parents and p != cwd_path:
|
||||
return "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_absolute_paths(command: str) -> list[str]:
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
|
||||
posix_paths = re.findall(r"(?:^|[\s|>])(/[^\s\"'>]+)", command) # POSIX: /absolute only
|
||||
return win_paths + posix_paths
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Spawn tool for creating background subagents."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
|
||||
class SpawnTool(Tool):
|
||||
"""Tool to spawn a subagent for background task execution."""
|
||||
|
||||
def __init__(self, manager: "SubagentManager"):
|
||||
self._manager = manager
|
||||
self._origin_channel = "cli"
|
||||
self._origin_chat_id = "direct"
|
||||
self._session_key = "cli:direct"
|
||||
|
||||
def set_context(self, channel: str, chat_id: str) -> None:
|
||||
"""Set the origin context for subagent announcements."""
|
||||
self._origin_channel = channel
|
||||
self._origin_chat_id = chat_id
|
||||
self._session_key = f"{channel}:{chat_id}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "spawn"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Spawn a subagent to handle a task in the background. "
|
||||
"Use this for complex or time-consuming tasks that can run independently. "
|
||||
"The subagent will complete the task and report back when done."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The task for the subagent to complete",
|
||||
},
|
||||
"label": {
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
},
|
||||
"required": ["task"],
|
||||
}
|
||||
|
||||
async def execute(self, task: str, label: str | None = None, **kwargs: Any) -> str:
|
||||
"""Spawn a subagent to execute the given task."""
|
||||
return await self._manager.spawn(
|
||||
task=task,
|
||||
label=label,
|
||||
origin_channel=self._origin_channel,
|
||||
origin_chat_id=self._origin_chat_id,
|
||||
session_key=self._session_key,
|
||||
)
|
||||
@@ -0,0 +1,181 @@
|
||||
"""Web tools: web_search and web_fetch."""
|
||||
|
||||
import html
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
# Shared constants
|
||||
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 14_7_2) AppleWebKit/537.36"
|
||||
MAX_REDIRECTS = 5 # Limit redirects to prevent DoS attacks
|
||||
|
||||
|
||||
def _strip_tags(text: str) -> str:
|
||||
"""Remove HTML tags and decode entities."""
|
||||
text = re.sub(r'<script[\s\S]*?</script>', '', text, flags=re.I)
|
||||
text = re.sub(r'<style[\s\S]*?</style>', '', text, flags=re.I)
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
return html.unescape(text).strip()
|
||||
|
||||
|
||||
def _normalize(text: str) -> str:
|
||||
"""Normalize whitespace."""
|
||||
text = re.sub(r'[ \t]+', ' ', text)
|
||||
return re.sub(r'\n{3,}', '\n\n', text).strip()
|
||||
|
||||
|
||||
def _validate_url(url: str) -> tuple[bool, str]:
|
||||
"""Validate URL: must be http(s) with valid domain."""
|
||||
try:
|
||||
p = urlparse(url)
|
||||
if p.scheme not in ('http', 'https'):
|
||||
return False, f"Only http/https allowed, got '{p.scheme or 'none'}'"
|
||||
if not p.netloc:
|
||||
return False, "Missing domain"
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
class WebSearchTool(Tool):
|
||||
"""Search the web using Brave Search API."""
|
||||
|
||||
name = "web_search"
|
||||
description = "Search the web. Returns titles, URLs, and snippets."
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"count": {"type": "integer", "description": "Results (1-10)", "minimum": 1, "maximum": 10}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
|
||||
def __init__(self, api_key: str | None = None, max_results: int = 5, proxy: str | None = None):
|
||||
self._init_api_key = api_key
|
||||
self.max_results = max_results
|
||||
self.proxy = proxy
|
||||
|
||||
@property
|
||||
def api_key(self) -> str:
|
||||
"""Resolve API key at call time so env/config changes are picked up."""
|
||||
return self._init_api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
if not self.api_key:
|
||||
return (
|
||||
"Error: Brave Search API key not configured. Set it in "
|
||||
"~/.nanobot/config.json under tools.web.search.apiKey "
|
||||
"(or export BRAVE_API_KEY), then restart the gateway."
|
||||
)
|
||||
|
||||
try:
|
||||
n = min(max(count or self.max_results, 1), 10)
|
||||
logger.debug("WebSearch: {}", "proxy enabled" if self.proxy else "direct connection")
|
||||
async with httpx.AsyncClient(proxy=self.proxy) as client:
|
||||
r = await client.get(
|
||||
"https://api.search.brave.com/res/v1/web/search",
|
||||
params={"q": query, "count": n},
|
||||
headers={"Accept": "application/json", "X-Subscription-Token": self.api_key},
|
||||
timeout=10.0
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
results = r.json().get("web", {}).get("results", [])[:n]
|
||||
if not results:
|
||||
return f"No results for: {query}"
|
||||
|
||||
lines = [f"Results for: {query}\n"]
|
||||
for i, item in enumerate(results, 1):
|
||||
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
||||
if desc := item.get("description"):
|
||||
lines.append(f" {desc}")
|
||||
return "\n".join(lines)
|
||||
except httpx.ProxyError as e:
|
||||
logger.error("WebSearch proxy error: {}", e)
|
||||
return f"Proxy error: {e}"
|
||||
except Exception as e:
|
||||
logger.error("WebSearch error: {}", e)
|
||||
return f"Error: {e}"
|
||||
|
||||
|
||||
class WebFetchTool(Tool):
|
||||
"""Fetch and extract content from a URL using Readability."""
|
||||
|
||||
name = "web_fetch"
|
||||
description = "Fetch URL and extract readable content (HTML → markdown/text)."
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "description": "URL to fetch"},
|
||||
"extractMode": {"type": "string", "enum": ["markdown", "text"], "default": "markdown"},
|
||||
"maxChars": {"type": "integer", "minimum": 100}
|
||||
},
|
||||
"required": ["url"]
|
||||
}
|
||||
|
||||
def __init__(self, max_chars: int = 50000, proxy: str | None = None):
|
||||
self.max_chars = max_chars
|
||||
self.proxy = proxy
|
||||
|
||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
||||
from readability import Document
|
||||
|
||||
max_chars = maxChars or self.max_chars
|
||||
is_valid, error_msg = _validate_url(url)
|
||||
if not is_valid:
|
||||
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
logger.debug("WebFetch: {}", "proxy enabled" if self.proxy else "direct connection")
|
||||
async with httpx.AsyncClient(
|
||||
follow_redirects=True,
|
||||
max_redirects=MAX_REDIRECTS,
|
||||
timeout=30.0,
|
||||
proxy=self.proxy,
|
||||
) as client:
|
||||
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
||||
r.raise_for_status()
|
||||
|
||||
ctype = r.headers.get("content-type", "")
|
||||
|
||||
if "application/json" in ctype:
|
||||
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
||||
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
||||
doc = Document(r.text)
|
||||
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
|
||||
text = f"# {doc.title()}\n\n{content}" if doc.title() else content
|
||||
extractor = "readability"
|
||||
else:
|
||||
text, extractor = r.text, "raw"
|
||||
|
||||
truncated = len(text) > max_chars
|
||||
if truncated: text = text[:max_chars]
|
||||
|
||||
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
||||
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
||||
except httpx.ProxyError as e:
|
||||
logger.error("WebFetch proxy error for {}: {}", url, e)
|
||||
return json.dumps({"error": f"Proxy error: {e}", "url": url}, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error("WebFetch error for {}: {}", url, e)
|
||||
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
||||
|
||||
def _to_markdown(self, html: str) -> str:
|
||||
"""Convert HTML to markdown."""
|
||||
# Convert links, headings, lists before stripping tags
|
||||
text = re.sub(r'<a\s+[^>]*href=["\']([^"\']+)["\'][^>]*>([\s\S]*?)</a>',
|
||||
lambda m: f'[{_strip_tags(m[2])}]({m[1]})', html, flags=re.I)
|
||||
text = re.sub(r'<h([1-6])[^>]*>([\s\S]*?)</h\1>',
|
||||
lambda m: f'\n{"#" * int(m[1])} {_strip_tags(m[2])}\n', text, flags=re.I)
|
||||
text = re.sub(r'<li[^>]*>([\s\S]*?)</li>', lambda m: f'\n- {_strip_tags(m[1])}', text, flags=re.I)
|
||||
text = re.sub(r'</(p|div|section|article)>', '\n\n', text, flags=re.I)
|
||||
text = re.sub(r'<(br|hr)\s*/?>', '\n', text, flags=re.I)
|
||||
return _normalize(_strip_tags(text))
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Message bus module for decoupled channel-agent communication."""
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
__all__ = ["MessageBus", "InboundMessage", "OutboundMessage"]
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Event types for the message bus."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class InboundMessage:
|
||||
"""Message received from a chat channel."""
|
||||
|
||||
channel: str # telegram, discord, slack, whatsapp
|
||||
sender_id: str # User identifier
|
||||
chat_id: str # Chat/channel identifier
|
||||
content: str # Message text
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
media: list[str] = field(default_factory=list) # Media URLs
|
||||
metadata: dict[str, Any] = field(default_factory=dict) # Channel-specific data
|
||||
session_key_override: str | None = None # Optional override for thread-scoped sessions
|
||||
|
||||
@property
|
||||
def session_key(self) -> str:
|
||||
"""Unique key for session identification."""
|
||||
return self.session_key_override or f"{self.channel}:{self.chat_id}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutboundMessage:
|
||||
"""Message to send to a chat channel."""
|
||||
|
||||
channel: str
|
||||
chat_id: str
|
||||
content: str
|
||||
reply_to: str | None = None
|
||||
media: list[str] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Async message queue for decoupled channel-agent communication."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
|
||||
|
||||
class MessageBus:
|
||||
"""
|
||||
Async message bus that decouples chat channels from the agent core.
|
||||
|
||||
Channels push messages to the inbound queue, and the agent processes
|
||||
them and pushes responses to the outbound queue.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.inbound: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||
self.outbound: asyncio.Queue[OutboundMessage] = asyncio.Queue()
|
||||
|
||||
async def publish_inbound(self, msg: InboundMessage) -> None:
|
||||
"""Publish a message from a channel to the agent."""
|
||||
await self.inbound.put(msg)
|
||||
|
||||
async def consume_inbound(self) -> InboundMessage:
|
||||
"""Consume the next inbound message (blocks until available)."""
|
||||
return await self.inbound.get()
|
||||
|
||||
async def publish_outbound(self, msg: OutboundMessage) -> None:
|
||||
"""Publish a response from the agent to channels."""
|
||||
await self.outbound.put(msg)
|
||||
|
||||
async def consume_outbound(self) -> OutboundMessage:
|
||||
"""Consume the next outbound message (blocks until available)."""
|
||||
return await self.outbound.get()
|
||||
|
||||
@property
|
||||
def inbound_size(self) -> int:
|
||||
"""Number of pending inbound messages."""
|
||||
return self.inbound.qsize()
|
||||
|
||||
@property
|
||||
def outbound_size(self) -> int:
|
||||
"""Number of pending outbound messages."""
|
||||
return self.outbound.qsize()
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Chat channels module with plugin architecture."""
|
||||
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
|
||||
__all__ = ["BaseChannel", "ChannelManager"]
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Base channel interface for chat platforms."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
class BaseChannel(ABC):
|
||||
"""
|
||||
Abstract base class for chat channel implementations.
|
||||
|
||||
Each channel (Telegram, Discord, etc.) should implement this interface
|
||||
to integrate with the nanobot message bus.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
"""
|
||||
Initialize the channel.
|
||||
|
||||
Args:
|
||||
config: Channel-specific configuration.
|
||||
bus: The message bus for communication.
|
||||
"""
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self._running = False
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the channel and begin listening for messages.
|
||||
|
||||
This should be a long-running async task that:
|
||||
1. Connects to the chat platform
|
||||
2. Listens for incoming messages
|
||||
3. Forwards messages to the bus via _handle_message()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Stop the channel and clean up resources."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""
|
||||
Send a message through this channel.
|
||||
|
||||
Args:
|
||||
msg: The message to send.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
if not allow_list:
|
||||
logger.warning("{}: allow_from is empty — all access denied", self.name)
|
||||
return False
|
||||
if "*" in allow_list:
|
||||
return True
|
||||
return str(sender_id) in allow_list
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
sender_id: str,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
media: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Handle an incoming message from the chat platform.
|
||||
|
||||
This method checks permissions and forwards to the bus.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
chat_id: The chat/channel identifier.
|
||||
content: Message text content.
|
||||
media: Optional list of media URLs.
|
||||
metadata: Optional channel-specific metadata.
|
||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||
"""
|
||||
if not self.is_allowed(sender_id):
|
||||
logger.warning(
|
||||
"Access denied for sender {} on channel {}. "
|
||||
"Add them to allowFrom list in config to grant access.",
|
||||
sender_id, self.name,
|
||||
)
|
||||
return
|
||||
|
||||
msg = InboundMessage(
|
||||
channel=self.name,
|
||||
sender_id=str(sender_id),
|
||||
chat_id=str(chat_id),
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata=metadata or {},
|
||||
session_key_override=session_key,
|
||||
)
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the channel is running."""
|
||||
return self._running
|
||||
@@ -0,0 +1,471 @@
|
||||
"""DingTalk/DingDing channel implementation using Stream Mode."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import DingTalkConfig
|
||||
|
||||
try:
|
||||
from dingtalk_stream import (
|
||||
AckMessage,
|
||||
CallbackHandler,
|
||||
CallbackMessage,
|
||||
Credential,
|
||||
DingTalkStreamClient,
|
||||
)
|
||||
from dingtalk_stream.chatbot import ChatbotMessage
|
||||
|
||||
DINGTALK_AVAILABLE = True
|
||||
except ImportError:
|
||||
DINGTALK_AVAILABLE = False
|
||||
# Fallback so class definitions don't crash at module level
|
||||
CallbackHandler = object # type: ignore[assignment,misc]
|
||||
CallbackMessage = None # type: ignore[assignment,misc]
|
||||
AckMessage = None # type: ignore[assignment,misc]
|
||||
ChatbotMessage = None # type: ignore[assignment,misc]
|
||||
|
||||
|
||||
class NanobotDingTalkHandler(CallbackHandler):
|
||||
"""
|
||||
Standard DingTalk Stream SDK Callback Handler.
|
||||
Parses incoming messages and forwards them to the Nanobot channel.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: "DingTalkChannel"):
|
||||
super().__init__()
|
||||
self.channel = channel
|
||||
|
||||
async def process(self, message: CallbackMessage):
|
||||
"""Process incoming stream message."""
|
||||
try:
|
||||
# Parse using SDK's ChatbotMessage for robust handling
|
||||
chatbot_msg = ChatbotMessage.from_dict(message.data)
|
||||
|
||||
# Extract text content; fall back to raw dict if SDK object is empty
|
||||
content = ""
|
||||
if chatbot_msg.text:
|
||||
content = chatbot_msg.text.content.strip()
|
||||
if not content:
|
||||
content = message.data.get("text", {}).get("content", "").strip()
|
||||
|
||||
if not content:
|
||||
logger.warning(
|
||||
"Received empty or unsupported message type: {}",
|
||||
chatbot_msg.message_type,
|
||||
)
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
sender_id = chatbot_msg.sender_staff_id or chatbot_msg.sender_id
|
||||
sender_name = chatbot_msg.sender_nick or "Unknown"
|
||||
|
||||
conversation_type = message.data.get("conversationType")
|
||||
conversation_id = (
|
||||
message.data.get("conversationId")
|
||||
or message.data.get("openConversationId")
|
||||
)
|
||||
|
||||
logger.info("Received DingTalk message from {} ({}): {}", sender_name, sender_id, content)
|
||||
|
||||
# Forward to Nanobot via _on_message (non-blocking).
|
||||
# Store reference to prevent GC before task completes.
|
||||
task = asyncio.create_task(
|
||||
self.channel._on_message(
|
||||
content,
|
||||
sender_id,
|
||||
sender_name,
|
||||
conversation_type,
|
||||
conversation_id,
|
||||
)
|
||||
)
|
||||
self.channel._background_tasks.add(task)
|
||||
task.add_done_callback(self.channel._background_tasks.discard)
|
||||
|
||||
return AckMessage.STATUS_OK, "OK"
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing DingTalk message: {}", e)
|
||||
# Return OK to avoid retry loop from DingTalk server
|
||||
return AckMessage.STATUS_OK, "Error"
|
||||
|
||||
|
||||
class DingTalkChannel(BaseChannel):
|
||||
"""
|
||||
DingTalk channel using Stream Mode.
|
||||
|
||||
Uses WebSocket to receive events via `dingtalk-stream` SDK.
|
||||
Uses direct HTTP API to send messages (SDK is mainly for receiving).
|
||||
|
||||
Supports both private (1:1) and group chats.
|
||||
Group chat_id is stored with a "group:" prefix to route replies back.
|
||||
"""
|
||||
|
||||
name = "dingtalk"
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||
|
||||
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DingTalkConfig = config
|
||||
self._client: Any = None
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
|
||||
# Access Token management for sending messages
|
||||
self._access_token: str | None = None
|
||||
self._token_expiry: float = 0
|
||||
|
||||
# Hold references to background tasks to prevent GC
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the DingTalk bot with Stream Mode."""
|
||||
try:
|
||||
if not DINGTALK_AVAILABLE:
|
||||
logger.error(
|
||||
"DingTalk Stream SDK not installed. Run: pip install dingtalk-stream"
|
||||
)
|
||||
return
|
||||
|
||||
if not self.config.client_id or not self.config.client_secret:
|
||||
logger.error("DingTalk client_id and client_secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient()
|
||||
|
||||
logger.info(
|
||||
"Initializing DingTalk Stream Client with Client ID: {}...",
|
||||
self.config.client_id,
|
||||
)
|
||||
credential = Credential(self.config.client_id, self.config.client_secret)
|
||||
self._client = DingTalkStreamClient(credential)
|
||||
|
||||
# Register standard handler
|
||||
handler = NanobotDingTalkHandler(self)
|
||||
self._client.register_callback_handler(ChatbotMessage.TOPIC, handler)
|
||||
|
||||
logger.info("DingTalk bot started with Stream Mode")
|
||||
|
||||
# Reconnect loop: restart stream if SDK exits or crashes
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start()
|
||||
except Exception as e:
|
||||
logger.warning("DingTalk stream error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting DingTalk stream in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to start DingTalk channel: {}", e)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the DingTalk bot."""
|
||||
self._running = False
|
||||
# Close the shared HTTP client
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
# Cancel outstanding background tasks
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
self._background_tasks.clear()
|
||||
|
||||
async def _get_access_token(self) -> str | None:
|
||||
"""Get or refresh Access Token."""
|
||||
if self._access_token and time.time() < self._token_expiry:
|
||||
return self._access_token
|
||||
|
||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
data = {
|
||||
"appKey": self.config.client_id,
|
||||
"appSecret": self.config.client_secret,
|
||||
}
|
||||
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot refresh token")
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=data)
|
||||
resp.raise_for_status()
|
||||
res_data = resp.json()
|
||||
self._access_token = res_data.get("accessToken")
|
||||
# Expire 60s early to be safe
|
||||
self._token_expiry = time.time() + int(res_data.get("expireIn", 7200)) - 60
|
||||
return self._access_token
|
||||
except Exception as e:
|
||||
logger.error("Failed to get DingTalk access token: {}", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _is_http_url(value: str) -> bool:
|
||||
return urlparse(value).scheme in ("http", "https")
|
||||
|
||||
def _guess_upload_type(self, media_ref: str) -> str:
|
||||
ext = Path(urlparse(media_ref).path).suffix.lower()
|
||||
if ext in self._IMAGE_EXTS: return "image"
|
||||
if ext in self._AUDIO_EXTS: return "voice"
|
||||
if ext in self._VIDEO_EXTS: return "video"
|
||||
return "file"
|
||||
|
||||
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
|
||||
name = os.path.basename(urlparse(media_ref).path)
|
||||
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
|
||||
|
||||
async def _read_media_bytes(
|
||||
self,
|
||||
media_ref: str,
|
||||
) -> tuple[bytes | None, str | None, str | None]:
|
||||
if not media_ref:
|
||||
return None, None, None
|
||||
|
||||
if self._is_http_url(media_ref):
|
||||
if not self._http:
|
||||
return None, None, None
|
||||
try:
|
||||
resp = await self._http.get(media_ref, follow_redirects=True)
|
||||
if resp.status_code >= 400:
|
||||
logger.warning(
|
||||
"DingTalk media download failed status={} ref={}",
|
||||
resp.status_code,
|
||||
media_ref,
|
||||
)
|
||||
return None, None, None
|
||||
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||
return resp.content, filename, content_type or None
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||
return None, None, None
|
||||
|
||||
try:
|
||||
if media_ref.startswith("file://"):
|
||||
parsed = urlparse(media_ref)
|
||||
local_path = Path(unquote(parsed.path))
|
||||
else:
|
||||
local_path = Path(os.path.expanduser(media_ref))
|
||||
if not local_path.is_file():
|
||||
logger.warning("DingTalk media file not found: {}", local_path)
|
||||
return None, None, None
|
||||
data = await asyncio.to_thread(local_path.read_bytes)
|
||||
content_type = mimetypes.guess_type(local_path.name)[0]
|
||||
return data, local_path.name, content_type
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media read error ref={} err={}", media_ref, e)
|
||||
return None, None, None
|
||||
|
||||
async def _upload_media(
|
||||
self,
|
||||
token: str,
|
||||
data: bytes,
|
||||
media_type: str,
|
||||
filename: str,
|
||||
content_type: str | None,
|
||||
) -> str | None:
|
||||
if not self._http:
|
||||
return None
|
||||
url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}"
|
||||
mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
files = {"media": (filename, data, mime)}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, files=files)
|
||||
text = resp.text
|
||||
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||||
if resp.status_code >= 400:
|
||||
logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
|
||||
return None
|
||||
errcode = result.get("errcode", 0)
|
||||
if errcode != 0:
|
||||
logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
|
||||
return None
|
||||
sub = result.get("result") or {}
|
||||
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
|
||||
if not media_id:
|
||||
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
||||
return None
|
||||
return str(media_id)
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
||||
return None
|
||||
|
||||
async def _send_batch_message(
|
||||
self,
|
||||
token: str,
|
||||
chat_id: str,
|
||||
msg_key: str,
|
||||
msg_param: dict[str, Any],
|
||||
) -> bool:
|
||||
if not self._http:
|
||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||
return False
|
||||
|
||||
headers = {"x-acs-dingtalk-access-token": token}
|
||||
if chat_id.startswith("group:"):
|
||||
# Group chat
|
||||
url = "https://api.dingtalk.com/v1.0/robot/groupMessages/send"
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"openConversationId": chat_id[6:], # Remove "group:" prefix,
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
else:
|
||||
# Private chat
|
||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
payload = {
|
||||
"robotCode": self.config.client_id,
|
||||
"userIds": [chat_id],
|
||||
"msgKey": msg_key,
|
||||
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, json=payload, headers=headers)
|
||||
body = resp.text
|
||||
if resp.status_code != 200:
|
||||
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
||||
return False
|
||||
try: result = resp.json()
|
||||
except Exception: result = {}
|
||||
errcode = result.get("errcode")
|
||||
if errcode not in (None, 0):
|
||||
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
||||
return False
|
||||
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
||||
return False
|
||||
|
||||
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
|
||||
return await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleMarkdown",
|
||||
{"text": content, "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool:
|
||||
media_ref = (media_ref or "").strip()
|
||||
if not media_ref:
|
||||
return True
|
||||
|
||||
upload_type = self._guess_upload_type(media_ref)
|
||||
if upload_type == "image" and self._is_http_url(media_ref):
|
||||
ok = await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleImageMsg",
|
||||
{"photoURL": media_ref},
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
|
||||
|
||||
data, filename, content_type = await self._read_media_bytes(media_ref)
|
||||
if not data:
|
||||
logger.error("DingTalk media read failed: {}", media_ref)
|
||||
return False
|
||||
|
||||
filename = filename or self._guess_filename(media_ref, upload_type)
|
||||
file_type = Path(filename).suffix.lower().lstrip(".")
|
||||
if not file_type:
|
||||
guessed = mimetypes.guess_extension(content_type or "")
|
||||
file_type = (guessed or ".bin").lstrip(".")
|
||||
if file_type == "jpeg":
|
||||
file_type = "jpg"
|
||||
|
||||
media_id = await self._upload_media(
|
||||
token=token,
|
||||
data=data,
|
||||
media_type=upload_type,
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
)
|
||||
if not media_id:
|
||||
return False
|
||||
|
||||
if upload_type == "image":
|
||||
# Verified in production: sampleImageMsg accepts media_id in photoURL.
|
||||
ok = await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleImageMsg",
|
||||
{"photoURL": media_id},
|
||||
)
|
||||
if ok:
|
||||
return True
|
||||
logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
|
||||
|
||||
return await self._send_batch_message(
|
||||
token,
|
||||
chat_id,
|
||||
"sampleFile",
|
||||
{"mediaId": media_id, "fileName": filename, "fileType": file_type},
|
||||
)
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through DingTalk."""
|
||||
token = await self._get_access_token()
|
||||
if not token:
|
||||
return
|
||||
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_markdown_text(token, msg.chat_id, msg.content.strip())
|
||||
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
|
||||
if ok:
|
||||
continue
|
||||
logger.error("DingTalk media send failed for {}", media_ref)
|
||||
# Send visible fallback so failures are observable by the user.
|
||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||
await self._send_markdown_text(
|
||||
token,
|
||||
msg.chat_id,
|
||||
f"[Attachment send failed: {filename}]",
|
||||
)
|
||||
|
||||
async def _on_message(
|
||||
self,
|
||||
content: str,
|
||||
sender_id: str,
|
||||
sender_name: str,
|
||||
conversation_type: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||
|
||||
Delegates to BaseChannel._handle_message() which enforces allow_from
|
||||
permission checks before publishing to the bus.
|
||||
"""
|
||||
try:
|
||||
logger.info("DingTalk inbound: {} from {}", content, sender_name)
|
||||
is_group = conversation_type == "2" and conversation_id
|
||||
chat_id = f"group:{conversation_id}" if is_group else sender_id
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=str(content),
|
||||
metadata={
|
||||
"sender_name": sender_name,
|
||||
"platform": "dingtalk",
|
||||
"conversation_type": conversation_type,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error publishing DingTalk message: {}", e)
|
||||
@@ -0,0 +1,376 @@
|
||||
"""Discord channel implementation using Discord Gateway websocket."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import DiscordConfig
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
"""Discord channel using Gateway websocket."""
|
||||
|
||||
name = "discord"
|
||||
|
||||
def __init__(self, config: DiscordConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: DiscordConfig = config
|
||||
self._ws: websockets.WebSocketClientProtocol | None = None
|
||||
self._seq: int | None = None
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
if not self.config.token:
|
||||
logger.error("Discord bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.info("Connecting to Discord gateway...")
|
||||
async with websockets.connect(self.config.gateway_url) as ws:
|
||||
self._ws = ws
|
||||
await self._gateway_loop()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Discord gateway error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Discord channel."""
|
||||
self._running = False
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API, including file attachments."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
return
|
||||
|
||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
|
||||
try:
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
# Send file attachments first
|
||||
for media_path in msg.media or []:
|
||||
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
# Send text content
|
||||
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||
if not chunks and failed_media and not sent_media:
|
||||
chunks = split_message(
|
||||
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
|
||||
MAX_MESSAGE_LEN,
|
||||
)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Let the first successful attachment carry the reply if present.
|
||||
if i == 0 and msg.reply_to and not sent_media:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
if not await self._send_payload(url, headers, payload):
|
||||
break # Abort remaining chunks on failure
|
||||
finally:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
|
||||
async def _send_payload(
|
||||
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
file_path: str,
|
||||
reply_to: str | None = None,
|
||||
) -> bool:
|
||||
"""Send a file attachment via Discord REST API using multipart/form-data."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
payload_json: dict[str, Any] = {}
|
||||
if reply_to:
|
||||
payload_json["message_reference"] = {"message_id": reply_to}
|
||||
payload_json["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
files = {"files[0]": (path.name, f, "application/octet-stream")}
|
||||
data: dict[str, Any] = {}
|
||||
if payload_json:
|
||||
data["payload_json"] = json.dumps(payload_json)
|
||||
response = await self._http.post(
|
||||
url, headers=headers, files=files, data=data
|
||||
)
|
||||
if response.status_code == 429:
|
||||
resp_data = response.json()
|
||||
retry_after = float(resp_data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
async for raw in self._ws:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||
continue
|
||||
|
||||
op = data.get("op")
|
||||
event_type = data.get("t")
|
||||
seq = data.get("s")
|
||||
payload = data.get("d")
|
||||
|
||||
if seq is not None:
|
||||
self._seq = seq
|
||||
|
||||
if op == 10:
|
||||
# HELLO: start heartbeat and identify
|
||||
interval_ms = payload.get("heartbeat_interval", 45000)
|
||||
await self._start_heartbeat(interval_ms / 1000)
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
# Capture bot user ID for mention detection
|
||||
user_data = payload.get("user") or {}
|
||||
self._bot_user_id = user_data.get("id")
|
||||
logger.info("Discord bot connected as user {}", self._bot_user_id)
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
# RECONNECT: exit loop to reconnect
|
||||
logger.info("Discord gateway requested reconnect")
|
||||
break
|
||||
elif op == 9:
|
||||
# INVALID_SESSION: reconnect
|
||||
logger.warning("Discord gateway invalid session")
|
||||
break
|
||||
|
||||
async def _identify(self) -> None:
|
||||
"""Send IDENTIFY payload."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
identify = {
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.config.token,
|
||||
"intents": self.config.intents,
|
||||
"properties": {
|
||||
"os": "nanobot",
|
||||
"browser": "nanobot",
|
||||
"device": "nanobot",
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send(json.dumps(identify))
|
||||
|
||||
async def _start_heartbeat(self, interval_s: float) -> None:
|
||||
"""Start or restart the heartbeat loop."""
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
|
||||
async def heartbeat_loop() -> None:
|
||||
while self._running and self._ws:
|
||||
payload = {"op": 1, "d": self._seq}
|
||||
try:
|
||||
await self._ws.send(json.dumps(payload))
|
||||
except Exception as e:
|
||||
logger.warning("Discord heartbeat failed: {}", e)
|
||||
break
|
||||
await asyncio.sleep(interval_s)
|
||||
|
||||
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
|
||||
|
||||
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
author = payload.get("author") or {}
|
||||
if author.get("bot"):
|
||||
return
|
||||
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
guild_id = payload.get("guild_id")
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
|
||||
# Check group channel policy (DMs always respond if is_allowed passes)
|
||||
if guild_id is not None:
|
||||
if not self._should_respond_in_group(payload, content):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
media_paths: list[str] = []
|
||||
media_dir = get_media_dir("discord")
|
||||
|
||||
for attachment in payload.get("attachments") or []:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename") or "attachment"
|
||||
size = attachment.get("size") or 0
|
||||
if not url or not self._http:
|
||||
continue
|
||||
if size and size > MAX_ATTACHMENT_BYTES:
|
||||
content_parts.append(f"[attachment: {filename} - too large]")
|
||||
continue
|
||||
try:
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
|
||||
resp = await self._http.get(url)
|
||||
resp.raise_for_status()
|
||||
file_path.write_bytes(resp.content)
|
||||
media_paths.append(str(file_path))
|
||||
content_parts.append(f"[attachment: {file_path}]")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download Discord attachment: {}", e)
|
||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||
|
||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||
|
||||
await self._start_typing(channel_id)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content="\n".join(p for p in content_parts if p) or "[empty message]",
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": guild_id,
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
|
||||
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
|
||||
"""Check if bot should respond in a group channel based on policy."""
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
if self.config.group_policy == "mention":
|
||||
# Check if bot was mentioned in the message
|
||||
if self._bot_user_id:
|
||||
# Check mentions array
|
||||
mentions = payload.get("mentions") or []
|
||||
for mention in mentions:
|
||||
if str(mention.get("id")) == self._bot_user_id:
|
||||
return True
|
||||
# Also check content for mention format <@USER_ID>
|
||||
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
|
||||
return True
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def typing_loop() -> None:
|
||||
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
while self._running:
|
||||
try:
|
||||
await self._http.post(url, headers=headers)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
|
||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||
|
||||
async def _stop_typing(self, channel_id: str) -> None:
|
||||
"""Stop typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(channel_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
@@ -0,0 +1,408 @@
|
||||
"""Email channel implementation using IMAP polling + SMTP replies."""
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import imaplib
|
||||
import re
|
||||
import smtplib
|
||||
import ssl
|
||||
from datetime import date
|
||||
from email import policy
|
||||
from email.header import decode_header, make_header
|
||||
from email.message import EmailMessage
|
||||
from email.parser import BytesParser
|
||||
from email.utils import parseaddr
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import EmailConfig
|
||||
|
||||
|
||||
class EmailChannel(BaseChannel):
|
||||
"""
|
||||
Email channel.
|
||||
|
||||
Inbound:
|
||||
- Poll IMAP mailbox for unread messages.
|
||||
- Convert each message into an inbound event.
|
||||
|
||||
Outbound:
|
||||
- Send responses via SMTP back to the sender address.
|
||||
"""
|
||||
|
||||
name = "email"
|
||||
_IMAP_MONTHS = (
|
||||
"Jan",
|
||||
"Feb",
|
||||
"Mar",
|
||||
"Apr",
|
||||
"May",
|
||||
"Jun",
|
||||
"Jul",
|
||||
"Aug",
|
||||
"Sep",
|
||||
"Oct",
|
||||
"Nov",
|
||||
"Dec",
|
||||
)
|
||||
|
||||
def __init__(self, config: EmailConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: EmailConfig = config
|
||||
self._last_subject_by_chat: dict[str, str] = {}
|
||||
self._last_message_id_by_chat: dict[str, str] = {}
|
||||
self._processed_uids: set[str] = set() # Capped to prevent unbounded growth
|
||||
self._MAX_PROCESSED_UIDS = 100000
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start polling IMAP for inbound emails."""
|
||||
if not self.config.consent_granted:
|
||||
logger.warning(
|
||||
"Email channel disabled: consent_granted is false. "
|
||||
"Set channels.email.consentGranted=true after explicit user permission."
|
||||
)
|
||||
return
|
||||
|
||||
if not self._validate_config():
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("Starting Email channel (IMAP polling mode)...")
|
||||
|
||||
poll_seconds = max(5, int(self.config.poll_interval_seconds))
|
||||
while self._running:
|
||||
try:
|
||||
inbound_items = await asyncio.to_thread(self._fetch_new_messages)
|
||||
for item in inbound_items:
|
||||
sender = item["sender"]
|
||||
subject = item.get("subject", "")
|
||||
message_id = item.get("message_id", "")
|
||||
|
||||
if subject:
|
||||
self._last_subject_by_chat[sender] = subject
|
||||
if message_id:
|
||||
self._last_message_id_by_chat[sender] = message_id
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender,
|
||||
chat_id=sender,
|
||||
content=item["content"],
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Email polling error: {}", e)
|
||||
|
||||
await asyncio.sleep(poll_seconds)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop polling loop."""
|
||||
self._running = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send email via SMTP."""
|
||||
if not self.config.consent_granted:
|
||||
logger.warning("Skip email send: consent_granted is false")
|
||||
return
|
||||
|
||||
if not self.config.smtp_host:
|
||||
logger.warning("Email channel SMTP host not configured")
|
||||
return
|
||||
|
||||
to_addr = msg.chat_id.strip()
|
||||
if not to_addr:
|
||||
logger.warning("Email channel missing recipient address")
|
||||
return
|
||||
|
||||
# Determine if this is a reply (recipient has sent us an email before)
|
||||
is_reply = to_addr in self._last_subject_by_chat
|
||||
force_send = bool((msg.metadata or {}).get("force_send"))
|
||||
|
||||
# autoReplyEnabled only controls automatic replies, not proactive sends
|
||||
if is_reply and not self.config.auto_reply_enabled and not force_send:
|
||||
logger.info("Skip automatic email reply to {}: auto_reply_enabled is false", to_addr)
|
||||
return
|
||||
|
||||
base_subject = self._last_subject_by_chat.get(to_addr, "nanobot reply")
|
||||
subject = self._reply_subject(base_subject)
|
||||
if msg.metadata and isinstance(msg.metadata.get("subject"), str):
|
||||
override = msg.metadata["subject"].strip()
|
||||
if override:
|
||||
subject = override
|
||||
|
||||
email_msg = EmailMessage()
|
||||
email_msg["From"] = self.config.from_address or self.config.smtp_username or self.config.imap_username
|
||||
email_msg["To"] = to_addr
|
||||
email_msg["Subject"] = subject
|
||||
email_msg.set_content(msg.content or "")
|
||||
|
||||
in_reply_to = self._last_message_id_by_chat.get(to_addr)
|
||||
if in_reply_to:
|
||||
email_msg["In-Reply-To"] = in_reply_to
|
||||
email_msg["References"] = in_reply_to
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(self._smtp_send, email_msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending email to {}: {}", to_addr, e)
|
||||
raise
|
||||
|
||||
def _validate_config(self) -> bool:
|
||||
missing = []
|
||||
if not self.config.imap_host:
|
||||
missing.append("imap_host")
|
||||
if not self.config.imap_username:
|
||||
missing.append("imap_username")
|
||||
if not self.config.imap_password:
|
||||
missing.append("imap_password")
|
||||
if not self.config.smtp_host:
|
||||
missing.append("smtp_host")
|
||||
if not self.config.smtp_username:
|
||||
missing.append("smtp_username")
|
||||
if not self.config.smtp_password:
|
||||
missing.append("smtp_password")
|
||||
|
||||
if missing:
|
||||
logger.error("Email channel not configured, missing: {}", ', '.join(missing))
|
||||
return False
|
||||
return True
|
||||
|
||||
def _smtp_send(self, msg: EmailMessage) -> None:
|
||||
timeout = 30
|
||||
if self.config.smtp_use_ssl:
|
||||
with smtplib.SMTP_SSL(
|
||||
self.config.smtp_host,
|
||||
self.config.smtp_port,
|
||||
timeout=timeout,
|
||||
) as smtp:
|
||||
smtp.login(self.config.smtp_username, self.config.smtp_password)
|
||||
smtp.send_message(msg)
|
||||
return
|
||||
|
||||
with smtplib.SMTP(self.config.smtp_host, self.config.smtp_port, timeout=timeout) as smtp:
|
||||
if self.config.smtp_use_tls:
|
||||
smtp.starttls(context=ssl.create_default_context())
|
||||
smtp.login(self.config.smtp_username, self.config.smtp_password)
|
||||
smtp.send_message(msg)
|
||||
|
||||
def _fetch_new_messages(self) -> list[dict[str, Any]]:
|
||||
"""Poll IMAP and return parsed unread messages."""
|
||||
return self._fetch_messages(
|
||||
search_criteria=("UNSEEN",),
|
||||
mark_seen=self.config.mark_seen,
|
||||
dedupe=True,
|
||||
limit=0,
|
||||
)
|
||||
|
||||
def fetch_messages_between_dates(
|
||||
self,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Fetch messages in [start_date, end_date) by IMAP date search.
|
||||
|
||||
This is used for historical summarization tasks (e.g. "yesterday").
|
||||
"""
|
||||
if end_date <= start_date:
|
||||
return []
|
||||
|
||||
return self._fetch_messages(
|
||||
search_criteria=(
|
||||
"SINCE",
|
||||
self._format_imap_date(start_date),
|
||||
"BEFORE",
|
||||
self._format_imap_date(end_date),
|
||||
),
|
||||
mark_seen=False,
|
||||
dedupe=False,
|
||||
limit=max(1, int(limit)),
|
||||
)
|
||||
|
||||
def _fetch_messages(
|
||||
self,
|
||||
search_criteria: tuple[str, ...],
|
||||
mark_seen: bool,
|
||||
dedupe: bool,
|
||||
limit: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch messages by arbitrary IMAP search criteria."""
|
||||
messages: list[dict[str, Any]] = []
|
||||
mailbox = self.config.imap_mailbox or "INBOX"
|
||||
|
||||
if self.config.imap_use_ssl:
|
||||
client = imaplib.IMAP4_SSL(self.config.imap_host, self.config.imap_port)
|
||||
else:
|
||||
client = imaplib.IMAP4(self.config.imap_host, self.config.imap_port)
|
||||
|
||||
try:
|
||||
client.login(self.config.imap_username, self.config.imap_password)
|
||||
status, _ = client.select(mailbox)
|
||||
if status != "OK":
|
||||
return messages
|
||||
|
||||
status, data = client.search(None, *search_criteria)
|
||||
if status != "OK" or not data:
|
||||
return messages
|
||||
|
||||
ids = data[0].split()
|
||||
if limit > 0 and len(ids) > limit:
|
||||
ids = ids[-limit:]
|
||||
for imap_id in ids:
|
||||
status, fetched = client.fetch(imap_id, "(BODY.PEEK[] UID)")
|
||||
if status != "OK" or not fetched:
|
||||
continue
|
||||
|
||||
raw_bytes = self._extract_message_bytes(fetched)
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
uid = self._extract_uid(fetched)
|
||||
if dedupe and uid and uid in self._processed_uids:
|
||||
continue
|
||||
|
||||
parsed = BytesParser(policy=policy.default).parsebytes(raw_bytes)
|
||||
sender = parseaddr(parsed.get("From", ""))[1].strip().lower()
|
||||
if not sender:
|
||||
continue
|
||||
|
||||
subject = self._decode_header_value(parsed.get("Subject", ""))
|
||||
date_value = parsed.get("Date", "")
|
||||
message_id = parsed.get("Message-ID", "").strip()
|
||||
body = self._extract_text_body(parsed)
|
||||
|
||||
if not body:
|
||||
body = "(empty email body)"
|
||||
|
||||
body = body[: self.config.max_body_chars]
|
||||
content = (
|
||||
f"Email received.\n"
|
||||
f"From: {sender}\n"
|
||||
f"Subject: {subject}\n"
|
||||
f"Date: {date_value}\n\n"
|
||||
f"{body}"
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"subject": subject,
|
||||
"date": date_value,
|
||||
"sender_email": sender,
|
||||
"uid": uid,
|
||||
}
|
||||
messages.append(
|
||||
{
|
||||
"sender": sender,
|
||||
"subject": subject,
|
||||
"message_id": message_id,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
|
||||
if dedupe and uid:
|
||||
self._processed_uids.add(uid)
|
||||
# mark_seen is the primary dedup; this set is a safety net
|
||||
if len(self._processed_uids) > self._MAX_PROCESSED_UIDS:
|
||||
# Evict a random half to cap memory; mark_seen is the primary dedup
|
||||
self._processed_uids = set(list(self._processed_uids)[len(self._processed_uids) // 2:])
|
||||
|
||||
if mark_seen:
|
||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||
finally:
|
||||
try:
|
||||
client.logout()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def _format_imap_date(cls, value: date) -> str:
|
||||
"""Format date for IMAP search (always English month abbreviations)."""
|
||||
month = cls._IMAP_MONTHS[value.month - 1]
|
||||
return f"{value.day:02d}-{month}-{value.year}"
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_bytes(fetched: list[Any]) -> bytes | None:
|
||||
for item in fetched:
|
||||
if isinstance(item, tuple) and len(item) >= 2 and isinstance(item[1], (bytes, bytearray)):
|
||||
return bytes(item[1])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_uid(fetched: list[Any]) -> str:
|
||||
for item in fetched:
|
||||
if isinstance(item, tuple) and item and isinstance(item[0], (bytes, bytearray)):
|
||||
head = bytes(item[0]).decode("utf-8", errors="ignore")
|
||||
m = re.search(r"UID\s+(\d+)", head)
|
||||
if m:
|
||||
return m.group(1)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _decode_header_value(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
try:
|
||||
return str(make_header(decode_header(value)))
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _extract_text_body(cls, msg: Any) -> str:
|
||||
"""Best-effort extraction of readable body text."""
|
||||
if msg.is_multipart():
|
||||
plain_parts: list[str] = []
|
||||
html_parts: list[str] = []
|
||||
for part in msg.walk():
|
||||
if part.get_content_disposition() == "attachment":
|
||||
continue
|
||||
content_type = part.get_content_type()
|
||||
try:
|
||||
payload = part.get_content()
|
||||
except Exception:
|
||||
payload_bytes = part.get_payload(decode=True) or b""
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
payload = payload_bytes.decode(charset, errors="replace")
|
||||
if not isinstance(payload, str):
|
||||
continue
|
||||
if content_type == "text/plain":
|
||||
plain_parts.append(payload)
|
||||
elif content_type == "text/html":
|
||||
html_parts.append(payload)
|
||||
if plain_parts:
|
||||
return "\n\n".join(plain_parts).strip()
|
||||
if html_parts:
|
||||
return cls._html_to_text("\n\n".join(html_parts)).strip()
|
||||
return ""
|
||||
|
||||
try:
|
||||
payload = msg.get_content()
|
||||
except Exception:
|
||||
payload_bytes = msg.get_payload(decode=True) or b""
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
payload = payload_bytes.decode(charset, errors="replace")
|
||||
if not isinstance(payload, str):
|
||||
return ""
|
||||
if msg.get_content_type() == "text/html":
|
||||
return cls._html_to_text(payload).strip()
|
||||
return payload.strip()
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(raw_html: str) -> str:
|
||||
text = re.sub(r"<\s*br\s*/?>", "\n", raw_html, flags=re.IGNORECASE)
|
||||
text = re.sub(r"<\s*/\s*p\s*>", "\n", text, flags=re.IGNORECASE)
|
||||
text = re.sub(r"<[^>]+>", "", text)
|
||||
return html.unescape(text)
|
||||
|
||||
def _reply_subject(self, base_subject: str) -> str:
|
||||
subject = (base_subject or "").strip() or "nanobot reply"
|
||||
prefix = self.config.subject_prefix or "Re: "
|
||||
if subject.lower().startswith("re:"):
|
||||
return subject
|
||||
return f"{prefix}{subject}"
|
||||
@@ -0,0 +1,985 @@
|
||||
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import FeishuConfig
|
||||
|
||||
import importlib.util
|
||||
|
||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
|
||||
# Message type display mapping
|
||||
MSG_TYPE_MAP = {
|
||||
"image": "[image]",
|
||||
"audio": "[audio]",
|
||||
"file": "[file]",
|
||||
"sticker": "[sticker]",
|
||||
}
|
||||
|
||||
|
||||
def _extract_share_card_content(content_json: dict, msg_type: str) -> str:
|
||||
"""Extract text representation from share cards and interactive messages."""
|
||||
parts = []
|
||||
|
||||
if msg_type == "share_chat":
|
||||
parts.append(f"[shared chat: {content_json.get('chat_id', '')}]")
|
||||
elif msg_type == "share_user":
|
||||
parts.append(f"[shared user: {content_json.get('user_id', '')}]")
|
||||
elif msg_type == "interactive":
|
||||
parts.extend(_extract_interactive_content(content_json))
|
||||
elif msg_type == "share_calendar_event":
|
||||
parts.append(f"[shared calendar event: {content_json.get('event_key', '')}]")
|
||||
elif msg_type == "system":
|
||||
parts.append("[system message]")
|
||||
elif msg_type == "merge_forward":
|
||||
parts.append("[merged forward messages]")
|
||||
|
||||
return "\n".join(parts) if parts else f"[{msg_type}]"
|
||||
|
||||
|
||||
def _extract_interactive_content(content: dict) -> list[str]:
|
||||
"""Recursively extract text and links from interactive card content."""
|
||||
parts = []
|
||||
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return [content] if content.strip() else []
|
||||
|
||||
if not isinstance(content, dict):
|
||||
return parts
|
||||
|
||||
if "title" in content:
|
||||
title = content["title"]
|
||||
if isinstance(title, dict):
|
||||
title_content = title.get("content", "") or title.get("text", "")
|
||||
if title_content:
|
||||
parts.append(f"title: {title_content}")
|
||||
elif isinstance(title, str):
|
||||
parts.append(f"title: {title}")
|
||||
|
||||
for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
|
||||
for element in elements:
|
||||
parts.extend(_extract_element_content(element))
|
||||
|
||||
card = content.get("card", {})
|
||||
if card:
|
||||
parts.extend(_extract_interactive_content(card))
|
||||
|
||||
header = content.get("header", {})
|
||||
if header:
|
||||
header_title = header.get("title", {})
|
||||
if isinstance(header_title, dict):
|
||||
header_text = header_title.get("content", "") or header_title.get("text", "")
|
||||
if header_text:
|
||||
parts.append(f"title: {header_text}")
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_element_content(element: dict) -> list[str]:
|
||||
"""Extract content from a single card element."""
|
||||
parts = []
|
||||
|
||||
if not isinstance(element, dict):
|
||||
return parts
|
||||
|
||||
tag = element.get("tag", "")
|
||||
|
||||
if tag in ("markdown", "lark_md"):
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
elif tag == "div":
|
||||
text = element.get("text", {})
|
||||
if isinstance(text, dict):
|
||||
text_content = text.get("content", "") or text.get("text", "")
|
||||
if text_content:
|
||||
parts.append(text_content)
|
||||
elif isinstance(text, str):
|
||||
parts.append(text)
|
||||
for field in element.get("fields", []):
|
||||
if isinstance(field, dict):
|
||||
field_text = field.get("text", {})
|
||||
if isinstance(field_text, dict):
|
||||
c = field_text.get("content", "")
|
||||
if c:
|
||||
parts.append(c)
|
||||
|
||||
elif tag == "a":
|
||||
href = element.get("href", "")
|
||||
text = element.get("text", "")
|
||||
if href:
|
||||
parts.append(f"link: {href}")
|
||||
if text:
|
||||
parts.append(text)
|
||||
|
||||
elif tag == "button":
|
||||
text = element.get("text", {})
|
||||
if isinstance(text, dict):
|
||||
c = text.get("content", "")
|
||||
if c:
|
||||
parts.append(c)
|
||||
url = element.get("url", "") or element.get("multi_url", {}).get("url", "")
|
||||
if url:
|
||||
parts.append(f"link: {url}")
|
||||
|
||||
elif tag == "img":
|
||||
alt = element.get("alt", {})
|
||||
parts.append(alt.get("content", "[image]") if isinstance(alt, dict) else "[image]")
|
||||
|
||||
elif tag == "note":
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
elif tag == "column_set":
|
||||
for col in element.get("columns", []):
|
||||
for ce in col.get("elements", []):
|
||||
parts.extend(_extract_element_content(ce))
|
||||
|
||||
elif tag == "plain_text":
|
||||
content = element.get("content", "")
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
else:
|
||||
for ne in element.get("elements", []):
|
||||
parts.extend(_extract_element_content(ne))
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||
"""Extract text and image keys from Feishu post (rich text) message.
|
||||
|
||||
Handles three payload shapes:
|
||||
- Direct: {"title": "...", "content": [[...]]}
|
||||
- Localized: {"zh_cn": {"title": "...", "content": [...]}}
|
||||
- Wrapped: {"post": {"zh_cn": {"title": "...", "content": [...]}}}
|
||||
"""
|
||||
|
||||
def _parse_block(block: dict) -> tuple[str | None, list[str]]:
|
||||
if not isinstance(block, dict) or not isinstance(block.get("content"), list):
|
||||
return None, []
|
||||
texts, images = [], []
|
||||
if title := block.get("title"):
|
||||
texts.append(title)
|
||||
for row in block["content"]:
|
||||
if not isinstance(row, list):
|
||||
continue
|
||||
for el in row:
|
||||
if not isinstance(el, dict):
|
||||
continue
|
||||
tag = el.get("tag")
|
||||
if tag in ("text", "a"):
|
||||
texts.append(el.get("text", ""))
|
||||
elif tag == "at":
|
||||
texts.append(f"@{el.get('user_name', 'user')}")
|
||||
elif tag == "img" and (key := el.get("image_key")):
|
||||
images.append(key)
|
||||
return (" ".join(texts).strip() or None), images
|
||||
|
||||
# Unwrap optional {"post": ...} envelope
|
||||
root = content_json
|
||||
if isinstance(root, dict) and isinstance(root.get("post"), dict):
|
||||
root = root["post"]
|
||||
if not isinstance(root, dict):
|
||||
return "", []
|
||||
|
||||
# Direct format
|
||||
if "content" in root:
|
||||
text, imgs = _parse_block(root)
|
||||
if text or imgs:
|
||||
return text or "", imgs
|
||||
|
||||
# Localized: prefer known locales, then fall back to any dict child
|
||||
for key in ("zh_cn", "en_us", "ja_jp"):
|
||||
if key in root:
|
||||
text, imgs = _parse_block(root[key])
|
||||
if text or imgs:
|
||||
return text or "", imgs
|
||||
for val in root.values():
|
||||
if isinstance(val, dict):
|
||||
text, imgs = _parse_block(val)
|
||||
if text or imgs:
|
||||
return text or "", imgs
|
||||
|
||||
return "", []
|
||||
|
||||
|
||||
def _extract_post_text(content_json: dict) -> str:
|
||||
"""Extract plain text from Feishu post (rich text) message content.
|
||||
|
||||
Legacy wrapper for _extract_post_content, returns only text.
|
||||
"""
|
||||
text, _ = _extract_post_content(content_json)
|
||||
return text
|
||||
|
||||
|
||||
class FeishuChannel(BaseChannel):
|
||||
"""
|
||||
Feishu/Lark channel using WebSocket long connection.
|
||||
|
||||
Uses WebSocket to receive events - no public IP or webhook required.
|
||||
|
||||
Requires:
|
||||
- App ID and App Secret from Feishu Open Platform
|
||||
- Bot capability enabled
|
||||
- Event subscription enabled (im.message.receive_v1)
|
||||
"""
|
||||
|
||||
name = "feishu"
|
||||
|
||||
def __init__(self, config: FeishuConfig, bus: MessageBus, groq_api_key: str = ""):
|
||||
super().__init__(config, bus)
|
||||
self.config: FeishuConfig = config
|
||||
self.groq_api_key = groq_api_key
|
||||
self._client: Any = None
|
||||
self._ws_client: Any = None
|
||||
self._ws_thread: threading.Thread | None = None
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
@staticmethod
|
||||
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
||||
"""Register an event handler only when the SDK supports it."""
|
||||
method = getattr(builder, method_name, None)
|
||||
return method(handler) if callable(method) else builder
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Feishu bot with WebSocket long connection."""
|
||||
if not FEISHU_AVAILABLE:
|
||||
logger.error("Feishu SDK not installed. Run: pip install lark-oapi")
|
||||
return
|
||||
|
||||
if not self.config.app_id or not self.config.app_secret:
|
||||
logger.error("Feishu app_id and app_secret not configured")
|
||||
return
|
||||
|
||||
import lark_oapi as lark
|
||||
self._running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Create Lark client for sending messages
|
||||
self._client = lark.Client.builder() \
|
||||
.app_id(self.config.app_id) \
|
||||
.app_secret(self.config.app_secret) \
|
||||
.log_level(lark.LogLevel.INFO) \
|
||||
.build()
|
||||
builder = lark.EventDispatcherHandler.builder(
|
||||
self.config.encrypt_key or "",
|
||||
self.config.verification_token or "",
|
||||
).register_p2_im_message_receive_v1(
|
||||
self._on_message_sync
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_message_read_v1", self._on_message_read
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder,
|
||||
"register_p2_im_chat_access_event_bot_p2p_chat_entered_v1",
|
||||
self._on_bot_p2p_chat_entered,
|
||||
)
|
||||
event_handler = builder.build()
|
||||
|
||||
# Create WebSocket client for long connection
|
||||
self._ws_client = lark.ws.Client(
|
||||
self.config.app_id,
|
||||
self.config.app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.INFO
|
||||
)
|
||||
|
||||
# Start WebSocket client in a separate thread with reconnect loop.
|
||||
# A dedicated event loop is created for this thread so that lark_oapi's
|
||||
# module-level `loop = asyncio.get_event_loop()` picks up an idle loop
|
||||
# instead of the already-running main asyncio loop, which would cause
|
||||
# "This event loop is already running" errors.
|
||||
def run_ws():
|
||||
import time
|
||||
import lark_oapi.ws.client as _lark_ws_client
|
||||
ws_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(ws_loop)
|
||||
# Patch the module-level loop used by lark's ws Client.start()
|
||||
_lark_ws_client.loop = ws_loop
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
self._ws_client.start()
|
||||
except Exception as e:
|
||||
logger.warning("Feishu WebSocket error: {}", e)
|
||||
if self._running:
|
||||
time.sleep(5)
|
||||
finally:
|
||||
ws_loop.close()
|
||||
|
||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||
self._ws_thread.start()
|
||||
|
||||
logger.info("Feishu bot started with WebSocket long connection")
|
||||
logger.info("No public IP required - using WebSocket to receive events")
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the Feishu bot.
|
||||
|
||||
Notice: lark.ws.Client does not expose stop method, simply exiting the program will close the client.
|
||||
|
||||
Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86
|
||||
"""
|
||||
self._running = False
|
||||
logger.info("Feishu bot stopped")
|
||||
|
||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
||||
try:
|
||||
request = CreateMessageReactionRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.request_body(
|
||||
CreateMessageReactionRequestBody.builder()
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
|
||||
.build()
|
||||
).build()
|
||||
|
||||
response = self._client.im.v1.message_reaction.create(request)
|
||||
|
||||
if not response.success():
|
||||
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
||||
else:
|
||||
logger.debug("Added {} reaction to message {}", emoji_type, message_id)
|
||||
except Exception as e:
|
||||
logger.warning("Error adding reaction: {}", e)
|
||||
|
||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> None:
|
||||
"""
|
||||
Add a reaction emoji to a message (non-blocking).
|
||||
|
||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._add_reaction_sync, message_id, emoji_type)
|
||||
|
||||
# Regex to match markdown tables (header + separator + data rows)
|
||||
_TABLE_RE = re.compile(
|
||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
_HEADING_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
||||
|
||||
_CODE_BLOCK_RE = re.compile(r"(```[\s\S]*?```)", re.MULTILINE)
|
||||
|
||||
@staticmethod
|
||||
def _parse_md_table(table_text: str) -> dict | None:
|
||||
"""Parse a markdown table into a Feishu table element."""
|
||||
lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
|
||||
if len(lines) < 3:
|
||||
return None
|
||||
def split(_line: str) -> list[str]:
|
||||
return [c.strip() for c in _line.strip("|").split("|")]
|
||||
headers = split(lines[0])
|
||||
rows = [split(_line) for _line in lines[2:]]
|
||||
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||
for i, h in enumerate(headers)]
|
||||
return {
|
||||
"tag": "table",
|
||||
"page_size": len(rows) + 1,
|
||||
"columns": columns,
|
||||
"rows": [{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows],
|
||||
}
|
||||
|
||||
def _build_card_elements(self, content: str) -> list[dict]:
|
||||
"""Split content into div/markdown + table elements for Feishu card."""
|
||||
elements, last_end = [], 0
|
||||
for m in self._TABLE_RE.finditer(content):
|
||||
before = content[last_end:m.start()]
|
||||
if before.strip():
|
||||
elements.extend(self._split_headings(before))
|
||||
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
|
||||
last_end = m.end()
|
||||
remaining = content[last_end:]
|
||||
if remaining.strip():
|
||||
elements.extend(self._split_headings(remaining))
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]:
|
||||
"""Split card elements into groups with at most *max_tables* table elements each.
|
||||
|
||||
Feishu cards have a hard limit of one table per card (API error 11310).
|
||||
When the rendered content contains multiple markdown tables each table is
|
||||
placed in a separate card message so every table reaches the user.
|
||||
"""
|
||||
if not elements:
|
||||
return [[]]
|
||||
groups: list[list[dict]] = []
|
||||
current: list[dict] = []
|
||||
table_count = 0
|
||||
for el in elements:
|
||||
if el.get("tag") == "table":
|
||||
if table_count >= max_tables:
|
||||
if current:
|
||||
groups.append(current)
|
||||
current = []
|
||||
table_count = 0
|
||||
current.append(el)
|
||||
table_count += 1
|
||||
else:
|
||||
current.append(el)
|
||||
if current:
|
||||
groups.append(current)
|
||||
return groups or [[]]
|
||||
|
||||
def _split_headings(self, content: str) -> list[dict]:
|
||||
"""Split content by headings, converting headings to div elements."""
|
||||
protected = content
|
||||
code_blocks = []
|
||||
for m in self._CODE_BLOCK_RE.finditer(content):
|
||||
code_blocks.append(m.group(1))
|
||||
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1)
|
||||
|
||||
elements = []
|
||||
last_end = 0
|
||||
for m in self._HEADING_RE.finditer(protected):
|
||||
before = protected[last_end:m.start()].strip()
|
||||
if before:
|
||||
elements.append({"tag": "markdown", "content": before})
|
||||
text = m.group(2).strip()
|
||||
elements.append({
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": f"**{text}**",
|
||||
},
|
||||
})
|
||||
last_end = m.end()
|
||||
remaining = protected[last_end:].strip()
|
||||
if remaining:
|
||||
elements.append({"tag": "markdown", "content": remaining})
|
||||
|
||||
for i, cb in enumerate(code_blocks):
|
||||
for el in elements:
|
||||
if el.get("tag") == "markdown":
|
||||
el["content"] = el["content"].replace(f"\x00CODE{i}\x00", cb)
|
||||
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
# ── Smart format detection ──────────────────────────────────────────
|
||||
# Patterns that indicate "complex" markdown needing card rendering
|
||||
_COMPLEX_MD_RE = re.compile(
|
||||
r"```" # fenced code block
|
||||
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
|
||||
r"|^#{1,6}\s+" # headings
|
||||
, re.MULTILINE,
|
||||
)
|
||||
|
||||
# Simple markdown patterns (bold, italic, strikethrough)
|
||||
_SIMPLE_MD_RE = re.compile(
|
||||
r"\*\*.+?\*\*" # **bold**
|
||||
r"|__.+?__" # __bold__
|
||||
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
|
||||
r"|~~.+?~~" # ~~strikethrough~~
|
||||
, re.DOTALL,
|
||||
)
|
||||
|
||||
# Markdown link: [text](url)
|
||||
_MD_LINK_RE = re.compile(r"\[([^\]]+)\]\((https?://[^\)]+)\)")
|
||||
|
||||
# Unordered list items
|
||||
_LIST_RE = re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE)
|
||||
|
||||
# Ordered list items
|
||||
_OLIST_RE = re.compile(r"^[\s]*\d+\.\s+", re.MULTILINE)
|
||||
|
||||
# Max length for plain text format
|
||||
_TEXT_MAX_LEN = 200
|
||||
|
||||
# Max length for post (rich text) format; beyond this, use card
|
||||
_POST_MAX_LEN = 2000
|
||||
|
||||
@classmethod
|
||||
def _detect_msg_format(cls, content: str) -> str:
|
||||
"""Determine the optimal Feishu message format for *content*.
|
||||
|
||||
Returns one of:
|
||||
- ``"text"`` – plain text, short and no markdown
|
||||
- ``"post"`` – rich text (links only, moderate length)
|
||||
- ``"interactive"`` – card with full markdown rendering
|
||||
"""
|
||||
stripped = content.strip()
|
||||
|
||||
# Complex markdown (code blocks, tables, headings) → always card
|
||||
if cls._COMPLEX_MD_RE.search(stripped):
|
||||
return "interactive"
|
||||
|
||||
# Long content → card (better readability with card layout)
|
||||
if len(stripped) > cls._POST_MAX_LEN:
|
||||
return "interactive"
|
||||
|
||||
# Has bold/italic/strikethrough → card (post format can't render these)
|
||||
if cls._SIMPLE_MD_RE.search(stripped):
|
||||
return "interactive"
|
||||
|
||||
# Has list items → card (post format can't render list bullets well)
|
||||
if cls._LIST_RE.search(stripped) or cls._OLIST_RE.search(stripped):
|
||||
return "interactive"
|
||||
|
||||
# Has links → post format (supports <a> tags)
|
||||
if cls._MD_LINK_RE.search(stripped):
|
||||
return "post"
|
||||
|
||||
# Short plain text → text format
|
||||
if len(stripped) <= cls._TEXT_MAX_LEN:
|
||||
return "text"
|
||||
|
||||
# Medium plain text without any formatting → post format
|
||||
return "post"
|
||||
|
||||
@classmethod
|
||||
def _markdown_to_post(cls, content: str) -> str:
|
||||
"""Convert markdown content to Feishu post message JSON.
|
||||
|
||||
Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
|
||||
Each line becomes a paragraph (row) in the post body.
|
||||
"""
|
||||
lines = content.strip().split("\n")
|
||||
paragraphs: list[list[dict]] = []
|
||||
|
||||
for line in lines:
|
||||
elements: list[dict] = []
|
||||
last_end = 0
|
||||
|
||||
for m in cls._MD_LINK_RE.finditer(line):
|
||||
# Text before this link
|
||||
before = line[last_end:m.start()]
|
||||
if before:
|
||||
elements.append({"tag": "text", "text": before})
|
||||
elements.append({
|
||||
"tag": "a",
|
||||
"text": m.group(1),
|
||||
"href": m.group(2),
|
||||
})
|
||||
last_end = m.end()
|
||||
|
||||
# Remaining text after last link
|
||||
remaining = line[last_end:]
|
||||
if remaining:
|
||||
elements.append({"tag": "text", "text": remaining})
|
||||
|
||||
# Empty line → empty paragraph for spacing
|
||||
if not elements:
|
||||
elements.append({"tag": "text", "text": ""})
|
||||
|
||||
paragraphs.append(elements)
|
||||
|
||||
post_body = {
|
||||
"zh_cn": {
|
||||
"content": paragraphs,
|
||||
}
|
||||
}
|
||||
return json.dumps(post_body, ensure_ascii=False)
|
||||
|
||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
|
||||
_AUDIO_EXTS = {".opus"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
|
||||
_FILE_TYPE_MAP = {
|
||||
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
||||
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
||||
}
|
||||
|
||||
def _upload_image_sync(self, file_path: str) -> str | None:
|
||||
"""Upload an image to Feishu and return the image_key."""
|
||||
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateImageRequest.builder() \
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type("message")
|
||||
.image(f)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.image.create(request)
|
||||
if response.success():
|
||||
image_key = response.data.image_key
|
||||
logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
|
||||
return image_key
|
||||
else:
|
||||
logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading image {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _upload_file_sync(self, file_path: str) -> str | None:
|
||||
"""Upload a file to Feishu and return the file_key."""
|
||||
from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
|
||||
file_name = os.path.basename(file_path)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateFileRequest.builder() \
|
||||
.request_body(
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(file_name)
|
||||
.file(f)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.file.create(request)
|
||||
if response.success():
|
||||
file_key = response.data.file_key
|
||||
logger.debug("Uploaded file {}: {}", file_name, file_key)
|
||||
return file_key
|
||||
else:
|
||||
logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading file {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
|
||||
"""Download an image from Feishu message by message_id and image_key."""
|
||||
from lark_oapi.api.im.v1 import GetMessageResourceRequest
|
||||
try:
|
||||
request = GetMessageResourceRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.file_key(image_key) \
|
||||
.type("image") \
|
||||
.build()
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
# GetMessageResourceRequest returns BytesIO, need to read bytes
|
||||
if hasattr(file_data, 'read'):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download image: code={}, msg={}", response.code, response.msg)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("Error downloading image {}: {}", image_key, e)
|
||||
return None, None
|
||||
|
||||
def _download_file_sync(
|
||||
self, message_id: str, file_key: str, resource_type: str = "file"
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
|
||||
from lark_oapi.api.im.v1 import GetMessageResourceRequest
|
||||
|
||||
# Feishu API only accepts 'image' or 'file' as type parameter
|
||||
# Convert 'audio' to 'file' for API compatibility
|
||||
if resource_type == "audio":
|
||||
resource_type = "file"
|
||||
|
||||
try:
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message_id)
|
||||
.file_key(file_key)
|
||||
.type(resource_type)
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
if hasattr(file_data, "read"):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
|
||||
return None, None
|
||||
except Exception:
|
||||
logger.exception("Error downloading {} {}", resource_type, file_key)
|
||||
return None, None
|
||||
|
||||
async def _download_and_save_media(
|
||||
self,
|
||||
msg_type: str,
|
||||
content_json: dict,
|
||||
message_id: str | None = None
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Download media from Feishu and save to local disk.
|
||||
|
||||
Returns:
|
||||
(file_path, content_text) - file_path is None if download failed
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
media_dir = get_media_dir("feishu")
|
||||
|
||||
data, filename = None, None
|
||||
|
||||
if msg_type == "image":
|
||||
image_key = content_json.get("image_key")
|
||||
if image_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_image_sync, message_id, image_key
|
||||
)
|
||||
if not filename:
|
||||
filename = f"{image_key[:16]}.jpg"
|
||||
|
||||
elif msg_type in ("audio", "file", "media"):
|
||||
file_key = content_json.get("file_key")
|
||||
if file_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
if not filename:
|
||||
ext = {"audio": ".opus", "media": ".mp4"}.get(msg_type, "")
|
||||
filename = f"{file_key[:16]}{ext}"
|
||||
|
||||
if data and filename:
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
logger.debug("Downloaded {} to {}", msg_type, file_path)
|
||||
return str(file_path), f"[{msg_type}: {filename}]"
|
||||
|
||||
return None, f"[{msg_type}: download failed]"
|
||||
|
||||
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
|
||||
"""Send a single message (text/image/file/interactive) synchronously."""
|
||||
from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
|
||||
try:
|
||||
request = CreateMessageRequest.builder() \
|
||||
.receive_id_type(receive_id_type) \
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.msg_type(msg_type)
|
||||
.content(content)
|
||||
.build()
|
||||
).build()
|
||||
response = self._client.im.v1.message.create(request)
|
||||
if not response.success():
|
||||
logger.error(
|
||||
"Failed to send Feishu {} message: code={}, msg={}, log_id={}",
|
||||
msg_type, response.code, response.msg, response.get_log_id()
|
||||
)
|
||||
return False
|
||||
logger.debug("Feishu {} message sent to {}", msg_type, receive_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
||||
return False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Feishu, including media (images/files) if present."""
|
||||
if not self._client:
|
||||
logger.warning("Feishu client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for file_path in msg.media:
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("Media file not found: {}", file_path)
|
||||
continue
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
if ext in self._IMAGE_EXTS:
|
||||
key = await loop.run_in_executor(None, self._upload_image_sync, file_path)
|
||||
if key:
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "image", json.dumps({"image_key": key}, ensure_ascii=False),
|
||||
)
|
||||
else:
|
||||
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
||||
if key:
|
||||
# Use msg_type "media" for audio/video so users can play inline;
|
||||
# "file" for everything else (documents, archives, etc.)
|
||||
if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS:
|
||||
media_type = "media"
|
||||
else:
|
||||
media_type = "file"
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
||||
)
|
||||
|
||||
if msg.content and msg.content.strip():
|
||||
fmt = self._detect_msg_format(msg.content)
|
||||
|
||||
if fmt == "text":
|
||||
# Short plain text – send as simple text message
|
||||
text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "text", text_body,
|
||||
)
|
||||
|
||||
elif fmt == "post":
|
||||
# Medium content with links – send as rich-text post
|
||||
post_body = self._markdown_to_post(msg.content)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "post", post_body,
|
||||
)
|
||||
|
||||
else:
|
||||
# Complex / long content – send as interactive card
|
||||
elements = self._build_card_elements(msg.content)
|
||||
for chunk in self._split_elements_by_table_limit(elements):
|
||||
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending Feishu message: {}", e)
|
||||
|
||||
def _on_message_sync(self, data: Any) -> None:
|
||||
"""
|
||||
Sync handler for incoming messages (called from WebSocket thread).
|
||||
Schedules async handling in the main event loop.
|
||||
"""
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||
|
||||
async def _on_message(self, data: Any) -> None:
|
||||
"""Handle incoming message from Feishu."""
|
||||
try:
|
||||
event = data.event
|
||||
message = event.message
|
||||
sender = event.sender
|
||||
|
||||
# Deduplication check
|
||||
message_id = message.message_id
|
||||
if message_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[message_id] = None
|
||||
|
||||
# Trim cache
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Skip bot messages
|
||||
if sender.sender_type == "bot":
|
||||
return
|
||||
|
||||
sender_id = sender.sender_id.open_id if sender.sender_id else "unknown"
|
||||
chat_id = message.chat_id
|
||||
chat_type = message.chat_type
|
||||
msg_type = message.message_type
|
||||
|
||||
# Add reaction
|
||||
await self._add_reaction(message_id, self.config.react_emoji)
|
||||
|
||||
# Parse content
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
try:
|
||||
content_json = json.loads(message.content) if message.content else {}
|
||||
except json.JSONDecodeError:
|
||||
content_json = {}
|
||||
|
||||
if msg_type == "text":
|
||||
text = content_json.get("text", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type == "post":
|
||||
text, image_keys = _extract_post_content(content_json)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
# Download images embedded in post
|
||||
for img_key in image_keys:
|
||||
file_path, content_text = await self._download_and_save_media(
|
||||
"image", {"image_key": img_key}, message_id
|
||||
)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("image", "audio", "file", "media"):
|
||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
|
||||
# Transcribe audio using Groq Whisper
|
||||
if msg_type == "audio" and file_path and self.groq_api_key:
|
||||
try:
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
||||
transcription = await transcriber.transcribe(file_path)
|
||||
if transcription:
|
||||
content_text = f"[transcription: {transcription}]"
|
||||
except Exception as e:
|
||||
logger.warning("Failed to transcribe audio: {}", e)
|
||||
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
||||
# Handle share cards and interactive messages
|
||||
text = _extract_share_card_content(content_json, msg_type)
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
# Forward to message bus
|
||||
reply_to = chat_id if chat_type == "group" else sender_id
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=reply_to,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"chat_type": chat_type,
|
||||
"msg_type": msg_type,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing Feishu message: {}", e)
|
||||
|
||||
def _on_reaction_created(self, data: Any) -> None:
|
||||
"""Ignore reaction events so they do not generate SDK noise."""
|
||||
pass
|
||||
|
||||
def _on_message_read(self, data: Any) -> None:
|
||||
"""Ignore read events so they do not generate SDK noise."""
|
||||
pass
|
||||
|
||||
def _on_bot_p2p_chat_entered(self, data: Any) -> None:
|
||||
"""Ignore p2p-enter events when a user opens a bot chat."""
|
||||
logger.debug("Bot entered p2p chat (user opened chat window)")
|
||||
pass
|
||||
@@ -0,0 +1,256 @@
|
||||
"""Channel manager for coordinating chat channels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""
|
||||
Manages chat channels and coordinates message routing.
|
||||
|
||||
Responsibilities:
|
||||
- Initialize enabled channels (Telegram, WhatsApp, etc.)
|
||||
- Start/stop channels
|
||||
- Route outbound messages
|
||||
"""
|
||||
|
||||
def __init__(self, config: Config, bus: MessageBus):
|
||||
self.config = config
|
||||
self.bus = bus
|
||||
self.channels: dict[str, BaseChannel] = {}
|
||||
self._dispatch_task: asyncio.Task | None = None
|
||||
|
||||
self._init_channels()
|
||||
|
||||
def _init_channels(self) -> None:
|
||||
"""Initialize channels based on config."""
|
||||
|
||||
# Telegram channel
|
||||
if self.config.channels.telegram.enabled:
|
||||
try:
|
||||
from nanobot.channels.telegram import TelegramChannel
|
||||
self.channels["telegram"] = TelegramChannel(
|
||||
self.config.channels.telegram,
|
||||
self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
)
|
||||
logger.info("Telegram channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Telegram channel not available: {}", e)
|
||||
|
||||
# WhatsApp channel
|
||||
if self.config.channels.whatsapp.enabled:
|
||||
try:
|
||||
from nanobot.channels.whatsapp import WhatsAppChannel
|
||||
self.channels["whatsapp"] = WhatsAppChannel(
|
||||
self.config.channels.whatsapp, self.bus
|
||||
)
|
||||
logger.info("WhatsApp channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("WhatsApp channel not available: {}", e)
|
||||
|
||||
# Discord channel
|
||||
if self.config.channels.discord.enabled:
|
||||
try:
|
||||
from nanobot.channels.discord import DiscordChannel
|
||||
self.channels["discord"] = DiscordChannel(
|
||||
self.config.channels.discord, self.bus
|
||||
)
|
||||
logger.info("Discord channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Discord channel not available: {}", e)
|
||||
|
||||
# Feishu channel
|
||||
if self.config.channels.feishu.enabled:
|
||||
try:
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
self.channels["feishu"] = FeishuChannel(
|
||||
self.config.channels.feishu, self.bus,
|
||||
groq_api_key=self.config.providers.groq.api_key,
|
||||
)
|
||||
logger.info("Feishu channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Feishu channel not available: {}", e)
|
||||
|
||||
# Mochat channel
|
||||
if self.config.channels.mochat.enabled:
|
||||
try:
|
||||
from nanobot.channels.mochat import MochatChannel
|
||||
|
||||
self.channels["mochat"] = MochatChannel(
|
||||
self.config.channels.mochat, self.bus
|
||||
)
|
||||
logger.info("Mochat channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Mochat channel not available: {}", e)
|
||||
|
||||
# DingTalk channel
|
||||
if self.config.channels.dingtalk.enabled:
|
||||
try:
|
||||
from nanobot.channels.dingtalk import DingTalkChannel
|
||||
self.channels["dingtalk"] = DingTalkChannel(
|
||||
self.config.channels.dingtalk, self.bus
|
||||
)
|
||||
logger.info("DingTalk channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("DingTalk channel not available: {}", e)
|
||||
|
||||
# Email channel
|
||||
if self.config.channels.email.enabled:
|
||||
try:
|
||||
from nanobot.channels.email import EmailChannel
|
||||
self.channels["email"] = EmailChannel(
|
||||
self.config.channels.email, self.bus
|
||||
)
|
||||
logger.info("Email channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Email channel not available: {}", e)
|
||||
|
||||
# Slack channel
|
||||
if self.config.channels.slack.enabled:
|
||||
try:
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
self.channels["slack"] = SlackChannel(
|
||||
self.config.channels.slack, self.bus
|
||||
)
|
||||
logger.info("Slack channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Slack channel not available: {}", e)
|
||||
|
||||
# QQ channel
|
||||
if self.config.channels.qq.enabled:
|
||||
try:
|
||||
from nanobot.channels.qq import QQChannel
|
||||
self.channels["qq"] = QQChannel(
|
||||
self.config.channels.qq,
|
||||
self.bus,
|
||||
)
|
||||
logger.info("QQ channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("QQ channel not available: {}", e)
|
||||
|
||||
# Matrix channel
|
||||
if self.config.channels.matrix.enabled:
|
||||
try:
|
||||
from nanobot.channels.matrix import MatrixChannel
|
||||
self.channels["matrix"] = MatrixChannel(
|
||||
self.config.channels.matrix,
|
||||
self.bus,
|
||||
)
|
||||
logger.info("Matrix channel enabled")
|
||||
except ImportError as e:
|
||||
logger.warning("Matrix channel not available: {}", e)
|
||||
|
||||
self._validate_allow_from()
|
||||
|
||||
def _validate_allow_from(self) -> None:
|
||||
for name, ch in self.channels.items():
|
||||
if getattr(ch.config, "allow_from", None) == []:
|
||||
raise SystemExit(
|
||||
f'Error: "{name}" has empty allowFrom (denies all). '
|
||||
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
||||
)
|
||||
|
||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||
"""Start a channel and log any exceptions."""
|
||||
try:
|
||||
await channel.start()
|
||||
except Exception as e:
|
||||
logger.error("Failed to start channel {}: {}", name, e)
|
||||
|
||||
async def start_all(self) -> None:
|
||||
"""Start all channels and the outbound dispatcher."""
|
||||
if not self.channels:
|
||||
logger.warning("No channels enabled")
|
||||
return
|
||||
|
||||
# Start outbound dispatcher
|
||||
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
|
||||
|
||||
# Start channels
|
||||
tasks = []
|
||||
for name, channel in self.channels.items():
|
||||
logger.info("Starting {} channel...", name)
|
||||
tasks.append(asyncio.create_task(self._start_channel(name, channel)))
|
||||
|
||||
# Wait for all to complete (they should run forever)
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
"""Stop all channels and the dispatcher."""
|
||||
logger.info("Stopping all channels...")
|
||||
|
||||
# Stop dispatcher
|
||||
if self._dispatch_task:
|
||||
self._dispatch_task.cancel()
|
||||
try:
|
||||
await self._dispatch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Stop all channels
|
||||
for name, channel in self.channels.items():
|
||||
try:
|
||||
await channel.stop()
|
||||
logger.info("Stopped {} channel", name)
|
||||
except Exception as e:
|
||||
logger.error("Error stopping {}: {}", name, e)
|
||||
|
||||
async def _dispatch_outbound(self) -> None:
|
||||
"""Dispatch outbound messages to the appropriate channel."""
|
||||
logger.info("Outbound dispatcher started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(
|
||||
self.bus.consume_outbound(),
|
||||
timeout=1.0
|
||||
)
|
||||
|
||||
if msg.metadata.get("_progress"):
|
||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||
continue
|
||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||
continue
|
||||
|
||||
channel = self.channels.get(msg.channel)
|
||||
if channel:
|
||||
try:
|
||||
await channel.send(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending to {}: {}", msg.channel, e)
|
||||
else:
|
||||
logger.warning("Unknown channel: {}", msg.channel)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
def get_channel(self, name: str) -> BaseChannel | None:
|
||||
"""Get a channel by name."""
|
||||
return self.channels.get(name)
|
||||
|
||||
def get_status(self) -> dict[str, Any]:
|
||||
"""Get status of all channels."""
|
||||
return {
|
||||
name: {
|
||||
"enabled": True,
|
||||
"running": channel.is_running
|
||||
}
|
||||
for name, channel in self.channels.items()
|
||||
}
|
||||
|
||||
@property
|
||||
def enabled_channels(self) -> list[str]:
|
||||
"""Get list of enabled channel names."""
|
||||
return list(self.channels.keys())
|
||||
@@ -0,0 +1,697 @@
|
||||
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import nh3
|
||||
from mistune import create_markdown
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
AsyncClientConfig,
|
||||
ContentRepositoryConfigError,
|
||||
DownloadError,
|
||||
InviteEvent,
|
||||
JoinError,
|
||||
MatrixRoom,
|
||||
MemoryDownloadResponse,
|
||||
RoomEncryptedMedia,
|
||||
RoomMessage,
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
|
||||
) from e
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_data_dir, get_media_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
|
||||
TYPING_NOTICE_TIMEOUT_MS = 30_000
|
||||
# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
|
||||
TYPING_KEEPALIVE_INTERVAL_MS = 20_000
|
||||
MATRIX_HTML_FORMAT = "org.matrix.custom.html"
|
||||
_ATTACH_MARKER = "[attachment: {}]"
|
||||
_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
|
||||
_ATTACH_FAILED = "[attachment: {} - download failed]"
|
||||
_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
|
||||
_DEFAULT_ATTACH_NAME = "attachment"
|
||||
_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
|
||||
|
||||
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||
|
||||
MATRIX_MARKDOWN = create_markdown(
|
||||
escape=True,
|
||||
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
|
||||
)
|
||||
|
||||
MATRIX_ALLOWED_HTML_TAGS = {
|
||||
"p", "a", "strong", "em", "del", "code", "pre", "blockquote",
|
||||
"ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
|
||||
"hr", "br", "table", "thead", "tbody", "tr", "th", "td",
|
||||
"caption", "sup", "sub", "img",
|
||||
}
|
||||
MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
|
||||
"a": {"href"}, "code": {"class"}, "ol": {"start"},
|
||||
"img": {"src", "alt", "title", "width", "height"},
|
||||
}
|
||||
MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
|
||||
|
||||
|
||||
def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
|
||||
"""Filter attribute values to a safe Matrix-compatible subset."""
|
||||
if tag == "a" and attr == "href":
|
||||
return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
|
||||
if tag == "img" and attr == "src":
|
||||
return value if value.lower().startswith("mxc://") else None
|
||||
if tag == "code" and attr == "class":
|
||||
classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
|
||||
return " ".join(classes) if classes else None
|
||||
return value
|
||||
|
||||
|
||||
MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||
tags=MATRIX_ALLOWED_HTML_TAGS,
|
||||
attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
|
||||
attribute_filter=_filter_matrix_html_attribute,
|
||||
url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
|
||||
strip_comments=True,
|
||||
link_rel="noopener noreferrer",
|
||||
)
|
||||
|
||||
|
||||
def _render_markdown_html(text: str) -> str | None:
|
||||
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||
try:
|
||||
formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
|
||||
except Exception:
|
||||
return None
|
||||
if not formatted:
|
||||
return None
|
||||
# Skip formatted_body for plain <p>text</p> to keep payload minimal.
|
||||
if formatted.startswith("<p>") and formatted.endswith("</p>"):
|
||||
inner = formatted[3:-4]
|
||||
if "<" not in inner and ">" not in inner:
|
||||
return None
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||
if html := _render_markdown_html(text):
|
||||
content["format"] = MATRIX_HTML_FORMAT
|
||||
content["formatted_body"] = html
|
||||
return content
|
||||
|
||||
|
||||
class _NioLoguruHandler(logging.Handler):
|
||||
"""Route matrix-nio stdlib logs into Loguru."""
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame and frame.f_code.co_filename == logging.__file__:
|
||||
frame, depth = frame.f_back, depth + 1
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def _configure_nio_logging_bridge() -> None:
|
||||
"""Bridge matrix-nio logs to Loguru (idempotent)."""
|
||||
nio_logger = logging.getLogger("nio")
|
||||
if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
|
||||
nio_logger.handlers = [_NioLoguruHandler()]
|
||||
nio_logger.propagate = False
|
||||
|
||||
|
||||
class MatrixChannel(BaseChannel):
|
||||
"""Matrix (Element) channel using long-polling sync."""
|
||||
|
||||
name = "matrix"
|
||||
|
||||
def __init__(self, config: Any, bus, *, restrict_to_workspace: bool = False,
|
||||
workspace: Path | None = None):
|
||||
super().__init__(config, bus)
|
||||
self.client: AsyncClient | None = None
|
||||
self._sync_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._restrict_to_workspace = restrict_to_workspace
|
||||
self._workspace = workspace.expanduser().resolve() if workspace else None
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Matrix client and begin sync loop."""
|
||||
self._running = True
|
||||
_configure_nio_logging_bridge()
|
||||
|
||||
store_path = get_data_dir() / "matrix-store"
|
||||
store_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.client = AsyncClient(
|
||||
homeserver=self.config.homeserver, user=self.config.user_id,
|
||||
store_path=store_path,
|
||||
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
|
||||
)
|
||||
self.client.user_id = self.config.user_id
|
||||
self.client.access_token = self.config.access_token
|
||||
self.client.device_id = self.config.device_id
|
||||
|
||||
self._register_event_callbacks()
|
||||
self._register_response_callbacks()
|
||||
|
||||
if not self.config.e2ee_enabled:
|
||||
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
|
||||
|
||||
if self.config.device_id:
|
||||
try:
|
||||
self.client.load_store()
|
||||
except Exception:
|
||||
logger.exception("Matrix store load failed; restart may replay recent messages.")
|
||||
else:
|
||||
logger.warning("Matrix device_id empty; restart may replay recent messages.")
|
||||
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Matrix channel with graceful sync shutdown."""
|
||||
self._running = False
|
||||
for room_id in list(self._typing_tasks):
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
if self.client:
|
||||
self.client.stop_sync_forever()
|
||||
if self._sync_task:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(self._sync_task),
|
||||
timeout=self.config.sync_stop_grace_seconds)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
self._sync_task.cancel()
|
||||
try:
|
||||
await self._sync_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
||||
"""Check path is inside workspace (when restriction enabled)."""
|
||||
if not self._restrict_to_workspace or not self._workspace:
|
||||
return True
|
||||
try:
|
||||
path.resolve(strict=False).relative_to(self._workspace)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
|
||||
"""Deduplicate and resolve outbound attachment paths."""
|
||||
seen: set[str] = set()
|
||||
candidates: list[Path] = []
|
||||
for raw in media:
|
||||
if not isinstance(raw, str) or not raw.strip():
|
||||
continue
|
||||
path = Path(raw.strip()).expanduser()
|
||||
try:
|
||||
key = str(path.resolve(strict=False))
|
||||
except OSError:
|
||||
key = str(path)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
candidates.append(path)
|
||||
return candidates
|
||||
|
||||
@staticmethod
|
||||
def _build_outbound_attachment_content(
|
||||
*, filename: str, mime: str, size_bytes: int,
|
||||
mxc_url: str, encryption_info: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build Matrix content payload for an uploaded file/image/audio/video."""
|
||||
prefix = mime.split("/")[0]
|
||||
msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
|
||||
content: dict[str, Any] = {
|
||||
"msgtype": msgtype, "body": filename, "filename": filename,
|
||||
"info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
|
||||
}
|
||||
if encryption_info:
|
||||
content["file"] = {**encryption_info, "url": mxc_url}
|
||||
else:
|
||||
content["url"] = mxc_url
|
||||
return content
|
||||
|
||||
def _is_encrypted_room(self, room_id: str) -> bool:
|
||||
if not self.client:
|
||||
return False
|
||||
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||
return bool(getattr(room, "encrypted", False))
|
||||
|
||||
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||
"""Send m.room.message with E2EE options."""
|
||||
if not self.client:
|
||||
return
|
||||
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||
if self.config.e2ee_enabled:
|
||||
kwargs["ignore_unverified_devices"] = True
|
||||
await self.client.room_send(**kwargs)
|
||||
|
||||
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||
"""Query homeserver upload limit once per channel lifecycle."""
|
||||
if self._server_upload_limit_checked:
|
||||
return self._server_upload_limit_bytes
|
||||
self._server_upload_limit_checked = True
|
||||
if not self.client:
|
||||
return None
|
||||
try:
|
||||
response = await self.client.content_repository_config()
|
||||
except Exception:
|
||||
return None
|
||||
upload_size = getattr(response, "upload_size", None)
|
||||
if isinstance(upload_size, int) and upload_size > 0:
|
||||
self._server_upload_limit_bytes = upload_size
|
||||
return upload_size
|
||||
return None
|
||||
|
||||
async def _effective_media_limit_bytes(self) -> int:
|
||||
"""min(local config, server advertised) — 0 blocks all uploads."""
|
||||
local_limit = max(int(self.config.max_media_bytes), 0)
|
||||
server_limit = await self._resolve_server_upload_limit_bytes()
|
||||
if server_limit is None:
|
||||
return local_limit
|
||||
return min(local_limit, server_limit) if local_limit else 0
|
||||
|
||||
async def _upload_and_send_attachment(
|
||||
self, room_id: str, path: Path, limit_bytes: int,
|
||||
relates_to: dict[str, Any] | None = None,
|
||||
) -> str | None:
|
||||
"""Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
|
||||
if not self.client:
|
||||
return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
|
||||
|
||||
resolved = path.expanduser().resolve(strict=False)
|
||||
filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
|
||||
fail = _ATTACH_UPLOAD_FAILED.format(filename)
|
||||
|
||||
if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
|
||||
return fail
|
||||
try:
|
||||
size_bytes = resolved.stat().st_size
|
||||
except OSError:
|
||||
return fail
|
||||
if limit_bytes <= 0 or size_bytes > limit_bytes:
|
||||
return _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
|
||||
try:
|
||||
with resolved.open("rb") as f:
|
||||
upload_result = await self.client.upload(
|
||||
f, content_type=mime, filename=filename,
|
||||
encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
|
||||
filesize=size_bytes,
|
||||
)
|
||||
except Exception:
|
||||
return fail
|
||||
|
||||
upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
|
||||
encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
|
||||
if isinstance(upload_response, UploadError):
|
||||
return fail
|
||||
mxc_url = getattr(upload_response, "content_uri", None)
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return fail
|
||||
|
||||
content = self._build_outbound_attachment_content(
|
||||
filename=filename, mime=mime, size_bytes=size_bytes,
|
||||
mxc_url=mxc_url, encryption_info=encryption_info,
|
||||
)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
try:
|
||||
await self._send_room_content(room_id, content)
|
||||
except Exception:
|
||||
return fail
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound content; clear typing for non-progress messages."""
|
||||
if not self.client:
|
||||
return
|
||||
text = msg.content or ""
|
||||
candidates = self._collect_outbound_media_candidates(msg.media)
|
||||
relates_to = self._build_thread_relates_to(msg.metadata)
|
||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||
try:
|
||||
failures: list[str] = []
|
||||
if candidates:
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
for path in candidates:
|
||||
if fail := await self._upload_and_send_attachment(
|
||||
room_id=msg.chat_id,
|
||||
path=path,
|
||||
limit_bytes=limit_bytes,
|
||||
relates_to=relates_to,
|
||||
):
|
||||
failures.append(fail)
|
||||
if failures:
|
||||
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
||||
if text or not candidates:
|
||||
content = _build_matrix_text_content(text)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
await self._send_room_content(msg.chat_id, content)
|
||||
finally:
|
||||
if not is_progress:
|
||||
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||
|
||||
def _register_event_callbacks(self) -> None:
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||
|
||||
def _register_response_callbacks(self) -> None:
|
||||
self.client.add_response_callback(self._on_sync_error, SyncError)
|
||||
self.client.add_response_callback(self._on_join_error, JoinError)
|
||||
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
||||
|
||||
def _log_response_error(self, label: str, response: Any) -> None:
|
||||
"""Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
|
||||
code = getattr(response, "status_code", None)
|
||||
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
||||
is_fatal = is_auth or getattr(response, "soft_logout", False)
|
||||
(logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
|
||||
|
||||
async def _on_sync_error(self, response: SyncError) -> None:
|
||||
self._log_response_error("sync", response)
|
||||
|
||||
async def _on_join_error(self, response: JoinError) -> None:
|
||||
self._log_response_error("join", response)
|
||||
|
||||
async def _on_send_error(self, response: RoomSendError) -> None:
|
||||
self._log_response_error("send", response)
|
||||
|
||||
async def _set_typing(self, room_id: str, typing: bool) -> None:
|
||||
"""Best-effort typing indicator update."""
|
||||
if not self.client:
|
||||
return
|
||||
try:
|
||||
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
||||
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
||||
if isinstance(response, RoomTypingError):
|
||||
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _start_typing_keepalive(self, room_id: str) -> None:
|
||||
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
||||
await self._stop_typing_keepalive(room_id, clear_typing=False)
|
||||
await self._set_typing(room_id, True)
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
async def loop() -> None:
|
||||
try:
|
||||
while self._running:
|
||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
|
||||
await self._set_typing(room_id, True)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[room_id] = asyncio.create_task(loop())
|
||||
|
||||
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
|
||||
if task := self._typing_tasks.pop(room_id, None):
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if clear_typing:
|
||||
await self._set_typing(room_id, False)
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
while self._running:
|
||||
try:
|
||||
await self.client.sync_forever(timeout=30000, full_state=True)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
|
||||
if self.is_allowed(event.sender):
|
||||
await self.client.join(room.room_id)
|
||||
|
||||
def _is_direct_room(self, room: MatrixRoom) -> bool:
|
||||
count = getattr(room, "member_count", None)
|
||||
return isinstance(count, int) and count <= 2
|
||||
|
||||
def _is_bot_mentioned(self, event: RoomMessage) -> bool:
|
||||
"""Check m.mentions payload for bot mention."""
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return False
|
||||
mentions = (source.get("content") or {}).get("m.mentions")
|
||||
if not isinstance(mentions, dict):
|
||||
return False
|
||||
user_ids = mentions.get("user_ids")
|
||||
if isinstance(user_ids, list) and self.config.user_id in user_ids:
|
||||
return True
|
||||
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||
|
||||
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||
"""Apply sender and room policy checks."""
|
||||
if not self.is_allowed(event.sender):
|
||||
return False
|
||||
if self._is_direct_room(room):
|
||||
return True
|
||||
policy = self.config.group_policy
|
||||
if policy == "open":
|
||||
return True
|
||||
if policy == "allowlist":
|
||||
return room.room_id in (self.config.group_allow_from or [])
|
||||
if policy == "mention":
|
||||
return self._is_bot_mentioned(event)
|
||||
return False
|
||||
|
||||
def _media_dir(self) -> Path:
|
||||
return get_media_dir("matrix")
|
||||
|
||||
@staticmethod
|
||||
def _event_source_content(event: RoomMessage) -> dict[str, Any]:
|
||||
source = getattr(event, "source", None)
|
||||
if not isinstance(source, dict):
|
||||
return {}
|
||||
content = source.get("content")
|
||||
return content if isinstance(content, dict) else {}
|
||||
|
||||
def _event_thread_root_id(self, event: RoomMessage) -> str | None:
|
||||
relates_to = self._event_source_content(event).get("m.relates_to")
|
||||
if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
|
||||
return None
|
||||
root_id = relates_to.get("event_id")
|
||||
return root_id if isinstance(root_id, str) and root_id else None
|
||||
|
||||
def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
|
||||
if not (root_id := self._event_thread_root_id(event)):
|
||||
return None
|
||||
meta: dict[str, str] = {"thread_root_event_id": root_id}
|
||||
if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
|
||||
meta["thread_reply_to_event_id"] = reply_to
|
||||
return meta
|
||||
|
||||
@staticmethod
|
||||
def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not metadata:
|
||||
return None
|
||||
root_id = metadata.get("thread_root_event_id")
|
||||
if not isinstance(root_id, str) or not root_id:
|
||||
return None
|
||||
reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
|
||||
if not isinstance(reply_to, str) or not reply_to:
|
||||
return None
|
||||
return {"rel_type": "m.thread", "event_id": root_id,
|
||||
"m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
|
||||
|
||||
def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
|
||||
msgtype = self._event_source_content(event).get("msgtype")
|
||||
return _MSGTYPE_MAP.get(msgtype, "file")
|
||||
|
||||
@staticmethod
|
||||
def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
|
||||
return (isinstance(getattr(event, "key", None), dict)
|
||||
and isinstance(getattr(event, "hashes", None), dict)
|
||||
and isinstance(getattr(event, "iv", None), str))
|
||||
|
||||
def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
size = info.get("size") if isinstance(info, dict) else None
|
||||
return size if isinstance(size, int) and size >= 0 else None
|
||||
|
||||
def _event_mime(self, event: MatrixMediaEvent) -> str | None:
|
||||
info = self._event_source_content(event).get("info")
|
||||
if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
|
||||
return m
|
||||
m = getattr(event, "mimetype", None)
|
||||
return m if isinstance(m, str) and m else None
|
||||
|
||||
def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
|
||||
body = getattr(event, "body", None)
|
||||
if isinstance(body, str) and body.strip():
|
||||
if candidate := safe_filename(Path(body).name):
|
||||
return candidate
|
||||
return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
|
||||
|
||||
def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
|
||||
filename: str, mime: str | None) -> Path:
|
||||
safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
|
||||
suffix = Path(safe_name).suffix
|
||||
if not suffix and mime:
|
||||
if guessed := mimetypes.guess_extension(mime, strict=False):
|
||||
safe_name, suffix = f"{safe_name}{guessed}", guessed
|
||||
stem = (Path(safe_name).stem or attachment_type)[:72]
|
||||
suffix = suffix[:16]
|
||||
event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
|
||||
event_prefix = (event_id[:24] or "evt").strip("_")
|
||||
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
|
||||
|
||||
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
|
||||
if not self.client:
|
||||
return None
|
||||
response = await self.client.download(mxc=mxc_url)
|
||||
if isinstance(response, DownloadError):
|
||||
logger.warning("Matrix download failed for {}: {}", mxc_url, response)
|
||||
return None
|
||||
body = getattr(response, "body", None)
|
||||
if isinstance(body, (bytes, bytearray)):
|
||||
return bytes(body)
|
||||
if isinstance(response, MemoryDownloadResponse):
|
||||
return bytes(response.body)
|
||||
if isinstance(body, (str, Path)):
|
||||
path = Path(body)
|
||||
if path.is_file():
|
||||
try:
|
||||
return path.read_bytes()
|
||||
except OSError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
|
||||
key = key_obj.get("k") if isinstance(key_obj, dict) else None
|
||||
sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
|
||||
if not all(isinstance(v, str) for v in (key, sha256, iv)):
|
||||
return None
|
||||
try:
|
||||
return decrypt_attachment(ciphertext, key, sha256, iv)
|
||||
except (EncryptionError, ValueError, TypeError):
|
||||
logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
|
||||
return None
|
||||
|
||||
async def _fetch_media_attachment(
|
||||
self, room: MatrixRoom, event: MatrixMediaEvent,
|
||||
) -> tuple[dict[str, Any] | None, str]:
|
||||
"""Download, decrypt if needed, and persist a Matrix attachment."""
|
||||
atype = self._event_attachment_type(event)
|
||||
mime = self._event_mime(event)
|
||||
filename = self._event_filename(event, atype)
|
||||
mxc_url = getattr(event, "url", None)
|
||||
fail = _ATTACH_FAILED.format(filename)
|
||||
|
||||
if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
|
||||
return None, fail
|
||||
|
||||
limit_bytes = await self._effective_media_limit_bytes()
|
||||
declared = self._event_declared_size_bytes(event)
|
||||
if declared is not None and declared > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
downloaded = await self._download_media_bytes(mxc_url)
|
||||
if downloaded is None:
|
||||
return None, fail
|
||||
|
||||
encrypted = self._is_encrypted_media_event(event)
|
||||
data = downloaded
|
||||
if encrypted:
|
||||
if (data := self._decrypt_media_bytes(event, downloaded)) is None:
|
||||
return None, fail
|
||||
|
||||
if len(data) > limit_bytes:
|
||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||
|
||||
path = self._build_attachment_path(event, atype, filename, mime)
|
||||
try:
|
||||
path.write_bytes(data)
|
||||
except OSError:
|
||||
return None, fail
|
||||
|
||||
attachment = {
|
||||
"type": atype, "mime": mime, "filename": filename,
|
||||
"event_id": str(getattr(event, "event_id", "") or ""),
|
||||
"encrypted": encrypted, "size_bytes": len(data),
|
||||
"path": str(path), "mxc_url": mxc_url,
|
||||
}
|
||||
return attachment, _ATTACH_MARKER.format(path)
|
||||
|
||||
def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
|
||||
"""Build common metadata for text and media handlers."""
|
||||
meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
|
||||
if isinstance(eid := getattr(event, "event_id", None), str) and eid:
|
||||
meta["event_id"] = eid
|
||||
if thread := self._thread_metadata(event):
|
||||
meta.update(thread)
|
||||
return meta
|
||||
|
||||
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content=event.body, metadata=self._base_metadata(room, event),
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
|
||||
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
||||
return
|
||||
attachment, marker = await self._fetch_media_attachment(room, event)
|
||||
parts: list[str] = []
|
||||
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||
parts.append(body.strip())
|
||||
if marker:
|
||||
parts.append(marker)
|
||||
|
||||
await self._start_typing_keepalive(room.room_id)
|
||||
try:
|
||||
meta = self._base_metadata(room, event)
|
||||
meta["attachments"] = []
|
||||
if attachment:
|
||||
meta["attachments"] = [attachment]
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content="\n".join(parts),
|
||||
media=[attachment["path"]] if attachment else [],
|
||||
metadata=meta,
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
raise
|
||||
@@ -0,0 +1,895 @@
|
||||
"""Mochat channel implementation using Socket.IO with HTTP polling fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_runtime_subdir
|
||||
from nanobot.config.schema import MochatConfig
|
||||
|
||||
try:
|
||||
import socketio
|
||||
SOCKETIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
socketio = None
|
||||
SOCKETIO_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import msgpack # noqa: F401
|
||||
MSGPACK_AVAILABLE = True
|
||||
except ImportError:
|
||||
MSGPACK_AVAILABLE = False
|
||||
|
||||
MAX_SEEN_MESSAGE_IDS = 2000
|
||||
CURSOR_SAVE_DEBOUNCE_S = 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class MochatBufferedEntry:
|
||||
"""Buffered inbound entry for delayed dispatch."""
|
||||
raw_body: str
|
||||
author: str
|
||||
sender_name: str = ""
|
||||
sender_username: str = ""
|
||||
timestamp: int | None = None
|
||||
message_id: str = ""
|
||||
group_id: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelayState:
|
||||
"""Per-target delayed message state."""
|
||||
entries: list[MochatBufferedEntry] = field(default_factory=list)
|
||||
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||
timer: asyncio.Task | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MochatTarget:
|
||||
"""Outbound target resolution result."""
|
||||
id: str
|
||||
is_panel: bool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _safe_dict(value: Any) -> dict:
|
||||
"""Return *value* if it's a dict, else empty dict."""
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _str_field(src: dict, *keys: str) -> str:
|
||||
"""Return the first non-empty str value found for *keys*, stripped."""
|
||||
for k in keys:
|
||||
v = src.get(k)
|
||||
if isinstance(v, str) and v.strip():
|
||||
return v.strip()
|
||||
return ""
|
||||
|
||||
|
||||
def _make_synthetic_event(
|
||||
message_id: str, author: str, content: Any,
|
||||
meta: Any, group_id: str, converse_id: str,
|
||||
timestamp: Any = None, *, author_info: Any = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a synthetic ``message.add`` event dict."""
|
||||
payload: dict[str, Any] = {
|
||||
"messageId": message_id, "author": author,
|
||||
"content": content, "meta": _safe_dict(meta),
|
||||
"groupId": group_id, "converseId": converse_id,
|
||||
}
|
||||
if author_info is not None:
|
||||
payload["authorInfo"] = _safe_dict(author_info)
|
||||
return {
|
||||
"type": "message.add",
|
||||
"timestamp": timestamp or datetime.utcnow().isoformat(),
|
||||
"payload": payload,
|
||||
}
|
||||
|
||||
|
||||
def normalize_mochat_content(content: Any) -> str:
|
||||
"""Normalize content payload to text."""
|
||||
if isinstance(content, str):
|
||||
return content.strip()
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return str(content)
|
||||
|
||||
|
||||
def resolve_mochat_target(raw: str) -> MochatTarget:
|
||||
"""Resolve id and target kind from user-provided target string."""
|
||||
trimmed = (raw or "").strip()
|
||||
if not trimmed:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
|
||||
lowered = trimmed.lower()
|
||||
cleaned, forced_panel = trimmed, False
|
||||
for prefix in ("mochat:", "group:", "channel:", "panel:"):
|
||||
if lowered.startswith(prefix):
|
||||
cleaned = trimmed[len(prefix):].strip()
|
||||
forced_panel = prefix in {"group:", "channel:", "panel:"}
|
||||
break
|
||||
|
||||
if not cleaned:
|
||||
return MochatTarget(id="", is_panel=False)
|
||||
return MochatTarget(id=cleaned, is_panel=forced_panel or not cleaned.startswith("session_"))
|
||||
|
||||
|
||||
def extract_mention_ids(value: Any) -> list[str]:
|
||||
"""Extract mention ids from heterogeneous mention payload."""
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
ids: list[str] = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
if item.strip():
|
||||
ids.append(item.strip())
|
||||
elif isinstance(item, dict):
|
||||
for key in ("id", "userId", "_id"):
|
||||
candidate = item.get(key)
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
ids.append(candidate.strip())
|
||||
break
|
||||
return ids
|
||||
|
||||
|
||||
def resolve_was_mentioned(payload: dict[str, Any], agent_user_id: str) -> bool:
|
||||
"""Resolve mention state from payload metadata and text fallback."""
|
||||
meta = payload.get("meta")
|
||||
if isinstance(meta, dict):
|
||||
if meta.get("mentioned") is True or meta.get("wasMentioned") is True:
|
||||
return True
|
||||
for f in ("mentions", "mentionIds", "mentionedUserIds", "mentionedUsers"):
|
||||
if agent_user_id and agent_user_id in extract_mention_ids(meta.get(f)):
|
||||
return True
|
||||
if not agent_user_id:
|
||||
return False
|
||||
content = payload.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
return False
|
||||
return f"<@{agent_user_id}>" in content or f"@{agent_user_id}" in content
|
||||
|
||||
|
||||
def resolve_require_mention(config: MochatConfig, session_id: str, group_id: str) -> bool:
|
||||
"""Resolve mention requirement for group/panel conversations."""
|
||||
groups = config.groups or {}
|
||||
for key in (group_id, session_id, "*"):
|
||||
if key and key in groups:
|
||||
return bool(groups[key].require_mention)
|
||||
return bool(config.mention.require_in_groups)
|
||||
|
||||
|
||||
def build_buffered_body(entries: list[MochatBufferedEntry], is_group: bool) -> str:
|
||||
"""Build text body from one or more buffered entries."""
|
||||
if not entries:
|
||||
return ""
|
||||
if len(entries) == 1:
|
||||
return entries[0].raw_body
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
if not entry.raw_body:
|
||||
continue
|
||||
if is_group:
|
||||
label = entry.sender_name.strip() or entry.sender_username.strip() or entry.author
|
||||
if label:
|
||||
lines.append(f"{label}: {entry.raw_body}")
|
||||
continue
|
||||
lines.append(entry.raw_body)
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
|
||||
def parse_timestamp(value: Any) -> int | None:
|
||||
"""Parse event timestamp to epoch milliseconds."""
|
||||
if not isinstance(value, str) or not value.strip():
|
||||
return None
|
||||
try:
|
||||
return int(datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() * 1000)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class MochatChannel(BaseChannel):
|
||||
"""Mochat channel using socket.io with fallback polling workers."""
|
||||
|
||||
name = "mochat"
|
||||
|
||||
def __init__(self, config: MochatConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: MochatConfig = config
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._socket: Any = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
self._state_dir = get_runtime_subdir("mochat")
|
||||
self._cursor_path = self._state_dir / "session_cursors.json"
|
||||
self._session_cursor: dict[str, int] = {}
|
||||
self._cursor_save_task: asyncio.Task | None = None
|
||||
|
||||
self._session_set: set[str] = set()
|
||||
self._panel_set: set[str] = set()
|
||||
self._auto_discover_sessions = self._auto_discover_panels = False
|
||||
|
||||
self._cold_sessions: set[str] = set()
|
||||
self._session_by_converse: dict[str, str] = {}
|
||||
|
||||
self._seen_set: dict[str, set[str]] = {}
|
||||
self._seen_queue: dict[str, deque[str]] = {}
|
||||
self._delay_states: dict[str, DelayState] = {}
|
||||
|
||||
self._fallback_mode = False
|
||||
self._session_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._panel_fallback_tasks: dict[str, asyncio.Task] = {}
|
||||
self._refresh_task: asyncio.Task | None = None
|
||||
self._target_locks: dict[str, asyncio.Lock] = {}
|
||||
|
||||
# ---- lifecycle ---------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Mochat channel workers and websocket connection."""
|
||||
if not self.config.claw_token:
|
||||
logger.error("Mochat claw_token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
await self._load_session_cursors()
|
||||
self._seed_targets_from_config()
|
||||
await self._refresh_targets(subscribe_new=False)
|
||||
|
||||
if not await self._start_socket_client():
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
self._refresh_task = asyncio.create_task(self._refresh_loop())
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop all workers and clean up resources."""
|
||||
self._running = False
|
||||
if self._refresh_task:
|
||||
self._refresh_task.cancel()
|
||||
self._refresh_task = None
|
||||
|
||||
await self._stop_fallback_workers()
|
||||
await self._cancel_delay_timers()
|
||||
|
||||
if self._socket:
|
||||
try:
|
||||
await self._socket.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
|
||||
if self._cursor_save_task:
|
||||
self._cursor_save_task.cancel()
|
||||
self._cursor_save_task = None
|
||||
await self._save_session_cursors()
|
||||
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
self._ws_connected = self._ws_ready = False
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send outbound message to session or panel."""
|
||||
if not self.config.claw_token:
|
||||
logger.warning("Mochat claw_token missing, skip send")
|
||||
return
|
||||
|
||||
parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
|
||||
if msg.media:
|
||||
parts.extend(m for m in msg.media if isinstance(m, str) and m.strip())
|
||||
content = "\n".join(parts).strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
target = resolve_mochat_target(msg.chat_id)
|
||||
if not target.id:
|
||||
logger.warning("Mochat outbound target is empty")
|
||||
return
|
||||
|
||||
is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
|
||||
try:
|
||||
if is_panel:
|
||||
await self._api_send("/api/claw/groups/panels/send", "panelId", target.id,
|
||||
content, msg.reply_to, self._read_group_id(msg.metadata))
|
||||
else:
|
||||
await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
|
||||
content, msg.reply_to)
|
||||
except Exception as e:
|
||||
logger.error("Failed to send Mochat message: {}", e)
|
||||
|
||||
# ---- config / init helpers ---------------------------------------------
|
||||
|
||||
def _seed_targets_from_config(self) -> None:
|
||||
sessions, self._auto_discover_sessions = self._normalize_id_list(self.config.sessions)
|
||||
panels, self._auto_discover_panels = self._normalize_id_list(self.config.panels)
|
||||
self._session_set.update(sessions)
|
||||
self._panel_set.update(panels)
|
||||
for sid in sessions:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_id_list(values: list[str]) -> tuple[list[str], bool]:
|
||||
cleaned = [str(v).strip() for v in values if str(v).strip()]
|
||||
return sorted({v for v in cleaned if v != "*"}), "*" in cleaned
|
||||
|
||||
# ---- websocket ---------------------------------------------------------
|
||||
|
||||
async def _start_socket_client(self) -> bool:
|
||||
if not SOCKETIO_AVAILABLE:
|
||||
logger.warning("python-socketio not installed, Mochat using polling fallback")
|
||||
return False
|
||||
|
||||
serializer = "default"
|
||||
if not self.config.socket_disable_msgpack:
|
||||
if MSGPACK_AVAILABLE:
|
||||
serializer = "msgpack"
|
||||
else:
|
||||
logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
|
||||
|
||||
client = socketio.AsyncClient(
|
||||
reconnection=True,
|
||||
reconnection_attempts=self.config.max_retry_attempts or None,
|
||||
reconnection_delay=max(0.1, self.config.socket_reconnect_delay_ms / 1000.0),
|
||||
reconnection_delay_max=max(0.1, self.config.socket_max_reconnect_delay_ms / 1000.0),
|
||||
logger=False, engineio_logger=False, serializer=serializer,
|
||||
)
|
||||
|
||||
@client.event
|
||||
async def connect() -> None:
|
||||
self._ws_connected, self._ws_ready = True, False
|
||||
logger.info("Mochat websocket connected")
|
||||
subscribed = await self._subscribe_all()
|
||||
self._ws_ready = subscribed
|
||||
await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
|
||||
|
||||
@client.event
|
||||
async def disconnect() -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._ws_connected = self._ws_ready = False
|
||||
logger.warning("Mochat websocket disconnected")
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
@client.event
|
||||
async def connect_error(data: Any) -> None:
|
||||
logger.error("Mochat websocket connect error: {}", data)
|
||||
|
||||
@client.on("claw.session.events")
|
||||
async def on_session_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
|
||||
@client.on("claw.panel.events")
|
||||
async def on_panel_events(payload: dict[str, Any]) -> None:
|
||||
await self._handle_watch_payload(payload, "panel")
|
||||
|
||||
for ev in ("notify:chat.inbox.append", "notify:chat.message.add",
|
||||
"notify:chat.message.update", "notify:chat.message.recall",
|
||||
"notify:chat.message.delete"):
|
||||
client.on(ev, self._build_notify_handler(ev))
|
||||
|
||||
socket_url = (self.config.socket_url or self.config.base_url).strip().rstrip("/")
|
||||
socket_path = (self.config.socket_path or "/socket.io").strip().lstrip("/")
|
||||
|
||||
try:
|
||||
self._socket = client
|
||||
await client.connect(
|
||||
socket_url, transports=["websocket"], socketio_path=socket_path,
|
||||
auth={"token": self.config.claw_token},
|
||||
wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to connect Mochat websocket: {}", e)
|
||||
try:
|
||||
await client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._socket = None
|
||||
return False
|
||||
|
||||
def _build_notify_handler(self, event_name: str):
|
||||
async def handler(payload: Any) -> None:
|
||||
if event_name == "notify:chat.inbox.append":
|
||||
await self._handle_notify_inbox_append(payload)
|
||||
elif event_name.startswith("notify:chat.message."):
|
||||
await self._handle_notify_chat_message(payload)
|
||||
return handler
|
||||
|
||||
# ---- subscribe ---------------------------------------------------------
|
||||
|
||||
async def _subscribe_all(self) -> bool:
|
||||
ok = await self._subscribe_sessions(sorted(self._session_set))
|
||||
ok = await self._subscribe_panels(sorted(self._panel_set)) and ok
|
||||
if self._auto_discover_sessions or self._auto_discover_panels:
|
||||
await self._refresh_targets(subscribe_new=True)
|
||||
return ok
|
||||
|
||||
async def _subscribe_sessions(self, session_ids: list[str]) -> bool:
|
||||
if not session_ids:
|
||||
return True
|
||||
for sid in session_ids:
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
|
||||
ack = await self._socket_call("com.claw.im.subscribeSessions", {
|
||||
"sessionIds": session_ids, "cursors": self._session_cursor,
|
||||
"limit": self.config.watch_limit,
|
||||
})
|
||||
if not ack.get("result"):
|
||||
logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
|
||||
data = ack.get("data")
|
||||
items: list[dict[str, Any]] = []
|
||||
if isinstance(data, list):
|
||||
items = [i for i in data if isinstance(i, dict)]
|
||||
elif isinstance(data, dict):
|
||||
sessions = data.get("sessions")
|
||||
if isinstance(sessions, list):
|
||||
items = [i for i in sessions if isinstance(i, dict)]
|
||||
elif "sessionId" in data:
|
||||
items = [data]
|
||||
for p in items:
|
||||
await self._handle_watch_payload(p, "session")
|
||||
return True
|
||||
|
||||
async def _subscribe_panels(self, panel_ids: list[str]) -> bool:
|
||||
if not self._auto_discover_panels and not panel_ids:
|
||||
return True
|
||||
ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
|
||||
if not ack.get("result"):
|
||||
logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _socket_call(self, event_name: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._socket:
|
||||
return {"result": False, "message": "socket not connected"}
|
||||
try:
|
||||
raw = await self._socket.call(event_name, payload, timeout=10)
|
||||
except Exception as e:
|
||||
return {"result": False, "message": str(e)}
|
||||
return raw if isinstance(raw, dict) else {"result": True, "data": raw}
|
||||
|
||||
# ---- refresh / discovery -----------------------------------------------
|
||||
|
||||
async def _refresh_loop(self) -> None:
|
||||
interval_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running:
|
||||
await asyncio.sleep(interval_s)
|
||||
try:
|
||||
await self._refresh_targets(subscribe_new=self._ws_ready)
|
||||
except Exception as e:
|
||||
logger.warning("Mochat refresh failed: {}", e)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_targets(self, subscribe_new: bool) -> None:
|
||||
if self._auto_discover_sessions:
|
||||
await self._refresh_sessions_directory(subscribe_new)
|
||||
if self._auto_discover_panels:
|
||||
await self._refresh_panels(subscribe_new)
|
||||
|
||||
async def _refresh_sessions_directory(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/sessions/list", {})
|
||||
except Exception as e:
|
||||
logger.warning("Mochat listSessions failed: {}", e)
|
||||
return
|
||||
|
||||
sessions = response.get("sessions")
|
||||
if not isinstance(sessions, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for s in sessions:
|
||||
if not isinstance(s, dict):
|
||||
continue
|
||||
sid = _str_field(s, "sessionId")
|
||||
if not sid:
|
||||
continue
|
||||
if sid not in self._session_set:
|
||||
self._session_set.add(sid)
|
||||
new_ids.append(sid)
|
||||
if sid not in self._session_cursor:
|
||||
self._cold_sessions.add(sid)
|
||||
cid = _str_field(s, "converseId")
|
||||
if cid:
|
||||
self._session_by_converse[cid] = sid
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_sessions(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
async def _refresh_panels(self, subscribe_new: bool) -> None:
|
||||
try:
|
||||
response = await self._post_json("/api/claw/groups/get", {})
|
||||
except Exception as e:
|
||||
logger.warning("Mochat getWorkspaceGroup failed: {}", e)
|
||||
return
|
||||
|
||||
raw_panels = response.get("panels")
|
||||
if not isinstance(raw_panels, list):
|
||||
return
|
||||
|
||||
new_ids: list[str] = []
|
||||
for p in raw_panels:
|
||||
if not isinstance(p, dict):
|
||||
continue
|
||||
pt = p.get("type")
|
||||
if isinstance(pt, int) and pt != 0:
|
||||
continue
|
||||
pid = _str_field(p, "id", "_id")
|
||||
if pid and pid not in self._panel_set:
|
||||
self._panel_set.add(pid)
|
||||
new_ids.append(pid)
|
||||
|
||||
if not new_ids:
|
||||
return
|
||||
if self._ws_ready and subscribe_new:
|
||||
await self._subscribe_panels(new_ids)
|
||||
if self._fallback_mode:
|
||||
await self._ensure_fallback_workers()
|
||||
|
||||
# ---- fallback workers --------------------------------------------------
|
||||
|
||||
async def _ensure_fallback_workers(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._fallback_mode = True
|
||||
for sid in sorted(self._session_set):
|
||||
t = self._session_fallback_tasks.get(sid)
|
||||
if not t or t.done():
|
||||
self._session_fallback_tasks[sid] = asyncio.create_task(self._session_watch_worker(sid))
|
||||
for pid in sorted(self._panel_set):
|
||||
t = self._panel_fallback_tasks.get(pid)
|
||||
if not t or t.done():
|
||||
self._panel_fallback_tasks[pid] = asyncio.create_task(self._panel_poll_worker(pid))
|
||||
|
||||
async def _stop_fallback_workers(self) -> None:
|
||||
self._fallback_mode = False
|
||||
tasks = [*self._session_fallback_tasks.values(), *self._panel_fallback_tasks.values()]
|
||||
for t in tasks:
|
||||
t.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._session_fallback_tasks.clear()
|
||||
self._panel_fallback_tasks.clear()
|
||||
|
||||
async def _session_watch_worker(self, session_id: str) -> None:
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
payload = await self._post_json("/api/claw/sessions/watch", {
|
||||
"sessionId": session_id, "cursor": self._session_cursor.get(session_id, 0),
|
||||
"timeoutMs": self.config.watch_timeout_ms, "limit": self.config.watch_limit,
|
||||
})
|
||||
await self._handle_watch_payload(payload, "session")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
|
||||
await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
|
||||
|
||||
async def _panel_poll_worker(self, panel_id: str) -> None:
|
||||
sleep_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
|
||||
while self._running and self._fallback_mode:
|
||||
try:
|
||||
resp = await self._post_json("/api/claw/groups/panels/messages", {
|
||||
"panelId": panel_id, "limit": min(100, max(1, self.config.watch_limit)),
|
||||
})
|
||||
msgs = resp.get("messages")
|
||||
if isinstance(msgs, list):
|
||||
for m in reversed(msgs):
|
||||
if not isinstance(m, dict):
|
||||
continue
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(m.get("messageId") or ""),
|
||||
author=str(m.get("author") or ""),
|
||||
content=m.get("content"),
|
||||
meta=m.get("meta"), group_id=str(resp.get("groupId") or ""),
|
||||
converse_id=panel_id, timestamp=m.get("createdAt"),
|
||||
author_info=m.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
|
||||
await asyncio.sleep(sleep_s)
|
||||
|
||||
# ---- inbound event processing ------------------------------------------
|
||||
|
||||
async def _handle_watch_payload(self, payload: dict[str, Any], target_kind: str) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
target_id = _str_field(payload, "sessionId")
|
||||
if not target_id:
|
||||
return
|
||||
|
||||
lock = self._target_locks.setdefault(f"{target_kind}:{target_id}", asyncio.Lock())
|
||||
async with lock:
|
||||
prev = self._session_cursor.get(target_id, 0) if target_kind == "session" else 0
|
||||
pc = payload.get("cursor")
|
||||
if target_kind == "session" and isinstance(pc, int) and pc >= 0:
|
||||
self._mark_session_cursor(target_id, pc)
|
||||
|
||||
raw_events = payload.get("events")
|
||||
if not isinstance(raw_events, list):
|
||||
return
|
||||
if target_kind == "session" and target_id in self._cold_sessions:
|
||||
self._cold_sessions.discard(target_id)
|
||||
return
|
||||
|
||||
for event in raw_events:
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
seq = event.get("seq")
|
||||
if target_kind == "session" and isinstance(seq, int) and seq > self._session_cursor.get(target_id, prev):
|
||||
self._mark_session_cursor(target_id, seq)
|
||||
if event.get("type") == "message.add":
|
||||
await self._process_inbound_event(target_id, event, target_kind)
|
||||
|
||||
async def _process_inbound_event(self, target_id: str, event: dict[str, Any], target_kind: str) -> None:
|
||||
payload = event.get("payload")
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
|
||||
author = _str_field(payload, "author")
|
||||
if not author or (self.config.agent_user_id and author == self.config.agent_user_id):
|
||||
return
|
||||
if not self.is_allowed(author):
|
||||
return
|
||||
|
||||
message_id = _str_field(payload, "messageId")
|
||||
seen_key = f"{target_kind}:{target_id}"
|
||||
if message_id and self._remember_message_id(seen_key, message_id):
|
||||
return
|
||||
|
||||
raw_body = normalize_mochat_content(payload.get("content")) or "[empty message]"
|
||||
ai = _safe_dict(payload.get("authorInfo"))
|
||||
sender_name = _str_field(ai, "nickname", "email")
|
||||
sender_username = _str_field(ai, "agentId")
|
||||
|
||||
group_id = _str_field(payload, "groupId")
|
||||
is_group = bool(group_id)
|
||||
was_mentioned = resolve_was_mentioned(payload, self.config.agent_user_id)
|
||||
require_mention = target_kind == "panel" and is_group and resolve_require_mention(self.config, target_id, group_id)
|
||||
use_delay = target_kind == "panel" and self.config.reply_delay_mode == "non-mention"
|
||||
|
||||
if require_mention and not was_mentioned and not use_delay:
|
||||
return
|
||||
|
||||
entry = MochatBufferedEntry(
|
||||
raw_body=raw_body, author=author, sender_name=sender_name,
|
||||
sender_username=sender_username, timestamp=parse_timestamp(event.get("timestamp")),
|
||||
message_id=message_id, group_id=group_id,
|
||||
)
|
||||
|
||||
if use_delay:
|
||||
delay_key = seen_key
|
||||
if was_mentioned:
|
||||
await self._flush_delayed_entries(delay_key, target_id, target_kind, "mention", entry)
|
||||
else:
|
||||
await self._enqueue_delayed_entry(delay_key, target_id, target_kind, entry)
|
||||
return
|
||||
|
||||
await self._dispatch_entries(target_id, target_kind, [entry], was_mentioned)
|
||||
|
||||
# ---- dedup / buffering -------------------------------------------------
|
||||
|
||||
def _remember_message_id(self, key: str, message_id: str) -> bool:
|
||||
seen_set = self._seen_set.setdefault(key, set())
|
||||
seen_queue = self._seen_queue.setdefault(key, deque())
|
||||
if message_id in seen_set:
|
||||
return True
|
||||
seen_set.add(message_id)
|
||||
seen_queue.append(message_id)
|
||||
while len(seen_queue) > MAX_SEEN_MESSAGE_IDS:
|
||||
seen_set.discard(seen_queue.popleft())
|
||||
return False
|
||||
|
||||
async def _enqueue_delayed_entry(self, key: str, target_id: str, target_kind: str, entry: MochatBufferedEntry) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
state.entries.append(entry)
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
state.timer = asyncio.create_task(self._delay_flush_after(key, target_id, target_kind))
|
||||
|
||||
async def _delay_flush_after(self, key: str, target_id: str, target_kind: str) -> None:
|
||||
await asyncio.sleep(max(0, self.config.reply_delay_ms) / 1000.0)
|
||||
await self._flush_delayed_entries(key, target_id, target_kind, "timer", None)
|
||||
|
||||
async def _flush_delayed_entries(self, key: str, target_id: str, target_kind: str, reason: str, entry: MochatBufferedEntry | None) -> None:
|
||||
state = self._delay_states.setdefault(key, DelayState())
|
||||
async with state.lock:
|
||||
if entry:
|
||||
state.entries.append(entry)
|
||||
current = asyncio.current_task()
|
||||
if state.timer and state.timer is not current:
|
||||
state.timer.cancel()
|
||||
state.timer = None
|
||||
entries = state.entries[:]
|
||||
state.entries.clear()
|
||||
if entries:
|
||||
await self._dispatch_entries(target_id, target_kind, entries, reason == "mention")
|
||||
|
||||
async def _dispatch_entries(self, target_id: str, target_kind: str, entries: list[MochatBufferedEntry], was_mentioned: bool) -> None:
|
||||
if not entries:
|
||||
return
|
||||
last = entries[-1]
|
||||
is_group = bool(last.group_id)
|
||||
body = build_buffered_body(entries, is_group) or "[empty message]"
|
||||
await self._handle_message(
|
||||
sender_id=last.author, chat_id=target_id, content=body,
|
||||
metadata={
|
||||
"message_id": last.message_id, "timestamp": last.timestamp,
|
||||
"is_group": is_group, "group_id": last.group_id,
|
||||
"sender_name": last.sender_name, "sender_username": last.sender_username,
|
||||
"target_kind": target_kind, "was_mentioned": was_mentioned,
|
||||
"buffered_count": len(entries),
|
||||
},
|
||||
)
|
||||
|
||||
async def _cancel_delay_timers(self) -> None:
|
||||
for state in self._delay_states.values():
|
||||
if state.timer:
|
||||
state.timer.cancel()
|
||||
self._delay_states.clear()
|
||||
|
||||
# ---- notify handlers ---------------------------------------------------
|
||||
|
||||
async def _handle_notify_chat_message(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict):
|
||||
return
|
||||
group_id = _str_field(payload, "groupId")
|
||||
panel_id = _str_field(payload, "converseId", "panelId")
|
||||
if not group_id or not panel_id:
|
||||
return
|
||||
if self._panel_set and panel_id not in self._panel_set:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(payload.get("_id") or payload.get("messageId") or ""),
|
||||
author=str(payload.get("author") or ""),
|
||||
content=payload.get("content"), meta=payload.get("meta"),
|
||||
group_id=group_id, converse_id=panel_id,
|
||||
timestamp=payload.get("createdAt"), author_info=payload.get("authorInfo"),
|
||||
)
|
||||
await self._process_inbound_event(panel_id, evt, "panel")
|
||||
|
||||
async def _handle_notify_inbox_append(self, payload: Any) -> None:
|
||||
if not isinstance(payload, dict) or payload.get("type") != "message":
|
||||
return
|
||||
detail = payload.get("payload")
|
||||
if not isinstance(detail, dict):
|
||||
return
|
||||
if _str_field(detail, "groupId"):
|
||||
return
|
||||
converse_id = _str_field(detail, "converseId")
|
||||
if not converse_id:
|
||||
return
|
||||
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
await self._refresh_sessions_directory(self._ws_ready)
|
||||
session_id = self._session_by_converse.get(converse_id)
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
evt = _make_synthetic_event(
|
||||
message_id=str(detail.get("messageId") or payload.get("_id") or ""),
|
||||
author=str(detail.get("messageAuthor") or ""),
|
||||
content=str(detail.get("messagePlainContent") or detail.get("messageSnippet") or ""),
|
||||
meta={"source": "notify:chat.inbox.append", "converseId": converse_id},
|
||||
group_id="", converse_id=converse_id, timestamp=payload.get("createdAt"),
|
||||
)
|
||||
await self._process_inbound_event(session_id, evt, "session")
|
||||
|
||||
# ---- cursor persistence ------------------------------------------------
|
||||
|
||||
def _mark_session_cursor(self, session_id: str, cursor: int) -> None:
|
||||
if cursor < 0 or cursor < self._session_cursor.get(session_id, 0):
|
||||
return
|
||||
self._session_cursor[session_id] = cursor
|
||||
if not self._cursor_save_task or self._cursor_save_task.done():
|
||||
self._cursor_save_task = asyncio.create_task(self._save_cursor_debounced())
|
||||
|
||||
async def _save_cursor_debounced(self) -> None:
|
||||
await asyncio.sleep(CURSOR_SAVE_DEBOUNCE_S)
|
||||
await self._save_session_cursors()
|
||||
|
||||
async def _load_session_cursors(self) -> None:
|
||||
if not self._cursor_path.exists():
|
||||
return
|
||||
try:
|
||||
data = json.loads(self._cursor_path.read_text("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read Mochat cursor file: {}", e)
|
||||
return
|
||||
cursors = data.get("cursors") if isinstance(data, dict) else None
|
||||
if isinstance(cursors, dict):
|
||||
for sid, cur in cursors.items():
|
||||
if isinstance(sid, str) and isinstance(cur, int) and cur >= 0:
|
||||
self._session_cursor[sid] = cur
|
||||
|
||||
async def _save_session_cursors(self) -> None:
|
||||
try:
|
||||
self._state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._cursor_path.write_text(json.dumps({
|
||||
"schemaVersion": 1, "updatedAt": datetime.utcnow().isoformat(),
|
||||
"cursors": self._session_cursor,
|
||||
}, ensure_ascii=False, indent=2) + "\n", "utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save Mochat cursor file: {}", e)
|
||||
|
||||
# ---- HTTP helpers ------------------------------------------------------
|
||||
|
||||
async def _post_json(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if not self._http:
|
||||
raise RuntimeError("Mochat HTTP client not initialized")
|
||||
url = f"{self.config.base_url.strip().rstrip('/')}{path}"
|
||||
response = await self._http.post(url, headers={
|
||||
"Content-Type": "application/json", "X-Claw-Token": self.config.claw_token,
|
||||
}, json=payload)
|
||||
if not response.is_success:
|
||||
raise RuntimeError(f"Mochat HTTP {response.status_code}: {response.text[:200]}")
|
||||
try:
|
||||
parsed = response.json()
|
||||
except Exception:
|
||||
parsed = response.text
|
||||
if isinstance(parsed, dict) and isinstance(parsed.get("code"), int):
|
||||
if parsed["code"] != 200:
|
||||
msg = str(parsed.get("message") or parsed.get("name") or "request failed")
|
||||
raise RuntimeError(f"Mochat API error: {msg} (code={parsed['code']})")
|
||||
data = parsed.get("data")
|
||||
return data if isinstance(data, dict) else {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
async def _api_send(self, path: str, id_key: str, id_val: str,
|
||||
content: str, reply_to: str | None, group_id: str | None = None) -> dict[str, Any]:
|
||||
"""Unified send helper for session and panel messages."""
|
||||
body: dict[str, Any] = {id_key: id_val, "content": content}
|
||||
if reply_to:
|
||||
body["replyTo"] = reply_to
|
||||
if group_id:
|
||||
body["groupId"] = group_id
|
||||
return await self._post_json(path, body)
|
||||
|
||||
@staticmethod
|
||||
def _read_group_id(metadata: dict[str, Any]) -> str | None:
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
value = metadata.get("group_id") or metadata.get("groupId")
|
||||
return value.strip() if isinstance(value, str) and value.strip() else None
|
||||
@@ -0,0 +1,160 @@
|
||||
"""QQ channel implementation using botpy SDK."""
|
||||
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import QQConfig
|
||||
|
||||
try:
|
||||
import botpy
|
||||
from botpy.message import C2CMessage, GroupMessage
|
||||
|
||||
QQ_AVAILABLE = True
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
botpy = None
|
||||
C2CMessage = None
|
||||
GroupMessage = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botpy.message import C2CMessage, GroupMessage
|
||||
|
||||
|
||||
def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
||||
"""Create a botpy Client subclass bound to the given channel."""
|
||||
intents = botpy.Intents(public_messages=True, direct_message=True)
|
||||
|
||||
class _Bot(botpy.Client):
|
||||
def __init__(self):
|
||||
# Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
|
||||
super().__init__(intents=intents, ext_handlers=False)
|
||||
|
||||
async def on_ready(self):
|
||||
logger.info("QQ bot ready: {}", self.robot.name)
|
||||
|
||||
async def on_c2c_message_create(self, message: "C2CMessage"):
|
||||
await channel._on_message(message, is_group=False)
|
||||
|
||||
async def on_group_at_message_create(self, message: "GroupMessage"):
|
||||
await channel._on_message(message, is_group=True)
|
||||
|
||||
async def on_direct_message_create(self, message):
|
||||
await channel._on_message(message, is_group=False)
|
||||
|
||||
return _Bot
|
||||
|
||||
|
||||
class QQChannel(BaseChannel):
|
||||
"""QQ channel using botpy SDK with WebSocket connection."""
|
||||
|
||||
name = "qq"
|
||||
|
||||
def __init__(self, config: QQConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: QQConfig = config
|
||||
self._client: "botpy.Client | None" = None
|
||||
self._processed_ids: deque = deque(maxlen=1000)
|
||||
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
|
||||
self._chat_type_cache: dict[str, str] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the QQ bot."""
|
||||
if not QQ_AVAILABLE:
|
||||
logger.error("QQ SDK not installed. Run: pip install qq-botpy")
|
||||
return
|
||||
|
||||
if not self.config.app_id or not self.config.secret:
|
||||
logger.error("QQ app_id and secret not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
BotClass = _make_bot_class(self)
|
||||
self._client = BotClass()
|
||||
logger.info("QQ bot started (C2C & Group supported)")
|
||||
await self._run_bot()
|
||||
|
||||
async def _run_bot(self) -> None:
|
||||
"""Run the bot connection with auto-reconnect."""
|
||||
while self._running:
|
||||
try:
|
||||
await self._client.start(appid=self.config.app_id, secret=self.config.secret)
|
||||
except Exception as e:
|
||||
logger.warning("QQ bot error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting QQ bot in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the QQ bot."""
|
||||
self._running = False
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("QQ bot stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through QQ."""
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
self._msg_seq += 1
|
||||
msg_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
if msg_type == "group":
|
||||
await self._client.api.post_group_message(
|
||||
group_openid=msg.chat_id,
|
||||
msg_type=2,
|
||||
markdown={"content": msg.content},
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq,
|
||||
)
|
||||
else:
|
||||
await self._client.api.post_c2c_message(
|
||||
openid=msg.chat_id,
|
||||
msg_type=2,
|
||||
markdown={"content": msg.content},
|
||||
msg_id=msg_id,
|
||||
msg_seq=self._msg_seq,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error sending QQ message: {}", e)
|
||||
|
||||
async def _on_message(self, data: "C2CMessage | GroupMessage", is_group: bool = False) -> None:
|
||||
"""Handle incoming message from QQ."""
|
||||
try:
|
||||
# Dedup by message ID
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
|
||||
content = (data.content or "").strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(getattr(data.author, 'id', None) or getattr(data.author, 'user_openid', 'unknown'))
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
metadata={"message_id": data.id},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ message")
|
||||
@@ -0,0 +1,281 @@
|
||||
"""Slack channel implementation using Socket Mode."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from slackify_markdown import slackify_markdown
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import SlackConfig
|
||||
|
||||
|
||||
class SlackChannel(BaseChannel):
|
||||
"""Slack channel using Socket Mode."""
|
||||
|
||||
name = "slack"
|
||||
|
||||
def __init__(self, config: SlackConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: SlackConfig = config
|
||||
self._web_client: AsyncWebClient | None = None
|
||||
self._socket_client: SocketModeClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Slack Socket Mode client."""
|
||||
if not self.config.bot_token or not self.config.app_token:
|
||||
logger.error("Slack bot/app token not configured")
|
||||
return
|
||||
if self.config.mode != "socket":
|
||||
logger.error("Unsupported Slack mode: {}", self.config.mode)
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
self._web_client = AsyncWebClient(token=self.config.bot_token)
|
||||
self._socket_client = SocketModeClient(
|
||||
app_token=self.config.app_token,
|
||||
web_client=self._web_client,
|
||||
)
|
||||
|
||||
self._socket_client.socket_mode_request_listeners.append(self._on_socket_request)
|
||||
|
||||
# Resolve bot user ID for mention handling
|
||||
try:
|
||||
auth = await self._web_client.auth_test()
|
||||
self._bot_user_id = auth.get("user_id")
|
||||
logger.info("Slack bot connected as {}", self._bot_user_id)
|
||||
except Exception as e:
|
||||
logger.warning("Slack auth_test failed: {}", e)
|
||||
|
||||
logger.info("Starting Slack Socket Mode client...")
|
||||
await self._socket_client.connect()
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Slack client."""
|
||||
self._running = False
|
||||
if self._socket_client:
|
||||
try:
|
||||
await self._socket_client.close()
|
||||
except Exception as e:
|
||||
logger.warning("Slack socket close failed: {}", e)
|
||||
self._socket_client = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Slack."""
|
||||
if not self._web_client:
|
||||
logger.warning("Slack client not running")
|
||||
return
|
||||
try:
|
||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||
thread_ts = slack_meta.get("thread_ts")
|
||||
channel_type = slack_meta.get("channel_type")
|
||||
# Only reply in thread for channel/group messages; DMs don't use threads
|
||||
thread_ts_param = thread_ts if use_thread else None
|
||||
|
||||
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||
# but send a single blank message when the bot has no text or files to send.
|
||||
if msg.content or not (msg.media or []):
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
|
||||
for media_path in msg.media or []:
|
||||
try:
|
||||
await self._web_client.files_upload_v2(
|
||||
channel=msg.chat_id,
|
||||
file=media_path,
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to upload file {}: {}", media_path, e)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Slack message: {}", e)
|
||||
|
||||
async def _on_socket_request(
|
||||
self,
|
||||
client: SocketModeClient,
|
||||
req: SocketModeRequest,
|
||||
) -> None:
|
||||
"""Handle incoming Socket Mode requests."""
|
||||
if req.type != "events_api":
|
||||
return
|
||||
|
||||
# Acknowledge right away
|
||||
await client.send_socket_mode_response(
|
||||
SocketModeResponse(envelope_id=req.envelope_id)
|
||||
)
|
||||
|
||||
payload = req.payload or {}
|
||||
event = payload.get("event") or {}
|
||||
event_type = event.get("type")
|
||||
|
||||
# Handle app mentions or plain messages
|
||||
if event_type not in ("message", "app_mention"):
|
||||
return
|
||||
|
||||
sender_id = event.get("user")
|
||||
chat_id = event.get("channel")
|
||||
|
||||
# Ignore bot/system messages (any subtype = not a normal user message)
|
||||
if event.get("subtype"):
|
||||
return
|
||||
if self._bot_user_id and sender_id == self._bot_user_id:
|
||||
return
|
||||
|
||||
# Avoid double-processing: Slack sends both `message` and `app_mention`
|
||||
# for mentions in channels. Prefer `app_mention`.
|
||||
text = event.get("text") or ""
|
||||
if event_type == "message" and self._bot_user_id and f"<@{self._bot_user_id}>" in text:
|
||||
return
|
||||
|
||||
# Debug: log basic event shape
|
||||
logger.debug(
|
||||
"Slack event: type={} subtype={} user={} channel={} channel_type={} text={}",
|
||||
event_type,
|
||||
event.get("subtype"),
|
||||
sender_id,
|
||||
chat_id,
|
||||
event.get("channel_type"),
|
||||
text[:80],
|
||||
)
|
||||
if not sender_id or not chat_id:
|
||||
return
|
||||
|
||||
channel_type = event.get("channel_type") or ""
|
||||
|
||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||
return
|
||||
|
||||
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
|
||||
return
|
||||
|
||||
text = self._strip_bot_mention(text)
|
||||
|
||||
thread_ts = event.get("thread_ts")
|
||||
if self.config.reply_in_thread and not thread_ts:
|
||||
thread_ts = event.get("ts")
|
||||
# Add :eyes: reaction to the triggering message (best-effort)
|
||||
try:
|
||||
if self._web_client and event.get("ts"):
|
||||
await self._web_client.reactions_add(
|
||||
channel=chat_id,
|
||||
name=self.config.react_emoji,
|
||||
timestamp=event.get("ts"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Slack reactions_add failed: {}", e)
|
||||
|
||||
# Thread-scoped session key for channel/group messages
|
||||
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": event,
|
||||
"thread_ts": thread_ts,
|
||||
"channel_type": channel_type,
|
||||
},
|
||||
},
|
||||
session_key=session_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling Slack message from {}", sender_id)
|
||||
|
||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||
if channel_type == "im":
|
||||
if not self.config.dm.enabled:
|
||||
return False
|
||||
if self.config.dm.policy == "allowlist":
|
||||
return sender_id in self.config.dm.allow_from
|
||||
return True
|
||||
|
||||
# Group / channel messages
|
||||
if self.config.group_policy == "allowlist":
|
||||
return chat_id in self.config.group_allow_from
|
||||
return True
|
||||
|
||||
def _should_respond_in_channel(self, event_type: str, text: str, chat_id: str) -> bool:
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
if self.config.group_policy == "mention":
|
||||
if event_type == "app_mention":
|
||||
return True
|
||||
return self._bot_user_id is not None and f"<@{self._bot_user_id}>" in text
|
||||
if self.config.group_policy == "allowlist":
|
||||
return chat_id in self.config.group_allow_from
|
||||
return False
|
||||
|
||||
def _strip_bot_mention(self, text: str) -> str:
|
||||
if not text or not self._bot_user_id:
|
||||
return text
|
||||
return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
|
||||
|
||||
_TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
|
||||
_CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
|
||||
_INLINE_CODE_RE = re.compile(r"`[^`]+`")
|
||||
_LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
|
||||
_LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
|
||||
_BARE_URL_RE = re.compile(r"(?<![|<])(https?://\S+)")
|
||||
|
||||
@classmethod
|
||||
def _to_mrkdwn(cls, text: str) -> str:
|
||||
"""Convert Markdown to Slack mrkdwn, including tables."""
|
||||
if not text:
|
||||
return ""
|
||||
text = cls._TABLE_RE.sub(cls._convert_table, text)
|
||||
return cls._fixup_mrkdwn(slackify_markdown(text))
|
||||
|
||||
@classmethod
|
||||
def _fixup_mrkdwn(cls, text: str) -> str:
|
||||
"""Fix markdown artifacts that slackify_markdown misses."""
|
||||
code_blocks: list[str] = []
|
||||
|
||||
def _save_code(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(0))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
text = cls._CODE_FENCE_RE.sub(_save_code, text)
|
||||
text = cls._INLINE_CODE_RE.sub(_save_code, text)
|
||||
text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
|
||||
text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
|
||||
text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text)
|
||||
|
||||
for i, block in enumerate(code_blocks):
|
||||
text = text.replace(f"\x00CB{i}\x00", block)
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _convert_table(match: re.Match) -> str:
|
||||
"""Convert a Markdown table to a Slack-readable list."""
|
||||
lines = [ln.strip() for ln in match.group(0).strip().splitlines() if ln.strip()]
|
||||
if len(lines) < 2:
|
||||
return match.group(0)
|
||||
headers = [h.strip() for h in lines[0].strip("|").split("|")]
|
||||
start = 2 if re.fullmatch(r"[|\s:\-]+", lines[1]) else 1
|
||||
rows: list[str] = []
|
||||
for line in lines[start:]:
|
||||
cells = [c.strip() for c in line.strip("|").split("|")]
|
||||
cells = (cells + [""] * len(headers))[: len(headers)]
|
||||
parts = [f"**{headers[i]}**: {cells[i]}" for i in range(len(headers)) if cells[i]]
|
||||
if parts:
|
||||
rows.append(" · ".join(parts))
|
||||
return "\n".join(rows)
|
||||
|
||||
@@ -0,0 +1,672 @@
|
||||
"""Telegram channel implementation using python-telegram-bot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import time
|
||||
import unicodedata
|
||||
|
||||
from loguru import logger
|
||||
from telegram import BotCommand, ReplyParameters, Update
|
||||
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import TelegramConfig
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
|
||||
|
||||
|
||||
def _strip_md(s: str) -> str:
|
||||
"""Strip markdown inline formatting from text."""
|
||||
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
|
||||
s = re.sub(r'__(.+?)__', r'\1', s)
|
||||
s = re.sub(r'~~(.+?)~~', r'\1', s)
|
||||
s = re.sub(r'`([^`]+)`', r'\1', s)
|
||||
return s.strip()
|
||||
|
||||
|
||||
def _render_table_box(table_lines: list[str]) -> str:
|
||||
"""Convert markdown pipe-table to compact aligned text for <pre> display."""
|
||||
|
||||
def dw(s: str) -> int:
|
||||
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
|
||||
|
||||
rows: list[list[str]] = []
|
||||
has_sep = False
|
||||
for line in table_lines:
|
||||
cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
|
||||
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
|
||||
has_sep = True
|
||||
continue
|
||||
rows.append(cells)
|
||||
if not rows or not has_sep:
|
||||
return '\n'.join(table_lines)
|
||||
|
||||
ncols = max(len(r) for r in rows)
|
||||
for r in rows:
|
||||
r.extend([''] * (ncols - len(r)))
|
||||
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
|
||||
|
||||
def dr(cells: list[str]) -> str:
|
||||
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
|
||||
|
||||
out = [dr(rows[0])]
|
||||
out.append(' '.join('─' * w for w in widths))
|
||||
for row in rows[1:]:
|
||||
out.append(dr(row))
|
||||
return '\n'.join(out)
|
||||
|
||||
|
||||
def _markdown_to_telegram_html(text: str) -> str:
|
||||
"""
|
||||
Convert markdown to Telegram-safe HTML.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 1. Extract and protect code blocks (preserve content from other processing)
|
||||
code_blocks: list[str] = []
|
||||
def save_code_block(m: re.Match) -> str:
|
||||
code_blocks.append(m.group(1))
|
||||
return f"\x00CB{len(code_blocks) - 1}\x00"
|
||||
|
||||
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
|
||||
|
||||
# 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
|
||||
lines = text.split('\n')
|
||||
rebuilt: list[str] = []
|
||||
li = 0
|
||||
while li < len(lines):
|
||||
if re.match(r'^\s*\|.+\|', lines[li]):
|
||||
tbl: list[str] = []
|
||||
while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
|
||||
tbl.append(lines[li])
|
||||
li += 1
|
||||
box = _render_table_box(tbl)
|
||||
if box != '\n'.join(tbl):
|
||||
code_blocks.append(box)
|
||||
rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
|
||||
else:
|
||||
rebuilt.extend(tbl)
|
||||
else:
|
||||
rebuilt.append(lines[li])
|
||||
li += 1
|
||||
text = '\n'.join(rebuilt)
|
||||
|
||||
# 2. Extract and protect inline code
|
||||
inline_codes: list[str] = []
|
||||
def save_inline_code(m: re.Match) -> str:
|
||||
inline_codes.append(m.group(1))
|
||||
return f"\x00IC{len(inline_codes) - 1}\x00"
|
||||
|
||||
text = re.sub(r'`([^`]+)`', save_inline_code, text)
|
||||
|
||||
# 3. Headers # Title -> just the title text
|
||||
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
# 4. Blockquotes > text -> just the text (before HTML escaping)
|
||||
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
|
||||
|
||||
# 5. Escape HTML special characters
|
||||
text = text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
# 6. Links [text](url) - must be before bold/italic to handle nested cases
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'<a href="\2">\1</a>', text)
|
||||
|
||||
# 7. Bold **text** or __text__
|
||||
text = re.sub(r'\*\*(.+?)\*\*', r'<b>\1</b>', text)
|
||||
text = re.sub(r'__(.+?)__', r'<b>\1</b>', text)
|
||||
|
||||
# 8. Italic _text_ (avoid matching inside words like some_var_name)
|
||||
text = re.sub(r'(?<![a-zA-Z0-9])_([^_]+)_(?![a-zA-Z0-9])', r'<i>\1</i>', text)
|
||||
|
||||
# 9. Strikethrough ~~text~~
|
||||
text = re.sub(r'~~(.+?)~~', r'<s>\1</s>', text)
|
||||
|
||||
# 10. Bullet lists - item -> • item
|
||||
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
|
||||
|
||||
# 11. Restore inline code with HTML tags
|
||||
for i, code in enumerate(inline_codes):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00IC{i}\x00", f"<code>{escaped}</code>")
|
||||
|
||||
# 12. Restore code blocks with HTML tags
|
||||
for i, code in enumerate(code_blocks):
|
||||
# Escape HTML in code content
|
||||
escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
text = text.replace(f"\x00CB{i}\x00", f"<pre><code>{escaped}</code></pre>")
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class TelegramChannel(BaseChannel):
|
||||
"""
|
||||
Telegram channel using long polling.
|
||||
|
||||
Simple and reliable - no webhook/public IP needed.
|
||||
"""
|
||||
|
||||
name = "telegram"
|
||||
|
||||
# Commands registered with Telegram's command menu
|
||||
BOT_COMMANDS = [
|
||||
BotCommand("start", "Start the bot"),
|
||||
BotCommand("new", "Start a new conversation"),
|
||||
BotCommand("stop", "Stop the current task"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TelegramConfig,
|
||||
bus: MessageBus,
|
||||
groq_api_key: str = "",
|
||||
):
|
||||
super().__init__(config, bus)
|
||||
self.config: TelegramConfig = config
|
||||
self.groq_api_key = groq_api_key
|
||||
self._app: Application | None = None
|
||||
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
|
||||
self._media_group_buffers: dict[str, dict] = {}
|
||||
self._media_group_tasks: dict[str, asyncio.Task] = {}
|
||||
self._message_threads: dict[tuple[str, int], int] = {}
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""Preserve Telegram's legacy id|username allowlist matching."""
|
||||
if super().is_allowed(sender_id):
|
||||
return True
|
||||
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
if not allow_list or "*" in allow_list:
|
||||
return False
|
||||
|
||||
sender_str = str(sender_id)
|
||||
if sender_str.count("|") != 1:
|
||||
return False
|
||||
|
||||
sid, username = sender_str.split("|", 1)
|
||||
if not sid.isdigit() or not username:
|
||||
return False
|
||||
|
||||
return sid in allow_list or username in allow_list
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Telegram bot with long polling."""
|
||||
if not self.config.token:
|
||||
logger.error("Telegram bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
|
||||
# Build the application with larger connection pool to avoid pool-timeout on long runs
|
||||
req = HTTPXRequest(
|
||||
connection_pool_size=16,
|
||||
pool_timeout=5.0,
|
||||
connect_timeout=30.0,
|
||||
read_timeout=30.0,
|
||||
proxy=self.config.proxy if self.config.proxy else None,
|
||||
)
|
||||
builder = Application.builder().token(self.config.token).request(req).get_updates_request(req)
|
||||
self._app = builder.build()
|
||||
self._app.add_error_handler(self._on_error)
|
||||
|
||||
# Add command handlers
|
||||
self._app.add_handler(CommandHandler("start", self._on_start))
|
||||
self._app.add_handler(CommandHandler("new", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("stop", self._forward_command))
|
||||
self._app.add_handler(CommandHandler("help", self._on_help))
|
||||
|
||||
# Add message handler for text, photos, voice, documents
|
||||
self._app.add_handler(
|
||||
MessageHandler(
|
||||
(filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
|
||||
& ~filters.COMMAND,
|
||||
self._on_message
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Starting Telegram bot (polling mode)...")
|
||||
|
||||
# Initialize and start polling
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
|
||||
# Get bot info and register command menu
|
||||
bot_info = await self._app.bot.get_me()
|
||||
logger.info("Telegram bot @{} connected", bot_info.username)
|
||||
|
||||
try:
|
||||
await self._app.bot.set_my_commands(self.BOT_COMMANDS)
|
||||
logger.debug("Telegram bot commands registered")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to register bot commands: {}", e)
|
||||
|
||||
# Start polling (this runs until stopped)
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=["message"],
|
||||
drop_pending_updates=True # Ignore old messages on startup
|
||||
)
|
||||
|
||||
# Keep running until stopped
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Telegram bot."""
|
||||
self._running = False
|
||||
|
||||
# Cancel all typing indicators
|
||||
for chat_id in list(self._typing_tasks):
|
||||
self._stop_typing(chat_id)
|
||||
|
||||
for task in self._media_group_tasks.values():
|
||||
task.cancel()
|
||||
self._media_group_tasks.clear()
|
||||
self._media_group_buffers.clear()
|
||||
|
||||
if self._app:
|
||||
logger.info("Stopping Telegram bot...")
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
self._app = None
|
||||
|
||||
@staticmethod
|
||||
def _get_media_type(path: str) -> str:
|
||||
"""Guess media type from file extension."""
|
||||
ext = path.rsplit(".", 1)[-1].lower() if "." in path else ""
|
||||
if ext in ("jpg", "jpeg", "png", "gif", "webp"):
|
||||
return "photo"
|
||||
if ext == "ogg":
|
||||
return "voice"
|
||||
if ext in ("mp3", "m4a", "wav", "aac"):
|
||||
return "audio"
|
||||
return "document"
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Telegram."""
|
||||
if not self._app:
|
||||
logger.warning("Telegram bot not running")
|
||||
return
|
||||
|
||||
# Only stop typing indicator for final responses
|
||||
if not msg.metadata.get("_progress", False):
|
||||
self._stop_typing(msg.chat_id)
|
||||
|
||||
try:
|
||||
chat_id = int(msg.chat_id)
|
||||
except ValueError:
|
||||
logger.error("Invalid chat_id: {}", msg.chat_id)
|
||||
return
|
||||
reply_to_message_id = msg.metadata.get("message_id")
|
||||
message_thread_id = msg.metadata.get("message_thread_id")
|
||||
if message_thread_id is None and reply_to_message_id is not None:
|
||||
message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
|
||||
thread_kwargs = {}
|
||||
if message_thread_id is not None:
|
||||
thread_kwargs["message_thread_id"] = message_thread_id
|
||||
|
||||
reply_params = None
|
||||
if self.config.reply_to_message:
|
||||
if reply_to_message_id:
|
||||
reply_params = ReplyParameters(
|
||||
message_id=reply_to_message_id,
|
||||
allow_sending_without_reply=True
|
||||
)
|
||||
|
||||
# Send media files
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
media_type = self._get_media_type(media_path)
|
||||
sender = {
|
||||
"photo": self._app.bot.send_photo,
|
||||
"voice": self._app.bot.send_voice,
|
||||
"audio": self._app.bot.send_audio,
|
||||
}.get(media_type, self._app.bot.send_document)
|
||||
param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
|
||||
with open(media_path, 'rb') as f:
|
||||
await sender(
|
||||
chat_id=chat_id,
|
||||
**{param: f},
|
||||
reply_parameters=reply_params,
|
||||
**thread_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
filename = media_path.rsplit("/", 1)[-1]
|
||||
logger.error("Failed to send media {}: {}", media_path, e)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=f"[Failed to send: {filename}]",
|
||||
reply_parameters=reply_params,
|
||||
**thread_kwargs,
|
||||
)
|
||||
|
||||
# Send text content
|
||||
if msg.content and msg.content != "[empty message]":
|
||||
is_progress = msg.metadata.get("_progress", False)
|
||||
|
||||
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
||||
# Final response: simulate streaming via draft, then persist
|
||||
if not is_progress:
|
||||
await self._send_with_streaming(chat_id, chunk, reply_params, thread_kwargs)
|
||||
else:
|
||||
await self._send_text(chat_id, chunk, reply_params, thread_kwargs)
|
||||
|
||||
async def _send_text(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_params=None,
|
||||
thread_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
"""Send a plain text message with HTML fallback."""
|
||||
try:
|
||||
html = _markdown_to_telegram_html(text)
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id, text=html, parse_mode="HTML",
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._app.bot.send_message(
|
||||
chat_id=chat_id,
|
||||
text=text,
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error("Error sending Telegram message: {}", e2)
|
||||
|
||||
async def _send_with_streaming(
|
||||
self,
|
||||
chat_id: int,
|
||||
text: str,
|
||||
reply_params=None,
|
||||
thread_kwargs: dict | None = None,
|
||||
) -> None:
|
||||
"""Simulate streaming via send_message_draft, then persist with send_message."""
|
||||
draft_id = int(time.time() * 1000) % (2**31)
|
||||
try:
|
||||
step = max(len(text) // 8, 40)
|
||||
for i in range(step, len(text), step):
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id, draft_id=draft_id, text=text[:i],
|
||||
)
|
||||
await asyncio.sleep(0.04)
|
||||
await self._app.bot.send_message_draft(
|
||||
chat_id=chat_id, draft_id=draft_id, text=text,
|
||||
)
|
||||
await asyncio.sleep(0.15)
|
||||
except Exception:
|
||||
pass
|
||||
await self._send_text(chat_id, text, reply_params, thread_kwargs)
|
||||
|
||||
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /start command."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
user = update.effective_user
|
||||
await update.message.reply_text(
|
||||
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
||||
"Send me a message and I'll respond!\n"
|
||||
"Type /help to see available commands."
|
||||
)
|
||||
|
||||
async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle /help command, bypassing ACL so all users can access it."""
|
||||
if not update.message:
|
||||
return
|
||||
await update.message.reply_text(
|
||||
"🐈 nanobot commands:\n"
|
||||
"/new — Start a new conversation\n"
|
||||
"/stop — Stop the current task\n"
|
||||
"/help — Show available commands"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sender_id(user) -> str:
|
||||
"""Build sender_id with username for allowlist matching."""
|
||||
sid = str(user.id)
|
||||
return f"{sid}|{user.username}" if user.username else sid
|
||||
|
||||
@staticmethod
|
||||
def _derive_topic_session_key(message) -> str | None:
|
||||
"""Derive topic-scoped session key for non-private Telegram chats."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message.chat.type == "private" or message_thread_id is None:
|
||||
return None
|
||||
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
|
||||
|
||||
@staticmethod
|
||||
def _build_message_metadata(message, user) -> dict:
|
||||
"""Build common Telegram inbound metadata payload."""
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"user_id": user.id,
|
||||
"username": user.username,
|
||||
"first_name": user.first_name,
|
||||
"is_group": message.chat.type != "private",
|
||||
"message_thread_id": getattr(message, "message_thread_id", None),
|
||||
"is_forum": bool(getattr(message.chat, "is_forum", False)),
|
||||
}
|
||||
|
||||
def _remember_thread_context(self, message) -> None:
|
||||
"""Cache topic thread id by chat/message id for follow-up replies."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message_thread_id is None:
|
||||
return
|
||||
key = (str(message.chat_id), message.message_id)
|
||||
self._message_threads[key] = message_thread_id
|
||||
if len(self._message_threads) > 1000:
|
||||
self._message_threads.pop(next(iter(self._message_threads)))
|
||||
|
||||
async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Forward slash commands to the bus for unified handling in AgentLoop."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
self._remember_thread_context(message)
|
||||
await self._handle_message(
|
||||
sender_id=self._sender_id(user),
|
||||
chat_id=str(message.chat_id),
|
||||
content=message.text,
|
||||
metadata=self._build_message_metadata(message, user),
|
||||
session_key=self._derive_topic_session_key(message),
|
||||
)
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming messages (text, photos, voice, documents)."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
|
||||
message = update.message
|
||||
user = update.effective_user
|
||||
chat_id = message.chat_id
|
||||
sender_id = self._sender_id(user)
|
||||
self._remember_thread_context(message)
|
||||
|
||||
# Store chat_id for replies
|
||||
self._chat_ids[sender_id] = chat_id
|
||||
|
||||
# Build content from text and/or media
|
||||
content_parts = []
|
||||
media_paths = []
|
||||
|
||||
# Text content
|
||||
if message.text:
|
||||
content_parts.append(message.text)
|
||||
if message.caption:
|
||||
content_parts.append(message.caption)
|
||||
|
||||
# Handle media files
|
||||
media_file = None
|
||||
media_type = None
|
||||
|
||||
if message.photo:
|
||||
media_file = message.photo[-1] # Largest photo
|
||||
media_type = "image"
|
||||
elif message.voice:
|
||||
media_file = message.voice
|
||||
media_type = "voice"
|
||||
elif message.audio:
|
||||
media_file = message.audio
|
||||
media_type = "audio"
|
||||
elif message.document:
|
||||
media_file = message.document
|
||||
media_type = "file"
|
||||
|
||||
# Download media if present
|
||||
if media_file and self._app:
|
||||
try:
|
||||
file = await self._app.bot.get_file(media_file.file_id)
|
||||
ext = self._get_extension(
|
||||
media_type,
|
||||
getattr(media_file, 'mime_type', None),
|
||||
getattr(media_file, 'file_name', None),
|
||||
)
|
||||
media_dir = get_media_dir("telegram")
|
||||
|
||||
file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
|
||||
await file.download_to_drive(str(file_path))
|
||||
|
||||
media_paths.append(str(file_path))
|
||||
|
||||
# Handle voice transcription
|
||||
if media_type == "voice" or media_type == "audio":
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
|
||||
transcription = await transcriber.transcribe(file_path)
|
||||
if transcription:
|
||||
logger.info("Transcribed {}: {}...", media_type, transcription[:50])
|
||||
content_parts.append(f"[transcription: {transcription}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
else:
|
||||
content_parts.append(f"[{media_type}: {file_path}]")
|
||||
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to download media: {}", e)
|
||||
content_parts.append(f"[{media_type}: download failed]")
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else "[empty message]"
|
||||
|
||||
logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
|
||||
|
||||
str_chat_id = str(chat_id)
|
||||
metadata = self._build_message_metadata(message, user)
|
||||
session_key = self._derive_topic_session_key(message)
|
||||
|
||||
# Telegram media groups: buffer briefly, forward as one aggregated turn.
|
||||
if media_group_id := getattr(message, "media_group_id", None):
|
||||
key = f"{str_chat_id}:{media_group_id}"
|
||||
if key not in self._media_group_buffers:
|
||||
self._media_group_buffers[key] = {
|
||||
"sender_id": sender_id, "chat_id": str_chat_id,
|
||||
"contents": [], "media": [],
|
||||
"metadata": metadata,
|
||||
"session_key": session_key,
|
||||
}
|
||||
self._start_typing(str_chat_id)
|
||||
buf = self._media_group_buffers[key]
|
||||
if content and content != "[empty message]":
|
||||
buf["contents"].append(content)
|
||||
buf["media"].extend(media_paths)
|
||||
if key not in self._media_group_tasks:
|
||||
self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
|
||||
return
|
||||
|
||||
# Start typing indicator before processing
|
||||
self._start_typing(str_chat_id)
|
||||
|
||||
# Forward to the message bus
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str_chat_id,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
async def _flush_media_group(self, key: str) -> None:
|
||||
"""Wait briefly, then forward buffered media-group as one turn."""
|
||||
try:
|
||||
await asyncio.sleep(0.6)
|
||||
if not (buf := self._media_group_buffers.pop(key, None)):
|
||||
return
|
||||
content = "\n".join(buf["contents"]) or "[empty message]"
|
||||
await self._handle_message(
|
||||
sender_id=buf["sender_id"], chat_id=buf["chat_id"],
|
||||
content=content, media=list(dict.fromkeys(buf["media"])),
|
||||
metadata=buf["metadata"],
|
||||
session_key=buf.get("session_key"),
|
||||
)
|
||||
finally:
|
||||
self._media_group_tasks.pop(key, None)
|
||||
|
||||
def _start_typing(self, chat_id: str) -> None:
|
||||
"""Start sending 'typing...' indicator for a chat."""
|
||||
# Cancel any existing typing task for this chat
|
||||
self._stop_typing(chat_id)
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
|
||||
|
||||
def _stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the typing indicator for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
async def _typing_loop(self, chat_id: str) -> None:
|
||||
"""Repeatedly send 'typing' action until cancelled."""
|
||||
try:
|
||||
while self._app:
|
||||
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||
await asyncio.sleep(4)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||
|
||||
async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Log polling / handler errors instead of silently swallowing them."""
|
||||
logger.error("Telegram error: {}", context.error)
|
||||
|
||||
def _get_extension(
|
||||
self,
|
||||
media_type: str,
|
||||
mime_type: str | None,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get file extension based on media type or original filename."""
|
||||
if mime_type:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
|
||||
"audio/ogg": ".ogg", "audio/mpeg": ".mp3", "audio/mp4": ".m4a",
|
||||
}
|
||||
if mime_type in ext_map:
|
||||
return ext_map[mime_type]
|
||||
|
||||
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
|
||||
if ext := type_map.get(media_type, ""):
|
||||
return ext
|
||||
|
||||
if filename:
|
||||
from pathlib import Path
|
||||
|
||||
return "".join(Path(filename).suffixes)
|
||||
|
||||
return ""
|
||||
@@ -0,0 +1,170 @@
|
||||
"""WhatsApp channel implementation using Node.js bridge."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
from collections import OrderedDict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import WhatsAppConfig
|
||||
|
||||
|
||||
class WhatsAppChannel(BaseChannel):
|
||||
"""
|
||||
WhatsApp channel that connects to a Node.js bridge.
|
||||
|
||||
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
|
||||
Communication between Python and Node.js is via WebSocket.
|
||||
"""
|
||||
|
||||
name = "whatsapp"
|
||||
|
||||
def __init__(self, config: WhatsAppConfig, bus: MessageBus):
|
||||
super().__init__(config, bus)
|
||||
self.config: WhatsAppConfig = config
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||
import websockets
|
||||
|
||||
bridge_url = self.config.bridge_url
|
||||
|
||||
logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
|
||||
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with websockets.connect(bridge_url) as ws:
|
||||
self._ws = ws
|
||||
# Send auth token if configured
|
||||
if self.config.bridge_token:
|
||||
await ws.send(json.dumps({"type": "auth", "token": self.config.bridge_token}))
|
||||
self._connected = True
|
||||
logger.info("Connected to WhatsApp bridge")
|
||||
|
||||
# Listen for messages
|
||||
async for message in ws:
|
||||
try:
|
||||
await self._handle_bridge_message(message)
|
||||
except Exception as e:
|
||||
logger.error("Error handling bridge message: {}", e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
self._ws = None
|
||||
logger.warning("WhatsApp bridge connection error: {}", e)
|
||||
|
||||
if self._running:
|
||||
logger.info("Reconnecting in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the WhatsApp channel."""
|
||||
self._running = False
|
||||
self._connected = False
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WhatsApp."""
|
||||
if not self._ws or not self._connected:
|
||||
logger.warning("WhatsApp bridge not connected")
|
||||
return
|
||||
|
||||
try:
|
||||
payload = {
|
||||
"type": "send",
|
||||
"to": msg.chat_id,
|
||||
"text": msg.content
|
||||
}
|
||||
await self._ws.send(json.dumps(payload, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error("Error sending WhatsApp message: {}", e)
|
||||
|
||||
async def _handle_bridge_message(self, raw: str) -> None:
|
||||
"""Handle a message from the bridge."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from bridge: {}", raw[:100])
|
||||
return
|
||||
|
||||
msg_type = data.get("type")
|
||||
|
||||
if msg_type == "message":
|
||||
# Incoming message from WhatsApp
|
||||
# Deprecated by whatsapp: old phone number style typically: <phone>@s.whatspp.net
|
||||
pn = data.get("pn", "")
|
||||
# New LID sytle typically:
|
||||
sender = data.get("sender", "")
|
||||
content = data.get("content", "")
|
||||
message_id = data.get("id", "")
|
||||
|
||||
if message_id:
|
||||
if message_id in self._processed_message_ids:
|
||||
return
|
||||
self._processed_message_ids[message_id] = None
|
||||
while len(self._processed_message_ids) > 1000:
|
||||
self._processed_message_ids.popitem(last=False)
|
||||
|
||||
# Extract just the phone number or lid as chat_id
|
||||
user_id = pn if pn else sender
|
||||
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
|
||||
logger.info("Sender {}", sender)
|
||||
|
||||
# Handle voice transcription if it's a voice message
|
||||
if content == "[Voice Message]":
|
||||
logger.info("Voice message received from {}, but direct download from bridge is not yet supported.", sender_id)
|
||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||
|
||||
# Extract media paths (images/documents/videos downloaded by the bridge)
|
||||
media_paths = data.get("media") or []
|
||||
|
||||
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
|
||||
if media_paths:
|
||||
for p in media_paths:
|
||||
mime, _ = mimetypes.guess_type(p)
|
||||
media_type = "image" if mime and mime.startswith("image/") else "file"
|
||||
media_tag = f"[{media_type}: {p}]"
|
||||
content = f"{content}\n{media_tag}" if content else media_tag
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=sender, # Use full LID for replies
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"is_group": data.get("isGroup", False)
|
||||
}
|
||||
)
|
||||
|
||||
elif msg_type == "status":
|
||||
# Connection status update
|
||||
status = data.get("status")
|
||||
logger.info("WhatsApp status: {}", status)
|
||||
|
||||
if status == "connected":
|
||||
self._connected = True
|
||||
elif status == "disconnected":
|
||||
self._connected = False
|
||||
|
||||
elif msg_type == "qr":
|
||||
# QR code for authentication
|
||||
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
|
||||
|
||||
elif msg_type == "error":
|
||||
logger.error("WhatsApp bridge error: {}", data.get('error'))
|
||||
@@ -0,0 +1 @@
|
||||
"""CLI module for nanobot."""
|
||||
@@ -0,0 +1,975 @@
|
||||
"""CLI commands for nanobot."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Force UTF-8 encoding for Windows console
|
||||
if sys.platform == "win32":
|
||||
if sys.stdout.encoding != "utf-8":
|
||||
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||||
# Re-open stdout/stderr with UTF-8 encoding
|
||||
try:
|
||||
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||||
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import typer
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from nanobot import __logo__, __version__
|
||||
from nanobot.config.paths import get_workspace_path
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.helpers import sync_workspace_templates
|
||||
|
||||
app = typer.Typer(
|
||||
name="nanobot",
|
||||
help=f"{__logo__} nanobot - Personal AI Assistant",
|
||||
no_args_is_help=True,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
EXIT_COMMANDS = {"exit", "quit", "/exit", "/quit", ":q"}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI input: prompt_toolkit for editing, paste, history, and display
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PROMPT_SESSION: PromptSession | None = None
|
||||
_SAVED_TERM_ATTRS = None # original termios settings, restored on exit
|
||||
|
||||
|
||||
def _flush_pending_tty_input() -> None:
|
||||
"""Drop unread keypresses typed while the model was generating output."""
|
||||
try:
|
||||
fd = sys.stdin.fileno()
|
||||
if not os.isatty(fd):
|
||||
return
|
||||
except Exception:
|
||||
return
|
||||
|
||||
try:
|
||||
import termios
|
||||
termios.tcflush(fd, termios.TCIFLUSH)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
while True:
|
||||
ready, _, _ = select.select([fd], [], [], 0)
|
||||
if not ready:
|
||||
break
|
||||
if not os.read(fd, 4096):
|
||||
break
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
def _restore_terminal() -> None:
|
||||
"""Restore terminal to its original state (echo, line buffering, etc.)."""
|
||||
if _SAVED_TERM_ATTRS is None:
|
||||
return
|
||||
try:
|
||||
import termios
|
||||
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _init_prompt_session() -> None:
|
||||
"""Create the prompt_toolkit session with persistent file history."""
|
||||
global _PROMPT_SESSION, _SAVED_TERM_ATTRS
|
||||
|
||||
# Save terminal state so we can restore it on exit
|
||||
try:
|
||||
import termios
|
||||
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from nanobot.config.paths import get_cli_history_path
|
||||
|
||||
history_file = get_cli_history_path()
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_PROMPT_SESSION = PromptSession(
|
||||
history=FileHistory(str(history_file)),
|
||||
enable_open_in_editor=False,
|
||||
multiline=False, # Enter submits (single line mode)
|
||||
)
|
||||
|
||||
|
||||
def _print_agent_response(response: str, render_markdown: bool) -> None:
|
||||
"""Render assistant response with consistent terminal styling."""
|
||||
content = response or ""
|
||||
body = Markdown(content) if render_markdown else Text(content)
|
||||
console.print()
|
||||
console.print(f"[cyan]{__logo__} nanobot[/cyan]")
|
||||
console.print(body)
|
||||
console.print()
|
||||
|
||||
|
||||
def _is_exit_command(command: str) -> bool:
|
||||
"""Return True when input should end interactive chat."""
|
||||
return command.lower() in EXIT_COMMANDS
|
||||
|
||||
|
||||
async def _read_interactive_input_async() -> str:
|
||||
"""Read user input using prompt_toolkit (handles paste, history, display).
|
||||
|
||||
prompt_toolkit natively handles:
|
||||
- Multiline paste (bracketed paste mode)
|
||||
- History navigation (up/down arrows)
|
||||
- Clean display (no ghost characters or artifacts)
|
||||
"""
|
||||
if _PROMPT_SESSION is None:
|
||||
raise RuntimeError("Call _init_prompt_session() first")
|
||||
try:
|
||||
with patch_stdout():
|
||||
return await _PROMPT_SESSION.prompt_async(
|
||||
HTML("<b fg='ansiblue'>You:</b> "),
|
||||
)
|
||||
except EOFError as exc:
|
||||
raise KeyboardInterrupt from exc
|
||||
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
if value:
|
||||
console.print(f"{__logo__} nanobot v{__version__}")
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
version: bool = typer.Option(
|
||||
None, "--version", "-v", callback=version_callback, is_eager=True
|
||||
),
|
||||
):
|
||||
"""nanobot - Personal AI Assistant."""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Onboard / Setup
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def onboard():
|
||||
"""Initialize nanobot configuration and workspace."""
|
||||
from nanobot.config.loader import get_config_path, load_config, save_config
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
config_path = get_config_path()
|
||||
|
||||
if config_path.exists():
|
||||
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
||||
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
|
||||
console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
|
||||
if typer.confirm("Overwrite?"):
|
||||
config = Config()
|
||||
save_config(config)
|
||||
console.print(f"[green]✓[/green] Config reset to defaults at {config_path}")
|
||||
else:
|
||||
config = load_config()
|
||||
save_config(config)
|
||||
console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
|
||||
else:
|
||||
save_config(Config())
|
||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||
|
||||
# Create workspace
|
||||
workspace = get_workspace_path()
|
||||
|
||||
if not workspace.exists():
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
||||
|
||||
sync_workspace_templates(workspace)
|
||||
|
||||
console.print(f"\n{__logo__} nanobot is ready!")
|
||||
console.print("\nNext steps:")
|
||||
console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]")
|
||||
console.print(" Get one at: https://openrouter.ai/keys")
|
||||
console.print(" 2. Chat: [cyan]nanobot agent -m \"Hello!\"[/cyan]")
|
||||
console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _make_provider(config: Config):
|
||||
"""Create the appropriate LLM provider from config."""
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
|
||||
# OpenAI Codex (OAuth)
|
||||
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
||||
return OpenAICodexProvider(default_model=model)
|
||||
|
||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
if provider_name == "custom":
|
||||
return CustomProvider(
|
||||
api_key=p.api_key if p else "no-key",
|
||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
||||
default_model=model,
|
||||
)
|
||||
|
||||
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
||||
if provider_name == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||
console.print("Use the model field to specify the deployment name.")
|
||||
raise typer.Exit(1)
|
||||
|
||||
return AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
spec = find_by_name(provider_name)
|
||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
|
||||
return LiteLLMProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
|
||||
|
||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||
"""Load config and optionally override the active workspace."""
|
||||
from nanobot.config.loader import load_config, set_config_path
|
||||
|
||||
config_path = None
|
||||
if config:
|
||||
config_path = Path(config).expanduser().resolve()
|
||||
if not config_path.exists():
|
||||
console.print(f"[red]Error: Config file not found: {config_path}[/red]")
|
||||
raise typer.Exit(1)
|
||||
set_config_path(config_path)
|
||||
console.print(f"[dim]Using config: {config_path}[/dim]")
|
||||
|
||||
loaded = load_config(config_path)
|
||||
if workspace:
|
||||
loaded.agents.defaults.workspace = workspace
|
||||
return loaded
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Gateway / Server
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def gateway(
|
||||
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Start the nanobot gateway."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.config.paths import get_cron_dir
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
if verbose:
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
config = _load_runtime_config(config, workspace)
|
||||
|
||||
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
session_manager = SessionManager(config.workspace_path)
|
||||
|
||||
# Create cron service first (callback set after agent creation)
|
||||
cron_store_path = get_cron_dir() / "jobs.json"
|
||||
cron = CronService(cron_store_path)
|
||||
|
||||
# Create agent with cron service
|
||||
agent = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
temperature=config.agents.defaults.temperature,
|
||||
max_tokens=config.agents.defaults.max_tokens,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
memory_window=config.agents.defaults.memory_window,
|
||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
)
|
||||
|
||||
# Set cron callback (needs agent)
|
||||
async def on_cron_job(job: CronJob) -> str | None:
|
||||
"""Execute a cron job through the agent."""
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
reminder_note = (
|
||||
"[Scheduled Task] Timer finished.\n\n"
|
||||
f"Task '{job.name}' has been triggered.\n"
|
||||
f"Scheduled instruction: {job.payload.message}"
|
||||
)
|
||||
|
||||
# Prevent the agent from scheduling new cron jobs during execution
|
||||
cron_tool = agent.tools.get("cron")
|
||||
cron_token = None
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_token = cron_tool.set_cron_context(True)
|
||||
try:
|
||||
response = await agent.process_direct(
|
||||
reminder_note,
|
||||
session_key=f"cron:{job.id}",
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to or "direct",
|
||||
)
|
||||
finally:
|
||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||
cron_tool.reset_cron_context(cron_token)
|
||||
|
||||
message_tool = agent.tools.get("message")
|
||||
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||
return response
|
||||
|
||||
if job.payload.deliver and job.payload.to and response:
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response
|
||||
))
|
||||
return response
|
||||
cron.on_job = on_cron_job
|
||||
|
||||
# Create channel manager
|
||||
channels = ChannelManager(config, bus)
|
||||
|
||||
def _pick_heartbeat_target() -> tuple[str, str]:
|
||||
"""Pick a routable channel/chat target for heartbeat-triggered messages."""
|
||||
enabled = set(channels.enabled_channels)
|
||||
# Prefer the most recently updated non-internal session on an enabled channel.
|
||||
for item in session_manager.list_sessions():
|
||||
key = item.get("key") or ""
|
||||
if ":" not in key:
|
||||
continue
|
||||
channel, chat_id = key.split(":", 1)
|
||||
if channel in {"cli", "system"}:
|
||||
continue
|
||||
if channel in enabled and chat_id:
|
||||
return channel, chat_id
|
||||
# Fallback keeps prior behavior but remains explicit.
|
||||
return "cli", "direct"
|
||||
|
||||
# Create heartbeat service
|
||||
async def on_heartbeat_execute(tasks: str) -> str:
|
||||
"""Phase 2: execute heartbeat tasks through the full agent loop."""
|
||||
channel, chat_id = _pick_heartbeat_target()
|
||||
|
||||
async def _silent(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
return await agent.process_direct(
|
||||
tasks,
|
||||
session_key="heartbeat",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
on_progress=_silent,
|
||||
)
|
||||
|
||||
async def on_heartbeat_notify(response: str) -> None:
|
||||
"""Deliver a heartbeat response to the user's channel."""
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
channel, chat_id = _pick_heartbeat_target()
|
||||
if channel == "cli":
|
||||
return # No external channel available to deliver to
|
||||
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
||||
|
||||
hb_cfg = config.gateway.heartbeat
|
||||
heartbeat = HeartbeatService(
|
||||
workspace=config.workspace_path,
|
||||
provider=provider,
|
||||
model=agent.model,
|
||||
on_execute=on_heartbeat_execute,
|
||||
on_notify=on_heartbeat_notify,
|
||||
interval_s=hb_cfg.interval_s,
|
||||
enabled=hb_cfg.enabled,
|
||||
)
|
||||
|
||||
if channels.enabled_channels:
|
||||
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
|
||||
else:
|
||||
console.print("[yellow]Warning: No channels enabled[/yellow]")
|
||||
|
||||
cron_status = cron.status()
|
||||
if cron_status["jobs"] > 0:
|
||||
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
|
||||
|
||||
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
||||
|
||||
async def run():
|
||||
try:
|
||||
await cron.start()
|
||||
await heartbeat.start()
|
||||
await asyncio.gather(
|
||||
agent.run(),
|
||||
channels.start_all(),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\nShutting down...")
|
||||
finally:
|
||||
await agent.close_mcp()
|
||||
heartbeat.stop()
|
||||
cron.stop()
|
||||
agent.stop()
|
||||
await channels.stop_all()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Agent Commands
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def agent(
|
||||
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
|
||||
session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
|
||||
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
|
||||
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
||||
):
|
||||
"""Interact with the agent directly."""
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.paths import get_cron_dir
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
config = _load_runtime_config(config, workspace)
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
|
||||
# Create cron service for tool usage (no callback needed for CLI unless running)
|
||||
cron_store_path = get_cron_dir() / "jobs.json"
|
||||
cron = CronService(cron_store_path)
|
||||
|
||||
if logs:
|
||||
logger.enable("nanobot")
|
||||
else:
|
||||
logger.disable("nanobot")
|
||||
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
temperature=config.agents.defaults.temperature,
|
||||
max_tokens=config.agents.defaults.max_tokens,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
memory_window=config.agents.defaults.memory_window,
|
||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
cron_service=cron,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
)
|
||||
|
||||
# Show spinner when logs are off (no output to miss); skip when logs are on
|
||||
def _thinking_ctx():
|
||||
if logs:
|
||||
from contextlib import nullcontext
|
||||
return nullcontext()
|
||||
# Animated spinner is safe to use with prompt_toolkit input handling
|
||||
return console.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
|
||||
|
||||
async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
ch = agent_loop.channels_config
|
||||
if ch and tool_hint and not ch.send_tool_hints:
|
||||
return
|
||||
if ch and not tool_hint and not ch.send_progress:
|
||||
return
|
||||
console.print(f" [dim]↳ {content}[/dim]")
|
||||
|
||||
if message:
|
||||
# Single message mode — direct call, no bus needed
|
||||
async def run_once():
|
||||
with _thinking_ctx():
|
||||
response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
|
||||
_print_agent_response(response, render_markdown=markdown)
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_once())
|
||||
else:
|
||||
# Interactive mode — route through bus like other channels
|
||||
from nanobot.bus.events import InboundMessage
|
||||
_init_prompt_session()
|
||||
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
||||
|
||||
if ":" in session_id:
|
||||
cli_channel, cli_chat_id = session_id.split(":", 1)
|
||||
else:
|
||||
cli_channel, cli_chat_id = "cli", session_id
|
||||
|
||||
def _handle_signal(signum, frame):
|
||||
sig_name = signal.Signals(signum).name
|
||||
_restore_terminal()
|
||||
console.print(f"\nReceived {sig_name}, goodbye!")
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, _handle_signal)
|
||||
signal.signal(signal.SIGTERM, _handle_signal)
|
||||
# SIGHUP is not available on Windows
|
||||
if hasattr(signal, 'SIGHUP'):
|
||||
signal.signal(signal.SIGHUP, _handle_signal)
|
||||
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
|
||||
# SIGPIPE is not available on Windows
|
||||
if hasattr(signal, 'SIGPIPE'):
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
|
||||
async def run_interactive():
|
||||
bus_task = asyncio.create_task(agent_loop.run())
|
||||
turn_done = asyncio.Event()
|
||||
turn_done.set()
|
||||
turn_response: list[str] = []
|
||||
|
||||
async def _consume_outbound():
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
if msg.metadata.get("_progress"):
|
||||
is_tool_hint = msg.metadata.get("_tool_hint", False)
|
||||
ch = agent_loop.channels_config
|
||||
if ch and is_tool_hint and not ch.send_tool_hints:
|
||||
pass
|
||||
elif ch and not is_tool_hint and not ch.send_progress:
|
||||
pass
|
||||
else:
|
||||
console.print(f" [dim]↳ {msg.content}[/dim]")
|
||||
elif not turn_done.is_set():
|
||||
if msg.content:
|
||||
turn_response.append(msg.content)
|
||||
turn_done.set()
|
||||
elif msg.content:
|
||||
console.print()
|
||||
_print_agent_response(msg.content, render_markdown=markdown)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
outbound_task = asyncio.create_task(_consume_outbound())
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
_flush_pending_tty_input()
|
||||
user_input = await _read_interactive_input_async()
|
||||
command = user_input.strip()
|
||||
if not command:
|
||||
continue
|
||||
|
||||
if _is_exit_command(command):
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
|
||||
turn_done.clear()
|
||||
turn_response.clear()
|
||||
|
||||
await bus.publish_inbound(InboundMessage(
|
||||
channel=cli_channel,
|
||||
sender_id="user",
|
||||
chat_id=cli_chat_id,
|
||||
content=user_input,
|
||||
))
|
||||
|
||||
with _thinking_ctx():
|
||||
await turn_done.wait()
|
||||
|
||||
if turn_response:
|
||||
_print_agent_response(turn_response[0], render_markdown=markdown)
|
||||
except KeyboardInterrupt:
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
except EOFError:
|
||||
_restore_terminal()
|
||||
console.print("\nGoodbye!")
|
||||
break
|
||||
finally:
|
||||
agent_loop.stop()
|
||||
outbound_task.cancel()
|
||||
await asyncio.gather(bus_task, outbound_task, return_exceptions=True)
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
asyncio.run(run_interactive())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Channel Commands
|
||||
# ============================================================================
|
||||
|
||||
|
||||
channels_app = typer.Typer(help="Manage channels")
|
||||
app.add_typer(channels_app, name="channels")
|
||||
|
||||
|
||||
@channels_app.command("status")
|
||||
def channels_status():
|
||||
"""Show channel status."""
|
||||
from nanobot.config.loader import load_config
|
||||
|
||||
config = load_config()
|
||||
|
||||
table = Table(title="Channel Status")
|
||||
table.add_column("Channel", style="cyan")
|
||||
table.add_column("Enabled", style="green")
|
||||
table.add_column("Configuration", style="yellow")
|
||||
|
||||
# WhatsApp
|
||||
wa = config.channels.whatsapp
|
||||
table.add_row(
|
||||
"WhatsApp",
|
||||
"✓" if wa.enabled else "✗",
|
||||
wa.bridge_url
|
||||
)
|
||||
|
||||
dc = config.channels.discord
|
||||
table.add_row(
|
||||
"Discord",
|
||||
"✓" if dc.enabled else "✗",
|
||||
dc.gateway_url
|
||||
)
|
||||
|
||||
# Feishu
|
||||
fs = config.channels.feishu
|
||||
fs_config = f"app_id: {fs.app_id[:10]}..." if fs.app_id else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Feishu",
|
||||
"✓" if fs.enabled else "✗",
|
||||
fs_config
|
||||
)
|
||||
|
||||
# Mochat
|
||||
mc = config.channels.mochat
|
||||
mc_base = mc.base_url or "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Mochat",
|
||||
"✓" if mc.enabled else "✗",
|
||||
mc_base
|
||||
)
|
||||
|
||||
# Telegram
|
||||
tg = config.channels.telegram
|
||||
tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Telegram",
|
||||
"✓" if tg.enabled else "✗",
|
||||
tg_config
|
||||
)
|
||||
|
||||
# Slack
|
||||
slack = config.channels.slack
|
||||
slack_config = "socket" if slack.app_token and slack.bot_token else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Slack",
|
||||
"✓" if slack.enabled else "✗",
|
||||
slack_config
|
||||
)
|
||||
|
||||
# DingTalk
|
||||
dt = config.channels.dingtalk
|
||||
dt_config = f"client_id: {dt.client_id[:10]}..." if dt.client_id else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"DingTalk",
|
||||
"✓" if dt.enabled else "✗",
|
||||
dt_config
|
||||
)
|
||||
|
||||
# QQ
|
||||
qq = config.channels.qq
|
||||
qq_config = f"app_id: {qq.app_id[:10]}..." if qq.app_id else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"QQ",
|
||||
"✓" if qq.enabled else "✗",
|
||||
qq_config
|
||||
)
|
||||
|
||||
# Email
|
||||
em = config.channels.email
|
||||
em_config = em.imap_host if em.imap_host else "[dim]not configured[/dim]"
|
||||
table.add_row(
|
||||
"Email",
|
||||
"✓" if em.enabled else "✗",
|
||||
em_config
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def _get_bridge_dir() -> Path:
|
||||
"""Get the bridge directory, setting it up if needed."""
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
# User's bridge location
|
||||
from nanobot.config.paths import get_bridge_install_dir
|
||||
|
||||
user_bridge = get_bridge_install_dir()
|
||||
|
||||
# Check if already built
|
||||
if (user_bridge / "dist" / "index.js").exists():
|
||||
return user_bridge
|
||||
|
||||
# Check for npm
|
||||
if not shutil.which("npm"):
|
||||
console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Find source bridge: first check package data, then source dir
|
||||
pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed)
|
||||
src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev)
|
||||
|
||||
source = None
|
||||
if (pkg_bridge / "package.json").exists():
|
||||
source = pkg_bridge
|
||||
elif (src_bridge / "package.json").exists():
|
||||
source = src_bridge
|
||||
|
||||
if not source:
|
||||
console.print("[red]Bridge source not found.[/red]")
|
||||
console.print("Try reinstalling: pip install --force-reinstall nanobot")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"{__logo__} Setting up bridge...")
|
||||
|
||||
# Copy to user directory
|
||||
user_bridge.parent.mkdir(parents=True, exist_ok=True)
|
||||
if user_bridge.exists():
|
||||
shutil.rmtree(user_bridge)
|
||||
shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
|
||||
|
||||
# Install and build
|
||||
try:
|
||||
console.print(" Installing dependencies...")
|
||||
subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
|
||||
|
||||
console.print(" Building...")
|
||||
subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
|
||||
|
||||
console.print("[green]✓[/green] Bridge ready\n")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]Build failed: {e}[/red]")
|
||||
if e.stderr:
|
||||
console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
return user_bridge
|
||||
|
||||
|
||||
@channels_app.command("login")
|
||||
def channels_login():
|
||||
"""Link device via QR code."""
|
||||
import subprocess
|
||||
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.paths import get_runtime_subdir
|
||||
|
||||
config = load_config()
|
||||
bridge_dir = _get_bridge_dir()
|
||||
|
||||
console.print(f"{__logo__} Starting bridge...")
|
||||
console.print("Scan the QR code to connect.\n")
|
||||
|
||||
env = {**os.environ}
|
||||
if config.channels.whatsapp.bridge_token:
|
||||
env["BRIDGE_TOKEN"] = config.channels.whatsapp.bridge_token
|
||||
env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
|
||||
|
||||
try:
|
||||
subprocess.run(["npm", "start"], cwd=bridge_dir, check=True, env=env)
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(f"[red]Bridge failed: {e}[/red]")
|
||||
except FileNotFoundError:
|
||||
console.print("[red]npm not found. Please install Node.js.[/red]")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Status Commands
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def status():
|
||||
"""Show nanobot status."""
|
||||
from nanobot.config.loader import get_config_path, load_config
|
||||
|
||||
config_path = get_config_path()
|
||||
config = load_config()
|
||||
workspace = config.workspace_path
|
||||
|
||||
console.print(f"{__logo__} nanobot Status\n")
|
||||
|
||||
console.print(f"Config: {config_path} {'[green]✓[/green]' if config_path.exists() else '[red]✗[/red]'}")
|
||||
console.print(f"Workspace: {workspace} {'[green]✓[/green]' if workspace.exists() else '[red]✗[/red]'}")
|
||||
|
||||
if config_path.exists():
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
console.print(f"Model: {config.agents.defaults.model}")
|
||||
|
||||
# Check API keys from registry
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(config.providers, spec.name, None)
|
||||
if p is None:
|
||||
continue
|
||||
if spec.is_oauth:
|
||||
console.print(f"{spec.label}: [green]✓ (OAuth)[/green]")
|
||||
elif spec.is_local:
|
||||
# Local deployments show api_base instead of api_key
|
||||
if p.api_base:
|
||||
console.print(f"{spec.label}: [green]✓ {p.api_base}[/green]")
|
||||
else:
|
||||
console.print(f"{spec.label}: [dim]not set[/dim]")
|
||||
else:
|
||||
has_key = bool(p.api_key)
|
||||
console.print(f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth Login
|
||||
# ============================================================================
|
||||
|
||||
provider_app = typer.Typer(help="Manage providers")
|
||||
app.add_typer(provider_app, name="provider")
|
||||
|
||||
|
||||
_LOGIN_HANDLERS: dict[str, callable] = {}
|
||||
|
||||
|
||||
def _register_login(name: str):
|
||||
def decorator(fn):
|
||||
_LOGIN_HANDLERS[name] = fn
|
||||
return fn
|
||||
return decorator
|
||||
|
||||
|
||||
@provider_app.command("login")
|
||||
def provider_login(
|
||||
provider: str = typer.Argument(..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')"),
|
||||
):
|
||||
"""Authenticate with an OAuth provider."""
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
key = provider.replace("-", "_")
|
||||
spec = next((s for s in PROVIDERS if s.name == key and s.is_oauth), None)
|
||||
if not spec:
|
||||
names = ", ".join(s.name.replace("_", "-") for s in PROVIDERS if s.is_oauth)
|
||||
console.print(f"[red]Unknown OAuth provider: {provider}[/red] Supported: {names}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
handler = _LOGIN_HANDLERS.get(spec.name)
|
||||
if not handler:
|
||||
console.print(f"[red]Login not implemented for {spec.label}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"{__logo__} OAuth Login - {spec.label}\n")
|
||||
handler()
|
||||
|
||||
|
||||
@_register_login("openai_codex")
|
||||
def _login_openai_codex() -> None:
|
||||
try:
|
||||
from oauth_cli_kit import get_token, login_oauth_interactive
|
||||
token = None
|
||||
try:
|
||||
token = get_token()
|
||||
except Exception:
|
||||
pass
|
||||
if not (token and token.access):
|
||||
console.print("[cyan]Starting interactive OAuth login...[/cyan]\n")
|
||||
token = login_oauth_interactive(
|
||||
print_fn=lambda s: console.print(s),
|
||||
prompt_fn=lambda s: typer.prompt(s),
|
||||
)
|
||||
if not (token and token.access):
|
||||
console.print("[red]✗ Authentication failed[/red]")
|
||||
raise typer.Exit(1)
|
||||
console.print(f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]")
|
||||
except ImportError:
|
||||
console.print("[red]oauth_cli_kit not installed. Run: pip install oauth-cli-kit[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@_register_login("github_copilot")
|
||||
def _login_github_copilot() -> None:
|
||||
import asyncio
|
||||
|
||||
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
||||
|
||||
async def _trigger():
|
||||
from litellm import acompletion
|
||||
await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1)
|
||||
|
||||
try:
|
||||
asyncio.run(_trigger())
|
||||
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Authentication error: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Configuration module for nanobot."""
|
||||
|
||||
from nanobot.config.loader import get_config_path, load_config
|
||||
from nanobot.config.paths import (
|
||||
get_bridge_install_dir,
|
||||
get_cli_history_path,
|
||||
get_cron_dir,
|
||||
get_data_dir,
|
||||
get_legacy_sessions_dir,
|
||||
get_logs_dir,
|
||||
get_media_dir,
|
||||
get_runtime_subdir,
|
||||
get_workspace_path,
|
||||
)
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
__all__ = [
|
||||
"Config",
|
||||
"load_config",
|
||||
"get_config_path",
|
||||
"get_data_dir",
|
||||
"get_runtime_subdir",
|
||||
"get_media_dir",
|
||||
"get_cron_dir",
|
||||
"get_logs_dir",
|
||||
"get_workspace_path",
|
||||
"get_cli_history_path",
|
||||
"get_bridge_install_dir",
|
||||
"get_legacy_sessions_dir",
|
||||
]
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Configuration loading utilities."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
|
||||
# Global variable to store current config path (for multi-instance support)
|
||||
_current_config_path: Path | None = None
|
||||
|
||||
|
||||
def set_config_path(path: Path) -> None:
|
||||
"""Set the current config path (used to derive data directory)."""
|
||||
global _current_config_path
|
||||
_current_config_path = path
|
||||
|
||||
|
||||
def get_config_path() -> Path:
|
||||
"""Get the configuration file path."""
|
||||
if _current_config_path:
|
||||
return _current_config_path
|
||||
return Path.home() / ".nanobot" / "config.json"
|
||||
|
||||
|
||||
def load_config(config_path: Path | None = None) -> Config:
|
||||
"""
|
||||
Load configuration from file or create default.
|
||||
|
||||
Args:
|
||||
config_path: Optional path to config file. Uses default if not provided.
|
||||
|
||||
Returns:
|
||||
Loaded configuration object.
|
||||
"""
|
||||
path = config_path or get_config_path()
|
||||
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
data = _migrate_config(data)
|
||||
return Config.model_validate(data)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Warning: Failed to load config from {path}: {e}")
|
||||
print("Using default configuration.")
|
||||
|
||||
return Config()
|
||||
|
||||
|
||||
def save_config(config: Config, config_path: Path | None = None) -> None:
|
||||
"""
|
||||
Save configuration to file.
|
||||
|
||||
Args:
|
||||
config: Configuration to save.
|
||||
config_path: Optional path to save to. Uses default if not provided.
|
||||
"""
|
||||
path = config_path or get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = config.model_dump(by_alias=True)
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def _migrate_config(data: dict) -> dict:
|
||||
"""Migrate old config formats to current."""
|
||||
# Move tools.exec.restrictToWorkspace → tools.restrictToWorkspace
|
||||
tools = data.get("tools", {})
|
||||
exec_cfg = tools.get("exec", {})
|
||||
if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools:
|
||||
tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace")
|
||||
return data
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Runtime path helpers derived from the active config context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.loader import get_config_path
|
||||
from nanobot.utils.helpers import ensure_dir
|
||||
|
||||
|
||||
def get_data_dir() -> Path:
|
||||
"""Return the instance-level runtime data directory."""
|
||||
return ensure_dir(get_config_path().parent)
|
||||
|
||||
|
||||
def get_runtime_subdir(name: str) -> Path:
|
||||
"""Return a named runtime subdirectory under the instance data dir."""
|
||||
return ensure_dir(get_data_dir() / name)
|
||||
|
||||
|
||||
def get_media_dir(channel: str | None = None) -> Path:
|
||||
"""Return the media directory, optionally namespaced per channel."""
|
||||
base = get_runtime_subdir("media")
|
||||
return ensure_dir(base / channel) if channel else base
|
||||
|
||||
|
||||
def get_cron_dir() -> Path:
|
||||
"""Return the cron storage directory."""
|
||||
return get_runtime_subdir("cron")
|
||||
|
||||
|
||||
def get_logs_dir() -> Path:
|
||||
"""Return the logs directory."""
|
||||
return get_runtime_subdir("logs")
|
||||
|
||||
|
||||
def get_workspace_path(workspace: str | None = None) -> Path:
|
||||
"""Resolve and ensure the agent workspace path."""
|
||||
path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
|
||||
return ensure_dir(path)
|
||||
|
||||
|
||||
def get_cli_history_path() -> Path:
|
||||
"""Return the shared CLI history file path."""
|
||||
return Path.home() / ".nanobot" / "history" / "cli_history"
|
||||
|
||||
|
||||
def get_bridge_install_dir() -> Path:
|
||||
"""Return the shared WhatsApp bridge installation directory."""
|
||||
return Path.home() / ".nanobot" / "bridge"
|
||||
|
||||
|
||||
def get_legacy_sessions_dir() -> Path:
|
||||
"""Return the legacy global session directory used for migration fallback."""
|
||||
return Path.home() / ".nanobot" / "sessions"
|
||||
@@ -0,0 +1,421 @@
|
||||
"""Configuration schema using Pydantic."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic.alias_generators import to_camel
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Base(BaseModel):
|
||||
"""Base model that accepts both camelCase and snake_case keys."""
|
||||
|
||||
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
|
||||
|
||||
|
||||
class WhatsAppConfig(Base):
|
||||
"""WhatsApp channel configuration."""
|
||||
|
||||
enabled: bool = False
|
||||
bridge_url: str = "ws://localhost:3001"
|
||||
bridge_token: str = "" # Shared token for bridge auth (optional, recommended)
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers
|
||||
|
||||
|
||||
class TelegramConfig(Base):
|
||||
"""Telegram channel configuration."""
|
||||
|
||||
enabled: bool = False
|
||||
token: str = "" # Bot token from @BotFather
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
|
||||
proxy: str | None = (
|
||||
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
||||
)
|
||||
reply_to_message: bool = False # If true, bot replies quote the original message
|
||||
|
||||
|
||||
class FeishuConfig(Base):
|
||||
"""Feishu/Lark channel configuration using WebSocket long connection."""
|
||||
|
||||
enabled: bool = False
|
||||
app_id: str = "" # App ID from Feishu Open Platform
|
||||
app_secret: str = "" # App Secret from Feishu Open Platform
|
||||
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
|
||||
verification_token: str = "" # Verification Token for event subscription (optional)
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
|
||||
react_emoji: str = (
|
||||
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
||||
)
|
||||
|
||||
|
||||
class DingTalkConfig(Base):
|
||||
"""DingTalk channel configuration using Stream mode."""
|
||||
|
||||
enabled: bool = False
|
||||
client_id: str = "" # AppKey
|
||||
client_secret: str = "" # AppSecret
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed staff_ids
|
||||
|
||||
|
||||
class DiscordConfig(Base):
|
||||
"""Discord channel configuration."""
|
||||
|
||||
enabled: bool = False
|
||||
token: str = "" # Bot token from Discord Developer Portal
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
|
||||
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
|
||||
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
|
||||
group_policy: Literal["mention", "open"] = "mention"
|
||||
|
||||
|
||||
class MatrixConfig(Base):
|
||||
"""Matrix (Element) channel configuration."""
|
||||
|
||||
enabled: bool = False
|
||||
homeserver: str = "https://matrix.org"
|
||||
access_token: str = ""
|
||||
user_id: str = "" # @bot:matrix.org
|
||||
device_id: str = ""
|
||||
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
|
||||
sync_stop_grace_seconds: int = (
|
||||
2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
|
||||
)
|
||||
max_media_bytes: int = (
|
||||
20 * 1024 * 1024
|
||||
) # Max attachment size accepted for Matrix media handling (inbound + outbound).
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||
group_allow_from: list[str] = Field(default_factory=list)
|
||||
allow_room_mentions: bool = False
|
||||
|
||||
|
||||
class EmailConfig(Base):
|
||||
"""Email channel configuration (IMAP inbound + SMTP outbound)."""
|
||||
|
||||
enabled: bool = False
|
||||
consent_granted: bool = False # Explicit owner permission to access mailbox data
|
||||
|
||||
# IMAP (receive)
|
||||
imap_host: str = ""
|
||||
imap_port: int = 993
|
||||
imap_username: str = ""
|
||||
imap_password: str = ""
|
||||
imap_mailbox: str = "INBOX"
|
||||
imap_use_ssl: bool = True
|
||||
|
||||
# SMTP (send)
|
||||
smtp_host: str = ""
|
||||
smtp_port: int = 587
|
||||
smtp_username: str = ""
|
||||
smtp_password: str = ""
|
||||
smtp_use_tls: bool = True
|
||||
smtp_use_ssl: bool = False
|
||||
from_address: str = ""
|
||||
|
||||
# Behavior
|
||||
auto_reply_enabled: bool = (
|
||||
True # If false, inbound email is read but no automatic reply is sent
|
||||
)
|
||||
poll_interval_seconds: int = 30
|
||||
mark_seen: bool = True
|
||||
max_body_chars: int = 12000
|
||||
subject_prefix: str = "Re: "
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed sender email addresses
|
||||
|
||||
|
||||
class MochatMentionConfig(Base):
|
||||
"""Mochat mention behavior configuration."""
|
||||
|
||||
require_in_groups: bool = False
|
||||
|
||||
|
||||
class MochatGroupRule(Base):
|
||||
"""Mochat per-group mention requirement."""
|
||||
|
||||
require_mention: bool = False
|
||||
|
||||
|
||||
class MochatConfig(Base):
|
||||
"""Mochat channel configuration."""
|
||||
|
||||
enabled: bool = False
|
||||
base_url: str = "https://mochat.io"
|
||||
socket_url: str = ""
|
||||
socket_path: str = "/socket.io"
|
||||
socket_disable_msgpack: bool = False
|
||||
socket_reconnect_delay_ms: int = 1000
|
||||
socket_max_reconnect_delay_ms: int = 10000
|
||||
socket_connect_timeout_ms: int = 10000
|
||||
refresh_interval_ms: int = 30000
|
||||
watch_timeout_ms: int = 25000
|
||||
watch_limit: int = 100
|
||||
retry_delay_ms: int = 500
|
||||
max_retry_attempts: int = 0 # 0 means unlimited retries
|
||||
claw_token: str = ""
|
||||
agent_user_id: str = ""
|
||||
sessions: list[str] = Field(default_factory=list)
|
||||
panels: list[str] = Field(default_factory=list)
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
mention: MochatMentionConfig = Field(default_factory=MochatMentionConfig)
|
||||
groups: dict[str, MochatGroupRule] = Field(default_factory=dict)
|
||||
reply_delay_mode: str = "non-mention" # off | non-mention
|
||||
reply_delay_ms: int = 120000
|
||||
|
||||
|
||||
class SlackDMConfig(Base):
|
||||
"""Slack DM policy configuration."""
|
||||
|
||||
enabled: bool = True
|
||||
policy: str = "open" # "open" or "allowlist"
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs
|
||||
|
||||
|
||||
class SlackConfig(Base):
|
||||
"""Slack channel configuration."""
|
||||
|
||||
enabled: bool = False
|
||||
mode: str = "socket" # "socket" supported
|
||||
webhook_path: str = "/slack/events"
|
||||
bot_token: str = "" # xoxb-...
|
||||
app_token: str = "" # xapp-...
|
||||
user_token_read_only: bool = True
|
||||
reply_in_thread: bool = True
|
||||
react_emoji: str = "eyes"
|
||||
allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs (sender-level)
|
||||
group_policy: str = "mention" # "mention", "open", "allowlist"
|
||||
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
|
||||
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
|
||||
|
||||
|
||||
class QQConfig(Base):
|
||||
"""QQ channel configuration using botpy SDK."""
|
||||
|
||||
enabled: bool = False
|
||||
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
|
||||
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
||||
allow_from: list[str] = Field(
|
||||
default_factory=list
|
||||
) # Allowed user openids (empty = public access)
|
||||
|
||||
|
||||
|
||||
|
||||
class ChannelsConfig(Base):
|
||||
"""Configuration for chat channels."""
|
||||
|
||||
send_progress: bool = True # stream agent's text progress to the channel
|
||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
|
||||
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
||||
discord: DiscordConfig = Field(default_factory=DiscordConfig)
|
||||
feishu: FeishuConfig = Field(default_factory=FeishuConfig)
|
||||
mochat: MochatConfig = Field(default_factory=MochatConfig)
|
||||
dingtalk: DingTalkConfig = Field(default_factory=DingTalkConfig)
|
||||
email: EmailConfig = Field(default_factory=EmailConfig)
|
||||
slack: SlackConfig = Field(default_factory=SlackConfig)
|
||||
qq: QQConfig = Field(default_factory=QQConfig)
|
||||
matrix: MatrixConfig = Field(default_factory=MatrixConfig)
|
||||
|
||||
|
||||
class AgentDefaults(Base):
|
||||
"""Default agent configuration."""
|
||||
|
||||
workspace: str = "~/.nanobot/workspace"
|
||||
model: str = "anthropic/claude-opus-4-5"
|
||||
provider: str = (
|
||||
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||
)
|
||||
max_tokens: int = 8192
|
||||
temperature: float = 0.1
|
||||
max_tool_iterations: int = 40
|
||||
memory_window: int = 100
|
||||
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||
|
||||
|
||||
class AgentsConfig(Base):
|
||||
"""Agent configuration."""
|
||||
|
||||
defaults: AgentDefaults = Field(default_factory=AgentDefaults)
|
||||
|
||||
|
||||
class ProviderConfig(Base):
|
||||
"""LLM provider configuration."""
|
||||
|
||||
api_key: str = ""
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
|
||||
|
||||
|
||||
class ProvidersConfig(Base):
|
||||
"""Configuration for LLM providers."""
|
||||
|
||||
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
||||
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
|
||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
groq: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
dashscope: ProviderConfig = Field(default_factory=ProviderConfig) # 阿里云通义千问
|
||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||
|
||||
|
||||
class HeartbeatConfig(Base):
|
||||
"""Heartbeat service configuration."""
|
||||
|
||||
enabled: bool = True
|
||||
interval_s: int = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
class GatewayConfig(Base):
|
||||
"""Gateway/server configuration."""
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 18790
|
||||
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
|
||||
|
||||
|
||||
class WebSearchConfig(Base):
|
||||
"""Web search tool configuration."""
|
||||
|
||||
api_key: str = "" # Brave Search API key
|
||||
max_results: int = 5
|
||||
|
||||
|
||||
class WebToolsConfig(Base):
|
||||
"""Web tools configuration."""
|
||||
|
||||
proxy: str | None = (
|
||||
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
||||
)
|
||||
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
|
||||
|
||||
|
||||
class ExecToolConfig(Base):
|
||||
"""Shell exec tool configuration."""
|
||||
|
||||
timeout: int = 60
|
||||
path_append: str = ""
|
||||
|
||||
|
||||
class MCPServerConfig(Base):
|
||||
"""MCP server connection configuration (stdio or HTTP)."""
|
||||
|
||||
type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted
|
||||
command: str = "" # Stdio: command to run (e.g. "npx")
|
||||
args: list[str] = Field(default_factory=list) # Stdio: command arguments
|
||||
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
||||
url: str = "" # HTTP/SSE: endpoint URL
|
||||
headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
|
||||
tool_timeout: int = 30 # seconds before a tool call is cancelled
|
||||
|
||||
|
||||
class ToolsConfig(Base):
|
||||
"""Tools configuration."""
|
||||
|
||||
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
||||
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
|
||||
restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
"""Root configuration for nanobot."""
|
||||
|
||||
agents: AgentsConfig = Field(default_factory=AgentsConfig)
|
||||
channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
|
||||
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||
|
||||
@property
|
||||
def workspace_path(self) -> Path:
|
||||
"""Get expanded workspace path."""
|
||||
return Path(self.agents.defaults.workspace).expanduser()
|
||||
|
||||
def _match_provider(
|
||||
self, model: str | None = None
|
||||
) -> tuple["ProviderConfig | None", str | None]:
|
||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
forced = self.agents.defaults.provider
|
||||
if forced != "auto":
|
||||
p = getattr(self.providers, forced, None)
|
||||
return (p, forced) if p else (None, None)
|
||||
|
||||
model_lower = (model or self.agents.defaults.model).lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
|
||||
def _kw_matches(kw: str) -> bool:
|
||||
kw = kw.lower()
|
||||
return kw in model_lower or kw.replace("-", "_") in model_normalized
|
||||
|
||||
# Explicit provider prefix wins — prevents `github-copilot/...codex` matching openai_codex.
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and model_prefix and normalized_prefix == spec.name:
|
||||
if spec.is_oauth or p.api_key:
|
||||
return p, spec.name
|
||||
|
||||
# Match by keyword (order follows PROVIDERS registry)
|
||||
for spec in PROVIDERS:
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and any(_kw_matches(kw) for kw in spec.keywords):
|
||||
if spec.is_oauth or p.api_key:
|
||||
return p, spec.name
|
||||
|
||||
# Fallback: gateways first, then others (follows registry order)
|
||||
# OAuth providers are NOT valid fallbacks — they require explicit model selection
|
||||
for spec in PROVIDERS:
|
||||
if spec.is_oauth:
|
||||
continue
|
||||
p = getattr(self.providers, spec.name, None)
|
||||
if p and p.api_key:
|
||||
return p, spec.name
|
||||
return None, None
|
||||
|
||||
def get_provider(self, model: str | None = None) -> ProviderConfig | None:
|
||||
"""Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
|
||||
p, _ = self._match_provider(model)
|
||||
return p
|
||||
|
||||
def get_provider_name(self, model: str | None = None) -> str | None:
|
||||
"""Get the registry name of the matched provider (e.g. "deepseek", "openrouter")."""
|
||||
_, name = self._match_provider(model)
|
||||
return name
|
||||
|
||||
def get_api_key(self, model: str | None = None) -> str | None:
|
||||
"""Get API key for the given model. Falls back to first available key."""
|
||||
p = self.get_provider(model)
|
||||
return p.api_key if p else None
|
||||
|
||||
def get_api_base(self, model: str | None = None) -> str | None:
|
||||
"""Get API base URL for the given model. Applies default URLs for known gateways."""
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
p, name = self._match_provider(model)
|
||||
if p and p.api_base:
|
||||
return p.api_base
|
||||
# Only gateways get a default api_base here. Standard providers
|
||||
# (like Moonshot) set their base URL via env vars in _setup_env
|
||||
# to avoid polluting the global litellm.api_base.
|
||||
if name:
|
||||
spec = find_by_name(name)
|
||||
if spec and spec.is_gateway and spec.default_api_base:
|
||||
return spec.default_api_base
|
||||
return None
|
||||
|
||||
model_config = ConfigDict(env_prefix="NANOBOT_", env_nested_delimiter="__")
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Cron service for scheduled agent tasks."""
|
||||
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob, CronSchedule
|
||||
|
||||
__all__ = ["CronService", "CronJob", "CronSchedule"]
|
||||
@@ -0,0 +1,376 @@
|
||||
"""Cron service for scheduling agent tasks."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule, CronStore
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
|
||||
"""Compute next run time in ms."""
|
||||
if schedule.kind == "at":
|
||||
return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None
|
||||
|
||||
if schedule.kind == "every":
|
||||
if not schedule.every_ms or schedule.every_ms <= 0:
|
||||
return None
|
||||
# Next interval from now
|
||||
return now_ms + schedule.every_ms
|
||||
|
||||
if schedule.kind == "cron" and schedule.expr:
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
from croniter import croniter
|
||||
# Use caller-provided reference time for deterministic scheduling
|
||||
base_time = now_ms / 1000
|
||||
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
|
||||
base_dt = datetime.fromtimestamp(base_time, tz=tz)
|
||||
cron = croniter(schedule.expr, base_dt)
|
||||
next_dt = cron.get_next(datetime)
|
||||
return int(next_dt.timestamp() * 1000)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_schedule_for_add(schedule: CronSchedule) -> None:
|
||||
"""Validate schedule fields that would otherwise create non-runnable jobs."""
|
||||
if schedule.tz and schedule.kind != "cron":
|
||||
raise ValueError("tz can only be used with cron schedules")
|
||||
|
||||
if schedule.kind == "cron" and schedule.tz:
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
ZoneInfo(schedule.tz)
|
||||
except Exception:
|
||||
raise ValueError(f"unknown timezone '{schedule.tz}'") from None
|
||||
|
||||
|
||||
class CronService:
|
||||
"""Service for managing and executing scheduled jobs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
store_path: Path,
|
||||
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None
|
||||
):
|
||||
self.store_path = store_path
|
||||
self.on_job = on_job
|
||||
self._store: CronStore | None = None
|
||||
self._last_mtime: float = 0.0
|
||||
self._timer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
|
||||
def _load_store(self) -> CronStore:
|
||||
"""Load jobs from disk. Reloads automatically if file was modified externally."""
|
||||
if self._store and self.store_path.exists():
|
||||
mtime = self.store_path.stat().st_mtime
|
||||
if mtime != self._last_mtime:
|
||||
logger.info("Cron: jobs.json modified externally, reloading")
|
||||
self._store = None
|
||||
if self._store:
|
||||
return self._store
|
||||
|
||||
if self.store_path.exists():
|
||||
try:
|
||||
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
||||
jobs = []
|
||||
for j in data.get("jobs", []):
|
||||
jobs.append(CronJob(
|
||||
id=j["id"],
|
||||
name=j["name"],
|
||||
enabled=j.get("enabled", True),
|
||||
schedule=CronSchedule(
|
||||
kind=j["schedule"]["kind"],
|
||||
at_ms=j["schedule"].get("atMs"),
|
||||
every_ms=j["schedule"].get("everyMs"),
|
||||
expr=j["schedule"].get("expr"),
|
||||
tz=j["schedule"].get("tz"),
|
||||
),
|
||||
payload=CronPayload(
|
||||
kind=j["payload"].get("kind", "agent_turn"),
|
||||
message=j["payload"].get("message", ""),
|
||||
deliver=j["payload"].get("deliver", False),
|
||||
channel=j["payload"].get("channel"),
|
||||
to=j["payload"].get("to"),
|
||||
),
|
||||
state=CronJobState(
|
||||
next_run_at_ms=j.get("state", {}).get("nextRunAtMs"),
|
||||
last_run_at_ms=j.get("state", {}).get("lastRunAtMs"),
|
||||
last_status=j.get("state", {}).get("lastStatus"),
|
||||
last_error=j.get("state", {}).get("lastError"),
|
||||
),
|
||||
created_at_ms=j.get("createdAtMs", 0),
|
||||
updated_at_ms=j.get("updatedAtMs", 0),
|
||||
delete_after_run=j.get("deleteAfterRun", False),
|
||||
))
|
||||
self._store = CronStore(jobs=jobs)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load cron store: {}", e)
|
||||
self._store = CronStore()
|
||||
else:
|
||||
self._store = CronStore()
|
||||
|
||||
return self._store
|
||||
|
||||
def _save_store(self) -> None:
|
||||
"""Save jobs to disk."""
|
||||
if not self._store:
|
||||
return
|
||||
|
||||
self.store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = {
|
||||
"version": self._store.version,
|
||||
"jobs": [
|
||||
{
|
||||
"id": j.id,
|
||||
"name": j.name,
|
||||
"enabled": j.enabled,
|
||||
"schedule": {
|
||||
"kind": j.schedule.kind,
|
||||
"atMs": j.schedule.at_ms,
|
||||
"everyMs": j.schedule.every_ms,
|
||||
"expr": j.schedule.expr,
|
||||
"tz": j.schedule.tz,
|
||||
},
|
||||
"payload": {
|
||||
"kind": j.payload.kind,
|
||||
"message": j.payload.message,
|
||||
"deliver": j.payload.deliver,
|
||||
"channel": j.payload.channel,
|
||||
"to": j.payload.to,
|
||||
},
|
||||
"state": {
|
||||
"nextRunAtMs": j.state.next_run_at_ms,
|
||||
"lastRunAtMs": j.state.last_run_at_ms,
|
||||
"lastStatus": j.state.last_status,
|
||||
"lastError": j.state.last_error,
|
||||
},
|
||||
"createdAtMs": j.created_at_ms,
|
||||
"updatedAtMs": j.updated_at_ms,
|
||||
"deleteAfterRun": j.delete_after_run,
|
||||
}
|
||||
for j in self._store.jobs
|
||||
]
|
||||
}
|
||||
|
||||
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||
self._last_mtime = self.store_path.stat().st_mtime
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the cron service."""
|
||||
self._running = True
|
||||
self._load_store()
|
||||
self._recompute_next_runs()
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the cron service."""
|
||||
self._running = False
|
||||
if self._timer_task:
|
||||
self._timer_task.cancel()
|
||||
self._timer_task = None
|
||||
|
||||
def _recompute_next_runs(self) -> None:
|
||||
"""Recompute next run times for all enabled jobs."""
|
||||
if not self._store:
|
||||
return
|
||||
now = _now_ms()
|
||||
for job in self._store.jobs:
|
||||
if job.enabled:
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, now)
|
||||
|
||||
def _get_next_wake_ms(self) -> int | None:
|
||||
"""Get the earliest next run time across all jobs."""
|
||||
if not self._store:
|
||||
return None
|
||||
times = [j.state.next_run_at_ms for j in self._store.jobs
|
||||
if j.enabled and j.state.next_run_at_ms]
|
||||
return min(times) if times else None
|
||||
|
||||
def _arm_timer(self) -> None:
|
||||
"""Schedule the next timer tick."""
|
||||
if self._timer_task:
|
||||
self._timer_task.cancel()
|
||||
|
||||
next_wake = self._get_next_wake_ms()
|
||||
if not next_wake or not self._running:
|
||||
return
|
||||
|
||||
delay_ms = max(0, next_wake - _now_ms())
|
||||
delay_s = delay_ms / 1000
|
||||
|
||||
async def tick():
|
||||
await asyncio.sleep(delay_s)
|
||||
if self._running:
|
||||
await self._on_timer()
|
||||
|
||||
self._timer_task = asyncio.create_task(tick())
|
||||
|
||||
async def _on_timer(self) -> None:
|
||||
"""Handle timer tick - run due jobs."""
|
||||
self._load_store()
|
||||
if not self._store:
|
||||
return
|
||||
|
||||
now = _now_ms()
|
||||
due_jobs = [
|
||||
j for j in self._store.jobs
|
||||
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
||||
]
|
||||
|
||||
for job in due_jobs:
|
||||
await self._execute_job(job)
|
||||
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
|
||||
async def _execute_job(self, job: CronJob) -> None:
|
||||
"""Execute a single job."""
|
||||
start_ms = _now_ms()
|
||||
logger.info("Cron: executing job '{}' ({})", job.name, job.id)
|
||||
|
||||
try:
|
||||
response = None
|
||||
if self.on_job:
|
||||
response = await self.on_job(job)
|
||||
|
||||
job.state.last_status = "ok"
|
||||
job.state.last_error = None
|
||||
logger.info("Cron: job '{}' completed", job.name)
|
||||
|
||||
except Exception as e:
|
||||
job.state.last_status = "error"
|
||||
job.state.last_error = str(e)
|
||||
logger.error("Cron: job '{}' failed: {}", job.name, e)
|
||||
|
||||
job.state.last_run_at_ms = start_ms
|
||||
job.updated_at_ms = _now_ms()
|
||||
|
||||
# Handle one-shot jobs
|
||||
if job.schedule.kind == "at":
|
||||
if job.delete_after_run:
|
||||
self._store.jobs = [j for j in self._store.jobs if j.id != job.id]
|
||||
else:
|
||||
job.enabled = False
|
||||
job.state.next_run_at_ms = None
|
||||
else:
|
||||
# Compute next run
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
|
||||
# ========== Public API ==========
|
||||
|
||||
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
|
||||
"""List all jobs."""
|
||||
store = self._load_store()
|
||||
jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled]
|
||||
return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf'))
|
||||
|
||||
def add_job(
|
||||
self,
|
||||
name: str,
|
||||
schedule: CronSchedule,
|
||||
message: str,
|
||||
deliver: bool = False,
|
||||
channel: str | None = None,
|
||||
to: str | None = None,
|
||||
delete_after_run: bool = False,
|
||||
) -> CronJob:
|
||||
"""Add a new job."""
|
||||
store = self._load_store()
|
||||
_validate_schedule_for_add(schedule)
|
||||
now = _now_ms()
|
||||
|
||||
job = CronJob(
|
||||
id=str(uuid.uuid4())[:8],
|
||||
name=name,
|
||||
enabled=True,
|
||||
schedule=schedule,
|
||||
payload=CronPayload(
|
||||
kind="agent_turn",
|
||||
message=message,
|
||||
deliver=deliver,
|
||||
channel=channel,
|
||||
to=to,
|
||||
),
|
||||
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
|
||||
created_at_ms=now,
|
||||
updated_at_ms=now,
|
||||
delete_after_run=delete_after_run,
|
||||
)
|
||||
|
||||
store.jobs.append(job)
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
|
||||
logger.info("Cron: added job '{}' ({})", name, job.id)
|
||||
return job
|
||||
|
||||
def remove_job(self, job_id: str) -> bool:
|
||||
"""Remove a job by ID."""
|
||||
store = self._load_store()
|
||||
before = len(store.jobs)
|
||||
store.jobs = [j for j in store.jobs if j.id != job_id]
|
||||
removed = len(store.jobs) < before
|
||||
|
||||
if removed:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
logger.info("Cron: removed job {}", job_id)
|
||||
|
||||
return removed
|
||||
|
||||
def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None:
|
||||
"""Enable or disable a job."""
|
||||
store = self._load_store()
|
||||
for job in store.jobs:
|
||||
if job.id == job_id:
|
||||
job.enabled = enabled
|
||||
job.updated_at_ms = _now_ms()
|
||||
if enabled:
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
else:
|
||||
job.state.next_run_at_ms = None
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
return job
|
||||
return None
|
||||
|
||||
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
||||
"""Manually run a job."""
|
||||
store = self._load_store()
|
||||
for job in store.jobs:
|
||||
if job.id == job_id:
|
||||
if not force and not job.enabled:
|
||||
return False
|
||||
await self._execute_job(job)
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
return True
|
||||
return False
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Get service status."""
|
||||
store = self._load_store()
|
||||
return {
|
||||
"enabled": self._running,
|
||||
"jobs": len(store.jobs),
|
||||
"next_wake_at_ms": self._get_next_wake_ms(),
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
"""Cron types."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronSchedule:
|
||||
"""Schedule definition for a cron job."""
|
||||
kind: Literal["at", "every", "cron"]
|
||||
# For "at": timestamp in ms
|
||||
at_ms: int | None = None
|
||||
# For "every": interval in ms
|
||||
every_ms: int | None = None
|
||||
# For "cron": cron expression (e.g. "0 9 * * *")
|
||||
expr: str | None = None
|
||||
# Timezone for cron expressions
|
||||
tz: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronPayload:
|
||||
"""What to do when the job runs."""
|
||||
kind: Literal["system_event", "agent_turn"] = "agent_turn"
|
||||
message: str = ""
|
||||
# Deliver response to channel
|
||||
deliver: bool = False
|
||||
channel: str | None = None # e.g. "whatsapp"
|
||||
to: str | None = None # e.g. phone number
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronJobState:
|
||||
"""Runtime state of a job."""
|
||||
next_run_at_ms: int | None = None
|
||||
last_run_at_ms: int | None = None
|
||||
last_status: Literal["ok", "error", "skipped"] | None = None
|
||||
last_error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronJob:
|
||||
"""A scheduled job."""
|
||||
id: str
|
||||
name: str
|
||||
enabled: bool = True
|
||||
schedule: CronSchedule = field(default_factory=lambda: CronSchedule(kind="every"))
|
||||
payload: CronPayload = field(default_factory=CronPayload)
|
||||
state: CronJobState = field(default_factory=CronJobState)
|
||||
created_at_ms: int = 0
|
||||
updated_at_ms: int = 0
|
||||
delete_after_run: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronStore:
|
||||
"""Persistent store for cron jobs."""
|
||||
version: int = 1
|
||||
jobs: list[CronJob] = field(default_factory=list)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Heartbeat service for periodic agent wake-ups."""
|
||||
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
|
||||
__all__ = ["HeartbeatService"]
|
||||
@@ -0,0 +1,173 @@
|
||||
"""Heartbeat service - periodic agent wake-up to check for tasks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
_HEARTBEAT_TOOL = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "heartbeat",
|
||||
"description": "Report heartbeat decision after reviewing tasks.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": ["skip", "run"],
|
||||
"description": "skip = nothing to do, run = has active tasks",
|
||||
},
|
||||
"tasks": {
|
||||
"type": "string",
|
||||
"description": "Natural-language summary of active tasks (required for run)",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class HeartbeatService:
|
||||
"""
|
||||
Periodic heartbeat service that wakes the agent to check for tasks.
|
||||
|
||||
Phase 1 (decision): reads HEARTBEAT.md and asks the LLM — via a virtual
|
||||
tool call — whether there are active tasks. This avoids free-text parsing
|
||||
and the unreliable HEARTBEAT_OK token.
|
||||
|
||||
Phase 2 (execution): only triggered when Phase 1 returns ``run``. The
|
||||
``on_execute`` callback runs the task through the full agent loop and
|
||||
returns the result to deliver.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
||||
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
|
||||
interval_s: int = 30 * 60,
|
||||
enabled: bool = True,
|
||||
):
|
||||
self.workspace = workspace
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.on_execute = on_execute
|
||||
self.on_notify = on_notify
|
||||
self.interval_s = interval_s
|
||||
self.enabled = enabled
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
|
||||
@property
|
||||
def heartbeat_file(self) -> Path:
|
||||
return self.workspace / "HEARTBEAT.md"
|
||||
|
||||
def _read_heartbeat_file(self) -> str | None:
|
||||
if self.heartbeat_file.exists():
|
||||
try:
|
||||
return self.heartbeat_file.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _decide(self, content: str) -> tuple[str, str]:
|
||||
"""Phase 1: ask LLM to decide skip/run via virtual tool call.
|
||||
|
||||
Returns (action, tasks) where action is 'skip' or 'run'.
|
||||
"""
|
||||
response = await self.provider.chat(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
||||
{"role": "user", "content": (
|
||||
"Review the following HEARTBEAT.md and decide whether there are active tasks.\n\n"
|
||||
f"{content}"
|
||||
)},
|
||||
],
|
||||
tools=_HEARTBEAT_TOOL,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
if not response.has_tool_calls:
|
||||
return "skip", ""
|
||||
|
||||
args = response.tool_calls[0].arguments
|
||||
return args.get("action", "skip"), args.get("tasks", "")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the heartbeat service."""
|
||||
if not self.enabled:
|
||||
logger.info("Heartbeat disabled")
|
||||
return
|
||||
if self._running:
|
||||
logger.warning("Heartbeat already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
logger.info("Heartbeat started (every {}s)", self.interval_s)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the heartbeat service."""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
"""Main heartbeat loop."""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.interval_s)
|
||||
if self._running:
|
||||
await self._tick()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Heartbeat error: {}", e)
|
||||
|
||||
async def _tick(self) -> None:
|
||||
"""Execute a single heartbeat tick."""
|
||||
content = self._read_heartbeat_file()
|
||||
if not content:
|
||||
logger.debug("Heartbeat: HEARTBEAT.md missing or empty")
|
||||
return
|
||||
|
||||
logger.info("Heartbeat: checking for tasks...")
|
||||
|
||||
try:
|
||||
action, tasks = await self._decide(content)
|
||||
|
||||
if action != "run":
|
||||
logger.info("Heartbeat: OK (nothing to report)")
|
||||
return
|
||||
|
||||
logger.info("Heartbeat: tasks found, executing...")
|
||||
if self.on_execute:
|
||||
response = await self.on_execute(tasks)
|
||||
if response and self.on_notify:
|
||||
logger.info("Heartbeat: completed, delivering response")
|
||||
await self.on_notify(response)
|
||||
except Exception:
|
||||
logger.exception("Heartbeat execution failed")
|
||||
|
||||
async def trigger_now(self) -> str | None:
|
||||
"""Manually trigger a heartbeat."""
|
||||
content = self._read_heartbeat_file()
|
||||
if not content:
|
||||
return None
|
||||
action, tasks = await self._decide(content)
|
||||
if action != "run" or not self.on_execute:
|
||||
return None
|
||||
return await self.on_execute(tasks)
|
||||
@@ -0,0 +1,8 @@
|
||||
"""LLM provider abstraction module."""
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Azure OpenAI provider implementation with API version 2024-10-21."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"})
|
||||
|
||||
|
||||
class AzureOpenAIProvider(LLMProvider):
|
||||
"""
|
||||
Azure OpenAI provider with API version 2024-10-21 compliance.
|
||||
|
||||
Features:
|
||||
- Hardcoded API version 2024-10-21
|
||||
- Uses model field as Azure deployment name in URL path
|
||||
- Uses api-key header instead of Authorization Bearer
|
||||
- Uses max_completion_tokens instead of max_tokens
|
||||
- Direct HTTP calls, bypasses LiteLLM
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "",
|
||||
api_base: str = "",
|
||||
default_model: str = "gpt-5.2-chat",
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.api_version = "2024-10-21"
|
||||
|
||||
# Validate required parameters
|
||||
if not api_key:
|
||||
raise ValueError("Azure OpenAI api_key is required")
|
||||
if not api_base:
|
||||
raise ValueError("Azure OpenAI api_base is required")
|
||||
|
||||
# Ensure api_base ends with /
|
||||
if not api_base.endswith('/'):
|
||||
api_base += '/'
|
||||
self.api_base = api_base
|
||||
|
||||
def _build_chat_url(self, deployment_name: str) -> str:
|
||||
"""Build the Azure OpenAI chat completions URL."""
|
||||
# Azure OpenAI URL format:
|
||||
# https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version}
|
||||
base_url = self.api_base
|
||||
if not base_url.endswith('/'):
|
||||
base_url += '/'
|
||||
|
||||
url = urljoin(
|
||||
base_url,
|
||||
f"openai/deployments/{deployment_name}/chat/completions"
|
||||
)
|
||||
return f"{url}?api-version={self.api_version}"
|
||||
|
||||
def _build_headers(self) -> dict[str, str]:
|
||||
"""Build headers for Azure OpenAI API with api-key header."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization
|
||||
"x-session-affinity": uuid.uuid4().hex, # For cache locality
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _supports_temperature(
|
||||
deployment_name: str,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> bool:
|
||||
"""Return True when temperature is likely supported for this deployment."""
|
||||
if reasoning_effort:
|
||||
return False
|
||||
name = deployment_name.lower()
|
||||
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||
|
||||
def _prepare_request_payload(
|
||||
self,
|
||||
deployment_name: str,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare the request payload with Azure OpenAI 2024-10-21 compliance."""
|
||||
payload: dict[str, Any] = {
|
||||
"messages": self._sanitize_request_messages(
|
||||
self._sanitize_empty_content(messages),
|
||||
_AZURE_MSG_KEYS,
|
||||
),
|
||||
"max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens
|
||||
}
|
||||
|
||||
if self._supports_temperature(deployment_name, reasoning_effort):
|
||||
payload["temperature"] = temperature
|
||||
|
||||
if reasoning_effort:
|
||||
payload["reasoning_effort"] = reasoning_effort
|
||||
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
payload["tool_choice"] = "auto"
|
||||
|
||||
return payload
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request to Azure OpenAI.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (used as deployment name).
|
||||
max_tokens: Maximum tokens in response (mapped to max_completion_tokens).
|
||||
temperature: Sampling temperature.
|
||||
reasoning_effort: Optional reasoning effort parameter.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
deployment_name = model or self.default_model
|
||||
url = self._build_chat_url(deployment_name)
|
||||
headers = self._build_headers()
|
||||
payload = self._prepare_request_payload(
|
||||
deployment_name, messages, tools, max_tokens, temperature, reasoning_effort
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=True) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
return LLMResponse(
|
||||
content=f"Azure OpenAI API Error {response.status_code}: {response.text}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
return self._parse_response(response_data)
|
||||
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Azure OpenAI: {repr(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: dict[str, Any]) -> LLMResponse:
|
||||
"""Parse Azure OpenAI response into our standard format."""
|
||||
try:
|
||||
choice = response["choices"][0]
|
||||
message = choice["message"]
|
||||
|
||||
tool_calls = []
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc["function"]["arguments"]
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=tc["id"],
|
||||
name=tc["function"]["name"],
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
usage = {}
|
||||
if response.get("usage"):
|
||||
usage_data = response["usage"]
|
||||
usage = {
|
||||
"prompt_tokens": usage_data.get("prompt_tokens", 0),
|
||||
"completion_tokens": usage_data.get("completion_tokens", 0),
|
||||
"total_tokens": usage_data.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
reasoning_content = message.get("reasoning_content") or None
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content"),
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
|
||||
except (KeyError, IndexError) as e:
|
||||
return LLMResponse(
|
||||
content=f"Error parsing Azure OpenAI response: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model (also used as default deployment name)."""
|
||||
return self.default_model
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Base LLM provider interface."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
"""A tool call request from the LLM."""
|
||||
id: str
|
||||
name: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Response from an LLM provider."""
|
||||
content: str | None
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
finish_reason: str = "stop"
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
||||
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""Check if response contains tool calls."""
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Replace empty text content that causes provider 400 errors.
|
||||
|
||||
Empty content can appear when MCP tools return nothing. Most providers
|
||||
reject empty-string content or empty text blocks in list content.
|
||||
"""
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
|
||||
if isinstance(content, str) and not content:
|
||||
clean = dict(msg)
|
||||
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, list):
|
||||
filtered = [
|
||||
item for item in content
|
||||
if not (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") in ("text", "input_text", "output_text")
|
||||
and not item.get("text")
|
||||
)
|
||||
]
|
||||
if len(filtered) != len(content):
|
||||
clean = dict(msg)
|
||||
if filtered:
|
||||
clean["content"] = filtered
|
||||
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
clean["content"] = None
|
||||
else:
|
||||
clean["content"] = "(empty)"
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
if isinstance(content, dict):
|
||||
clean = dict(msg)
|
||||
clean["content"] = [content]
|
||||
result.append(clean)
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_request_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
allowed_keys: frozenset[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Keep only provider-safe message keys and normalize assistant content."""
|
||||
sanitized = []
|
||||
for msg in messages:
|
||||
clean = {k: v for k, v in msg.items() if k in allowed_keys}
|
||||
if clean.get("role") == "assistant" and "content" not in clean:
|
||||
clean["content"] = None
|
||||
sanitized.append(clean)
|
||||
return sanitized
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions.
|
||||
model: Model identifier (provider-specific).
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model for this provider."""
|
||||
pass
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class CustomProvider(LLMProvider):
|
||||
|
||||
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
# Keep affinity stable for this provider instance to improve backend cache locality.
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
default_headers={"x-session-affinity": uuid.uuid4().hex},
|
||||
)
|
||||
|
||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None) -> LLMResponse:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"messages": self._sanitize_empty_content(messages),
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
if tools:
|
||||
kwargs.update(tools=tools, tool_choice="auto")
|
||||
try:
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error: {e}", finish_reason="error")
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
choice = response.choices[0]
|
||||
msg = choice.message
|
||||
tool_calls = [
|
||||
ToolCallRequest(id=tc.id, name=tc.function.name,
|
||||
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments)
|
||||
for tc in (msg.tool_calls or [])
|
||||
]
|
||||
u = response.usage
|
||||
return LLMResponse(
|
||||
content=msg.content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "stop",
|
||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
@@ -0,0 +1,340 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.registry import find_by_model, find_gateway
|
||||
|
||||
# Standard chat-completion message keys.
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||
_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"})
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""
|
||||
LLM provider using LiteLLM for multi-provider support.
|
||||
|
||||
Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through
|
||||
a unified interface. Provider-specific logic is driven by the registry
|
||||
(see providers/registry.py) — no if-elif chains needed here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
# Detect gateway / local deployment.
|
||||
# provider_name (from config key) is the primary signal;
|
||||
# api_key / api_base are fallback for auto-detection.
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
|
||||
# Configure environment variables
|
||||
if api_key:
|
||||
self._setup_env(api_key, api_base, default_model)
|
||||
|
||||
if api_base:
|
||||
litellm.api_base = api_base
|
||||
|
||||
# Disable LiteLLM logging noise
|
||||
litellm.suppress_debug_info = True
|
||||
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
|
||||
litellm.drop_params = True
|
||||
|
||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||
"""Set environment variables based on detected provider."""
|
||||
spec = self._gateway or find_by_model(model)
|
||||
if not spec:
|
||||
return
|
||||
if not spec.env_key:
|
||||
# OAuth/provider-only specs (for example: openai_codex)
|
||||
return
|
||||
|
||||
# Gateway/local overrides existing env; standard provider doesn't
|
||||
if self._gateway:
|
||||
os.environ[spec.env_key] = api_key
|
||||
else:
|
||||
os.environ.setdefault(spec.env_key, api_key)
|
||||
|
||||
# Resolve env_extras placeholders:
|
||||
# {api_key} → user's API key
|
||||
# {api_base} → user's api_base, falling back to spec.default_api_base
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_val in spec.env_extras:
|
||||
resolved = env_val.replace("{api_key}", api_key)
|
||||
resolved = resolved.replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""Resolve model name by applying provider/gateway prefixes."""
|
||||
if self._gateway:
|
||||
# Gateway mode: apply gateway prefix, skip provider-specific prefixes
|
||||
prefix = self._gateway.litellm_prefix
|
||||
if self._gateway.strip_model_prefix:
|
||||
model = model.split("/")[-1]
|
||||
if prefix and not model.startswith(f"{prefix}/"):
|
||||
model = f"{prefix}/{model}"
|
||||
return model
|
||||
|
||||
# Standard mode: auto-prefix for known providers
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix)
|
||||
if not any(model.startswith(s) for s in spec.skip_prefixes):
|
||||
model = f"{spec.litellm_prefix}/{model}"
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
||||
"""Normalize explicit provider prefixes like `github-copilot/...`."""
|
||||
if "/" not in model:
|
||||
return model
|
||||
prefix, remainder = model.split("/", 1)
|
||||
if prefix.lower().replace("-", "_") != spec_name:
|
||||
return model
|
||||
return f"{canonical_prefix}/{remainder}"
|
||||
|
||||
def _supports_cache_control(self, model: str) -> bool:
|
||||
"""Return True when the provider supports cache_control on content blocks."""
|
||||
if self._gateway is not None:
|
||||
return self._gateway.supports_prompt_caching
|
||||
spec = find_by_model(model)
|
||||
return spec is not None and spec.supports_prompt_caching
|
||||
|
||||
def _apply_cache_control(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||
"""Return copies of messages and tools with cache_control injected."""
|
||||
new_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
content = msg["content"]
|
||||
if isinstance(content, str):
|
||||
new_content = [{"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}]
|
||||
else:
|
||||
new_content = list(content)
|
||||
new_content[-1] = {**new_content[-1], "cache_control": {"type": "ephemeral"}}
|
||||
new_messages.append({**msg, "content": new_content})
|
||||
else:
|
||||
new_messages.append(msg)
|
||||
|
||||
new_tools = tools
|
||||
if tools:
|
||||
new_tools = list(tools)
|
||||
new_tools[-1] = {**new_tools[-1], "cache_control": {"type": "ephemeral"}}
|
||||
|
||||
return new_messages, new_tools
|
||||
|
||||
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
|
||||
"""Apply model-specific parameter overrides from the registry."""
|
||||
model_lower = model.lower()
|
||||
spec = find_by_model(model)
|
||||
if spec:
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]:
|
||||
"""Return provider-specific extra keys to preserve in request messages."""
|
||||
spec = find_by_model(original_model) or find_by_model(resolved_model)
|
||||
if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
|
||||
return _ANTHROPIC_EXTRA_KEYS
|
||||
return frozenset()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
||||
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
|
||||
if not isinstance(tool_call_id, str):
|
||||
return tool_call_id
|
||||
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
||||
return tool_call_id
|
||||
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
||||
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
|
||||
id_map: dict[str, str] = {}
|
||||
|
||||
def map_id(value: Any) -> Any:
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
||||
|
||||
for clean in sanitized:
|
||||
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
||||
# shortening, otherwise strict providers reject the broken linkage.
|
||||
if isinstance(clean.get("tool_calls"), list):
|
||||
normalized_tool_calls = []
|
||||
for tc in clean["tool_calls"]:
|
||||
if not isinstance(tc, dict):
|
||||
normalized_tool_calls.append(tc)
|
||||
continue
|
||||
tc_clean = dict(tc)
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
normalized_tool_calls.append(tc_clean)
|
||||
clean["tool_calls"] = normalized_tool_calls
|
||||
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return sanitized
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Send a chat completion request via LiteLLM.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: Optional list of tool definitions in OpenAI format.
|
||||
model: Model identifier (e.g., 'anthropic/claude-sonnet-4-5').
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool calls.
|
||||
"""
|
||||
original_model = model or self.default_model
|
||||
model = self._resolve_model(original_model)
|
||||
extra_msg_keys = self._extra_msg_keys(original_model, model)
|
||||
|
||||
if self._supports_cache_control(original_model):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
# Clamp max_tokens to at least 1 — negative or zero values cause
|
||||
# LiteLLM to reject the request with "max_tokens must be at least 1".
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
# Apply model-specific overrides (e.g. kimi-k2.5 temperature)
|
||||
self._apply_model_overrides(model, kwargs)
|
||||
|
||||
# Pass api_key directly — more reliable than env vars alone
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
|
||||
# Pass api_base for custom endpoints
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
|
||||
# Pass extra headers (e.g. APP-Code for AiHubMix)
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
kwargs["drop_params"] = True
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
# Return error as content for graceful handling
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: Any) -> LLMResponse:
|
||||
"""Parse LiteLLM response into our standard format."""
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
content = message.content
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
# Some providers (e.g. GitHub Copilot) split content and tool_calls
|
||||
# across multiple choices. Merge them so tool_calls are not lost.
|
||||
raw_tool_calls = []
|
||||
for ch in response.choices:
|
||||
msg = ch.message
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
raw_tool_calls.extend(msg.tool_calls)
|
||||
if ch.finish_reason in ("tool_calls", "stop"):
|
||||
finish_reason = ch.finish_reason
|
||||
if not content and msg.content:
|
||||
content = msg.content
|
||||
|
||||
if len(response.choices) > 1:
|
||||
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
|
||||
len(response.choices), len(raw_tool_calls))
|
||||
|
||||
tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
usage = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||
thinking_blocks = getattr(message, "thinking_blocks", None) or None
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason or "stop",
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=thinking_blocks,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
@@ -0,0 +1,316 @@
|
||||
"""OpenAI Codex Responses Provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||
DEFAULT_ORIGINATOR = "nanobot"
|
||||
|
||||
|
||||
class OpenAICodexProvider(LLMProvider):
|
||||
"""Use Codex OAuth to call the Responses API."""
|
||||
|
||||
def __init__(self, default_model: str = "openai-codex/gpt-5.1-codex"):
|
||||
super().__init__(api_key=None, api_base=None)
|
||||
self.default_model = default_model
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
) -> LLMResponse:
|
||||
model = model or self.default_model
|
||||
system_prompt, input_items = _convert_messages(messages)
|
||||
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
headers = _build_headers(token.account_id, token.access)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": _strip_model_prefix(model),
|
||||
"store": False,
|
||||
"stream": True,
|
||||
"instructions": system_prompt,
|
||||
"input": input_items,
|
||||
"text": {"verbosity": "medium"},
|
||||
"include": ["reasoning.encrypted_content"],
|
||||
"prompt_cache_key": _prompt_cache_key(messages),
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": True,
|
||||
}
|
||||
|
||||
if reasoning_effort:
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
|
||||
if tools:
|
||||
body["tools"] = _convert_tools(tools)
|
||||
|
||||
url = DEFAULT_CODEX_URL
|
||||
|
||||
try:
|
||||
try:
|
||||
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=True)
|
||||
except Exception as e:
|
||||
if "CERTIFICATE_VERIFY_FAILED" not in str(e):
|
||||
raise
|
||||
logger.warning("SSL certificate verification failed for Codex API; retrying with verify=False")
|
||||
content, tool_calls, finish_reason = await _request_codex(url, headers, body, verify=False)
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling Codex: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
|
||||
def _strip_model_prefix(model: str) -> str:
|
||||
if model.startswith("openai-codex/") or model.startswith("openai_codex/"):
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
def _build_headers(account_id: str, token: str) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"chatgpt-account-id": account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": DEFAULT_ORIGINATOR,
|
||||
"User-Agent": "nanobot (python)",
|
||||
"accept": "text/event-stream",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
async def _request_codex(
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
body: dict[str, Any],
|
||||
verify: bool,
|
||||
) -> tuple[str, list[ToolCallRequest], str]:
|
||||
async with httpx.AsyncClient(timeout=60.0, verify=verify) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=body) as response:
|
||||
if response.status_code != 200:
|
||||
text = await response.aread()
|
||||
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
|
||||
return await _consume_sse(response)
|
||||
|
||||
|
||||
def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI function-calling schema to Codex flat format."""
|
||||
converted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
params = fn.get("parameters") or {}
|
||||
converted.append({
|
||||
"type": "function",
|
||||
"name": name,
|
||||
"description": fn.get("description") or "",
|
||||
"parameters": params if isinstance(params, dict) else {},
|
||||
})
|
||||
return converted
|
||||
|
||||
|
||||
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system_prompt = content if isinstance(content, str) else ""
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
input_items.append(_convert_user_message(content))
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# Handle text first.
|
||||
if isinstance(content, str) and content:
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed",
|
||||
"id": f"msg_{idx}",
|
||||
}
|
||||
)
|
||||
# Then handle tool calls.
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = _split_tool_call_id(tool_call.get("id"))
|
||||
call_id = call_id or f"call_{idx}"
|
||||
item_id = item_id or f"fc_{idx}"
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
call_id, _ = _split_tool_call_id(msg.get("tool_call_id"))
|
||||
output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False)
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": output_text,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
return system_prompt, input_items
|
||||
|
||||
|
||||
def _convert_user_message(content: Any) -> dict[str, Any]:
|
||||
if isinstance(content, str):
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": content}]}
|
||||
if isinstance(content, list):
|
||||
converted: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") == "text":
|
||||
converted.append({"type": "input_text", "text": item.get("text", "")})
|
||||
elif item.get("type") == "image_url":
|
||||
url = (item.get("image_url") or {}).get("url")
|
||||
if url:
|
||||
converted.append({"type": "input_image", "image_url": url, "detail": "auto"})
|
||||
if converted:
|
||||
return {"role": "user", "content": converted}
|
||||
return {"role": "user", "content": [{"type": "input_text", "text": ""}]}
|
||||
|
||||
|
||||
def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
if isinstance(tool_call_id, str) and tool_call_id:
|
||||
if "|" in tool_call_id:
|
||||
call_id, item_id = tool_call_id.split("|", 1)
|
||||
return call_id, item_id or None
|
||||
return tool_call_id, None
|
||||
return "call_0", None
|
||||
|
||||
|
||||
def _prompt_cache_key(messages: list[dict[str, Any]]) -> str:
|
||||
raw = json.dumps(messages, ensure_ascii=True, sort_keys=True)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]:
|
||||
buffer: list[str] = []
|
||||
async for line in response.aiter_lines():
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")]
|
||||
buffer = []
|
||||
if not data_lines:
|
||||
continue
|
||||
data = "\n".join(data_lines).strip()
|
||||
if not data or data == "[DONE]":
|
||||
continue
|
||||
try:
|
||||
yield json.loads(data)
|
||||
except Exception:
|
||||
continue
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
|
||||
async def _consume_sse(response: httpx.Response) -> tuple[str, list[ToolCallRequest], str]:
|
||||
content = ""
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
tool_call_buffers: dict[str, dict[str, Any]] = {}
|
||||
finish_reason = "stop"
|
||||
|
||||
async for event in _iter_sse(response):
|
||||
event_type = event.get("type")
|
||||
if event_type == "response.output_item.added":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
tool_call_buffers[call_id] = {
|
||||
"id": item.get("id") or "fc_0",
|
||||
"name": item.get("name"),
|
||||
"arguments": item.get("arguments") or "",
|
||||
}
|
||||
elif event_type == "response.output_text.delta":
|
||||
content += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] += event.get("delta") or ""
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = event.get("call_id")
|
||||
if call_id and call_id in tool_call_buffers:
|
||||
tool_call_buffers[call_id]["arguments"] = event.get("arguments") or ""
|
||||
elif event_type == "response.output_item.done":
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") == "function_call":
|
||||
call_id = item.get("call_id")
|
||||
if not call_id:
|
||||
continue
|
||||
buf = tool_call_buffers.get(call_id) or {}
|
||||
args_raw = buf.get("arguments") or item.get("arguments") or "{}"
|
||||
try:
|
||||
args = json.loads(args_raw)
|
||||
except Exception:
|
||||
args = {"raw": args_raw}
|
||||
tool_calls.append(
|
||||
ToolCallRequest(
|
||||
id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}",
|
||||
name=buf.get("name") or item.get("name"),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
elif event_type == "response.completed":
|
||||
status = (event.get("response") or {}).get("status")
|
||||
finish_reason = _map_finish_reason(status)
|
||||
elif event_type in {"error", "response.failed"}:
|
||||
raise RuntimeError("Codex response failed")
|
||||
|
||||
return content, tool_calls, finish_reason
|
||||
|
||||
|
||||
_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"}
|
||||
|
||||
|
||||
def _map_finish_reason(status: str | None) -> str:
|
||||
return _FINISH_REASON_MAP.get(status or "completed", "stop")
|
||||
|
||||
|
||||
def _friendly_error(status_code: int, raw: str) -> str:
|
||||
if status_code == 429:
|
||||
return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later."
|
||||
return f"HTTP {status_code}: {raw}"
|
||||
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
Provider Registry — single source of truth for LLM provider metadata.
|
||||
|
||||
Adding a new provider:
|
||||
1. Add a ProviderSpec to PROVIDERS below.
|
||||
2. Add a field to ProvidersConfig in config/schema.py.
|
||||
Done. Env vars, prefixing, config matching, status display all derive from here.
|
||||
|
||||
Order matters — it controls match priority and fallback. Gateways first.
|
||||
Every entry writes out all fields so you can copy-paste as a template.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSpec:
|
||||
"""One LLM provider's metadata. See PROVIDERS below for real examples.
|
||||
|
||||
Placeholders in env_extras values:
|
||||
{api_key} — the user's API key
|
||||
{api_base} — api_base from config, or this spec's default_api_base
|
||||
"""
|
||||
|
||||
# identity
|
||||
name: str # config field name, e.g. "dashscope"
|
||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
||||
display_name: str = "" # shown in `nanobot status`
|
||||
|
||||
# model prefixing
|
||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
||||
|
||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||
env_extras: tuple[tuple[str, str], ...] = ()
|
||||
|
||||
# gateway / local detection
|
||||
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||
default_api_base: str = "" # fallback base URL
|
||||
|
||||
# gateway behavior
|
||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
||||
|
||||
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||
|
||||
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
||||
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
||||
|
||||
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
||||
is_direct: bool = False
|
||||
|
||||
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
|
||||
supports_prompt_caching: bool = False
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.display_name or self.name.title()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PROVIDERS — the registry. Order = priority. Copy any entry as template.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
||||
ProviderSpec(
|
||||
name="custom",
|
||||
keywords=(),
|
||||
env_key="",
|
||||
display_name="Custom",
|
||||
litellm_prefix="",
|
||||
is_direct=True,
|
||||
),
|
||||
|
||||
# === Azure OpenAI (direct API calls with API version 2024-10-21) =====
|
||||
ProviderSpec(
|
||||
name="azure_openai",
|
||||
keywords=("azure", "azure-openai"),
|
||||
env_key="",
|
||||
display_name="Azure OpenAI",
|
||||
litellm_prefix="",
|
||||
is_direct=True,
|
||||
),
|
||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||
# Gateways can route any model, so they win in fallback.
|
||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||
ProviderSpec(
|
||||
name="openrouter",
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai", # → openai/{model}
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||
model_overrides=(),
|
||||
),
|
||||
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
||||
ProviderSpec(
|
||||
name="siliconflow",
|
||||
keywords=("siliconflow",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="SiliconFlow",
|
||||
litellm_prefix="openai",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="siliconflow",
|
||||
default_api_base="https://api.siliconflow.cn/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
||||
ProviderSpec(
|
||||
name="volcengine",
|
||||
keywords=("volcengine", "volces", "ark"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="VolcEngine",
|
||||
litellm_prefix="volcengine",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="volces",
|
||||
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# === Standard providers (matched by model-name keywords) ===============
|
||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# OpenAI Codex: uses OAuth, not API key.
|
||||
ProviderSpec(
|
||||
name="openai_codex",
|
||||
keywords=("openai-codex",),
|
||||
env_key="", # OAuth-based, no API key
|
||||
display_name="OpenAI Codex",
|
||||
litellm_prefix="", # Not routed through LiteLLM
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="codex",
|
||||
default_api_base="https://chatgpt.com/backend-api",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
),
|
||||
# Github Copilot: uses OAuth, not API key.
|
||||
ProviderSpec(
|
||||
name="github_copilot",
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="", # OAuth-based, no API key
|
||||
display_name="Github Copilot",
|
||||
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
||||
skip_prefixes=("github_copilot/",),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
),
|
||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
env_key="DEEPSEEK_API_KEY",
|
||||
display_name="DeepSeek",
|
||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||
),
|
||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
# Detected when config key is "vllm" (provider_name="vllm").
|
||||
ProviderSpec(
|
||||
name="vllm",
|
||||
keywords=("vllm",),
|
||||
env_key="HOSTED_VLLM_API_KEY",
|
||||
display_name="vLLM/Local",
|
||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="", # user must provide in config
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# === Auxiliary (not a primary LLM provider) ============================
|
||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||
skip_prefixes=("groq/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lookup helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_by_model(model: str) -> ProviderSpec | None:
|
||||
"""Match a standard provider by model-name keyword (case-insensitive).
|
||||
Skips gateways/local — those are matched by api_key/api_base instead."""
|
||||
model_lower = model.lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
std_specs = [s for s in PROVIDERS if not s.is_gateway and not s.is_local]
|
||||
|
||||
# Prefer explicit provider prefix — prevents `github-copilot/...codex` matching openai_codex.
|
||||
for spec in std_specs:
|
||||
if model_prefix and normalized_prefix == spec.name:
|
||||
return spec
|
||||
|
||||
for spec in std_specs:
|
||||
if any(
|
||||
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
|
||||
):
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_gateway(
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""Detect gateway/local provider.
|
||||
|
||||
Priority:
|
||||
1. provider_name — if it maps to a gateway/local spec, use it directly.
|
||||
2. api_key prefix — e.g. "sk-or-" → OpenRouter.
|
||||
3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix.
|
||||
|
||||
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
|
||||
will NOT be mistaken for vLLM — the old fallback is gone.
|
||||
"""
|
||||
# 1. Direct match by config key
|
||||
if provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
# 2. Auto-detect by api_key prefix / api_base keyword
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_by_name(name: str) -> ProviderSpec | None:
|
||||
"""Find a provider spec by config field name, e.g. "dashscope"."""
|
||||
for spec in PROVIDERS:
|
||||
if spec.name == name:
|
||||
return spec
|
||||
return None
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Voice transcription provider using Groq."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GroqTranscriptionProvider:
|
||||
"""
|
||||
Voice transcription provider using Groq's Whisper API.
|
||||
|
||||
Groq offers extremely fast transcription with a generous free tier.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.environ.get("GROQ_API_KEY")
|
||||
self.api_url = "https://api.groq.com/openai/v1/audio/transcriptions"
|
||||
|
||||
async def transcribe(self, file_path: str | Path) -> str:
|
||||
"""
|
||||
Transcribe an audio file using Groq.
|
||||
|
||||
Args:
|
||||
file_path: Path to the audio file.
|
||||
|
||||
Returns:
|
||||
Transcribed text.
|
||||
"""
|
||||
if not self.api_key:
|
||||
logger.warning("Groq API key not configured for transcription")
|
||||
return ""
|
||||
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.error("Audio file not found: {}", file_path)
|
||||
return ""
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(path, "rb") as f:
|
||||
files = {
|
||||
"file": (path.name, f),
|
||||
"model": (None, "whisper-large-v3"),
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
files=files,
|
||||
timeout=60.0
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("text", "")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Groq transcription error: {}", e)
|
||||
return ""
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Session management module."""
|
||||
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
__all__ = ["SessionManager", "Session"]
|
||||
@@ -0,0 +1,213 @@
|
||||
"""Session management for conversation history."""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_legacy_sessions_dir
|
||||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""
|
||||
A conversation session.
|
||||
|
||||
Stores messages in JSONL format for easy reading and persistence.
|
||||
|
||||
Important: Messages are append-only for LLM cache efficiency.
|
||||
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
||||
but does NOT modify the messages list or get_history() output.
|
||||
"""
|
||||
|
||||
key: str # channel:chat_id
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||
"""Add a message to the session."""
|
||||
msg = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
**kwargs
|
||||
}
|
||||
self.messages.append(msg)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""Return unconsolidated messages for LLM input, aligned to a user turn."""
|
||||
unconsolidated = self.messages[self.last_consolidated:]
|
||||
sliced = unconsolidated[-max_messages:]
|
||||
|
||||
# Drop leading non-user messages to avoid orphaned tool_result blocks
|
||||
for i, m in enumerate(sliced):
|
||||
if m.get("role") == "user":
|
||||
sliced = sliced[i:]
|
||||
break
|
||||
|
||||
out: list[dict[str, Any]] = []
|
||||
for m in sliced:
|
||||
entry: dict[str, Any] = {"role": m["role"], "content": m.get("content", "")}
|
||||
for k in ("tool_calls", "tool_call_id", "name"):
|
||||
if k in m:
|
||||
entry[k] = m[k]
|
||||
out.append(entry)
|
||||
return out
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages and reset session to initial state."""
|
||||
self.messages = []
|
||||
self.last_consolidated = 0
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages conversation sessions.
|
||||
|
||||
Sessions are stored as JSONL files in the sessions directory.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace: Path):
|
||||
self.workspace = workspace
|
||||
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
||||
self.legacy_sessions_dir = get_legacy_sessions_dir()
|
||||
self._cache: dict[str, Session] = {}
|
||||
|
||||
def _get_session_path(self, key: str) -> Path:
|
||||
"""Get the file path for a session."""
|
||||
safe_key = safe_filename(key.replace(":", "_"))
|
||||
return self.sessions_dir / f"{safe_key}.jsonl"
|
||||
|
||||
def _get_legacy_session_path(self, key: str) -> Path:
|
||||
"""Legacy global session path (~/.nanobot/sessions/)."""
|
||||
safe_key = safe_filename(key.replace(":", "_"))
|
||||
return self.legacy_sessions_dir / f"{safe_key}.jsonl"
|
||||
|
||||
def get_or_create(self, key: str) -> Session:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
Args:
|
||||
key: Session key (usually channel:chat_id).
|
||||
|
||||
Returns:
|
||||
The session.
|
||||
"""
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
|
||||
session = self._load(key)
|
||||
if session is None:
|
||||
session = Session(key=key)
|
||||
|
||||
self._cache[key] = session
|
||||
return session
|
||||
|
||||
def _load(self, key: str) -> Session | None:
|
||||
"""Load a session from disk."""
|
||||
path = self._get_session_path(key)
|
||||
if not path.exists():
|
||||
legacy_path = self._get_legacy_session_path(key)
|
||||
if legacy_path.exists():
|
||||
try:
|
||||
shutil.move(str(legacy_path), str(path))
|
||||
logger.info("Migrated session {} from legacy path", key)
|
||||
except Exception:
|
||||
logger.exception("Failed to migrate session {}", key)
|
||||
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
messages = []
|
||||
metadata = {}
|
||||
created_at = None
|
||||
last_consolidated = 0
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
data = json.loads(line)
|
||||
|
||||
if data.get("_type") == "metadata":
|
||||
metadata = data.get("metadata", {})
|
||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||
last_consolidated = data.get("last_consolidated", 0)
|
||||
else:
|
||||
messages.append(data)
|
||||
|
||||
return Session(
|
||||
key=key,
|
||||
messages=messages,
|
||||
created_at=created_at or datetime.now(),
|
||||
metadata=metadata,
|
||||
last_consolidated=last_consolidated
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load session {}: {}", key, e)
|
||||
return None
|
||||
|
||||
def save(self, session: Session) -> None:
|
||||
"""Save a session to disk."""
|
||||
path = self._get_session_path(session.key)
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
metadata_line = {
|
||||
"_type": "metadata",
|
||||
"key": session.key,
|
||||
"created_at": session.created_at.isoformat(),
|
||||
"updated_at": session.updated_at.isoformat(),
|
||||
"metadata": session.metadata,
|
||||
"last_consolidated": session.last_consolidated
|
||||
}
|
||||
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
||||
for msg in session.messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
self._cache[session.key] = session
|
||||
|
||||
def invalidate(self, key: str) -> None:
|
||||
"""Remove a session from the in-memory cache."""
|
||||
self._cache.pop(key, None)
|
||||
|
||||
def list_sessions(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
List all sessions.
|
||||
|
||||
Returns:
|
||||
List of session info dicts.
|
||||
"""
|
||||
sessions = []
|
||||
|
||||
for path in self.sessions_dir.glob("*.jsonl"):
|
||||
try:
|
||||
# Read just the metadata line
|
||||
with open(path, encoding="utf-8") as f:
|
||||
first_line = f.readline().strip()
|
||||
if first_line:
|
||||
data = json.loads(first_line)
|
||||
if data.get("_type") == "metadata":
|
||||
key = data.get("key") or path.stem.replace("_", ":", 1)
|
||||
sessions.append({
|
||||
"key": key,
|
||||
"created_at": data.get("created_at"),
|
||||
"updated_at": data.get("updated_at"),
|
||||
"path": str(path)
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
|
||||
@@ -0,0 +1,25 @@
|
||||
# nanobot Skills
|
||||
|
||||
This directory contains built-in skills that extend nanobot's capabilities.
|
||||
|
||||
## Skill Format
|
||||
|
||||
Each skill is a directory containing a `SKILL.md` file with:
|
||||
- YAML frontmatter (name, description, metadata)
|
||||
- Markdown instructions for the agent
|
||||
|
||||
## Attribution
|
||||
|
||||
These skills are adapted from [OpenClaw](https://github.com/openclaw/openclaw)'s skill system.
|
||||
The skill format and metadata structure follow OpenClaw's conventions to maintain compatibility.
|
||||
|
||||
## Available Skills
|
||||
|
||||
| Skill | Description |
|
||||
|-------|-------------|
|
||||
| `github` | Interact with GitHub using the `gh` CLI |
|
||||
| `weather` | Get weather info using wttr.in and Open-Meteo |
|
||||
| `summarize` | Summarize URLs, files, and YouTube videos |
|
||||
| `tmux` | Remote-control tmux sessions |
|
||||
| `clawhub` | Search and install skills from ClawHub registry |
|
||||
| `skill-creator` | Create new skills |
|
||||
@@ -0,0 +1,53 @@
|
||||
---
|
||||
name: clawhub
|
||||
description: Search and install agent skills from ClawHub, the public skill registry.
|
||||
homepage: https://clawhub.ai
|
||||
metadata: {"nanobot":{"emoji":"🦞"}}
|
||||
---
|
||||
|
||||
# ClawHub
|
||||
|
||||
Public skill registry for AI agents. Search by natural language (vector search).
|
||||
|
||||
## When to use
|
||||
|
||||
Use this skill when the user asks any of:
|
||||
- "find a skill for …"
|
||||
- "search for skills"
|
||||
- "install a skill"
|
||||
- "what skills are available?"
|
||||
- "update my skills"
|
||||
|
||||
## Search
|
||||
|
||||
```bash
|
||||
npx --yes clawhub@latest search "web scraping" --limit 5
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
npx --yes clawhub@latest install <slug> --workdir ~/.nanobot/workspace
|
||||
```
|
||||
|
||||
Replace `<slug>` with the skill name from search results. This places the skill into `~/.nanobot/workspace/skills/`, where nanobot loads workspace skills from. Always include `--workdir`.
|
||||
|
||||
## Update
|
||||
|
||||
```bash
|
||||
npx --yes clawhub@latest update --all --workdir ~/.nanobot/workspace
|
||||
```
|
||||
|
||||
## List installed
|
||||
|
||||
```bash
|
||||
npx --yes clawhub@latest list --workdir ~/.nanobot/workspace
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
- Requires Node.js (`npx` comes with it).
|
||||
- No API key needed for search and install.
|
||||
- Login (`npx --yes clawhub@latest login`) is only required for publishing.
|
||||
- `--workdir ~/.nanobot/workspace` is critical — without it, skills install to the current directory instead of the nanobot workspace.
|
||||
- After install, remind the user to start a new session to load the skill.
|
||||
@@ -0,0 +1,57 @@
|
||||
---
|
||||
name: cron
|
||||
description: Schedule reminders and recurring tasks.
|
||||
---
|
||||
|
||||
# Cron
|
||||
|
||||
Use the `cron` tool to schedule reminders or recurring tasks.
|
||||
|
||||
## Three Modes
|
||||
|
||||
1. **Reminder** - message is sent directly to user
|
||||
2. **Task** - message is a task description, agent executes and sends result
|
||||
3. **One-time** - runs once at a specific time, then auto-deletes
|
||||
|
||||
## Examples
|
||||
|
||||
Fixed reminder:
|
||||
```
|
||||
cron(action="add", message="Time to take a break!", every_seconds=1200)
|
||||
```
|
||||
|
||||
Dynamic task (agent executes each time):
|
||||
```
|
||||
cron(action="add", message="Check HKUDS/nanobot GitHub stars and report", every_seconds=600)
|
||||
```
|
||||
|
||||
One-time scheduled task (compute ISO datetime from current time):
|
||||
```
|
||||
cron(action="add", message="Remind me about the meeting", at="<ISO datetime>")
|
||||
```
|
||||
|
||||
Timezone-aware cron:
|
||||
```
|
||||
cron(action="add", message="Morning standup", cron_expr="0 9 * * 1-5", tz="America/Vancouver")
|
||||
```
|
||||
|
||||
List/remove:
|
||||
```
|
||||
cron(action="list")
|
||||
cron(action="remove", job_id="abc123")
|
||||
```
|
||||
|
||||
## Time Expressions
|
||||
|
||||
| User says | Parameters |
|
||||
|-----------|------------|
|
||||
| every 20 minutes | every_seconds: 1200 |
|
||||
| every hour | every_seconds: 3600 |
|
||||
| every day at 8am | cron_expr: "0 8 * * *" |
|
||||
| weekdays at 5pm | cron_expr: "0 17 * * 1-5" |
|
||||
| 9am Vancouver time daily | cron_expr: "0 9 * * *", tz: "America/Vancouver" |
|
||||
| at a specific time | at: ISO datetime string (compute from current time) |
|
||||
|
||||
## Timezone
|
||||
|
||||
Use `tz` with `cron_expr` to schedule in a specific IANA timezone. Without `tz`, the server's local timezone is used.
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user