Update 2026-05-13 16:43:53

This commit is contained in:
yi
2026-05-13 16:43:53 +08:00
parent 6af5c584f4
commit afd7c5fe85
490 changed files with 850 additions and 922 deletions
View File
+203
View File
@@ -0,0 +1,203 @@
import mimetypes
import re
from pathlib import Path
from typing import Any, Iterable
from urllib.parse import quote
from pydantic import BaseModel
from app.core.data_root import get_data_root, get_reports_root, get_uploads_root, get_workspace_root
LOCAL_URI_PATTERN = re.compile(r"local://[^\s<>'\"\]\)\}]+")
PATH_PATTERN = re.compile(
r"(?:[A-Za-z]:[\\/][^\s<>'\"]+\.[A-Za-z0-9]{1,12}|/[^\s<>'\"]+\.[A-Za-z0-9]{1,12}|(?:\.\./|\.?/)?(?:[\w\-.]+[\\/])+[\w\-.]+\.[A-Za-z0-9]{1,12})"
)
REPORT_PATH_PATTERN = re.compile(r"data[\\/]data[\\/][\w\-.]+\.[A-Za-z0-9]{1,12}", re.IGNORECASE)
PREVIEWABLE_EXTENSIONS = {
".html",
".htm",
".pdf",
".pptx",
".txt",
".md",
".json",
".csv",
".tsv",
".yaml",
".yml",
".xml",
".log",
}
class ArtifactPayload(BaseModel):
name: str
mime_type: str
size: int
download_url: str
previewable: bool
preview_url: str | None = None
def extract_artifacts(content: str, session_messages: list[dict[str, Any]] | None = None) -> list[dict[str, Any]]:
candidates = _collect_candidate_texts(content, session_messages or [])
ordered_locators: list[str] = []
seen_locators: set[str] = set()
for text in candidates:
for locator in _extract_locators(text):
if locator in seen_locators:
continue
seen_locators.add(locator)
ordered_locators.append(locator)
artifacts: list[dict[str, Any]] = []
seen_paths: set[Path] = set()
for locator in ordered_locators:
path = _resolve_locator(locator)
if not path or not path.exists() or not path.is_file():
continue
resolved = path.resolve()
if resolved in seen_paths:
continue
seen_paths.add(resolved)
artifact = _build_artifact_payload(locator, resolved)
artifacts.append(artifact.model_dump(exclude_none=True))
return artifacts
def _build_artifact_payload(locator: str, path: Path) -> ArtifactPayload:
mime_type = _guess_mime_type(path)
previewable = _is_previewable(path, mime_type)
encoded = quote(locator, safe="")
preview_url = f"/nanobot/artifacts/preview?target={encoded}" if previewable else None
return ArtifactPayload(
name=path.name,
mime_type=mime_type,
size=path.stat().st_size,
download_url=f"/nanobot/artifacts/download?target={encoded}",
previewable=previewable,
preview_url=preview_url,
)
def _guess_mime_type(path: Path) -> str:
mime_type, _ = mimetypes.guess_type(path.name)
return mime_type or "application/octet-stream"
def _is_previewable(path: Path, mime_type: str) -> bool:
if mime_type.startswith("image/") or mime_type.startswith("text/"):
return True
extension = path.suffix.lower()
if extension in PREVIEWABLE_EXTENSIONS:
return True
return mime_type in {
"application/pdf",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
}
def _collect_candidate_texts(content: str, session_messages: list[dict[str, Any]]) -> list[str]:
texts = [content or ""]
if not session_messages:
return texts
last_user_idx = -1
for idx, message in enumerate(session_messages):
if message.get("role") == "user":
last_user_idx = idx
if last_user_idx == -1:
segment = session_messages
else:
segment = session_messages[last_user_idx + 1 :]
for message in segment:
raw = message.get("content")
flattened = _flatten_content(raw)
if flattened:
texts.append(flattened)
return texts
def _extract_locators(text: str) -> Iterable[str]:
if not text:
return []
ordered: list[str] = []
seen: set[str] = set()
patterns = (LOCAL_URI_PATTERN, REPORT_PATH_PATTERN, PATH_PATTERN)
for pattern in patterns:
for match in pattern.findall(text):
normalized = _normalize_locator(match)
if not normalized or normalized in seen:
continue
seen.add(normalized)
ordered.append(normalized)
return ordered
def _normalize_locator(raw_locator: str) -> str:
locator = raw_locator.strip().strip("`'\"")
locator = locator.rstrip(".,;:!?)]}")
return locator
def _resolve_locator(locator: str) -> Path | None:
data_root = get_data_root()
workspace_root = get_workspace_root()
uploads_root = get_uploads_root()
reports_root = get_reports_root()
repo_root = data_root.parent
if locator.startswith("local://"):
raw_local = locator.replace("local://", "", 1).strip().lstrip("/\\")
if not raw_local:
return None
candidate = Path(raw_local)
if candidate.is_absolute():
return candidate
checks = [workspace_root / candidate, reports_root / candidate, uploads_root / candidate, uploads_root / candidate.name]
for path in checks:
if path.exists():
return path
return uploads_root / candidate.name
normalized = locator.replace("\\", "/")
path = Path(locator)
if path.is_absolute():
return path
if normalized.startswith("data/data/"):
return repo_root / normalized
checks = [
workspace_root / normalized,
data_root / normalized,
repo_root / normalized,
]
for candidate in checks:
if candidate.exists():
return candidate
return None
def _flatten_content(value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value
if isinstance(value, list):
fragments: list[str] = []
for item in value:
flattened = _flatten_content(item)
if flattened:
fragments.append(flattened)
return "\n".join(fragments)
if isinstance(value, dict):
fragments: list[str] = []
text = value.get("text")
if isinstance(text, str):
fragments.append(text)
content = value.get("content")
if content is not None:
nested = _flatten_content(content)
if nested:
fragments.append(nested)
for field in ("path", "file", "file_path", "url"):
data = value.get(field)
if isinstance(data, str):
fragments.append(data)
return "\n".join(fragments)
return str(value)
+39
View File
@@ -0,0 +1,39 @@
import os
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[2]
REPO_ROOT = BACKEND_ROOT.parent
DEFAULT_DATA_ROOT = REPO_ROOT / "data"
LEGACY_DATA_ROOT = BACKEND_ROOT / "data"
def get_data_root() -> Path:
configured = (os.getenv("DATA_ROOT") or "").strip()
if configured:
return Path(configured).expanduser().resolve()
if DEFAULT_DATA_ROOT.exists():
return DEFAULT_DATA_ROOT
if LEGACY_DATA_ROOT.exists():
print(f"[DATA_ROOT] legacy path detected: {LEGACY_DATA_ROOT}. Please migrate to {DEFAULT_DATA_ROOT}.")
return LEGACY_DATA_ROOT
return DEFAULT_DATA_ROOT
def get_workspace_root() -> Path:
return get_data_root() / "workspace"
def get_uploads_root() -> Path:
return get_data_root() / "uploads"
def get_reports_root() -> Path:
return get_data_root() / "data"
def ensure_data_layout() -> None:
get_data_root().mkdir(parents=True, exist_ok=True)
get_workspace_root().mkdir(parents=True, exist_ok=True)
get_uploads_root().mkdir(parents=True, exist_ok=True)
get_reports_root().mkdir(parents=True, exist_ok=True)
+43
View File
@@ -0,0 +1,43 @@
import smtplib
import os
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
def send_verification_email(to_email: str, token: str):
smtp_host = os.getenv("SMTP_HOST", "smtp.qq.com")
smtp_port = int(os.getenv("SMTP_PORT", "465"))
smtp_user = os.getenv("SMTP_USER", "")
smtp_password = os.getenv("SMTP_PASSWORD", "")
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173")
if not smtp_user or not smtp_password:
print("SMTP configuration is missing. Skip sending email.")
return
msg = MIMEMultipart()
msg['From'] = smtp_user
msg['To'] = to_email
msg['Subject'] = "请验证你的邮箱地址"
verify_link = f"{frontend_url}/verify-email?token={token}"
body = f"""
<html>
<body>
<h2>欢迎使用全源灵动!</h2>
<p>请点击下方链接验证邮箱并激活账号:</p>
<p><a href="{verify_link}">{verify_link}</a></p>
<p>如果你没有发起该请求,请忽略此邮件。</p>
</body>
</html>
"""
msg.attach(MIMEText(body, 'html'))
try:
# Use SMTP_SSL for port 465
server = smtplib.SMTP_SSL(smtp_host, smtp_port)
server.login(smtp_user, smtp_password)
server.send_message(msg)
server.quit()
print(f"Verification email sent to {to_email}")
except Exception as e:
print(f"Failed to send email: {e}")
+96
View File
@@ -0,0 +1,96 @@
import os
from pathlib import Path
from typing import Optional
from app.core.data_root import (
BACKEND_ROOT,
LEGACY_DATA_ROOT,
get_data_root,
get_reports_root,
get_uploads_root,
get_workspace_root,
)
data_root = get_data_root()
workspace_root = get_workspace_root()
uploads_root = get_uploads_root()
reports_root = get_reports_root()
legacy_workspace_root = LEGACY_DATA_ROOT / "workspace"
legacy_uploads_root = LEGACY_DATA_ROOT / "uploads"
legacy_reports_root = LEGACY_DATA_ROOT / "data"
backend_root = BACKEND_ROOT
allowed_artifact_roots = (
workspace_root,
uploads_root,
reports_root,
legacy_workspace_root,
legacy_uploads_root,
legacy_reports_root,
)
def resolve_upload_file_path(file_url: Optional[str]) -> Path:
if not file_url:
raise ValueError("File URL is empty")
if file_url.startswith("local://"):
raw_name = file_url.replace("local://", "", 1)
safe_name = os.path.basename(raw_name)
file_path = uploads_root / safe_name
return file_path
return Path(file_url)
def resolve_artifact_target(target: str) -> Path | None:
locator = (target or "").strip().strip("'\"")
if not locator:
return None
if locator.startswith("local://"):
raw_local = locator.replace("local://", "", 1).strip().lstrip("/\\")
if not raw_local:
return None
candidate = Path(raw_local)
if candidate.is_absolute():
return candidate
checks = (
workspace_root / candidate,
reports_root / candidate,
uploads_root / candidate,
uploads_root / candidate.name,
)
for path in checks:
if path.exists():
return path
return uploads_root / candidate.name
normalized = locator.replace("\\", "/")
path = Path(locator)
if path.is_absolute():
return path
if normalized.startswith("data/data/"):
return data_root.parent / normalized
checks = (
workspace_root / normalized,
data_root / normalized,
backend_root / normalized,
)
for candidate in checks:
if candidate.exists():
return candidate
return None
def ensure_artifact_access(path: Path, *, require_file: bool = True) -> Path:
try:
resolved = path.resolve(strict=True)
except FileNotFoundError as exc:
raise FileNotFoundError("目标文件不存在") from exc
if require_file and not resolved.is_file():
raise FileNotFoundError("目标文件不存在")
if not require_file and not resolved.is_dir():
raise FileNotFoundError("目标目录不存在")
for root in allowed_artifact_roots:
if resolved.is_relative_to(root.resolve()):
return resolved
raise PermissionError("非法路径访问")
+87
View File
@@ -0,0 +1,87 @@
import os
from typing import Optional, Dict
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.registry import find_by_name
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
def normalize_provider_name(provider: Optional[str]) -> Optional[str]:
if not provider:
return None
normalized = provider.strip().lower()
alias_map = {
"azure": "azure_openai",
"local": "vllm",
}
return alias_map.get(normalized, normalized)
def _running_in_docker() -> bool:
# Best-effort, cross-platform detection.
if os.environ.get("DATACLAW_RUNNING_IN_DOCKER", "").strip().lower() in ("1", "true", "yes", "y"):
return True
return os.path.exists("/.dockerenv")
def _rewrite_localhost_api_base(api_base: Optional[str]) -> Optional[str]:
"""
When running inside Docker, `localhost` points to the container itself.
For host-local LLMs (Ollama/vLLM), users often configure `http://localhost:...`,
which breaks in containers. We rewrite it to `host.docker.internal`.
"""
if not api_base:
return api_base
base = api_base.strip()
if base.startswith("http://localhost") or base.startswith("https://localhost"):
return base.replace("://localhost", "://host.docker.internal", 1)
if base.startswith("http://127.0.0.1") or base.startswith("https://127.0.0.1"):
return base.replace("://127.0.0.1", "://host.docker.internal", 1)
return api_base
def build_llm_provider(
*,
model: str,
provider: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
):
provider_name = normalize_provider_name(provider)
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
if _running_in_docker():
api_base = _rewrite_localhost_api_base(api_base)
if backend == "openai_codex" or model.startswith("openai-codex/"):
return OpenAICodexProvider(default_model=model)
if backend == "azure_openai":
if not api_key or not api_base:
raise ValueError("Azure OpenAI requires api_key and api_base.")
return AzureOpenAIProvider(
api_key=api_key,
api_base=api_base,
default_model=model,
)
if backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider(
api_key=api_key,
api_base=api_base,
default_model=model,
extra_headers=extra_headers,
)
return PatchedOpenAICompatProvider(
api_key=api_key,
api_base=api_base,
default_model=model,
extra_headers=extra_headers,
spec=spec,
)
+499
View File
@@ -0,0 +1,499 @@
import asyncio
import sys
import os
import shutil
from pathlib import Path
from typing import List, Callable, Awaitable, Any, Dict
# Add project root to sys.path to allow importing nanobot
# Assuming backend/app/core/nanobot.py -> backend/app/core -> backend/app -> backend -> root
# This path calculation seems correct for backend/app/core/nanobot.py relative to backend/
# BUT nanobot package is in ../nanobot relative to backend/
# So we need to go up one more level to reach the parent of backend/
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT / "agent-core") not in sys.path:
sys.path.append(str(PROJECT_ROOT / "agent-core"))
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.config.loader import load_config
from nanobot.cron.service import CronService
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.base import GenerationSettings
from nanobot.providers.registry import find_by_name
from nanobot.session.manager import SessionManager
from nanobot.config.schema import Config
# Import skills loader
# We use a lazy import inside the method to avoid potential circular dependencies if any arise,
# or just import here if we are confident.
# Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py.
from app.api.skills import load_skills
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
from app.core.llm_provider import _rewrite_localhost_api_base, _running_in_docker
from app.services.llm_cache import get_llm_configs, get_active_llm_config
from app.services.web_search_config_store import get_web_search_config
from app.core.data_root import get_workspace_root
from app.trace import build_error_attributes, build_usage_attributes, trace_service
class NanobotIntegration:
def __init__(self):
self.agent: AgentLoop | None = None
self.bus: MessageBus | None = None
self.cron: CronService | None = None
self.config: Config | None = None
self._started = False
self._model_agent_cache: Dict[tuple[str | None, int | None], AgentLoop] = {}
self._model_agent_lock = asyncio.Lock()
self._last_usage_by_session: Dict[str, Dict[str, Any]] = {}
@staticmethod
def _normalize_config_value(value: Any) -> Any:
if isinstance(value, str):
stripped = value.strip()
return stripped or None
return value
@staticmethod
def _normalize_model_id(value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
stripped = value.strip()
return stripped or None
return str(value)
@staticmethod
def _extract_response_text(response: Any) -> str:
if response is None:
return ""
if isinstance(response, str):
return response
if isinstance(response, OutboundMessage):
return response.content or ""
if isinstance(response, dict):
content = response.get("content")
if isinstance(content, str):
return content
return str(content or "")
content = getattr(response, "content", None)
if isinstance(content, str):
return content
return str(response)
@staticmethod
def _normalize_usage(usage: Any) -> Dict[str, int] | None:
if not isinstance(usage, dict):
return None
normalized: Dict[str, int] = {}
prompt = int(usage.get("prompt_tokens", 0) or 0)
completion = int(usage.get("completion_tokens", 0) or 0)
total = int(usage.get("total_tokens", 0) or 0)
# If total_tokens is missing or zero, calculate it
if total == 0:
total = prompt + completion
normalized["prompt_tokens"] = prompt
normalized["completion_tokens"] = completion
normalized["total_tokens"] = total
return normalized if (prompt > 0 or completion > 0) else None
def get_last_usage(self, session_id: str) -> Dict[str, int] | None:
usage = self._last_usage_by_session.get(session_id)
return dict(usage) if usage else None
def _get_web_search_config(self) -> Any:
from nanobot.config.schema import WebSearchConfig
ws_dict = get_web_search_config()
return WebSearchConfig(
provider=ws_dict.get("provider", "duckduckgo"),
api_key=ws_dict.get("api_key", ""),
base_url=ws_dict.get("base_url", ""),
max_results=ws_dict.get("max_results", 5)
)
def _need_custom_agent_for_target(self, target_config: Dict[str, Any]) -> bool:
if not self.agent:
return False
provider = self.agent.provider
target_model = self._normalize_config_value(target_config.get("model"))
current_model = self._normalize_config_value(
getattr(self.agent, "model", None) or getattr(provider, "default_model", None)
)
if target_model != current_model:
return True
target_provider = self._normalize_config_value(target_config.get("provider"))
current_provider = self._normalize_config_value(getattr(provider, "_provider_name_override", None))
if not current_provider:
current_provider = self._normalize_config_value(getattr(getattr(provider, "_spec", None), "name", None))
if not current_provider and current_model and self.config:
current_provider = self._normalize_config_value(self.config.get_provider_name(current_model))
if target_provider != current_provider:
return True
target_api_base = self._normalize_config_value(target_config.get("api_base"))
current_api_base = self._normalize_config_value(getattr(provider, "api_base", None))
if target_api_base != current_api_base:
return True
target_api_key = self._normalize_config_value(target_config.get("api_key"))
current_api_key = self._normalize_config_value(getattr(provider, "api_key", None))
if target_api_key != current_api_key:
return True
target_headers = target_config.get("extra_headers") or {}
current_headers = getattr(provider, "extra_headers", None) or {}
return target_headers != current_headers
def initialize(self):
workspace_path = get_workspace_root()
workspace_path.mkdir(parents=True, exist_ok=True)
self._sync_builtin_skills_to_workspace(workspace_path)
# Override config workspace path via environment variable (since config is loaded from env)
os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path)
self.config = load_config()
# No need to set self.config.workspace_path as it's a property that reads from agents.defaults.workspace
self.bus = MessageBus()
active_config = get_active_llm_config()
initial_model = self.config.agents.defaults.model
if active_config and active_config.get("model"):
provider = self._make_provider_from_target(active_config)
initial_model = self._normalize_config_value(active_config.get("model")) or initial_model
else:
provider = self._make_provider(self.config)
cron_store_path = workspace_path / "cron"
cron_store_path.mkdir(parents=True, exist_ok=True)
cron_store_file = cron_store_path / "jobs.json"
self.cron = CronService(cron_store_file)
session_manager = SessionManager(self.config.workspace_path)
self.agent = AgentLoop(
bus=self.bus,
provider=provider,
workspace=self.config.workspace_path,
model=initial_model,
max_iterations=self.config.agents.defaults.max_tool_iterations,
context_window_tokens=self.config.agents.defaults.context_window_tokens,
web_search_config=self._get_web_search_config(),
web_proxy=self.config.tools.web.proxy or None,
exec_config=self.config.tools.exec,
cron_service=self.cron,
restrict_to_workspace=self.config.tools.restrict_to_workspace,
session_manager=session_manager,
mcp_servers=self.config.tools.mcp_servers,
channels_config=self.config.channels,
timezone=self.config.agents.defaults.timezone,
)
self._register_custom_tools(self.agent)
def _sync_builtin_skills_to_workspace(self, workspace_path: Path) -> None:
builtin_root = Path(__file__).resolve().parents[1] / "skills_builtin"
workspace_skills_root = workspace_path / "skills"
workspace_skills_root.mkdir(parents=True, exist_ok=True)
for skill_name in ("nl2sql", "visualization", "knowledge-base"):
source_dir = builtin_root / skill_name
source_skill_file = source_dir / "SKILL.md"
if not source_skill_file.exists():
continue
target_dir = workspace_skills_root / skill_name
target_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(source_skill_file, target_dir / "SKILL.md")
def _register_custom_tools(self, agent: AgentLoop, project_id: int | None = None):
from app.tools.nl2sql import NL2SQLTool
from app.tools.visualization import VisualizationTool
from app.tools.get_schema import GetDatabaseSchemaTool
from app.tools.knowledge_base import KnowledgeBaseRetrieveTool
from app.tools.subagent import ListSubagentsTool, InvokeSubagentTool
agent.tools.register(NL2SQLTool())
agent.tools.register(VisualizationTool())
agent.tools.register(GetDatabaseSchemaTool())
agent.tools.register(KnowledgeBaseRetrieveTool())
agent.tools.register(ListSubagentsTool(project_id=project_id))
agent.tools.register(InvokeSubagentTool(project_id=project_id))
def _build_provider(
self,
model: str,
provider_name: str | None,
api_key: str | None,
api_base: str | None,
extra_headers: dict[str, Any] | None = None,
):
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
if _running_in_docker():
api_base = _rewrite_localhost_api_base(api_base)
if backend == "openai_codex" or model.startswith("openai-codex/"):
return OpenAICodexProvider(default_model=model)
if backend == "azure_openai":
if not api_key or not api_base:
raise ValueError("Azure OpenAI requires api_key and api_base.")
return AzureOpenAIProvider(
api_key=api_key,
api_base=api_base,
default_model=model,
)
if backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
return AnthropicProvider(
api_key=api_key,
api_base=api_base,
default_model=model,
extra_headers=extra_headers,
)
return PatchedOpenAICompatProvider(
api_key=api_key,
api_base=api_base,
default_model=model,
extra_headers=extra_headers,
spec=spec,
)
def _make_provider(self, config: Config):
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
p = config.get_provider(model)
provider = self._build_provider(
model=model,
provider_name=provider_name,
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
extra_headers=p.extra_headers if p else None,
)
provider.generation = GenerationSettings(
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
reasoning_effort=config.agents.defaults.reasoning_effort,
)
return provider
def _make_provider_from_target(self, target_config: Dict[str, Any]):
model = self._normalize_config_value(target_config.get("model")) or self.config.agents.defaults.model
provider_name = self._normalize_config_value(target_config.get("provider"))
if not provider_name and model and self.config:
provider_name = self._normalize_config_value(self.config.get_provider_name(model))
provider = self._build_provider(
model=model,
provider_name=provider_name,
api_key=self._normalize_config_value(target_config.get("api_key")),
api_base=self._normalize_config_value(target_config.get("api_base")),
extra_headers=target_config.get("extra_headers"),
)
provider.generation = GenerationSettings(
temperature=self.config.agents.defaults.temperature,
max_tokens=self.config.agents.defaults.max_tokens,
reasoning_effort=self.config.agents.defaults.reasoning_effort,
)
return provider
async def start(self):
if self._started:
return
if not self.agent:
self.initialize()
asyncio.create_task(self.agent.run())
asyncio.create_task(self.cron.start())
self._started = True
async def stop(self):
if self.agent:
self.agent.stop()
await self.agent.close_mcp()
for agent in self._model_agent_cache.values():
agent.stop()
await agent.close_mcp()
self._model_agent_cache.clear()
if self.cron:
self.cron.stop()
self._started = False
def _build_agent_for_provider(self, provider: Any, mcp_servers: dict | None = None) -> AgentLoop:
return AgentLoop(
bus=self.bus,
provider=provider,
workspace=self.config.workspace_path,
model=provider.default_model,
max_iterations=self.config.agents.defaults.max_tool_iterations,
context_window_tokens=self.config.agents.defaults.context_window_tokens,
web_search_config=self._get_web_search_config(),
web_proxy=self.config.tools.web.proxy or None,
exec_config=self.config.tools.exec,
cron_service=self.cron,
restrict_to_workspace=self.config.tools.restrict_to_workspace,
session_manager=self.agent.sessions if self.agent else None,
mcp_servers=mcp_servers if mcp_servers is not None else self.config.tools.mcp_servers,
channels_config=self.config.channels,
timezone=self.config.agents.defaults.timezone,
)
async def _get_or_create_model_agent(self, model_id: str | None, target_config: Dict[str, Any] | None, project_id: int | None = None) -> AgentLoop:
normalized_model_id = self._normalize_model_id(model_id)
cache_key = (normalized_model_id, project_id)
async with self._model_agent_lock:
cached = self._model_agent_cache.get(cache_key)
if cached:
return cached
if target_config:
provider = self._make_provider_from_target(target_config)
else:
provider = self._make_provider(self.config)
mcp_servers_dict = dict(self.config.tools.mcp_servers) if self.config.tools.mcp_servers else {}
if project_id is not None:
from app.api.mcp import list_mcp_servers
from nanobot.config.schema import MCPServerConfig
servers = await list_mcp_servers(project_id=project_id)
for s in servers:
cfg = MCPServerConfig(
type=s.get("type"),
command=s.get("command") or "",
args=s.get("args") or [],
env=s.get("env") or {},
url=s.get("url") or "",
headers=s.get("headers") or {}
)
mcp_servers_dict[s["name"]] = cfg
agent = self._build_agent_for_provider(provider, mcp_servers=mcp_servers_dict)
self._register_custom_tools(agent, project_id=project_id)
self._model_agent_cache[cache_key] = agent
return agent
async def process_message(
self,
message: str,
session_id: str = "api:default",
skill_ids: List[str] | None = None,
model_id: str | None = None,
project_id: int | None = None,
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
):
span_attributes = {
"session_id": session_id,
"project_id": project_id,
"model_id": model_id,
"component": "nanobot.process_message",
}
with trace_service.start_span(
"nanobot.process_message",
attributes=span_attributes,
input_payload={"message": message},
) as root_span:
try:
if not self.agent:
self.initialize()
if not self._started:
await self.start()
if project_id is None:
from app.core.session_alias_store import session_alias_store
alias_meta = session_alias_store.get_alias_meta(session_id)
if alias_meta and alias_meta.get("project_id") is not None:
project_id = alias_meta.get("project_id")
root_span.set_attributes({"project_id": project_id})
agent_to_use = self.agent
need_custom_agent = False
target_config = None
selected_model_id = self._normalize_model_id(model_id)
if selected_model_id:
llm_configs = get_llm_configs()
target_config = next(
(item for item in llm_configs if self._normalize_model_id(item.get("id")) == selected_model_id),
None,
)
if target_config is None:
active_config = get_active_llm_config()
if active_config and active_config.get("id"):
selected_model_id = self._normalize_model_id(active_config.get("id"))
target_config = active_config
if target_config and self._need_custom_agent_for_target(target_config):
need_custom_agent = True
if project_id is not None:
need_custom_agent = True
with trace_service.start_span(
"nanobot.resolve_agent",
attributes={
"session_id": session_id,
"project_id": project_id,
"selected_model_id": selected_model_id,
"custom_agent": need_custom_agent,
},
):
if need_custom_agent:
agent_to_use = await self._get_or_create_model_agent(selected_model_id, target_config, project_id)
session = agent_to_use.sessions.get_or_create(session_id)
normalized_messages = self._normalize_session_messages(session.messages)
if len(normalized_messages) != len(session.messages):
session.messages = normalized_messages
agent_to_use.sessions.save(session)
with trace_service.start_span(
"nanobot.process_direct",
attributes={
"session_id": session_id,
"model": getattr(agent_to_use, "model", None),
},
) as direct_span:
response = await agent_to_use.process_direct(
message,
session_key=session_id,
channel="api",
chat_id=session_id,
on_progress=on_progress,
on_stream=on_stream,
)
usage = self._normalize_usage(getattr(agent_to_use, "_last_usage", None))
if usage:
self._last_usage_by_session[session_id] = usage
direct_span.set_attributes(build_usage_attributes(usage))
root_span.set_attributes(build_usage_attributes(usage))
text = self._extract_response_text(response)
direct_span.update(output={"content": text})
root_span.update(output={"content": text})
return text
except Exception as exc:
root_span.set_attributes(build_error_attributes(exc, stage="nanobot_process_message"))
root_span.record_error(exc, stage="nanobot_process_message")
raise
def _normalize_session_messages(self, messages: List[Any]) -> List[dict[str, Any]]:
normalized: List[dict[str, Any]] = []
stack: List[Any] = list(messages)
while stack:
current = stack.pop(0)
if isinstance(current, dict):
normalized.append(current)
continue
if isinstance(current, list):
stack = list(current) + stack
return normalized
nanobot_service = NanobotIntegration()
@@ -0,0 +1,43 @@
from __future__ import annotations
from typing import Any
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
class PatchedOpenAICompatProvider(OpenAICompatProvider):
_MAX_COMPLETION_TOKEN_MODELS = ("gpt-5", "o1", "o3", "o4")
def _build_kwargs(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
model: str | None,
max_tokens: int,
temperature: float,
reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]:
kwargs = super()._build_kwargs(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
)
model_name = (model or self.default_model or "").lower()
spec = self._spec
supports_max_completion_tokens = bool(
spec and getattr(spec, "supports_max_completion_tokens", False)
)
should_use_max_completion_tokens = supports_max_completion_tokens or any(
token in model_name for token in self._MAX_COMPLETION_TOKEN_MODELS
)
if should_use_max_completion_tokens and "max_tokens" in kwargs:
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
return kwargs
+56
View File
@@ -0,0 +1,56 @@
from datetime import datetime, timedelta
from typing import Optional
from jose import jwt, JWTError
from passlib.context import CryptContext
from fastapi import HTTPException, Depends, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel
SECRET_KEY = "your-super-secret-key-for-dataclaw" # In production, use env variable
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 days
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
security = HTTPBearer()
class CurrentUser(BaseModel):
id: int
username: str
is_admin: bool = False
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser:
unauthorized = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
)
try:
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
except JWTError:
raise unauthorized
user_id = payload.get("id")
username = payload.get("sub")
is_admin = bool(payload.get("is_admin", False))
if user_id is None or username is None:
raise unauthorized
return CurrentUser(id=user_id, username=username, is_admin=is_admin)
def get_admin_user(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
if not current_user.is_admin:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
return current_user
@@ -0,0 +1,215 @@
from __future__ import annotations
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from app.core.data_root import get_data_root
class SessionAliasStore:
def __init__(self) -> None:
data_dir = get_data_root()
try:
data_dir.mkdir(parents=True, exist_ok=True)
except PermissionError as exc:
raise RuntimeError(f"DATA_ROOT 权限不足: {data_dir}") from exc
self.db_path = data_dir / "nanobot_sessions.db"
try:
self._init_db()
except PermissionError as exc:
raise RuntimeError(f"DATA_ROOT 权限不足: {data_dir}") from exc
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(str(self.db_path))
conn.row_factory = sqlite3.Row
return conn
def _init_db(self) -> None:
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS session_cache (
session_key TEXT PRIMARY KEY,
created_at TEXT,
updated_at TEXT,
alias TEXT,
pinned INTEGER NOT NULL DEFAULT 0,
archived INTEGER NOT NULL DEFAULT 0,
last_seen_at TEXT NOT NULL
)
"""
)
cols = {
str(row["name"])
for row in conn.execute("PRAGMA table_info(session_cache)").fetchall()
}
if "pinned" not in cols:
conn.execute("ALTER TABLE session_cache ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0")
if "archived" not in cols:
conn.execute("ALTER TABLE session_cache ADD COLUMN archived INTEGER NOT NULL DEFAULT 0")
if "project_id" not in cols:
conn.execute("ALTER TABLE session_cache ADD COLUMN project_id INTEGER")
def sync_sessions(self, sessions: list[dict[str, Any]]) -> None:
now = datetime.now(timezone.utc).isoformat()
keys: list[str] = []
with self._connect() as conn:
for item in sessions:
key = str(item.get("key") or "").strip()
if not key:
continue
keys.append(key)
created_at = str(item.get("created_at") or "")
updated_at = str(item.get("updated_at") or "")
conn.execute(
"""
INSERT INTO session_cache (session_key, created_at, updated_at, last_seen_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(session_key) DO UPDATE SET
created_at = excluded.created_at,
updated_at = excluded.updated_at,
last_seen_at = excluded.last_seen_at
""",
(key, created_at, updated_at, now),
)
if keys:
placeholders = ",".join("?" for _ in keys)
conn.execute(
f"DELETE FROM session_cache WHERE session_key NOT IN ({placeholders})",
keys,
)
else:
conn.execute("DELETE FROM session_cache")
def list_cached_sessions(self, project_id: int | None = None) -> list[dict[str, Any]]:
with self._connect() as conn:
if project_id is not None:
rows = conn.execute(
"""
SELECT session_key, created_at, updated_at, alias, pinned, archived, project_id
FROM session_cache
WHERE project_id = ? OR project_id IS NULL
ORDER BY pinned DESC, archived ASC, updated_at DESC
""",
(project_id,)
).fetchall()
else:
rows = conn.execute(
"""
SELECT session_key, created_at, updated_at, alias, pinned, archived, project_id
FROM session_cache
ORDER BY pinned DESC, archived ASC, updated_at DESC
"""
).fetchall()
return [self._row_to_session_item(row) for row in rows]
def sync_and_list(self, sessions: list[dict[str, Any]], project_id: int | None = None) -> list[dict[str, Any]]:
self.sync_sessions(sessions)
return self.list_cached_sessions(project_id)
def set_alias(self, session_key: str, alias: str) -> None:
now = datetime.now(timezone.utc).isoformat()
clean_alias = alias.strip()
with self._connect() as conn:
conn.execute(
"""
INSERT INTO session_cache (session_key, created_at, updated_at, alias, last_seen_at)
VALUES (?, '', '', ?, ?)
ON CONFLICT(session_key) DO UPDATE SET
alias = excluded.alias,
last_seen_at = excluded.last_seen_at
""",
(session_key, clean_alias, now),
)
def update_alias_meta(
self,
session_key: str,
alias: str | None = None,
pinned: bool | None = None,
archived: bool | None = None,
project_id: int | None = None,
) -> dict[str, Any]:
now = datetime.now(timezone.utc).isoformat()
with self._connect() as conn:
row = conn.execute(
"SELECT alias, pinned, archived, project_id FROM session_cache WHERE session_key = ?",
(session_key,),
).fetchone()
current_alias = (str(row["alias"]) if row and row["alias"] else "")
current_pinned = bool(row["pinned"]) if row else False
current_archived = bool(row["archived"]) if row else False
current_project_id = row["project_id"] if row and "project_id" in row.keys() else None
next_alias = current_alias if alias is None else alias.strip()
next_pinned = current_pinned if pinned is None else bool(pinned)
next_archived = current_archived if archived is None else bool(archived)
next_project_id = current_project_id if project_id is None else project_id
conn.execute(
"""
INSERT INTO session_cache (session_key, created_at, updated_at, alias, pinned, archived, project_id, last_seen_at)
VALUES (?, '', '', ?, ?, ?, ?, ?)
ON CONFLICT(session_key) DO UPDATE SET
alias = excluded.alias,
pinned = excluded.pinned,
archived = excluded.archived,
project_id = excluded.project_id,
last_seen_at = excluded.last_seen_at
""",
(session_key, next_alias, int(next_pinned), int(next_archived), next_project_id, now),
)
return {"alias": next_alias or None, "pinned": next_pinned, "archived": next_archived, "project_id": next_project_id}
def get_alias(self, session_key: str) -> str | None:
with self._connect() as conn:
row = conn.execute(
"SELECT alias FROM session_cache WHERE session_key = ?",
(session_key,),
).fetchone()
if not row:
return None
alias = row["alias"]
return str(alias) if alias else None
def get_alias_meta(self, session_key: str) -> dict[str, Any] | None:
with self._connect() as conn:
row = conn.execute(
"SELECT alias, pinned, archived, project_id FROM session_cache WHERE session_key = ?",
(session_key,),
).fetchone()
if not row:
return None
alias = (row["alias"] or "").strip()
return {
"alias": alias or None,
"pinned": bool(row["pinned"]) if "pinned" in row.keys() else False,
"archived": bool(row["archived"]) if "archived" in row.keys() else False,
"project_id": row["project_id"] if "project_id" in row.keys() else None,
}
def delete_session(self, session_key: str) -> None:
with self._connect() as conn:
conn.execute("DELETE FROM session_cache WHERE session_key = ?", (session_key,))
def _row_to_session_item(self, row: sqlite3.Row) -> dict[str, Any]:
alias = (row["alias"] or "").strip()
fallback = str(row["session_key"]).replace("api:", "")
title = alias or fallback
pinned = bool(row["pinned"]) if "pinned" in row.keys() else False
archived = bool(row["archived"]) if "archived" in row.keys() else False
project_id = row["project_id"] if "project_id" in row.keys() else None
return {
"key": row["session_key"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"metadata": {"title": title},
"alias": alias or None,
"pinned": pinned,
"archived": archived,
"project_id": project_id,
}
session_alias_store = SessionAliasStore()