diff --git a/backend/main.py b/backend/main.py index 919d0b3..46a2c8a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional from fastapi import FastAPI, HTTPException +from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel @@ -84,6 +85,13 @@ class ChatRequest(BaseModel): file_url: Optional[str] = None +def _should_use_nl2sql(request: ChatRequest) -> bool: + if request.prefer_sql_chart: + return True + source = (request.source or "").strip().lower() + return source == "upload" or source.startswith("ds:") + + class SessionAliasUpdateRequest(BaseModel): title: Optional[str] = None pinned: Optional[bool] = None @@ -96,6 +104,7 @@ class BatchDeleteRequest(BaseModel): class SessionFileContextUpdateRequest(BaseModel): active_data_file: Optional[Dict[str, Any]] = None + selected_data_source: Optional[str] = None def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str: @@ -112,12 +121,13 @@ def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str: def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict: chart = nl2sql_result.chart - return { + payload = { "sql": nl2sql_result.sql, "result": nl2sql_result.result, "chart": chart.model_dump() if chart else None, "error": nl2sql_result.error, } + return jsonable_encoder(payload) def _persist_session_turn( @@ -136,7 +146,7 @@ def _persist_session_turn( @app.post("/nanobot/chat") async def nanobot_chat(request: ChatRequest): try: - if request.prefer_sql_chart: + if _should_use_nl2sql(request): nl2sql_result = await process_nl2sql( NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) ) @@ -161,7 +171,7 @@ async def nanobot_chat(request: ChatRequest): async def nanobot_chat_stream(request: ChatRequest): async def event_generator(): try: - if request.prefer_sql_chart: + if _should_use_nl2sql(request): nl2sql_result = await process_nl2sql( NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) ) @@ -295,10 +305,17 @@ def update_session_context_file(session_id: str, payload: SessionFileContextUpda if not nanobot_service.agent: raise HTTPException(status_code=400, detail="Nanobot not running") session = nanobot_service.agent.sessions.get_or_create(session_id) - if payload.active_data_file is None: - session.metadata.pop("active_data_file", None) - else: - session.metadata["active_data_file"] = payload.active_data_file + updated_fields = payload.model_fields_set + if "active_data_file" in updated_fields: + if payload.active_data_file is None: + session.metadata.pop("active_data_file", None) + else: + session.metadata["active_data_file"] = payload.active_data_file + if "selected_data_source" in updated_fields: + if payload.selected_data_source: + session.metadata["selected_data_source"] = payload.selected_data_source + else: + session.metadata.pop("selected_data_source", None) session.updated_at = datetime.now() nanobot_service.agent.sessions.save(session) return {"status": "success", "metadata": session.metadata} diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index 8bd0af1..af6c3f0 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -58,6 +58,7 @@ interface SessionData { key: string; metadata?: { active_data_file?: DataFileContext | null; + selected_data_source?: string | null; [key: string]: any; }; messages: Array<{ @@ -104,7 +105,6 @@ export function ChatInterface() { useEffect(() => { if (currentProject) { - setSelectedDataSource(""); fetchDataSources(); } }, [currentProject]); @@ -117,26 +117,37 @@ export function ChatInterface() { setAvailableDataSources(projectSources); if (selectedDataSource && !projectSources.find(ds => ds.id === selectedDataSource)) { setSelectedDataSource(""); + void syncSessionContext({ selected_data_source: null }); } } catch (e) { console.error("Failed to fetch data sources", e); } }; - const syncSessionFileContext = async (file: DataFileContext | null) => { + const syncSessionContext = async (payload: { + active_data_file?: DataFileContext | null; + selected_data_source?: string | null; + }) => { try { - await api.put(`/nanobot/sessions/${encodeURIComponent(activeSessionKey)}/context-file`, { - active_data_file: file, - }); + await api.put(`/nanobot/sessions/${encodeURIComponent(activeSessionKey)}/context-file`, payload); } catch (e) { - console.error("Failed to sync session file context", e); + console.error("Failed to sync session context", e); } }; + const handleSelectDataSource = async (sourceId: string) => { + setSelectedDataSource(sourceId); + await syncSessionContext({ selected_data_source: sourceId }); + }; + + const handleClearDataSource = async () => { + setSelectedDataSource(""); + await syncSessionContext({ selected_data_source: null }); + }; + useEffect(() => { const fetchSessionData = async () => { setIsLoading(true); - setSelectedDataSource(""); setSelectedSkillIds([]); try { const data = await api.get(`/nanobot/sessions/${activeSessionKey}`); @@ -152,12 +163,15 @@ export function ChatInterface() { setMessages([]); } const restoredFile = data.metadata?.active_data_file || null; + const restoredSource = data.metadata?.selected_data_source || ""; setActiveDataFile(restoredFile); + setSelectedDataSource(restoredSource); setAttachedFile(null); } catch (e) { console.error("Failed to fetch session messages", e); setMessages([]); setActiveDataFile(null); + setSelectedDataSource(""); setAttachedFile(null); } finally { setIsLoading(false); @@ -238,7 +252,7 @@ export function ChatInterface() { setAttachedFile(uploadedFile); setActiveDataFile(uploadedFile); setSelectedDataSource(""); - await syncSessionFileContext(uploadedFile); + await syncSessionContext({ active_data_file: uploadedFile, selected_data_source: null }); } catch (error) { console.error("File upload error:", error); // Could show a toast notification here @@ -253,7 +267,7 @@ export function ChatInterface() { const handleRemoveFile = async () => { setAttachedFile(null); setActiveDataFile(null); - await syncSessionFileContext(null); + await syncSessionContext({ active_data_file: null }); }; const selectedDataSourceName = availableDataSources.find(ds => ds.id === selectedDataSource)?.name || ""; @@ -610,7 +624,7 @@ export function ChatInterface() {