Files
DataClaw/backend/app/tools/get_schema.py
T
2026-03-22 00:42:48 +08:00

132 lines
5.1 KiB
Python

import json
import logging
from typing import Any
import asyncio
from nanobot.agent.tools.base import Tool
from app.context import current_data_source, current_file_url, current_progress_callback
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.connectors.factory import get_connector
from app.database import SessionLocal
from app.models.datasource import DataSource
# Import schema logic from nl2sql
from app.agent.nl2sql import (
_get_cached_schema,
_set_cached_schema,
_check_connection_with_cache,
_get_upload_payload
)
logger = logging.getLogger(__name__)
class GetDatabaseSchemaTool(Tool):
"""
Tool for fetching the database schema directly without SQL generation.
"""
@property
def name(self) -> str:
return "get_database_schema"
@property
def description(self) -> str:
return (
"Get the structural schema of the currently connected database or data source. "
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
"'show me the database structure', 'what are the columns in table X', etc. "
"It directly returns the schema without generating SQL."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
async def execute(self, **kwargs: Any) -> str:
source = current_data_source.get()
file_url = current_file_url.get()
on_progress = current_progress_callback.get()
async def emit_progress(msg: str):
if on_progress:
await on_progress(msg)
await emit_progress("正在获取数据源结构...")
connector = None
schema = {}
if not source:
return "Error: No data source connected."
if source == "postgres":
connector = postgres_connector
elif source == "clickhouse":
connector = clickhouse_connector
elif source == "upload":
try:
payload = await asyncio.to_thread(_get_upload_payload, file_url)
schema = payload["schema"]
await emit_progress("文件 Schema 获取完成")
except Exception as e:
return f"Failed to get upload schema: {e}"
elif source.startswith("ds:"):
try:
ds_id = int(source.split(":")[1])
def _get_ds_connector():
db = SessionLocal()
try:
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
if not ds: return None
return get_connector(ds)
finally:
db.close()
connector = await asyncio.to_thread(_get_ds_connector)
if not connector:
return f"Data source not found: {source}"
except Exception as e:
return f"Failed to load data source: {e}"
else:
return f"Unsupported data source: {source}"
if connector:
cached_schema = _get_cached_schema(source, connector)
if cached_schema is not None:
schema = cached_schema
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
else:
if not await _check_connection_with_cache(source, connector):
return f"Failed to connect to {source}"
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=120.0
)
_set_cached_schema(source, connector, schema)
await emit_progress(f"成功获取 {len(schema)} 张表结构")
except asyncio.TimeoutError:
return "Failed to fetch schema: Timeout after 120 seconds."
except Exception as e:
return f"Failed to fetch schema: {e}"
# Format the output for the LLM to make it readable and token-efficient
lines = []
for table_name, table_info in schema.items():
if isinstance(table_info, list):
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
lines.append(f"Table: {table_name}\n Columns: {cols}")
elif isinstance(table_info, dict):
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
pks = ", ".join(table_info.get("primary_keys", []))
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
return "\n\n".join(lines) if lines else "No tables found in schema."