speed acc

This commit is contained in:
qixinbo
2026-03-17 21:32:01 +08:00
parent c51f51ff69
commit 49d38692cd
4 changed files with 155 additions and 36 deletions
+10 -3
View File
@@ -12,6 +12,10 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from app.schemas.chart import ChartGenerationResponse from app.schemas.chart import ChartGenerationResponse
from app.services.llm_cache import get_active_llm_config from app.services.llm_cache import get_active_llm_config
CHART_MAX_TOKENS = 700
CHART_TEMPERATURE = 0.2
CHART_REASONING_EFFORT = "low"
CHART_INSTRUCTIONS = """ CHART_INSTRUCTIONS = """
### INSTRUCTIONS ### ### INSTRUCTIONS ###
@@ -202,8 +206,6 @@ Question: {query}
Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)} Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)}
Sample Column Values: {columns} Sample Column Values: {columns}
Language: Chinese (Simplified) Language: Chinese (Simplified)
Please think step by step
""" """
messages = [ messages = [
@@ -213,7 +215,12 @@ Please think step by step
# 4. Call LLM # 4. Call LLM
try: try:
response = await provider.chat(messages=messages) response = await provider.chat(
messages=messages,
max_tokens=CHART_MAX_TOKENS,
temperature=CHART_TEMPERATURE,
reasoning_effort=CHART_REASONING_EFFORT,
)
content = response.content content = response.content
# Clean up code blocks # Clean up code blocks
+42 -6
View File
@@ -4,7 +4,7 @@ import json
import time import time
import threading import threading
from pathlib import Path from pathlib import Path
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import duckdb import duckdb
import pandas as pd import pandas as pd
@@ -30,6 +30,9 @@ SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30 CONNECTION_CACHE_TTL_SECONDS = 30
UPLOAD_CACHE_TTL_SECONDS = 900 UPLOAD_CACHE_TTL_SECONDS = 900
MAX_UPLOAD_CACHE_ITEMS = 8 MAX_UPLOAD_CACHE_ITEMS = 8
NL2SQL_MAX_TOKENS = 900
NL2SQL_TEMPERATURE = 0.1
NL2SQL_REASONING_EFFORT = "low"
_schema_cache: Dict[str, Dict[str, Any]] = {} _schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {} _connection_cache: Dict[str, Dict[str, Any]] = {}
@@ -84,7 +87,7 @@ DEFAULT_TEXT_TO_SQL_RULES = """
SQL_GENERATION_SYSTEM_PROMPT = f""" SQL_GENERATION_SYSTEM_PROMPT = f"""
You are a helpful assistant that converts natural language queries into ANSI SQL queries. You are a helpful assistant that converts natural language queries into ANSI SQL queries.
Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step. Given user's question and database schema, generate accurate ANSI SQL directly and concisely.
### GENERAL RULES ### ### GENERAL RULES ###
@@ -195,7 +198,15 @@ def _check_connection_with_cache(source: str, connector: Any) -> bool:
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS} _connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
return ok return ok
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse: async def process_nl2sql(
request: NL2SQLRequest,
on_progress: Callable[[str], Awaitable[None]] | None = None,
) -> NL2SQLResponse:
async def emit_progress(content: str) -> None:
if on_progress and content:
await on_progress(content)
total_started = time.perf_counter()
# 1. Get the connector and schema # 1. Get the connector and schema
connector = None connector = None
schema = {} schema = {}
@@ -207,13 +218,16 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
connector = clickhouse_connector connector = clickhouse_connector
elif request.source == "upload": elif request.source == "upload":
try: try:
upload_started = time.perf_counter()
upload_payload = _get_upload_payload(request.file_url) upload_payload = _get_upload_payload(request.file_url)
upload_df = upload_payload["df"] upload_df = upload_payload["df"]
schema = upload_payload["schema"] schema = upload_payload["schema"]
await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)")
except Exception as e: except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}") return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
elif request.source.startswith("ds:"): elif request.source.startswith("ds:"):
try: try:
ds_started = time.perf_counter()
ds_id = int(request.source.split(":")[1]) ds_id = int(request.source.split(":")[1])
db = SessionLocal() db = SessionLocal()
try: try:
@@ -223,6 +237,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
connector = get_connector(ds) connector = get_connector(ds)
finally: finally:
db.close() db.close()
await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)")
except ValueError: except ValueError:
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}") return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
except Exception as e: except Exception as e:
@@ -231,21 +246,29 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}") return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
if connector: if connector:
await emit_progress("正在检测数据源连通性")
cached_schema = _get_cached_schema(request.source, connector) cached_schema = _get_cached_schema(request.source, connector)
if cached_schema: if cached_schema:
schema = cached_schema schema = cached_schema
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
else: else:
conn_started = time.perf_counter()
if not _check_connection_with_cache(request.source, connector): if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)")
schema_started = time.perf_counter()
schema = connector.get_schema() schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema) _set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
if connector and not schema: if connector and not schema:
retry_started = time.perf_counter()
# Double check in case schema was empty but connection is ok (e.g. empty db) # Double check in case schema was empty but connection is ok (e.g. empty db)
if not _check_connection_with_cache(request.source, connector): if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema() schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema) _set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":")) schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
@@ -307,8 +330,6 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
### INPUTS ### ### INPUTS ###
User's Question: {request.query} User's Question: {request.query}
Language: Chinese (Simplified) Language: Chinese (Simplified)
Let's think step by step.
""" """
messages = [ messages = [
@@ -318,7 +339,14 @@ Let's think step by step.
# 5. Call LLM # 5. Call LLM
try: try:
response = await provider.chat(messages=messages) llm_started = time.perf_counter()
await emit_progress("正在生成 SQL")
response = await provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
)
content = response.content.strip() content = response.content.strip()
# Clean up code blocks # Clean up code blocks
@@ -336,12 +364,15 @@ Let's think step by step.
except json.JSONDecodeError: except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions # Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except Exception as e: except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}") return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
# 6. Execute SQL # 6. Execute SQL
try: try:
sql_exec_started = time.perf_counter()
await emit_progress("正在执行 SQL 查询")
if request.source == "upload": if request.source == "upload":
if upload_df is None: if upload_df is None:
upload_df = _get_upload_payload(request.file_url)["df"] upload_df = _get_upload_payload(request.file_url)["df"]
@@ -380,11 +411,16 @@ Let's think step by step.
else: else:
# Unknown format, try to return as is or empty # Unknown format, try to return as is or empty
formatted_results = [] formatted_results = []
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
# 7. Generate Chart # 7. Generate Chart
chart_response = None chart_response = None
if request.generate_chart and formatted_results: if request.generate_chart and formatted_results:
chart_started = time.perf_counter()
await emit_progress("正在生成可视化方案")
chart_response = await generate_chart(formatted_results, request.query) chart_response = await generate_chart(formatted_results, request.query)
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except Exception as e: except Exception as e:
+31 -7
View File
@@ -250,14 +250,37 @@ async def nanobot_chat_stream(request: ChatRequest):
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request) use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
if use_nl2sql: if use_nl2sql:
nl2sql_result = await process_nl2sql( yield f"data: {json.dumps({'type': 'progress', 'content': '已识别为数据分析请求,正在连接数据源'}, ensure_ascii=False)}\n\n"
NL2SQLRequest( sql_progress_queue: asyncio.Queue[str] = asyncio.Queue()
query=request.message,
source=resolved_source, async def _on_sql_progress(content: str) -> None:
file_url=request.file_url, if content:
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message), await sql_progress_queue.put(content)
sql_task = asyncio.create_task(
process_nl2sql(
NL2SQLRequest(
query=request.message,
source=resolved_source,
file_url=request.file_url,
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
),
on_progress=_on_sql_progress,
) )
) )
while True:
if sql_task.done() and sql_progress_queue.empty():
break
try:
progress = await asyncio.wait_for(sql_progress_queue.get(), timeout=0.2)
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
continue
nl2sql_result = await sql_task
if nl2sql_result.error:
yield f"data: {json.dumps({'type': 'progress', 'content': '数据查询阶段返回错误,正在整理结果'}, ensure_ascii=False)}\n\n"
else:
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
viz_payload = { viz_payload = {
"type": "viz", "type": "viz",
@@ -284,13 +307,14 @@ async def nanobot_chat_stream(request: ChatRequest):
on_progress=_on_progress, on_progress=_on_progress,
) )
) )
yield f"data: {json.dumps({'type': 'progress', 'content': '已发送给模型,正在分析问题'}, ensure_ascii=False)}\n\n"
text = "" text = ""
while True: while True:
if task.done() and progress_queue.empty(): if task.done() and progress_queue.empty():
break break
try: try:
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2) progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
yield f"data: {json.dumps({'type': 'delta', 'content': progress}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
response = await task response = await task
+72 -20
View File
@@ -21,6 +21,8 @@ interface Message {
content: string; content: string;
awaitingFirstToken?: boolean; awaitingFirstToken?: boolean;
viz?: MessageViz; viz?: MessageViz;
progressLogs?: string[];
routeInfo?: string;
} }
interface MessageViz { interface MessageViz {
@@ -403,9 +405,23 @@ export function ChatInterface() {
id: assistantId, id: assistantId,
role: "assistant", role: "assistant",
content: "", content: "",
awaitingFirstToken: true awaitingFirstToken: true,
progressLogs: ["请求已提交,准备路由..."],
}]); }]);
const pushProgressLog = (text: string) => {
if (!text.trim()) return;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => {
if (msg.id !== assistantId) return msg;
const current = msg.progressLogs || [];
if (current[current.length - 1] === text) return msg;
const next = [...current, text].slice(-8);
return { ...msg, progressLogs: next };
})
);
};
const token = localStorage.getItem("token"); const token = localStorage.getItem("token");
const effectiveModelId = selectedModelId || currentModel?.id || ""; const effectiveModelId = selectedModelId || currentModel?.id || "";
@@ -497,6 +513,8 @@ export function ChatInterface() {
sql?: string; sql?: string;
result?: unknown; result?: unknown;
error?: string; error?: string;
selected?: string;
reason?: string;
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null; chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
}; };
@@ -505,13 +523,29 @@ export function ChatInterface() {
flushAssistant(false); flushAssistant(false);
} }
if (payload.type === "routing") {
const selected = payload.selected === "sql" ? "SQL 分析" : "通用对话";
const reason = payload.reason ? `${payload.reason}` : "";
pushProgressLog(`路由:${selected}${reason}`);
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, routeInfo: `${selected}${reason}` } : msg
)
);
}
if (payload.type === "progress" && payload.content) {
pushProgressLog(payload.content);
}
if (payload.type === "final" && payload.content) { if (payload.type === "final" && payload.content) {
hasFinalPayload = true; hasFinalPayload = true;
streamedText = payload.content; streamedText = payload.content;
flushAssistant(true); flushAssistant(true);
pushProgressLog("回答生成完成");
setMessagesForSession(targetSessionKey, (prev) => setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
) )
); );
} }
@@ -525,6 +559,7 @@ export function ChatInterface() {
} }
if (payload.type === "viz") { if (payload.type === "viz") {
pushProgressLog("可视化结果已生成");
streamedViz = buildMessageViz(payload); streamedViz = buildMessageViz(payload);
setMessagesForSession(targetSessionKey, (prev) => setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => prev.map((msg) =>
@@ -805,25 +840,42 @@ export function ChatInterface() {
}`} }`}
> >
{msg.role === "assistant" ? ( {msg.role === "assistant" ? (
msg.awaitingFirstToken && !msg.content ? ( <>
<div className="flex items-center gap-2 text-zinc-500 text-sm py-1"> {msg.progressLogs && msg.progressLogs.length > 0 ? (
<Loader2 className="h-4 w-4 animate-spin" /> <div className="mb-2 rounded-xl border border-zinc-100 bg-zinc-50/70 px-3 py-2">
<span>...</span> <div className="flex items-center gap-2 text-zinc-500 text-xs">
</div> {msg.awaitingFirstToken ? <Loader2 className="h-3.5 w-3.5 animate-spin" /> : <CheckCircle2 className="h-3.5 w-3.5 text-emerald-500" />}
) : ( <span>{msg.awaitingFirstToken ? "正在处理中" : "处理完成"}</span>
<>
<div className="prose prose-sm prose-zinc max-w-none prose-p:leading-normal prose-p:my-2 prose-headings:my-3 prose-ul:my-2 prose-li:my-0.5 prose-pre:bg-zinc-50 prose-pre:text-zinc-800 prose-pre:border prose-pre:border-zinc-200">
<ReactMarkdown remarkPlugins={[remarkGfm]} rehypePlugins={[rehypeRaw]}>
{msg.content}
</ReactMarkdown>
</div>
{msg.viz ? (
<div className="mt-3 pt-3 border-t border-zinc-100">
<InlineVisualizationCard viz={msg.viz} />
</div> </div>
) : null} <div className="mt-1.5 space-y-1">
</> {msg.progressLogs.map((log, idx) => (
) <div key={`${msg.id}-log-${idx}`} className="text-[12px] text-zinc-500 leading-5">
{idx + 1}. {log}
</div>
))}
</div>
</div>
) : null}
{msg.awaitingFirstToken && !msg.content ? (
<div className="flex items-center gap-2 text-zinc-500 text-sm py-1">
<Loader2 className="h-4 w-4 animate-spin" />
<span>...</span>
</div>
) : (
<>
<div className="prose prose-sm prose-zinc max-w-none prose-p:leading-normal prose-p:my-2 prose-headings:my-3 prose-ul:my-2 prose-li:my-0.5 prose-pre:bg-zinc-50 prose-pre:text-zinc-800 prose-pre:border prose-pre:border-zinc-200">
<ReactMarkdown remarkPlugins={[remarkGfm]} rehypePlugins={[rehypeRaw]}>
{msg.content}
</ReactMarkdown>
</div>
{msg.viz ? (
<div className="mt-3 pt-3 border-t border-zinc-100">
<InlineVisualizationCard viz={msg.viz} />
</div>
) : null}
</>
)}
</>
) : ( ) : (
msg.content msg.content
)} )}