diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index 78269ae..7ccd9fa 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -1,6 +1,8 @@ import sys import os import json +import time +import threading from pathlib import Path from typing import List, Optional, Dict, Any from pydantic import BaseModel, Field @@ -19,6 +21,16 @@ from app.api.llm import _load_data as load_llm_config from app.schemas.chart import ChartGenerationResponse from app.agent.chart import generate_chart +SCHEMA_CACHE_TTL_SECONDS = 300 +CONNECTION_CACHE_TTL_SECONDS = 30 +UPLOAD_CACHE_TTL_SECONDS = 900 +MAX_UPLOAD_CACHE_ITEMS = 8 + +_schema_cache: Dict[str, Dict[str, Any]] = {} +_connection_cache: Dict[str, Dict[str, Any]] = {} +_upload_cache: Dict[str, Dict[str, Any]] = {} +_cache_lock = threading.Lock() + class NL2SQLRequest(BaseModel): query: str = Field(..., description="User's natural language query") source: str = Field(..., description="Data source to query (postgres, clickhouse, upload)") @@ -95,8 +107,7 @@ def _resolve_upload_file_path(file_url: Optional[str]) -> Path: raise ValueError(f"Uploaded file not found: {safe_name}") return file_path -def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame: - file_path = _resolve_upload_file_path(file_url) +def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame: suffix = file_path.suffix.lower() if suffix == ".csv": return pd.read_csv(file_path) @@ -104,8 +115,7 @@ def _load_upload_dataframe(file_url: Optional[str]) -> pd.DataFrame: return pd.read_excel(file_path) raise ValueError(f"Unsupported uploaded file type: {suffix}") -def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]: - df = _load_upload_dataframe(file_url) +def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]: conn = duckdb.connect(":memory:") conn.register("uploaded_file", df) columns = conn.execute("DESCRIBE uploaded_file").fetchall() @@ -113,34 +123,104 @@ def _get_upload_schema(file_url: Optional[str]) -> Dict[str, List[str]]: conn.close() return schema -def _execute_upload_sql(sql_query: str, file_url: Optional[str]) -> List[Dict[str, Any]]: - df = _load_upload_dataframe(file_url) +def _get_upload_payload(file_url: Optional[str]) -> Dict[str, Any]: + file_path = _resolve_upload_file_path(file_url) + stat = file_path.stat() + cache_key = f"{file_path}:{int(stat.st_mtime)}:{stat.st_size}" + now = time.time() + with _cache_lock: + cached = _upload_cache.get(cache_key) + if cached and now < cached["expires_at"]: + return {"df": cached["df"], "schema": cached["schema"]} + df = _load_upload_dataframe_from_path(file_path) + schema = _build_upload_schema(df) + with _cache_lock: + if len(_upload_cache) >= MAX_UPLOAD_CACHE_ITEMS: + oldest_key = min(_upload_cache.keys(), key=lambda key: _upload_cache[key]["expires_at"]) + _upload_cache.pop(oldest_key, None) + _upload_cache[cache_key] = { + "df": df, + "schema": schema, + "expires_at": now + UPLOAD_CACHE_TTL_SECONDS, + } + return {"df": df, "schema": schema} + +def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]]: conn = duckdb.connect(":memory:") conn.register("uploaded_file", df) result_df = conn.execute(sql_query).df() conn.close() return result_df.to_dict(orient="records") +def _build_schema_cache_key(source: str, connector: Any) -> str: + if source == "postgres": + return f"postgres:{getattr(connector, 'db_url', '')}" + if source == "clickhouse": + return ( + f"clickhouse:{getattr(connector, 'host', '')}:{getattr(connector, 'port', '')}:" + f"{getattr(connector, 'user', '')}:{getattr(connector, 'database', '')}" + ) + return source + +def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]: + key = _build_schema_cache_key(source, connector) + now = time.time() + with _cache_lock: + cached = _schema_cache.get(key) + if cached and now < cached["expires_at"]: + return cached["schema"] + return None + +def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None: + key = _build_schema_cache_key(source, connector) + with _cache_lock: + _schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS} + +def _check_connection_with_cache(source: str, connector: Any) -> bool: + cache_key = _build_schema_cache_key(source, connector) + now = time.time() + with _cache_lock: + cached = _connection_cache.get(cache_key) + if cached and now < cached["expires_at"]: + return bool(cached["ok"]) + ok = connector.test_connection() + with _cache_lock: + _connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS} + return ok + async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: # 1. Get the connector and schema connector = None schema = {} + upload_df: Optional[pd.DataFrame] = None if request.source == "postgres": connector = postgres_connector elif request.source == "clickhouse": connector = clickhouse_connector elif request.source == "upload": try: - schema = _get_upload_schema(request.file_url) + upload_payload = _get_upload_payload(request.file_url) + upload_df = upload_payload["df"] + schema = upload_payload["schema"] except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}") else: return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}") if connector: - if not connector.test_connection(): + cached_schema = _get_cached_schema(request.source, connector) + if cached_schema: + schema = cached_schema + else: + if not _check_connection_with_cache(request.source, connector): + return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") + schema = connector.get_schema() + _set_cached_schema(request.source, connector, schema) + if connector and not schema: + if not _check_connection_with_cache(request.source, connector): return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") schema = connector.get_schema() + _set_cached_schema(request.source, connector, schema) schema_str = json.dumps(schema, indent=2) # 2. Get the active LLM config @@ -206,7 +286,9 @@ Let's think step by step. # 6. Execute SQL try: if request.source == "upload": - formatted_results = _execute_upload_sql(sql_query, request.file_url) + if upload_df is None: + upload_df = _get_upload_payload(request.file_url)["df"] + formatted_results = _execute_upload_sql(sql_query, upload_df) else: results = connector.execute_query(sql_query) # Convert results to list of dicts if not already (Postgres returns list of dicts, ClickHouse returns list of tuples)