data source binded
This commit is contained in:
+24
-7
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
@@ -84,6 +85,13 @@ class ChatRequest(BaseModel):
|
||||
file_url: Optional[str] = None
|
||||
|
||||
|
||||
def _should_use_nl2sql(request: ChatRequest) -> bool:
|
||||
if request.prefer_sql_chart:
|
||||
return True
|
||||
source = (request.source or "").strip().lower()
|
||||
return source == "upload" or source.startswith("ds:")
|
||||
|
||||
|
||||
class SessionAliasUpdateRequest(BaseModel):
|
||||
title: Optional[str] = None
|
||||
pinned: Optional[bool] = None
|
||||
@@ -96,6 +104,7 @@ class BatchDeleteRequest(BaseModel):
|
||||
|
||||
class SessionFileContextUpdateRequest(BaseModel):
|
||||
active_data_file: Optional[Dict[str, Any]] = None
|
||||
selected_data_source: Optional[str] = None
|
||||
|
||||
|
||||
def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
|
||||
@@ -112,12 +121,13 @@ def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
|
||||
|
||||
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
|
||||
chart = nl2sql_result.chart
|
||||
return {
|
||||
payload = {
|
||||
"sql": nl2sql_result.sql,
|
||||
"result": nl2sql_result.result,
|
||||
"chart": chart.model_dump() if chart else None,
|
||||
"error": nl2sql_result.error,
|
||||
}
|
||||
return jsonable_encoder(payload)
|
||||
|
||||
|
||||
def _persist_session_turn(
|
||||
@@ -136,7 +146,7 @@ def _persist_session_turn(
|
||||
@app.post("/nanobot/chat")
|
||||
async def nanobot_chat(request: ChatRequest):
|
||||
try:
|
||||
if request.prefer_sql_chart:
|
||||
if _should_use_nl2sql(request):
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
||||
)
|
||||
@@ -161,7 +171,7 @@ async def nanobot_chat(request: ChatRequest):
|
||||
async def nanobot_chat_stream(request: ChatRequest):
|
||||
async def event_generator():
|
||||
try:
|
||||
if request.prefer_sql_chart:
|
||||
if _should_use_nl2sql(request):
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
||||
)
|
||||
@@ -295,10 +305,17 @@ def update_session_context_file(session_id: str, payload: SessionFileContextUpda
|
||||
if not nanobot_service.agent:
|
||||
raise HTTPException(status_code=400, detail="Nanobot not running")
|
||||
session = nanobot_service.agent.sessions.get_or_create(session_id)
|
||||
if payload.active_data_file is None:
|
||||
session.metadata.pop("active_data_file", None)
|
||||
else:
|
||||
session.metadata["active_data_file"] = payload.active_data_file
|
||||
updated_fields = payload.model_fields_set
|
||||
if "active_data_file" in updated_fields:
|
||||
if payload.active_data_file is None:
|
||||
session.metadata.pop("active_data_file", None)
|
||||
else:
|
||||
session.metadata["active_data_file"] = payload.active_data_file
|
||||
if "selected_data_source" in updated_fields:
|
||||
if payload.selected_data_source:
|
||||
session.metadata["selected_data_source"] = payload.selected_data_source
|
||||
else:
|
||||
session.metadata.pop("selected_data_source", None)
|
||||
session.updated_at = datetime.now()
|
||||
nanobot_service.agent.sessions.save(session)
|
||||
return {"status": "success", "metadata": session.metadata}
|
||||
|
||||
Reference in New Issue
Block a user