speed acc

This commit is contained in:
qixinbo
2026-03-17 21:32:01 +08:00
parent c51f51ff69
commit 49d38692cd
4 changed files with 155 additions and 36 deletions
+10 -3
View File
@@ -12,6 +12,10 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from app.schemas.chart import ChartGenerationResponse
from app.services.llm_cache import get_active_llm_config
CHART_MAX_TOKENS = 700
CHART_TEMPERATURE = 0.2
CHART_REASONING_EFFORT = "low"
CHART_INSTRUCTIONS = """
### INSTRUCTIONS ###
@@ -202,8 +206,6 @@ Question: {query}
Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)}
Sample Column Values: {columns}
Language: Chinese (Simplified)
Please think step by step
"""
messages = [
@@ -213,7 +215,12 @@ Please think step by step
# 4. Call LLM
try:
response = await provider.chat(messages=messages)
response = await provider.chat(
messages=messages,
max_tokens=CHART_MAX_TOKENS,
temperature=CHART_TEMPERATURE,
reasoning_effort=CHART_REASONING_EFFORT,
)
content = response.content
# Clean up code blocks
+42 -6
View File
@@ -4,7 +4,7 @@ import json
import time
import threading
from pathlib import Path
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Callable, Awaitable
from pydantic import BaseModel, Field
import duckdb
import pandas as pd
@@ -30,6 +30,9 @@ SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30
UPLOAD_CACHE_TTL_SECONDS = 900
MAX_UPLOAD_CACHE_ITEMS = 8
NL2SQL_MAX_TOKENS = 900
NL2SQL_TEMPERATURE = 0.1
NL2SQL_REASONING_EFFORT = "low"
_schema_cache: Dict[str, Dict[str, Any]] = {}
_connection_cache: Dict[str, Dict[str, Any]] = {}
@@ -84,7 +87,7 @@ DEFAULT_TEXT_TO_SQL_RULES = """
SQL_GENERATION_SYSTEM_PROMPT = f"""
You are a helpful assistant that converts natural language queries into ANSI SQL queries.
Given user's question, database schema, etc., you should think deeply and carefully and generate the SQL query based on the given reasoning plan step by step.
Given user's question and database schema, generate accurate ANSI SQL directly and concisely.
### GENERAL RULES ###
@@ -195,7 +198,15 @@ def _check_connection_with_cache(source: str, connector: Any) -> bool:
_connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS}
return ok
async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
async def process_nl2sql(
request: NL2SQLRequest,
on_progress: Callable[[str], Awaitable[None]] | None = None,
) -> NL2SQLResponse:
async def emit_progress(content: str) -> None:
if on_progress and content:
await on_progress(content)
total_started = time.perf_counter()
# 1. Get the connector and schema
connector = None
schema = {}
@@ -207,13 +218,16 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
connector = clickhouse_connector
elif request.source == "upload":
try:
upload_started = time.perf_counter()
upload_payload = _get_upload_payload(request.file_url)
upload_df = upload_payload["df"]
schema = upload_payload["schema"]
await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"Failed to load uploaded file: {e}")
elif request.source.startswith("ds:"):
try:
ds_started = time.perf_counter()
ds_id = int(request.source.split(":")[1])
db = SessionLocal()
try:
@@ -223,6 +237,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
connector = get_connector(ds)
finally:
db.close()
await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)")
except ValueError:
return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}")
except Exception as e:
@@ -231,21 +246,29 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
return NL2SQLResponse(sql="", result=[], error=f"Unsupported data source: {request.source}")
if connector:
await emit_progress("正在检测数据源连通性")
cached_schema = _get_cached_schema(request.source, connector)
if cached_schema:
schema = cached_schema
await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表")
else:
conn_started = time.perf_counter()
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)")
schema_started = time.perf_counter()
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)")
if connector and not schema:
retry_started = time.perf_counter()
# Double check in case schema was empty but connection is ok (e.g. empty db)
if not _check_connection_with_cache(request.source, connector):
return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}")
schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema)
await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)")
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
@@ -307,8 +330,6 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
### INPUTS ###
User's Question: {request.query}
Language: Chinese (Simplified)
Let's think step by step.
"""
messages = [
@@ -318,7 +339,14 @@ Let's think step by step.
# 5. Call LLM
try:
response = await provider.chat(messages=messages)
llm_started = time.perf_counter()
await emit_progress("正在生成 SQL")
response = await provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
)
content = response.content.strip()
# Clean up code blocks
@@ -336,12 +364,15 @@ Let's think step by step.
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
# 6. Execute SQL
try:
sql_exec_started = time.perf_counter()
await emit_progress("正在执行 SQL 查询")
if request.source == "upload":
if upload_df is None:
upload_df = _get_upload_payload(request.file_url)["df"]
@@ -380,11 +411,16 @@ Let's think step by step.
else:
# Unknown format, try to return as is or empty
formatted_results = []
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
# 7. Generate Chart
chart_response = None
if request.generate_chart and formatted_results:
chart_started = time.perf_counter()
await emit_progress("正在生成可视化方案")
chart_response = await generate_chart(formatted_results, request.query)
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except Exception as e: