fix: session bug
This commit is contained in:
@@ -23,6 +23,7 @@ class NL2SQLRequest(BaseModel):
|
|||||||
query: str = Field(..., description="User's natural language query")
|
query: str = Field(..., description="User's natural language query")
|
||||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
||||||
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
||||||
|
session_id: Optional[str] = Field(None, description="Conversation session identifier")
|
||||||
|
|
||||||
class NL2SQLResponse(BaseModel):
|
class NL2SQLResponse(BaseModel):
|
||||||
sql: str
|
sql: str
|
||||||
|
|||||||
+47
-24
@@ -77,13 +77,8 @@ class SessionAliasUpdateRequest(BaseModel):
|
|||||||
pinned: Optional[bool] = None
|
pinned: Optional[bool] = None
|
||||||
archived: Optional[bool] = None
|
archived: Optional[bool] = None
|
||||||
|
|
||||||
@app.post("/nanobot/chat")
|
|
||||||
async def nanobot_chat(request: ChatRequest):
|
def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
|
||||||
try:
|
|
||||||
if request.prefer_sql_chart:
|
|
||||||
nl2sql_result = await process_nl2sql(
|
|
||||||
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
|
||||||
)
|
|
||||||
chart = nl2sql_result.chart
|
chart = nl2sql_result.chart
|
||||||
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
||||||
text = (
|
text = (
|
||||||
@@ -91,15 +86,46 @@ async def nanobot_chat(request: ChatRequest):
|
|||||||
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
||||||
)
|
)
|
||||||
if chart and chart.reasoning:
|
if chart and chart.reasoning:
|
||||||
text = f"{text}\n\n可视化说明:{chart.reasoning}"
|
return f"{text}\n\n可视化说明:{chart.reasoning}"
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
|
||||||
|
chart = nl2sql_result.chart
|
||||||
return {
|
return {
|
||||||
"response": text,
|
|
||||||
"viz": {
|
|
||||||
"sql": nl2sql_result.sql,
|
"sql": nl2sql_result.sql,
|
||||||
"result": nl2sql_result.result,
|
"result": nl2sql_result.result,
|
||||||
"chart": chart.model_dump() if chart else None,
|
"chart": chart.model_dump() if chart else None,
|
||||||
"error": nl2sql_result.error,
|
"error": nl2sql_result.error,
|
||||||
},
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _persist_session_turn(
|
||||||
|
session_id: str,
|
||||||
|
user_message: str,
|
||||||
|
assistant_message: str,
|
||||||
|
assistant_extra: Optional[dict] = None,
|
||||||
|
) -> None:
|
||||||
|
if not nanobot_service.agent:
|
||||||
|
return
|
||||||
|
session = nanobot_service.agent.sessions.get_or_create(session_id)
|
||||||
|
session.add_message("user", user_message)
|
||||||
|
session.add_message("assistant", assistant_message, **(assistant_extra or {}))
|
||||||
|
nanobot_service.agent.sessions.save(session)
|
||||||
|
|
||||||
|
@app.post("/nanobot/chat")
|
||||||
|
async def nanobot_chat(request: ChatRequest):
|
||||||
|
try:
|
||||||
|
if request.prefer_sql_chart:
|
||||||
|
nl2sql_result = await process_nl2sql(
|
||||||
|
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
||||||
|
)
|
||||||
|
text = _build_sql_chart_text(nl2sql_result)
|
||||||
|
viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||||
|
_persist_session_turn(request.session_id, request.message, text, {"viz": viz_payload})
|
||||||
|
return {
|
||||||
|
"response": text,
|
||||||
|
"viz": viz_payload,
|
||||||
}
|
}
|
||||||
response = await nanobot_service.process_message(
|
response = await nanobot_service.process_message(
|
||||||
request.message,
|
request.message,
|
||||||
@@ -119,22 +145,14 @@ async def nanobot_chat_stream(request: ChatRequest):
|
|||||||
nl2sql_result = await process_nl2sql(
|
nl2sql_result = await process_nl2sql(
|
||||||
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
||||||
)
|
)
|
||||||
chart = nl2sql_result.chart
|
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||||
viz_payload = {
|
viz_payload = {
|
||||||
"type": "viz",
|
"type": "viz",
|
||||||
"sql": nl2sql_result.sql,
|
**persisted_viz_payload,
|
||||||
"result": nl2sql_result.result,
|
|
||||||
"chart": chart.model_dump() if chart else None,
|
|
||||||
"error": nl2sql_result.error,
|
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
|
||||||
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
text = _build_sql_chart_text(nl2sql_result)
|
||||||
text = (
|
_persist_session_turn(request.session_id, request.message, text, {"viz": persisted_viz_payload})
|
||||||
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
|
|
||||||
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
|
||||||
)
|
|
||||||
if chart and chart.reasoning:
|
|
||||||
text = f"{text}\n\n可视化说明:{chart.reasoning}"
|
|
||||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||||
return
|
return
|
||||||
@@ -229,4 +247,9 @@ def update_session(session_id: str, payload: SessionAliasUpdateRequest):
|
|||||||
|
|
||||||
@app.post("/api/v1/agent/nl2sql", response_model=NL2SQLResponse)
|
@app.post("/api/v1/agent/nl2sql", response_model=NL2SQLResponse)
|
||||||
async def run_nl2sql(request: NL2SQLRequest):
|
async def run_nl2sql(request: NL2SQLRequest):
|
||||||
return await process_nl2sql(request)
|
result = await process_nl2sql(request)
|
||||||
|
if request.session_id:
|
||||||
|
text = _build_sql_chart_text(result)
|
||||||
|
viz_payload = _build_sql_chart_viz(result)
|
||||||
|
_persist_session_turn(request.session_id, request.query, text, {"viz": viz_payload})
|
||||||
|
return result
|
||||||
|
|||||||
@@ -85,7 +85,8 @@ export function ChatInterface() {
|
|||||||
const formattedMessages = data.messages.map((m, idx) => ({
|
const formattedMessages = data.messages.map((m, idx) => ({
|
||||||
id: `${Date.now()}-${idx}`,
|
id: `${Date.now()}-${idx}`,
|
||||||
role: m.role as 'user' | 'assistant',
|
role: m.role as 'user' | 'assistant',
|
||||||
content: m.content
|
content: m.content,
|
||||||
|
viz: m.viz ? buildMessageViz(m.viz) : undefined,
|
||||||
}));
|
}));
|
||||||
setMessages(formattedMessages);
|
setMessages(formattedMessages);
|
||||||
} else {
|
} else {
|
||||||
@@ -461,7 +462,7 @@ export function ChatInterface() {
|
|||||||
onChange={handleFileUpload}
|
onChange={handleFileUpload}
|
||||||
/>
|
/>
|
||||||
<div className="min-h-full">
|
<div className="min-h-full">
|
||||||
{messages.length <= 1 ? (
|
{messages.length === 0 ? (
|
||||||
<div className="h-full flex flex-col items-center justify-center pt-[20vh] px-4 pb-32">
|
<div className="h-full flex flex-col items-center justify-center pt-[20vh] px-4 pb-32">
|
||||||
{/* Logo Area */}
|
{/* Logo Area */}
|
||||||
<div className="mb-16 flex items-center justify-center gap-4 select-none">
|
<div className="mb-16 flex items-center justify-center gap-4 select-none">
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ export function InlineVisualizationCard({ viz }: InlineVisualizationCardProps) {
|
|||||||
type: dashboardType,
|
type: dashboardType,
|
||||||
data: objectRows,
|
data: objectRows,
|
||||||
sql: viz.sql,
|
sql: viz.sql,
|
||||||
|
chartSpec: viz.chartSpec,
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ export function VisualizationPanel() {
|
|||||||
type: dashboardType,
|
type: dashboardType,
|
||||||
data: currentData,
|
data: currentData,
|
||||||
sql: currentSQL,
|
sql: currentSQL,
|
||||||
|
chartSpec: currentChartSpec,
|
||||||
});
|
});
|
||||||
alert("Added to Dashboard!");
|
alert("Added to Dashboard!");
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -5,9 +5,39 @@ import { Card, CardContent, CardHeader, CardTitle, CardDescription } from "@/com
|
|||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { X } from "lucide-react";
|
import { X } from "lucide-react";
|
||||||
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer, LineChart, Line } from 'recharts';
|
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer, LineChart, Line } from 'recharts';
|
||||||
|
import { VegaChart } from "@/components/VegaChart";
|
||||||
import 'react-grid-layout/css/styles.css';
|
import 'react-grid-layout/css/styles.css';
|
||||||
import 'react-resizable/css/styles.css';
|
import 'react-resizable/css/styles.css';
|
||||||
|
|
||||||
|
const CHART_COLORS = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6', '#06b6d4'];
|
||||||
|
|
||||||
|
function isNumericValue(value: unknown) {
|
||||||
|
if (typeof value === 'number') return Number.isFinite(value);
|
||||||
|
if (typeof value === 'string') {
|
||||||
|
const trimmed = value.trim();
|
||||||
|
if (!trimmed) return false;
|
||||||
|
const parsed = Number(trimmed);
|
||||||
|
return Number.isFinite(parsed);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
function inferChartKeys(data: Record<string, unknown>[]) {
|
||||||
|
if (data.length === 0) {
|
||||||
|
return { xKey: null as string | null, yKeys: [] as string[] };
|
||||||
|
}
|
||||||
|
const allKeys = Object.keys(data[0] || {});
|
||||||
|
if (allKeys.length === 0) {
|
||||||
|
return { xKey: null as string | null, yKeys: [] as string[] };
|
||||||
|
}
|
||||||
|
const preferredX = ['name', 'date', 'time', 'category', 'label'];
|
||||||
|
const xKey = preferredX.find((k) => allKeys.includes(k)) || allKeys[0];
|
||||||
|
const candidateY = allKeys.filter((k) => k !== xKey);
|
||||||
|
const numericY = candidateY.filter((key) => data.some((row) => isNumericValue(row[key])));
|
||||||
|
const yKeys = (numericY.length > 0 ? numericY : candidateY).slice(0, 3);
|
||||||
|
return { xKey, yKeys };
|
||||||
|
}
|
||||||
|
|
||||||
export function Dashboard() {
|
export function Dashboard() {
|
||||||
const { charts, removeChart } = useDashboardStore();
|
const { charts, removeChart } = useDashboardStore();
|
||||||
const ResponsiveGridLayout = useMemo(
|
const ResponsiveGridLayout = useMemo(
|
||||||
@@ -65,32 +95,54 @@ export function Dashboard() {
|
|||||||
</Button>
|
</Button>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent className="flex-1 min-h-0 p-2">
|
<CardContent className="flex-1 min-h-0 p-2">
|
||||||
|
{(() => {
|
||||||
|
const rows = chart.data as Record<string, unknown>[];
|
||||||
|
if (chart.chartSpec && rows.length > 0) {
|
||||||
|
return (
|
||||||
|
<div className="h-full w-full rounded-xl border border-zinc-100 p-2">
|
||||||
|
<VegaChart data={rows} spec={chart.chartSpec} />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const { xKey, yKeys } = inferChartKeys(rows);
|
||||||
|
if (!xKey || yKeys.length === 0) {
|
||||||
|
return (
|
||||||
|
<div className="h-full w-full flex items-center justify-center text-xs text-zinc-500">
|
||||||
|
当前图表数据缺少可绘制字段
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
<ResponsiveContainer width="100%" height="100%">
|
<ResponsiveContainer width="100%" height="100%">
|
||||||
{chart.type === 'bar' ? (
|
{chart.type === 'bar' ? (
|
||||||
<BarChart data={chart.data}>
|
<BarChart data={rows}>
|
||||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#e5e7eb" />
|
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#e5e7eb" />
|
||||||
<XAxis dataKey="name" tickLine={false} axisLine={false} tick={{ fontSize: 10, fill: '#6b7280' }} />
|
<XAxis dataKey={xKey} tickLine={false} axisLine={false} tick={{ fontSize: 10, fill: '#6b7280' }} />
|
||||||
<YAxis tickLine={false} axisLine={false} tick={{ fontSize: 10, fill: '#6b7280' }} />
|
<YAxis tickLine={false} axisLine={false} tick={{ fontSize: 10, fill: '#6b7280' }} />
|
||||||
<Tooltip
|
<Tooltip
|
||||||
cursor={{ fill: 'rgba(0,0,0,0.05)' }}
|
cursor={{ fill: 'rgba(0,0,0,0.05)' }}
|
||||||
contentStyle={{ borderRadius: '8px', border: 'none', boxShadow: '0 4px 6px -1px rgb(0 0 0 / 0.1)' }}
|
contentStyle={{ borderRadius: '8px', border: 'none', boxShadow: '0 4px 6px -1px rgb(0 0 0 / 0.1)' }}
|
||||||
/>
|
/>
|
||||||
<Bar dataKey="sales" fill="#3b82f6" radius={[4, 4, 0, 0]} name="Sales" />
|
{yKeys.map((key, index) => (
|
||||||
<Bar dataKey="profit" fill="#10b981" radius={[4, 4, 0, 0]} name="Profit" />
|
<Bar key={key} dataKey={key} fill={CHART_COLORS[index % CHART_COLORS.length]} radius={[4, 4, 0, 0]} name={key} />
|
||||||
|
))}
|
||||||
</BarChart>
|
</BarChart>
|
||||||
) : (
|
) : (
|
||||||
<LineChart data={chart.data}>
|
<LineChart data={rows}>
|
||||||
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#e5e7eb" />
|
<CartesianGrid strokeDasharray="3 3" vertical={false} stroke="#e5e7eb" />
|
||||||
<XAxis dataKey="name" tickLine={false} axisLine={false} tick={{ fontSize: 10, fill: '#6b7280' }} />
|
<XAxis dataKey={xKey} tickLine={false} axisLine={false} tick={{ fontSize: 10, fill: '#6b7280' }} />
|
||||||
<YAxis tickLine={false} axisLine={false} tick={{ fontSize: 10, fill: '#6b7280' }} />
|
<YAxis tickLine={false} axisLine={false} tick={{ fontSize: 10, fill: '#6b7280' }} />
|
||||||
<Tooltip
|
<Tooltip
|
||||||
contentStyle={{ borderRadius: '8px', border: 'none', boxShadow: '0 4px 6px -1px rgb(0 0 0 / 0.1)' }}
|
contentStyle={{ borderRadius: '8px', border: 'none', boxShadow: '0 4px 6px -1px rgb(0 0 0 / 0.1)' }}
|
||||||
/>
|
/>
|
||||||
<Line type="monotone" dataKey="sales" stroke="#3b82f6" strokeWidth={2} dot={{ r: 4 }} />
|
{yKeys.map((key, index) => (
|
||||||
<Line type="monotone" dataKey="profit" stroke="#10b981" strokeWidth={2} dot={{ r: 4 }} />
|
<Line key={key} type="monotone" dataKey={key} stroke={CHART_COLORS[index % CHART_COLORS.length]} strokeWidth={2} dot={{ r: 3 }} />
|
||||||
|
))}
|
||||||
</LineChart>
|
</LineChart>
|
||||||
)}
|
)}
|
||||||
</ResponsiveContainer>
|
</ResponsiveContainer>
|
||||||
|
);
|
||||||
|
})()}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import { create } from 'zustand';
|
import { create } from 'zustand';
|
||||||
|
import type { ChartSpec } from './visualizationStore';
|
||||||
|
|
||||||
type ChartRow = Record<string, unknown>;
|
type ChartRow = Record<string, unknown>;
|
||||||
type GridLayout = { i: string; x: number; y: number; w: number; h: number };
|
type GridLayout = { i: string; x: number; y: number; w: number; h: number };
|
||||||
@@ -9,6 +10,7 @@ export interface ChartConfig {
|
|||||||
type: 'bar' | 'line';
|
type: 'bar' | 'line';
|
||||||
data: ChartRow[];
|
data: ChartRow[];
|
||||||
sql: string;
|
sql: string;
|
||||||
|
chartSpec?: ChartSpec | null;
|
||||||
layout: GridLayout;
|
layout: GridLayout;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user