beta
This commit is contained in:
+96
-18
@@ -14,6 +14,8 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.agent.chart import generate_chart
|
||||
|
||||
class NL2SQLRequest(BaseModel):
|
||||
query: str = Field(..., description="User's natural language query")
|
||||
@@ -23,6 +25,60 @@ class NL2SQLResponse(BaseModel):
|
||||
sql: str
|
||||
result: List[Dict[str, Any]]
|
||||
error: Optional[str] = None
|
||||
chart: Optional[ChartGenerationResponse] = None
|
||||
|
||||
# WrenAI-inspired SQL Rules
|
||||
DEFAULT_TEXT_TO_SQL_RULES = """
|
||||
### SQL RULES ###
|
||||
- ONLY USE SELECT statements, NO DELETE, UPDATE OR INSERT etc. statements that might change the data in the database.
|
||||
- ONLY USE the tables and columns mentioned in the database schema.
|
||||
- ONLY USE "*" if the user query asks for all the columns of a table.
|
||||
- ONLY CHOOSE columns belong to the tables mentioned in the database schema.
|
||||
- DON'T INCLUDE comments in the generated SQL query.
|
||||
- YOU MUST USE "JOIN" if you choose columns from multiple tables!
|
||||
- PREFER USING CTEs over subqueries.
|
||||
- When generating SQL query, always:
|
||||
- Put double quotes around column and table names.
|
||||
- Put single quotes around string literals.
|
||||
- Never quote numeric literals.
|
||||
For example: SELECT "customers"."customer_name" FROM "customers" WHERE "customers"."city" = 'Taipei' and "customers"."year" = 1992;
|
||||
- YOU MUST USE "lower(<table_name>.<column_name>) like lower(<value>)" function or "lower(<table_name>.<column_name>) = lower(<value>)" function for case-insensitive comparison!
|
||||
- Use "lower(<table_name>.<column_name>) LIKE lower(<value>)" when:
|
||||
- The user requests a pattern or partial match.
|
||||
- The value is not specific enough to be a single, exact value.
|
||||
- Wildcards (%) are needed to capture the pattern.
|
||||
- Use "lower(<table_name>.<column_name>) = lower(<value>)" when:
|
||||
- The user requests an exact, specific value.
|
||||
- There is no ambiguity or pattern in the value.
|
||||
- If the column is date/time related field, and it is a INT/BIGINT/DOUBLE/FLOAT type, please use the appropriate function mentioned in the SQL FUNCTIONS section to cast the column to "TIMESTAMP" type first before using it in the query
|
||||
- ALWAYS CAST the date/time related field to "TIMESTAMP WITH TIME ZONE" type when using them in the query
|
||||
- If the user asks for a specific date, please give the date range in SQL query
|
||||
- Aggregate functions are not allowed in the WHERE clause. Instead, they belong in the HAVING clause, which is used to filter after aggregation.
|
||||
- You can only add "ORDER BY" and "LIMIT" to the final "UNION" result.
|
||||
- For the ranking problem, you must use the ranking function, `DENSE_RANK()` to rank the results and then use `WHERE` clause to filter the results.
|
||||
- For the ranking problem, you must add the ranking column to the final SELECT clause.
|
||||
"""
|
||||
|
||||
SQL_GENERATION_SYSTEM_PROMPT = f"""
|
||||
You are a helpful assistant that converts natural language queries into ANSI SQL queries.
|
||||
|
||||
Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
|
||||
|
||||
### GENERAL RULES ###
|
||||
|
||||
1. YOU MUST FOLLOW the instructions strictly to generate the SQL query if the section of USER INSTRUCTIONS is available in user's input.
|
||||
2. YOU MUST FOLLOW SQL Rules if they are not contradicted with instructions.
|
||||
|
||||
{DEFAULT_TEXT_TO_SQL_RULES}
|
||||
|
||||
### FINAL ANSWER FORMAT ###
|
||||
The final answer must be a ANSI SQL query in JSON format:
|
||||
|
||||
{{
|
||||
"reasoning": <STEP_BY_STEP_REASONING_PLAN>,
|
||||
"sql": <SQL_QUERY_STRING>
|
||||
}}
|
||||
"""
|
||||
|
||||
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
# 1. Get the connector and schema
|
||||
@@ -53,35 +109,50 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
api_key=active_config.get("api_key"),
|
||||
api_base=active_config.get("api_base"),
|
||||
default_model=active_config.get("model"),
|
||||
extra_headers=active_config.get("extra_headers")
|
||||
extra_headers=active_config.get("extra_headers") or {},
|
||||
provider_name=active_config.get("provider")
|
||||
)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to initialize LLM provider: {e}")
|
||||
|
||||
# 4. Construct Prompt
|
||||
prompt = f"""You are an expert SQL generator.
|
||||
Given the following database schema for a {request.source} database:
|
||||
user_prompt = f"""
|
||||
### DATABASE SCHEMA ###
|
||||
{schema_str}
|
||||
|
||||
Write a SQL query to answer the following question:
|
||||
"{request.query}"
|
||||
### INPUTS ###
|
||||
User's Question: {request.query}
|
||||
Language: Chinese (Simplified)
|
||||
|
||||
Return ONLY the SQL query. Do not include any markdown formatting, explanations, or code blocks. Just the raw SQL string.
|
||||
Let's think step by step.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": SQL_GENERATION_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 5. Call LLM
|
||||
try:
|
||||
# provider.complete returns a string
|
||||
response = await provider.complete(prompt)
|
||||
sql_query = response.strip()
|
||||
# Remove potential markdown code blocks if the LLM ignores instructions
|
||||
if sql_query.startswith("```sql"):
|
||||
sql_query = sql_query[6:]
|
||||
if sql_query.startswith("```"):
|
||||
sql_query = sql_query[3:]
|
||||
if sql_query.endswith("```"):
|
||||
sql_query = sql_query[:-3]
|
||||
sql_query = sql_query.strip()
|
||||
response = await provider.chat(messages=messages)
|
||||
content = response.content.strip()
|
||||
|
||||
# Clean up code blocks
|
||||
if "```json" in content:
|
||||
content = content.split("```json")[1].split("```")[0]
|
||||
elif "```" in content:
|
||||
content = content.split("```")[1].split("```")[0]
|
||||
|
||||
content = content.strip()
|
||||
|
||||
try:
|
||||
result_json = json.loads(content)
|
||||
sql_query = result_json.get("sql", "").strip()
|
||||
reasoning = result_json.get("reasoning", "") # We can log this or return it if needed
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
@@ -100,7 +171,14 @@ Return ONLY the SQL query. Do not include any markdown formatting, explanations,
|
||||
# Let's assume we can just return the raw tuples for now or try to fetch column names.
|
||||
# For now, let's just return as list of lists/tuples if it's not a dict
|
||||
formatted_results = [list(row) for row in results]
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if formatted_results:
|
||||
# Only try to generate chart if we have results
|
||||
# Convert to list of dicts if possible, or pass as is
|
||||
chart_response = await generate_chart(formatted_results, request.query)
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results)
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user