add timeout
This commit is contained in:
@@ -136,7 +136,129 @@ CHART_EXAMPLES = """
|
||||
}
|
||||
"""
|
||||
|
||||
TEMPORAL_KEYWORDS = ("date", "time", "day", "month", "year", "日期", "时间", "月份", "年份")
|
||||
PIE_QUERY_KEYWORDS = ("占比", "构成", "比例", "份额", "分布", "pie")
|
||||
|
||||
|
||||
def _first_non_null(rows: List[Dict[str, Any]], key: str) -> Any:
|
||||
for row in rows:
|
||||
value = row.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _is_number(value: Any) -> bool:
|
||||
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def _looks_temporal_field(key: str, sample_value: Any) -> bool:
|
||||
lowered = key.lower()
|
||||
if any(token in lowered for token in TEMPORAL_KEYWORDS):
|
||||
return True
|
||||
if not isinstance(sample_value, str):
|
||||
return False
|
||||
text = sample_value.strip()
|
||||
patterns = [
|
||||
r"^\d{4}[-/]\d{1,2}([-/]\d{1,2})?$",
|
||||
r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}(:\d{2})?$",
|
||||
r"^\d{8}$",
|
||||
]
|
||||
return any(re.match(p, text) for p in patterns)
|
||||
|
||||
|
||||
def _encoding_title(field: str) -> str:
|
||||
return field.replace("_", " ").strip() or field
|
||||
|
||||
|
||||
def _fast_generate_chart(data: List[Dict[str, Any]], query: str) -> Optional[ChartGenerationResponse]:
|
||||
if not data or not isinstance(data[0], dict):
|
||||
return None
|
||||
columns = list(data[0].keys())
|
||||
if not columns:
|
||||
return None
|
||||
numeric_cols: List[str] = []
|
||||
temporal_cols: List[str] = []
|
||||
categorical_cols: List[str] = []
|
||||
sample_rows = data[:50]
|
||||
for col in columns:
|
||||
sample_value = _first_non_null(sample_rows, col)
|
||||
if sample_value is None:
|
||||
continue
|
||||
if _is_number(sample_value):
|
||||
numeric_cols.append(col)
|
||||
continue
|
||||
if _looks_temporal_field(col, sample_value):
|
||||
temporal_cols.append(col)
|
||||
continue
|
||||
categorical_cols.append(col)
|
||||
if not numeric_cols:
|
||||
return None
|
||||
|
||||
title = "查询结果可视化"
|
||||
query_lower = (query or "").lower()
|
||||
|
||||
if temporal_cols:
|
||||
x_col = temporal_cols[0]
|
||||
y_col = numeric_cols[0]
|
||||
chart_spec = {
|
||||
"title": title,
|
||||
"mark": {"type": "line"},
|
||||
"encoding": {
|
||||
"x": {"field": x_col, "type": "temporal", "timeUnit": "yearmonth", "title": _encoding_title(x_col)},
|
||||
"y": {"field": y_col, "type": "quantitative", "title": _encoding_title(y_col)},
|
||||
},
|
||||
}
|
||||
return ChartGenerationResponse(
|
||||
reasoning="已基于字段类型快速生成趋势图",
|
||||
chart_type="line",
|
||||
chart_spec=chart_spec,
|
||||
can_visualize=True,
|
||||
)
|
||||
|
||||
if categorical_cols:
|
||||
cat_col = categorical_cols[0]
|
||||
val_col = numeric_cols[0]
|
||||
unique_values = {str(row.get(cat_col)) for row in sample_rows if row.get(cat_col) is not None}
|
||||
use_pie = len(unique_values) <= 8 and any(token in query_lower for token in PIE_QUERY_KEYWORDS)
|
||||
if use_pie:
|
||||
chart_spec = {
|
||||
"title": title,
|
||||
"mark": {"type": "arc"},
|
||||
"encoding": {
|
||||
"theta": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
|
||||
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||
},
|
||||
}
|
||||
return ChartGenerationResponse(
|
||||
reasoning="已基于字段类型快速生成占比图",
|
||||
chart_type="pie",
|
||||
chart_spec=chart_spec,
|
||||
can_visualize=True,
|
||||
)
|
||||
chart_spec = {
|
||||
"title": title,
|
||||
"mark": {"type": "bar"},
|
||||
"encoding": {
|
||||
"x": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||
"y": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
|
||||
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||
},
|
||||
}
|
||||
return ChartGenerationResponse(
|
||||
reasoning="已基于字段类型快速生成对比图",
|
||||
chart_type="bar",
|
||||
chart_spec=chart_spec,
|
||||
can_visualize=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
|
||||
fast_result = _fast_generate_chart(data, query)
|
||||
if fast_result:
|
||||
return fast_result
|
||||
|
||||
active_config = get_active_llm_config()
|
||||
|
||||
if not active_config:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
@@ -33,6 +34,9 @@ MAX_UPLOAD_CACHE_ITEMS = 8
|
||||
NL2SQL_MAX_TOKENS = 900
|
||||
NL2SQL_TEMPERATURE = 0.1
|
||||
NL2SQL_REASONING_EFFORT = "low"
|
||||
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
|
||||
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
|
||||
NL2SQL_CHART_TIMEOUT_SECONDS = 45
|
||||
|
||||
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -341,11 +345,14 @@ Language: Chinese (Simplified)
|
||||
try:
|
||||
llm_started = time.perf_counter()
|
||||
await emit_progress("正在生成 SQL")
|
||||
response = await provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
content = response.content.strip()
|
||||
|
||||
@@ -366,19 +373,30 @@ Language: Chinese (Simplified)
|
||||
sql_query = content
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
timeout_stage = "sql_execution"
|
||||
sql_exec_started = time.perf_counter()
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_df = _get_upload_payload(request.file_url)["df"]
|
||||
formatted_results = _execute_upload_sql(sql_query, upload_df)
|
||||
timeout_stage = "sql_execution"
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
results = connector.execute_query(sql_query)
|
||||
timeout_stage = "sql_execution"
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.execute_query, sql_query),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
@@ -418,10 +436,18 @@ Language: Chinese (Simplified)
|
||||
if request.generate_chart and formatted_results:
|
||||
chart_started = time.perf_counter()
|
||||
await emit_progress("正在生成可视化方案")
|
||||
chart_response = await generate_chart(formatted_results, request.query)
|
||||
timeout_stage = "chart_generation"
|
||||
chart_response = await asyncio.wait_for(
|
||||
generate_chart(formatted_results, request.query),
|
||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||
)
|
||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except asyncio.TimeoutError:
|
||||
if timeout_stage == "chart_generation":
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user