feat: add streaming output
This commit is contained in:
+137
-132
@@ -375,130 +375,141 @@ Language: Chinese (Simplified)
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 5. Call LLM
|
||||
try:
|
||||
llm_started = time.perf_counter()
|
||||
await emit_progress("正在生成 SQL")
|
||||
response = None
|
||||
last_error = ""
|
||||
|
||||
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
|
||||
num_retries=0,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
except Exception as e:
|
||||
last_error = f"LLM generation failed: {e}"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
|
||||
if response.finish_reason == "error":
|
||||
last_error = response.content or "LLM Error"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
|
||||
|
||||
content = (response.content or "").strip()
|
||||
if not content:
|
||||
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
|
||||
|
||||
# Clean up code blocks
|
||||
if "```json" in content:
|
||||
content = content.split("```json")[1].split("```")[0]
|
||||
elif "```" in content:
|
||||
content = content.split("```")[1].split("```")[0]
|
||||
|
||||
content = content.strip()
|
||||
# 5. Call LLM & 6. Execute SQL (with Self-Correction Loop)
|
||||
MAX_SQL_EXEC_RETRIES = int(os.getenv("NL2SQL_MAX_EXEC_RETRIES", "2"))
|
||||
sql_query = ""
|
||||
formatted_results = []
|
||||
chart_response = None
|
||||
timeout_stage = "llm_generation"
|
||||
|
||||
for exec_attempt in range(MAX_SQL_EXEC_RETRIES + 1):
|
||||
try:
|
||||
result_json = json.loads(content)
|
||||
sql_query = result_json.get("sql", "").strip()
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
timeout_stage = "sql_execution"
|
||||
sql_exec_started = time.perf_counter()
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
timeout_stage = "sql_execution"
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
timeout_stage = "sql_execution"
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.execute_query, sql_query),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
# Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case)
|
||||
# If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types))
|
||||
# But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types)
|
||||
# Wait, client.execute(with_column_types=True) returns (data, columns_with_types)
|
||||
# Let's check what connector.execute_query returns.
|
||||
# PostgresConnector returns list of dicts.
|
||||
# ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs.
|
||||
# Let's handle the ClickHouse case explicitly if possible or make it generic.
|
||||
|
||||
# If results is list of tuples/lists, we need headers.
|
||||
# Postgres returns list of dicts, so we are good.
|
||||
# ClickHouse: if modified to return client.execute(..., with_column_types=True),
|
||||
# it returns `(result_rows, column_types_list)`.
|
||||
# So `results` here would be a tuple, not a list.
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
# Likely ClickHouse (rows, columns)
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
llm_started = time.perf_counter()
|
||||
if exec_attempt == 0:
|
||||
await emit_progress("正在生成 SQL")
|
||||
else:
|
||||
# Unknown format, try to return as is or empty
|
||||
formatted_results = []
|
||||
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
|
||||
await emit_progress(f"正在尝试修复 SQL ({exec_attempt}/{MAX_SQL_EXEC_RETRIES})")
|
||||
|
||||
response = None
|
||||
last_error = ""
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if request.generate_chart and formatted_results:
|
||||
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
|
||||
num_retries=0,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
|
||||
except Exception as e:
|
||||
last_error = f"LLM generation failed: {e}"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
|
||||
|
||||
if response.finish_reason == "error":
|
||||
last_error = response.content or "LLM Error"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=last_error or "LLM generation failed")
|
||||
|
||||
content = (response.content or "").strip()
|
||||
if not content:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error="LLM returned empty response")
|
||||
|
||||
# Clean up code blocks
|
||||
if "```json" in content:
|
||||
content = content.split("```json")[1].split("```")[0]
|
||||
elif "```" in content:
|
||||
content = content.split("```")[1].split("```")[0]
|
||||
|
||||
content = content.strip()
|
||||
|
||||
try:
|
||||
result_json = json.loads(content)
|
||||
sql_query = result_json.get("sql", "").strip()
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
timeout_stage = "sql_execution"
|
||||
sql_exec_started = time.perf_counter()
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.execute_query, sql_query),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
else:
|
||||
formatted_results = []
|
||||
|
||||
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
|
||||
break # Execution succeeded, break the retry loop
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
if exec_attempt < MAX_SQL_EXEC_RETRIES:
|
||||
await emit_progress(f"SQL 执行失败,准备自动修复 ({exec_attempt + 1}/{MAX_SQL_EXEC_RETRIES})")
|
||||
messages.append({"role": "assistant", "content": f"```json\n{{\"sql\": \"{sql_query}\"}}\n```"})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"The generated SQL failed to execute. Database error:\n{str(e)}\n\nPlease fix the SQL query to resolve this error and provide the corrected version following the exact same JSON format."
|
||||
})
|
||||
continue
|
||||
else:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed after {MAX_SQL_EXEC_RETRIES} retries: {e}")
|
||||
|
||||
# 7. Generate Chart
|
||||
if request.generate_chart and formatted_results:
|
||||
try:
|
||||
chart_started = time.perf_counter()
|
||||
await emit_progress("正在生成可视化方案")
|
||||
timeout_stage = "chart_generation"
|
||||
@@ -506,16 +517,8 @@ Language: Chinese (Simplified)
|
||||
generate_chart(formatted_results, request.query),
|
||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||
)
|
||||
if not chart_response or not chart_response.chart_spec:
|
||||
# Do not fallback automatically if the LLM explicitly decided not to or failed.
|
||||
# Just pass whatever it returned (or lack thereof)
|
||||
pass
|
||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except asyncio.TimeoutError:
|
||||
if timeout_stage == "chart_generation":
|
||||
except asyncio.TimeoutError:
|
||||
fallback_chart = ChartGenerationResponse(
|
||||
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
|
||||
chart_type="",
|
||||
@@ -523,6 +526,8 @@ Language: Chinese (Simplified)
|
||||
chart_spec=None,
|
||||
)
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||
except Exception as e:
|
||||
pass # Ignore chart generation errors, return data only
|
||||
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
|
||||
Reference in New Issue
Block a user