Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,203 @@
|
||||
import mimetypes
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
from urllib.parse import quote
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.data_root import get_data_root, get_reports_root, get_uploads_root, get_workspace_root
|
||||
|
||||
LOCAL_URI_PATTERN = re.compile(r"local://[^\s<>'\"\]\)\}]+")
|
||||
PATH_PATTERN = re.compile(
|
||||
r"(?:[A-Za-z]:[\\/][^\s<>'\"]+\.[A-Za-z0-9]{1,12}|/[^\s<>'\"]+\.[A-Za-z0-9]{1,12}|(?:\.\./|\.?/)?(?:[\w\-.]+[\\/])+[\w\-.]+\.[A-Za-z0-9]{1,12})"
|
||||
)
|
||||
REPORT_PATH_PATTERN = re.compile(r"data[\\/]data[\\/][\w\-.]+\.[A-Za-z0-9]{1,12}", re.IGNORECASE)
|
||||
PREVIEWABLE_EXTENSIONS = {
|
||||
".html",
|
||||
".htm",
|
||||
".pdf",
|
||||
".pptx",
|
||||
".txt",
|
||||
".md",
|
||||
".json",
|
||||
".csv",
|
||||
".tsv",
|
||||
".yaml",
|
||||
".yml",
|
||||
".xml",
|
||||
".log",
|
||||
}
|
||||
|
||||
|
||||
class ArtifactPayload(BaseModel):
|
||||
name: str
|
||||
mime_type: str
|
||||
size: int
|
||||
download_url: str
|
||||
previewable: bool
|
||||
preview_url: str | None = None
|
||||
|
||||
|
||||
def extract_artifacts(content: str, session_messages: list[dict[str, Any]] | None = None) -> list[dict[str, Any]]:
|
||||
candidates = _collect_candidate_texts(content, session_messages or [])
|
||||
ordered_locators: list[str] = []
|
||||
seen_locators: set[str] = set()
|
||||
for text in candidates:
|
||||
for locator in _extract_locators(text):
|
||||
if locator in seen_locators:
|
||||
continue
|
||||
seen_locators.add(locator)
|
||||
ordered_locators.append(locator)
|
||||
artifacts: list[dict[str, Any]] = []
|
||||
seen_paths: set[Path] = set()
|
||||
for locator in ordered_locators:
|
||||
path = _resolve_locator(locator)
|
||||
if not path or not path.exists() or not path.is_file():
|
||||
continue
|
||||
resolved = path.resolve()
|
||||
if resolved in seen_paths:
|
||||
continue
|
||||
seen_paths.add(resolved)
|
||||
artifact = _build_artifact_payload(locator, resolved)
|
||||
artifacts.append(artifact.model_dump(exclude_none=True))
|
||||
return artifacts
|
||||
|
||||
|
||||
def _build_artifact_payload(locator: str, path: Path) -> ArtifactPayload:
|
||||
mime_type = _guess_mime_type(path)
|
||||
previewable = _is_previewable(path, mime_type)
|
||||
encoded = quote(locator, safe="")
|
||||
preview_url = f"/nanobot/artifacts/preview?target={encoded}" if previewable else None
|
||||
return ArtifactPayload(
|
||||
name=path.name,
|
||||
mime_type=mime_type,
|
||||
size=path.stat().st_size,
|
||||
download_url=f"/nanobot/artifacts/download?target={encoded}",
|
||||
previewable=previewable,
|
||||
preview_url=preview_url,
|
||||
)
|
||||
|
||||
|
||||
def _guess_mime_type(path: Path) -> str:
|
||||
mime_type, _ = mimetypes.guess_type(path.name)
|
||||
return mime_type or "application/octet-stream"
|
||||
|
||||
|
||||
def _is_previewable(path: Path, mime_type: str) -> bool:
|
||||
if mime_type.startswith("image/") or mime_type.startswith("text/"):
|
||||
return True
|
||||
extension = path.suffix.lower()
|
||||
if extension in PREVIEWABLE_EXTENSIONS:
|
||||
return True
|
||||
return mime_type in {
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
}
|
||||
|
||||
|
||||
def _collect_candidate_texts(content: str, session_messages: list[dict[str, Any]]) -> list[str]:
|
||||
texts = [content or ""]
|
||||
if not session_messages:
|
||||
return texts
|
||||
last_user_idx = -1
|
||||
for idx, message in enumerate(session_messages):
|
||||
if message.get("role") == "user":
|
||||
last_user_idx = idx
|
||||
if last_user_idx == -1:
|
||||
segment = session_messages
|
||||
else:
|
||||
segment = session_messages[last_user_idx + 1 :]
|
||||
for message in segment:
|
||||
raw = message.get("content")
|
||||
flattened = _flatten_content(raw)
|
||||
if flattened:
|
||||
texts.append(flattened)
|
||||
return texts
|
||||
|
||||
|
||||
def _extract_locators(text: str) -> Iterable[str]:
|
||||
if not text:
|
||||
return []
|
||||
ordered: list[str] = []
|
||||
seen: set[str] = set()
|
||||
patterns = (LOCAL_URI_PATTERN, REPORT_PATH_PATTERN, PATH_PATTERN)
|
||||
for pattern in patterns:
|
||||
for match in pattern.findall(text):
|
||||
normalized = _normalize_locator(match)
|
||||
if not normalized or normalized in seen:
|
||||
continue
|
||||
seen.add(normalized)
|
||||
ordered.append(normalized)
|
||||
return ordered
|
||||
|
||||
|
||||
def _normalize_locator(raw_locator: str) -> str:
|
||||
locator = raw_locator.strip().strip("`'\"")
|
||||
locator = locator.rstrip(".,;:!?)]}")
|
||||
return locator
|
||||
|
||||
|
||||
def _resolve_locator(locator: str) -> Path | None:
|
||||
data_root = get_data_root()
|
||||
workspace_root = get_workspace_root()
|
||||
uploads_root = get_uploads_root()
|
||||
reports_root = get_reports_root()
|
||||
repo_root = data_root.parent
|
||||
if locator.startswith("local://"):
|
||||
raw_local = locator.replace("local://", "", 1).strip().lstrip("/\\")
|
||||
if not raw_local:
|
||||
return None
|
||||
candidate = Path(raw_local)
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
checks = [workspace_root / candidate, reports_root / candidate, uploads_root / candidate, uploads_root / candidate.name]
|
||||
for path in checks:
|
||||
if path.exists():
|
||||
return path
|
||||
return uploads_root / candidate.name
|
||||
normalized = locator.replace("\\", "/")
|
||||
path = Path(locator)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
if normalized.startswith("data/data/"):
|
||||
return repo_root / normalized
|
||||
checks = [
|
||||
workspace_root / normalized,
|
||||
data_root / normalized,
|
||||
repo_root / normalized,
|
||||
]
|
||||
for candidate in checks:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _flatten_content(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
fragments: list[str] = []
|
||||
for item in value:
|
||||
flattened = _flatten_content(item)
|
||||
if flattened:
|
||||
fragments.append(flattened)
|
||||
return "\n".join(fragments)
|
||||
if isinstance(value, dict):
|
||||
fragments: list[str] = []
|
||||
text = value.get("text")
|
||||
if isinstance(text, str):
|
||||
fragments.append(text)
|
||||
content = value.get("content")
|
||||
if content is not None:
|
||||
nested = _flatten_content(content)
|
||||
if nested:
|
||||
fragments.append(nested)
|
||||
for field in ("path", "file", "file_path", "url"):
|
||||
data = value.get(field)
|
||||
if isinstance(data, str):
|
||||
fragments.append(data)
|
||||
return "\n".join(fragments)
|
||||
return str(value)
|
||||
@@ -0,0 +1,39 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[2]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
DEFAULT_DATA_ROOT = REPO_ROOT / "data"
|
||||
LEGACY_DATA_ROOT = BACKEND_ROOT / "data"
|
||||
|
||||
|
||||
def get_data_root() -> Path:
|
||||
configured = (os.getenv("DATA_ROOT") or "").strip()
|
||||
if configured:
|
||||
return Path(configured).expanduser().resolve()
|
||||
if DEFAULT_DATA_ROOT.exists():
|
||||
return DEFAULT_DATA_ROOT
|
||||
if LEGACY_DATA_ROOT.exists():
|
||||
print(f"[DATA_ROOT] legacy path detected: {LEGACY_DATA_ROOT}. Please migrate to {DEFAULT_DATA_ROOT}.")
|
||||
return LEGACY_DATA_ROOT
|
||||
return DEFAULT_DATA_ROOT
|
||||
|
||||
|
||||
def get_workspace_root() -> Path:
|
||||
return get_data_root() / "workspace"
|
||||
|
||||
|
||||
def get_uploads_root() -> Path:
|
||||
return get_data_root() / "uploads"
|
||||
|
||||
|
||||
def get_reports_root() -> Path:
|
||||
return get_data_root() / "data"
|
||||
|
||||
|
||||
def ensure_data_layout() -> None:
|
||||
get_data_root().mkdir(parents=True, exist_ok=True)
|
||||
get_workspace_root().mkdir(parents=True, exist_ok=True)
|
||||
get_uploads_root().mkdir(parents=True, exist_ok=True)
|
||||
get_reports_root().mkdir(parents=True, exist_ok=True)
|
||||
@@ -0,0 +1,43 @@
|
||||
import smtplib
|
||||
import os
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
|
||||
def send_verification_email(to_email: str, token: str):
|
||||
smtp_host = os.getenv("SMTP_HOST", "smtp.qq.com")
|
||||
smtp_port = int(os.getenv("SMTP_PORT", "465"))
|
||||
smtp_user = os.getenv("SMTP_USER", "")
|
||||
smtp_password = os.getenv("SMTP_PASSWORD", "")
|
||||
frontend_url = os.getenv("FRONTEND_URL", "http://localhost:5173")
|
||||
|
||||
if not smtp_user or not smtp_password:
|
||||
print("SMTP configuration is missing. Skip sending email.")
|
||||
return
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg['From'] = smtp_user
|
||||
msg['To'] = to_email
|
||||
msg['Subject'] = "请验证你的邮箱地址"
|
||||
|
||||
verify_link = f"{frontend_url}/verify-email?token={token}"
|
||||
body = f"""
|
||||
<html>
|
||||
<body>
|
||||
<h2>欢迎使用全源灵动!</h2>
|
||||
<p>请点击下方链接验证邮箱并激活账号:</p>
|
||||
<p><a href="{verify_link}">{verify_link}</a></p>
|
||||
<p>如果你没有发起该请求,请忽略此邮件。</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
msg.attach(MIMEText(body, 'html'))
|
||||
|
||||
try:
|
||||
# Use SMTP_SSL for port 465
|
||||
server = smtplib.SMTP_SSL(smtp_host, smtp_port)
|
||||
server.login(smtp_user, smtp_password)
|
||||
server.send_message(msg)
|
||||
server.quit()
|
||||
print(f"Verification email sent to {to_email}")
|
||||
except Exception as e:
|
||||
print(f"Failed to send email: {e}")
|
||||
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from app.core.data_root import (
|
||||
BACKEND_ROOT,
|
||||
LEGACY_DATA_ROOT,
|
||||
get_data_root,
|
||||
get_reports_root,
|
||||
get_uploads_root,
|
||||
get_workspace_root,
|
||||
)
|
||||
|
||||
|
||||
data_root = get_data_root()
|
||||
workspace_root = get_workspace_root()
|
||||
uploads_root = get_uploads_root()
|
||||
reports_root = get_reports_root()
|
||||
legacy_workspace_root = LEGACY_DATA_ROOT / "workspace"
|
||||
legacy_uploads_root = LEGACY_DATA_ROOT / "uploads"
|
||||
legacy_reports_root = LEGACY_DATA_ROOT / "data"
|
||||
backend_root = BACKEND_ROOT
|
||||
allowed_artifact_roots = (
|
||||
workspace_root,
|
||||
uploads_root,
|
||||
reports_root,
|
||||
legacy_workspace_root,
|
||||
legacy_uploads_root,
|
||||
legacy_reports_root,
|
||||
)
|
||||
|
||||
|
||||
def resolve_upload_file_path(file_url: Optional[str]) -> Path:
|
||||
if not file_url:
|
||||
raise ValueError("File URL is empty")
|
||||
|
||||
if file_url.startswith("local://"):
|
||||
raw_name = file_url.replace("local://", "", 1)
|
||||
safe_name = os.path.basename(raw_name)
|
||||
file_path = uploads_root / safe_name
|
||||
return file_path
|
||||
|
||||
return Path(file_url)
|
||||
|
||||
|
||||
def resolve_artifact_target(target: str) -> Path | None:
|
||||
locator = (target or "").strip().strip("'\"")
|
||||
if not locator:
|
||||
return None
|
||||
if locator.startswith("local://"):
|
||||
raw_local = locator.replace("local://", "", 1).strip().lstrip("/\\")
|
||||
if not raw_local:
|
||||
return None
|
||||
candidate = Path(raw_local)
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
checks = (
|
||||
workspace_root / candidate,
|
||||
reports_root / candidate,
|
||||
uploads_root / candidate,
|
||||
uploads_root / candidate.name,
|
||||
)
|
||||
for path in checks:
|
||||
if path.exists():
|
||||
return path
|
||||
return uploads_root / candidate.name
|
||||
normalized = locator.replace("\\", "/")
|
||||
path = Path(locator)
|
||||
if path.is_absolute():
|
||||
return path
|
||||
if normalized.startswith("data/data/"):
|
||||
return data_root.parent / normalized
|
||||
checks = (
|
||||
workspace_root / normalized,
|
||||
data_root / normalized,
|
||||
backend_root / normalized,
|
||||
)
|
||||
for candidate in checks:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def ensure_artifact_access(path: Path, *, require_file: bool = True) -> Path:
|
||||
try:
|
||||
resolved = path.resolve(strict=True)
|
||||
except FileNotFoundError as exc:
|
||||
raise FileNotFoundError("目标文件不存在") from exc
|
||||
if require_file and not resolved.is_file():
|
||||
raise FileNotFoundError("目标文件不存在")
|
||||
if not require_file and not resolved.is_dir():
|
||||
raise FileNotFoundError("目标目录不存在")
|
||||
for root in allowed_artifact_roots:
|
||||
if resolved.is_relative_to(root.resolve()):
|
||||
return resolved
|
||||
raise PermissionError("非法路径访问")
|
||||
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
from typing import Optional, Dict
|
||||
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
|
||||
|
||||
|
||||
def normalize_provider_name(provider: Optional[str]) -> Optional[str]:
|
||||
if not provider:
|
||||
return None
|
||||
normalized = provider.strip().lower()
|
||||
alias_map = {
|
||||
"azure": "azure_openai",
|
||||
"local": "vllm",
|
||||
}
|
||||
return alias_map.get(normalized, normalized)
|
||||
|
||||
|
||||
def _running_in_docker() -> bool:
|
||||
# Best-effort, cross-platform detection.
|
||||
if os.environ.get("DATACLAW_RUNNING_IN_DOCKER", "").strip().lower() in ("1", "true", "yes", "y"):
|
||||
return True
|
||||
return os.path.exists("/.dockerenv")
|
||||
|
||||
|
||||
def _rewrite_localhost_api_base(api_base: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
When running inside Docker, `localhost` points to the container itself.
|
||||
For host-local LLMs (Ollama/vLLM), users often configure `http://localhost:...`,
|
||||
which breaks in containers. We rewrite it to `host.docker.internal`.
|
||||
"""
|
||||
if not api_base:
|
||||
return api_base
|
||||
base = api_base.strip()
|
||||
if base.startswith("http://localhost") or base.startswith("https://localhost"):
|
||||
return base.replace("://localhost", "://host.docker.internal", 1)
|
||||
if base.startswith("http://127.0.0.1") or base.startswith("https://127.0.0.1"):
|
||||
return base.replace("://127.0.0.1", "://host.docker.internal", 1)
|
||||
return api_base
|
||||
|
||||
|
||||
def build_llm_provider(
|
||||
*,
|
||||
model: str,
|
||||
provider: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
provider_name = normalize_provider_name(provider)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
if _running_in_docker():
|
||||
api_base = _rewrite_localhost_api_base(api_base)
|
||||
|
||||
if backend == "openai_codex" or model.startswith("openai-codex/"):
|
||||
return OpenAICodexProvider(default_model=model)
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not api_key or not api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base.")
|
||||
return AzureOpenAIProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
)
|
||||
|
||||
if backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
return AnthropicProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
return PatchedOpenAICompatProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
extra_headers=extra_headers,
|
||||
spec=spec,
|
||||
)
|
||||
@@ -0,0 +1,499 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List, Callable, Awaitable, Any, Dict
|
||||
|
||||
# Add project root to sys.path to allow importing nanobot
|
||||
# Assuming backend/app/core/nanobot.py -> backend/app/core -> backend/app -> backend -> root
|
||||
# This path calculation seems correct for backend/app/core/nanobot.py relative to backend/
|
||||
# BUT nanobot package is in ../nanobot relative to backend/
|
||||
# So we need to go up one more level to reach the parent of backend/
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(PROJECT_ROOT / "agent-core") not in sys.path:
|
||||
sys.path.append(str(PROJECT_ROOT / "agent-core"))
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.registry import find_by_name
|
||||
from nanobot.session.manager import SessionManager
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
# Import skills loader
|
||||
# We use a lazy import inside the method to avoid potential circular dependencies if any arise,
|
||||
# or just import here if we are confident.
|
||||
# Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py.
|
||||
from app.api.skills import load_skills
|
||||
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
|
||||
from app.core.llm_provider import _rewrite_localhost_api_base, _running_in_docker
|
||||
from app.services.llm_cache import get_llm_configs, get_active_llm_config
|
||||
from app.services.web_search_config_store import get_web_search_config
|
||||
|
||||
from app.core.data_root import get_workspace_root
|
||||
from app.trace import build_error_attributes, build_usage_attributes, trace_service
|
||||
|
||||
class NanobotIntegration:
|
||||
def __init__(self):
|
||||
self.agent: AgentLoop | None = None
|
||||
self.bus: MessageBus | None = None
|
||||
self.cron: CronService | None = None
|
||||
self.config: Config | None = None
|
||||
self._started = False
|
||||
self._model_agent_cache: Dict[tuple[str | None, int | None], AgentLoop] = {}
|
||||
self._model_agent_lock = asyncio.Lock()
|
||||
self._last_usage_by_session: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_config_value(value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_id(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
return str(value)
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response: Any) -> str:
|
||||
if response is None:
|
||||
return ""
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
if isinstance(response, OutboundMessage):
|
||||
return response.content or ""
|
||||
if isinstance(response, dict):
|
||||
content = response.get("content")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return str(content or "")
|
||||
content = getattr(response, "content", None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return str(response)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_usage(usage: Any) -> Dict[str, int] | None:
|
||||
if not isinstance(usage, dict):
|
||||
return None
|
||||
normalized: Dict[str, int] = {}
|
||||
prompt = int(usage.get("prompt_tokens", 0) or 0)
|
||||
completion = int(usage.get("completion_tokens", 0) or 0)
|
||||
total = int(usage.get("total_tokens", 0) or 0)
|
||||
|
||||
# If total_tokens is missing or zero, calculate it
|
||||
if total == 0:
|
||||
total = prompt + completion
|
||||
|
||||
normalized["prompt_tokens"] = prompt
|
||||
normalized["completion_tokens"] = completion
|
||||
normalized["total_tokens"] = total
|
||||
return normalized if (prompt > 0 or completion > 0) else None
|
||||
|
||||
def get_last_usage(self, session_id: str) -> Dict[str, int] | None:
|
||||
usage = self._last_usage_by_session.get(session_id)
|
||||
return dict(usage) if usage else None
|
||||
|
||||
def _get_web_search_config(self) -> Any:
|
||||
from nanobot.config.schema import WebSearchConfig
|
||||
ws_dict = get_web_search_config()
|
||||
return WebSearchConfig(
|
||||
provider=ws_dict.get("provider", "duckduckgo"),
|
||||
api_key=ws_dict.get("api_key", ""),
|
||||
base_url=ws_dict.get("base_url", ""),
|
||||
max_results=ws_dict.get("max_results", 5)
|
||||
)
|
||||
|
||||
def _need_custom_agent_for_target(self, target_config: Dict[str, Any]) -> bool:
|
||||
if not self.agent:
|
||||
return False
|
||||
|
||||
provider = self.agent.provider
|
||||
target_model = self._normalize_config_value(target_config.get("model"))
|
||||
current_model = self._normalize_config_value(
|
||||
getattr(self.agent, "model", None) or getattr(provider, "default_model", None)
|
||||
)
|
||||
if target_model != current_model:
|
||||
return True
|
||||
|
||||
target_provider = self._normalize_config_value(target_config.get("provider"))
|
||||
current_provider = self._normalize_config_value(getattr(provider, "_provider_name_override", None))
|
||||
if not current_provider:
|
||||
current_provider = self._normalize_config_value(getattr(getattr(provider, "_spec", None), "name", None))
|
||||
if not current_provider and current_model and self.config:
|
||||
current_provider = self._normalize_config_value(self.config.get_provider_name(current_model))
|
||||
if target_provider != current_provider:
|
||||
return True
|
||||
|
||||
target_api_base = self._normalize_config_value(target_config.get("api_base"))
|
||||
current_api_base = self._normalize_config_value(getattr(provider, "api_base", None))
|
||||
if target_api_base != current_api_base:
|
||||
return True
|
||||
|
||||
target_api_key = self._normalize_config_value(target_config.get("api_key"))
|
||||
current_api_key = self._normalize_config_value(getattr(provider, "api_key", None))
|
||||
if target_api_key != current_api_key:
|
||||
return True
|
||||
|
||||
target_headers = target_config.get("extra_headers") or {}
|
||||
current_headers = getattr(provider, "extra_headers", None) or {}
|
||||
return target_headers != current_headers
|
||||
|
||||
def initialize(self):
|
||||
workspace_path = get_workspace_root()
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
self._sync_builtin_skills_to_workspace(workspace_path)
|
||||
|
||||
# Override config workspace path via environment variable (since config is loaded from env)
|
||||
os.environ["NANOBOT_AGENTS__DEFAULTS__WORKSPACE"] = str(workspace_path)
|
||||
|
||||
self.config = load_config()
|
||||
# No need to set self.config.workspace_path as it's a property that reads from agents.defaults.workspace
|
||||
|
||||
self.bus = MessageBus()
|
||||
active_config = get_active_llm_config()
|
||||
initial_model = self.config.agents.defaults.model
|
||||
if active_config and active_config.get("model"):
|
||||
provider = self._make_provider_from_target(active_config)
|
||||
initial_model = self._normalize_config_value(active_config.get("model")) or initial_model
|
||||
else:
|
||||
provider = self._make_provider(self.config)
|
||||
|
||||
cron_store_path = workspace_path / "cron"
|
||||
cron_store_path.mkdir(parents=True, exist_ok=True)
|
||||
cron_store_file = cron_store_path / "jobs.json"
|
||||
|
||||
self.cron = CronService(cron_store_file)
|
||||
|
||||
session_manager = SessionManager(self.config.workspace_path)
|
||||
|
||||
self.agent = AgentLoop(
|
||||
bus=self.bus,
|
||||
provider=provider,
|
||||
workspace=self.config.workspace_path,
|
||||
model=initial_model,
|
||||
max_iterations=self.config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=self.config.agents.defaults.context_window_tokens,
|
||||
web_search_config=self._get_web_search_config(),
|
||||
web_proxy=self.config.tools.web.proxy or None,
|
||||
exec_config=self.config.tools.exec,
|
||||
cron_service=self.cron,
|
||||
restrict_to_workspace=self.config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=self.config.tools.mcp_servers,
|
||||
channels_config=self.config.channels,
|
||||
timezone=self.config.agents.defaults.timezone,
|
||||
)
|
||||
|
||||
self._register_custom_tools(self.agent)
|
||||
|
||||
def _sync_builtin_skills_to_workspace(self, workspace_path: Path) -> None:
|
||||
builtin_root = Path(__file__).resolve().parents[1] / "skills_builtin"
|
||||
workspace_skills_root = workspace_path / "skills"
|
||||
workspace_skills_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for skill_name in ("nl2sql", "visualization", "knowledge-base"):
|
||||
source_dir = builtin_root / skill_name
|
||||
source_skill_file = source_dir / "SKILL.md"
|
||||
if not source_skill_file.exists():
|
||||
continue
|
||||
target_dir = workspace_skills_root / skill_name
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(source_skill_file, target_dir / "SKILL.md")
|
||||
|
||||
def _register_custom_tools(self, agent: AgentLoop, project_id: int | None = None):
|
||||
from app.tools.nl2sql import NL2SQLTool
|
||||
from app.tools.visualization import VisualizationTool
|
||||
from app.tools.get_schema import GetDatabaseSchemaTool
|
||||
from app.tools.knowledge_base import KnowledgeBaseRetrieveTool
|
||||
from app.tools.subagent import ListSubagentsTool, InvokeSubagentTool
|
||||
agent.tools.register(NL2SQLTool())
|
||||
agent.tools.register(VisualizationTool())
|
||||
agent.tools.register(GetDatabaseSchemaTool())
|
||||
agent.tools.register(KnowledgeBaseRetrieveTool())
|
||||
agent.tools.register(ListSubagentsTool(project_id=project_id))
|
||||
agent.tools.register(InvokeSubagentTool(project_id=project_id))
|
||||
|
||||
def _build_provider(
|
||||
self,
|
||||
model: str,
|
||||
provider_name: str | None,
|
||||
api_key: str | None,
|
||||
api_base: str | None,
|
||||
extra_headers: dict[str, Any] | None = None,
|
||||
):
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
if _running_in_docker():
|
||||
api_base = _rewrite_localhost_api_base(api_base)
|
||||
|
||||
if backend == "openai_codex" or model.startswith("openai-codex/"):
|
||||
return OpenAICodexProvider(default_model=model)
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not api_key or not api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base.")
|
||||
return AzureOpenAIProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
)
|
||||
|
||||
if backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
return AnthropicProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
return PatchedOpenAICompatProvider(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
default_model=model,
|
||||
extra_headers=extra_headers,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
def _make_provider(self, config: Config):
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
provider = self._build_provider(
|
||||
model=model,
|
||||
provider_name=provider_name,
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=config.agents.defaults.temperature,
|
||||
max_tokens=config.agents.defaults.max_tokens,
|
||||
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
def _make_provider_from_target(self, target_config: Dict[str, Any]):
|
||||
model = self._normalize_config_value(target_config.get("model")) or self.config.agents.defaults.model
|
||||
provider_name = self._normalize_config_value(target_config.get("provider"))
|
||||
if not provider_name and model and self.config:
|
||||
provider_name = self._normalize_config_value(self.config.get_provider_name(model))
|
||||
provider = self._build_provider(
|
||||
model=model,
|
||||
provider_name=provider_name,
|
||||
api_key=self._normalize_config_value(target_config.get("api_key")),
|
||||
api_base=self._normalize_config_value(target_config.get("api_base")),
|
||||
extra_headers=target_config.get("extra_headers"),
|
||||
)
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=self.config.agents.defaults.temperature,
|
||||
max_tokens=self.config.agents.defaults.max_tokens,
|
||||
reasoning_effort=self.config.agents.defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
if not self.agent:
|
||||
self.initialize()
|
||||
asyncio.create_task(self.agent.run())
|
||||
asyncio.create_task(self.cron.start())
|
||||
self._started = True
|
||||
|
||||
async def stop(self):
|
||||
if self.agent:
|
||||
self.agent.stop()
|
||||
await self.agent.close_mcp()
|
||||
for agent in self._model_agent_cache.values():
|
||||
agent.stop()
|
||||
await agent.close_mcp()
|
||||
self._model_agent_cache.clear()
|
||||
if self.cron:
|
||||
self.cron.stop()
|
||||
self._started = False
|
||||
|
||||
def _build_agent_for_provider(self, provider: Any, mcp_servers: dict | None = None) -> AgentLoop:
|
||||
return AgentLoop(
|
||||
bus=self.bus,
|
||||
provider=provider,
|
||||
workspace=self.config.workspace_path,
|
||||
model=provider.default_model,
|
||||
max_iterations=self.config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=self.config.agents.defaults.context_window_tokens,
|
||||
web_search_config=self._get_web_search_config(),
|
||||
web_proxy=self.config.tools.web.proxy or None,
|
||||
exec_config=self.config.tools.exec,
|
||||
cron_service=self.cron,
|
||||
restrict_to_workspace=self.config.tools.restrict_to_workspace,
|
||||
session_manager=self.agent.sessions if self.agent else None,
|
||||
mcp_servers=mcp_servers if mcp_servers is not None else self.config.tools.mcp_servers,
|
||||
channels_config=self.config.channels,
|
||||
timezone=self.config.agents.defaults.timezone,
|
||||
)
|
||||
|
||||
async def _get_or_create_model_agent(self, model_id: str | None, target_config: Dict[str, Any] | None, project_id: int | None = None) -> AgentLoop:
|
||||
normalized_model_id = self._normalize_model_id(model_id)
|
||||
cache_key = (normalized_model_id, project_id)
|
||||
async with self._model_agent_lock:
|
||||
cached = self._model_agent_cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
if target_config:
|
||||
provider = self._make_provider_from_target(target_config)
|
||||
else:
|
||||
provider = self._make_provider(self.config)
|
||||
|
||||
mcp_servers_dict = dict(self.config.tools.mcp_servers) if self.config.tools.mcp_servers else {}
|
||||
if project_id is not None:
|
||||
from app.api.mcp import list_mcp_servers
|
||||
from nanobot.config.schema import MCPServerConfig
|
||||
servers = await list_mcp_servers(project_id=project_id)
|
||||
for s in servers:
|
||||
cfg = MCPServerConfig(
|
||||
type=s.get("type"),
|
||||
command=s.get("command") or "",
|
||||
args=s.get("args") or [],
|
||||
env=s.get("env") or {},
|
||||
url=s.get("url") or "",
|
||||
headers=s.get("headers") or {}
|
||||
)
|
||||
mcp_servers_dict[s["name"]] = cfg
|
||||
|
||||
agent = self._build_agent_for_provider(provider, mcp_servers=mcp_servers_dict)
|
||||
self._register_custom_tools(agent, project_id=project_id)
|
||||
self._model_agent_cache[cache_key] = agent
|
||||
return agent
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
message: str,
|
||||
session_id: str = "api:default",
|
||||
skill_ids: List[str] | None = None,
|
||||
model_id: str | None = None,
|
||||
project_id: int | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
):
|
||||
span_attributes = {
|
||||
"session_id": session_id,
|
||||
"project_id": project_id,
|
||||
"model_id": model_id,
|
||||
"component": "nanobot.process_message",
|
||||
}
|
||||
with trace_service.start_span(
|
||||
"nanobot.process_message",
|
||||
attributes=span_attributes,
|
||||
input_payload={"message": message},
|
||||
) as root_span:
|
||||
try:
|
||||
if not self.agent:
|
||||
self.initialize()
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
if project_id is None:
|
||||
from app.core.session_alias_store import session_alias_store
|
||||
|
||||
alias_meta = session_alias_store.get_alias_meta(session_id)
|
||||
if alias_meta and alias_meta.get("project_id") is not None:
|
||||
project_id = alias_meta.get("project_id")
|
||||
root_span.set_attributes({"project_id": project_id})
|
||||
|
||||
agent_to_use = self.agent
|
||||
need_custom_agent = False
|
||||
target_config = None
|
||||
|
||||
selected_model_id = self._normalize_model_id(model_id)
|
||||
if selected_model_id:
|
||||
llm_configs = get_llm_configs()
|
||||
target_config = next(
|
||||
(item for item in llm_configs if self._normalize_model_id(item.get("id")) == selected_model_id),
|
||||
None,
|
||||
)
|
||||
|
||||
if target_config is None:
|
||||
active_config = get_active_llm_config()
|
||||
if active_config and active_config.get("id"):
|
||||
selected_model_id = self._normalize_model_id(active_config.get("id"))
|
||||
target_config = active_config
|
||||
|
||||
if target_config and self._need_custom_agent_for_target(target_config):
|
||||
need_custom_agent = True
|
||||
if project_id is not None:
|
||||
need_custom_agent = True
|
||||
|
||||
with trace_service.start_span(
|
||||
"nanobot.resolve_agent",
|
||||
attributes={
|
||||
"session_id": session_id,
|
||||
"project_id": project_id,
|
||||
"selected_model_id": selected_model_id,
|
||||
"custom_agent": need_custom_agent,
|
||||
},
|
||||
):
|
||||
if need_custom_agent:
|
||||
agent_to_use = await self._get_or_create_model_agent(selected_model_id, target_config, project_id)
|
||||
|
||||
session = agent_to_use.sessions.get_or_create(session_id)
|
||||
normalized_messages = self._normalize_session_messages(session.messages)
|
||||
if len(normalized_messages) != len(session.messages):
|
||||
session.messages = normalized_messages
|
||||
agent_to_use.sessions.save(session)
|
||||
|
||||
with trace_service.start_span(
|
||||
"nanobot.process_direct",
|
||||
attributes={
|
||||
"session_id": session_id,
|
||||
"model": getattr(agent_to_use, "model", None),
|
||||
},
|
||||
) as direct_span:
|
||||
response = await agent_to_use.process_direct(
|
||||
message,
|
||||
session_key=session_id,
|
||||
channel="api",
|
||||
chat_id=session_id,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
)
|
||||
usage = self._normalize_usage(getattr(agent_to_use, "_last_usage", None))
|
||||
if usage:
|
||||
self._last_usage_by_session[session_id] = usage
|
||||
direct_span.set_attributes(build_usage_attributes(usage))
|
||||
root_span.set_attributes(build_usage_attributes(usage))
|
||||
text = self._extract_response_text(response)
|
||||
direct_span.update(output={"content": text})
|
||||
root_span.update(output={"content": text})
|
||||
return text
|
||||
except Exception as exc:
|
||||
root_span.set_attributes(build_error_attributes(exc, stage="nanobot_process_message"))
|
||||
root_span.record_error(exc, stage="nanobot_process_message")
|
||||
raise
|
||||
|
||||
def _normalize_session_messages(self, messages: List[Any]) -> List[dict[str, Any]]:
|
||||
normalized: List[dict[str, Any]] = []
|
||||
stack: List[Any] = list(messages)
|
||||
while stack:
|
||||
current = stack.pop(0)
|
||||
if isinstance(current, dict):
|
||||
normalized.append(current)
|
||||
continue
|
||||
if isinstance(current, list):
|
||||
stack = list(current) + stack
|
||||
return normalized
|
||||
|
||||
nanobot_service = NanobotIntegration()
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
class PatchedOpenAICompatProvider(OpenAICompatProvider):
|
||||
_MAX_COMPLETION_TOKEN_MODELS = ("gpt-5", "o1", "o3", "o4")
|
||||
|
||||
def _build_kwargs(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
kwargs = super()._build_kwargs(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
model_name = (model or self.default_model or "").lower()
|
||||
spec = self._spec
|
||||
supports_max_completion_tokens = bool(
|
||||
spec and getattr(spec, "supports_max_completion_tokens", False)
|
||||
)
|
||||
should_use_max_completion_tokens = supports_max_completion_tokens or any(
|
||||
token in model_name for token in self._MAX_COMPLETION_TOKEN_MODELS
|
||||
)
|
||||
|
||||
if should_use_max_completion_tokens and "max_tokens" in kwargs:
|
||||
kwargs["max_completion_tokens"] = kwargs.pop("max_tokens")
|
||||
|
||||
return kwargs
|
||||
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from fastapi import HTTPException, Depends, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
SECRET_KEY = "your-super-secret-key-for-dataclaw" # In production, use env variable
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 30 * 24 * 60 # 30 days
|
||||
|
||||
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
security = HTTPBearer()
|
||||
|
||||
class CurrentUser(BaseModel):
|
||||
id: int
|
||||
username: str
|
||||
is_admin: bool = False
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def get_password_hash(password):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)) -> CurrentUser:
|
||||
unauthorized = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(credentials.credentials, SECRET_KEY, algorithms=[ALGORITHM])
|
||||
except JWTError:
|
||||
raise unauthorized
|
||||
user_id = payload.get("id")
|
||||
username = payload.get("sub")
|
||||
is_admin = bool(payload.get("is_admin", False))
|
||||
if user_id is None or username is None:
|
||||
raise unauthorized
|
||||
return CurrentUser(id=user_id, username=username, is_admin=is_admin)
|
||||
|
||||
def get_admin_user(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin permission required")
|
||||
return current_user
|
||||
@@ -0,0 +1,215 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
|
||||
class SessionAliasStore:
|
||||
def __init__(self) -> None:
|
||||
data_dir = get_data_root()
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as exc:
|
||||
raise RuntimeError(f"DATA_ROOT 权限不足: {data_dir}") from exc
|
||||
self.db_path = data_dir / "nanobot_sessions.db"
|
||||
try:
|
||||
self._init_db()
|
||||
except PermissionError as exc:
|
||||
raise RuntimeError(f"DATA_ROOT 权限不足: {data_dir}") from exc
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(str(self.db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _init_db(self) -> None:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS session_cache (
|
||||
session_key TEXT PRIMARY KEY,
|
||||
created_at TEXT,
|
||||
updated_at TEXT,
|
||||
alias TEXT,
|
||||
pinned INTEGER NOT NULL DEFAULT 0,
|
||||
archived INTEGER NOT NULL DEFAULT 0,
|
||||
last_seen_at TEXT NOT NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
cols = {
|
||||
str(row["name"])
|
||||
for row in conn.execute("PRAGMA table_info(session_cache)").fetchall()
|
||||
}
|
||||
if "pinned" not in cols:
|
||||
conn.execute("ALTER TABLE session_cache ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0")
|
||||
if "archived" not in cols:
|
||||
conn.execute("ALTER TABLE session_cache ADD COLUMN archived INTEGER NOT NULL DEFAULT 0")
|
||||
if "project_id" not in cols:
|
||||
conn.execute("ALTER TABLE session_cache ADD COLUMN project_id INTEGER")
|
||||
|
||||
def sync_sessions(self, sessions: list[dict[str, Any]]) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
keys: list[str] = []
|
||||
with self._connect() as conn:
|
||||
for item in sessions:
|
||||
key = str(item.get("key") or "").strip()
|
||||
if not key:
|
||||
continue
|
||||
keys.append(key)
|
||||
created_at = str(item.get("created_at") or "")
|
||||
updated_at = str(item.get("updated_at") or "")
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO session_cache (session_key, created_at, updated_at, last_seen_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(session_key) DO UPDATE SET
|
||||
created_at = excluded.created_at,
|
||||
updated_at = excluded.updated_at,
|
||||
last_seen_at = excluded.last_seen_at
|
||||
""",
|
||||
(key, created_at, updated_at, now),
|
||||
)
|
||||
|
||||
if keys:
|
||||
placeholders = ",".join("?" for _ in keys)
|
||||
conn.execute(
|
||||
f"DELETE FROM session_cache WHERE session_key NOT IN ({placeholders})",
|
||||
keys,
|
||||
)
|
||||
else:
|
||||
conn.execute("DELETE FROM session_cache")
|
||||
|
||||
def list_cached_sessions(self, project_id: int | None = None) -> list[dict[str, Any]]:
|
||||
with self._connect() as conn:
|
||||
if project_id is not None:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT session_key, created_at, updated_at, alias, pinned, archived, project_id
|
||||
FROM session_cache
|
||||
WHERE project_id = ? OR project_id IS NULL
|
||||
ORDER BY pinned DESC, archived ASC, updated_at DESC
|
||||
""",
|
||||
(project_id,)
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT session_key, created_at, updated_at, alias, pinned, archived, project_id
|
||||
FROM session_cache
|
||||
ORDER BY pinned DESC, archived ASC, updated_at DESC
|
||||
"""
|
||||
).fetchall()
|
||||
return [self._row_to_session_item(row) for row in rows]
|
||||
|
||||
def sync_and_list(self, sessions: list[dict[str, Any]], project_id: int | None = None) -> list[dict[str, Any]]:
|
||||
self.sync_sessions(sessions)
|
||||
return self.list_cached_sessions(project_id)
|
||||
|
||||
def set_alias(self, session_key: str, alias: str) -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
clean_alias = alias.strip()
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO session_cache (session_key, created_at, updated_at, alias, last_seen_at)
|
||||
VALUES (?, '', '', ?, ?)
|
||||
ON CONFLICT(session_key) DO UPDATE SET
|
||||
alias = excluded.alias,
|
||||
last_seen_at = excluded.last_seen_at
|
||||
""",
|
||||
(session_key, clean_alias, now),
|
||||
)
|
||||
|
||||
def update_alias_meta(
|
||||
self,
|
||||
session_key: str,
|
||||
alias: str | None = None,
|
||||
pinned: bool | None = None,
|
||||
archived: bool | None = None,
|
||||
project_id: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT alias, pinned, archived, project_id FROM session_cache WHERE session_key = ?",
|
||||
(session_key,),
|
||||
).fetchone()
|
||||
current_alias = (str(row["alias"]) if row and row["alias"] else "")
|
||||
current_pinned = bool(row["pinned"]) if row else False
|
||||
current_archived = bool(row["archived"]) if row else False
|
||||
current_project_id = row["project_id"] if row and "project_id" in row.keys() else None
|
||||
next_alias = current_alias if alias is None else alias.strip()
|
||||
next_pinned = current_pinned if pinned is None else bool(pinned)
|
||||
next_archived = current_archived if archived is None else bool(archived)
|
||||
next_project_id = current_project_id if project_id is None else project_id
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO session_cache (session_key, created_at, updated_at, alias, pinned, archived, project_id, last_seen_at)
|
||||
VALUES (?, '', '', ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(session_key) DO UPDATE SET
|
||||
alias = excluded.alias,
|
||||
pinned = excluded.pinned,
|
||||
archived = excluded.archived,
|
||||
project_id = excluded.project_id,
|
||||
last_seen_at = excluded.last_seen_at
|
||||
""",
|
||||
(session_key, next_alias, int(next_pinned), int(next_archived), next_project_id, now),
|
||||
)
|
||||
return {"alias": next_alias or None, "pinned": next_pinned, "archived": next_archived, "project_id": next_project_id}
|
||||
|
||||
def get_alias(self, session_key: str) -> str | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT alias FROM session_cache WHERE session_key = ?",
|
||||
(session_key,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
alias = row["alias"]
|
||||
return str(alias) if alias else None
|
||||
|
||||
def get_alias_meta(self, session_key: str) -> dict[str, Any] | None:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute(
|
||||
"SELECT alias, pinned, archived, project_id FROM session_cache WHERE session_key = ?",
|
||||
(session_key,),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
alias = (row["alias"] or "").strip()
|
||||
return {
|
||||
"alias": alias or None,
|
||||
"pinned": bool(row["pinned"]) if "pinned" in row.keys() else False,
|
||||
"archived": bool(row["archived"]) if "archived" in row.keys() else False,
|
||||
"project_id": row["project_id"] if "project_id" in row.keys() else None,
|
||||
}
|
||||
|
||||
def delete_session(self, session_key: str) -> None:
|
||||
with self._connect() as conn:
|
||||
conn.execute("DELETE FROM session_cache WHERE session_key = ?", (session_key,))
|
||||
|
||||
def _row_to_session_item(self, row: sqlite3.Row) -> dict[str, Any]:
|
||||
alias = (row["alias"] or "").strip()
|
||||
fallback = str(row["session_key"]).replace("api:", "")
|
||||
title = alias or fallback
|
||||
pinned = bool(row["pinned"]) if "pinned" in row.keys() else False
|
||||
archived = bool(row["archived"]) if "archived" in row.keys() else False
|
||||
project_id = row["project_id"] if "project_id" in row.keys() else None
|
||||
return {
|
||||
"key": row["session_key"],
|
||||
"created_at": row["created_at"],
|
||||
"updated_at": row["updated_at"],
|
||||
"metadata": {"title": title},
|
||||
"alias": alias or None,
|
||||
"pinned": pinned,
|
||||
"archived": archived,
|
||||
"project_id": project_id,
|
||||
}
|
||||
|
||||
|
||||
session_alias_store = SessionAliasStore()
|
||||
Reference in New Issue
Block a user