feat: add get schema tool
This commit is contained in:
@@ -4,12 +4,15 @@ import os
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import duckdb
|
import duckdb
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Add project root to sys.path to allow importing nanobot
|
# Add project root to sys.path to allow importing nanobot
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
PROJECT_ROOT = Path(__file__).resolve().parents[3]
|
||||||
if str(PROJECT_ROOT) not in sys.path:
|
if str(PROJECT_ROOT) not in sys.path:
|
||||||
@@ -221,8 +224,11 @@ async def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
|||||||
try:
|
try:
|
||||||
ok = await asyncio.wait_for(
|
ok = await asyncio.wait_for(
|
||||||
asyncio.to_thread(connector.test_connection),
|
asyncio.to_thread(connector.test_connection),
|
||||||
timeout=10.0
|
timeout=15.0
|
||||||
)
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
print("Connection test failed or timed out: Timeout after 15 seconds")
|
||||||
|
ok = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Connection test failed or timed out: {e}")
|
print(f"Connection test failed or timed out: {e}")
|
||||||
ok = False
|
ok = False
|
||||||
@@ -300,8 +306,10 @@ async def process_nl2sql(
|
|||||||
try:
|
try:
|
||||||
schema = await asyncio.wait_for(
|
schema = await asyncio.wait_for(
|
||||||
asyncio.to_thread(connector.get_schema),
|
asyncio.to_thread(connector.get_schema),
|
||||||
timeout=30.0
|
timeout=120.0
|
||||||
)
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return NL2SQLResponse(sql="", result=[], error="Failed to fetch schema: Timeout after 120 seconds. Data source might be too large or network is slow.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
|
return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}")
|
||||||
|
|
||||||
@@ -449,7 +457,13 @@ Language: Chinese (Simplified)
|
|||||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||||
sql_query = content
|
sql_query = content
|
||||||
|
|
||||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
logger.info(f"Generated SQL for query '{request.query}':\n{sql_query}")
|
||||||
|
|
||||||
|
# 格式化单行 SQL 用于在前端进度中展示
|
||||||
|
formatted_sql = sql_query.replace('\n', ' ')
|
||||||
|
if len(formatted_sql) > 150:
|
||||||
|
formatted_sql = formatted_sql[:147] + "..."
|
||||||
|
await emit_progress(f"SQL 生成完成: {formatted_sql}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")
|
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")
|
||||||
|
|||||||
@@ -29,7 +29,44 @@ class PostgresConnector:
|
|||||||
# Default schema for postgres is 'public', sqlite is None
|
# Default schema for postgres is 'public', sqlite is None
|
||||||
schema_name = 'public' if self.engine.dialect.name == 'postgresql' else None
|
schema_name = 'public' if self.engine.dialect.name == 'postgresql' else None
|
||||||
|
|
||||||
for table_name in inspector.get_table_names(schema=schema_name):
|
table_names = inspector.get_table_names(schema=schema_name)
|
||||||
|
|
||||||
|
# Use SQLAlchemy 2.0+ multi-fetch to avoid N+1 queries issue, especially over remote networks
|
||||||
|
if hasattr(inspector, 'get_multi_columns'):
|
||||||
|
multi_columns = inspector.get_multi_columns(schema=schema_name)
|
||||||
|
multi_pk = inspector.get_multi_pk_constraint(schema=schema_name)
|
||||||
|
multi_fk = inspector.get_multi_foreign_keys(schema=schema_name)
|
||||||
|
|
||||||
|
for table_name in table_names:
|
||||||
|
key = (schema_name, table_name)
|
||||||
|
|
||||||
|
columns = []
|
||||||
|
for col in multi_columns.get(key, []):
|
||||||
|
columns.append({
|
||||||
|
"name": col['name'],
|
||||||
|
"type": str(col['type'])
|
||||||
|
})
|
||||||
|
|
||||||
|
pk_constraint = multi_pk.get(key)
|
||||||
|
pks = pk_constraint.get('constrained_columns', []) if pk_constraint else []
|
||||||
|
|
||||||
|
foreign_keys = []
|
||||||
|
for fk in multi_fk.get(key, []):
|
||||||
|
foreign_keys.append({
|
||||||
|
"constrained_columns": fk['constrained_columns'],
|
||||||
|
"referred_table": fk['referred_table'],
|
||||||
|
"referred_columns": fk['referred_columns']
|
||||||
|
})
|
||||||
|
|
||||||
|
schema[table_name] = {
|
||||||
|
"columns": columns,
|
||||||
|
"primary_keys": pks,
|
||||||
|
"foreign_keys": foreign_keys
|
||||||
|
}
|
||||||
|
return schema
|
||||||
|
|
||||||
|
# Fallback for older SQLAlchemy versions
|
||||||
|
for table_name in table_names:
|
||||||
columns = []
|
columns = []
|
||||||
# get columns
|
# get columns
|
||||||
for col in inspector.get_columns(table_name, schema=schema_name):
|
for col in inspector.get_columns(table_name, schema=schema_name):
|
||||||
@@ -59,8 +96,10 @@ class PostgresConnector:
|
|||||||
}
|
}
|
||||||
return schema
|
return schema
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
print(f"Error getting schema: {e}")
|
print(f"Error getting schema: {e}")
|
||||||
return {}
|
raise e
|
||||||
|
|
||||||
def test_connection(self) -> bool:
|
def test_connection(self) -> bool:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -108,8 +108,10 @@ class NanobotIntegration:
|
|||||||
def _register_custom_tools(self, agent: AgentLoop):
|
def _register_custom_tools(self, agent: AgentLoop):
|
||||||
from app.tools.nl2sql import NL2SQLTool
|
from app.tools.nl2sql import NL2SQLTool
|
||||||
from app.tools.visualization import VisualizationTool
|
from app.tools.visualization import VisualizationTool
|
||||||
|
from app.tools.get_schema import GetDatabaseSchemaTool
|
||||||
agent.tools.register(NL2SQLTool())
|
agent.tools.register(NL2SQLTool())
|
||||||
agent.tools.register(VisualizationTool())
|
agent.tools.register(VisualizationTool())
|
||||||
|
agent.tools.register(GetDatabaseSchemaTool())
|
||||||
|
|
||||||
def _make_provider(self, config: Config):
|
def _make_provider(self, config: Config):
|
||||||
# Logic adapted from nanobot/cli/commands.py
|
# Logic adapted from nanobot/cli/commands.py
|
||||||
|
|||||||
@@ -0,0 +1,131 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from app.context import current_data_source, current_file_url, current_progress_callback
|
||||||
|
from app.connectors.postgres import postgres_connector
|
||||||
|
from app.connectors.clickhouse import clickhouse_connector
|
||||||
|
from app.connectors.factory import get_connector
|
||||||
|
from app.database import SessionLocal
|
||||||
|
from app.models.datasource import DataSource
|
||||||
|
|
||||||
|
# Import schema logic from nl2sql
|
||||||
|
from app.agent.nl2sql import (
|
||||||
|
_get_cached_schema,
|
||||||
|
_set_cached_schema,
|
||||||
|
_check_connection_with_cache,
|
||||||
|
_get_upload_payload
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class GetDatabaseSchemaTool(Tool):
|
||||||
|
"""
|
||||||
|
Tool for fetching the database schema directly without SQL generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "get_database_schema"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Get the structural schema of the currently connected database or data source. "
|
||||||
|
"Use this tool when the user asks questions about metadata, such as 'what tables are there', "
|
||||||
|
"'show me the database structure', 'what are the columns in table X', etc. "
|
||||||
|
"It directly returns the schema without generating SQL."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
"required": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
source = current_data_source.get()
|
||||||
|
file_url = current_file_url.get()
|
||||||
|
on_progress = current_progress_callback.get()
|
||||||
|
|
||||||
|
async def emit_progress(msg: str):
|
||||||
|
if on_progress:
|
||||||
|
await on_progress(msg)
|
||||||
|
|
||||||
|
await emit_progress("正在获取数据源结构...")
|
||||||
|
|
||||||
|
connector = None
|
||||||
|
schema = {}
|
||||||
|
|
||||||
|
if not source:
|
||||||
|
return "Error: No data source connected."
|
||||||
|
|
||||||
|
if source == "postgres":
|
||||||
|
connector = postgres_connector
|
||||||
|
elif source == "clickhouse":
|
||||||
|
connector = clickhouse_connector
|
||||||
|
elif source == "upload":
|
||||||
|
try:
|
||||||
|
payload = await asyncio.to_thread(_get_upload_payload, file_url)
|
||||||
|
schema = payload["schema"]
|
||||||
|
await emit_progress("文件 Schema 获取完成")
|
||||||
|
except Exception as e:
|
||||||
|
return f"Failed to get upload schema: {e}"
|
||||||
|
elif source.startswith("ds:"):
|
||||||
|
try:
|
||||||
|
ds_id = int(source.split(":")[1])
|
||||||
|
def _get_ds_connector():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
ds = db.query(DataSource).filter(DataSource.id == ds_id).first()
|
||||||
|
if not ds: return None
|
||||||
|
return get_connector(ds)
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
connector = await asyncio.to_thread(_get_ds_connector)
|
||||||
|
if not connector:
|
||||||
|
return f"Data source not found: {source}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Failed to load data source: {e}"
|
||||||
|
else:
|
||||||
|
return f"Unsupported data source: {source}"
|
||||||
|
|
||||||
|
if connector:
|
||||||
|
cached_schema = _get_cached_schema(source, connector)
|
||||||
|
if cached_schema is not None:
|
||||||
|
schema = cached_schema
|
||||||
|
await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构")
|
||||||
|
else:
|
||||||
|
if not await _check_connection_with_cache(source, connector):
|
||||||
|
return f"Failed to connect to {source}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
schema = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(connector.get_schema),
|
||||||
|
timeout=120.0
|
||||||
|
)
|
||||||
|
_set_cached_schema(source, connector, schema)
|
||||||
|
await emit_progress(f"成功获取 {len(schema)} 张表结构")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return "Failed to fetch schema: Timeout after 120 seconds."
|
||||||
|
except Exception as e:
|
||||||
|
return f"Failed to fetch schema: {e}"
|
||||||
|
|
||||||
|
# Format the output for the LLM to make it readable and token-efficient
|
||||||
|
lines = []
|
||||||
|
for table_name, table_info in schema.items():
|
||||||
|
if isinstance(table_info, list):
|
||||||
|
# Clickhouse/Upload format: [{"name": "col", "type": "type"}]
|
||||||
|
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info])
|
||||||
|
lines.append(f"Table: {table_name}\n Columns: {cols}")
|
||||||
|
elif isinstance(table_info, dict):
|
||||||
|
# Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]}
|
||||||
|
cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])])
|
||||||
|
pks = ", ".join(table_info.get("primary_keys", []))
|
||||||
|
lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}")
|
||||||
|
|
||||||
|
return "\n\n".join(lines) if lines else "No tables found in schema."
|
||||||
@@ -74,14 +74,8 @@ class NL2SQLTool(Tool):
|
|||||||
# Call the core logic
|
# Call the core logic
|
||||||
result = await process_nl2sql(request, on_progress=on_progress)
|
result = await process_nl2sql(request, on_progress=on_progress)
|
||||||
|
|
||||||
if result.error:
|
# Always save visualization payload to context so the chat stream can pick it up
|
||||||
return f"Error executing query: {result.error}"
|
# Even if there's an error, we want the frontend to see the generated SQL
|
||||||
|
|
||||||
# Save the result data to context for potential later use by VisualizationTool
|
|
||||||
if result.result:
|
|
||||||
current_data.set(result.result)
|
|
||||||
|
|
||||||
# Save visualization payload to context so the chat stream can pick it up
|
|
||||||
viz_payload = _build_sql_chart_viz(result)
|
viz_payload = _build_sql_chart_viz(result)
|
||||||
existing_viz = current_viz_data.get()
|
existing_viz = current_viz_data.get()
|
||||||
if isinstance(existing_viz, dict):
|
if isinstance(existing_viz, dict):
|
||||||
@@ -91,6 +85,13 @@ class NL2SQLTool(Tool):
|
|||||||
else:
|
else:
|
||||||
current_viz_data.set(viz_payload)
|
current_viz_data.set(viz_payload)
|
||||||
|
|
||||||
|
if result.error:
|
||||||
|
return f"Error executing query: {result.error}\nGenerated SQL: {result.sql}"
|
||||||
|
|
||||||
|
# Save the result data to context for potential later use by VisualizationTool
|
||||||
|
if result.result:
|
||||||
|
current_data.set(result.result)
|
||||||
|
|
||||||
# Build a summary string for the Agent to read
|
# Build a summary string for the Agent to read
|
||||||
row_count = len(result.result) if result.result else 0
|
row_count = len(result.result) if result.result else 0
|
||||||
|
|
||||||
|
|||||||
@@ -246,6 +246,7 @@ async def nanobot_chat_stream(request: ChatRequest):
|
|||||||
else:
|
else:
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
yield ": keep-alive\n\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
response = await current_task
|
response = await current_task
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
req = NL2SQLRequest(query="列出所有表", source="postgres", generate_chart=False)
|
||||||
|
res = await process_nl2sql(req)
|
||||||
|
print("SQL:", res.sql)
|
||||||
|
print("Error:", res.error)
|
||||||
|
print("Result:", res.result)
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
@@ -958,7 +958,8 @@ export function ChatInterface() {
|
|||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<div className="max-w-3xl mx-auto px-4 py-8 space-y-8">
|
<div className="max-w-3xl mx-auto px-4 py-8 space-y-8">
|
||||||
{messages.map((msg) => {
|
{messages.map((msg, msgIdx) => {
|
||||||
|
const isMessageGenerating = isLoading && msgIdx === messages.length - 1;
|
||||||
const { markdown, reportHtml } = splitReportHtml(msg.content);
|
const { markdown, reportHtml } = splitReportHtml(msg.content);
|
||||||
const externalReportUrl = extractExternalReport(msg.content);
|
const externalReportUrl = extractExternalReport(msg.content);
|
||||||
return (
|
return (
|
||||||
@@ -992,17 +993,24 @@ export function ChatInterface() {
|
|||||||
{msg.progressLogs && msg.progressLogs.length > 0 ? (
|
{msg.progressLogs && msg.progressLogs.length > 0 ? (
|
||||||
<div className="mb-2 rounded-xl border border-zinc-100 bg-zinc-50/70 px-3 py-2">
|
<div className="mb-2 rounded-xl border border-zinc-100 bg-zinc-50/70 px-3 py-2">
|
||||||
<div className="flex items-center gap-2 text-zinc-500 text-xs mb-1.5 pb-1.5 border-b border-zinc-100/50">
|
<div className="flex items-center gap-2 text-zinc-500 text-xs mb-1.5 pb-1.5 border-b border-zinc-100/50">
|
||||||
{msg.awaitingFirstToken ? <Loader2 className="h-3.5 w-3.5 animate-spin" /> : <CheckCircle2 className="h-3.5 w-3.5 text-emerald-500" />}
|
{isMessageGenerating ? <Loader2 className="h-3.5 w-3.5 animate-spin" /> : <CheckCircle2 className="h-3.5 w-3.5 text-emerald-500" />}
|
||||||
<span>{msg.awaitingFirstToken ? t('processing') : t('processCompleted')}</span>
|
<span>{isMessageGenerating ? t('processing') : t('processCompleted')}</span>
|
||||||
</div>
|
</div>
|
||||||
<div className="space-y-1.5 max-h-[160px] overflow-y-auto pr-1">
|
<div
|
||||||
|
className="space-y-1.5 max-h-[160px] overflow-y-auto pr-1"
|
||||||
|
ref={(el) => {
|
||||||
|
if (el && isMessageGenerating) {
|
||||||
|
el.scrollTop = el.scrollHeight;
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
{msg.progressLogs.map((log, idx, arr) => {
|
{msg.progressLogs.map((log, idx, arr) => {
|
||||||
const isLast = idx === arr.length - 1;
|
const isLast = idx === arr.length - 1;
|
||||||
// 如果是正在处理的会话,且当前日志是最后一条,或者是明确包含“正在”的日志,则显示 loading
|
// 只有当是整个会话的最后一条消息,且当前日志是最后一条时,才显示 loading 动画
|
||||||
const isLoadingLog = (isLast && msg.awaitingFirstToken) || log.includes(t('processingIndicator'));
|
const isLoadingLog = isLast && isMessageGenerating;
|
||||||
return (
|
return (
|
||||||
<div key={`${msg.id}-log-${idx}`} className="flex items-start gap-2 text-[12px] text-zinc-500 leading-5">
|
<div key={`${msg.id}-log-${idx}`} className="flex items-start gap-2 text-[12px] text-zinc-500 leading-5">
|
||||||
{isLoadingLog && msg.awaitingFirstToken ? (
|
{isLoadingLog ? (
|
||||||
<Settings className="mt-0.5 h-3.5 w-3.5 text-amber-500 animate-spin shrink-0" />
|
<Settings className="mt-0.5 h-3.5 w-3.5 text-amber-500 animate-spin shrink-0" />
|
||||||
) : (
|
) : (
|
||||||
<CheckCircle2 className="mt-0.5 h-3.5 w-3.5 text-emerald-500 shrink-0" />
|
<CheckCircle2 className="mt-0.5 h-3.5 w-3.5 text-emerald-500 shrink-0" />
|
||||||
|
|||||||
Reference in New Issue
Block a user