diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index 3f7aa04..2c6425c 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -4,12 +4,15 @@ import os import json import time import threading +import logging from pathlib import Path from typing import List, Optional, Dict, Any, Callable, Awaitable from pydantic import BaseModel, Field import duckdb import pandas as pd +logger = logging.getLogger(__name__) + # Add project root to sys.path to allow importing nanobot PROJECT_ROOT = Path(__file__).resolve().parents[3] if str(PROJECT_ROOT) not in sys.path: @@ -221,8 +224,11 @@ async def _check_connection_with_cache(source: str, connector: Any) -> bool: try: ok = await asyncio.wait_for( asyncio.to_thread(connector.test_connection), - timeout=10.0 + timeout=15.0 ) + except asyncio.TimeoutError: + print("Connection test failed or timed out: Timeout after 15 seconds") + ok = False except Exception as e: print(f"Connection test failed or timed out: {e}") ok = False @@ -300,8 +306,10 @@ async def process_nl2sql( try: schema = await asyncio.wait_for( asyncio.to_thread(connector.get_schema), - timeout=30.0 + timeout=120.0 ) + except asyncio.TimeoutError: + return NL2SQLResponse(sql="", result=[], error="Failed to fetch schema: Timeout after 120 seconds. Data source might be too large or network is slow.") except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}") @@ -449,7 +457,13 @@ Language: Chinese (Simplified) # Fallback if LLM doesn't return valid JSON despite instructions sql_query = content - await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)") + logger.info(f"Generated SQL for query '{request.query}':\n{sql_query}") + + # 格式化单行 SQL 用于在前端进度中展示 + formatted_sql = sql_query.replace('\n', ' ') + if len(formatted_sql) > 150: + formatted_sql = formatted_sql[:147] + "..." + await emit_progress(f"SQL 生成完成: {formatted_sql}") except Exception as e: return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}") diff --git a/backend/app/connectors/postgres.py b/backend/app/connectors/postgres.py index 1f00ef3..cec1bc2 100644 --- a/backend/app/connectors/postgres.py +++ b/backend/app/connectors/postgres.py @@ -29,7 +29,44 @@ class PostgresConnector: # Default schema for postgres is 'public', sqlite is None schema_name = 'public' if self.engine.dialect.name == 'postgresql' else None - for table_name in inspector.get_table_names(schema=schema_name): + table_names = inspector.get_table_names(schema=schema_name) + + # Use SQLAlchemy 2.0+ multi-fetch to avoid N+1 queries issue, especially over remote networks + if hasattr(inspector, 'get_multi_columns'): + multi_columns = inspector.get_multi_columns(schema=schema_name) + multi_pk = inspector.get_multi_pk_constraint(schema=schema_name) + multi_fk = inspector.get_multi_foreign_keys(schema=schema_name) + + for table_name in table_names: + key = (schema_name, table_name) + + columns = [] + for col in multi_columns.get(key, []): + columns.append({ + "name": col['name'], + "type": str(col['type']) + }) + + pk_constraint = multi_pk.get(key) + pks = pk_constraint.get('constrained_columns', []) if pk_constraint else [] + + foreign_keys = [] + for fk in multi_fk.get(key, []): + foreign_keys.append({ + "constrained_columns": fk['constrained_columns'], + "referred_table": fk['referred_table'], + "referred_columns": fk['referred_columns'] + }) + + schema[table_name] = { + "columns": columns, + "primary_keys": pks, + "foreign_keys": foreign_keys + } + return schema + + # Fallback for older SQLAlchemy versions + for table_name in table_names: columns = [] # get columns for col in inspector.get_columns(table_name, schema=schema_name): @@ -59,8 +96,10 @@ class PostgresConnector: } return schema except Exception as e: + import traceback + traceback.print_exc() print(f"Error getting schema: {e}") - return {} + raise e def test_connection(self) -> bool: try: diff --git a/backend/app/core/nanobot.py b/backend/app/core/nanobot.py index 659a6a8..04d4338 100644 --- a/backend/app/core/nanobot.py +++ b/backend/app/core/nanobot.py @@ -108,8 +108,10 @@ class NanobotIntegration: def _register_custom_tools(self, agent: AgentLoop): from app.tools.nl2sql import NL2SQLTool from app.tools.visualization import VisualizationTool + from app.tools.get_schema import GetDatabaseSchemaTool agent.tools.register(NL2SQLTool()) agent.tools.register(VisualizationTool()) + agent.tools.register(GetDatabaseSchemaTool()) def _make_provider(self, config: Config): # Logic adapted from nanobot/cli/commands.py diff --git a/backend/app/tools/get_schema.py b/backend/app/tools/get_schema.py new file mode 100644 index 0000000..31b1ed5 --- /dev/null +++ b/backend/app/tools/get_schema.py @@ -0,0 +1,131 @@ +import json +import logging +from typing import Any +import asyncio + +from nanobot.agent.tools.base import Tool +from app.context import current_data_source, current_file_url, current_progress_callback +from app.connectors.postgres import postgres_connector +from app.connectors.clickhouse import clickhouse_connector +from app.connectors.factory import get_connector +from app.database import SessionLocal +from app.models.datasource import DataSource + +# Import schema logic from nl2sql +from app.agent.nl2sql import ( + _get_cached_schema, + _set_cached_schema, + _check_connection_with_cache, + _get_upload_payload +) + +logger = logging.getLogger(__name__) + +class GetDatabaseSchemaTool(Tool): + """ + Tool for fetching the database schema directly without SQL generation. + """ + + @property + def name(self) -> str: + return "get_database_schema" + + @property + def description(self) -> str: + return ( + "Get the structural schema of the currently connected database or data source. " + "Use this tool when the user asks questions about metadata, such as 'what tables are there', " + "'show me the database structure', 'what are the columns in table X', etc. " + "It directly returns the schema without generating SQL." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": {}, + "required": [], + } + + async def execute(self, **kwargs: Any) -> str: + source = current_data_source.get() + file_url = current_file_url.get() + on_progress = current_progress_callback.get() + + async def emit_progress(msg: str): + if on_progress: + await on_progress(msg) + + await emit_progress("正在获取数据源结构...") + + connector = None + schema = {} + + if not source: + return "Error: No data source connected." + + if source == "postgres": + connector = postgres_connector + elif source == "clickhouse": + connector = clickhouse_connector + elif source == "upload": + try: + payload = await asyncio.to_thread(_get_upload_payload, file_url) + schema = payload["schema"] + await emit_progress("文件 Schema 获取完成") + except Exception as e: + return f"Failed to get upload schema: {e}" + elif source.startswith("ds:"): + try: + ds_id = int(source.split(":")[1]) + def _get_ds_connector(): + db = SessionLocal() + try: + ds = db.query(DataSource).filter(DataSource.id == ds_id).first() + if not ds: return None + return get_connector(ds) + finally: + db.close() + connector = await asyncio.to_thread(_get_ds_connector) + if not connector: + return f"Data source not found: {source}" + except Exception as e: + return f"Failed to load data source: {e}" + else: + return f"Unsupported data source: {source}" + + if connector: + cached_schema = _get_cached_schema(source, connector) + if cached_schema is not None: + schema = cached_schema + await emit_progress(f"命中缓存,成功获取 {len(schema)} 张表结构") + else: + if not await _check_connection_with_cache(source, connector): + return f"Failed to connect to {source}" + + try: + schema = await asyncio.wait_for( + asyncio.to_thread(connector.get_schema), + timeout=120.0 + ) + _set_cached_schema(source, connector, schema) + await emit_progress(f"成功获取 {len(schema)} 张表结构") + except asyncio.TimeoutError: + return "Failed to fetch schema: Timeout after 120 seconds." + except Exception as e: + return f"Failed to fetch schema: {e}" + + # Format the output for the LLM to make it readable and token-efficient + lines = [] + for table_name, table_info in schema.items(): + if isinstance(table_info, list): + # Clickhouse/Upload format: [{"name": "col", "type": "type"}] + cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info]) + lines.append(f"Table: {table_name}\n Columns: {cols}") + elif isinstance(table_info, dict): + # Postgres format: {"columns": [...], "primary_keys": [...], "foreign_keys": [...]} + cols = ", ".join([f"{c['name']} ({c['type']})" for c in table_info.get("columns", [])]) + pks = ", ".join(table_info.get("primary_keys", [])) + lines.append(f"Table: {table_name}\n Columns: {cols}\n Primary Keys: {pks}") + + return "\n\n".join(lines) if lines else "No tables found in schema." diff --git a/backend/app/tools/nl2sql.py b/backend/app/tools/nl2sql.py index 4036721..aab0108 100644 --- a/backend/app/tools/nl2sql.py +++ b/backend/app/tools/nl2sql.py @@ -74,14 +74,8 @@ class NL2SQLTool(Tool): # Call the core logic result = await process_nl2sql(request, on_progress=on_progress) - if result.error: - return f"Error executing query: {result.error}" - - # Save the result data to context for potential later use by VisualizationTool - if result.result: - current_data.set(result.result) - - # Save visualization payload to context so the chat stream can pick it up + # Always save visualization payload to context so the chat stream can pick it up + # Even if there's an error, we want the frontend to see the generated SQL viz_payload = _build_sql_chart_viz(result) existing_viz = current_viz_data.get() if isinstance(existing_viz, dict): @@ -91,6 +85,13 @@ class NL2SQLTool(Tool): else: current_viz_data.set(viz_payload) + if result.error: + return f"Error executing query: {result.error}\nGenerated SQL: {result.sql}" + + # Save the result data to context for potential later use by VisualizationTool + if result.result: + current_data.set(result.result) + # Build a summary string for the Agent to read row_count = len(result.result) if result.result else 0 diff --git a/backend/main.py b/backend/main.py index 4ce7673..00fcec7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -246,6 +246,7 @@ async def nanobot_chat_stream(request: ChatRequest): else: yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n" except asyncio.TimeoutError: + yield ": keep-alive\n\n" continue response = await current_task diff --git a/backend/test_nl2sql.py b/backend/test_nl2sql.py new file mode 100644 index 0000000..e3d3051 --- /dev/null +++ b/backend/test_nl2sql.py @@ -0,0 +1,12 @@ +import asyncio +import json +from app.agent.nl2sql import process_nl2sql, NL2SQLRequest + +async def main(): + req = NL2SQLRequest(query="列出所有表", source="postgres", generate_chart=False) + res = await process_nl2sql(req) + print("SQL:", res.sql) + print("Error:", res.error) + print("Result:", res.result) + +asyncio.run(main()) diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index ae76dc7..6f7d175 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -958,7 +958,8 @@ export function ChatInterface() { ) : (
- {messages.map((msg) => { + {messages.map((msg, msgIdx) => { + const isMessageGenerating = isLoading && msgIdx === messages.length - 1; const { markdown, reportHtml } = splitReportHtml(msg.content); const externalReportUrl = extractExternalReport(msg.content); return ( @@ -992,17 +993,24 @@ export function ChatInterface() { {msg.progressLogs && msg.progressLogs.length > 0 ? (
- {msg.awaitingFirstToken ? : } - {msg.awaitingFirstToken ? t('processing') : t('processCompleted')} + {isMessageGenerating ? : } + {isMessageGenerating ? t('processing') : t('processCompleted')}
-
+
{ + if (el && isMessageGenerating) { + el.scrollTop = el.scrollHeight; + } + }} + > {msg.progressLogs.map((log, idx, arr) => { const isLast = idx === arr.length - 1; - // 如果是正在处理的会话,且当前日志是最后一条,或者是明确包含“正在”的日志,则显示 loading - const isLoadingLog = (isLast && msg.awaitingFirstToken) || log.includes(t('processingIndicator')); + // 只有当是整个会话的最后一条消息,且当前日志是最后一条时,才显示 loading 动画 + const isLoadingLog = isLast && isMessageGenerating; return (
- {isLoadingLog && msg.awaitingFirstToken ? ( + {isLoadingLog ? ( ) : (