add timeout

This commit is contained in:
qixinbo
2026-03-17 22:09:47 +08:00
parent 49d38692cd
commit f03a653112
2 changed files with 156 additions and 8 deletions
+122
View File
@@ -136,7 +136,129 @@ CHART_EXAMPLES = """
}
"""
TEMPORAL_KEYWORDS = ("date", "time", "day", "month", "year", "日期", "时间", "月份", "年份")
PIE_QUERY_KEYWORDS = ("占比", "构成", "比例", "份额", "分布", "pie")
def _first_non_null(rows: List[Dict[str, Any]], key: str) -> Any:
for row in rows:
value = row.get(key)
if value is not None:
return value
return None
def _is_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)
def _looks_temporal_field(key: str, sample_value: Any) -> bool:
lowered = key.lower()
if any(token in lowered for token in TEMPORAL_KEYWORDS):
return True
if not isinstance(sample_value, str):
return False
text = sample_value.strip()
patterns = [
r"^\d{4}[-/]\d{1,2}([-/]\d{1,2})?$",
r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}(:\d{2})?$",
r"^\d{8}$",
]
return any(re.match(p, text) for p in patterns)
def _encoding_title(field: str) -> str:
return field.replace("_", " ").strip() or field
def _fast_generate_chart(data: List[Dict[str, Any]], query: str) -> Optional[ChartGenerationResponse]:
if not data or not isinstance(data[0], dict):
return None
columns = list(data[0].keys())
if not columns:
return None
numeric_cols: List[str] = []
temporal_cols: List[str] = []
categorical_cols: List[str] = []
sample_rows = data[:50]
for col in columns:
sample_value = _first_non_null(sample_rows, col)
if sample_value is None:
continue
if _is_number(sample_value):
numeric_cols.append(col)
continue
if _looks_temporal_field(col, sample_value):
temporal_cols.append(col)
continue
categorical_cols.append(col)
if not numeric_cols:
return None
title = "查询结果可视化"
query_lower = (query or "").lower()
if temporal_cols:
x_col = temporal_cols[0]
y_col = numeric_cols[0]
chart_spec = {
"title": title,
"mark": {"type": "line"},
"encoding": {
"x": {"field": x_col, "type": "temporal", "timeUnit": "yearmonth", "title": _encoding_title(x_col)},
"y": {"field": y_col, "type": "quantitative", "title": _encoding_title(y_col)},
},
}
return ChartGenerationResponse(
reasoning="已基于字段类型快速生成趋势图",
chart_type="line",
chart_spec=chart_spec,
can_visualize=True,
)
if categorical_cols:
cat_col = categorical_cols[0]
val_col = numeric_cols[0]
unique_values = {str(row.get(cat_col)) for row in sample_rows if row.get(cat_col) is not None}
use_pie = len(unique_values) <= 8 and any(token in query_lower for token in PIE_QUERY_KEYWORDS)
if use_pie:
chart_spec = {
"title": title,
"mark": {"type": "arc"},
"encoding": {
"theta": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
},
}
return ChartGenerationResponse(
reasoning="已基于字段类型快速生成占比图",
chart_type="pie",
chart_spec=chart_spec,
can_visualize=True,
)
chart_spec = {
"title": title,
"mark": {"type": "bar"},
"encoding": {
"x": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
"y": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
},
}
return ChartGenerationResponse(
reasoning="已基于字段类型快速生成对比图",
chart_type="bar",
chart_spec=chart_spec,
can_visualize=True,
)
return None
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
fast_result = _fast_generate_chart(data, query)
if fast_result:
return fast_result
active_config = get_active_llm_config()
if not active_config:
+34 -8
View File
@@ -1,3 +1,4 @@
import asyncio
import sys
import os
import json
@@ -33,6 +34,9 @@ MAX_UPLOAD_CACHE_ITEMS = 8
NL2SQL_MAX_TOKENS = 900
NL2SQL_TEMPERATURE = 0.1
NL2SQL_REASONING_EFFORT = "low"
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
NL2SQL_CHART_TIMEOUT_SECONDS = 45
_schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {}
@@ -341,11 +345,14 @@ Language: Chinese (Simplified)
try:
llm_started = time.perf_counter()
await emit_progress("正在生成 SQL")
response = await provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
content = response.content.strip()
@@ -366,19 +373,30 @@ Language: Chinese (Simplified)
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except asyncio.TimeoutError:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
# 6. Execute SQL
try:
timeout_stage = "sql_execution"
sql_exec_started = time.perf_counter()
await emit_progress("正在执行 SQL 查询")
if request.source == "upload":
if upload_df is None:
upload_df = _get_upload_payload(request.file_url)["df"]
formatted_results = _execute_upload_sql(sql_query, upload_df)
timeout_stage = "sql_execution"
formatted_results = await asyncio.wait_for(
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
else:
results = connector.execute_query(sql_query)
timeout_stage = "sql_execution"
results = await asyncio.wait_for(
asyncio.to_thread(connector.execute_query, sql_query),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
# Format results
formatted_results = []
@@ -418,10 +436,18 @@ Language: Chinese (Simplified)
if request.generate_chart and formatted_results:
chart_started = time.perf_counter()
await emit_progress("正在生成可视化方案")
chart_response = await generate_chart(formatted_results, request.query)
timeout_stage = "chart_generation"
chart_response = await asyncio.wait_for(
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except asyncio.TimeoutError:
if timeout_stage == "chart_generation":
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")