From 696fd94ff3e0942c31571144975d6ae5eda0cb59 Mon Sep 17 00:00:00 2001 From: qixinbo Date: Sun, 15 Mar 2026 10:49:37 +0800 Subject: [PATCH] feature: nl2sql first successful --- backend/app/agent/nl2sql.py | 71 +++++++++-- backend/main.py | 47 ++++++++ frontend/src/App.tsx | 10 +- frontend/src/components/ChatInterface.tsx | 112 ++++++++++++++++-- frontend/src/components/VegaChart.tsx | 17 +-- .../src/components/VisualizationPanel.tsx | 10 +- frontend/src/store/visualizationStore.ts | 32 +++-- 7 files changed, 252 insertions(+), 47 deletions(-) diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index cb0e6a3..58747b6 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -4,6 +4,8 @@ import json from pathlib import Path from typing import List, Optional, Dict, Any from pydantic import BaseModel, Field +import duckdb +import pandas as pd # Add project root to sys.path to allow importing nanobot PROJECT_ROOT = Path(__file__).resolve().parents[3] @@ -19,7 +21,8 @@ from app.agent.chart import generate_chart class NL2SQLRequest(BaseModel): query: str = Field(..., description="User's natural language query") - source: str = Field(..., description="Data source to query (postgres, clickhouse)") + source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)") + file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload") class NL2SQLResponse(BaseModel): sql: str @@ -80,20 +83,63 @@ The final answer must be a ANSI SQL query in JSON format: }} """ +def _resolve_upload_file_path(file_url: Optional[str]) -> Path: + if not file_url or not file_url.startswith("local://"): + raise ValueError("Invalid uploaded file URL") + raw_name = file_url.replace("local://", "", 1) + safe_name = os.path.basename(raw_name) + upload_dir = Path(__file__).resolve().parents[2] / "data" / "uploads" + file_path = upload_dir / safe_name + if not file_path.exists(): + raise ValueError(f"Uploaded file not found: {safe_name}") + return file_path + +def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame: + file_path = _resolve_upload_file_path(file_url) + suffix = file_path.suffix.lower() + if suffix == ".csv": + return pd.read_csv(file_path) + if suffix in [".xls", ".xlsx"]: + return pd.read_excel(file_path) + raise ValueError(f"Unsupported uploaded file type: {suffix}") + +def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]: + df = _load_upload_dataframe(file_url) + conn = duckdb.connect(":memory:") + conn.register("uploaded_file", df) + columns = conn.execute("DESCRIBE uploaded_file").fetchall() + schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]} + conn.close() + return schema + +def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]: + df = _load_upload_dataframe(file_url) + conn = duckdb.connect(":memory:") + conn.register("uploaded_file", df) + result_df = conn.execute(sql_query).df() + conn.close() + return result_df.to_dict(orient="records") + async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: # 1. Get the connector and schema connector = None + schema = {} if request.source == "postgres": connector = postgres_connector elif request.source == "clickhouse": connector = clickhouse_connector + elif request.source == "upload": + try: + schema = _get_upload_schema(request.file_url) + except Exception as e: + return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}") else: return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}") - if not connector.test_connection(): - return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") - - schema = connector.get_schema() + if connector: + if not connector.test_connection(): + return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") + schema = connector.get_schema() schema_str = json.dumps(schema, indent=2) # 2. Get the active LLM config @@ -158,19 +204,22 @@ Let's think step by step. # 6. Execute SQL try: - results = connector.execute_query(sql_query) + if request.source == "upload": + formatted_results = _execute_upload_sql(sql_query, request.file_url) + else: + results = connector.execute_query(sql_query) # Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples) - formatted_results = [] - if request.source == "postgres": - formatted_results = results - elif request.source == "clickhouse": + formatted_results = [] + if request.source == "postgres": + formatted_results = results + elif request.source == "clickhouse": # ClickHouse returns list of tuples, we need column names # But execute_query in ClickHouseConnector just returns raw results from client.execute # client.execute(query, with_column_types=True) might be better but let's stick to simple for now # Actually, without column names it's hard to format as dict. # Let's assume we can just return the raw tuples for now or try to fetch column names. # For now, let's just return as list of lists/tuples if it's not a dict - formatted_results = [list(row) for row in results] + formatted_results = [list(row) for row in results] # 7. Generate Chart chart_response = None diff --git a/backend/main.py b/backend/main.py index 843005a..5165c9a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -67,6 +67,9 @@ class ChatRequest(BaseModel): session_id: str = "api:default" skill_ids: Optional[List[str]] = None model_id: Optional[str] = None + source: str = "postgres" + prefer_sql_chart: bool = False + file_url: Optional[str] = None class SessionAliasUpdateRequest(BaseModel): @@ -77,6 +80,27 @@ class SessionAliasUpdateRequest(BaseModel): @app.post("/nanobot/chat") async def nanobot_chat(request: ChatRequest): try: + if request.prefer_sql_chart: + nl2sql_result = await process_nl2sql( + NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) + ) + chart = nl2sql_result.chart + can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) + text = ( + f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" + f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" + ) + if chart and chart.reasoning: + text = f"{text}\n\n可视化说明:{chart.reasoning}" + return { + "response": text, + "viz": { + "sql": nl2sql_result.sql, + "result": nl2sql_result.result, + "chart": chart.model_dump() if chart else None, + "error": nl2sql_result.error, + }, + } response = await nanobot_service.process_message( request.message, session_id=request.session_id, @@ -91,6 +115,29 @@ async def nanobot_chat(request: ChatRequest): async def nanobot_chat_stream(request: ChatRequest): async def event_generator(): try: + if request.prefer_sql_chart: + nl2sql_result = await process_nl2sql( + NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) + ) + chart = nl2sql_result.chart + viz_payload = { + "type": "viz", + "sql": nl2sql_result.sql, + "result": nl2sql_result.result, + "chart": chart.model_dump() if chart else None, + "error": nl2sql_result.error, + } + yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n" + can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) + text = ( + f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" + f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" + ) + if chart and chart.reasoning: + text = f"{text}\n\n可视化说明:{chart.reasoning}" + yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + return response = await nanobot_service.process_message( request.message, session_id=request.session_id, diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 503a195..7a80db4 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,6 +1,7 @@ import { BrowserRouter, Routes, Route, Navigate } from "react-router-dom"; import { Sidebar } from "./components/Sidebar"; import { ChatInterface } from "./components/ChatInterface"; +import { VisualizationPanel } from "./components/VisualizationPanel"; import { Dashboard } from "./pages/Dashboard"; import { Skills } from "./pages/Skills"; import { Settings } from "./pages/Settings"; @@ -45,8 +46,13 @@ function App() { -
- +
+
+ +
+
+ +
diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index d386fae..e824e4a 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -41,7 +41,7 @@ export function ChatInterface() { const [messages, setMessages] = useState([]); const [input, setInput] = useState(""); const [selectedCapability, setSelectedCapability] = useState("智能问答"); - const selectedDataSource = "postgres-main"; + const [selectedDataSource, setSelectedDataSource] = useState("postgres-main"); const [isLoading, setIsLoading] = useState(false); const scrollRef = useRef(null); const { setVisualization, setLoading: setVizLoading, setError: setVizError } = useVisualizationStore(); @@ -114,6 +114,7 @@ export function ChatInterface() { { icon: Table, label: "表格问答", color: "text-orange-500", bg: "bg-orange-50" }, { icon: Search, label: "深度问数", color: "text-blue-500", bg: "bg-blue-50" }, ]; + const chartIntentPattern = /(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|chart|plot|visuali[sz]e)/i; const handleFileUpload = async (e: React.ChangeEvent) => { const file = e.target.files?.[0]; @@ -168,8 +169,9 @@ export function ChatInterface() { setInput(""); let messagePayload = newMessage.content; - if (attachedFile) { - messagePayload = `[用户上传了文件: ${attachedFile.filename}]\n[文件内容摘要: ${attachedFile.summary || "无"}]\n[数据列: ${attachedFile.columns?.join(", ") || "无"}]\n[文件下载链接: ${attachedFile.url}]\n\n${newMessage.content}`; + const currentAttachedFile = attachedFile; + if (currentAttachedFile) { + messagePayload = `[用户上传了文件: ${currentAttachedFile.filename}]\n[文件内容摘要: ${currentAttachedFile.summary || "无"}]\n[数据列: ${currentAttachedFile.columns?.join(", ") || "无"}]\n[文件下载链接: ${currentAttachedFile.url}]\n\n${newMessage.content}`; setAttachedFile(null); } @@ -189,6 +191,9 @@ export function ChatInterface() { const token = localStorage.getItem("token"); const effectiveModelId = selectedModelId || currentModel?.id || ""; + const source = currentAttachedFile?.url?.startsWith("local://") ? "upload" : selectedDataSource.split('-')[0]; + const fileUrl = currentAttachedFile?.url || undefined; + const preferSqlChart = chartIntentPattern.test(messagePayload); const response = await fetch("/nanobot/chat/stream", { method: "POST", headers: { @@ -199,6 +204,9 @@ export function ChatInterface() { message: messagePayload, session_id: activeSessionKey, model_id: effectiveModelId, + source, + prefer_sql_chart: preferSqlChart, + file_url: fileUrl, }), }); @@ -226,7 +234,14 @@ export function ChatInterface() { if (!line) continue; const payloadText = line.slice(5).trim(); if (!payloadText) continue; - const payload = JSON.parse(payloadText) as { type: string; content?: string }; + const payload = JSON.parse(payloadText) as { + type: string; + content?: string; + sql?: string; + result?: unknown; + error?: string; + chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null; + }; if (payload.type === "delta" && payload.content) { streamedText = `${streamedText}${payload.content}`; @@ -249,15 +264,69 @@ export function ChatInterface() { if (payload.type === "error") { throw new Error(payload.content || "流式响应错误"); } + + if (payload.type === "viz") { + if (payload.error) { + setVizError(payload.error); + } else { + const rows = Array.isArray(payload.result) ? payload.result : []; + const sql = typeof payload.sql === "string" ? payload.sql : ""; + const chart = payload.chart ?? undefined; + const canVisualize = Boolean(chart?.can_visualize); + const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null; + setVisualization( + rows, + sql, + chartSpec, + { + canVisualize, + reasoning: chart?.reasoning, + chartType: chart?.chart_type, + description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化", + } + ); + } + } } } if (!streamedText) { - const fallback = await api.post<{ response: string }>("/nanobot/chat", { + const fallback = await api.post<{ + response: string; + viz?: { + sql?: string; + result?: unknown; + error?: string | null; + chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null; + }; + }>("/nanobot/chat", { message: messagePayload, session_id: activeSessionKey, model_id: effectiveModelId, + source, + prefer_sql_chart: preferSqlChart, + file_url: fileUrl, }); + if (fallback.viz?.error) { + setVizError(fallback.viz.error); + } else if (fallback.viz) { + const rows = Array.isArray(fallback.viz.result) ? fallback.viz.result : []; + const sql = typeof fallback.viz.sql === "string" ? fallback.viz.sql : ""; + const chart = fallback.viz.chart ?? undefined; + const canVisualize = Boolean(chart?.can_visualize); + const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null; + setVisualization( + rows, + sql, + chartSpec, + { + canVisualize, + reasoning: chart?.reasoning, + chartType: chart?.chart_type, + description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化", + } + ); + } setMessages((prev) => prev.map((msg) => msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false } : msg @@ -266,15 +335,16 @@ export function ChatInterface() { } } else { // Fallback to existing NL2SQL or other skills (e.g. for "表格问答" or "深度问数") - const source = selectedDataSource.split('-')[0]; // postgres-main -> postgres + const source = currentAttachedFile?.url?.startsWith("local://") ? "upload" : selectedDataSource.split('-')[0]; const response = await api.post<{ sql?: string, result?: unknown, error?: string, - chart?: { chart_spec: ChartSpec, reasoning: string, can_visualize: boolean } + chart?: { chart_spec?: ChartSpec | null, reasoning?: string, can_visualize?: boolean, chart_type?: string } }>('/api/v1/agent/nl2sql', { query: messagePayload, source: source, + file_url: currentAttachedFile?.url, session_id: activeSessionKey, model_id: selectedModelId }); @@ -289,12 +359,25 @@ export function ChatInterface() { } else { const rows = Array.isArray(response.result) ? response.result : []; const sql = typeof response.sql === "string" ? response.sql : ""; + const chart = response.chart; + const canVisualize = Boolean(chart?.can_visualize); + const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null; setMessages(prev => [...prev, { id: (Date.now() + 1).toString(), role: 'assistant', - content: `I've generated a SQL query and fetched ${rows.length} rows for you. Check the visualization panel.${response.chart?.reasoning ? `\n\nVisualization reasoning: ${response.chart.reasoning}` : ''}` + content: `已为你生成 SQL 并查询到 ${rows.length} 行数据。${canVisualize ? '可视化面板已同步更新图表。' : '本次结果不适合图表展示。'}${chart?.reasoning ? `\n\n可视化说明:${chart.reasoning}` : ''}` }]); - setVisualization(rows, sql, response.chart?.chart_spec); + setVisualization( + rows, + sql, + chartSpec, + { + canVisualize, + reasoning: chart?.reasoning, + chartType: chart?.chart_type, + description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化", + } + ); } } } catch (error: any) { @@ -353,6 +436,17 @@ export function ChatInterface() { +
+ 数据源 + +
diff --git a/frontend/src/components/VegaChart.tsx b/frontend/src/components/VegaChart.tsx index 729355f..4826e1b 100644 --- a/frontend/src/components/VegaChart.tsx +++ b/frontend/src/components/VegaChart.tsx @@ -9,23 +9,14 @@ interface VegaChartProps { export const VegaChart: React.FC = ({ data, spec }) => { const vegaSpec: any = { - $schema: 'https://vega.github.io/schema/vega-lite/v5.json', - description: spec.description, - title: spec.title, + $schema: typeof spec.$schema === 'string' ? spec.$schema : 'https://vega.github.io/schema/vega-lite/v5.json', + ...spec, width: "container", height: "container", - mark: { type: spec.chart_type, tooltip: true }, - encoding: { - x: { field: spec.x_axis, type: 'nominal', axis: { labelAngle: -45 } }, - y: { field: spec.y_axis, type: 'quantitative' }, - }, - data: { values: data } + data: { values: data }, + autosize: { type: "fit", contains: "padding" }, }; - if (spec.color) { - vegaSpec.encoding.color = { field: spec.color, type: 'nominal' }; - } - return (
('chart'); const { addChart } = useDashboardStore(); - const { currentData, currentSQL, currentChartSpec, isLoading, error } = useVisualizationStore(); + const { currentData, currentSQL, currentChartSpec, currentChartInfo, isLoading, error } = useVisualizationStore(); const handleAddToDashboard = () => { if (!currentData || !currentSQL) return; - + const mark = currentChartSpec?.mark; + const markType = typeof mark === "string" ? mark : mark?.type; + const dashboardType = markType === "line" ? "line" : "bar"; addChart({ id: Date.now().toString(), title: currentChartSpec?.title || 'Generated Analysis', - type: currentChartSpec?.chart_type as any || 'bar', + type: dashboardType, data: currentData, sql: currentSQL, }); @@ -134,7 +136,7 @@ export function VisualizationPanel() { {currentChartSpec?.title || 'Analysis Result'} - {currentChartSpec?.description || 'Generated from your query'} + {currentChartInfo?.reasoning || currentChartSpec?.description || 'Generated from your query'} {view === 'chart' ? ( diff --git a/frontend/src/store/visualizationStore.ts b/frontend/src/store/visualizationStore.ts index 6da500a..69e2f49 100644 --- a/frontend/src/store/visualizationStore.ts +++ b/frontend/src/store/visualizationStore.ts @@ -1,11 +1,19 @@ import { create } from 'zustand'; export interface ChartSpec { - chart_type: string; - title: string; - x_axis: string; - y_axis: string; - color?: string; + $schema?: string; + title?: string; + description?: string; + mark?: string | { type?: string; [key: string]: unknown }; + encoding?: Record; + transform?: Array>; + [key: string]: unknown; +} + +export interface ChartInfo { + canVisualize: boolean; + reasoning?: string; + chartType?: string; description?: string; } @@ -13,9 +21,10 @@ export interface VisualizationState { currentData: any[] | null; currentSQL: string | null; currentChartSpec: ChartSpec | null; + currentChartInfo: ChartInfo | null; isLoading: boolean; error: string | null; - setVisualization: (data: any[], sql: string, chartSpec?: ChartSpec | null) => void; + setVisualization: (data: any[], sql: string, chartSpec?: ChartSpec | null, chartInfo?: ChartInfo | null) => void; setLoading: (loading: boolean) => void; setError: (error: string | null) => void; clearVisualization: () => void; @@ -25,10 +34,17 @@ export const useVisualizationStore = create((set) => ({ currentData: null, currentSQL: null, currentChartSpec: null, + currentChartInfo: null, isLoading: false, error: null, - setVisualization: (data, sql, chartSpec = null) => set({ currentData: data, currentSQL: sql, currentChartSpec: chartSpec, error: null }), + setVisualization: (data, sql, chartSpec = null, chartInfo = null) => set({ + currentData: data, + currentSQL: sql, + currentChartSpec: chartSpec, + currentChartInfo: chartInfo, + error: null, + }), setLoading: (loading) => set({ isLoading: loading }), setError: (error) => set({ error, isLoading: false }), - clearVisualization: () => set({ currentData: null, currentSQL: null, currentChartSpec: null, error: null }), + clearVisualization: () => set({ currentData: null, currentSQL: null, currentChartSpec: null, currentChartInfo: null, error: null }), }));