add timeout
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
@@ -33,6 +34,9 @@ MAX_UPLOAD_CACHE_ITEMS = 8
|
||||
NL2SQL_MAX_TOKENS = 900
|
||||
NL2SQL_TEMPERATURE = 0.1
|
||||
NL2SQL_REASONING_EFFORT = "low"
|
||||
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
|
||||
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
|
||||
NL2SQL_CHART_TIMEOUT_SECONDS = 45
|
||||
|
||||
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -341,11 +345,14 @@ Language: Chinese (Simplified)
|
||||
try:
|
||||
llm_started = time.perf_counter()
|
||||
await emit_progress("正在生成 SQL")
|
||||
response = await provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
content = response.content.strip()
|
||||
|
||||
@@ -366,19 +373,30 @@ Language: Chinese (Simplified)
|
||||
sql_query = content
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
timeout_stage = "sql_execution"
|
||||
sql_exec_started = time.perf_counter()
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_df = _get_upload_payload(request.file_url)["df"]
|
||||
formatted_results = _execute_upload_sql(sql_query, upload_df)
|
||||
timeout_stage = "sql_execution"
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
results = connector.execute_query(sql_query)
|
||||
timeout_stage = "sql_execution"
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.execute_query, sql_query),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
@@ -418,10 +436,18 @@ Language: Chinese (Simplified)
|
||||
if request.generate_chart and formatted_results:
|
||||
chart_started = time.perf_counter()
|
||||
await emit_progress("正在生成可视化方案")
|
||||
chart_response = await generate_chart(formatted_results, request.query)
|
||||
timeout_stage = "chart_generation"
|
||||
chart_response = await asyncio.wait_for(
|
||||
generate_chart(formatted_results, request.query),
|
||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||
)
|
||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except asyncio.TimeoutError:
|
||||
if timeout_stage == "chart_generation":
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user