feat: add get schema tool
This commit is contained in:
@@ -4,12 +4,15 @@ import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from pydantic import BaseModel, Field
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add project root to sys.path to allow importing nanobot
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||
if str(PROJECT_ROOT) not in sys.path:
|
||||
@@ -221,8 +224,11 @@ async def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||
try:
|
||||
ok = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.test_connection),
|
||||
timeout=10.0
|
||||
timeout=15.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
print("Connection test failed or timed out: Timeout after 15 seconds")
|
||||
ok = False
|
||||
except Exception as e:
|
||||
print(f"Connection test failed or timed out: {e}")
|
||||
ok = False
|
||||
@@ -300,8 +306,10 @@ async def process_nl2sql(
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=30.0
|
||||
timeout=120.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return NL2SQLResponse(sql="", result=[], error="Failed to fetch schema: Timeout after 120 seconds. Data source might be too large or network is slow.")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
|
||||
|
||||
@@ -449,7 +457,13 @@ Language: Chinese (Simplified)
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
logger.info(f"Generated SQL for query '{request.query}':\n{sql_query}")
|
||||
|
||||
# 格式化单行 SQL 用于在前端进度中展示
|
||||
formatted_sql = sql_query.replace('\n', ' ')
|
||||
if len(formatted_sql) > 150:
|
||||
formatted_sql = formatted_sql[:147] + "..."
|
||||
await emit_progress(f"SQL 生成完成: {formatted_sql}")
|
||||
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
@@ -29,7 +29,44 @@ class PostgresConnector:
|
||||
# Default schema for postgres is 'public', sqlite is None
|
||||
schema_name = 'public' if self.engine.dialect.name == 'postgresql' else None
|
||||
|
||||
for table_name in inspector.get_table_names(schema=schema_name):
|
||||
table_names = inspector.get_table_names(schema=schema_name)
|
||||
|
||||
# Use SQLAlchemy 2.0+ multi-fetch to avoid N+1 queries issue, especially over remote networks
|
||||
if hasattr(inspector, 'get_multi_columns'):
|
||||
multi_columns = inspector.get_multi_columns(schema=schema_name)
|
||||
multi_pk = inspector.get_multi_pk_constraint(schema=schema_name)
|
||||
multi_fk = inspector.get_multi_foreign_keys(schema=schema_name)
|
||||
|
||||
for table_name in table_names:
|
||||
key = (schema_name, table_name)
|
||||
|
||||
columns = []
|
||||
for col in multi_columns.get(key, []):
|
||||
columns.append({
|
||||
"name": col['name'],
|
||||
"type": str(col['type'])
|
||||
})
|
||||
|
||||
pk_constraint = multi_pk.get(key)
|
||||
pks = pk_constraint.get('constrained_columns', []) if pk_constraint else []
|
||||
|
||||
foreign_keys = []
|
||||
for fk in multi_fk.get(key, []):
|
||||
foreign_keys.append({
|
||||
"constrained_columns": fk['constrained_columns'],
|
||||
"referred_table": fk['referred_table'],
|
||||
"referred_columns": fk['referred_columns']
|
||||
})
|
||||
|
||||
schema[table_name] = {
|
||||
"columns": columns,
|
||||
"primary_keys": pks,
|
||||
"foreign_keys": foreign_keys
|
||||
}
|
||||
return schema
|
||||
|
||||
# Fallback for older SQLAlchemy versions
|
||||
for table_name in table_names:
|
||||
columns = []
|
||||
# get columns
|
||||
for col in inspector.get_columns(table_name, schema=schema_name):
|
||||
@@ -59,8 +96,10 @@ class PostgresConnector:
|
||||
}
|
||||
return schema
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(f"Error getting schema: {e}")
|
||||
return {}
|
||||
raise e
|
||||
|
||||
def test_connection(self) -> bool:
|
||||
try:
|
||||
|
||||
@@ -108,8 +108,10 @@ class NanobotIntegration:
|
||||
def _register_custom_tools(self, agent: AgentLoop):
|
||||
from app.tools.nl2sql import NL2SQLTool
|
||||
from app.tools.visualization import VisualizationTool
|
||||
from app.tools.get_schema import GetDatabaseSchemaTool
|
||||
agent.tools.register(NL2SQLTool())
|
||||
agent.tools.register(VisualizationTool())
|
||||
agent.tools.register(GetDatabaseSchemaTool())
|
||||
|
||||
def _make_provider(self, config: Config):
|
||||
# Logic adapted from nanobot/cli/commands.py
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from app.context import current_data_source, current_file_url, current_progress_callback
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.connectors.factory import get_connector
|
||||
from app.database import SessionLocal
|
||||
from app.models.datasource import DataSource
|
||||
|
||||
# Import schema logic from nl2sql
|
||||
from app.agent.nl2sql import (
|
||||
_get_cached_schema,
|
||||
_set_cached_schema,
|
||||
_check_connection_with_cache,
|
||||
_get_upload_payload
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GetDatabaseSchemaTool(Tool):
|
||||
"""
|
||||
Tool for fetching the database schema directly without SQL generation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_database_schema"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Get the structural schema of the currently connected database or data source. "
|
||||
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
|
||||
"'show me the database structure', 'what are the columns in table X', etc. "
|
||||
"It directly returns the schema without generating SQL."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
source = current_data_source.get()
|
||||
file_url = current_file_url.get()
|
||||
on_progress = current_progress_callback.get()
|
||||
|
||||
async def emit_progress(msg: str):
|
||||
if on_progress:
|
||||
await on_progress(msg)
|
||||
|
||||
await emit_progress("正在获取数据源结构...")
|
||||
|
||||
connector = None
|
||||
schema = {}
|
||||
|
||||
if not source:
|
||||
return "Error: No data source connected."
|
||||
|
||||
if source == "postgres":
|
||||
connector = postgres_connector
|
||||
elif source == "clickhouse":
|
||||
connector = clickhouse_connector
|
||||
elif source == "upload":
|
||||
try:
|
||||
payload = await asyncio.to_thread(_get_upload_payload, file_url)
|
||||
schema = payload["schema"]
|
||||
await emit_progress("文件 Schema 获取完成")
|
||||
except Exception as e:
|
||||
return f"Failed to get upload schema: {e}"
|
||||
elif source.startswith("ds:"):
|
||||
try:
|
||||
ds_id = int(source.split(":")[1])
|
||||
def _get_ds_connector():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||
if not ds: return None
|
||||
return get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
connector = await asyncio.to_thread(_get_ds_connector)
|
||||
if not connector:
|
||||
return f"Data source not found: {source}"
|
||||
except Exception as e:
|
||||
return f"Failed to load data source: {e}"
|
||||
else:
|
||||
return f"Unsupported data source: {source}"
|
||||
|
||||
if connector:
|
||||
cached_schema = _get_cached_schema(source, connector)
|
||||
if cached_schema is not None:
|
||||
schema = cached_schema
|
||||
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
|
||||
else:
|
||||
if not await _check_connection_with_cache(source, connector):
|
||||
return f"Failed to connect to {source}"
|
||||
|
||||
try:
|
||||
schema = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.get_schema),
|
||||
timeout=120.0
|
||||
)
|
||||
_set_cached_schema(source, connector, schema)
|
||||
await emit_progress(f"成功获取 {len(schema)} 张表结构")
|
||||
except asyncio.TimeoutError:
|
||||
return "Failed to fetch schema: Timeout after 120 seconds."
|
||||
except Exception as e:
|
||||
return f"Failed to fetch schema: {e}"
|
||||
|
||||
# Format the output for the LLM to make it readable and token-efficient
|
||||
lines = []
|
||||
for table_name, table_info in schema.items():
|
||||
if isinstance(table_info, list):
|
||||
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
|
||||
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
|
||||
lines.append(f"Table: {table_name}\n Columns: {cols}")
|
||||
elif isinstance(table_info, dict):
|
||||
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
|
||||
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
|
||||
pks = ", ".join(table_info.get("primary_keys", []))
|
||||
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
|
||||
|
||||
return "\n\n".join(lines) if lines else "No tables found in schema."
|
||||
@@ -74,14 +74,8 @@ class NL2SQLTool(Tool):
|
||||
# Call the core logic
|
||||
result = await process_nl2sql(request, on_progress=on_progress)
|
||||
|
||||
if result.error:
|
||||
return f"Error executing query: {result.error}"
|
||||
|
||||
# Save the result data to context for potential later use by VisualizationTool
|
||||
if result.result:
|
||||
current_data.set(result.result)
|
||||
|
||||
# Save visualization payload to context so the chat stream can pick it up
|
||||
# Always save visualization payload to context so the chat stream can pick it up
|
||||
# Even if there's an error, we want the frontend to see the generated SQL
|
||||
viz_payload = _build_sql_chart_viz(result)
|
||||
existing_viz = current_viz_data.get()
|
||||
if isinstance(existing_viz, dict):
|
||||
@@ -91,6 +85,13 @@ class NL2SQLTool(Tool):
|
||||
else:
|
||||
current_viz_data.set(viz_payload)
|
||||
|
||||
if result.error:
|
||||
return f"Error executing query: {result.error}\nGenerated SQL: {result.sql}"
|
||||
|
||||
# Save the result data to context for potential later use by VisualizationTool
|
||||
if result.result:
|
||||
current_data.set(result.result)
|
||||
|
||||
# Build a summary string for the Agent to read
|
||||
row_count = len(result.result) if result.result else 0
|
||||
|
||||
|
||||
@@ -246,6 +246,7 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
yield ": keep-alive\n\n"
|
||||
continue
|
||||
|
||||
response = await current_task
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest
|
||||
|
||||
async def main():
|
||||
req = NL2SQLRequest(query="列出所有表", source="postgres", generate_chart=False)
|
||||
res = await process_nl2sql(req)
|
||||
print("SQL:", res.sql)
|
||||
print("Error:", res.error)
|
||||
print("Result:", res.result)
|
||||
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user