import sys import os import json import time import threading from pathlib import Path from typing import List, Optional, Dict, Any from pydantic import BaseModel, Field import duckdb import pandas as pd # Add project root to sys.path to allow importing nanobot PROJECT_ROOT = Path(__file__).resolve().parents[3] if str(PROJECT_ROOT) not in sys.path: sys.path.append(str(PROJECT_ROOT)) from nanobot.providers.litellm_provider import LiteLLMProvider from app.connectors.postgres import postgres_connector from app.connectors.clickhouse import clickhouse_connector from app.connectors.factory import get_connector from app.api.llm import _load_data as load_llm_config from app.schemas.chart import ChartGenerationResponse from app.agent.chart import generate_chart from app.database import SessionLocal from app.models.datasource import DataSource SCHEMA_CACHE_TTL_SECONDS = 300 CONNECTION_CACHE_TTL_SECONDS = 30 UPLOAD_CACHE_TTL_SECONDS = 900 MAX_UPLOAD_CACHE_ITEMS = 8 _schema_cache: Dict[str, Dict[str, Any]] = {} _connection_cache: Dict[str, Dict[str, Any]] = {} _upload_cache: Dict[str, Dict[str, Any]] = {} _cache_lock = threading.Lock() class NL2SQLRequest(BaseModel): query: str = Field(..., description="User's natural language query") source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})") file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload") session_id: Optional[str] = Field(None, description="Conversation session identifier") class NL2SQLResponse(BaseModel): sql: str result: List[Dict[str, Any]] error: Optional[str] = None chart: Optional[ChartGenerationResponse] = None # WrenAI-inspired SQL Rules DEFAULT_TEXT_TO_SQL_RULES = """ ### SQL RULES ### - ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database. - ONLY USE the tables and columns mentioned in the database schema. - ONLY USE "*" if the user query asks for all the columns of a table. - ONLY CHOOSE columns belong to the tables mentioned in the database schema. - DON'T INCLUDE comments in the generated SQL query. - YOU MUST USE "JOIN" if you choose columns from multiple tables! - PREFER USING CTEs over subqueries. - When generating SQL query, always: - Put double quotes around column and table names. - Put single quotes around string literals. - Never quote numeric literals. For example: SELECT "customers"."customer_name" FROM "customers" WHERE "customers"."city" = 'Taipei' and "customers"."year" = 1992; - YOU MUST USE "lower(.) like lower()" function or "lower(.) = lower()" function for case-insensitive comparison! - Use "lower(.) LIKE lower()" when: - The user requests a pattern or partial match. - The value is not specific enough to be a single, exact value. - Wildcards (%) are needed to capture the pattern. - Use "lower(.) = lower()" when: - The user requests an exact, specific value. - There is no ambiguity or pattern in the value. - If the column is date/time related field, and it is a INT/BIGINT/DOUBLE/FLOAT type, please use the appropriate function mentioned in the SQL FUNCTIONS section to cast the column to "TIMESTAMP" type first before using it in the query - ALWAYS CAST the date/time related field to "TIMESTAMP WITH TIME ZONE" type when using them in the query - If the user asks for a specific date, please give the date range in SQL query - Aggregate functions are not allowed in the WHERE clause. Instead, they belong in the HAVING clause, which is used to filter after aggregation. - You can only add "ORDER BY" and "LIMIT" to the final "UNION" result. - For the ranking problem, you must use the ranking function, `DENSE_RANK()` to rank the results and then use `WHERE` clause to filter the results. - For the ranking problem, you must add the ranking column to the final SELECT clause. """ SQL_GENERATION_SYSTEM_PROMPT = f""" You are a helpful assistant that converts natural language queries into ANSI SQL queries. Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. ### GENERAL RULES ### 1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user's input. 2. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions. {DEFAULT_TEXT_TO_SQL_RULES} ### FINAL ANSWER FORMAT ### The final answer must be a ANSI SQL query in JSON format: {{ "reasoning": , "sql": }} """ def _resolve_upload_file_path(file_url: Optional[str]) -> Path: if not file_url or not file_url.startswith("local://"): raise ValueError("Invalid uploaded file URL") raw_name = file_url.replace("local://", "", 1) safe_name = os.path.basename(raw_name) upload_dir = Path(__file__).resolve().parents[2] / "data" / "uploads" file_path = upload_dir / safe_name if not file_path.exists(): raise ValueError(f"Uploaded file not found: {safe_name}") return file_path def _load_upload_dataframe_from_path(file_path: Path) -> pd.DataFrame: suffix = file_path.suffix.lower() if suffix == ".csv": return pd.read_csv(file_path) if suffix in [".xls", ".xlsx"]: return pd.read_excel(file_path) if suffix == ".parquet": return pd.read_parquet(file_path) raise ValueError(f"Unsupported uploaded file type: {suffix}") def _build_upload_schema(df: pd.DataFrame) -> Dict[str, List[str]]: conn = duckdb.connect(":memory:") conn.register("uploaded_file", df) columns = conn.execute("DESCRIBE uploaded_file").fetchall() schema = {"uploaded_file": [f"{col[0]} ({col[1]})" for col in columns]} conn.close() return schema def _get_upload_payload(file_url: Optional[str]) -> Dict[str, Any]: file_path = _resolve_upload_file_path(file_url) stat = file_path.stat() cache_key = f"{file_path}:{int(stat.st_mtime)}:{stat.st_size}" now = time.time() with _cache_lock: cached = _upload_cache.get(cache_key) if cached and now < cached["expires_at"]: return {"df": cached["df"], "schema": cached["schema"]} df = _load_upload_dataframe_from_path(file_path) schema = _build_upload_schema(df) with _cache_lock: if len(_upload_cache) >= MAX_UPLOAD_CACHE_ITEMS: oldest_key = min(_upload_cache.keys(), key=lambda key: _upload_cache[key]["expires_at"]) _upload_cache.pop(oldest_key, None) _upload_cache[cache_key] = { "df": df, "schema": schema, "expires_at": now + UPLOAD_CACHE_TTL_SECONDS, } return {"df": df, "schema": schema} def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]]: conn = duckdb.connect(":memory:") conn.register("uploaded_file", df) result_df = conn.execute(sql_query).df() conn.close() return result_df.to_dict(orient="records") def _build_schema_cache_key(source: str, connector: Any) -> str: # If source is ds:ID, that's already a good key if source.startswith("ds:"): return source if source == "postgres": return f"postgres:{getattr(connector, 'db_url', '')}" if source == "clickhouse": return ( f"clickhouse:{getattr(connector, 'host', '')}:{getattr(connector, 'port', '')}:" f"{getattr(connector, 'user', '')}:{getattr(connector, 'database', '')}" ) return source def _get_cached_schema(source: str, connector: Any) -> Optional[Dict[str, List[str]]]: key = _build_schema_cache_key(source, connector) now = time.time() with _cache_lock: cached = _schema_cache.get(key) if cached and now < cached["expires_at"]: return cached["schema"] return None def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[str]]) -> None: key = _build_schema_cache_key(source, connector) with _cache_lock: _schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS} def _check_connection_with_cache(source: str, connector: Any) -> bool: cache_key = _build_schema_cache_key(source, connector) now = time.time() with _cache_lock: cached = _connection_cache.get(cache_key) if cached and now < cached["expires_at"]: return bool(cached["ok"]) ok = connector.test_connection() with _cache_lock: _connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS} return ok async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: # 1. Get the connector and schema connector = None schema = {} upload_df: Optional[pd.DataFrame] = None if request.source == "postgres": connector = postgres_connector elif request.source == "clickhouse": connector = clickhouse_connector elif request.source == "upload": try: upload_payload = _get_upload_payload(request.file_url) upload_df = upload_payload["df"] schema = upload_payload["schema"] except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}") elif request.source.startswith("ds:"): try: ds_id = int(request.source.split(":")[1]) db = SessionLocal() try: ds = db.query(DataSource).filter(DataSource.id == ds_id).first() if not ds: return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}") connector = get_connector(ds) finally: db.close() except ValueError: return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}") except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"Failed to load data source: {e}") else: return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}") if connector: cached_schema = _get_cached_schema(request.source, connector) if cached_schema: schema = cached_schema else: if not _check_connection_with_cache(request.source, connector): return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") schema = connector.get_schema() _set_cached_schema(request.source, connector, schema) if connector and not schema: # Double check in case schema was empty but connection is ok (e.g. empty db) if not _check_connection_with_cache(request.source, connector): return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") schema = connector.get_schema() _set_cached_schema(request.source, connector, schema) schema_str = json.dumps(schema, indent=2) # 2. Get the active LLM config llm_configs = load_llm_config() active_config = next((c for c in llm_configs if c.get("is_active")), None) if not active_config: return NL2SQLResponse(sql="", result=[], error="No active LLM configuration found") # 3. Initialize Provider try: provider = LiteLLMProvider( api_key=active_config.get("api_key"), api_base=active_config.get("api_base"), default_model=active_config.get("model"), extra_headers=active_config.get("extra_headers") or {}, provider_name=active_config.get("provider") ) except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"Failed to initialize LLM provider: {e}") # 4. Construct Prompt user_prompt = f""" ### DATABASE SCHEMA ### {schema_str} ### INPUTS ### User's Question: {request.query} Language: Chinese (Simplified) Let's think step by step. """ messages = [ {"role": "system", "content": SQL_GENERATION_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt} ] # 5. Call LLM try: response = await provider.chat(messages=messages) content = response.content.strip() # Clean up code blocks if "```json" in content: content = content.split("```json")[1].split("```")[0] elif "```" in content: content = content.split("```")[1].split("```")[0] content = content.strip() try: result_json = json.loads(content) sql_query = result_json.get("sql", "").strip() reasoning = result_json.get("reasoning", "") # We can log this or return it if needed except json.JSONDecodeError: # Fallback if LLM doesn't return valid JSON despite instructions sql_query = content except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}") # 6. Execute SQL try: if request.source == "upload": if upload_df is None: upload_df = _get_upload_payload(request.file_url)["df"] formatted_results = _execute_upload_sql(sql_query, upload_df) else: results = connector.execute_query(sql_query) # Format results formatted_results = [] if isinstance(results, list): if results and isinstance(results[0], dict): formatted_results = results elif results and isinstance(results[0], (list, tuple)): # Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case) # If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types)) # But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types) # Wait, client.execute(with_column_types=True) returns (data, columns_with_types) # Let's check what connector.execute_query returns. # PostgresConnector returns list of dicts. # ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs. # Let's handle the ClickHouse case explicitly if possible or make it generic. # If results is list of tuples/lists, we need headers. # Postgres returns list of dicts, so we are good. # ClickHouse: if modified to return client.execute(..., with_column_types=True), # it returns `(result_rows, column_types_list)`. # So `results` here would be a tuple, not a list. formatted_results = [list(row) for row in results] else: formatted_results = results elif isinstance(results, tuple) and len(results) == 2: # Likely ClickHouse (rows, columns) rows, cols = results col_names = [c[0] for c in cols] formatted_results = [dict(zip(col_names, row)) for row in rows] else: # Unknown format, try to return as is or empty formatted_results = [] # 7. Generate Chart chart_response = None if formatted_results: # Only try to generate chart if we have results # Convert to list of dicts if possible, or pass as is chart_response = await generate_chart(formatted_results, request.query) return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) except Exception as e: return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")