feat: add get schema tool

This commit is contained in:
qixinbo
2026-03-22 00:42:48 +08:00
parent b0c8f84db9
commit 0e7f275285
8 changed files with 228 additions and 20 deletions
+17 -3
View File
@@ -4,12 +4,15 @@ import os
import json
import time
import threading
import logging
from pathlib import Path
from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field
import duckdb
import pandas as pd
logger = logging.getLogger(__name__)
# Add project root to sys.path to allow importing nanobot
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
@@ -221,8 +224,11 @@ async def _check_connection_with_cache(source: str, connector: Any) -> bool:
try:
ok = await asyncio.wait_for(
asyncio.to_thread(connector.test_connection),
timeout=10.0
timeout=15.0
)
except asyncio.TimeoutError:
print("Connection test failed or timed out: Timeout after 15 seconds")
ok = False
except Exception as e:
print(f"Connection test failed or timed out: {e}")
ok = False
@@ -300,8 +306,10 @@ async def process_nl2sql(
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=30.0
timeout=120.0
)
except asyncio.TimeoutError:
return NL2SQLResponse(sql="", result=[], error="Failed to fetch schema: Timeout after 120 seconds. Data source might be too large or network is slow.")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
@@ -449,7 +457,13 @@ Language: Chinese (Simplified)
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
logger.info(f"Generated SQL for query '{request.query}':\n{sql_query}")
# 格式化单行 SQL 用于在前端进度中展示
formatted_sql = sql_query.replace('\n', ' ')
if len(formatted_sql) > 150:
formatted_sql = formatted_sql[:147] + "..."
await emit_progress(f"SQL 生成完成: {formatted_sql}")
except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")