data source binded

This commit is contained in:
qixinbo
2026-03-16 23:16:33 +08:00
parent 720c30a893
commit b9d1c182cf
2 changed files with 55 additions and 20 deletions
+24 -7
View File
@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
@@ -84,6 +85,13 @@ class ChatRequest(BaseModel):
file_url: Optional[str] = None
def _should_use_nl2sql(request: ChatRequest) -> bool:
if request.prefer_sql_chart:
return True
source = (request.source or "").strip().lower()
return source == "upload" or source.startswith("ds:")
class SessionAliasUpdateRequest(BaseModel):
title: Optional[str] = None
pinned: Optional[bool] = None
@@ -96,6 +104,7 @@ class BatchDeleteRequest(BaseModel):
class SessionFileContextUpdateRequest(BaseModel):
active_data_file: Optional[Dict[str, Any]] = None
selected_data_source: Optional[str] = None
def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
@@ -112,12 +121,13 @@ def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
chart = nl2sql_result.chart
return {
payload = {
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"error": nl2sql_result.error,
}
return jsonable_encoder(payload)
def _persist_session_turn(
@@ -136,7 +146,7 @@ def _persist_session_turn(
@app.post("/nanobot/chat")
async def nanobot_chat(request: ChatRequest):
try:
if request.prefer_sql_chart:
if _should_use_nl2sql(request):
nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
)
@@ -161,7 +171,7 @@ async def nanobot_chat(request: ChatRequest):
async def nanobot_chat_stream(request: ChatRequest):
async def event_generator():
try:
if request.prefer_sql_chart:
if _should_use_nl2sql(request):
nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
)
@@ -295,10 +305,17 @@ def update_session_context_file(session_id: str, payload: SessionFileContextUpda
if not nanobot_service.agent:
raise HTTPException(status_code=400, detail="Nanobot not running")
session = nanobot_service.agent.sessions.get_or_create(session_id)
if payload.active_data_file is None:
session.metadata.pop("active_data_file", None)
else:
session.metadata["active_data_file"] = payload.active_data_file
updated_fields = payload.model_fields_set
if "active_data_file" in updated_fields:
if payload.active_data_file is None:
session.metadata.pop("active_data_file", None)
else:
session.metadata["active_data_file"] = payload.active_data_file
if "selected_data_source" in updated_fields:
if payload.selected_data_source:
session.metadata["selected_data_source"] = payload.selected_data_source
else:
session.metadata.pop("selected_data_source", None)
session.updated_at = datetime.now()
nanobot_service.agent.sessions.save(session)
return {"status": "success", "metadata": session.metadata}