add timeout
This commit is contained in:
@@ -136,7 +136,129 @@ CHART_EXAMPLES = """
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
TEMPORAL_KEYWORDS = ("date", "time", "day", "month", "year", "日期", "时间", "月份", "年份")
|
||||||
|
PIE_QUERY_KEYWORDS = ("占比", "构成", "比例", "份额", "分布", "pie")
|
||||||
|
|
||||||
|
|
||||||
|
def _first_non_null(rows: List[Dict[str, Any]], key: str) -> Any:
|
||||||
|
for row in rows:
|
||||||
|
value = row.get(key)
|
||||||
|
if value is not None:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_number(value: Any) -> bool:
|
||||||
|
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_temporal_field(key: str, sample_value: Any) -> bool:
|
||||||
|
lowered = key.lower()
|
||||||
|
if any(token in lowered for token in TEMPORAL_KEYWORDS):
|
||||||
|
return True
|
||||||
|
if not isinstance(sample_value, str):
|
||||||
|
return False
|
||||||
|
text = sample_value.strip()
|
||||||
|
patterns = [
|
||||||
|
r"^\d{4}[-/]\d{1,2}([-/]\d{1,2})?$",
|
||||||
|
r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}(:\d{2})?$",
|
||||||
|
r"^\d{8}$",
|
||||||
|
]
|
||||||
|
return any(re.match(p, text) for p in patterns)
|
||||||
|
|
||||||
|
|
||||||
|
def _encoding_title(field: str) -> str:
|
||||||
|
return field.replace("_", " ").strip() or field
|
||||||
|
|
||||||
|
|
||||||
|
def _fast_generate_chart(data: List[Dict[str, Any]], query: str) -> Optional[ChartGenerationResponse]:
|
||||||
|
if not data or not isinstance(data[0], dict):
|
||||||
|
return None
|
||||||
|
columns = list(data[0].keys())
|
||||||
|
if not columns:
|
||||||
|
return None
|
||||||
|
numeric_cols: List[str] = []
|
||||||
|
temporal_cols: List[str] = []
|
||||||
|
categorical_cols: List[str] = []
|
||||||
|
sample_rows = data[:50]
|
||||||
|
for col in columns:
|
||||||
|
sample_value = _first_non_null(sample_rows, col)
|
||||||
|
if sample_value is None:
|
||||||
|
continue
|
||||||
|
if _is_number(sample_value):
|
||||||
|
numeric_cols.append(col)
|
||||||
|
continue
|
||||||
|
if _looks_temporal_field(col, sample_value):
|
||||||
|
temporal_cols.append(col)
|
||||||
|
continue
|
||||||
|
categorical_cols.append(col)
|
||||||
|
if not numeric_cols:
|
||||||
|
return None
|
||||||
|
|
||||||
|
title = "查询结果可视化"
|
||||||
|
query_lower = (query or "").lower()
|
||||||
|
|
||||||
|
if temporal_cols:
|
||||||
|
x_col = temporal_cols[0]
|
||||||
|
y_col = numeric_cols[0]
|
||||||
|
chart_spec = {
|
||||||
|
"title": title,
|
||||||
|
"mark": {"type": "line"},
|
||||||
|
"encoding": {
|
||||||
|
"x": {"field": x_col, "type": "temporal", "timeUnit": "yearmonth", "title": _encoding_title(x_col)},
|
||||||
|
"y": {"field": y_col, "type": "quantitative", "title": _encoding_title(y_col)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return ChartGenerationResponse(
|
||||||
|
reasoning="已基于字段类型快速生成趋势图",
|
||||||
|
chart_type="line",
|
||||||
|
chart_spec=chart_spec,
|
||||||
|
can_visualize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if categorical_cols:
|
||||||
|
cat_col = categorical_cols[0]
|
||||||
|
val_col = numeric_cols[0]
|
||||||
|
unique_values = {str(row.get(cat_col)) for row in sample_rows if row.get(cat_col) is not None}
|
||||||
|
use_pie = len(unique_values) <= 8 and any(token in query_lower for token in PIE_QUERY_KEYWORDS)
|
||||||
|
if use_pie:
|
||||||
|
chart_spec = {
|
||||||
|
"title": title,
|
||||||
|
"mark": {"type": "arc"},
|
||||||
|
"encoding": {
|
||||||
|
"theta": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
|
||||||
|
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return ChartGenerationResponse(
|
||||||
|
reasoning="已基于字段类型快速生成占比图",
|
||||||
|
chart_type="pie",
|
||||||
|
chart_spec=chart_spec,
|
||||||
|
can_visualize=True,
|
||||||
|
)
|
||||||
|
chart_spec = {
|
||||||
|
"title": title,
|
||||||
|
"mark": {"type": "bar"},
|
||||||
|
"encoding": {
|
||||||
|
"x": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||||
|
"y": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
|
||||||
|
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return ChartGenerationResponse(
|
||||||
|
reasoning="已基于字段类型快速生成对比图",
|
||||||
|
chart_type="bar",
|
||||||
|
chart_spec=chart_spec,
|
||||||
|
can_visualize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
|
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
|
||||||
|
fast_result = _fast_generate_chart(data, query)
|
||||||
|
if fast_result:
|
||||||
|
return fast_result
|
||||||
|
|
||||||
active_config = get_active_llm_config()
|
active_config = get_active_llm_config()
|
||||||
|
|
||||||
if not active_config:
|
if not active_config:
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
@@ -33,6 +34,9 @@ MAX_UPLOAD_CACHE_ITEMS = 8
|
|||||||
NL2SQL_MAX_TOKENS = 900
|
NL2SQL_MAX_TOKENS = 900
|
||||||
NL2SQL_TEMPERATURE = 0.1
|
NL2SQL_TEMPERATURE = 0.1
|
||||||
NL2SQL_REASONING_EFFORT = "low"
|
NL2SQL_REASONING_EFFORT = "low"
|
||||||
|
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
|
||||||
|
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
|
||||||
|
NL2SQL_CHART_TIMEOUT_SECONDS = 45
|
||||||
|
|
||||||
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
@@ -341,11 +345,14 @@ Language: Chinese (Simplified)
|
|||||||
try:
|
try:
|
||||||
llm_started = time.perf_counter()
|
llm_started = time.perf_counter()
|
||||||
await emit_progress("正在生成 SQL")
|
await emit_progress("正在生成 SQL")
|
||||||
response = await provider.chat(
|
response = await asyncio.wait_for(
|
||||||
messages=messages,
|
provider.chat(
|
||||||
max_tokens=NL2SQL_MAX_TOKENS,
|
messages=messages,
|
||||||
temperature=NL2SQL_TEMPERATURE,
|
max_tokens=NL2SQL_MAX_TOKENS,
|
||||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
temperature=NL2SQL_TEMPERATURE,
|
||||||
|
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||||
|
),
|
||||||
|
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
content = response.content.strip()
|
content = response.content.strip()
|
||||||
|
|
||||||
@@ -366,19 +373,30 @@ Language: Chinese (Simplified)
|
|||||||
sql_query = content
|
sql_query = content
|
||||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||||
|
|
||||||
# 6. Execute SQL
|
# 6. Execute SQL
|
||||||
try:
|
try:
|
||||||
|
timeout_stage = "sql_execution"
|
||||||
sql_exec_started = time.perf_counter()
|
sql_exec_started = time.perf_counter()
|
||||||
await emit_progress("正在执行 SQL 查询")
|
await emit_progress("正在执行 SQL 查询")
|
||||||
if request.source == "upload":
|
if request.source == "upload":
|
||||||
if upload_df is None:
|
if upload_df is None:
|
||||||
upload_df = _get_upload_payload(request.file_url)["df"]
|
upload_df = _get_upload_payload(request.file_url)["df"]
|
||||||
formatted_results = _execute_upload_sql(sql_query, upload_df)
|
timeout_stage = "sql_execution"
|
||||||
|
formatted_results = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||||
|
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
results = connector.execute_query(sql_query)
|
timeout_stage = "sql_execution"
|
||||||
|
results = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(connector.execute_query, sql_query),
|
||||||
|
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
# Format results
|
# Format results
|
||||||
formatted_results = []
|
formatted_results = []
|
||||||
@@ -418,10 +436,18 @@ Language: Chinese (Simplified)
|
|||||||
if request.generate_chart and formatted_results:
|
if request.generate_chart and formatted_results:
|
||||||
chart_started = time.perf_counter()
|
chart_started = time.perf_counter()
|
||||||
await emit_progress("正在生成可视化方案")
|
await emit_progress("正在生成可视化方案")
|
||||||
chart_response = await generate_chart(formatted_results, request.query)
|
timeout_stage = "chart_generation"
|
||||||
|
chart_response = await asyncio.wait_for(
|
||||||
|
generate_chart(formatted_results, request.query),
|
||||||
|
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||||
|
)
|
||||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||||
|
|
||||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if timeout_stage == "chart_generation":
|
||||||
|
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
|
||||||
|
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||||
|
|||||||
Reference in New Issue
Block a user