From f03a653112f335556227f3c1223ca810494ff82f Mon Sep 17 00:00:00 2001 From: qixinbo Date: Tue, 17 Mar 2026 22:09:47 +0800 Subject: [PATCH] add timeout --- backend/app/agent/chart.py | 122 ++++++++++++++++++++++++++++++++++++ backend/app/agent/nl2sql.py | 42 ++++++++++--- 2 files changed, 156 insertions(+), 8 deletions(-) diff --git a/backend/app/agent/chart.py b/backend/app/agent/chart.py index 3ab74cb..7148c49 100644 --- a/backend/app/agent/chart.py +++ b/backend/app/agent/chart.py @@ -136,7 +136,129 @@ CHART_EXAMPLES = """ } """ +TEMPORAL_KEYWORDS = ("date", "time", "day", "month", "year", "日期", "时间", "月份", "年份") +PIE_QUERY_KEYWORDS = ("占比", "构成", "比例", "份额", "分布", "pie") + + +def _first_non_null(rows: List[Dict[str, Any]], key: str) -> Any: + for row in rows: + value = row.get(key) + if value is not None: + return value + return None + + +def _is_number(value: Any) -> bool: + return isinstance(value, (int, float)) and not isinstance(value, bool) + + +def _looks_temporal_field(key: str, sample_value: Any) -> bool: + lowered = key.lower() + if any(token in lowered for token in TEMPORAL_KEYWORDS): + return True + if not isinstance(sample_value, str): + return False + text = sample_value.strip() + patterns = [ + r"^\d{4}[-/]\d{1,2}([-/]\d{1,2})?$", + r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}(:\d{2})?$", + r"^\d{8}$", + ] + return any(re.match(p, text) for p in patterns) + + +def _encoding_title(field: str) -> str: + return field.replace("_", " ").strip() or field + + +def _fast_generate_chart(data: List[Dict[str, Any]], query: str) -> Optional[ChartGenerationResponse]: + if not data or not isinstance(data[0], dict): + return None + columns = list(data[0].keys()) + if not columns: + return None + numeric_cols: List[str] = [] + temporal_cols: List[str] = [] + categorical_cols: List[str] = [] + sample_rows = data[:50] + for col in columns: + sample_value = _first_non_null(sample_rows, col) + if sample_value is None: + continue + if _is_number(sample_value): + numeric_cols.append(col) + continue + if _looks_temporal_field(col, sample_value): + temporal_cols.append(col) + continue + categorical_cols.append(col) + if not numeric_cols: + return None + + title = "查询结果可视化" + query_lower = (query or "").lower() + + if temporal_cols: + x_col = temporal_cols[0] + y_col = numeric_cols[0] + chart_spec = { + "title": title, + "mark": {"type": "line"}, + "encoding": { + "x": {"field": x_col, "type": "temporal", "timeUnit": "yearmonth", "title": _encoding_title(x_col)}, + "y": {"field": y_col, "type": "quantitative", "title": _encoding_title(y_col)}, + }, + } + return ChartGenerationResponse( + reasoning="已基于字段类型快速生成趋势图", + chart_type="line", + chart_spec=chart_spec, + can_visualize=True, + ) + + if categorical_cols: + cat_col = categorical_cols[0] + val_col = numeric_cols[0] + unique_values = {str(row.get(cat_col)) for row in sample_rows if row.get(cat_col) is not None} + use_pie = len(unique_values) <= 8 and any(token in query_lower for token in PIE_QUERY_KEYWORDS) + if use_pie: + chart_spec = { + "title": title, + "mark": {"type": "arc"}, + "encoding": { + "theta": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)}, + "color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)}, + }, + } + return ChartGenerationResponse( + reasoning="已基于字段类型快速生成占比图", + chart_type="pie", + chart_spec=chart_spec, + can_visualize=True, + ) + chart_spec = { + "title": title, + "mark": {"type": "bar"}, + "encoding": { + "x": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)}, + "y": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)}, + "color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)}, + }, + } + return ChartGenerationResponse( + reasoning="已基于字段类型快速生成对比图", + chart_type="bar", + chart_spec=chart_spec, + can_visualize=True, + ) + + return None + async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse: + fast_result = _fast_generate_chart(data, query) + if fast_result: + return fast_result + active_config = get_active_llm_config() if not active_config: diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index ee3eb8b..9c487e5 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -1,3 +1,4 @@ +import asyncio import sys import os import json @@ -33,6 +34,9 @@ MAX_UPLOAD_CACHE_ITEMS = 8 NL2SQL_MAX_TOKENS = 900 NL2SQL_TEMPERATURE = 0.1 NL2SQL_REASONING_EFFORT = "low" +NL2SQL_LLM_TIMEOUT_SECONDS = 60*5 +NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60 +NL2SQL_CHART_TIMEOUT_SECONDS = 45 _schema_cache: Dict[str, Dict[str, Any]] = {} _connection_cache: Dict[str, Dict[str, Any]] = {} @@ -341,11 +345,14 @@ Language: Chinese (Simplified) try: llm_started = time.perf_counter() await emit_progress("正在生成 SQL") - response = await provider.chat( - messages=messages, - max_tokens=NL2SQL_MAX_TOKENS, - temperature=NL2SQL_TEMPERATURE, - reasoning_effort=NL2SQL_REASONING_EFFORT, + response = await asyncio.wait_for( + provider.chat( + messages=messages, + max_tokens=NL2SQL_MAX_TOKENS, + temperature=NL2SQL_TEMPERATURE, + reasoning_effort=NL2SQL_REASONING_EFFORT, + ), + timeout=NL2SQL_LLM_TIMEOUT_SECONDS, ) content = response.content.strip() @@ -366,19 +373,30 @@ Language: Chinese (Simplified) sql_query = content await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)") + except asyncio.TimeoutError: + return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s") except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}") # 6. Execute SQL try: + timeout_stage = "sql_execution" sql_exec_started = time.perf_counter() await emit_progress("正在执行 SQL 查询") if request.source == "upload": if upload_df is None: upload_df = _get_upload_payload(request.file_url)["df"] - formatted_results = _execute_upload_sql(sql_query, upload_df) + timeout_stage = "sql_execution" + formatted_results = await asyncio.wait_for( + asyncio.to_thread(_execute_upload_sql, sql_query, upload_df), + timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS, + ) else: - results = connector.execute_query(sql_query) + timeout_stage = "sql_execution" + results = await asyncio.wait_for( + asyncio.to_thread(connector.execute_query, sql_query), + timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS, + ) # Format results formatted_results = [] @@ -418,10 +436,18 @@ Language: Chinese (Simplified) if request.generate_chart and formatted_results: chart_started = time.perf_counter() await emit_progress("正在生成可视化方案") - chart_response = await generate_chart(formatted_results, request.query) + timeout_stage = "chart_generation" + chart_response = await asyncio.wait_for( + generate_chart(formatted_results, request.query), + timeout=NL2SQL_CHART_TIMEOUT_SECONDS, + ) await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)") await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s") return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) + except asyncio.TimeoutError: + if timeout_stage == "chart_generation": + return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s") + return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s") except Exception as e: return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")