optimize nl2sql

This commit is contained in:
qixinbo
2026-03-15 18:36:28 +08:00
parent 5b25563c0a
commit 29b12ec1d4
+91 -9
View File
@@ -1,6 +1,8 @@
import sys import sys
import os import os
import json import json
import time
import threading
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -19,6 +21,16 @@ from app.api.llm import _load_data as load_llm_config
from app.schemas.chart import ChartGenerationResponse from app.schemas.chart import ChartGenerationResponse
from app.agent.chart import generate_chart from app.agent.chart import generate_chart
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
UPLOAD_CACHE_TTL_SECONDS = 900
MAX_UPLOAD_CACHE_ITEMS = 8
_schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {}
_upload_cache: Dict[str, Dict[str, Any]] = {}
_cache_lock = threading.Lock()
class NL2SQLRequest(BaseModel): class NL2SQLRequest(BaseModel):
query: str = Field(..., description="User's natural language query") query: str = Field(..., description="User's natural language query")
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)") source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
@@ -95,8 +107,7 @@ def _resolve_upload_file_path(file_url: Optional[str]) -> Path:
raise ValueError(f"Uploaded file not found: {safe_name}") raise ValueError(f"Uploaded file not found: {safe_name}")
return file_path return file_path
def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame: def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
file_path = _resolve_upload_file_path(file_url)
suffix = file_path.suffix.lower() suffix = file_path.suffix.lower()
if suffix == ".csv": if suffix == ".csv":
return pd.read_csv(file_path) return pd.read_csv(file_path)
@@ -104,8 +115,7 @@ def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
return pd.read_excel(file_path) return pd.read_excel(file_path)
raise ValueError(f"Unsupported uploaded file type: {suffix}") raise ValueError(f"Unsupported uploaded file type: {suffix}")
def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]: def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
df = _load_upload_dataframe(file_url)
conn = duckdb.connect(":memory:") conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df) conn.register("uploaded_file", df)
columns = conn.execute("DESCRIBE uploaded_file").fetchall() columns = conn.execute("DESCRIBE uploaded_file").fetchall()
@@ -113,34 +123,104 @@ def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
conn.close() conn.close()
return schema return schema
def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]: def _get_upload_payload(file_url: Optional[str]) -> Dict[str, Any]:
df = _load_upload_dataframe(file_url) file_path = _resolve_upload_file_path(file_url)
stat = file_path.stat()
cache_key = f"{file_path}:{int(stat.st_mtime)}:{stat.st_size}"
now = time.time()
with _cache_lock:
cached = _upload_cache.get(cache_key)
if cached and now < cached["expires_at"]:
return {"df": cached["df"], "schema": cached["schema"]}
df = _load_upload_dataframe_from_path(file_path)
schema = _build_upload_schema(df)
with _cache_lock:
if len(_upload_cache) >= MAX_UPLOAD_CACHE_ITEMS:
oldest_key = min(_upload_cache.keys(), key=lambda key: _upload_cache[key]["expires_at"])
_upload_cache.pop(oldest_key, None)
_upload_cache[cache_key] = {
"df": df,
"schema": schema,
"expires_at": now + UPLOAD_CACHE_TTL_SECONDS,
}
return {"df": df, "schema": schema}
def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
conn = duckdb.connect(":memory:") conn = duckdb.connect(":memory:")
conn.register("uploaded_file", df) conn.register("uploaded_file", df)
result_df = conn.execute(sql_query).df() result_df = conn.execute(sql_query).df()
conn.close() conn.close()
return result_df.to_dict(orient="records") return result_df.to_dict(orient="records")
def _build_schema_cache_key(source: str, connector: Any) -> str:
if source == "postgres":
return f"postgres:{getattr(connector, 'db_url', '')}"
if source == "clickhouse":
return (
f"clickhouse:{getattr(connector, 'host', '')}:{getattr(connector, 'port', '')}:"
f"{getattr(connector, 'user', '')}:{getattr(connector, 'database', '')}"
)
return source
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]:
key = _build_schema_cache_key(source, connector)
now = time.time()
with _cache_lock:
cached = _schema_cache.get(key)
if cached and now < cached["expires_at"]:
return cached["schema"]
return None
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None:
key = _build_schema_cache_key(source, connector)
with _cache_lock:
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
def _check_connection_with_cache(source: str, connector: Any) -> bool:
cache_key = _build_schema_cache_key(source, connector)
now = time.time()
with _cache_lock:
cached = _connection_cache.get(cache_key)
if cached and now < cached["expires_at"]:
return bool(cached["ok"])
ok = connector.test_connection()
with _cache_lock:
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
return ok
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
# 1. Get the connector and schema # 1. Get the connector and schema
connector = None connector = None
schema = {} schema = {}
upload_df: Optional[pd.DataFrame] = None
if request.source == "postgres": if request.source == "postgres":
connector = postgres_connector connector = postgres_connector
elif request.source == "clickhouse": elif request.source == "clickhouse":
connector = clickhouse_connector connector = clickhouse_connector
elif request.source == "upload": elif request.source == "upload":
try: try:
schema = _get_upload_schema(request.file_url) upload_payload = _get_upload_payload(request.file_url)
upload_df = upload_payload["df"]
schema = upload_payload["schema"]
except Exception as e: except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}") return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
else: else:
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}") return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
if connector: if connector:
if not connector.test_connection(): cached_schema = _get_cached_schema(request.source, connector)
if cached_schema:
schema = cached_schema
else:
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
if connector and not schema:
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema() schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
schema_str = json.dumps(schema, indent=2) schema_str = json.dumps(schema, indent=2)
# 2. Get the active LLM config # 2. Get the active LLM config
@@ -206,7 +286,9 @@ Let's think step by step.
# 6. Execute SQL # 6. Execute SQL
try: try:
if request.source == "upload": if request.source == "upload":
formatted_results = _execute_upload_sql(sql_query, request.file_url) if upload_df is None:
upload_df = _get_upload_payload(request.file_url)["df"]
formatted_results = _execute_upload_sql(sql_query, upload_df)
else: else:
results = connector.execute_query(sql_query) results = connector.execute_query(sql_query)
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples) # Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)