feat: add get schema tool

This commit is contained in:
qixinbo
2026-03-22 00:42:48 +08:00
parent b0c8f84db9
commit 0e7f275285
8 changed files with 228 additions and 20 deletions
+17 -3
View File
@@ -4,12 +4,15 @@ import os
import json import json
import time import time
import threading import threading
import logging
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Any, Callable, Awaitable from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import duckdb import duckdb
import pandas as pd import pandas as pd
logger = logging.getLogger(__name__)
# Add project root to sys.path to allow importing nanobot # Add project root to sys.path to allow importing nanobot
PROJECT_ROOT = Path(__file__).resolve().parents[3] PROJECT_ROOT = Path(__file__).resolve().parents[3]
if str(PROJECT_ROOT) not in sys.path: if str(PROJECT_ROOT) not in sys.path:
@@ -221,8 +224,11 @@ async def _check_connection_with_cache(source: str, connector: Any) -> bool:
try: try:
ok = await asyncio.wait_for( ok = await asyncio.wait_for(
asyncio.to_thread(connector.test_connection), asyncio.to_thread(connector.test_connection),
timeout=10.0 timeout=15.0
) )
except asyncio.TimeoutError:
print("Connection test failed or timed out: Timeout after 15 seconds")
ok = False
except Exception as e: except Exception as e:
print(f"Connection test failed or timed out: {e}") print(f"Connection test failed or timed out: {e}")
ok = False ok = False
@@ -300,8 +306,10 @@ async def process_nl2sql(
try: try:
schema = await asyncio.wait_for( schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema), asyncio.to_thread(connector.get_schema),
timeout=30.0 timeout=120.0
) )
except asyncio.TimeoutError:
return NL2SQLResponse(sql="", result=[], error="Failed to fetch schema: Timeout after 120 seconds. Data source might be too large or network is slow.")
except Exception as e: except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}") return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
@@ -449,7 +457,13 @@ Language: Chinese (Simplified)
# Fallback if LLM doesn't return valid JSON despite instructions # Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)") logger.info(f"Generated SQL for query '{request.query}':\n{sql_query}")
# 格式化单行 SQL 用于在前端进度中展示
formatted_sql = sql_query.replace('\n', ' ')
if len(formatted_sql) > 150:
formatted_sql = formatted_sql[:147] + "..."
await emit_progress(f"SQL 生成完成: {formatted_sql}")
except Exception as e: except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}") return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")
+41 -2
View File
@@ -29,7 +29,44 @@ class PostgresConnector:
# Default schema for postgres is 'public', sqlite is None # Default schema for postgres is 'public', sqlite is None
schema_name = 'public' if self.engine.dialect.name == 'postgresql' else None schema_name = 'public' if self.engine.dialect.name == 'postgresql' else None
for table_name in inspector.get_table_names(schema=schema_name): table_names = inspector.get_table_names(schema=schema_name)
# Use SQLAlchemy 2.0+ multi-fetch to avoid N+1 queries issue, especially over remote networks
if hasattr(inspector, 'get_multi_columns'):
multi_columns = inspector.get_multi_columns(schema=schema_name)
multi_pk = inspector.get_multi_pk_constraint(schema=schema_name)
multi_fk = inspector.get_multi_foreign_keys(schema=schema_name)
for table_name in table_names:
key = (schema_name, table_name)
columns = []
for col in multi_columns.get(key, []):
columns.append({
"name": col['name'],
"type": str(col['type'])
})
pk_constraint = multi_pk.get(key)
pks = pk_constraint.get('constrained_columns', []) if pk_constraint else []
foreign_keys = []
for fk in multi_fk.get(key, []):
foreign_keys.append({
"constrained_columns": fk['constrained_columns'],
"referred_table": fk['referred_table'],
"referred_columns": fk['referred_columns']
})
schema[table_name] = {
"columns": columns,
"primary_keys": pks,
"foreign_keys": foreign_keys
}
return schema
# Fallback for older SQLAlchemy versions
for table_name in table_names:
columns = [] columns = []
# get columns # get columns
for col in inspector.get_columns(table_name, schema=schema_name): for col in inspector.get_columns(table_name, schema=schema_name):
@@ -59,8 +96,10 @@ class PostgresConnector:
} }
return schema return schema
except Exception as e: except Exception as e:
import traceback
traceback.print_exc()
print(f"Error getting schema: {e}") print(f"Error getting schema: {e}")
return {} raise e
def test_connection(self) -> bool: def test_connection(self) -> bool:
try: try:
+2
View File
@@ -108,8 +108,10 @@ class NanobotIntegration:
def _register_custom_tools(self, agent: AgentLoop): def _register_custom_tools(self, agent: AgentLoop):
from app.tools.nl2sql import NL2SQLTool from app.tools.nl2sql import NL2SQLTool
from app.tools.visualization import VisualizationTool from app.tools.visualization import VisualizationTool
from app.tools.get_schema import GetDatabaseSchemaTool
agent.tools.register(NL2SQLTool()) agent.tools.register(NL2SQLTool())
agent.tools.register(VisualizationTool()) agent.tools.register(VisualizationTool())
agent.tools.register(GetDatabaseSchemaTool())
def _make_provider(self, config: Config): def _make_provider(self, config: Config):
# Logic adapted from nanobot/cli/commands.py # Logic adapted from nanobot/cli/commands.py
+131
View File
@@ -0,0 +1,131 @@
import json
import logging
from typing import Any
import asyncio
from nanobot.agent.tools.base import Tool
from app.context import current_data_source, current_file_url, current_progress_callback
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.connectors.factory import get_connector
from app.database import SessionLocal
from app.models.datasource import DataSource
# Import schema logic from nl2sql
from app.agent.nl2sql import (
_get_cached_schema,
_set_cached_schema,
_check_connection_with_cache,
_get_upload_payload
)
logger = logging.getLogger(__name__)
class GetDatabaseSchemaTool(Tool):
"""
Tool for fetching the database schema directly without SQL generation.
"""
@property
def name(self) -> str:
return "get_database_schema"
@property
def description(self) -> str:
return (
"Get the structural schema of the currently connected database or data source. "
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
"'show me the database structure', 'what are the columns in table X', etc. "
"It directly returns the schema without generating SQL."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {},
"required": [],
}
async def execute(self, **kwargs: Any) -> str:
source = current_data_source.get()
file_url = current_file_url.get()
on_progress = current_progress_callback.get()
async def emit_progress(msg: str):
if on_progress:
await on_progress(msg)
await emit_progress("正在获取数据源结构...")
connector = None
schema = {}
if not source:
return "Error: No data source connected."
if source == "postgres":
connector = postgres_connector
elif source == "clickhouse":
connector = clickhouse_connector
elif source == "upload":
try:
payload = await asyncio.to_thread(_get_upload_payload, file_url)
schema = payload["schema"]
await emit_progress("文件 Schema 获取完成")
except Exception as e:
return f"Failed to get upload schema: {e}"
elif source.startswith("ds:"):
try:
ds_id = int(source.split(":")[1])
def _get_ds_connector():
db = SessionLocal()
try:
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
if not ds: return None
return get_connector(ds)
finally:
db.close()
connector = await asyncio.to_thread(_get_ds_connector)
if not connector:
return f"Data source not found: {source}"
except Exception as e:
return f"Failed to load data source: {e}"
else:
return f"Unsupported data source: {source}"
if connector:
cached_schema = _get_cached_schema(source, connector)
if cached_schema is not None:
schema = cached_schema
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
else:
if not await _check_connection_with_cache(source, connector):
return f"Failed to connect to {source}"
try:
schema = await asyncio.wait_for(
asyncio.to_thread(connector.get_schema),
timeout=120.0
)
_set_cached_schema(source, connector, schema)
await emit_progress(f"成功获取 {len(schema)} 张表结构")
except asyncio.TimeoutError:
return "Failed to fetch schema: Timeout after 120 seconds."
except Exception as e:
return f"Failed to fetch schema: {e}"
# Format the output for the LLM to make it readable and token-efficient
lines = []
for table_name, table_info in schema.items():
if isinstance(table_info, list):
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
lines.append(f"Table: {table_name}\n Columns: {cols}")
elif isinstance(table_info, dict):
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
pks = ", ".join(table_info.get("primary_keys", []))
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
return "\n\n".join(lines) if lines else "No tables found in schema."
+9 -8
View File
@@ -74,14 +74,8 @@ class NL2SQLTool(Tool):
# Call the core logic # Call the core logic
result = await process_nl2sql(request, on_progress=on_progress) result = await process_nl2sql(request, on_progress=on_progress)
if result.error: # Always save visualization payload to context so the chat stream can pick it up
return f"Error executing query: {result.error}" # Even if there's an error, we want the frontend to see the generated SQL
# Save the result data to context for potential later use by VisualizationTool
if result.result:
current_data.set(result.result)
# Save visualization payload to context so the chat stream can pick it up
viz_payload = _build_sql_chart_viz(result) viz_payload = _build_sql_chart_viz(result)
existing_viz = current_viz_data.get() existing_viz = current_viz_data.get()
if isinstance(existing_viz, dict): if isinstance(existing_viz, dict):
@@ -91,6 +85,13 @@ class NL2SQLTool(Tool):
else: else:
current_viz_data.set(viz_payload) current_viz_data.set(viz_payload)
if result.error:
return f"Error executing query: {result.error}\nGenerated SQL: {result.sql}"
# Save the result data to context for potential later use by VisualizationTool
if result.result:
current_data.set(result.result)
# Build a summary string for the Agent to read # Build a summary string for the Agent to read
row_count = len(result.result) if result.result else 0 row_count = len(result.result) if result.result else 0
+1
View File
@@ -246,6 +246,7 @@ async def nanobot_chat_stream(request: ChatRequest):
else: else:
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError: except asyncio.TimeoutError:
yield ": keep-alive\n\n"
continue continue
response = await current_task response = await current_task
+12
View File
@@ -0,0 +1,12 @@
import asyncio
import json
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest
async def main():
req = NL2SQLRequest(query="列出所有表", source="postgres", generate_chart=False)
res = await process_nl2sql(req)
print("SQL:", res.sql)
print("Error:", res.error)
print("Result:", res.result)
asyncio.run(main())
+15 -7
View File
@@ -958,7 +958,8 @@ export function ChatInterface() {
</div> </div>
) : ( ) : (
<div className="max-w-3xl mx-auto px-4 py-8 space-y-8"> <div className="max-w-3xl mx-auto px-4 py-8 space-y-8">
{messages.map((msg) => { {messages.map((msg, msgIdx) => {
const isMessageGenerating = isLoading && msgIdx === messages.length - 1;
const { markdown, reportHtml } = splitReportHtml(msg.content); const { markdown, reportHtml } = splitReportHtml(msg.content);
const externalReportUrl = extractExternalReport(msg.content); const externalReportUrl = extractExternalReport(msg.content);
return ( return (
@@ -992,17 +993,24 @@ export function ChatInterface() {
{msg.progressLogs && msg.progressLogs.length > 0 ? ( {msg.progressLogs && msg.progressLogs.length > 0 ? (
<div className="mb-2 rounded-xl border border-zinc-100 bg-zinc-50/70 px-3 py-2"> <div className="mb-2 rounded-xl border border-zinc-100 bg-zinc-50/70 px-3 py-2">
<div className="flex items-center gap-2 text-zinc-500 text-xs mb-1.5 pb-1.5 border-b border-zinc-100/50"> <div className="flex items-center gap-2 text-zinc-500 text-xs mb-1.5 pb-1.5 border-b border-zinc-100/50">
{msg.awaitingFirstToken ? <Loader2 className="h-3.5 w-3.5 animate-spin" /> : <CheckCircle2 className="h-3.5 w-3.5 text-emerald-500" />} {isMessageGenerating ? <Loader2 className="h-3.5 w-3.5 animate-spin" /> : <CheckCircle2 className="h-3.5 w-3.5 text-emerald-500" />}
<span>{msg.awaitingFirstToken ? t('processing') : t('processCompleted')}</span> <span>{isMessageGenerating ? t('processing') : t('processCompleted')}</span>
</div> </div>
<div className="space-y-1.5 max-h-[160px] overflow-y-auto pr-1"> <div
className="space-y-1.5 max-h-[160px] overflow-y-auto pr-1"
ref={(el) => {
if (el && isMessageGenerating) {
el.scrollTop = el.scrollHeight;
}
}}
>
{msg.progressLogs.map((log, idx, arr) => { {msg.progressLogs.map((log, idx, arr) => {
const isLast = idx === arr.length - 1; const isLast = idx === arr.length - 1;
// 如果是正在处理的会话,且当前日志是最后一条,或者是明确包含“正在”的日志,则显示 loading // 只有当是整个会话的最后一条消息,且当前日志是最后一条时,才显示 loading 动画
const isLoadingLog = (isLast && msg.awaitingFirstToken) || log.includes(t('processingIndicator')); const isLoadingLog = isLast && isMessageGenerating;
return ( return (
<div key={`${msg.id}-log-${idx}`} className="flex items-start gap-2 text-[12px] text-zinc-500 leading-5"> <div key={`${msg.id}-log-${idx}`} className="flex items-start gap-2 text-[12px] text-zinc-500 leading-5">
{isLoadingLog && msg.awaitingFirstToken ? ( {isLoadingLog ? (
<Settings className="mt-0.5 h-3.5 w-3.5 text-amber-500 animate-spin shrink-0" /> <Settings className="mt-0.5 h-3.5 w-3.5 text-amber-500 animate-spin shrink-0" />
) : ( ) : (
<CheckCircle2 className="mt-0.5 h-3.5 w-3.5 text-emerald-500 shrink-0" /> <CheckCircle2 className="mt-0.5 h-3.5 w-3.5 text-emerald-500 shrink-0" />