diff --git a/backend/app/agent/chart.py b/backend/app/agent/chart.py index 7148c49..b7f941f 100644 --- a/backend/app/agent/chart.py +++ b/backend/app/agent/chart.py @@ -1,5 +1,7 @@ import json +import re from typing import List, Dict, Any, Optional +import time import sys from pathlib import Path @@ -136,129 +138,7 @@ CHART_EXAMPLES = """ } """ -TEMPORAL_KEYWORDS = ("date", "time", "day", "month", "year", "日期", "时间", "月份", "年份") -PIE_QUERY_KEYWORDS = ("占比", "构成", "比例", "份额", "分布", "pie") - - -def _first_non_null(rows: List[Dict[str, Any]], key: str) -> Any: - for row in rows: - value = row.get(key) - if value is not None: - return value - return None - - -def _is_number(value: Any) -> bool: - return isinstance(value, (int, float)) and not isinstance(value, bool) - - -def _looks_temporal_field(key: str, sample_value: Any) -> bool: - lowered = key.lower() - if any(token in lowered for token in TEMPORAL_KEYWORDS): - return True - if not isinstance(sample_value, str): - return False - text = sample_value.strip() - patterns = [ - r"^\d{4}[-/]\d{1,2}([-/]\d{1,2})?$", - r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}(:\d{2})?$", - r"^\d{8}$", - ] - return any(re.match(p, text) for p in patterns) - - -def _encoding_title(field: str) -> str: - return field.replace("_", " ").strip() or field - - -def _fast_generate_chart(data: List[Dict[str, Any]], query: str) -> Optional[ChartGenerationResponse]: - if not data or not isinstance(data[0], dict): - return None - columns = list(data[0].keys()) - if not columns: - return None - numeric_cols: List[str] = [] - temporal_cols: List[str] = [] - categorical_cols: List[str] = [] - sample_rows = data[:50] - for col in columns: - sample_value = _first_non_null(sample_rows, col) - if sample_value is None: - continue - if _is_number(sample_value): - numeric_cols.append(col) - continue - if _looks_temporal_field(col, sample_value): - temporal_cols.append(col) - continue - categorical_cols.append(col) - if not numeric_cols: - return None - - title = "查询结果可视化" - query_lower = (query or "").lower() - - if temporal_cols: - x_col = temporal_cols[0] - y_col = numeric_cols[0] - chart_spec = { - "title": title, - "mark": {"type": "line"}, - "encoding": { - "x": {"field": x_col, "type": "temporal", "timeUnit": "yearmonth", "title": _encoding_title(x_col)}, - "y": {"field": y_col, "type": "quantitative", "title": _encoding_title(y_col)}, - }, - } - return ChartGenerationResponse( - reasoning="已基于字段类型快速生成趋势图", - chart_type="line", - chart_spec=chart_spec, - can_visualize=True, - ) - - if categorical_cols: - cat_col = categorical_cols[0] - val_col = numeric_cols[0] - unique_values = {str(row.get(cat_col)) for row in sample_rows if row.get(cat_col) is not None} - use_pie = len(unique_values) <= 8 and any(token in query_lower for token in PIE_QUERY_KEYWORDS) - if use_pie: - chart_spec = { - "title": title, - "mark": {"type": "arc"}, - "encoding": { - "theta": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)}, - "color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)}, - }, - } - return ChartGenerationResponse( - reasoning="已基于字段类型快速生成占比图", - chart_type="pie", - chart_spec=chart_spec, - can_visualize=True, - ) - chart_spec = { - "title": title, - "mark": {"type": "bar"}, - "encoding": { - "x": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)}, - "y": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)}, - "color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)}, - }, - } - return ChartGenerationResponse( - reasoning="已基于字段类型快速生成对比图", - chart_type="bar", - chart_spec=chart_spec, - can_visualize=True, - ) - - return None - async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse: - fast_result = _fast_generate_chart(data, query) - if fast_result: - return fast_result - active_config = get_active_llm_config() if not active_config: diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index 9c487e5..c620f94 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -4,6 +4,7 @@ import os import json import time import threading +import re from pathlib import Path from typing import List, Optional, Dict, Any, Callable, Awaitable from pydantic import BaseModel, Field @@ -354,7 +355,13 @@ Language: Chinese (Simplified) ), timeout=NL2SQL_LLM_TIMEOUT_SECONDS, ) - content = response.content.strip() + + if response.finish_reason == "error": + return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error") + + content = (response.content or "").strip() + if not content: + return NL2SQLResponse(sql="", result=[], error="LLM returned empty response") # Clean up code blocks if "```json" in content: diff --git a/backend/main.py b/backend/main.py index 2b17e70..78a3cfc 100644 --- a/backend/main.py +++ b/backend/main.py @@ -278,7 +278,7 @@ async def nanobot_chat_stream(request: ChatRequest): continue nl2sql_result = await sql_task if nl2sql_result.error: - yield f"data: {json.dumps({'type': 'progress', 'content': '数据查询阶段返回错误,正在整理结果'}, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps({'type': 'progress', 'content': f'出错:{nl2sql_result.error},正在整理结果'}, ensure_ascii=False)}\n\n" else: yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n" persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)