fix: session bug

This commit is contained in:
qixinbo
2026-03-15 17:57:09 +08:00
parent eb09129148
commit 4f46f3f8d5
7 changed files with 136 additions and 55 deletions
+1
View File
@@ -23,6 +23,7 @@ class NL2SQLRequest(BaseModel):
query: str = Field(..., description="User's natural language query")
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
session_id: Optional[str] = Field(None, description="Conversation session identifier")
class NL2SQLResponse(BaseModel):
sql: str
+50 -27
View File
@@ -77,6 +77,42 @@ class SessionAliasUpdateRequest(BaseModel):
pinned: Optional[bool] = None
archived: Optional[bool] = None
def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
chart = nl2sql_result.chart
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
text = (
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
)
if chart and chart.reasoning:
return f"{text}\n\n可视化说明:{chart.reasoning}"
return text
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
chart = nl2sql_result.chart
return {
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"error": nl2sql_result.error,
}
def _persist_session_turn(
session_id: str,
user_message: str,
assistant_message: str,
assistant_extra: Optional[dict] = None,
) -> None:
if not nanobot_service.agent:
return
session = nanobot_service.agent.sessions.get_or_create(session_id)
session.add_message("user", user_message)
session.add_message("assistant", assistant_message, **(assistant_extra or {}))
nanobot_service.agent.sessions.save(session)
@app.post("/nanobot/chat")
async def nanobot_chat(request: ChatRequest):
try:
@@ -84,22 +120,12 @@ async def nanobot_chat(request: ChatRequest):
nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
)
chart = nl2sql_result.chart
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
text = (
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
)
if chart and chart.reasoning:
text = f"{text}\n\n可视化说明:{chart.reasoning}"
text = _build_sql_chart_text(nl2sql_result)
viz_payload = _build_sql_chart_viz(nl2sql_result)
_persist_session_turn(request.session_id, request.message, text, {"viz": viz_payload})
return {
"response": text,
"viz": {
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"error": nl2sql_result.error,
},
"viz": viz_payload,
}
response = await nanobot_service.process_message(
request.message,
@@ -119,22 +145,14 @@ async def nanobot_chat_stream(request: ChatRequest):
nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
)
chart = nl2sql_result.chart
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
viz_payload = {
"type": "viz",
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"error": nl2sql_result.error,
**persisted_viz_payload,
}
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
text = (
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
)
if chart and chart.reasoning:
text = f"{text}\n\n可视化说明:{chart.reasoning}"
text = _build_sql_chart_text(nl2sql_result)
_persist_session_turn(request.session_id, request.message, text, {"viz": persisted_viz_payload})
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return
@@ -229,4 +247,9 @@ def update_session(session_id: str, payload: SessionAliasUpdateRequest):
@app.post("/api/v1/agent/nl2sql", response_model=NL2SQLResponse)
async def run_nl2sql(request: NL2SQLRequest):
return await process_nl2sql(request)
result = await process_nl2sql(request)
if request.session_id:
text = _build_sql_chart_text(result)
viz_payload = _build_sql_chart_viz(result)
_persist_session_turn(request.session_id, request.query, text, {"viz": viz_payload})
return result