reorg skill folder
This commit is contained in:
+77
-42
@@ -4,7 +4,6 @@ import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -35,9 +34,11 @@ MAX_UPLOAD_CACHE_ITEMS = 8
|
||||
NL2SQL_MAX_TOKENS = 900
|
||||
NL2SQL_TEMPERATURE = 0.1
|
||||
NL2SQL_REASONING_EFFORT = "low"
|
||||
NL2SQL_LLM_TIMEOUT_SECONDS = 60*5
|
||||
NL2SQL_LLM_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_TIMEOUT_SECONDS", "90"))
|
||||
NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS", "45"))
|
||||
NL2SQL_LLM_RETRY_COUNT = int(os.getenv("NL2SQL_LLM_RETRY_COUNT", "0"))
|
||||
NL2SQL_SQL_EXEC_TIMEOUT_SECONDS = 60
|
||||
NL2SQL_CHART_TIMEOUT_SECONDS = 45
|
||||
NL2SQL_CHART_TIMEOUT_SECONDS = int(os.getenv("NL2SQL_CHART_TIMEOUT_SECONDS", "20"))
|
||||
|
||||
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -163,6 +164,23 @@ def _execute_upload_sql(sql_query: str, df: pd.DataFrame) -> List[Dict[str, Any]
|
||||
conn.close()
|
||||
return result_df.to_dict(orient="records")
|
||||
|
||||
def _to_number(value: Any) -> Optional[float]:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if isinstance(value, (int, float)):
|
||||
return float(value)
|
||||
if isinstance(value, str):
|
||||
text = value.strip().replace(",", "")
|
||||
if not text:
|
||||
return None
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
# _build_fallback_chart removed as per user request to not hardcode fallbacks
|
||||
|
||||
def _build_schema_cache_key(source: str, connector: Any) -> str:
|
||||
# If source is ds:ID, that's already a good key
|
||||
if source.startswith("ds:"):
|
||||
@@ -270,7 +288,7 @@ async def process_nl2sql(
|
||||
if connector:
|
||||
await emit_progress("正在检测数据源连通性")
|
||||
cached_schema = _get_cached_schema(request.source, connector)
|
||||
if cached_schema:
|
||||
if cached_schema is not None:
|
||||
schema = cached_schema
|
||||
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
||||
else:
|
||||
@@ -289,23 +307,6 @@ async def process_nl2sql(
|
||||
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
||||
|
||||
if connector and not schema:
|
||||
retry_started = time.perf_counter()
|
||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||
if not await _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=30.0
|
||||
)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
|
||||
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
|
||||
|
||||
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
@@ -378,42 +379,66 @@ Language: Chinese (Simplified)
|
||||
try:
|
||||
llm_started = time.perf_counter()
|
||||
await emit_progress("正在生成 SQL")
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
if response.finish_reason == "error":
|
||||
return NL2SQLResponse(sql="", result=[], error=response.content or "LLM Error")
|
||||
|
||||
response = None
|
||||
last_error = ""
|
||||
|
||||
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
|
||||
num_retries=0,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
except Exception as e:
|
||||
last_error = f"LLM generation failed: {e}"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
|
||||
if response.finish_reason == "error":
|
||||
last_error = response.content or "LLM Error"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
|
||||
|
||||
content = (response.content or "").strip()
|
||||
if not content:
|
||||
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
|
||||
|
||||
|
||||
# Clean up code blocks
|
||||
if "```json" in content:
|
||||
content = content.split("```json")[1].split("```")[0]
|
||||
elif "```" in content:
|
||||
content = content.split("```")[1].split("```")[0]
|
||||
|
||||
|
||||
content = content.strip()
|
||||
|
||||
|
||||
try:
|
||||
result_json = json.loads(content)
|
||||
sql_query = result_json.get("sql", "").strip()
|
||||
reasoning = result_json.get("reasoning", "") # We can log this or return it if needed
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s")
|
||||
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
@@ -481,13 +506,23 @@ Language: Chinese (Simplified)
|
||||
generate_chart(formatted_results, request.query),
|
||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||
)
|
||||
if not chart_response or not chart_response.chart_spec:
|
||||
# Do not fallback automatically if the LLM explicitly decided not to or failed.
|
||||
# Just pass whatever it returned (or lack thereof)
|
||||
pass
|
||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except asyncio.TimeoutError:
|
||||
if timeout_stage == "chart_generation":
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, error=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s")
|
||||
fallback_chart = ChartGenerationResponse(
|
||||
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
|
||||
chart_type="",
|
||||
can_visualize=False,
|
||||
chart_spec=None,
|
||||
)
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||
|
||||
Reference in New Issue
Block a user