something

This commit is contained in:
qixinbo
2026-03-17 22:38:10 +08:00
parent f03a653112
commit cc93e0ea5d
3 changed files with 11 additions and 124 deletions
+2 -122
View File
@@ -1,5 +1,7 @@
import json import json
import re
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
import time
import sys import sys
from pathlib import Path from pathlib import Path
@@ -136,129 +138,7 @@ CHART_EXAMPLES = """
} }
""" """
TEMPORAL_KEYWORDS = ("date", "time", "day", "month", "year", "日期", "时间", "月份", "年份")
PIE_QUERY_KEYWORDS = ("占比", "构成", "比例", "份额", "分布", "pie")
def _first_non_null(rows: List[Dict[str, Any]], key: str) -> Any:
for row in rows:
value = row.get(key)
if value is not None:
return value
return None
def _is_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)
def _looks_temporal_field(key: str, sample_value: Any) -> bool:
lowered = key.lower()
if any(token in lowered for token in TEMPORAL_KEYWORDS):
return True
if not isinstance(sample_value, str):
return False
text = sample_value.strip()
patterns = [
r"^\d{4}[-/]\d{1,2}([-/]\d{1,2})?$",
r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}(:\d{2})?$",
r"^\d{8}$",
]
return any(re.match(p, text) for p in patterns)
def _encoding_title(field: str) -> str:
return field.replace("_", " ").strip() or field
def _fast_generate_chart(data: List[Dict[str, Any]], query: str) -> Optional[ChartGenerationResponse]:
if not data or not isinstance(data[0], dict):
return None
columns = list(data[0].keys())
if not columns:
return None
numeric_cols: List[str] = []
temporal_cols: List[str] = []
categorical_cols: List[str] = []
sample_rows = data[:50]
for col in columns:
sample_value = _first_non_null(sample_rows, col)
if sample_value is None:
continue
if _is_number(sample_value):
numeric_cols.append(col)
continue
if _looks_temporal_field(col, sample_value):
temporal_cols.append(col)
continue
categorical_cols.append(col)
if not numeric_cols:
return None
title = "查询结果可视化"
query_lower = (query or "").lower()
if temporal_cols:
x_col = temporal_cols[0]
y_col = numeric_cols[0]
chart_spec = {
"title": title,
"mark": {"type": "line"},
"encoding": {
"x": {"field": x_col, "type": "temporal", "timeUnit": "yearmonth", "title": _encoding_title(x_col)},
"y": {"field": y_col, "type": "quantitative", "title": _encoding_title(y_col)},
},
}
return ChartGenerationResponse(
reasoning="已基于字段类型快速生成趋势图",
chart_type="line",
chart_spec=chart_spec,
can_visualize=True,
)
if categorical_cols:
cat_col = categorical_cols[0]
val_col = numeric_cols[0]
unique_values = {str(row.get(cat_col)) for row in sample_rows if row.get(cat_col) is not None}
use_pie = len(unique_values) <= 8 and any(token in query_lower for token in PIE_QUERY_KEYWORDS)
if use_pie:
chart_spec = {
"title": title,
"mark": {"type": "arc"},
"encoding": {
"theta": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
},
}
return ChartGenerationResponse(
reasoning="已基于字段类型快速生成占比图",
chart_type="pie",
chart_spec=chart_spec,
can_visualize=True,
)
chart_spec = {
"title": title,
"mark": {"type": "bar"},
"encoding": {
"x": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
"y": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
},
}
return ChartGenerationResponse(
reasoning="已基于字段类型快速生成对比图",
chart_type="bar",
chart_spec=chart_spec,
can_visualize=True,
)
return None
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse: async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
fast_result = _fast_generate_chart(data, query)
if fast_result:
return fast_result
active_config = get_active_llm_config() active_config = get_active_llm_config()
if not active_config: if not active_config:
+8 -1
View File
@@ -4,6 +4,7 @@ import os
import json import json
import time import time
import threading import threading
import re
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Any, Callable, Awaitable from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -354,7 +355,13 @@ Language: Chinese (Simplified)
), ),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS, timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
) )
content = response.content.strip()
if response.finish_reason == "error":
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
content = (response.content or "").strip()
if not content:
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
# Clean up code blocks # Clean up code blocks
if "```json" in content: if "```json" in content:
+1 -1
View File
@@ -278,7 +278,7 @@ async def nanobot_chat_stream(request: ChatRequest):
continue continue
nl2sql_result = await sql_task nl2sql_result = await sql_task
if nl2sql_result.error: if nl2sql_result.error:
yield f"data: {json.dumps({'type': 'progress', 'content': '数据查询阶段返回错误,正在整理结果'}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'progress', 'content': f'出错:{nl2sql_result.error},正在整理结果'}, ensure_ascii=False)}\n\n"
else: else:
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)