something
This commit is contained in:
+2
-122
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
import time
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -136,129 +138,7 @@ CHART_EXAMPLES = """
|
||||
}
|
||||
"""
|
||||
|
||||
TEMPORAL_KEYWORDS = ("date", "time", "day", "month", "year", "日期", "时间", "月份", "年份")
|
||||
PIE_QUERY_KEYWORDS = ("占比", "构成", "比例", "份额", "分布", "pie")
|
||||
|
||||
|
||||
def _first_non_null(rows: List[Dict[str, Any]], key: str) -> Any:
|
||||
for row in rows:
|
||||
value = row.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _is_number(value: Any) -> bool:
|
||||
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
||||
|
||||
|
||||
def _looks_temporal_field(key: str, sample_value: Any) -> bool:
|
||||
lowered = key.lower()
|
||||
if any(token in lowered for token in TEMPORAL_KEYWORDS):
|
||||
return True
|
||||
if not isinstance(sample_value, str):
|
||||
return False
|
||||
text = sample_value.strip()
|
||||
patterns = [
|
||||
r"^\d{4}[-/]\d{1,2}([-/]\d{1,2})?$",
|
||||
r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}(:\d{2})?$",
|
||||
r"^\d{8}$",
|
||||
]
|
||||
return any(re.match(p, text) for p in patterns)
|
||||
|
||||
|
||||
def _encoding_title(field: str) -> str:
|
||||
return field.replace("_", " ").strip() or field
|
||||
|
||||
|
||||
def _fast_generate_chart(data: List[Dict[str, Any]], query: str) -> Optional[ChartGenerationResponse]:
|
||||
if not data or not isinstance(data[0], dict):
|
||||
return None
|
||||
columns = list(data[0].keys())
|
||||
if not columns:
|
||||
return None
|
||||
numeric_cols: List[str] = []
|
||||
temporal_cols: List[str] = []
|
||||
categorical_cols: List[str] = []
|
||||
sample_rows = data[:50]
|
||||
for col in columns:
|
||||
sample_value = _first_non_null(sample_rows, col)
|
||||
if sample_value is None:
|
||||
continue
|
||||
if _is_number(sample_value):
|
||||
numeric_cols.append(col)
|
||||
continue
|
||||
if _looks_temporal_field(col, sample_value):
|
||||
temporal_cols.append(col)
|
||||
continue
|
||||
categorical_cols.append(col)
|
||||
if not numeric_cols:
|
||||
return None
|
||||
|
||||
title = "查询结果可视化"
|
||||
query_lower = (query or "").lower()
|
||||
|
||||
if temporal_cols:
|
||||
x_col = temporal_cols[0]
|
||||
y_col = numeric_cols[0]
|
||||
chart_spec = {
|
||||
"title": title,
|
||||
"mark": {"type": "line"},
|
||||
"encoding": {
|
||||
"x": {"field": x_col, "type": "temporal", "timeUnit": "yearmonth", "title": _encoding_title(x_col)},
|
||||
"y": {"field": y_col, "type": "quantitative", "title": _encoding_title(y_col)},
|
||||
},
|
||||
}
|
||||
return ChartGenerationResponse(
|
||||
reasoning="已基于字段类型快速生成趋势图",
|
||||
chart_type="line",
|
||||
chart_spec=chart_spec,
|
||||
can_visualize=True,
|
||||
)
|
||||
|
||||
if categorical_cols:
|
||||
cat_col = categorical_cols[0]
|
||||
val_col = numeric_cols[0]
|
||||
unique_values = {str(row.get(cat_col)) for row in sample_rows if row.get(cat_col) is not None}
|
||||
use_pie = len(unique_values) <= 8 and any(token in query_lower for token in PIE_QUERY_KEYWORDS)
|
||||
if use_pie:
|
||||
chart_spec = {
|
||||
"title": title,
|
||||
"mark": {"type": "arc"},
|
||||
"encoding": {
|
||||
"theta": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
|
||||
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||
},
|
||||
}
|
||||
return ChartGenerationResponse(
|
||||
reasoning="已基于字段类型快速生成占比图",
|
||||
chart_type="pie",
|
||||
chart_spec=chart_spec,
|
||||
can_visualize=True,
|
||||
)
|
||||
chart_spec = {
|
||||
"title": title,
|
||||
"mark": {"type": "bar"},
|
||||
"encoding": {
|
||||
"x": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||
"y": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
|
||||
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
||||
},
|
||||
}
|
||||
return ChartGenerationResponse(
|
||||
reasoning="已基于字段类型快速生成对比图",
|
||||
chart_type="bar",
|
||||
chart_spec=chart_spec,
|
||||
can_visualize=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
|
||||
fast_result = _fast_generate_chart(data, query)
|
||||
if fast_result:
|
||||
return fast_result
|
||||
|
||||
active_config = get_active_llm_config()
|
||||
|
||||
if not active_config:
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -354,7 +355,13 @@ Language: Chinese (Simplified)
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
content = response.content.strip()
|
||||
|
||||
if response.finish_reason == "error":
|
||||
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
|
||||
|
||||
content = (response.content or "").strip()
|
||||
if not content:
|
||||
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
|
||||
|
||||
# Clean up code blocks
|
||||
if "```json" in content:
|
||||
|
||||
Reference in New Issue
Block a user