132 lines
5.1 KiB
Python
132 lines
5.1 KiB
Python
import json
|
|
import logging
|
|
from typing import Any
|
|
import asyncio
|
|
|
|
from nanobot.agent.tools.base import Tool
|
|
from app.context import current_data_source, current_file_url, current_progress_callback
|
|
from app.connectors.postgres import postgres_connector
|
|
from app.connectors.clickhouse import clickhouse_connector
|
|
from app.connectors.factory import get_connector
|
|
from app.database import SessionLocal
|
|
from app.models.datasource import DataSource
|
|
|
|
# Import schema logic from nl2sql
|
|
from app.agent.nl2sql import (
|
|
_get_cached_schema,
|
|
_set_cached_schema,
|
|
_check_connection_with_cache,
|
|
_get_upload_payload
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class GetDatabaseSchemaTool(Tool):
|
|
"""
|
|
Tool for fetching the database schema directly without SQL generation.
|
|
"""
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "get_database_schema"
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return (
|
|
"Get the structural schema of the currently connected database or data source. "
|
|
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
|
|
"'show me the database structure', 'what are the columns in table X', etc. "
|
|
"It directly returns the schema without generating SQL."
|
|
)
|
|
|
|
@property
|
|
def parameters(self) -> dict[str, Any]:
|
|
return {
|
|
"type": "object",
|
|
"properties": {},
|
|
"required": [],
|
|
}
|
|
|
|
async def execute(self, **kwargs: Any) -> str:
|
|
source = current_data_source.get()
|
|
file_url = current_file_url.get()
|
|
on_progress = current_progress_callback.get()
|
|
|
|
async def emit_progress(msg: str):
|
|
if on_progress:
|
|
await on_progress(msg)
|
|
|
|
await emit_progress("正在获取数据源结构...")
|
|
|
|
connector = None
|
|
schema = {}
|
|
|
|
if not source:
|
|
return "Error: No data source connected."
|
|
|
|
if source == "postgres":
|
|
connector = postgres_connector
|
|
elif source == "clickhouse":
|
|
connector = clickhouse_connector
|
|
elif source == "upload":
|
|
try:
|
|
payload = await asyncio.to_thread(_get_upload_payload, file_url)
|
|
schema = payload["schema"]
|
|
await emit_progress("文件 Schema 获取完成")
|
|
except Exception as e:
|
|
return f"Failed to get upload schema: {e}"
|
|
elif source.startswith("ds:"):
|
|
try:
|
|
ds_id = int(source.split(":")[1])
|
|
def _get_ds_connector():
|
|
db = SessionLocal()
|
|
try:
|
|
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
|
if not ds: return None
|
|
return get_connector(ds)
|
|
finally:
|
|
db.close()
|
|
connector = await asyncio.to_thread(_get_ds_connector)
|
|
if not connector:
|
|
return f"Data source not found: {source}"
|
|
except Exception as e:
|
|
return f"Failed to load data source: {e}"
|
|
else:
|
|
return f"Unsupported data source: {source}"
|
|
|
|
if connector:
|
|
cached_schema = _get_cached_schema(source, connector)
|
|
if cached_schema is not None:
|
|
schema = cached_schema
|
|
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
|
|
else:
|
|
if not await _check_connection_with_cache(source, connector):
|
|
return f"Failed to connect to {source}"
|
|
|
|
try:
|
|
schema = await asyncio.wait_for(
|
|
asyncio.to_thread(connector.get_schema),
|
|
timeout=120.0
|
|
)
|
|
_set_cached_schema(source, connector, schema)
|
|
await emit_progress(f"成功获取 {len(schema)} 张表结构")
|
|
except asyncio.TimeoutError:
|
|
return "Failed to fetch schema: Timeout after 120 seconds."
|
|
except Exception as e:
|
|
return f"Failed to fetch schema: {e}"
|
|
|
|
# Format the output for the LLM to make it readable and token-efficient
|
|
lines = []
|
|
for table_name, table_info in schema.items():
|
|
if isinstance(table_info, list):
|
|
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
|
|
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
|
|
lines.append(f"Table: {table_name}\n Columns: {cols}")
|
|
elif isinstance(table_info, dict):
|
|
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
|
|
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
|
|
pks = ", ".join(table_info.get("primary_keys", []))
|
|
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
|
|
|
|
return "\n\n".join(lines) if lines else "No tables found in schema."
|