feat: add get schema tool

This commit is contained in:
qixinbo
2026-03-22 00:42:48 +08:00
parent b0c8f84db9
commit 0e7f275285
8 changed files with 228 additions and 20 deletions
+131
View File
@@ -0,0 +1,131 @@
import json
import logging
from typing import Any
import asyncio
from nanobot.agent.tools.base import Tool
from app.context import current_data_source, current_file_url, current_progress_callback
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.connectors.factory import get_connector
from app.database import SessionLocal
from app.models.datasource import DataSource
# Import schema logic from nl2sql
from app.agent.nl2sql import (
_get_cached_schema,
_set_cached_schema,
_check_connection_with_cache,
_get_upload_payload
)
logger = logging.getLogger(__name__)
class GetDatabaseSchemaTool(Tool):
"""
Tool for fetching the database schema directly without SQL generation.
"""
@property
def name(self) -> str:
return "get_database_schema"
@property
def description(self) -> str:
return (
"Get the structural schema of the currently connected database or data source. "
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
"'show me the database structure', 'what are the columns in table X', etc. "
"It directly returns the schema without generating SQL."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
async def execute(self, **kwargs: Any) -> str:
source = current_data_source.get()
file_url = current_file_url.get()
on_progress = current_progress_callback.get()
async def emit_progress(msg: str):
if on_progress:
await on_progress(msg)
await emit_progress("正在获取数据源结构...")
connector = None
schema = {}
if not source:
return "Error: No data source connected."
if source == "postgres":
connector = postgres_connector
elif source == "clickhouse":
connector = clickhouse_connector
elif source == "upload":
try:
payload = await asyncio.to_thread(_get_upload_payload, file_url)
schema = payload["schema"]
await emit_progress("文件 Schema 获取完成")
except Exception as e:
return f"Failed to get upload schema: {e}"
elif source.startswith("ds:"):
try:
ds_id = int(source.split(":")[1])
def _get_ds_connector():
db = SessionLocal()
try:
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
if not ds: return None
return get_connector(ds)
finally:
db.close()
connector = await asyncio.to_thread(_get_ds_connector)
if not connector:
return f"Data source not found: {source}"
except Exception as e:
return f"Failed to load data source: {e}"
else:
return f"Unsupported data source: {source}"
if connector:
cached_schema = _get_cached_schema(source, connector)
if cached_schema is not None:
schema = cached_schema
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
else:
if not await _check_connection_with_cache(source, connector):
return f"Failed to connect to {source}"
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=120.0
)
_set_cached_schema(source, connector, schema)
await emit_progress(f"成功获取 {len(schema)} 张表结构")
except asyncio.TimeoutError:
return "Failed to fetch schema: Timeout after 120 seconds."
except Exception as e:
return f"Failed to fetch schema: {e}"
# Format the output for the LLM to make it readable and token-efficient
lines = []
for table_name, table_info in schema.items():
if isinstance(table_info, list):
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
lines.append(f"Table: {table_name}\n Columns: {cols}")
elif isinstance(table_info, dict):
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
pks = ", ".join(table_info.get("primary_keys", []))
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
return "\n\n".join(lines) if lines else "No tables found in schema."
+9 -8
View File
@@ -74,14 +74,8 @@ class NL2SQLTool(Tool):
# Call the core logic
result = await process_nl2sql(request, on_progress=on_progress)
if result.error:
return f"Error executing query: {result.error}"
# Save the result data to context for potential later use by VisualizationTool
if result.result:
current_data.set(result.result)
# Save visualization payload to context so the chat stream can pick it up
# Always save visualization payload to context so the chat stream can pick it up
# Even if there's an error, we want the frontend to see the generated SQL
viz_payload = _build_sql_chart_viz(result)
existing_viz = current_viz_data.get()
if isinstance(existing_viz, dict):
@@ -91,6 +85,13 @@ class NL2SQLTool(Tool):
else:
current_viz_data.set(viz_payload)
if result.error:
return f"Error executing query: {result.error}\nGenerated SQL: {result.sql}"
# Save the result data to context for potential later use by VisualizationTool
if result.result:
current_data.set(result.result)
# Build a summary string for the Agent to read
row_count = len(result.result) if result.result else 0