optimize nl2sql
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -19,6 +21,16 @@ from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.agent.chart import generate_chart
|
||||
|
||||
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
UPLOAD_CACHE_TTL_SECONDS = 900
|
||||
MAX_UPLOAD_CACHE_ITEMS = 8
|
||||
|
||||
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_upload_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
class NL2SQLRequest(BaseModel):
|
||||
query: str = Field(..., description="User's natural language query")
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
||||
@@ -95,8 +107,7 @@ def _resolve_upload_file_path(file_url: Optional[str]) -> Path:
|
||||
raise ValueError(f"Uploaded file not found: {safe_name}")
|
||||
return file_path
|
||||
|
||||
def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
|
||||
file_path = _resolve_upload_file_path(file_url)
|
||||
def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix == ".csv":
|
||||
return pd.read_csv(file_path)
|
||||
@@ -104,8 +115,7 @@ def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
|
||||
return pd.read_excel(file_path)
|
||||
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
||||
|
||||
def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
|
||||
df = _load_upload_dataframe(file_url)
|
||||
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.register("uploaded_file", df)
|
||||
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
|
||||
@@ -113,34 +123,104 @@ def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
|
||||
conn.close()
|
||||
return schema
|
||||
|
||||
def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]:
|
||||
df = _load_upload_dataframe(file_url)
|
||||
def _get_upload_payload(file_url: Optional[str]) -> Dict[str, Any]:
|
||||
file_path = _resolve_upload_file_path(file_url)
|
||||
stat = file_path.stat()
|
||||
cache_key = f"{file_path}:{int(stat.st_mtime)}:{stat.st_size}"
|
||||
now = time.time()
|
||||
with _cache_lock:
|
||||
cached = _upload_cache.get(cache_key)
|
||||
if cached and now < cached["expires_at"]:
|
||||
return {"df": cached["df"], "schema": cached["schema"]}
|
||||
df = _load_upload_dataframe_from_path(file_path)
|
||||
schema = _build_upload_schema(df)
|
||||
with _cache_lock:
|
||||
if len(_upload_cache) >= MAX_UPLOAD_CACHE_ITEMS:
|
||||
oldest_key = min(_upload_cache.keys(), key=lambda key: _upload_cache[key]["expires_at"])
|
||||
_upload_cache.pop(oldest_key, None)
|
||||
_upload_cache[cache_key] = {
|
||||
"df": df,
|
||||
"schema": schema,
|
||||
"expires_at": now + UPLOAD_CACHE_TTL_SECONDS,
|
||||
}
|
||||
return {"df": df, "schema": schema}
|
||||
|
||||
def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
|
||||
conn = duckdb.connect(":memory:")
|
||||
conn.register("uploaded_file", df)
|
||||
result_df = conn.execute(sql_query).df()
|
||||
conn.close()
|
||||
return result_df.to_dict(orient="records")
|
||||
|
||||
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||
if source == "postgres":
|
||||
return f"postgres:{getattr(connector, 'db_url', '')}"
|
||||
if source == "clickhouse":
|
||||
return (
|
||||
f"clickhouse:{getattr(connector, 'host', '')}:{getattr(connector, 'port', '')}:"
|
||||
f"{getattr(connector, 'user', '')}:{getattr(connector, 'database', '')}"
|
||||
)
|
||||
return source
|
||||
|
||||
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]:
|
||||
key = _build_schema_cache_key(source, connector)
|
||||
now = time.time()
|
||||
with _cache_lock:
|
||||
cached = _schema_cache.get(key)
|
||||
if cached and now < cached["expires_at"]:
|
||||
return cached["schema"]
|
||||
return None
|
||||
|
||||
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None:
|
||||
key = _build_schema_cache_key(source, connector)
|
||||
with _cache_lock:
|
||||
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
|
||||
|
||||
def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||
cache_key = _build_schema_cache_key(source, connector)
|
||||
now = time.time()
|
||||
with _cache_lock:
|
||||
cached = _connection_cache.get(cache_key)
|
||||
if cached and now < cached["expires_at"]:
|
||||
return bool(cached["ok"])
|
||||
ok = connector.test_connection()
|
||||
with _cache_lock:
|
||||
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
|
||||
return ok
|
||||
|
||||
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
# 1. Get the connector and schema
|
||||
connector = None
|
||||
schema = {}
|
||||
upload_df: Optional[pd.DataFrame] = None
|
||||
if request.source == "postgres":
|
||||
connector = postgres_connector
|
||||
elif request.source == "clickhouse":
|
||||
connector = clickhouse_connector
|
||||
elif request.source == "upload":
|
||||
try:
|
||||
schema = _get_upload_schema(request.file_url)
|
||||
upload_payload = _get_upload_payload(request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
schema = upload_payload["schema"]
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
||||
else:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||
|
||||
if connector:
|
||||
if not connector.test_connection():
|
||||
cached_schema = _get_cached_schema(request.source, connector)
|
||||
if cached_schema:
|
||||
schema = cached_schema
|
||||
else:
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
if connector and not schema:
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
|
||||
# 2. Get the active LLM config
|
||||
@@ -206,7 +286,9 @@ Let's think step by step.
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
if request.source == "upload":
|
||||
formatted_results = _execute_upload_sql(sql_query, request.file_url)
|
||||
if upload_df is None:
|
||||
upload_df = _get_upload_payload(request.file_url)["df"]
|
||||
formatted_results = _execute_upload_sql(sql_query, upload_df)
|
||||
else:
|
||||
results = connector.execute_query(sql_query)
|
||||
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
||||
|
||||
Reference in New Issue
Block a user