diff --git a/backend/app/core/session_alias_store.py b/backend/app/core/session_alias_store.py new file mode 100644 index 0000000..153d1c1 --- /dev/null +++ b/backend/app/core/session_alias_store.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +import sqlite3 +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + + +class SessionAliasStore: + def __init__(self) -> None: + backend_root = Path(__file__).resolve().parents[2] + data_dir = backend_root / "data" + data_dir.mkdir(parents=True, exist_ok=True) + self.db_path = data_dir / "nanobot_sessions.db" + self._init_db() + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(str(self.db_path)) + conn.row_factory = sqlite3.Row + return conn + + def _init_db(self) -> None: + with self._connect() as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS session_cache ( + session_key TEXT PRIMARY KEY, + created_at TEXT, + updated_at TEXT, + alias TEXT, + pinned INTEGER NOT NULL DEFAULT 0, + archived INTEGER NOT NULL DEFAULT 0, + last_seen_at TEXT NOT NULL + ) + """ + ) + cols = { + str(row["name"]) + for row in conn.execute("PRAGMA table_info(session_cache)").fetchall() + } + if "pinned" not in cols: + conn.execute("ALTER TABLE session_cache ADD COLUMN pinned INTEGER NOT NULL DEFAULT 0") + if "archived" not in cols: + conn.execute("ALTER TABLE session_cache ADD COLUMN archived INTEGER NOT NULL DEFAULT 0") + + def sync_sessions(self, sessions: list[dict[str, Any]]) -> None: + now = datetime.now(timezone.utc).isoformat() + keys: list[str] = [] + with self._connect() as conn: + for item in sessions: + key = str(item.get("key") or "").strip() + if not key: + continue + keys.append(key) + created_at = str(item.get("created_at") or "") + updated_at = str(item.get("updated_at") or "") + conn.execute( + """ + INSERT INTO session_cache (session_key, created_at, updated_at, last_seen_at) + VALUES (?, ?, ?, ?) + ON CONFLICT(session_key) DO UPDATE SET + created_at = excluded.created_at, + updated_at = excluded.updated_at, + last_seen_at = excluded.last_seen_at + """, + (key, created_at, updated_at, now), + ) + + if keys: + placeholders = ",".join("?" for _ in keys) + conn.execute( + f"DELETE FROM session_cache WHERE session_key NOT IN ({placeholders})", + keys, + ) + else: + conn.execute("DELETE FROM session_cache") + + def list_cached_sessions(self) -> list[dict[str, Any]]: + with self._connect() as conn: + rows = conn.execute( + """ + SELECT session_key, created_at, updated_at, alias, pinned, archived + FROM session_cache + ORDER BY pinned DESC, archived ASC, updated_at DESC + """ + ).fetchall() + return [self._row_to_session_item(row) for row in rows] + + def sync_and_list(self, sessions: list[dict[str, Any]]) -> list[dict[str, Any]]: + self.sync_sessions(sessions) + return self.list_cached_sessions() + + def set_alias(self, session_key: str, alias: str) -> None: + now = datetime.now(timezone.utc).isoformat() + clean_alias = alias.strip() + with self._connect() as conn: + conn.execute( + """ + INSERT INTO session_cache (session_key, created_at, updated_at, alias, last_seen_at) + VALUES (?, '', '', ?, ?) + ON CONFLICT(session_key) DO UPDATE SET + alias = excluded.alias, + last_seen_at = excluded.last_seen_at + """, + (session_key, clean_alias, now), + ) + + def update_alias_meta( + self, + session_key: str, + alias: str | None = None, + pinned: bool | None = None, + archived: bool | None = None, + ) -> dict[str, Any]: + now = datetime.now(timezone.utc).isoformat() + with self._connect() as conn: + row = conn.execute( + "SELECT alias, pinned, archived FROM session_cache WHERE session_key = ?", + (session_key,), + ).fetchone() + current_alias = (str(row["alias"]) if row and row["alias"] else "") + current_pinned = bool(row["pinned"]) if row else False + current_archived = bool(row["archived"]) if row else False + next_alias = current_alias if alias is None else alias.strip() + next_pinned = current_pinned if pinned is None else bool(pinned) + next_archived = current_archived if archived is None else bool(archived) + conn.execute( + """ + INSERT INTO session_cache (session_key, created_at, updated_at, alias, pinned, archived, last_seen_at) + VALUES (?, '', '', ?, ?, ?, ?) + ON CONFLICT(session_key) DO UPDATE SET + alias = excluded.alias, + pinned = excluded.pinned, + archived = excluded.archived, + last_seen_at = excluded.last_seen_at + """, + (session_key, next_alias, int(next_pinned), int(next_archived), now), + ) + return {"alias": next_alias or None, "pinned": next_pinned, "archived": next_archived} + + def get_alias(self, session_key: str) -> str | None: + with self._connect() as conn: + row = conn.execute( + "SELECT alias FROM session_cache WHERE session_key = ?", + (session_key,), + ).fetchone() + if not row: + return None + alias = row["alias"] + return str(alias) if alias else None + + def delete_session(self, session_key: str) -> None: + with self._connect() as conn: + conn.execute("DELETE FROM session_cache WHERE session_key = ?", (session_key,)) + + def _row_to_session_item(self, row: sqlite3.Row) -> dict[str, Any]: + alias = (row["alias"] or "").strip() + fallback = str(row["session_key"]).replace("api:", "") + title = alias or fallback + pinned = bool(row["pinned"]) if "pinned" in row.keys() else False + archived = bool(row["archived"]) if "archived" in row.keys() else False + return { + "key": row["session_key"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "metadata": {"title": title}, + "alias": alias or None, + "pinned": pinned, + "archived": archived, + } + + +session_alias_store = SessionAliasStore() diff --git a/backend/main.py b/backend/main.py index 53fc541..f1d1815 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,5 @@ from typing import List, Optional -from fastapi import FastAPI, HTTPException, Body +from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -11,6 +11,7 @@ from app.connectors.postgres import postgres_connector from app.connectors.clickhouse import clickhouse_connector from app.connectors.minio import minio_connector from app.core.nanobot import nanobot_service +from app.core.session_alias_store import session_alias_store from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse app = FastAPI() @@ -74,10 +75,21 @@ class ChatRequest(BaseModel): skill_ids: Optional[List[str]] = None model_id: Optional[str] = None + +class SessionAliasUpdateRequest(BaseModel): + title: Optional[str] = None + pinned: Optional[bool] = None + archived: Optional[bool] = None + @app.post("/nanobot/chat") async def nanobot_chat(request: ChatRequest): try: - response = await nanobot_service.process_message(request.message, skill_ids=request.skill_ids, model_id=request.model_id) + response = await nanobot_service.process_message( + request.message, + session_id=request.session_id, + skill_ids=request.skill_ids, + model_id=request.model_id, + ) return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -115,21 +127,22 @@ async def nanobot_chat_stream(request: ChatRequest): @app.get("/nanobot/sessions") def get_sessions(): if not nanobot_service.agent: - return [] - # session_manager has list_sessions() + return session_alias_store.list_cached_sessions() sessions = nanobot_service.agent.sessions.list_sessions() - return sessions + return session_alias_store.sync_and_list(sessions) @app.get("/nanobot/sessions/{session_id}") def get_session(session_id: str): if not nanobot_service.agent: raise HTTPException(status_code=400, detail="Nanobot not running") session = nanobot_service.agent.sessions.get_or_create(session_id) + alias = session_alias_store.get_alias(session_id) return { "key": session.key, "created_at": session.created_at, "updated_at": session.updated_at, "metadata": session.metadata, + "alias": alias, "messages": session.messages } @@ -145,18 +158,19 @@ def delete_session(session_id: str): path = nanobot_service.agent.sessions._get_session_path(session_id) if path.exists(): path.unlink() + session_alias_store.delete_session(session_id) return {"status": "success"} raise HTTPException(status_code=404, detail="Session not found") @app.put("/nanobot/sessions/{session_id}") -def update_session(session_id: str, title: str = Body(..., embed=True)): - if not nanobot_service.agent: - raise HTTPException(status_code=400, detail="Nanobot not running") - - session = nanobot_service.agent.sessions.get_or_create(session_id) - session.metadata["title"] = title - nanobot_service.agent.sessions.save(session) - return {"status": "success", "title": title} +def update_session(session_id: str, payload: SessionAliasUpdateRequest): + updated = session_alias_store.update_alias_meta( + session_key=session_id, + alias=payload.title, + pinned=payload.pinned, + archived=payload.archived, + ) + return {"status": "success", **updated} @app.post("/api/v1/agent/nl2sql", response_model=NL2SQLResponse) async def run_nl2sql(request: NL2SQLRequest): diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index 373a5d9..e8e1633 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -17,6 +17,7 @@ interface Message { id: string; role: 'user' | 'assistant'; content: string; + awaitingFirstToken?: boolean; } interface ModelConfig { @@ -131,10 +132,12 @@ export function ChatInterface() { setMessages(prev => [...prev, { id: assistantId, role: "assistant", - content: "" + content: "", + awaitingFirstToken: true }]); const token = localStorage.getItem("token"); + const effectiveModelId = selectedModelId || currentModel?.id || ""; const response = await fetch("/nanobot/chat/stream", { method: "POST", headers: { @@ -144,7 +147,7 @@ export function ChatInterface() { body: JSON.stringify({ message: newMessage.content, session_id: activeSessionKey, - model_id: selectedModelId, + model_id: effectiveModelId, }), }); @@ -178,7 +181,7 @@ export function ChatInterface() { streamedText = `${streamedText}${payload.content}`; setMessages((prev) => prev.map((msg) => - msg.id === assistantId ? { ...msg, content: streamedText } : msg + msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg ) ); } @@ -187,7 +190,7 @@ export function ChatInterface() { streamedText = payload.content; setMessages((prev) => prev.map((msg) => - msg.id === assistantId ? { ...msg, content: payload.content || "" } : msg + msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false } : msg ) ); } @@ -197,6 +200,19 @@ export function ChatInterface() { } } } + + if (!streamedText) { + const fallback = await api.post<{ response: string }>("/nanobot/chat", { + message: newMessage.content, + session_id: activeSessionKey, + model_id: effectiveModelId, + }); + setMessages((prev) => + prev.map((msg) => + msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false } : msg + ) + ); + } } else { // Fallback to existing NL2SQL or other skills (e.g. for "表格问答" or "深度问数") const source = selectedDataSource.split('-')[0]; // postgres-main -> postgres @@ -235,6 +251,7 @@ export function ChatInterface() { } finally { setIsLoading(false); setVizLoading(false); + window.dispatchEvent(new Event("nanobot:sessions-changed")); } }; @@ -372,11 +389,18 @@ export function ChatInterface() { }`} > {msg.role === "assistant" ? ( -