feat: add get schema tool

This commit is contained in:
qixinbo
2026-03-22 00:42:48 +08:00
parent b0c8f84db9
commit 0e7f275285
8 changed files with 228 additions and 20 deletions
+17 -3
View File
@@ -4,12 +4,15 @@ import os
import json
import time
import threading
import logging
from pathlib import Path
from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field
import duckdb
import pandas as pd
logger = logging.getLogger(__name__)
# Add project root to sys.path to allow importing nanobot
PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path:
@@ -221,8 +224,11 @@ async def _check_connection_with_cache(source: str, connector: Any) -> bool:
try:
ok = await asyncio.wait_for(
asyncio.to_thread(connector.test_connection),
timeout=10.0
timeout=15.0
)
except asyncio.TimeoutError:
print("Connection test failed or timed out: Timeout after 15 seconds")
ok = False
except Exception as e:
print(f"Connection test failed or timed out: {e}")
ok = False
@@ -300,8 +306,10 @@ async def process_nl2sql(
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=30.0
timeout=120.0
)
except asyncio.TimeoutError:
return NL2SQLResponse(sql="", result=[], error="Failed to fetch schema: Timeout after 120 seconds. Data source might be too large or network is slow.")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
@@ -449,7 +457,13 @@ Language: Chinese (Simplified)
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
logger.info(f"Generated SQL for query '{request.query}':\n{sql_query}")
# 格式化单行 SQL 用于在前端进度中展示
formatted_sql = sql_query.replace('\n', ' ')
if len(formatted_sql) > 150:
formatted_sql = formatted_sql[:147] + "..."
await emit_progress(f"SQL 生成完成: {formatted_sql}")
except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")
+41 -2
View File
@@ -29,7 +29,44 @@ class PostgresConnector:
# Default schema for postgres is 'public', sqlite is None
schema_name = 'public' if self.engine.dialect.name == 'postgresql' else None
for table_name in inspector.get_table_names(schema=schema_name):
table_names = inspector.get_table_names(schema=schema_name)
# Use SQLAlchemy 2.0+ multi-fetch to avoid N+1 queries issue, especially over remote networks
if hasattr(inspector, 'get_multi_columns'):
multi_columns = inspector.get_multi_columns(schema=schema_name)
multi_pk = inspector.get_multi_pk_constraint(schema=schema_name)
multi_fk = inspector.get_multi_foreign_keys(schema=schema_name)
for table_name in table_names:
key = (schema_name, table_name)
columns = []
for col in multi_columns.get(key, []):
columns.append({
"name": col['name'],
"type": str(col['type'])
})
pk_constraint = multi_pk.get(key)
pks = pk_constraint.get('constrained_columns', []) if pk_constraint else []
foreign_keys = []
for fk in multi_fk.get(key, []):
foreign_keys.append({
"constrained_columns": fk['constrained_columns'],
"referred_table": fk['referred_table'],
"referred_columns": fk['referred_columns']
})
schema[table_name] = {
"columns": columns,
"primary_keys": pks,
"foreign_keys": foreign_keys
}
return schema
# Fallback for older SQLAlchemy versions
for table_name in table_names:
columns = []
# get columns
for col in inspector.get_columns(table_name, schema=schema_name):
@@ -59,8 +96,10 @@ class PostgresConnector:
}
return schema
except Exception as e:
import traceback
traceback.print_exc()
print(f"Error getting schema: {e}")
return {}
raise e
def test_connection(self) -> bool:
try:
+2
View File
@@ -108,8 +108,10 @@ class NanobotIntegration:
def _register_custom_tools(self, agent: AgentLoop):
from app.tools.nl2sql import NL2SQLTool
from app.tools.visualization import VisualizationTool
from app.tools.get_schema import GetDatabaseSchemaTool
agent.tools.register(NL2SQLTool())
agent.tools.register(VisualizationTool())
agent.tools.register(GetDatabaseSchemaTool())
def _make_provider(self, config: Config):
# Logic adapted from nanobot/cli/commands.py
+131
View File
@@ -0,0 +1,131 @@
import json
import logging
from typing import Any
import asyncio
from nanobot.agent.tools.base import Tool
from app.context import current_data_source, current_file_url, current_progress_callback
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.connectors.factory import get_connector
from app.database import SessionLocal
from app.models.datasource import DataSource
# Import schema logic from nl2sql
from app.agent.nl2sql import (
_get_cached_schema,
_set_cached_schema,
_check_connection_with_cache,
_get_upload_payload
)
logger = logging.getLogger(__name__)
class GetDatabaseSchemaTool(Tool):
"""
Tool for fetching the database schema directly without SQL generation.
"""
@property
def name(self) -> str:
return "get_database_schema"
@property
def description(self) -> str:
return (
"Get the structural schema of the currently connected database or data source. "
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
"'show me the database structure', 'what are the columns in table X', etc. "
"It directly returns the schema without generating SQL."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
async def execute(self, **kwargs: Any) -> str:
source = current_data_source.get()
file_url = current_file_url.get()
on_progress = current_progress_callback.get()
async def emit_progress(msg: str):
if on_progress:
await on_progress(msg)
await emit_progress("正在获取数据源结构...")
connector = None
schema = {}
if not source:
return "Error: No data source connected."
if source == "postgres":
connector = postgres_connector
elif source == "clickhouse":
connector = clickhouse_connector
elif source == "upload":
try:
payload = await asyncio.to_thread(_get_upload_payload, file_url)
schema = payload["schema"]
await emit_progress("文件 Schema 获取完成")
except Exception as e:
return f"Failed to get upload schema: {e}"
elif source.startswith("ds:"):
try:
ds_id = int(source.split(":")[1])
def _get_ds_connector():
db = SessionLocal()
try:
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
if not ds: return None
return get_connector(ds)
finally:
db.close()
connector = await asyncio.to_thread(_get_ds_connector)
if not connector:
return f"Data source not found: {source}"
except Exception as e:
return f"Failed to load data source: {e}"
else:
return f"Unsupported data source: {source}"
if connector:
cached_schema = _get_cached_schema(source, connector)
if cached_schema is not None:
schema = cached_schema
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
else:
if not await _check_connection_with_cache(source, connector):
return f"Failed to connect to {source}"
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=120.0
)
_set_cached_schema(source, connector, schema)
await emit_progress(f"成功获取 {len(schema)} 张表结构")
except asyncio.TimeoutError:
return "Failed to fetch schema: Timeout after 120 seconds."
except Exception as e:
return f"Failed to fetch schema: {e}"
# Format the output for the LLM to make it readable and token-efficient
lines = []
for table_name, table_info in schema.items():
if isinstance(table_info, list):
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
lines.append(f"Table: {table_name}\n Columns: {cols}")
elif isinstance(table_info, dict):
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
pks = ", ".join(table_info.get("primary_keys", []))
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
return "\n\n".join(lines) if lines else "No tables found in schema."
+9 -8
View File
@@ -74,14 +74,8 @@ class NL2SQLTool(Tool):
# Call the core logic
result = await process_nl2sql(request, on_progress=on_progress)
if result.error:
return f"Error executing query: {result.error}"
# Save the result data to context for potential later use by VisualizationTool
if result.result:
current_data.set(result.result)
# Save visualization payload to context so the chat stream can pick it up
# Always save visualization payload to context so the chat stream can pick it up
# Even if there's an error, we want the frontend to see the generated SQL
viz_payload = _build_sql_chart_viz(result)
existing_viz = current_viz_data.get()
if isinstance(existing_viz, dict):
@@ -91,6 +85,13 @@ class NL2SQLTool(Tool):
else:
current_viz_data.set(viz_payload)
if result.error:
return f"Error executing query: {result.error}\nGenerated SQL: {result.sql}"
# Save the result data to context for potential later use by VisualizationTool
if result.result:
current_data.set(result.result)
# Build a summary string for the Agent to read
row_count = len(result.result) if result.result else 0
+1
View File
@@ -246,6 +246,7 @@ async def nanobot_chat_stream(request: ChatRequest):
else:
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
yield ": keep-alive\n\n"
continue
response = await current_task
+12
View File
@@ -0,0 +1,12 @@
import asyncio
import json
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest
async def main():
req = NL2SQLRequest(query="列出所有表", source="postgres", generate_chart=False)
res = await process_nl2sql(req)
print("SQL:", res.sql)
print("Error:", res.error)
print("Result:", res.result)
asyncio.run(main())