2026-03-14 15:44:48 +08:00
import sys
import os
import json
2026-03-15 18:36:28 +08:00
import time
import threading
2026-03-14 15:44:48 +08:00
from pathlib import Path
from typing import List , Optional , Dict , Any
from pydantic import BaseModel , Field
2026-03-15 10:49:37 +08:00
import duckdb
import pandas as pd
2026-03-14 15:44:48 +08:00
# Add project root to sys.path to allow importing nanobot
PROJECT_ROOT = Path ( __file__ ) . resolve ( ) . parents [ 3 ]
if str ( PROJECT_ROOT ) not in sys . path :
sys . path . append ( str ( PROJECT_ROOT ) )
from nanobot . providers . litellm_provider import LiteLLMProvider
from app . connectors . postgres import postgres_connector
from app . connectors . clickhouse import clickhouse_connector
2026-03-15 19:36:02 +08:00
from app . connectors . factory import get_connector
2026-03-14 15:44:48 +08:00
from app . api . llm import _load_data as load_llm_config
2026-03-15 01:29:36 +08:00
from app . schemas . chart import ChartGenerationResponse
from app . agent . chart import generate_chart
2026-03-15 19:36:02 +08:00
from app . database import SessionLocal
from app . models . datasource import DataSource
2026-03-15 20:48:40 +08:00
from app . core . files import resolve_upload_file_path
2026-03-16 22:18:23 +08:00
from app . services . mdl import MDLService
2026-03-14 15:44:48 +08:00
2026-03-15 18:36:28 +08:00
SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
UPLOAD_CACHE_TTL_SECONDS = 900
MAX_UPLOAD_CACHE_ITEMS = 8
_schema_cache : Dict [ str , Dict [ str , Any ] ] = { }
_connection_cache : Dict [ str , Dict [ str , Any ] ] = { }
_upload_cache : Dict [ str , Dict [ str , Any ] ] = { }
_cache_lock = threading . Lock ( )
2026-03-14 15:44:48 +08:00
class NL2SQLRequest ( BaseModel ) :
query : str = Field ( . . . , description = " User ' s natural language query " )
2026-03-15 19:36:02 +08:00
source : str = Field ( . . . , description = " Data source to query (postgres, clickhouse, upload, ds: {id} ) " )
2026-03-15 10:49:37 +08:00
file_url : Optional [ str ] = Field ( None , description = " Uploaded file URL when source is upload " )
2026-03-15 17:57:09 +08:00
session_id : Optional [ str ] = Field ( None , description = " Conversation session identifier " )
2026-03-14 15:44:48 +08:00
class NL2SQLResponse ( BaseModel ) :
sql : str
result : List [ Dict [ str , Any ] ]
error : Optional [ str ] = None
2026-03-15 01:29:36 +08:00
chart : Optional [ ChartGenerationResponse ] = None
# WrenAI-inspired SQL Rules
DEFAULT_TEXT_TO_SQL_RULES = """
### SQL RULES ###
- ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database.
- ONLY USE the tables and columns mentioned in the database schema.
- ONLY USE " * " if the user query asks for all the columns of a table.
- ONLY CHOOSE columns belong to the tables mentioned in the database schema.
- DON ' T INCLUDE comments in the generated SQL query.
- YOU MUST USE " JOIN " if you choose columns from multiple tables!
- PREFER USING CTEs over subqueries.
- When generating SQL query, always:
- Put double quotes around column and table names.
- Put single quotes around string literals.
- Never quote numeric literals.
For example: SELECT " customers " . " customer_name " FROM " customers " WHERE " customers " . " city " = ' Taipei ' and " customers " . " year " = 1992;
- YOU MUST USE " lower(<table_name>.<column_name>) like lower(<value>) " function or " lower(<table_name>.<column_name>) = lower(<value>) " function for case-insensitive comparison!
- Use " lower(<table_name>.<column_name>) LIKE lower(<value>) " when:
- The user requests a pattern or partial match.
- The value is not specific enough to be a single, exact value.
- Wildcards ( % ) are needed to capture the pattern.
- Use " lower(<table_name>.<column_name>) = lower(<value>) " when:
- The user requests an exact, specific value.
- There is no ambiguity or pattern in the value.
- If the column is date/time related field, and it is a INT/BIGINT/DOUBLE/FLOAT type, please use the appropriate function mentioned in the SQL FUNCTIONS section to cast the column to " TIMESTAMP " type first before using it in the query
- ALWAYS CAST the date/time related field to " TIMESTAMP WITH TIME ZONE " type when using them in the query
- If the user asks for a specific date, please give the date range in SQL query
- Aggregate functions are not allowed in the WHERE clause. Instead, they belong in the HAVING clause, which is used to filter after aggregation.
- You can only add " ORDER BY " and " LIMIT " to the final " UNION " result.
- For the ranking problem, you must use the ranking function, `DENSE_RANK()` to rank the results and then use `WHERE` clause to filter the results.
- For the ranking problem, you must add the ranking column to the final SELECT clause.
"""
SQL_GENERATION_SYSTEM_PROMPT = f """
You are a helpful assistant that converts natural language queries into ANSI SQL queries.
Given user ' s question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
### GENERAL RULES ###
1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user ' s input.
2. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions.
{ DEFAULT_TEXT_TO_SQL_RULES }
### FINAL ANSWER FORMAT ###
The final answer must be a ANSI SQL query in JSON format:
{{
" reasoning " : <STEP_BY_STEP_REASONING_PLAN>,
" sql " : <SQL_QUERY_STRING>
}}
"""
2026-03-14 15:44:48 +08:00
2026-03-15 10:49:37 +08:00
def _resolve_upload_file_path ( file_url : Optional [ str ] ) - > Path :
2026-03-15 20:48:40 +08:00
try :
return resolve_upload_file_path ( file_url )
except ValueError as e :
raise ValueError ( f " Invalid uploaded file URL: { e } " )
2026-03-15 10:49:37 +08:00
2026-03-15 18:36:28 +08:00
def _load_upload_dataframe_from_path ( file_path : Path ) - > pd . DataFrame :
2026-03-15 10:49:37 +08:00
suffix = file_path . suffix . lower ( )
if suffix == " .csv " :
return pd . read_csv ( file_path )
if suffix in [ " .xls " , " .xlsx " ] :
return pd . read_excel ( file_path )
2026-03-15 19:36:02 +08:00
if suffix == " .parquet " :
return pd . read_parquet ( file_path )
2026-03-15 10:49:37 +08:00
raise ValueError ( f " Unsupported uploaded file type: { suffix } " )
2026-03-16 22:18:23 +08:00
def _build_upload_schema ( df : pd . DataFrame ) - > Dict [ str , List [ Dict [ str , str ] ] ] :
2026-03-15 10:49:37 +08:00
conn = duckdb . connect ( " :memory: " )
conn . register ( " uploaded_file " , df )
columns = conn . execute ( " DESCRIBE uploaded_file " ) . fetchall ( )
2026-03-16 22:18:23 +08:00
schema = { " uploaded_file " : [ { " name " : col [ 0 ] , " type " : col [ 1 ] } for col in columns ] }
2026-03-15 10:49:37 +08:00
conn . close ( )
return schema
2026-03-15 18:36:28 +08:00
def _get_upload_payload ( file_url : Optional [ str ] ) - > Dict [ str , Any ] :
file_path = _resolve_upload_file_path ( file_url )
stat = file_path . stat ( )
cache_key = f " { file_path } : { int ( stat . st_mtime ) } : { stat . st_size } "
now = time . time ( )
with _cache_lock :
cached = _upload_cache . get ( cache_key )
if cached and now < cached [ " expires_at " ] :
return { " df " : cached [ " df " ] , " schema " : cached [ " schema " ] }
df = _load_upload_dataframe_from_path ( file_path )
schema = _build_upload_schema ( df )
with _cache_lock :
if len ( _upload_cache ) > = MAX_UPLOAD_CACHE_ITEMS :
oldest_key = min ( _upload_cache . keys ( ) , key = lambda key : _upload_cache [ key ] [ " expires_at " ] )
_upload_cache . pop ( oldest_key , None )
_upload_cache [ cache_key ] = {
" df " : df ,
" schema " : schema ,
" expires_at " : now + UPLOAD_CACHE_TTL_SECONDS ,
}
return { " df " : df , " schema " : schema }
def _execute_upload_sql ( sql_query : str , df : pd . DataFrame ) - > List [ Dict [ str , Any ] ] :
2026-03-15 10:49:37 +08:00
conn = duckdb . connect ( " :memory: " )
conn . register ( " uploaded_file " , df )
result_df = conn . execute ( sql_query ) . df ( )
conn . close ( )
return result_df . to_dict ( orient = " records " )
2026-03-15 18:36:28 +08:00
def _build_schema_cache_key ( source : str , connector : Any ) - > str :
2026-03-15 19:36:02 +08:00
# If source is ds:ID, that's already a good key
if source . startswith ( " ds: " ) :
return source
2026-03-15 18:36:28 +08:00
if source == " postgres " :
return f " postgres: { getattr ( connector , ' db_url ' , ' ' ) } "
if source == " clickhouse " :
return (
f " clickhouse: { getattr ( connector , ' host ' , ' ' ) } : { getattr ( connector , ' port ' , ' ' ) } : "
f " { getattr ( connector , ' user ' , ' ' ) } : { getattr ( connector , ' database ' , ' ' ) } "
)
return source
2026-03-16 22:18:23 +08:00
def _get_cached_schema ( source : str , connector : Any ) - > Optional [ Dict [ str , List [ Dict [ str , str ] ] ] ] :
2026-03-15 18:36:28 +08:00
key = _build_schema_cache_key ( source , connector )
now = time . time ( )
with _cache_lock :
cached = _schema_cache . get ( key )
if cached and now < cached [ " expires_at " ] :
return cached [ " schema " ]
return None
2026-03-16 22:18:23 +08:00
def _set_cached_schema ( source : str , connector : Any , schema : Dict [ str , List [ Dict [ str , str ] ] ] ) - > None :
2026-03-15 18:36:28 +08:00
key = _build_schema_cache_key ( source , connector )
with _cache_lock :
_schema_cache [ key ] = { " schema " : schema , " expires_at " : time . time ( ) + SCHEMA_CACHE_TTL_SECONDS }
def _check_connection_with_cache ( source : str , connector : Any ) - > bool :
cache_key = _build_schema_cache_key ( source , connector )
now = time . time ( )
with _cache_lock :
cached = _connection_cache . get ( cache_key )
if cached and now < cached [ " expires_at " ] :
return bool ( cached [ " ok " ] )
ok = connector . test_connection ( )
with _cache_lock :
_connection_cache [ cache_key ] = { " ok " : ok , " expires_at " : now + CONNECTION_CACHE_TTL_SECONDS }
return ok
2026-03-14 15:44:48 +08:00
async def process_nl2sql ( request : NL2SQLRequest ) - > NL2SQLResponse :
# 1. Get the connector and schema
connector = None
2026-03-15 10:49:37 +08:00
schema = { }
2026-03-15 18:36:28 +08:00
upload_df : Optional [ pd . DataFrame ] = None
2026-03-15 19:36:02 +08:00
2026-03-14 15:44:48 +08:00
if request . source == " postgres " :
connector = postgres_connector
elif request . source == " clickhouse " :
connector = clickhouse_connector
2026-03-15 10:49:37 +08:00
elif request . source == " upload " :
try :
2026-03-15 18:36:28 +08:00
upload_payload = _get_upload_payload ( request . file_url )
upload_df = upload_payload [ " df " ]
schema = upload_payload [ " schema " ]
2026-03-15 10:49:37 +08:00
except Exception as e :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " Failed to load uploaded file: { e } " )
2026-03-15 19:36:02 +08:00
elif request . source . startswith ( " ds: " ) :
try :
ds_id = int ( request . source . split ( " : " ) [ 1 ] )
db = SessionLocal ( )
try :
ds = db . query ( DataSource ) . filter ( DataSource . id == ds_id ) . first ( )
if not ds :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " Data source not found: { request . source } " )
connector = get_connector ( ds )
finally :
db . close ( )
except ValueError :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " Invalid data source ID: { request . source } " )
except Exception as e :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " Failed to load data source: { e } " )
2026-03-14 15:44:48 +08:00
else :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " Unsupported data source: { request . source } " )
2026-03-15 10:49:37 +08:00
if connector :
2026-03-15 18:36:28 +08:00
cached_schema = _get_cached_schema ( request . source , connector )
if cached_schema :
schema = cached_schema
else :
if not _check_connection_with_cache ( request . source , connector ) :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " Failed to connect to { request . source } " )
schema = connector . get_schema ( )
_set_cached_schema ( request . source , connector , schema )
2026-03-15 19:36:02 +08:00
2026-03-15 18:36:28 +08:00
if connector and not schema :
2026-03-15 19:36:02 +08:00
# Double check in case schema was empty but connection is ok (e.g. empty db)
if not _check_connection_with_cache ( request . source , connector ) :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " Failed to connect to { request . source } " )
schema = connector . get_schema ( )
_set_cached_schema ( request . source , connector , schema )
2026-03-14 15:44:48 +08:00
schema_str = json . dumps ( schema , indent = 2 )
2026-03-16 22:18:23 +08:00
# Try to load MDL context
mdl_context = " "
if request . source . startswith ( " ds: " ) :
try :
ds_id = int ( request . source . split ( " : " ) [ 1 ] )
mdl = MDLService . get_mdl ( ds_id )
if mdl :
mdl_lines = [ " \n ### SEMANTIC MODEL (WrenMDL) ### " ]
mdl_lines . append ( " MODELS: " )
for model in mdl . models :
table_ref = model . tableReference . table if model . tableReference else model . name
desc = f " - Description: { model . properties . get ( ' description ' , ' ' ) } " if model . properties . get ( ' description ' ) else " "
mdl_lines . append ( f " - Model: { model . name } (Table: { table_ref } ) { desc } " )
if model . columns :
mdl_lines . append ( " Columns: " )
for col in model . columns :
col_desc = f " ( { col . properties . get ( ' description ' ) } ) " if col . properties . get ( ' description ' ) else " "
expr = f " [Calculated: { col . expression } ] " if col . isCalculated else " "
mdl_lines . append ( f " - { col . name } ( { col . type } ) { col_desc } { expr } " )
if mdl . relationships :
mdl_lines . append ( " \n RELATIONSHIPS: " )
for rel in mdl . relationships :
mdl_lines . append ( f " - { rel . name } : { rel . joinType } between { rel . models } ON { rel . condition } " )
mdl_context = " \n " . join ( mdl_lines )
except Exception as e :
print ( f " Failed to load MDL: { e } " )
2026-03-14 15:44:48 +08:00
# 2. Get the active LLM config
llm_configs = load_llm_config ( )
active_config = next ( ( c for c in llm_configs if c . get ( " is_active " ) ) , None )
if not active_config :
return NL2SQLResponse ( sql = " " , result = [ ] , error = " No active LLM configuration found " )
# 3. Initialize Provider
try :
provider = LiteLLMProvider (
api_key = active_config . get ( " api_key " ) ,
api_base = active_config . get ( " api_base " ) ,
default_model = active_config . get ( " model " ) ,
2026-03-15 01:29:36 +08:00
extra_headers = active_config . get ( " extra_headers " ) or { } ,
provider_name = active_config . get ( " provider " )
2026-03-14 15:44:48 +08:00
)
except Exception as e :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " Failed to initialize LLM provider: { e } " )
# 4. Construct Prompt
2026-03-15 01:29:36 +08:00
user_prompt = f """
### DATABASE SCHEMA ###
2026-03-14 15:44:48 +08:00
{ schema_str }
2026-03-16 22:18:23 +08:00
{ mdl_context }
2026-03-14 15:44:48 +08:00
2026-03-15 01:29:36 +08:00
### INPUTS ###
User ' s Question: { request . query }
Language: Chinese (Simplified)
2026-03-14 15:44:48 +08:00
2026-03-15 01:29:36 +08:00
Let ' s think step by step.
2026-03-14 15:44:48 +08:00
"""
2026-03-15 01:29:36 +08:00
messages = [
{ " role " : " system " , " content " : SQL_GENERATION_SYSTEM_PROMPT } ,
{ " role " : " user " , " content " : user_prompt }
]
2026-03-14 15:44:48 +08:00
# 5. Call LLM
try :
2026-03-15 01:29:36 +08:00
response = await provider . chat ( messages = messages )
content = response . content . strip ( )
# Clean up code blocks
if " ```json " in content :
content = content . split ( " ```json " ) [ 1 ] . split ( " ``` " ) [ 0 ]
elif " ``` " in content :
content = content . split ( " ``` " ) [ 1 ] . split ( " ``` " ) [ 0 ]
content = content . strip ( )
try :
result_json = json . loads ( content )
sql_query = result_json . get ( " sql " , " " ) . strip ( )
reasoning = result_json . get ( " reasoning " , " " ) # We can log this or return it if needed
except json . JSONDecodeError :
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
2026-03-14 15:44:48 +08:00
except Exception as e :
return NL2SQLResponse ( sql = " " , result = [ ] , error = f " LLM generation failed: { e } " )
# 6. Execute SQL
try :
2026-03-15 10:49:37 +08:00
if request . source == " upload " :
2026-03-15 18:36:28 +08:00
if upload_df is None :
upload_df = _get_upload_payload ( request . file_url ) [ " df " ]
formatted_results = _execute_upload_sql ( sql_query , upload_df )
2026-03-15 10:49:37 +08:00
else :
results = connector . execute_query ( sql_query )
2026-03-15 19:36:02 +08:00
# Format results
2026-03-15 10:49:37 +08:00
formatted_results = [ ]
2026-03-15 19:36:02 +08:00
if isinstance ( results , list ) :
if results and isinstance ( results [ 0 ] , dict ) :
formatted_results = results
elif results and isinstance ( results [ 0 ] , ( list , tuple ) ) :
# Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case)
# If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types))
# But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types)
# Wait, client.execute(with_column_types=True) returns (data, columns_with_types)
# Let's check what connector.execute_query returns.
# PostgresConnector returns list of dicts.
# ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs.
# Let's handle the ClickHouse case explicitly if possible or make it generic.
# If results is list of tuples/lists, we need headers.
# Postgres returns list of dicts, so we are good.
# ClickHouse: if modified to return client.execute(..., with_column_types=True),
# it returns `(result_rows, column_types_list)`.
# So `results` here would be a tuple, not a list.
formatted_results = [ list ( row ) for row in results ]
else :
formatted_results = results
elif isinstance ( results , tuple ) and len ( results ) == 2 :
# Likely ClickHouse (rows, columns)
rows , cols = results
col_names = [ c [ 0 ] for c in cols ]
formatted_results = [ dict ( zip ( col_names , row ) ) for row in rows ]
else :
# Unknown format, try to return as is or empty
formatted_results = [ ]
2026-03-15 01:29:36 +08:00
# 7. Generate Chart
chart_response = None
if formatted_results :
# Only try to generate chart if we have results
# Convert to list of dicts if possible, or pass as is
chart_response = await generate_chart ( formatted_results , request . query )
2026-03-14 15:44:48 +08:00
2026-03-15 01:29:36 +08:00
return NL2SQLResponse ( sql = sql_query , result = formatted_results , chart = chart_response )
2026-03-14 15:44:48 +08:00
except Exception as e :
return NL2SQLResponse ( sql = sql_query , result = [ ] , error = f " SQL execution failed: { e } " )