speed acc

This commit is contained in:
qixinbo
2026-03-17 21:32:01 +08:00
parent c51f51ff69
commit 49d38692cd
4 changed files with 155 additions and 36 deletions
+10 -3
View File
@@ -12,6 +12,10 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from app.schemas.chart import ChartGenerationResponse
from app.services.llm_cache import get_active_llm_config
CHART_MAX_TOKENS = 700
CHART_TEMPERATURE = 0.2
CHART_REASONING_EFFORT = "low"
CHART_INSTRUCTIONS = """
### INSTRUCTIONS ###
@@ -202,8 +206,6 @@ Question: {query}
Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)}
Sample Column Values: {columns}
Language: Chinese (Simplified)
Please think step by step
"""
messages = [
@@ -213,7 +215,12 @@ Please think step by step
# 4. Call LLM
try:
response = await provider.chat(messages=messages)
response = await provider.chat(
messages=messages,
max_tokens=CHART_MAX_TOKENS,
temperature=CHART_TEMPERATURE,
reasoning_effort=CHART_REASONING_EFFORT,
)
content = response.content
# Clean up code blocks
+42 -6
View File
@@ -4,7 +4,7 @@ import json
import time
import threading
from pathlib import Path
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field
import duckdb
import pandas as pd
@@ -30,6 +30,9 @@ SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
UPLOAD_CACHE_TTL_SECONDS = 900
MAX_UPLOAD_CACHE_ITEMS = 8
NL2SQL_MAX_TOKENS = 900
NL2SQL_TEMPERATURE = 0.1
NL2SQL_REASONING_EFFORT = "low"
_schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {}
@@ -84,7 +87,7 @@ DEFAULT_TEXT_TO_SQL_RULES = """
SQL_GENERATION_SYSTEM_PROMPT = f"""
You are a helpful assistant that converts natural language queries into ANSI SQL queries.
Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
Given user's question and database schema, generate accurate ANSI SQL directly and concisely.
### GENERAL RULES ###
@@ -195,7 +198,15 @@ def _check_connection_with_cache(source: str, connector: Any) -> bool:
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
return ok
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
async def process_nl2sql(
request: NL2SQLRequest,
on_progress: Callable[[str], Awaitable[None]] | None = None,
) -> NL2SQLResponse:
async def emit_progress(content: str) -> None:
if on_progress and content:
await on_progress(content)
total_started = time.perf_counter()
# 1. Get the connector and schema
connector = None
schema = {}
@@ -207,13 +218,16 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
connector = clickhouse_connector
elif request.source == "upload":
try:
upload_started = time.perf_counter()
upload_payload = _get_upload_payload(request.file_url)
upload_df = upload_payload["df"]
schema = upload_payload["schema"]
await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
elif request.source.startswith("ds:"):
try:
ds_started = time.perf_counter()
ds_id = int(request.source.split(":")[1])
db = SessionLocal()
try:
@@ -223,6 +237,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
connector = get_connector(ds)
finally:
db.close()
await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)")
except ValueError:
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
except Exception as e:
@@ -231,21 +246,29 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
if connector:
await emit_progress("正在检测数据源连通性")
cached_schema = _get_cached_schema(request.source, connector)
if cached_schema:
schema = cached_schema
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
else:
conn_started = time.perf_counter()
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)")
schema_started = time.perf_counter()
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
if connector and not schema:
retry_started = time.perf_counter()
# Double check in case schema was empty but connection is ok (e.g. empty db)
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
@@ -307,8 +330,6 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
### INPUTS ###
User's Question: {request.query}
Language: Chinese (Simplified)
Let's think step by step.
"""
messages = [
@@ -318,7 +339,14 @@ Let's think step by step.
# 5. Call LLM
try:
response = await provider.chat(messages=messages)
llm_started = time.perf_counter()
await emit_progress("正在生成 SQL")
response = await provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
)
content = response.content.strip()
# Clean up code blocks
@@ -336,12 +364,15 @@ Let's think step by step.
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
# 6. Execute SQL
try:
sql_exec_started = time.perf_counter()
await emit_progress("正在执行 SQL 查询")
if request.source == "upload":
if upload_df is None:
upload_df = _get_upload_payload(request.file_url)["df"]
@@ -380,11 +411,16 @@ Let's think step by step.
else:
# Unknown format, try to return as is or empty
formatted_results = []
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
# 7. Generate Chart
chart_response = None
if request.generate_chart and formatted_results:
chart_started = time.perf_counter()
await emit_progress("正在生成可视化方案")
chart_response = await generate_chart(formatted_results, request.query)
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except Exception as e:
+31 -7
View File
@@ -250,14 +250,37 @@ async def nanobot_chat_stream(request: ChatRequest):
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
if use_nl2sql:
nl2sql_result = await process_nl2sql(
NL2SQLRequest(
query=request.message,
source=resolved_source,
file_url=request.file_url,
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
yield f"data: {json.dumps({'type': 'progress', 'content': '已识别为数据分析请求,正在连接数据源'}, ensure_ascii=False)}\n\n"
sql_progress_queue: asyncio.Queue[str] = asyncio.Queue()
async def _on_sql_progress(content: str) -> None:
if content:
await sql_progress_queue.put(content)
sql_task = asyncio.create_task(
process_nl2sql(
NL2SQLRequest(
query=request.message,
source=resolved_source,
file_url=request.file_url,
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
),
on_progress=_on_sql_progress,
)
)
while True:
if sql_task.done() and sql_progress_queue.empty():
break
try:
progress = await asyncio.wait_for(sql_progress_queue.get(), timeout=0.2)
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
continue
nl2sql_result = await sql_task
if nl2sql_result.error:
yield f"data: {json.dumps({'type': 'progress', 'content': '数据查询阶段返回错误,正在整理结果'}, ensure_ascii=False)}\n\n"
else:
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
viz_payload = {
"type": "viz",
@@ -284,13 +307,14 @@ async def nanobot_chat_stream(request: ChatRequest):
on_progress=_on_progress,
)
)
yield f"data: {json.dumps({'type': 'progress', 'content': '已发送给模型,正在分析问题'}, ensure_ascii=False)}\n\n"
text = ""
while True:
if task.done() and progress_queue.empty():
break
try:
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
yield f"data: {json.dumps({'type': 'delta', 'content': progress}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
continue
response = await task
+72 -20
View File
@@ -21,6 +21,8 @@ interface Message {
content: string;
awaitingFirstToken?: boolean;
viz?: MessageViz;
progressLogs?: string[];
routeInfo?: string;
}
interface MessageViz {
@@ -403,9 +405,23 @@ export function ChatInterface() {
id: assistantId,
role: "assistant",
content: "",
awaitingFirstToken: true
awaitingFirstToken: true,
progressLogs: ["请求已提交,准备路由..."],
}]);
const pushProgressLog = (text: string) => {
if (!text.trim()) return;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => {
if (msg.id !== assistantId) return msg;
const current = msg.progressLogs || [];
if (current[current.length - 1] === text) return msg;
const next = [...current, text].slice(-8);
return { ...msg, progressLogs: next };
})
);
};
const token = localStorage.getItem("token");
const effectiveModelId = selectedModelId || currentModel?.id || "";
@@ -497,6 +513,8 @@ export function ChatInterface() {
sql?: string;
result?: unknown;
error?: string;
selected?: string;
reason?: string;
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
};
@@ -505,13 +523,29 @@ export function ChatInterface() {
flushAssistant(false);
}
if (payload.type === "routing") {
const selected = payload.selected === "sql" ? "SQL 分析" : "通用对话";
const reason = payload.reason ? `${payload.reason}` : "";
pushProgressLog(`路由:${selected}${reason}`);
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, routeInfo: `${selected}${reason}` } : msg
)
);
}
if (payload.type === "progress" && payload.content) {
pushProgressLog(payload.content);
}
if (payload.type === "final" && payload.content) {
hasFinalPayload = true;
streamedText = payload.content;
flushAssistant(true);
pushProgressLog("回答生成完成");
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
)
);
}
@@ -525,6 +559,7 @@ export function ChatInterface() {
}
if (payload.type === "viz") {
pushProgressLog("可视化结果已生成");
streamedViz = buildMessageViz(payload);
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
@@ -805,25 +840,42 @@ export function ChatInterface() {
}`}
>
{msg.role === "assistant" ? (
msg.awaitingFirstToken && !msg.content ? (
<div className="flex items-center gap-2 text-zinc-500 text-sm py-1">
<Loader2 className="h-4 w-4 animate-spin" />
<span>...</span>
</div>
) : (
<>
<div className="prose prose-sm prose-zinc max-w-none prose-p:leading-normal prose-p:my-2 prose-headings:my-3 prose-ul:my-2 prose-li:my-0.5 prose-pre:bg-zinc-50 prose-pre:text-zinc-800 prose-pre:border prose-pre:border-zinc-200">
<ReactMarkdown remarkPlugins={[remarkGfm]} rehypePlugins={[rehypeRaw]}>
{msg.content}
</ReactMarkdown>
</div>
{msg.viz ? (
<div className="mt-3 pt-3 border-t border-zinc-100">
<InlineVisualizationCard viz={msg.viz} />
<>
{msg.progressLogs && msg.progressLogs.length > 0 ? (
<div className="mb-2 rounded-xl border border-zinc-100 bg-zinc-50/70 px-3 py-2">
<div className="flex items-center gap-2 text-zinc-500 text-xs">
{msg.awaitingFirstToken ? <Loader2 className="h-3.5 w-3.5 animate-spin" /> : <CheckCircle2 className="h-3.5 w-3.5 text-emerald-500" />}
<span>{msg.awaitingFirstToken ? "正在处理中" : "处理完成"}</span>
</div>
) : null}
</>
)
<div className="mt-1.5 space-y-1">
{msg.progressLogs.map((log, idx) => (
<div key={`${msg.id}-log-${idx}`} className="text-[12px] text-zinc-500 leading-5">
{idx + 1}. {log}
</div>
))}
</div>
</div>
) : null}
{msg.awaitingFirstToken && !msg.content ? (
<div className="flex items-center gap-2 text-zinc-500 text-sm py-1">
<Loader2 className="h-4 w-4 animate-spin" />
<span>...</span>
</div>
) : (
<>
<div className="prose prose-sm prose-zinc max-w-none prose-p:leading-normal prose-p:my-2 prose-headings:my-3 prose-ul:my-2 prose-li:my-0.5 prose-pre:bg-zinc-50 prose-pre:text-zinc-800 prose-pre:border prose-pre:border-zinc-200">
<ReactMarkdown remarkPlugins={[remarkGfm]} rehypePlugins={[rehypeRaw]}>
{msg.content}
</ReactMarkdown>
</div>
{msg.viz ? (
<div className="mt-3 pt-3 border-t border-zinc-100">
<InlineVisualizationCard viz={msg.viz} />
</div>
) : null}
</>
)}
</>
) : (
msg.content
)}