optimize nl2sql
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -19,6 +21,16 @@ from app.api.llm import _load_data as load_llm_config
|
|||||||
from app.schemas.chart import ChartGenerationResponse
|
from app.schemas.chart import ChartGenerationResponse
|
||||||
from app.agent.chart import generate_chart
|
from app.agent.chart import generate_chart
|
||||||
|
|
||||||
|
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||||
|
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||||
|
UPLOAD_CACHE_TTL_SECONDS = 900
|
||||||
|
MAX_UPLOAD_CACHE_ITEMS = 8
|
||||||
|
|
||||||
|
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
_upload_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
_cache_lock = threading.Lock()
|
||||||
|
|
||||||
class NL2SQLRequest(BaseModel):
|
class NL2SQLRequest(BaseModel):
|
||||||
query: str = Field(..., description="User's natural language query")
|
query: str = Field(..., description="User's natural language query")
|
||||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)")
|
||||||
@@ -95,8 +107,7 @@ def _resolve_upload_file_path(file_url: Optional[str]) -> Path:
|
|||||||
raise ValueError(f"Uploaded file not found: {safe_name}")
|
raise ValueError(f"Uploaded file not found: {safe_name}")
|
||||||
return file_path
|
return file_path
|
||||||
|
|
||||||
def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
|
def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame:
|
||||||
file_path = _resolve_upload_file_path(file_url)
|
|
||||||
suffix = file_path.suffix.lower()
|
suffix = file_path.suffix.lower()
|
||||||
if suffix == ".csv":
|
if suffix == ".csv":
|
||||||
return pd.read_csv(file_path)
|
return pd.read_csv(file_path)
|
||||||
@@ -104,8 +115,7 @@ def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame:
|
|||||||
return pd.read_excel(file_path)
|
return pd.read_excel(file_path)
|
||||||
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
raise ValueError(f"Unsupported uploaded file type: {suffix}")
|
||||||
|
|
||||||
def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
|
def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]:
|
||||||
df = _load_upload_dataframe(file_url)
|
|
||||||
conn = duckdb.connect(":memory:")
|
conn = duckdb.connect(":memory:")
|
||||||
conn.register("uploaded_file", df)
|
conn.register("uploaded_file", df)
|
||||||
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
|
columns = conn.execute("DESCRIBE uploaded_file").fetchall()
|
||||||
@@ -113,34 +123,104 @@ def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]:
|
|||||||
conn.close()
|
conn.close()
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]:
|
def _get_upload_payload(file_url: Optional[str]) -> Dict[str, Any]:
|
||||||
df = _load_upload_dataframe(file_url)
|
file_path = _resolve_upload_file_path(file_url)
|
||||||
|
stat = file_path.stat()
|
||||||
|
cache_key = f"{file_path}:{int(stat.st_mtime)}:{stat.st_size}"
|
||||||
|
now = time.time()
|
||||||
|
with _cache_lock:
|
||||||
|
cached = _upload_cache.get(cache_key)
|
||||||
|
if cached and now < cached["expires_at"]:
|
||||||
|
return {"df": cached["df"], "schema": cached["schema"]}
|
||||||
|
df = _load_upload_dataframe_from_path(file_path)
|
||||||
|
schema = _build_upload_schema(df)
|
||||||
|
with _cache_lock:
|
||||||
|
if len(_upload_cache) >= MAX_UPLOAD_CACHE_ITEMS:
|
||||||
|
oldest_key = min(_upload_cache.keys(), key=lambda key: _upload_cache[key]["expires_at"])
|
||||||
|
_upload_cache.pop(oldest_key, None)
|
||||||
|
_upload_cache[cache_key] = {
|
||||||
|
"df": df,
|
||||||
|
"schema": schema,
|
||||||
|
"expires_at": now + UPLOAD_CACHE_TTL_SECONDS,
|
||||||
|
}
|
||||||
|
return {"df": df, "schema": schema}
|
||||||
|
|
||||||
|
def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
|
||||||
conn = duckdb.connect(":memory:")
|
conn = duckdb.connect(":memory:")
|
||||||
conn.register("uploaded_file", df)
|
conn.register("uploaded_file", df)
|
||||||
result_df = conn.execute(sql_query).df()
|
result_df = conn.execute(sql_query).df()
|
||||||
conn.close()
|
conn.close()
|
||||||
return result_df.to_dict(orient="records")
|
return result_df.to_dict(orient="records")
|
||||||
|
|
||||||
|
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||||
|
if source == "postgres":
|
||||||
|
return f"postgres:{getattr(connector, 'db_url', '')}"
|
||||||
|
if source == "clickhouse":
|
||||||
|
return (
|
||||||
|
f"clickhouse:{getattr(connector, 'host', '')}:{getattr(connector, 'port', '')}:"
|
||||||
|
f"{getattr(connector, 'user', '')}:{getattr(connector, 'database', '')}"
|
||||||
|
)
|
||||||
|
return source
|
||||||
|
|
||||||
|
def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]:
|
||||||
|
key = _build_schema_cache_key(source, connector)
|
||||||
|
now = time.time()
|
||||||
|
with _cache_lock:
|
||||||
|
cached = _schema_cache.get(key)
|
||||||
|
if cached and now < cached["expires_at"]:
|
||||||
|
return cached["schema"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None:
|
||||||
|
key = _build_schema_cache_key(source, connector)
|
||||||
|
with _cache_lock:
|
||||||
|
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
|
||||||
|
|
||||||
|
def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||||
|
cache_key = _build_schema_cache_key(source, connector)
|
||||||
|
now = time.time()
|
||||||
|
with _cache_lock:
|
||||||
|
cached = _connection_cache.get(cache_key)
|
||||||
|
if cached and now < cached["expires_at"]:
|
||||||
|
return bool(cached["ok"])
|
||||||
|
ok = connector.test_connection()
|
||||||
|
with _cache_lock:
|
||||||
|
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
|
||||||
|
return ok
|
||||||
|
|
||||||
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||||
# 1. Get the connector and schema
|
# 1. Get the connector and schema
|
||||||
connector = None
|
connector = None
|
||||||
schema = {}
|
schema = {}
|
||||||
|
upload_df: Optional[pd.DataFrame] = None
|
||||||
if request.source == "postgres":
|
if request.source == "postgres":
|
||||||
connector = postgres_connector
|
connector = postgres_connector
|
||||||
elif request.source == "clickhouse":
|
elif request.source == "clickhouse":
|
||||||
connector = clickhouse_connector
|
connector = clickhouse_connector
|
||||||
elif request.source == "upload":
|
elif request.source == "upload":
|
||||||
try:
|
try:
|
||||||
schema = _get_upload_schema(request.file_url)
|
upload_payload = _get_upload_payload(request.file_url)
|
||||||
|
upload_df = upload_payload["df"]
|
||||||
|
schema = upload_payload["schema"]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
||||||
else:
|
else:
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||||
|
|
||||||
if connector:
|
if connector:
|
||||||
if not connector.test_connection():
|
cached_schema = _get_cached_schema(request.source, connector)
|
||||||
|
if cached_schema:
|
||||||
|
schema = cached_schema
|
||||||
|
else:
|
||||||
|
if not _check_connection_with_cache(request.source, connector):
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||||
|
schema = connector.get_schema()
|
||||||
|
_set_cached_schema(request.source, connector, schema)
|
||||||
|
if connector and not schema:
|
||||||
|
if not _check_connection_with_cache(request.source, connector):
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||||
schema = connector.get_schema()
|
schema = connector.get_schema()
|
||||||
|
_set_cached_schema(request.source, connector, schema)
|
||||||
schema_str = json.dumps(schema, indent=2)
|
schema_str = json.dumps(schema, indent=2)
|
||||||
|
|
||||||
# 2. Get the active LLM config
|
# 2. Get the active LLM config
|
||||||
@@ -206,7 +286,9 @@ Let's think step by step.
|
|||||||
# 6. Execute SQL
|
# 6. Execute SQL
|
||||||
try:
|
try:
|
||||||
if request.source == "upload":
|
if request.source == "upload":
|
||||||
formatted_results = _execute_upload_sql(sql_query, request.file_url)
|
if upload_df is None:
|
||||||
|
upload_df = _get_upload_payload(request.file_url)["df"]
|
||||||
|
formatted_results = _execute_upload_sql(sql_query, upload_df)
|
||||||
else:
|
else:
|
||||||
results = connector.execute_query(sql_query)
|
results = connector.execute_query(sql_query)
|
||||||
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
# Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)
|
||||||
|
|||||||
Reference in New Issue
Block a user