optimize nl2sql

This commit is contained in:
qixinbo
2026-03-15 18:36:28 +08:00
parent 5b25563c0a
commit 29b12ec1d4
+91 -9
View File
@@ -1,6 +1,8 @@
import sys
import os
import json
import time
import threading
from pathlib import Path
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
@@ -19,6 +21,16 @@ from app.api.llm import _load_data as load_llm_config
from app.schemas.chart import ChartGenerationResponse
from app.agent.chart import generate_chart
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
UPLOAD_CACHE_TTL_SECONDS = 900
MAX_UPLOAD_CACHE_ITEMS = 8
_schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {}
_upload_cache: Dict[str, Dict[str, Any]] = {}
_cache_lock = threading.Lock()
class NL2SQLRequest(BaseModel):
query: str = Field(..., description="User's natural language query")
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
@@ -95,8 +107,7 @@ def _resolve_upload_file_path(file_url: Optional[str]) -> Path:
raise ValueError(f"Uploaded file not found: {safe_name}")
return file_path
def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
file_path = _resolve_upload_file_path(file_url)
def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
suffix = file_path.suffix.lower()
if suffix == ".csv":
return pd.read_csv(file_path)
@@ -104,8 +115,7 @@ def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
return pd.read_excel(file_path)
raise ValueError(f"Unsupported uploaded file type: {suffix}")
def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
df = _load_upload_dataframe(file_url)
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
@@ -113,34 +123,104 @@ def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
conn.close()
return schema
def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]:
df = _load_upload_dataframe(file_url)
def _get_upload_payload(file_url: Optional[str]) -> Dict[str, Any]:
file_path = _resolve_upload_file_path(file_url)
stat = file_path.stat()
cache_key = f"{file_path}:{int(stat.st_mtime)}:{stat.st_size}"
now = time.time()
with _cache_lock:
cached = _upload_cache.get(cache_key)
if cached and now < cached["expires_at"]:
return {"df": cached["df"], "schema": cached["schema"]}
df = _load_upload_dataframe_from_path(file_path)
schema = _build_upload_schema(df)
with _cache_lock:
if len(_upload_cache) >= MAX_UPLOAD_CACHE_ITEMS:
oldest_key = min(_upload_cache.keys(), key=lambda key: _upload_cache[key]["expires_at"])
_upload_cache.pop(oldest_key, None)
_upload_cache[cache_key] = {
"df": df,
"schema": schema,
"expires_at": now + UPLOAD_CACHE_TTL_SECONDS,
}
return {"df": df, "schema": schema}
def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df)
result_df = conn.execute(sql_query).df()
conn.close()
return result_df.to_dict(orient="records")
def _build_schema_cache_key(source: str, connector: Any) -> str:
if source == "postgres":
return f"postgres:{getattr(connector, 'db_url', '')}"
if source == "clickhouse":
return (
f"clickhouse:{getattr(connector, 'host', '')}:{getattr(connector, 'port', '')}:"
f"{getattr(connector, 'user', '')}:{getattr(connector, 'database', '')}"
)
return source
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]:
key = _build_schema_cache_key(source, connector)
now = time.time()
with _cache_lock:
cached = _schema_cache.get(key)
if cached and now < cached["expires_at"]:
return cached["schema"]
return None
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None:
key = _build_schema_cache_key(source, connector)
with _cache_lock:
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
def _check_connection_with_cache(source: str, connector: Any) -> bool:
cache_key = _build_schema_cache_key(source, connector)
now = time.time()
with _cache_lock:
cached = _connection_cache.get(cache_key)
if cached and now < cached["expires_at"]:
return bool(cached["ok"])
ok = connector.test_connection()
with _cache_lock:
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
return ok
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
# 1. Get the connector and schema
connector = None
schema = {}
upload_df: Optional[pd.DataFrame] = None
if request.source == "postgres":
connector = postgres_connector
elif request.source == "clickhouse":
connector = clickhouse_connector
elif request.source == "upload":
try:
schema = _get_upload_schema(request.file_url)
upload_payload = _get_upload_payload(request.file_url)
upload_df = upload_payload["df"]
schema = upload_payload["schema"]
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
else:
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
if connector:
if not connector.test_connection():
cached_schema = _get_cached_schema(request.source, connector)
if cached_schema:
schema = cached_schema
else:
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
if connector and not schema:
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
schema_str = json.dumps(schema, indent=2)
# 2. Get the active LLM config
@@ -206,7 +286,9 @@ Let's think step by step.
# 6. Execute SQL
try:
if request.source == "upload":
formatted_results = _execute_upload_sql(sql_query, request.file_url)
if upload_df is None:
upload_df = _get_upload_payload(request.file_url)["df"]
formatted_results = _execute_upload_sql(sql_query, upload_df)
else:
results = connector.execute_query(sql_query)
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)