From 4f46f3f8d5d9ff9b38ee8f458df8a93f44f296ec Mon Sep 17 00:00:00 2001 From: qixinbo Date: Sun, 15 Mar 2026 17:57:09 +0800 Subject: [PATCH] fix: session bug --- backend/app/agent/nl2sql.py | 1 + backend/main.py | 77 ++++++++----- frontend/src/components/ChatInterface.tsx | 5 +- .../components/InlineVisualizationCard.tsx | 1 + .../src/components/VisualizationPanel.tsx | 1 + frontend/src/pages/Dashboard.tsx | 104 +++++++++++++----- frontend/src/store/dashboardStore.ts | 2 + 7 files changed, 136 insertions(+), 55 deletions(-) diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index 58747b6..78269ae 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -23,6 +23,7 @@ class NL2SQLRequest(BaseModel): query: str = Field(..., description="User's natural language query") source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)") file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload") + session_id: Optional[str] = Field(None, description="Conversation session identifier") class NL2SQLResponse(BaseModel): sql: str diff --git a/backend/main.py b/backend/main.py index a5fad05..82b53c7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -77,6 +77,42 @@ class SessionAliasUpdateRequest(BaseModel): pinned: Optional[bool] = None archived: Optional[bool] = None + +def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str: + chart = nl2sql_result.chart + can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) + text = ( + f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" + f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" + ) + if chart and chart.reasoning: + return f"{text}\n\n可视化说明:{chart.reasoning}" + return text + + +def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict: + chart = nl2sql_result.chart + return { + "sql": nl2sql_result.sql, + "result": nl2sql_result.result, + "chart": chart.model_dump() if chart else None, + "error": nl2sql_result.error, + } + + +def _persist_session_turn( + session_id: str, + user_message: str, + assistant_message: str, + assistant_extra: Optional[dict] = None, +) -> None: + if not nanobot_service.agent: + return + session = nanobot_service.agent.sessions.get_or_create(session_id) + session.add_message("user", user_message) + session.add_message("assistant", assistant_message, **(assistant_extra or {})) + nanobot_service.agent.sessions.save(session) + @app.post("/nanobot/chat") async def nanobot_chat(request: ChatRequest): try: @@ -84,22 +120,12 @@ async def nanobot_chat(request: ChatRequest): nl2sql_result = await process_nl2sql( NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) ) - chart = nl2sql_result.chart - can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) - text = ( - f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" - f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" - ) - if chart and chart.reasoning: - text = f"{text}\n\n可视化说明:{chart.reasoning}" + text = _build_sql_chart_text(nl2sql_result) + viz_payload = _build_sql_chart_viz(nl2sql_result) + _persist_session_turn(request.session_id, request.message, text, {"viz": viz_payload}) return { "response": text, - "viz": { - "sql": nl2sql_result.sql, - "result": nl2sql_result.result, - "chart": chart.model_dump() if chart else None, - "error": nl2sql_result.error, - }, + "viz": viz_payload, } response = await nanobot_service.process_message( request.message, @@ -119,22 +145,14 @@ async def nanobot_chat_stream(request: ChatRequest): nl2sql_result = await process_nl2sql( NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) ) - chart = nl2sql_result.chart + persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) viz_payload = { "type": "viz", - "sql": nl2sql_result.sql, - "result": nl2sql_result.result, - "chart": chart.model_dump() if chart else None, - "error": nl2sql_result.error, + **persisted_viz_payload, } yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n" - can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) - text = ( - f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" - f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" - ) - if chart and chart.reasoning: - text = f"{text}\n\n可视化说明:{chart.reasoning}" + text = _build_sql_chart_text(nl2sql_result) + _persist_session_turn(request.session_id, request.message, text, {"viz": persisted_viz_payload}) yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" return @@ -229,4 +247,9 @@ def update_session(session_id: str, payload: SessionAliasUpdateRequest): @app.post("/api/v1/agent/nl2sql", response_model=NL2SQLResponse) async def run_nl2sql(request: NL2SQLRequest): - return await process_nl2sql(request) + result = await process_nl2sql(request) + if request.session_id: + text = _build_sql_chart_text(result) + viz_payload = _build_sql_chart_viz(result) + _persist_session_turn(request.session_id, request.query, text, {"viz": viz_payload}) + return result diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index fa4a093..bffe873 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -85,7 +85,8 @@ export function ChatInterface() { const formattedMessages = data.messages.map((m, idx) => ({ id: `${Date.now()}-${idx}`, role: m.role as 'user' | 'assistant', - content: m.content + content: m.content, + viz: m.viz ? buildMessageViz(m.viz) : undefined, })); setMessages(formattedMessages); } else { @@ -461,7 +462,7 @@ export function ChatInterface() { onChange={handleFileUpload} />
- {messages.length <= 1 ? ( + {messages.length === 0 ? (
{/* Logo Area */}
diff --git a/frontend/src/components/InlineVisualizationCard.tsx b/frontend/src/components/InlineVisualizationCard.tsx index ee6f829..5031178 100644 --- a/frontend/src/components/InlineVisualizationCard.tsx +++ b/frontend/src/components/InlineVisualizationCard.tsx @@ -36,6 +36,7 @@ export function InlineVisualizationCard({ viz }: InlineVisualizationCardProps) { type: dashboardType, data: objectRows, sql: viz.sql, + chartSpec: viz.chartSpec, }); }; diff --git a/frontend/src/components/VisualizationPanel.tsx b/frontend/src/components/VisualizationPanel.tsx index 661254d..0344200 100644 --- a/frontend/src/components/VisualizationPanel.tsx +++ b/frontend/src/components/VisualizationPanel.tsx @@ -25,6 +25,7 @@ export function VisualizationPanel() { type: dashboardType, data: currentData, sql: currentSQL, + chartSpec: currentChartSpec, }); alert("Added to Dashboard!"); }; diff --git a/frontend/src/pages/Dashboard.tsx b/frontend/src/pages/Dashboard.tsx index 1ce1d8e..2dd5a8c 100644 --- a/frontend/src/pages/Dashboard.tsx +++ b/frontend/src/pages/Dashboard.tsx @@ -5,9 +5,39 @@ import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/com import { Button } from "@/components/ui/button"; import { X } from "lucide-react"; import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer, LineChart, Line } from 'recharts'; +import { VegaChart } from "@/components/VegaChart"; import 'react-grid-layout/css/styles.css'; import 'react-resizable/css/styles.css'; +const CHART_COLORS = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#06b6d4']; + +function isNumericValue(value: unknown) { + if (typeof value === 'number') return Number.isFinite(value); + if (typeof value === 'string') { + const trimmed = value.trim(); + if (!trimmed) return false; + const parsed = Number(trimmed); + return Number.isFinite(parsed); + } + return false; +} + +function inferChartKeys(data: Record[]) { + if (data.length === 0) { + return { xKey: null as string | null, yKeys: [] as string[] }; + } + const allKeys = Object.keys(data[0] || {}); + if (allKeys.length === 0) { + return { xKey: null as string | null, yKeys: [] as string[] }; + } + const preferredX = ['name', 'date', 'time', 'category', 'label']; + const xKey = preferredX.find((k) => allKeys.includes(k)) || allKeys[0]; + const candidateY = allKeys.filter((k) => k !== xKey); + const numericY = candidateY.filter((key) => data.some((row) => isNumericValue(row[key]))); + const yKeys = (numericY.length > 0 ? numericY : candidateY).slice(0, 3); + return { xKey, yKeys }; +} + export function Dashboard() { const { charts, removeChart } = useDashboardStore(); const ResponsiveGridLayout = useMemo( @@ -65,32 +95,54 @@ export function Dashboard() { - - {chart.type === 'bar' ? ( - - - - - - - - - ) : ( - - - - - - - - - )} - + {(() => { + const rows = chart.data as Record[]; + if (chart.chartSpec && rows.length > 0) { + return ( +
+ +
+ ); + } + const { xKey, yKeys } = inferChartKeys(rows); + if (!xKey || yKeys.length === 0) { + return ( +
+ 当前图表数据缺少可绘制字段 +
+ ); + } + return ( + + {chart.type === 'bar' ? ( + + + + + + {yKeys.map((key, index) => ( + + ))} + + ) : ( + + + + + + {yKeys.map((key, index) => ( + + ))} + + )} + + ); + })()}
diff --git a/frontend/src/store/dashboardStore.ts b/frontend/src/store/dashboardStore.ts index ba8463a..67cacf4 100644 --- a/frontend/src/store/dashboardStore.ts +++ b/frontend/src/store/dashboardStore.ts @@ -1,4 +1,5 @@ import { create } from 'zustand'; +import type { ChartSpec } from './visualizationStore'; type ChartRow = Record; type GridLayout = { i: string; x: number; y: number; w: number; h: number }; @@ -9,6 +10,7 @@ export interface ChartConfig { type: 'bar' | 'line'; data: ChartRow[]; sql: string; + chartSpec?: ChartSpec | null; layout: GridLayout; }