feature: nl2sql first successful

This commit is contained in:
qixinbo
2026-03-15 10:49:37 +08:00
parent 76724b2313
commit 696fd94ff3
7 changed files with 252 additions and 47 deletions
+60 -11
View File
@@ -4,6 +4,8 @@ import json
from pathlib import Path
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
import duckdb
import pandas as pd
# Add project root to sys.path to allow importing nanobot
PROJECT_ROOT = Path(__file__).resolve().parents[3]
@@ -19,7 +21,8 @@ from app.agent.chart import generate_chart
class NL2SQLRequest(BaseModel):
query: str = Field(..., description="User's natural language query")
source: str = Field(..., description="Data source to query (postgres, clickhouse)")
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
class NL2SQLResponse(BaseModel):
sql: str
@@ -80,20 +83,63 @@ The final answer must be a ANSI SQL query in JSON format:
}}
"""
def _resolve_upload_file_path(file_url: Optional[str]) -> Path:
if not file_url or not file_url.startswith("local://"):
raise ValueError("Invalid uploaded file URL")
raw_name = file_url.replace("local://", "", 1)
safe_name = os.path.basename(raw_name)
upload_dir = Path(__file__).resolve().parents[2] / "data" / "uploads"
file_path = upload_dir / safe_name
if not file_path.exists():
raise ValueError(f"Uploaded file not found: {safe_name}")
return file_path
def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
file_path = _resolve_upload_file_path(file_url)
suffix = file_path.suffix.lower()
if suffix == ".csv":
return pd.read_csv(file_path)
if suffix in [".xls", ".xlsx"]:
return pd.read_excel(file_path)
raise ValueError(f"Unsupported uploaded file type: {suffix}")
def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
df = _load_upload_dataframe(file_url)
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]}
conn.close()
return schema
def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]:
df = _load_upload_dataframe(file_url)
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
result_df = conn.execute(sql_query).df()
conn.close()
return result_df.to_dict(orient="records")
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
# 1. Get the connector and schema
connector = None
schema = {}
if request.source == "postgres":
connector = postgres_connector
elif request.source == "clickhouse":
connector = clickhouse_connector
elif request.source == "upload":
try:
schema = _get_upload_schema(request.file_url)
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
else:
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
if not connector.test_connection():
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
if connector:
if not connector.test_connection():
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
schema_str = json.dumps(schema, indent=2)
# 2. Get the active LLM config
@@ -158,19 +204,22 @@ Let's think step by step.
# 6. Execute SQL
try:
results = connector.execute_query(sql_query)
if request.source == "upload":
formatted_results = _execute_upload_sql(sql_query, request.file_url)
else:
results = connector.execute_query(sql_query)
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
formatted_results = []
if request.source == "postgres":
formatted_results = results
elif request.source == "clickhouse":
formatted_results = []
if request.source == "postgres":
formatted_results = results
elif request.source == "clickhouse":
# ClickHouse returns list of tuples, we need column names
# But execute_query in ClickHouseConnector just returns raw results from client.execute
# client.execute(query, with_column_types=True) might be better but let's stick to simple for now
# Actually, without column names it's hard to format as dict.
# Let's assume we can just return the raw tuples for now or try to fetch column names.
# For now, let's just return as list of lists/tuples if it's not a dict
formatted_results = [list(row) for row in results]
formatted_results = [list(row) for row in results]
# 7. Generate Chart
chart_response = None