feature: nl2sql first successful
This commit is contained in:
@@ -4,6 +4,8 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
import duckdb
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
# Add project root to sys.path to allow importing nanobot
|
# Add project root to sys.path to allow importing nanobot
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||||
@@ -19,7 +21,8 @@ from app.agent.chart import generate_chart
|
|||||||
|
|
||||||
class NL2SQLRequest(BaseModel):
|
class NL2SQLRequest(BaseModel):
|
||||||
query: str = Field(..., description="User's natural language query")
|
query: str = Field(..., description="User's natural language query")
|
||||||
source: str = Field(..., description="Data source to query (postgres, clickhouse)")
|
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
||||||
|
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
||||||
|
|
||||||
class NL2SQLResponse(BaseModel):
|
class NL2SQLResponse(BaseModel):
|
||||||
sql: str
|
sql: str
|
||||||
@@ -80,19 +83,62 @@ The final answer must be a ANSI SQL query in JSON format:
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def _resolve_upload_file_path(file_url: Optional[str]) -> Path:
|
||||||
|
if not file_url or not file_url.startswith("local://"):
|
||||||
|
raise ValueError("Invalid uploaded file URL")
|
||||||
|
raw_name = file_url.replace("local://", "", 1)
|
||||||
|
safe_name = os.path.basename(raw_name)
|
||||||
|
upload_dir = Path(__file__).resolve().parents[2] / "data" / "uploads"
|
||||||
|
file_path = upload_dir / safe_name
|
||||||
|
if not file_path.exists():
|
||||||
|
raise ValueError(f"Uploaded file not found: {safe_name}")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
|
||||||
|
file_path = _resolve_upload_file_path(file_url)
|
||||||
|
suffix = file_path.suffix.lower()
|
||||||
|
if suffix == ".csv":
|
||||||
|
return pd.read_csv(file_path)
|
||||||
|
if suffix in [".xls", ".xlsx"]:
|
||||||
|
return pd.read_excel(file_path)
|
||||||
|
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
||||||
|
|
||||||
|
def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
|
||||||
|
df = _load_upload_dataframe(file_url)
|
||||||
|
conn = duckdb.connect(":memory:")
|
||||||
|
conn.register("uploaded_file", df)
|
||||||
|
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
|
||||||
|
schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]}
|
||||||
|
conn.close()
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]:
|
||||||
|
df = _load_upload_dataframe(file_url)
|
||||||
|
conn = duckdb.connect(":memory:")
|
||||||
|
conn.register("uploaded_file", df)
|
||||||
|
result_df = conn.execute(sql_query).df()
|
||||||
|
conn.close()
|
||||||
|
return result_df.to_dict(orient="records")
|
||||||
|
|
||||||
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||||
# 1. Get the connector and schema
|
# 1. Get the connector and schema
|
||||||
connector = None
|
connector = None
|
||||||
|
schema = {}
|
||||||
if request.source == "postgres":
|
if request.source == "postgres":
|
||||||
connector = postgres_connector
|
connector = postgres_connector
|
||||||
elif request.source == "clickhouse":
|
elif request.source == "clickhouse":
|
||||||
connector = clickhouse_connector
|
connector = clickhouse_connector
|
||||||
|
elif request.source == "upload":
|
||||||
|
try:
|
||||||
|
schema = _get_upload_schema(request.file_url)
|
||||||
|
except Exception as e:
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
||||||
else:
|
else:
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||||
|
|
||||||
|
if connector:
|
||||||
if not connector.test_connection():
|
if not connector.test_connection():
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||||
|
|
||||||
schema = connector.get_schema()
|
schema = connector.get_schema()
|
||||||
schema_str = json.dumps(schema, indent=2)
|
schema_str = json.dumps(schema, indent=2)
|
||||||
|
|
||||||
@@ -158,6 +204,9 @@ Let's think step by step.
|
|||||||
|
|
||||||
# 6. Execute SQL
|
# 6. Execute SQL
|
||||||
try:
|
try:
|
||||||
|
if request.source == "upload":
|
||||||
|
formatted_results = _execute_upload_sql(sql_query, request.file_url)
|
||||||
|
else:
|
||||||
results = connector.execute_query(sql_query)
|
results = connector.execute_query(sql_query)
|
||||||
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
||||||
formatted_results = []
|
formatted_results = []
|
||||||
|
|||||||
@@ -67,6 +67,9 @@ class ChatRequest(BaseModel):
|
|||||||
session_id: str = "api:default"
|
session_id: str = "api:default"
|
||||||
skill_ids: Optional[List[str]] = None
|
skill_ids: Optional[List[str]] = None
|
||||||
model_id: Optional[str] = None
|
model_id: Optional[str] = None
|
||||||
|
source: str = "postgres"
|
||||||
|
prefer_sql_chart: bool = False
|
||||||
|
file_url: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class SessionAliasUpdateRequest(BaseModel):
|
class SessionAliasUpdateRequest(BaseModel):
|
||||||
@@ -77,6 +80,27 @@ class SessionAliasUpdateRequest(BaseModel):
|
|||||||
@app.post("/nanobot/chat")
|
@app.post("/nanobot/chat")
|
||||||
async def nanobot_chat(request: ChatRequest):
|
async def nanobot_chat(request: ChatRequest):
|
||||||
try:
|
try:
|
||||||
|
if request.prefer_sql_chart:
|
||||||
|
nl2sql_result = await process_nl2sql(
|
||||||
|
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
||||||
|
)
|
||||||
|
chart = nl2sql_result.chart
|
||||||
|
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
||||||
|
text = (
|
||||||
|
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
|
||||||
|
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
||||||
|
)
|
||||||
|
if chart and chart.reasoning:
|
||||||
|
text = f"{text}\n\n可视化说明:{chart.reasoning}"
|
||||||
|
return {
|
||||||
|
"response": text,
|
||||||
|
"viz": {
|
||||||
|
"sql": nl2sql_result.sql,
|
||||||
|
"result": nl2sql_result.result,
|
||||||
|
"chart": chart.model_dump() if chart else None,
|
||||||
|
"error": nl2sql_result.error,
|
||||||
|
},
|
||||||
|
}
|
||||||
response = await nanobot_service.process_message(
|
response = await nanobot_service.process_message(
|
||||||
request.message,
|
request.message,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
@@ -91,6 +115,29 @@ async def nanobot_chat(request: ChatRequest):
|
|||||||
async def nanobot_chat_stream(request: ChatRequest):
|
async def nanobot_chat_stream(request: ChatRequest):
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
try:
|
try:
|
||||||
|
if request.prefer_sql_chart:
|
||||||
|
nl2sql_result = await process_nl2sql(
|
||||||
|
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
||||||
|
)
|
||||||
|
chart = nl2sql_result.chart
|
||||||
|
viz_payload = {
|
||||||
|
"type": "viz",
|
||||||
|
"sql": nl2sql_result.sql,
|
||||||
|
"result": nl2sql_result.result,
|
||||||
|
"chart": chart.model_dump() if chart else None,
|
||||||
|
"error": nl2sql_result.error,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
|
||||||
|
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
||||||
|
text = (
|
||||||
|
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
|
||||||
|
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
||||||
|
)
|
||||||
|
if chart and chart.reasoning:
|
||||||
|
text = f"{text}\n\n可视化说明:{chart.reasoning}"
|
||||||
|
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||||
|
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||||
|
return
|
||||||
response = await nanobot_service.process_message(
|
response = await nanobot_service.process_message(
|
||||||
request.message,
|
request.message,
|
||||||
session_id=request.session_id,
|
session_id=request.session_id,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { BrowserRouter, Routes, Route, Navigate } from "react-router-dom";
|
import { BrowserRouter, Routes, Route, Navigate } from "react-router-dom";
|
||||||
import { Sidebar } from "./components/Sidebar";
|
import { Sidebar } from "./components/Sidebar";
|
||||||
import { ChatInterface } from "./components/ChatInterface";
|
import { ChatInterface } from "./components/ChatInterface";
|
||||||
|
import { VisualizationPanel } from "./components/VisualizationPanel";
|
||||||
import { Dashboard } from "./pages/Dashboard";
|
import { Dashboard } from "./pages/Dashboard";
|
||||||
import { Skills } from "./pages/Skills";
|
import { Skills } from "./pages/Skills";
|
||||||
import { Settings } from "./pages/Settings";
|
import { Settings } from "./pages/Settings";
|
||||||
@@ -45,9 +46,14 @@ function App() {
|
|||||||
<Route path="/" element={
|
<Route path="/" element={
|
||||||
<ProtectedRoute>
|
<ProtectedRoute>
|
||||||
<MainLayout>
|
<MainLayout>
|
||||||
<div className="h-full overflow-hidden bg-white">
|
<div className="h-full overflow-hidden bg-white flex">
|
||||||
|
<div className="flex-1 min-w-0">
|
||||||
<ChatInterface />
|
<ChatInterface />
|
||||||
</div>
|
</div>
|
||||||
|
<div className="w-[42%] min-w-[420px] border-l bg-background">
|
||||||
|
<VisualizationPanel />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</MainLayout>
|
</MainLayout>
|
||||||
</ProtectedRoute>
|
</ProtectedRoute>
|
||||||
} />
|
} />
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ export function ChatInterface() {
|
|||||||
const [messages, setMessages] = useState<Message[]>([]);
|
const [messages, setMessages] = useState<Message[]>([]);
|
||||||
const [input, setInput] = useState("");
|
const [input, setInput] = useState("");
|
||||||
const [selectedCapability, setSelectedCapability] = useState<string>("智能问答");
|
const [selectedCapability, setSelectedCapability] = useState<string>("智能问答");
|
||||||
const selectedDataSource = "postgres-main";
|
const [selectedDataSource, setSelectedDataSource] = useState<string>("postgres-main");
|
||||||
const [isLoading, setIsLoading] = useState(false);
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
const scrollRef = useRef<HTMLDivElement>(null);
|
const scrollRef = useRef<HTMLDivElement>(null);
|
||||||
const { setVisualization, setLoading: setVizLoading, setError: setVizError } = useVisualizationStore();
|
const { setVisualization, setLoading: setVizLoading, setError: setVizError } = useVisualizationStore();
|
||||||
@@ -114,6 +114,7 @@ export function ChatInterface() {
|
|||||||
{ icon: Table, label: "表格问答", color: "text-orange-500", bg: "bg-orange-50" },
|
{ icon: Table, label: "表格问答", color: "text-orange-500", bg: "bg-orange-50" },
|
||||||
{ icon: Search, label: "深度问数", color: "text-blue-500", bg: "bg-blue-50" },
|
{ icon: Search, label: "深度问数", color: "text-blue-500", bg: "bg-blue-50" },
|
||||||
];
|
];
|
||||||
|
const chartIntentPattern = /(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|chart|plot|visuali[sz]e)/i;
|
||||||
|
|
||||||
const handleFileUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
|
const handleFileUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
const file = e.target.files?.[0];
|
const file = e.target.files?.[0];
|
||||||
@@ -168,8 +169,9 @@ export function ChatInterface() {
|
|||||||
setInput("");
|
setInput("");
|
||||||
|
|
||||||
let messagePayload = newMessage.content;
|
let messagePayload = newMessage.content;
|
||||||
if (attachedFile) {
|
const currentAttachedFile = attachedFile;
|
||||||
messagePayload = `[用户上传了文件: ${attachedFile.filename}]\n[文件内容摘要: ${attachedFile.summary || "无"}]\n[数据列: ${attachedFile.columns?.join(", ") || "无"}]\n[文件下载链接: ${attachedFile.url}]\n\n${newMessage.content}`;
|
if (currentAttachedFile) {
|
||||||
|
messagePayload = `[用户上传了文件: ${currentAttachedFile.filename}]\n[文件内容摘要: ${currentAttachedFile.summary || "无"}]\n[数据列: ${currentAttachedFile.columns?.join(", ") || "无"}]\n[文件下载链接: ${currentAttachedFile.url}]\n\n${newMessage.content}`;
|
||||||
setAttachedFile(null);
|
setAttachedFile(null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,6 +191,9 @@ export function ChatInterface() {
|
|||||||
|
|
||||||
const token = localStorage.getItem("token");
|
const token = localStorage.getItem("token");
|
||||||
const effectiveModelId = selectedModelId || currentModel?.id || "";
|
const effectiveModelId = selectedModelId || currentModel?.id || "";
|
||||||
|
const source = currentAttachedFile?.url?.startsWith("local://") ? "upload" : selectedDataSource.split('-')[0];
|
||||||
|
const fileUrl = currentAttachedFile?.url || undefined;
|
||||||
|
const preferSqlChart = chartIntentPattern.test(messagePayload);
|
||||||
const response = await fetch("/nanobot/chat/stream", {
|
const response = await fetch("/nanobot/chat/stream", {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
@@ -199,6 +204,9 @@ export function ChatInterface() {
|
|||||||
message: messagePayload,
|
message: messagePayload,
|
||||||
session_id: activeSessionKey,
|
session_id: activeSessionKey,
|
||||||
model_id: effectiveModelId,
|
model_id: effectiveModelId,
|
||||||
|
source,
|
||||||
|
prefer_sql_chart: preferSqlChart,
|
||||||
|
file_url: fileUrl,
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -226,7 +234,14 @@ export function ChatInterface() {
|
|||||||
if (!line) continue;
|
if (!line) continue;
|
||||||
const payloadText = line.slice(5).trim();
|
const payloadText = line.slice(5).trim();
|
||||||
if (!payloadText) continue;
|
if (!payloadText) continue;
|
||||||
const payload = JSON.parse(payloadText) as { type: string; content?: string };
|
const payload = JSON.parse(payloadText) as {
|
||||||
|
type: string;
|
||||||
|
content?: string;
|
||||||
|
sql?: string;
|
||||||
|
result?: unknown;
|
||||||
|
error?: string;
|
||||||
|
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
|
||||||
|
};
|
||||||
|
|
||||||
if (payload.type === "delta" && payload.content) {
|
if (payload.type === "delta" && payload.content) {
|
||||||
streamedText = `${streamedText}${payload.content}`;
|
streamedText = `${streamedText}${payload.content}`;
|
||||||
@@ -249,15 +264,69 @@ export function ChatInterface() {
|
|||||||
if (payload.type === "error") {
|
if (payload.type === "error") {
|
||||||
throw new Error(payload.content || "流式响应错误");
|
throw new Error(payload.content || "流式响应错误");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (payload.type === "viz") {
|
||||||
|
if (payload.error) {
|
||||||
|
setVizError(payload.error);
|
||||||
|
} else {
|
||||||
|
const rows = Array.isArray(payload.result) ? payload.result : [];
|
||||||
|
const sql = typeof payload.sql === "string" ? payload.sql : "";
|
||||||
|
const chart = payload.chart ?? undefined;
|
||||||
|
const canVisualize = Boolean(chart?.can_visualize);
|
||||||
|
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
|
||||||
|
setVisualization(
|
||||||
|
rows,
|
||||||
|
sql,
|
||||||
|
chartSpec,
|
||||||
|
{
|
||||||
|
canVisualize,
|
||||||
|
reasoning: chart?.reasoning,
|
||||||
|
chartType: chart?.chart_type,
|
||||||
|
description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化",
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!streamedText) {
|
if (!streamedText) {
|
||||||
const fallback = await api.post<{ response: string }>("/nanobot/chat", {
|
const fallback = await api.post<{
|
||||||
|
response: string;
|
||||||
|
viz?: {
|
||||||
|
sql?: string;
|
||||||
|
result?: unknown;
|
||||||
|
error?: string | null;
|
||||||
|
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
|
||||||
|
};
|
||||||
|
}>("/nanobot/chat", {
|
||||||
message: messagePayload,
|
message: messagePayload,
|
||||||
session_id: activeSessionKey,
|
session_id: activeSessionKey,
|
||||||
model_id: effectiveModelId,
|
model_id: effectiveModelId,
|
||||||
|
source,
|
||||||
|
prefer_sql_chart: preferSqlChart,
|
||||||
|
file_url: fileUrl,
|
||||||
});
|
});
|
||||||
|
if (fallback.viz?.error) {
|
||||||
|
setVizError(fallback.viz.error);
|
||||||
|
} else if (fallback.viz) {
|
||||||
|
const rows = Array.isArray(fallback.viz.result) ? fallback.viz.result : [];
|
||||||
|
const sql = typeof fallback.viz.sql === "string" ? fallback.viz.sql : "";
|
||||||
|
const chart = fallback.viz.chart ?? undefined;
|
||||||
|
const canVisualize = Boolean(chart?.can_visualize);
|
||||||
|
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
|
||||||
|
setVisualization(
|
||||||
|
rows,
|
||||||
|
sql,
|
||||||
|
chartSpec,
|
||||||
|
{
|
||||||
|
canVisualize,
|
||||||
|
reasoning: chart?.reasoning,
|
||||||
|
chartType: chart?.chart_type,
|
||||||
|
description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化",
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
prev.map((msg) =>
|
prev.map((msg) =>
|
||||||
msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false } : msg
|
msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false } : msg
|
||||||
@@ -266,15 +335,16 @@ export function ChatInterface() {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Fallback to existing NL2SQL or other skills (e.g. for "表格问答" or "深度问数")
|
// Fallback to existing NL2SQL or other skills (e.g. for "表格问答" or "深度问数")
|
||||||
const source = selectedDataSource.split('-')[0]; // postgres-main -> postgres
|
const source = currentAttachedFile?.url?.startsWith("local://") ? "upload" : selectedDataSource.split('-')[0];
|
||||||
const response = await api.post<{
|
const response = await api.post<{
|
||||||
sql?: string,
|
sql?: string,
|
||||||
result?: unknown,
|
result?: unknown,
|
||||||
error?: string,
|
error?: string,
|
||||||
chart?: { chart_spec: ChartSpec, reasoning: string, can_visualize: boolean }
|
chart?: { chart_spec?: ChartSpec | null, reasoning?: string, can_visualize?: boolean, chart_type?: string }
|
||||||
}>('/api/v1/agent/nl2sql', {
|
}>('/api/v1/agent/nl2sql', {
|
||||||
query: messagePayload,
|
query: messagePayload,
|
||||||
source: source,
|
source: source,
|
||||||
|
file_url: currentAttachedFile?.url,
|
||||||
session_id: activeSessionKey,
|
session_id: activeSessionKey,
|
||||||
model_id: selectedModelId
|
model_id: selectedModelId
|
||||||
});
|
});
|
||||||
@@ -289,12 +359,25 @@ export function ChatInterface() {
|
|||||||
} else {
|
} else {
|
||||||
const rows = Array.isArray(response.result) ? response.result : [];
|
const rows = Array.isArray(response.result) ? response.result : [];
|
||||||
const sql = typeof response.sql === "string" ? response.sql : "";
|
const sql = typeof response.sql === "string" ? response.sql : "";
|
||||||
|
const chart = response.chart;
|
||||||
|
const canVisualize = Boolean(chart?.can_visualize);
|
||||||
|
const chartSpec = canVisualize ? (chart?.chart_spec ?? null) : null;
|
||||||
setMessages(prev => [...prev, {
|
setMessages(prev => [...prev, {
|
||||||
id: (Date.now() + 1).toString(),
|
id: (Date.now() + 1).toString(),
|
||||||
role: 'assistant',
|
role: 'assistant',
|
||||||
content: `I've generated a SQL query and fetched ${rows.length} rows for you. Check the visualization panel.${response.chart?.reasoning ? `\n\nVisualization reasoning: ${response.chart.reasoning}` : ''}`
|
content: `已为你生成 SQL 并查询到 ${rows.length} 行数据。${canVisualize ? '可视化面板已同步更新图表。' : '本次结果不适合图表展示。'}${chart?.reasoning ? `\n\n可视化说明:${chart.reasoning}` : ''}`
|
||||||
}]);
|
}]);
|
||||||
setVisualization(rows, sql, response.chart?.chart_spec);
|
setVisualization(
|
||||||
|
rows,
|
||||||
|
sql,
|
||||||
|
chartSpec,
|
||||||
|
{
|
||||||
|
canVisualize,
|
||||||
|
reasoning: chart?.reasoning,
|
||||||
|
chartType: chart?.chart_type,
|
||||||
|
description: canVisualize ? "根据模型返回的 Vega-Lite schema 渲染" : "当前结果不适合可视化",
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
@@ -353,6 +436,17 @@ export function ChatInterface() {
|
|||||||
</Command>
|
</Command>
|
||||||
</PopoverContent>
|
</PopoverContent>
|
||||||
</Popover>
|
</Popover>
|
||||||
|
<div className="flex items-center gap-2 bg-white/80 backdrop-blur-sm rounded-md px-3 py-2 text-sm text-zinc-700">
|
||||||
|
<span className="text-zinc-500">数据源</span>
|
||||||
|
<select
|
||||||
|
value={selectedDataSource}
|
||||||
|
onChange={(e) => setSelectedDataSource(e.target.value)}
|
||||||
|
className="bg-transparent border-none outline-none text-sm font-medium"
|
||||||
|
>
|
||||||
|
<option value="postgres-main">PostgreSQL</option>
|
||||||
|
<option value="clickhouse-main">ClickHouse</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<ScrollArea className="flex-1 h-[calc(100vh-100px)]">
|
<ScrollArea className="flex-1 h-[calc(100vh-100px)]">
|
||||||
|
|||||||
@@ -9,23 +9,14 @@ interface VegaChartProps {
|
|||||||
|
|
||||||
export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
|
export const VegaChart: React.FC<VegaChartProps> = ({ data, spec }) => {
|
||||||
const vegaSpec: any = {
|
const vegaSpec: any = {
|
||||||
$schema: 'https://vega.github.io/schema/vega-lite/v5.json',
|
$schema: typeof spec.$schema === 'string' ? spec.$schema : 'https://vega.github.io/schema/vega-lite/v5.json',
|
||||||
description: spec.description,
|
...spec,
|
||||||
title: spec.title,
|
|
||||||
width: "container",
|
width: "container",
|
||||||
height: "container",
|
height: "container",
|
||||||
mark: { type: spec.chart_type, tooltip: true },
|
data: { values: data },
|
||||||
encoding: {
|
autosize: { type: "fit", contains: "padding" },
|
||||||
x: { field: spec.x_axis, type: 'nominal', axis: { labelAngle: -45 } },
|
|
||||||
y: { field: spec.y_axis, type: 'quantitative' },
|
|
||||||
},
|
|
||||||
data: { values: data }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (spec.color) {
|
|
||||||
vegaSpec.encoding.color = { field: spec.color, type: 'nominal' };
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full h-full">
|
<div className="w-full h-full">
|
||||||
<VegaEmbed
|
<VegaEmbed
|
||||||
|
|||||||
@@ -12,15 +12,17 @@ import { VegaChart } from "./VegaChart";
|
|||||||
export function VisualizationPanel() {
|
export function VisualizationPanel() {
|
||||||
const [view, setView] = useState<'table' | 'chart'>('chart');
|
const [view, setView] = useState<'table' | 'chart'>('chart');
|
||||||
const { addChart } = useDashboardStore();
|
const { addChart } = useDashboardStore();
|
||||||
const { currentData, currentSQL, currentChartSpec, isLoading, error } = useVisualizationStore();
|
const { currentData, currentSQL, currentChartSpec, currentChartInfo, isLoading, error } = useVisualizationStore();
|
||||||
|
|
||||||
const handleAddToDashboard = () => {
|
const handleAddToDashboard = () => {
|
||||||
if (!currentData || !currentSQL) return;
|
if (!currentData || !currentSQL) return;
|
||||||
|
const mark = currentChartSpec?.mark;
|
||||||
|
const markType = typeof mark === "string" ? mark : mark?.type;
|
||||||
|
const dashboardType = markType === "line" ? "line" : "bar";
|
||||||
addChart({
|
addChart({
|
||||||
id: Date.now().toString(),
|
id: Date.now().toString(),
|
||||||
title: currentChartSpec?.title || 'Generated Analysis',
|
title: currentChartSpec?.title || 'Generated Analysis',
|
||||||
type: currentChartSpec?.chart_type as any || 'bar',
|
type: dashboardType,
|
||||||
data: currentData,
|
data: currentData,
|
||||||
sql: currentSQL,
|
sql: currentSQL,
|
||||||
});
|
});
|
||||||
@@ -134,7 +136,7 @@ export function VisualizationPanel() {
|
|||||||
<Card className="h-full flex flex-col shadow-sm border-muted">
|
<Card className="h-full flex flex-col shadow-sm border-muted">
|
||||||
<CardHeader className="pb-2 shrink-0">
|
<CardHeader className="pb-2 shrink-0">
|
||||||
<CardTitle>{currentChartSpec?.title || 'Analysis Result'}</CardTitle>
|
<CardTitle>{currentChartSpec?.title || 'Analysis Result'}</CardTitle>
|
||||||
<CardDescription>{currentChartSpec?.description || 'Generated from your query'}</CardDescription>
|
<CardDescription>{currentChartInfo?.reasoning || currentChartSpec?.description || 'Generated from your query'}</CardDescription>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="flex-1 min-h-0 p-4">
|
<CardContent className="flex-1 min-h-0 p-4">
|
||||||
{view === 'chart' ? (
|
{view === 'chart' ? (
|
||||||
|
|||||||
@@ -1,11 +1,19 @@
|
|||||||
import { create } from 'zustand';
|
import { create } from 'zustand';
|
||||||
|
|
||||||
export interface ChartSpec {
|
export interface ChartSpec {
|
||||||
chart_type: string;
|
$schema?: string;
|
||||||
title: string;
|
title?: string;
|
||||||
x_axis: string;
|
description?: string;
|
||||||
y_axis: string;
|
mark?: string | { type?: string; [key: string]: unknown };
|
||||||
color?: string;
|
encoding?: Record<string, unknown>;
|
||||||
|
transform?: Array<Record<string, unknown>>;
|
||||||
|
[key: string]: unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChartInfo {
|
||||||
|
canVisualize: boolean;
|
||||||
|
reasoning?: string;
|
||||||
|
chartType?: string;
|
||||||
description?: string;
|
description?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -13,9 +21,10 @@ export interface VisualizationState {
|
|||||||
currentData: any[] | null;
|
currentData: any[] | null;
|
||||||
currentSQL: string | null;
|
currentSQL: string | null;
|
||||||
currentChartSpec: ChartSpec | null;
|
currentChartSpec: ChartSpec | null;
|
||||||
|
currentChartInfo: ChartInfo | null;
|
||||||
isLoading: boolean;
|
isLoading: boolean;
|
||||||
error: string | null;
|
error: string | null;
|
||||||
setVisualization: (data: any[], sql: string, chartSpec?: ChartSpec | null) => void;
|
setVisualization: (data: any[], sql: string, chartSpec?: ChartSpec | null, chartInfo?: ChartInfo | null) => void;
|
||||||
setLoading: (loading: boolean) => void;
|
setLoading: (loading: boolean) => void;
|
||||||
setError: (error: string | null) => void;
|
setError: (error: string | null) => void;
|
||||||
clearVisualization: () => void;
|
clearVisualization: () => void;
|
||||||
@@ -25,10 +34,17 @@ export const useVisualizationStore = create<VisualizationState>((set) => ({
|
|||||||
currentData: null,
|
currentData: null,
|
||||||
currentSQL: null,
|
currentSQL: null,
|
||||||
currentChartSpec: null,
|
currentChartSpec: null,
|
||||||
|
currentChartInfo: null,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
error: null,
|
error: null,
|
||||||
setVisualization: (data, sql, chartSpec = null) => set({ currentData: data, currentSQL: sql, currentChartSpec: chartSpec, error: null }),
|
setVisualization: (data, sql, chartSpec = null, chartInfo = null) => set({
|
||||||
|
currentData: data,
|
||||||
|
currentSQL: sql,
|
||||||
|
currentChartSpec: chartSpec,
|
||||||
|
currentChartInfo: chartInfo,
|
||||||
|
error: null,
|
||||||
|
}),
|
||||||
setLoading: (loading) => set({ isLoading: loading }),
|
setLoading: (loading) => set({ isLoading: loading }),
|
||||||
setError: (error) => set({ error, isLoading: false }),
|
setError: (error) => set({ error, isLoading: false }),
|
||||||
clearVisualization: () => set({ currentData: null, currentSQL: null, currentChartSpec: null, error: null }),
|
clearVisualization: () => set({ currentData: null, currentSQL: null, currentChartSpec: null, currentChartInfo: null, error: null }),
|
||||||
}));
|
}));
|
||||||
|
|||||||
Reference in New Issue
Block a user