diff --git a/backend/app/agent/chart.py b/backend/app/agent/chart.py index 969c608..3ab74cb 100644 --- a/backend/app/agent/chart.py +++ b/backend/app/agent/chart.py @@ -12,6 +12,10 @@ from nanobot.providers.litellm_provider import LiteLLMProvider from app.schemas.chart import ChartGenerationResponse from app.services.llm_cache import get_active_llm_config +CHART_MAX_TOKENS = 700 +CHART_TEMPERATURE = 0.2 +CHART_REASONING_EFFORT = "low" + CHART_INSTRUCTIONS = """ ### INSTRUCTIONS ### @@ -202,8 +206,6 @@ Question: {query} Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)} Sample Column Values: {columns} Language: Chinese (Simplified) - -Please think step by step """ messages = [ @@ -213,7 +215,12 @@ Please think step by step # 4. Call LLM try: - response = await provider.chat(messages=messages) + response = await provider.chat( + messages=messages, + max_tokens=CHART_MAX_TOKENS, + temperature=CHART_TEMPERATURE, + reasoning_effort=CHART_REASONING_EFFORT, + ) content = response.content # Clean up code blocks diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index a463ced..ee3eb8b 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -4,7 +4,7 @@ import json import time import threading from pathlib import Path -from typing import List, Optional, Dict, Any +from typing import List, Optional, Dict, Any, Callable, Awaitable from pydantic import BaseModel, Field import duckdb import pandas as pd @@ -30,6 +30,9 @@ SCHEMA_CACHE_TTL_SECONDS = 300 CONNECTION_CACHE_TTL_SECONDS = 30 UPLOAD_CACHE_TTL_SECONDS = 900 MAX_UPLOAD_CACHE_ITEMS = 8 +NL2SQL_MAX_TOKENS = 900 +NL2SQL_TEMPERATURE = 0.1 +NL2SQL_REASONING_EFFORT = "low" _schema_cache: Dict[str, Dict[str, Any]] = {} _connection_cache: Dict[str, Dict[str, Any]] = {} @@ -84,7 +87,7 @@ DEFAULT_TEXT_TO_SQL_RULES = """ SQL_GENERATION_SYSTEM_PROMPT = f""" You are a helpful assistant that converts natural language queries into ANSI SQL queries. -Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. +Given user's question and database schema, generate accurate ANSI SQL directly and concisely. ### GENERAL RULES ### @@ -195,7 +198,15 @@ def _check_connection_with_cache(source: str, connector: Any) -> bool: _connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS} return ok -async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: +async def process_nl2sql( + request: NL2SQLRequest, + on_progress: Callable[[str], Awaitable[None]] | None = None, +) -> NL2SQLResponse: + async def emit_progress(content: str) -> None: + if on_progress and content: + await on_progress(content) + + total_started = time.perf_counter() # 1. Get the connector and schema connector = None schema = {} @@ -207,13 +218,16 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: connector = clickhouse_connector elif request.source == "upload": try: + upload_started = time.perf_counter() upload_payload = _get_upload_payload(request.file_url) upload_df = upload_payload["df"] schema = upload_payload["schema"] + await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)") except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}") elif request.source.startswith("ds:"): try: + ds_started = time.perf_counter() ds_id = int(request.source.split(":")[1]) db = SessionLocal() try: @@ -223,6 +237,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: connector = get_connector(ds) finally: db.close() + await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)") except ValueError: return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}") except Exception as e: @@ -231,21 +246,29 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}") if connector: + await emit_progress("正在检测数据源连通性") cached_schema = _get_cached_schema(request.source, connector) if cached_schema: schema = cached_schema + await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表") else: + conn_started = time.perf_counter() if not _check_connection_with_cache(request.source, connector): return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") + await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)") + schema_started = time.perf_counter() schema = connector.get_schema() _set_cached_schema(request.source, connector, schema) + await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)") if connector and not schema: + retry_started = time.perf_counter() # Double check in case schema was empty but connection is ok (e.g. empty db) if not _check_connection_with_cache(request.source, connector): return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") schema = connector.get_schema() _set_cached_schema(request.source, connector, schema) + await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)") schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":")) @@ -307,8 +330,6 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: ### INPUTS ### User's Question: {request.query} Language: Chinese (Simplified) - -Let's think step by step. """ messages = [ @@ -318,7 +339,14 @@ Let's think step by step. # 5. Call LLM try: - response = await provider.chat(messages=messages) + llm_started = time.perf_counter() + await emit_progress("正在生成 SQL") + response = await provider.chat( + messages=messages, + max_tokens=NL2SQL_MAX_TOKENS, + temperature=NL2SQL_TEMPERATURE, + reasoning_effort=NL2SQL_REASONING_EFFORT, + ) content = response.content.strip() # Clean up code blocks @@ -336,12 +364,15 @@ Let's think step by step. except json.JSONDecodeError: # Fallback if LLM doesn't return valid JSON despite instructions sql_query = content + await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)") except Exception as e: return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}") # 6. Execute SQL try: + sql_exec_started = time.perf_counter() + await emit_progress("正在执行 SQL 查询") if request.source == "upload": if upload_df is None: upload_df = _get_upload_payload(request.file_url)["df"] @@ -380,11 +411,16 @@ Let's think step by step. else: # Unknown format, try to return as is or empty formatted_results = [] + await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)") # 7. Generate Chart chart_response = None if request.generate_chart and formatted_results: + chart_started = time.perf_counter() + await emit_progress("正在生成可视化方案") chart_response = await generate_chart(formatted_results, request.query) + await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)") + await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s") return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) except Exception as e: diff --git a/backend/main.py b/backend/main.py index 391d963..2b17e70 100644 --- a/backend/main.py +++ b/backend/main.py @@ -250,14 +250,37 @@ async def nanobot_chat_stream(request: ChatRequest): use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request) yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n" if use_nl2sql: - nl2sql_result = await process_nl2sql( - NL2SQLRequest( - query=request.message, - source=resolved_source, - file_url=request.file_url, - generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message), + yield f"data: {json.dumps({'type': 'progress', 'content': '已识别为数据分析请求,正在连接数据源'}, ensure_ascii=False)}\n\n" + sql_progress_queue: asyncio.Queue[str] = asyncio.Queue() + + async def _on_sql_progress(content: str) -> None: + if content: + await sql_progress_queue.put(content) + + sql_task = asyncio.create_task( + process_nl2sql( + NL2SQLRequest( + query=request.message, + source=resolved_source, + file_url=request.file_url, + generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message), + ), + on_progress=_on_sql_progress, ) ) + while True: + if sql_task.done() and sql_progress_queue.empty(): + break + try: + progress = await asyncio.wait_for(sql_progress_queue.get(), timeout=0.2) + yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n" + except asyncio.TimeoutError: + continue + nl2sql_result = await sql_task + if nl2sql_result.error: + yield f"data: {json.dumps({'type': 'progress', 'content': '数据查询阶段返回错误,正在整理结果'}, ensure_ascii=False)}\n\n" + else: + yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n" persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) viz_payload = { "type": "viz", @@ -284,13 +307,14 @@ async def nanobot_chat_stream(request: ChatRequest): on_progress=_on_progress, ) ) + yield f"data: {json.dumps({'type': 'progress', 'content': '已发送给模型,正在分析问题'}, ensure_ascii=False)}\n\n" text = "" while True: if task.done() and progress_queue.empty(): break try: progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2) - yield f"data: {json.dumps({'type': 'delta', 'content': progress}, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n" except asyncio.TimeoutError: continue response = await task diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index d9f397e..ea87d64 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -21,6 +21,8 @@ interface Message { content: string; awaitingFirstToken?: boolean; viz?: MessageViz; + progressLogs?: string[]; + routeInfo?: string; } interface MessageViz { @@ -403,9 +405,23 @@ export function ChatInterface() { id: assistantId, role: "assistant", content: "", - awaitingFirstToken: true + awaitingFirstToken: true, + progressLogs: ["请求已提交,准备路由..."], }]); + const pushProgressLog = (text: string) => { + if (!text.trim()) return; + setMessagesForSession(targetSessionKey, (prev) => + prev.map((msg) => { + if (msg.id !== assistantId) return msg; + const current = msg.progressLogs || []; + if (current[current.length - 1] === text) return msg; + const next = [...current, text].slice(-8); + return { ...msg, progressLogs: next }; + }) + ); + }; + const token = localStorage.getItem("token"); const effectiveModelId = selectedModelId || currentModel?.id || ""; @@ -497,6 +513,8 @@ export function ChatInterface() { sql?: string; result?: unknown; error?: string; + selected?: string; + reason?: string; chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null; }; @@ -505,13 +523,29 @@ export function ChatInterface() { flushAssistant(false); } + if (payload.type === "routing") { + const selected = payload.selected === "sql" ? "SQL 分析" : "通用对话"; + const reason = payload.reason ? `(${payload.reason})` : ""; + pushProgressLog(`路由:${selected}${reason}`); + setMessagesForSession(targetSessionKey, (prev) => + prev.map((msg) => + msg.id === assistantId ? { ...msg, routeInfo: `${selected}${reason}` } : msg + ) + ); + } + + if (payload.type === "progress" && payload.content) { + pushProgressLog(payload.content); + } + if (payload.type === "final" && payload.content) { hasFinalPayload = true; streamedText = payload.content; flushAssistant(true); + pushProgressLog("回答生成完成"); setMessagesForSession(targetSessionKey, (prev) => prev.map((msg) => - msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg + msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg ) ); } @@ -525,6 +559,7 @@ export function ChatInterface() { } if (payload.type === "viz") { + pushProgressLog("可视化结果已生成"); streamedViz = buildMessageViz(payload); setMessagesForSession(targetSessionKey, (prev) => prev.map((msg) => @@ -805,25 +840,42 @@ export function ChatInterface() { }`} > {msg.role === "assistant" ? ( - msg.awaitingFirstToken && !msg.content ? ( -