diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index 58747b6..78269ae 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -23,6 +23,7 @@ class NL2SQLRequest(BaseModel): query: str = Field(..., description="User's natural language query") source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)") file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload") + session_id: Optional[str] = Field(None, description="Conversation session identifier") class NL2SQLResponse(BaseModel): sql: str diff --git a/backend/main.py b/backend/main.py index a5fad05..82b53c7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -77,6 +77,42 @@ class SessionAliasUpdateRequest(BaseModel): pinned: Optional[bool] = None archived: Optional[bool] = None + +def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str: + chart = nl2sql_result.chart + can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) + text = ( + f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" + f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" + ) + if chart and chart.reasoning: + return f"{text}\n\n可视化说明:{chart.reasoning}" + return text + + +def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict: + chart = nl2sql_result.chart + return { + "sql": nl2sql_result.sql, + "result": nl2sql_result.result, + "chart": chart.model_dump() if chart else None, + "error": nl2sql_result.error, + } + + +def _persist_session_turn( + session_id: str, + user_message: str, + assistant_message: str, + assistant_extra: Optional[dict] = None, +) -> None: + if not nanobot_service.agent: + return + session = nanobot_service.agent.sessions.get_or_create(session_id) + session.add_message("user", user_message) + session.add_message("assistant", assistant_message, **(assistant_extra or {})) + nanobot_service.agent.sessions.save(session) + @app.post("/nanobot/chat") async def nanobot_chat(request: ChatRequest): try: @@ -84,22 +120,12 @@ async def nanobot_chat(request: ChatRequest): nl2sql_result = await process_nl2sql( NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) ) - chart = nl2sql_result.chart - can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) - text = ( - f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" - f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" - ) - if chart and chart.reasoning: - text = f"{text}\n\n可视化说明:{chart.reasoning}" + text = _build_sql_chart_text(nl2sql_result) + viz_payload = _build_sql_chart_viz(nl2sql_result) + _persist_session_turn(request.session_id, request.message, text, {"viz": viz_payload}) return { "response": text, - "viz": { - "sql": nl2sql_result.sql, - "result": nl2sql_result.result, - "chart": chart.model_dump() if chart else None, - "error": nl2sql_result.error, - }, + "viz": viz_payload, } response = await nanobot_service.process_message( request.message, @@ -119,22 +145,14 @@ async def nanobot_chat_stream(request: ChatRequest): nl2sql_result = await process_nl2sql( NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) ) - chart = nl2sql_result.chart + persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) viz_payload = { "type": "viz", - "sql": nl2sql_result.sql, - "result": nl2sql_result.result, - "chart": chart.model_dump() if chart else None, - "error": nl2sql_result.error, + **persisted_viz_payload, } yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n" - can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) - text = ( - f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" - f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" - ) - if chart and chart.reasoning: - text = f"{text}\n\n可视化说明:{chart.reasoning}" + text = _build_sql_chart_text(nl2sql_result) + _persist_session_turn(request.session_id, request.message, text, {"viz": persisted_viz_payload}) yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" return @@ -229,4 +247,9 @@ def update_session(session_id: str, payload: SessionAliasUpdateRequest): @app.post("/api/v1/agent/nl2sql", response_model=NL2SQLResponse) async def run_nl2sql(request: NL2SQLRequest): - return await process_nl2sql(request) + result = await process_nl2sql(request) + if request.session_id: + text = _build_sql_chart_text(result) + viz_payload = _build_sql_chart_viz(result) + _persist_session_turn(request.session_id, request.query, text, {"viz": viz_payload}) + return result diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index fa4a093..bffe873 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -85,7 +85,8 @@ export function ChatInterface() { const formattedMessages = data.messages.map((m, idx) => ({ id: `${Date.now()}-${idx}`, role: m.role as 'user' | 'assistant', - content: m.content + content: m.content, + viz: m.viz ? buildMessageViz(m.viz) : undefined, })); setMessages(formattedMessages); } else { @@ -461,7 +462,7 @@ export function ChatInterface() { onChange={handleFileUpload} />