reorg skill folder

This commit is contained in:
qixinbo
2026-03-19 12:27:31 +08:00
parent baec21c774
commit cca492cfdb
13 changed files with 232 additions and 81 deletions
+13 -1
View File
@@ -44,7 +44,7 @@ CHART_INSTRUCTIONS = """
- For daily question, the time unit should be "yearmonthdate".
- Default time unit is "yearmonth".
- For each axis, generate the corresponding human-readable title based on the language provided by the user.
- Make sure all of the fields(x, y, xOffset, color, etc.) in the encoding section of the chart schema are present in the column names of the data.
- **CRITICAL REQUIREMENT**: Make sure all of the `field` values in the encoding section of the chart schema EXACTLY MATCH the column names of the sample data provided! DO NOT translate, rename, or hallucinate `field` names. If you want to show a translated name in the chart, use the `title` property, NOT the `field` property!
### GUIDELINES TO PLOT CHART ###
@@ -233,6 +233,18 @@ Language: Chinese (Simplified)
content = content.strip()
result = json.loads(content)
# Post-process to fix common LLM hallucinations (translating field names)
if result.get("chart_spec") and isinstance(result["chart_spec"], dict):
encoding = result["chart_spec"].get("encoding", {})
for channel, enc_def in encoding.items():
if isinstance(enc_def, dict) and "field" in enc_def:
field = enc_def["field"]
# If field is not in columns, try to find a match or let it be (Vega will render empty)
# But if we can detect it was translated, we might not be able to fix it perfectly.
# As a simple fallback, if there's only one string column and one numeric column, we could guess,
# but it's safer to just rely on the stricter prompt.
return ChartGenerationResponse(**result)
except Exception as e:
+77 -42
View File
@@ -4,7 +4,6 @@ import os
import json
import time
import threading
import re
from pathlib import Path
from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field
@@ -35,9 +34,11 @@ MAX_UPLOAD_CACHE_ITEMS = 8
NL2SQL_MAX_TOKENS = 900
NL2SQL_TEMPERATURE = 0.1
NL2SQL_REASONING_EFFORT = "low"
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
NL2SQL_LLM_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_TIMEOUT_SECONDS", "90"))
NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS", "45"))
NL2SQL_LLM_RETRY_COUNT = int(os.getenv("NL2SQL_LLM_RETRY_COUNT", "0"))
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
NL2SQL_CHART_TIMEOUT_SECONDS = 45
NL2SQL_CHART_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_CHART_TIMEOUT_SECONDS", "20"))
_schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {}
@@ -163,6 +164,23 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
conn.close()
return result_df.to_dict(orient="records")
def _to_number(value: Any) -> Optional[float]:
if isinstance(value, bool):
return None
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, str):
text = value.strip().replace(",", "")
if not text:
return None
try:
return float(text)
except ValueError:
return None
return None
# _build_fallback_chart removed as per user request to not hardcode fallbacks
def _build_schema_cache_key(source: str, connector: Any) -> str:
# If source is ds:ID, that's already a good key
if source.startswith("ds:"):
@@ -270,7 +288,7 @@ async def process_nl2sql(
if connector:
await emit_progress("正在检测数据源连通性")
cached_schema = _get_cached_schema(request.source, connector)
if cached_schema:
if cached_schema is not None:
schema = cached_schema
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
else:
@@ -289,23 +307,6 @@ async def process_nl2sql(
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
if connector and not schema:
retry_started = time.perf_counter()
# Double check in case schema was empty but connection is ok (e.g. empty db)
if not await _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=30.0
)
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
@@ -378,42 +379,66 @@ Language: Chinese (Simplified)
try:
llm_started = time.perf_counter()
await emit_progress("正在生成 SQL")
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
if response.finish_reason == "error":
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
response = None
last_error = ""
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
try:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
num_retries=0,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
except Exception as e:
last_error = f"LLM generation failed: {e}"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
if response.finish_reason == "error":
last_error = response.content or "LLM Error"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
break
if response is None:
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
content = (response.content or "").strip()
if not content:
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
# Clean up code blocks
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
elif "```" in content:
content = content.split("```")[1].split("```")[0]
content = content.strip()
try:
result_json = json.loads(content)
sql_query = result_json.get("sql", "").strip()
reasoning = result_json.get("reasoning", "") # We can log this or return it if needed
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except asyncio.TimeoutError:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
@@ -481,13 +506,23 @@ Language: Chinese (Simplified)
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
if not chart_response or not chart_response.chart_spec:
# Do not fallback automatically if the LLM explicitly decided not to or failed.
# Just pass whatever it returned (or lack thereof)
pass
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except asyncio.TimeoutError:
if timeout_stage == "chart_generation":
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
fallback_chart = ChartGenerationResponse(
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
chart_type="",
can_visualize=False,
chart_spec=None,
)
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")