speed acc
This commit is contained in:
@@ -12,6 +12,10 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.services.llm_cache import get_active_llm_config
|
||||
|
||||
CHART_MAX_TOKENS = 700
|
||||
CHART_TEMPERATURE = 0.2
|
||||
CHART_REASONING_EFFORT = "low"
|
||||
|
||||
CHART_INSTRUCTIONS = """
|
||||
### INSTRUCTIONS ###
|
||||
|
||||
@@ -202,8 +206,6 @@ Question: {query}
|
||||
Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)}
|
||||
Sample Column Values: {columns}
|
||||
Language: Chinese (Simplified)
|
||||
|
||||
Please think step by step
|
||||
"""
|
||||
|
||||
messages = [
|
||||
@@ -213,7 +215,12 @@ Please think step by step
|
||||
|
||||
# 4. Call LLM
|
||||
try:
|
||||
response = await provider.chat(messages=messages)
|
||||
response = await provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=CHART_MAX_TOKENS,
|
||||
temperature=CHART_TEMPERATURE,
|
||||
reasoning_effort=CHART_REASONING_EFFORT,
|
||||
)
|
||||
content = response.content
|
||||
|
||||
# Clean up code blocks
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import time
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from pydantic import BaseModel, Field
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
@@ -30,6 +30,9 @@ SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
UPLOAD_CACHE_TTL_SECONDS = 900
|
||||
MAX_UPLOAD_CACHE_ITEMS = 8
|
||||
NL2SQL_MAX_TOKENS = 900
|
||||
NL2SQL_TEMPERATURE = 0.1
|
||||
NL2SQL_REASONING_EFFORT = "low"
|
||||
|
||||
_schema_cache: Dict[str, Dict[str, Any]] = {}
|
||||
_connection_cache: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -84,7 +87,7 @@ DEFAULT_TEXT_TO_SQL_RULES = """
|
||||
SQL_GENERATION_SYSTEM_PROMPT = f"""
|
||||
You are a helpful assistant that converts natural language queries into ANSI SQL queries.
|
||||
|
||||
Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
|
||||
Given user's question and database schema, generate accurate ANSI SQL directly and concisely.
|
||||
|
||||
### GENERAL RULES ###
|
||||
|
||||
@@ -195,7 +198,15 @@ def _check_connection_with_cache(source: str, connector: Any) -> bool:
|
||||
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
|
||||
return ok
|
||||
|
||||
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
async def process_nl2sql(
|
||||
request: NL2SQLRequest,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> NL2SQLResponse:
|
||||
async def emit_progress(content: str) -> None:
|
||||
if on_progress and content:
|
||||
await on_progress(content)
|
||||
|
||||
total_started = time.perf_counter()
|
||||
# 1. Get the connector and schema
|
||||
connector = None
|
||||
schema = {}
|
||||
@@ -207,13 +218,16 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
connector = clickhouse_connector
|
||||
elif request.source == "upload":
|
||||
try:
|
||||
upload_started = time.perf_counter()
|
||||
upload_payload = _get_upload_payload(request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
schema = upload_payload["schema"]
|
||||
await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
|
||||
elif request.source.startswith("ds:"):
|
||||
try:
|
||||
ds_started = time.perf_counter()
|
||||
ds_id = int(request.source.split(":")[1])
|
||||
db = SessionLocal()
|
||||
try:
|
||||
@@ -223,6 +237,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
connector = get_connector(ds)
|
||||
finally:
|
||||
db.close()
|
||||
await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)")
|
||||
except ValueError:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
|
||||
except Exception as e:
|
||||
@@ -231,21 +246,29 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
|
||||
|
||||
if connector:
|
||||
await emit_progress("正在检测数据源连通性")
|
||||
cached_schema = _get_cached_schema(request.source, connector)
|
||||
if cached_schema:
|
||||
schema = cached_schema
|
||||
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
|
||||
else:
|
||||
conn_started = time.perf_counter()
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)")
|
||||
schema_started = time.perf_counter()
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
|
||||
|
||||
if connector and not schema:
|
||||
retry_started = time.perf_counter()
|
||||
# Double check in case schema was empty but connection is ok (e.g. empty db)
|
||||
if not _check_connection_with_cache(request.source, connector):
|
||||
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
|
||||
|
||||
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
@@ -307,8 +330,6 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
### INPUTS ###
|
||||
User's Question: {request.query}
|
||||
Language: Chinese (Simplified)
|
||||
|
||||
Let's think step by step.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
@@ -318,7 +339,14 @@ Let's think step by step.
|
||||
|
||||
# 5. Call LLM
|
||||
try:
|
||||
response = await provider.chat(messages=messages)
|
||||
llm_started = time.perf_counter()
|
||||
await emit_progress("正在生成 SQL")
|
||||
response = await provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
)
|
||||
content = response.content.strip()
|
||||
|
||||
# Clean up code blocks
|
||||
@@ -336,12 +364,15 @@ Let's think step by step.
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
sql_exec_started = time.perf_counter()
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_df = _get_upload_payload(request.file_url)["df"]
|
||||
@@ -380,11 +411,16 @@ Let's think step by step.
|
||||
else:
|
||||
# Unknown format, try to return as is or empty
|
||||
formatted_results = []
|
||||
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if request.generate_chart and formatted_results:
|
||||
chart_started = time.perf_counter()
|
||||
await emit_progress("正在生成可视化方案")
|
||||
chart_response = await generate_chart(formatted_results, request.query)
|
||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except Exception as e:
|
||||
|
||||
+31
-7
@@ -250,14 +250,37 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
||||
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
|
||||
if use_nl2sql:
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(
|
||||
query=request.message,
|
||||
source=resolved_source,
|
||||
file_url=request.file_url,
|
||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': '已识别为数据分析请求,正在连接数据源'}, ensure_ascii=False)}\n\n"
|
||||
sql_progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def _on_sql_progress(content: str) -> None:
|
||||
if content:
|
||||
await sql_progress_queue.put(content)
|
||||
|
||||
sql_task = asyncio.create_task(
|
||||
process_nl2sql(
|
||||
NL2SQLRequest(
|
||||
query=request.message,
|
||||
source=resolved_source,
|
||||
file_url=request.file_url,
|
||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
||||
),
|
||||
on_progress=_on_sql_progress,
|
||||
)
|
||||
)
|
||||
while True:
|
||||
if sql_task.done() and sql_progress_queue.empty():
|
||||
break
|
||||
try:
|
||||
progress = await asyncio.wait_for(sql_progress_queue.get(), timeout=0.2)
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
nl2sql_result = await sql_task
|
||||
if nl2sql_result.error:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': '数据查询阶段返回错误,正在整理结果'}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
|
||||
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||
viz_payload = {
|
||||
"type": "viz",
|
||||
@@ -284,13 +307,14 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
on_progress=_on_progress,
|
||||
)
|
||||
)
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': '已发送给模型,正在分析问题'}, ensure_ascii=False)}\n\n"
|
||||
text = ""
|
||||
while True:
|
||||
if task.done() and progress_queue.empty():
|
||||
break
|
||||
try:
|
||||
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
|
||||
yield f"data: {json.dumps({'type': 'delta', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
response = await task
|
||||
|
||||
@@ -21,6 +21,8 @@ interface Message {
|
||||
content: string;
|
||||
awaitingFirstToken?: boolean;
|
||||
viz?: MessageViz;
|
||||
progressLogs?: string[];
|
||||
routeInfo?: string;
|
||||
}
|
||||
|
||||
interface MessageViz {
|
||||
@@ -403,9 +405,23 @@ export function ChatInterface() {
|
||||
id: assistantId,
|
||||
role: "assistant",
|
||||
content: "",
|
||||
awaitingFirstToken: true
|
||||
awaitingFirstToken: true,
|
||||
progressLogs: ["请求已提交,准备路由..."],
|
||||
}]);
|
||||
|
||||
const pushProgressLog = (text: string) => {
|
||||
if (!text.trim()) return;
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) => {
|
||||
if (msg.id !== assistantId) return msg;
|
||||
const current = msg.progressLogs || [];
|
||||
if (current[current.length - 1] === text) return msg;
|
||||
const next = [...current, text].slice(-8);
|
||||
return { ...msg, progressLogs: next };
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
const token = localStorage.getItem("token");
|
||||
const effectiveModelId = selectedModelId || currentModel?.id || "";
|
||||
|
||||
@@ -497,6 +513,8 @@ export function ChatInterface() {
|
||||
sql?: string;
|
||||
result?: unknown;
|
||||
error?: string;
|
||||
selected?: string;
|
||||
reason?: string;
|
||||
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
|
||||
};
|
||||
|
||||
@@ -505,13 +523,29 @@ export function ChatInterface() {
|
||||
flushAssistant(false);
|
||||
}
|
||||
|
||||
if (payload.type === "routing") {
|
||||
const selected = payload.selected === "sql" ? "SQL 分析" : "通用对话";
|
||||
const reason = payload.reason ? `(${payload.reason})` : "";
|
||||
pushProgressLog(`路由:${selected}${reason}`);
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, routeInfo: `${selected}${reason}` } : msg
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
if (payload.type === "progress" && payload.content) {
|
||||
pushProgressLog(payload.content);
|
||||
}
|
||||
|
||||
if (payload.type === "final" && payload.content) {
|
||||
hasFinalPayload = true;
|
||||
streamedText = payload.content;
|
||||
flushAssistant(true);
|
||||
pushProgressLog("回答生成完成");
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
|
||||
msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
|
||||
)
|
||||
);
|
||||
}
|
||||
@@ -525,6 +559,7 @@ export function ChatInterface() {
|
||||
}
|
||||
|
||||
if (payload.type === "viz") {
|
||||
pushProgressLog("可视化结果已生成");
|
||||
streamedViz = buildMessageViz(payload);
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
@@ -805,25 +840,42 @@ export function ChatInterface() {
|
||||
}`}
|
||||
>
|
||||
{msg.role === "assistant" ? (
|
||||
msg.awaitingFirstToken && !msg.content ? (
|
||||
<div className="flex items-center gap-2 text-zinc-500 text-sm py-1">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span>模型思考中,请稍候...</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="prose prose-sm prose-zinc max-w-none prose-p:leading-normal prose-p:my-2 prose-headings:my-3 prose-ul:my-2 prose-li:my-0.5 prose-pre:bg-zinc-50 prose-pre:text-zinc-800 prose-pre:border prose-pre:border-zinc-200">
|
||||
<ReactMarkdown remarkPlugins={[remarkGfm]} rehypePlugins={[rehypeRaw]}>
|
||||
{msg.content}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
{msg.viz ? (
|
||||
<div className="mt-3 pt-3 border-t border-zinc-100">
|
||||
<InlineVisualizationCard viz={msg.viz} />
|
||||
<>
|
||||
{msg.progressLogs && msg.progressLogs.length > 0 ? (
|
||||
<div className="mb-2 rounded-xl border border-zinc-100 bg-zinc-50/70 px-3 py-2">
|
||||
<div className="flex items-center gap-2 text-zinc-500 text-xs">
|
||||
{msg.awaitingFirstToken ? <Loader2 className="h-3.5 w-3.5 animate-spin" /> : <CheckCircle2 className="h-3.5 w-3.5 text-emerald-500" />}
|
||||
<span>{msg.awaitingFirstToken ? "正在处理中" : "处理完成"}</span>
|
||||
</div>
|
||||
) : null}
|
||||
</>
|
||||
)
|
||||
<div className="mt-1.5 space-y-1">
|
||||
{msg.progressLogs.map((log, idx) => (
|
||||
<div key={`${msg.id}-log-${idx}`} className="text-[12px] text-zinc-500 leading-5">
|
||||
{idx + 1}. {log}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
) : null}
|
||||
{msg.awaitingFirstToken && !msg.content ? (
|
||||
<div className="flex items-center gap-2 text-zinc-500 text-sm py-1">
|
||||
<Loader2 className="h-4 w-4 animate-spin" />
|
||||
<span>模型思考中,请稍候...</span>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="prose prose-sm prose-zinc max-w-none prose-p:leading-normal prose-p:my-2 prose-headings:my-3 prose-ul:my-2 prose-li:my-0.5 prose-pre:bg-zinc-50 prose-pre:text-zinc-800 prose-pre:border prose-pre:border-zinc-200">
|
||||
<ReactMarkdown remarkPlugins={[remarkGfm]} rehypePlugins={[rehypeRaw]}>
|
||||
{msg.content}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
{msg.viz ? (
|
||||
<div className="mt-3 pt-3 border-t border-zinc-100">
|
||||
<InlineVisualizationCard viz={msg.viz} />
|
||||
</div>
|
||||
) : null}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
) : (
|
||||
msg.content
|
||||
)}
|
||||
|
||||
Reference in New Issue
Block a user