something
This commit is contained in:
+2
-122
@@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
import time
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -136,129 +138,7 @@ CHART_EXAMPLES = """
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
TEMPORAL_KEYWORDS = ("date", "time", "day", "month", "year", "日期", "时间", "月份", "年份")
|
|
||||||
PIE_QUERY_KEYWORDS = ("占比", "构成", "比例", "份额", "分布", "pie")
|
|
||||||
|
|
||||||
|
|
||||||
def _first_non_null(rows: List[Dict[str, Any]], key: str) -> Any:
|
|
||||||
for row in rows:
|
|
||||||
value = row.get(key)
|
|
||||||
if value is not None:
|
|
||||||
return value
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_number(value: Any) -> bool:
|
|
||||||
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
|
||||||
|
|
||||||
|
|
||||||
def _looks_temporal_field(key: str, sample_value: Any) -> bool:
|
|
||||||
lowered = key.lower()
|
|
||||||
if any(token in lowered for token in TEMPORAL_KEYWORDS):
|
|
||||||
return True
|
|
||||||
if not isinstance(sample_value, str):
|
|
||||||
return False
|
|
||||||
text = sample_value.strip()
|
|
||||||
patterns = [
|
|
||||||
r"^\d{4}[-/]\d{1,2}([-/]\d{1,2})?$",
|
|
||||||
r"^\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}(:\d{2})?$",
|
|
||||||
r"^\d{8}$",
|
|
||||||
]
|
|
||||||
return any(re.match(p, text) for p in patterns)
|
|
||||||
|
|
||||||
|
|
||||||
def _encoding_title(field: str) -> str:
|
|
||||||
return field.replace("_", " ").strip() or field
|
|
||||||
|
|
||||||
|
|
||||||
def _fast_generate_chart(data: List[Dict[str, Any]], query: str) -> Optional[ChartGenerationResponse]:
|
|
||||||
if not data or not isinstance(data[0], dict):
|
|
||||||
return None
|
|
||||||
columns = list(data[0].keys())
|
|
||||||
if not columns:
|
|
||||||
return None
|
|
||||||
numeric_cols: List[str] = []
|
|
||||||
temporal_cols: List[str] = []
|
|
||||||
categorical_cols: List[str] = []
|
|
||||||
sample_rows = data[:50]
|
|
||||||
for col in columns:
|
|
||||||
sample_value = _first_non_null(sample_rows, col)
|
|
||||||
if sample_value is None:
|
|
||||||
continue
|
|
||||||
if _is_number(sample_value):
|
|
||||||
numeric_cols.append(col)
|
|
||||||
continue
|
|
||||||
if _looks_temporal_field(col, sample_value):
|
|
||||||
temporal_cols.append(col)
|
|
||||||
continue
|
|
||||||
categorical_cols.append(col)
|
|
||||||
if not numeric_cols:
|
|
||||||
return None
|
|
||||||
|
|
||||||
title = "查询结果可视化"
|
|
||||||
query_lower = (query or "").lower()
|
|
||||||
|
|
||||||
if temporal_cols:
|
|
||||||
x_col = temporal_cols[0]
|
|
||||||
y_col = numeric_cols[0]
|
|
||||||
chart_spec = {
|
|
||||||
"title": title,
|
|
||||||
"mark": {"type": "line"},
|
|
||||||
"encoding": {
|
|
||||||
"x": {"field": x_col, "type": "temporal", "timeUnit": "yearmonth", "title": _encoding_title(x_col)},
|
|
||||||
"y": {"field": y_col, "type": "quantitative", "title": _encoding_title(y_col)},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return ChartGenerationResponse(
|
|
||||||
reasoning="已基于字段类型快速生成趋势图",
|
|
||||||
chart_type="line",
|
|
||||||
chart_spec=chart_spec,
|
|
||||||
can_visualize=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if categorical_cols:
|
|
||||||
cat_col = categorical_cols[0]
|
|
||||||
val_col = numeric_cols[0]
|
|
||||||
unique_values = {str(row.get(cat_col)) for row in sample_rows if row.get(cat_col) is not None}
|
|
||||||
use_pie = len(unique_values) <= 8 and any(token in query_lower for token in PIE_QUERY_KEYWORDS)
|
|
||||||
if use_pie:
|
|
||||||
chart_spec = {
|
|
||||||
"title": title,
|
|
||||||
"mark": {"type": "arc"},
|
|
||||||
"encoding": {
|
|
||||||
"theta": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
|
|
||||||
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return ChartGenerationResponse(
|
|
||||||
reasoning="已基于字段类型快速生成占比图",
|
|
||||||
chart_type="pie",
|
|
||||||
chart_spec=chart_spec,
|
|
||||||
can_visualize=True,
|
|
||||||
)
|
|
||||||
chart_spec = {
|
|
||||||
"title": title,
|
|
||||||
"mark": {"type": "bar"},
|
|
||||||
"encoding": {
|
|
||||||
"x": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
|
||||||
"y": {"field": val_col, "type": "quantitative", "title": _encoding_title(val_col)},
|
|
||||||
"color": {"field": cat_col, "type": "nominal", "title": _encoding_title(cat_col)},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return ChartGenerationResponse(
|
|
||||||
reasoning="已基于字段类型快速生成对比图",
|
|
||||||
chart_type="bar",
|
|
||||||
chart_spec=chart_spec,
|
|
||||||
can_visualize=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
|
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
|
||||||
fast_result = _fast_generate_chart(data, query)
|
|
||||||
if fast_result:
|
|
||||||
return fast_result
|
|
||||||
|
|
||||||
active_config = get_active_llm_config()
|
active_config = get_active_llm_config()
|
||||||
|
|
||||||
if not active_config:
|
if not active_config:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -354,7 +355,13 @@ Language: Chinese (Simplified)
|
|||||||
),
|
),
|
||||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||||
)
|
)
|
||||||
content = response.content.strip()
|
|
||||||
|
if response.finish_reason == "error":
|
||||||
|
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
|
||||||
|
|
||||||
|
content = (response.content or "").strip()
|
||||||
|
if not content:
|
||||||
|
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
|
||||||
|
|
||||||
# Clean up code blocks
|
# Clean up code blocks
|
||||||
if "```json" in content:
|
if "```json" in content:
|
||||||
|
|||||||
+1
-1
@@ -278,7 +278,7 @@ async def nanobot_chat_stream(request: ChatRequest):
|
|||||||
continue
|
continue
|
||||||
nl2sql_result = await sql_task
|
nl2sql_result = await sql_task
|
||||||
if nl2sql_result.error:
|
if nl2sql_result.error:
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'content': '数据查询阶段返回错误,正在整理结果'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'progress', 'content': f'出错:{nl2sql_result.error},正在整理结果'}, ensure_ascii=False)}\n\n"
|
||||||
else:
|
else:
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
|
||||||
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||||
|
|||||||
Reference in New Issue
Block a user