feature: nl2sql first successful

This commit is contained in:
qixinbo
2026-03-15 10:49:37 +08:00
parent 76724b2313
commit 696fd94ff3
7 changed files with 252 additions and 47 deletions
+51 -2
View File
@@ -4,6 +4,8 @@ import json
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import duckdb
import pandas as pd
# Add project root to sys.path to allow importing nanobot # Add project root to sys.path to allow importing nanobot
PROJECT_ROOT = Path(__file__).resolve().parents[3] PROJECT_ROOT = Path(__file__).resolve().parents[3]
@@ -19,7 +21,8 @@ from app.agent.chart import generate_chart
class NL2SQLRequest(BaseModel): class NL2SQLRequest(BaseModel):
query: str = Field(..., description="User's natural language query") query: str = Field(..., description="User's natural language query")
source: str = Field(..., description="Data source to query (postgres, clickhouse)") source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
class NL2SQLResponse(BaseModel): class NL2SQLResponse(BaseModel):
sql: str sql: str
@@ -80,19 +83,62 @@ The final answer must be a ANSI SQL query in JSON format:
}} }}
""" """
def _resolve_upload_file_path(file_url: Optional[str]) -> Path:
if not file_url or not file_url.startswith("local://"):
raise ValueError("Invalid uploaded file URL")
raw_name = file_url.replace("local://", "", 1)
safe_name = os.path.basename(raw_name)
upload_dir = Path(__file__).resolve().parents[2] / "data" / "uploads"
file_path = upload_dir / safe_name
if not file_path.exists():
raise ValueError(f"Uploaded file not found: {safe_name}")
return file_path
def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
file_path = _resolve_upload_file_path(file_url)
suffix = file_path.suffix.lower()
if suffix == ".csv":
return pd.read_csv(file_path)
if suffix in [".xls", ".xlsx"]:
return pd.read_excel(file_path)
raise ValueError(f"Unsupported uploaded file type: {suffix}")
def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
df = _load_upload_dataframe(file_url)
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]}
conn.close()
return schema
def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]:
df = _load_upload_dataframe(file_url)
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
result_df = conn.execute(sql_query).df()
conn.close()
return result_df.to_dict(orient="records")
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
# 1. Get the connector and schema # 1. Get the connector and schema
connector = None connector = None
schema = {}
if request.source == "postgres": if request.source == "postgres":
connector = postgres_connector connector = postgres_connector
elif request.source == "clickhouse": elif request.source == "clickhouse":
connector = clickhouse_connector connector = clickhouse_connector
elif request.source == "upload":
try:
schema = _get_upload_schema(request.file_url)
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
else: else:
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}") return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
if connector:
if not connector.test_connection(): if not connector.test_connection():
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema() schema = connector.get_schema()
schema_str = json.dumps(schema, indent=2) schema_str = json.dumps(schema, indent=2)
@@ -158,6 +204,9 @@ Let's think step by step.
# 6. Execute SQL # 6. Execute SQL
try: try:
if request.source == "upload":
formatted_results = _execute_upload_sql(sql_query, request.file_url)
else:
results = connector.execute_query(sql_query) results = connector.execute_query(sql_query)
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples) # Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
formatted_results = [] formatted_results = []
+47
View File
@@ -67,6 +67,9 @@ class ChatRequest(BaseModel):
session_id: str = "api:default" session_id: str = "api:default"
skill_ids: Optional[List[str]] = None skill_ids: Optional[List[str]] = None
model_id: Optional[str] = None model_id: Optional[str] = None
source: str = "postgres"
prefer_sql_chart: bool = False
file_url: Optional[str] = None
class SessionAliasUpdateRequest(BaseModel): class SessionAliasUpdateRequest(BaseModel):
@@ -77,6 +80,27 @@ class SessionAliasUpdateRequest(BaseModel):
@app.post("/nanobot/chat") @app.post("/nanobot/chat")
async def nanobot_chat(request: ChatRequest): async def nanobot_chat(request: ChatRequest):
try: try:
if request.prefer_sql_chart:
nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
)
chart = nl2sql_result.chart
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
text = (
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
)
if chart and chart.reasoning:
text = f"{text}\n\n可视化说明:{chart.reasoning}"
return {
"response": text,
"viz": {
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"error": nl2sql_result.error,
},
}
response = await nanobot_service.process_message( response = await nanobot_service.process_message(
request.message, request.message,
session_id=request.session_id, session_id=request.session_id,
@@ -91,6 +115,29 @@ async def nanobot_chat(request: ChatRequest):
async def nanobot_chat_stream(request: ChatRequest): async def nanobot_chat_stream(request: ChatRequest):
async def event_generator(): async def event_generator():
try: try:
if request.prefer_sql_chart:
nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
)
chart = nl2sql_result.chart
viz_payload = {
"type": "viz",
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"error": nl2sql_result.error,
}
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
text = (
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
)
if chart and chart.reasoning:
text = f"{text}\n\n可视化说明:{chart.reasoning}"
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return
response = await nanobot_service.process_message( response = await nanobot_service.process_message(
request.message, request.message,
session_id=request.session_id, session_id=request.session_id,
+7 -1
View File
@@ -1,6 +1,7 @@
import { BrowserRouter, Routes, Route, Navigate } from "react-router-dom"; import { BrowserRouter, Routes, Route, Navigate } from "react-router-dom";
import { Sidebar } from "./components/Sidebar"; import { Sidebar } from "./components/Sidebar";
import { ChatInterface } from "./components/ChatInterface"; import { ChatInterface } from "./components/ChatInterface";
import { VisualizationPanel } from "./components/VisualizationPanel";
import { Dashboard } from "./pages/Dashboard"; import { Dashboard } from "./pages/Dashboard";
import { Skills } from "./pages/Skills"; import { Skills } from "./pages/Skills";
import { Settings } from "./pages/Settings"; import { Settings } from "./pages/Settings";
@@ -45,9 +46,14 @@ function App() {
<Route path="/" element={ <Route path="/" element={
<ProtectedRoute> <ProtectedRoute>
<MainLayout> <MainLayout>
<div className="h-full overflow-hidden bg-white"> <div className="h-full overflow-hidden bg-white flex">
<div className="flex-1 min-w-0">
<ChatInterface /> <ChatInterface />
</div> </div>
<div className="w-[42%] min-w-[420px] border-l bg-background">
<VisualizationPanel />
</div>
</div>
</MainLayout> </MainLayout>
</ProtectedRoute> </ProtectedRoute>
} /> } />
+103 -9
View File
@@ -41,7 +41,7 @@ export function ChatInterface() {
const [messages, setMessages] = useState<Message[]>([]); const [messages, setMessages] = useState<Message[]>([]);
const [input, setInput] = useState(""); const [input, setInput] = useState("");
const [selectedCapability, setSelectedCapability] = useState<string>("智能问答"); const [selectedCapability, setSelectedCapability] = useState<string>("智能问答");
const selectedDataSource = "postgres-main"; const [selectedDataSource, setSelectedDataSource] = useState<string>("postgres-main");
const [isLoading, setIsLoading] = useState(false); const [isLoading, setIsLoading] = useState(false);
const scrollRef = useRef<HTMLDivElement>(null); const scrollRef = useRef<HTMLDivElement>(null);
const { setVisualization, setLoading: setVizLoading, setError: setVizError } = useVisualizationStore(); const { setVisualization, setLoading: setVizLoading, setError: setVizError } = useVisualizationStore();
@@ -114,6 +114,7 @@ export function ChatInterface() {
{ icon: Table, label: "表格问答", color: "text-orange-500", bg: "bg-orange-50" }, { icon: Table, label: "表格问答", color: "text-orange-500", bg: "bg-orange-50" },
{ icon: Search, label: "深度问数", color: "text-blue-500", bg: "bg-blue-50" }, { icon: Search, label: "深度问数", color: "text-blue-500", bg: "bg-blue-50" },
]; ];
const chartIntentPattern = /(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|chart|plot|visuali[sz]e)/i;
const handleFileUpload = async (e: React.ChangeEvent<HTMLInputElement>) => { const handleFileUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0]; const file = e.target.files?.[0];
@@ -168,8 +169,9 @@ export function ChatInterface() {
setInput(""); setInput("");
let messagePayload = newMessage.content; let messagePayload = newMessage.content;
if (attachedFile) { const currentAttachedFile = attachedFile;
messagePayload = `[用户上传了文件: ${attachedFile.filename}]\n[文件内容摘要: ${attachedFile.summary || "无"}]\n[数据列: ${attachedFile.columns?.join(", ") || "无"}]\n[文件下载链接: ${attachedFile.url}]\n\n${newMessage.content}`; if (currentAttachedFile) {
messagePayload = `[用户上传了文件: ${currentAttachedFile.filename}]\n[文件内容摘要: ${currentAttachedFile.summary || "无"}]\n[数据列: ${currentAttachedFile.columns?.join(", ") || "无"}]\n[文件下载链接: ${currentAttachedFile.url}]\n\n${newMessage.content}`;
setAttachedFile(null); setAttachedFile(null);
} }
@@ -189,6 +191,9 @@ export function ChatInterface() {
const token = localStorage.getItem("token"); const token = localStorage.getItem("token");
const effectiveModelId = selectedModelId || currentModel?.id || ""; const effectiveModelId = selectedModelId || currentModel?.id || "";
const source = currentAttachedFile?.url?.startsWith("local://") ? "upload" : selectedDataSource.split('-')[0];
const fileUrl = currentAttachedFile?.url || undefined;
const preferSqlChart = chartIntentPattern.test(messagePayload);
const response = await fetch("/nanobot/chat/stream", { const response = await fetch("/nanobot/chat/stream", {
method: "POST", method: "POST",
headers: { headers: {
@@ -199,6 +204,9 @@ export function ChatInterface() {
message: messagePayload, message: messagePayload,
session_id: activeSessionKey, session_id: activeSessionKey,
model_id: effectiveModelId, model_id: effectiveModelId,
source,
prefer_sql_chart: preferSqlChart,
file_url: fileUrl,
}), }),
}); });
@@ -226,7 +234,14 @@ export function ChatInterface() {
if (!line) continue; if (!line) continue;
const payloadText = line.slice(5).trim(); const payloadText = line.slice(5).trim();
if (!payloadText) continue; if (!payloadText) continue;
const payload = JSON.parse(payloadText) as { type: string; content?: string }; const payload = JSON.parse(payloadText) as {
type: string;
content?: string;
sql?: string;
result?: unknown;
error?: string;
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
};
if (payload.type === "delta" && payload.content) { if (payload.type === "delta" && payload.content) {
streamedText = `${streamedText}${payload.content}`; streamedText = `${streamedText}${payload.content}`;
@@ -249,15 +264,69 @@ export function ChatInterface() {
if (payload.type === "error") { if (payload.type === "error") {
throw new Error(payload.content || "流式响应错误"); throw new Error(payload.content || "流式响应错误");
} }
if (payload.type === "viz") {
if (payload.error) {
setVizError(payload.error);
} else {
const rows = Array.isArray(payload.result) ? payload.result : [];
const sql = typeof payload.sql === "string" ? payload.sql : "";
const chart = payload.chart ?? undefined;
const canVisualize = Boolean(chart?.can_visualize);
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
setVisualization(
rows,
sql,
chartSpec,
{
canVisualize,
reasoning: chart?.reasoning,
chartType: chart?.chart_type,
description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化",
}
);
}
}
} }
} }
if (!streamedText) { if (!streamedText) {
const fallback = await api.post<{ response: string }>("/nanobot/chat", { const fallback = await api.post<{
response: string;
viz?: {
sql?: string;
result?: unknown;
error?: string | null;
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
};
}>("/nanobot/chat", {
message: messagePayload, message: messagePayload,
session_id: activeSessionKey, session_id: activeSessionKey,
model_id: effectiveModelId, model_id: effectiveModelId,
source,
prefer_sql_chart: preferSqlChart,
file_url: fileUrl,
}); });
if (fallback.viz?.error) {
setVizError(fallback.viz.error);
} else if (fallback.viz) {
const rows = Array.isArray(fallback.viz.result) ? fallback.viz.result : [];
const sql = typeof fallback.viz.sql === "string" ? fallback.viz.sql : "";
const chart = fallback.viz.chart ?? undefined;
const canVisualize = Boolean(chart?.can_visualize);
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
setVisualization(
rows,
sql,
chartSpec,
{
canVisualize,
reasoning: chart?.reasoning,
chartType: chart?.chart_type,
description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化",
}
);
}
setMessages((prev) => setMessages((prev) =>
prev.map((msg) => prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false } : msg msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false } : msg
@@ -266,15 +335,16 @@ export function ChatInterface() {
} }
} else { } else {
// Fallback to existing NL2SQL or other skills (e.g. for "表格问答" or "深度问数") // Fallback to existing NL2SQL or other skills (e.g. for "表格问答" or "深度问数")
const source = selectedDataSource.split('-')[0]; // postgres-main -> postgres const source = currentAttachedFile?.url?.startsWith("local://") ? "upload" : selectedDataSource.split('-')[0];
const response = await api.post<{ const response = await api.post<{
sql?: string, sql?: string,
result?: unknown, result?: unknown,
error?: string, error?: string,
chart?: { chart_spec: ChartSpec, reasoning: string, can_visualize: boolean } chart?: { chart_spec?: ChartSpec | null, reasoning?: string, can_visualize?: boolean, chart_type?: string }
}>('/api/v1/agent/nl2sql', { }>('/api/v1/agent/nl2sql', {
query: messagePayload, query: messagePayload,
source: source, source: source,
file_url: currentAttachedFile?.url,
session_id: activeSessionKey, session_id: activeSessionKey,
model_id: selectedModelId model_id: selectedModelId
}); });
@@ -289,12 +359,25 @@ export function ChatInterface() {
} else { } else {
const rows = Array.isArray(response.result) ? response.result : []; const rows = Array.isArray(response.result) ? response.result : [];
const sql = typeof response.sql === "string" ? response.sql : ""; const sql = typeof response.sql === "string" ? response.sql : "";
const chart = response.chart;
const canVisualize = Boolean(chart?.can_visualize);
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
setMessages(prev => [...prev, { setMessages(prev => [...prev, {
id: (Date.now() + 1).toString(), id: (Date.now() + 1).toString(),
role: 'assistant', role: 'assistant',
content: `I've generated a SQL query and fetched ${rows.length} rows for you. Check the visualization panel.${response.chart?.reasoning ? `\n\nVisualization reasoning: ${response.chart.reasoning}` : ''}` content: `已为你生成 SQL 并查询到 ${rows.length} 行数据。${canVisualize ? '可视化面板已同步更新图表。' : '本次结果不适合图表展示。'}${chart?.reasoning ? `\n\n可视化说明:${chart.reasoning}` : ''}`
}]); }]);
setVisualization(rows, sql, response.chart?.chart_spec); setVisualization(
rows,
sql,
chartSpec,
{
canVisualize,
reasoning: chart?.reasoning,
chartType: chart?.chart_type,
description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化",
}
);
} }
} }
} catch (error: any) { } catch (error: any) {
@@ -353,6 +436,17 @@ export function ChatInterface() {
</Command> </Command>
</PopoverContent> </PopoverContent>
</Popover> </Popover>
<div className="flex items-center gap-2 bg-white/80 backdrop-blur-sm rounded-md px-3 py-2 text-sm text-zinc-700">
<span className="text-zinc-500"></span>
<select
value={selectedDataSource}
onChange={(e) => setSelectedDataSource(e.target.value)}
className="bg-transparent border-none outline-none text-sm font-medium"
>
<option value="postgres-main">PostgreSQL</option>
<option value="clickhouse-main">ClickHouse</option>
</select>
</div>
</div> </div>
<ScrollArea className="flex-1 h-[calc(100vh-100px)]"> <ScrollArea className="flex-1 h-[calc(100vh-100px)]">
+4 -13
View File
@@ -9,23 +9,14 @@ interface VegaChartProps {
export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => { export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
const vegaSpec: any = { const vegaSpec: any = {
$schema: 'https://vega.github.io/schema/vega-lite/v5.json', $schema: typeof spec.$schema === 'string' ? spec.$schema : 'https://vega.github.io/schema/vega-lite/v5.json',
description: spec.description, ...spec,
title: spec.title,
width: "container", width: "container",
height: "container", height: "container",
mark: { type: spec.chart_type, tooltip: true }, data: { values: data },
encoding: { autosize: { type: "fit", contains: "padding" },
x: { field: spec.x_axis, type: 'nominal', axis: { labelAngle: -45 } },
y: { field: spec.y_axis, type: 'quantitative' },
},
data: { values: data }
}; };
if (spec.color) {
vegaSpec.encoding.color = { field: spec.color, type: 'nominal' };
}
return ( return (
<div className="w-full h-full"> <div className="w-full h-full">
<VegaEmbed <VegaEmbed
@@ -12,15 +12,17 @@ import { VegaChart } from "./VegaChart";
export function VisualizationPanel() { export function VisualizationPanel() {
const [view, setView] = useState<'table' | 'chart'>('chart'); const [view, setView] = useState<'table' | 'chart'>('chart');
const { addChart } = useDashboardStore(); const { addChart } = useDashboardStore();
const { currentData, currentSQL, currentChartSpec, isLoading, error } = useVisualizationStore(); const { currentData, currentSQL, currentChartSpec, currentChartInfo, isLoading, error } = useVisualizationStore();
const handleAddToDashboard = () => { const handleAddToDashboard = () => {
if (!currentData || !currentSQL) return; if (!currentData || !currentSQL) return;
const mark = currentChartSpec?.mark;
const markType = typeof mark === "string" ? mark : mark?.type;
const dashboardType = markType === "line" ? "line" : "bar";
addChart({ addChart({
id: Date.now().toString(), id: Date.now().toString(),
title: currentChartSpec?.title || 'Generated Analysis', title: currentChartSpec?.title || 'Generated Analysis',
type: currentChartSpec?.chart_type as any || 'bar', type: dashboardType,
data: currentData, data: currentData,
sql: currentSQL, sql: currentSQL,
}); });
@@ -134,7 +136,7 @@ export function VisualizationPanel() {
<Card className="h-full flex flex-col shadow-sm border-muted"> <Card className="h-full flex flex-col shadow-sm border-muted">
<CardHeader className="pb-2 shrink-0"> <CardHeader className="pb-2 shrink-0">
<CardTitle>{currentChartSpec?.title || 'Analysis Result'}</CardTitle> <CardTitle>{currentChartSpec?.title || 'Analysis Result'}</CardTitle>
<CardDescription>{currentChartSpec?.description || 'Generated from your query'}</CardDescription> <CardDescription>{currentChartInfo?.reasoning || currentChartSpec?.description || 'Generated from your query'}</CardDescription>
</CardHeader> </CardHeader>
<CardContent className="flex-1 min-h-0 p-4"> <CardContent className="flex-1 min-h-0 p-4">
{view === 'chart' ? ( {view === 'chart' ? (
+24 -8
View File
@@ -1,11 +1,19 @@
import { create } from 'zustand'; import { create } from 'zustand';
export interface ChartSpec { export interface ChartSpec {
chart_type: string; $schema?: string;
title: string; title?: string;
x_axis: string; description?: string;
y_axis: string; mark?: string | { type?: string; [key: string]: unknown };
color?: string; encoding?: Record<string, unknown>;
transform?: Array<Record<string, unknown>>;
[key: string]: unknown;
}
export interface ChartInfo {
canVisualize: boolean;
reasoning?: string;
chartType?: string;
description?: string; description?: string;
} }
@@ -13,9 +21,10 @@ export interface VisualizationState {
currentData: any[] | null; currentData: any[] | null;
currentSQL: string | null; currentSQL: string | null;
currentChartSpec: ChartSpec | null; currentChartSpec: ChartSpec | null;
currentChartInfo: ChartInfo | null;
isLoading: boolean; isLoading: boolean;
error: string | null; error: string | null;
setVisualization: (data: any[], sql: string, chartSpec?: ChartSpec | null) => void; setVisualization: (data: any[], sql: string, chartSpec?: ChartSpec | null, chartInfo?: ChartInfo | null) => void;
setLoading: (loading: boolean) => void; setLoading: (loading: boolean) => void;
setError: (error: string | null) => void; setError: (error: string | null) => void;
clearVisualization: () => void; clearVisualization: () => void;
@@ -25,10 +34,17 @@ export const useVisualizationStore = create<VisualizationState>((set) => ({
currentData: null, currentData: null,
currentSQL: null, currentSQL: null,
currentChartSpec: null, currentChartSpec: null,
currentChartInfo: null,
isLoading: false, isLoading: false,
error: null, error: null,
setVisualization: (data, sql, chartSpec = null) => set({ currentData: data, currentSQL: sql, currentChartSpec: chartSpec, error: null }), setVisualization: (data, sql, chartSpec = null, chartInfo = null) => set({
currentData: data,
currentSQL: sql,
currentChartSpec: chartSpec,
currentChartInfo: chartInfo,
error: null,
}),
setLoading: (loading) => set({ isLoading: loading }), setLoading: (loading) => set({ isLoading: loading }),
setError: (error) => set({ error, isLoading: false }), setError: (error) => set({ error, isLoading: false }),
clearVisualization: () => set({ currentData: null, currentSQL: null, currentChartSpec: null, error: null }), clearVisualization: () => set({ currentData: null, currentSQL: null, currentChartSpec: null, currentChartInfo: null, error: null }),
})); }));