feature: nl2sql first successful
This commit is contained in:
+60
-11
@@ -4,6 +4,8 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
|
||||
# Add project root to sys.path to allow importing nanobot
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
@@ -19,7 +21,8 @@ from app.agent.chart import generate_chart
|
||||
|
||||
class NL2SQLRequest(BaseModel):
|
||||
query: str = Field(..., description="User's natural language query")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse)")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
||||
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
||||
|
||||
class NL2SQLResponse(BaseModel):
|
||||
sql: str
|
||||
@@ -80,20 +83,63 @@ The final answer must be a ANSI SQL query in JSON format:
|
||||
}}
|
||||
"""
|
||||
|
||||
def _resolve_upload_file_path(file_url: Optional[str]) -> Path:
|
||||
if not file_url or not file_url.startswith("local://"):
|
||||
raise ValueError("Invalid uploaded file URL")
|
||||
raw_name = file_url.replace("local://", "", 1)
|
||||
safe_name = os.path.basename(raw_name)
|
||||
upload_dir = Path(__file__).resolve().parents[2] / "data" / "uploads"
|
||||
file_path = upload_dir / safe_name
|
||||
if not file_path.exists():
|
||||
raise ValueError(f"Uploaded file not found: {safe_name}")
|
||||
return file_path
|
||||
|
||||
def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
|
||||
file_path = _resolve_upload_file_path(file_url)
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix == ".csv":
|
||||
return pd.read_csv(file_path)
|
||||
if suffix in [".xls", ".xlsx"]:
|
||||
return pd.read_excel(file_path)
|
||||
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
||||
|
||||
def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
|
||||
df = _load_upload_dataframe(file_url)
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.register("uploaded_file", df)
|
||||
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
|
||||
schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]}
|
||||
conn.close()
|
||||
return schema
|
||||
|
||||
def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]:
|
||||
df = _load_upload_dataframe(file_url)
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.register("uploaded_file", df)
|
||||
result_df = conn.execute(sql_query).df()
|
||||
conn.close()
|
||||
return result_df.to_dict(orient="records")
|
||||
|
||||
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
# 1. Get the connector and schema
|
||||
connector = None
|
||||
schema = {}
|
||||
if request.source == "postgres":
|
||||
connector = postgres_connector
|
||||
elif request.source == "clickhouse":
|
||||
connector = clickhouse_connector
|
||||
elif request.source == "upload":
|
||||
try:
|
||||
schema = _get_upload_schema(request.file_url)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
||||
else:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||
|
||||
if not connector.test_connection():
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
|
||||
schema = connector.get_schema()
|
||||
if connector:
|
||||
if not connector.test_connection():
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
|
||||
# 2. Get the active LLM config
|
||||
@@ -158,19 +204,22 @@ Let's think step by step.
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
results = connector.execute_query(sql_query)
|
||||
if request.source == "upload":
|
||||
formatted_results = _execute_upload_sql(sql_query, request.file_url)
|
||||
else:
|
||||
results = connector.execute_query(sql_query)
|
||||
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
||||
formatted_results = []
|
||||
if request.source == "postgres":
|
||||
formatted_results = results
|
||||
elif request.source == "clickhouse":
|
||||
formatted_results = []
|
||||
if request.source == "postgres":
|
||||
formatted_results = results
|
||||
elif request.source == "clickhouse":
|
||||
# ClickHouse returns list of tuples, we need column names
|
||||
# But execute_query in ClickHouseConnector just returns raw results from client.execute
|
||||
# client.execute(query, with_column_types=True) might be better but let's stick to simple for now
|
||||
# Actually, without column names it's hard to format as dict.
|
||||
# Let's assume we can just return the raw tuples for now or try to fetch column names.
|
||||
# For now, let's just return as list of lists/tuples if it's not a dict
|
||||
formatted_results = [list(row) for row in results]
|
||||
formatted_results = [list(row) for row in results]
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
|
||||
@@ -67,6 +67,9 @@ class ChatRequest(BaseModel):
|
||||
session_id: str = "api:default"
|
||||
skill_ids: Optional[List[str]] = None
|
||||
model_id: Optional[str] = None
|
||||
source: str = "postgres"
|
||||
prefer_sql_chart: bool = False
|
||||
file_url: Optional[str] = None
|
||||
|
||||
|
||||
class SessionAliasUpdateRequest(BaseModel):
|
||||
@@ -77,6 +80,27 @@ class SessionAliasUpdateRequest(BaseModel):
|
||||
@app.post("/nanobot/chat")
|
||||
async def nanobot_chat(request: ChatRequest):
|
||||
try:
|
||||
if request.prefer_sql_chart:
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
||||
)
|
||||
chart = nl2sql_result.chart
|
||||
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
||||
text = (
|
||||
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
|
||||
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
||||
)
|
||||
if chart and chart.reasoning:
|
||||
text = f"{text}\n\n可视化说明:{chart.reasoning}"
|
||||
return {
|
||||
"response": text,
|
||||
"viz": {
|
||||
"sql": nl2sql_result.sql,
|
||||
"result": nl2sql_result.result,
|
||||
"chart": chart.model_dump() if chart else None,
|
||||
"error": nl2sql_result.error,
|
||||
},
|
||||
}
|
||||
response = await nanobot_service.process_message(
|
||||
request.message,
|
||||
session_id=request.session_id,
|
||||
@@ -91,6 +115,29 @@ async def nanobot_chat(request: ChatRequest):
|
||||
async def nanobot_chat_stream(request: ChatRequest):
|
||||
async def event_generator():
|
||||
try:
|
||||
if request.prefer_sql_chart:
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
||||
)
|
||||
chart = nl2sql_result.chart
|
||||
viz_payload = {
|
||||
"type": "viz",
|
||||
"sql": nl2sql_result.sql,
|
||||
"result": nl2sql_result.result,
|
||||
"chart": chart.model_dump() if chart else None,
|
||||
"error": nl2sql_result.error,
|
||||
}
|
||||
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
|
||||
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
||||
text = (
|
||||
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
|
||||
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
||||
)
|
||||
if chart and chart.reasoning:
|
||||
text = f"{text}\n\n可视化说明:{chart.reasoning}"
|
||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
response = await nanobot_service.process_message(
|
||||
request.message,
|
||||
session_id=request.session_id,
|
||||
|
||||
Reference in New Issue
Block a user