fix: prompt for viz enhanced

This commit is contained in:
qixinbo
2026-04-01 10:00:40 +08:00
parent 8e174df3d5
commit 9952af198a
5 changed files with 49 additions and 6 deletions
+12 -3
View File
@@ -1,5 +1,6 @@
import json
import logging
import re
from typing import Any
from nanobot.agent.tools.base import Tool
@@ -9,12 +10,19 @@ from fastapi.encoders import jsonable_encoder
logger = logging.getLogger(__name__)
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
def _normalize_query(value: str) -> str:
return re.sub(r"\s+", "", (value or "")).lower()
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse, query: str) -> dict:
chart = nl2sql_result.chart
payload = {
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump(by_alias=True, exclude_none=True) if chart else None,
"chart_query": query,
"chart_query_normalized": _normalize_query(query),
"chart_generated_by": "nl2sql",
"error": nl2sql_result.error,
}
return jsonable_encoder(payload)
@@ -34,7 +42,8 @@ class NL2SQLTool(Tool):
return (
"Query the connected database or data source using natural language. "
"Use this tool when the user asks to query, analyze, aggregate, or fetch data from the database. "
"Set generate_chart=True if the user also wants to visualize or plot the data."
"Set generate_chart=True if the user also wants to visualize or plot the data. "
"If generate_chart=True, do not call visualization again for the same request."
)
@property
@@ -76,7 +85,7 @@ class NL2SQLTool(Tool):
# Always save visualization payload to context so the chat stream can pick it up
# Even if there's an error, we want the frontend to see the generated SQL
viz_payload = _build_sql_chart_viz(result)
viz_payload = _build_sql_chart_viz(result, query)
existing_viz = current_viz_data.get()
if isinstance(existing_viz, dict):
existing_viz.clear()
+24 -1
View File
@@ -1,4 +1,5 @@
import logging
import re
from typing import Any
from nanobot.agent.tools.base import Tool
@@ -8,6 +9,10 @@ from fastapi.encoders import jsonable_encoder
logger = logging.getLogger(__name__)
def _normalize_query(value: str) -> str:
return re.sub(r"\s+", "", (value or "")).lower()
class VisualizationTool(Tool):
"""
Tool for generating a visualization (chart) from existing data.
@@ -22,7 +27,8 @@ class VisualizationTool(Tool):
return (
"Generate a chart or visualization based on the most recently queried data. "
"Use this tool when the user asks to plot, visualize, or create a chart from data that has already been retrieved. "
"Note: This tool relies on the data from the last executed SQL query. If no query has been executed yet, you must use the nl2sql tool first."
"Note: This tool relies on the data from the last executed SQL query. If no query has been executed yet, you must use the nl2sql tool first. "
"Do not call this tool right after nl2sql(generate_chart=True) for the same request."
)
@property
@@ -47,6 +53,20 @@ class VisualizationTool(Tool):
return "Error: No data available to visualize. Please query the data first using the nl2sql tool."
try:
existing_viz = current_viz_data.get() or {}
existing_chart = existing_viz.get("chart") if isinstance(existing_viz, dict) else None
existing_result = existing_viz.get("result") if isinstance(existing_viz, dict) else None
existing_query_normalized = (
existing_viz.get("chart_query_normalized") if isinstance(existing_viz, dict) else None
)
if (
existing_chart
and existing_result == data
and existing_query_normalized
and existing_query_normalized == _normalize_query(query)
):
return "Chart already exists for this query and dataset. Reusing existing Vega visualization."
if on_progress:
await on_progress("正在分析数据特征并生成可视化方案...")
@@ -61,6 +81,9 @@ class VisualizationTool(Tool):
"sql": existing_viz.get("sql", ""),
"result": data,
"chart": chart_response.model_dump(by_alias=True, exclude_none=True),
"chart_query": query,
"chart_query_normalized": _normalize_query(query),
"chart_generated_by": "visualization",
"error": None,
}
encoded_viz = jsonable_encoder(viz_payload)