refactor: convert to nl2sql skills
This commit is contained in:
+50
-17
@@ -191,14 +191,24 @@ def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[
|
||||
with _cache_lock:
|
||||
_schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS}
|
||||
|
||||
def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||
async def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||
cache_key = _build_schema_cache_key(source, connector)
|
||||
now = time.time()
|
||||
with _cache_lock:
|
||||
cached = _connection_cache.get(cache_key)
|
||||
if cached and now < cached["expires_at"]:
|
||||
return bool(cached["ok"])
|
||||
ok = connector.test_connection()
|
||||
|
||||
# Run synchronous test_connection in a separate thread to avoid blocking event loop
|
||||
try:
|
||||
ok = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.test_connection),
|
||||
timeout=10.0
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Connection test failed or timed out: {e}")
|
||||
ok = False
|
||||
|
||||
with _cache_lock:
|
||||
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
|
||||
return ok
|
||||
@@ -224,7 +234,7 @@ async def process_nl2sql(
|
||||
elif request.source == "upload":
|
||||
try:
|
||||
upload_started = time.perf_counter()
|
||||
upload_payload = _get_upload_payload(request.file_url)
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
schema = upload_payload["schema"]
|
||||
await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)")
|
||||
@@ -234,14 +244,21 @@ async def process_nl2sql(
|
||||
try:
|
||||
ds_started = time.perf_counter()
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
||||
connector = get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _get_ds_connector():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds:
|
||||
return None
|
||||
return get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
connector = await asyncio.to_thread(_get_ds_connector)
|
||||
if not connector:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}")
|
||||
|
||||
await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)")
|
||||
except ValueError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
|
||||
@@ -258,20 +275,35 @@ async def process_nl2sql(
|
||||
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
||||
else:
|
||||
conn_started = time.perf_counter()
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
if not await _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)")
|
||||
schema_started = time.perf_counter()
|
||||
schema = connector.get_schema()
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=30.0
|
||||
)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
|
||||
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
||||
|
||||
if connector and not schema:
|
||||
retry_started = time.perf_counter()
|
||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
if not await _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=30.0
|
||||
)
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}")
|
||||
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
|
||||
|
||||
@@ -282,7 +314,7 @@ async def process_nl2sql(
|
||||
if request.source.startswith("ds:"):
|
||||
try:
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
mdl = MDLService.get_mdl(ds_id)
|
||||
mdl = await asyncio.to_thread(MDLService.get_mdl, ds_id)
|
||||
if mdl:
|
||||
mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"]
|
||||
|
||||
@@ -392,7 +424,8 @@ Language: Chinese (Simplified)
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_df = _get_upload_payload(request.file_url)["df"]
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
timeout_stage = "sql_execution"
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
|
||||
Reference in New Issue
Block a user