feat: add streaming output

This commit is contained in:
qixinbo
2026-03-20 16:54:21 +08:00
parent e3f67d38f8
commit 50352a3653
5 changed files with 258 additions and 148 deletions
+137 -132
View File
@@ -375,130 +375,141 @@ Language: Chinese (Simplified)
{"role": "user", "content": user_prompt}
]
# 5. Call LLM
try:
llm_started = time.perf_counter()
await emit_progress("正在生成 SQL")
response = None
last_error = ""
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
try:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
num_retries=0,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
except Exception as e:
last_error = f"LLM generation failed: {e}"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
if response.finish_reason == "error":
last_error = response.content or "LLM Error"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql="", result=[], error=last_error)
break
if response is None:
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
content = (response.content or "").strip()
if not content:
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
# Clean up code blocks
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
elif "```" in content:
content = content.split("```")[1].split("```")[0]
content = content.strip()
# 5. Call LLM & 6. Execute SQL (with Self-Correction Loop)
MAX_SQL_EXEC_RETRIES = int(os.getenv("NL2SQL_MAX_EXEC_RETRIES", "2"))
sql_query = ""
formatted_results = []
chart_response = None
timeout_stage = "llm_generation"
for exec_attempt in range(MAX_SQL_EXEC_RETRIES + 1):
try:
result_json = json.loads(content)
sql_query = result_json.get("sql", "").strip()
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except Exception as e:
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
# 6. Execute SQL
try:
timeout_stage = "sql_execution"
sql_exec_started = time.perf_counter()
await emit_progress("正在执行 SQL 查询")
if request.source == "upload":
if upload_df is None:
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
upload_df = upload_payload["df"]
timeout_stage = "sql_execution"
formatted_results = await asyncio.wait_for(
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
else:
timeout_stage = "sql_execution"
results = await asyncio.wait_for(
asyncio.to_thread(connector.execute_query, sql_query),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
# Format results
formatted_results = []
if isinstance(results, list):
if results and isinstance(results[0], dict):
formatted_results = results
elif results and isinstance(results[0], (list, tuple)):
# Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case)
# If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types))
# But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types)
# Wait, client.execute(with_column_types=True) returns (data, columns_with_types)
# Let's check what connector.execute_query returns.
# PostgresConnector returns list of dicts.
# ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs.
# Let's handle the ClickHouse case explicitly if possible or make it generic.
# If results is list of tuples/lists, we need headers.
# Postgres returns list of dicts, so we are good.
# ClickHouse: if modified to return client.execute(..., with_column_types=True),
# it returns `(result_rows, column_types_list)`.
# So `results` here would be a tuple, not a list.
formatted_results = [list(row) for row in results]
else:
formatted_results = results
elif isinstance(results, tuple) and len(results) == 2:
# Likely ClickHouse (rows, columns)
rows, cols = results
col_names = [c[0] for c in cols]
formatted_results = [dict(zip(col_names, row)) for row in rows]
llm_started = time.perf_counter()
if exec_attempt == 0:
await emit_progress("正在生成 SQL")
else:
# Unknown format, try to return as is or empty
formatted_results = []
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
await emit_progress(f"正在尝试修复 SQL ({exec_attempt}/{MAX_SQL_EXEC_RETRIES})")
# 7. Generate Chart
chart_response = None
if request.generate_chart and formatted_results:
response = None
last_error = ""
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
try:
response = await asyncio.wait_for(
provider.chat(
messages=messages,
max_tokens=NL2SQL_MAX_TOKENS,
temperature=NL2SQL_TEMPERATURE,
reasoning_effort=NL2SQL_REASONING_EFFORT,
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
num_retries=0,
),
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
except Exception as e:
last_error = f"LLM generation failed: {e}"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
if response.finish_reason == "error":
last_error = response.content or "LLM Error"
if attempt < NL2SQL_LLM_RETRY_COUNT:
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
continue
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
break
if response is None:
return NL2SQLResponse(sql=sql_query, result=[], error=last_error or "LLM generation failed")
content = (response.content or "").strip()
if not content:
return NL2SQLResponse(sql=sql_query, result=[], error="LLM returned empty response")
# Clean up code blocks
if "```json" in content:
content = content.split("```json")[1].split("```")[0]
elif "```" in content:
content = content.split("```")[1].split("```")[0]
content = content.strip()
try:
result_json = json.loads(content)
sql_query = result_json.get("sql", "").strip()
except json.JSONDecodeError:
# Fallback if LLM doesn't return valid JSON despite instructions
sql_query = content
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")
# 6. Execute SQL
try:
timeout_stage = "sql_execution"
sql_exec_started = time.perf_counter()
await emit_progress("正在执行 SQL 查询")
if request.source == "upload":
if upload_df is None:
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
upload_df = upload_payload["df"]
formatted_results = await asyncio.wait_for(
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
else:
results = await asyncio.wait_for(
asyncio.to_thread(connector.execute_query, sql_query),
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
)
# Format results
formatted_results = []
if isinstance(results, list):
if results and isinstance(results[0], dict):
formatted_results = results
elif results and isinstance(results[0], (list, tuple)):
formatted_results = [list(row) for row in results]
else:
formatted_results = results
elif isinstance(results, tuple) and len(results) == 2:
rows, cols = results
col_names = [c[0] for c in cols]
formatted_results = [dict(zip(col_names, row)) for row in rows]
else:
formatted_results = []
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
break # Execution succeeded, break the retry loop
except asyncio.TimeoutError:
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
except Exception as e:
if exec_attempt < MAX_SQL_EXEC_RETRIES:
await emit_progress(f"SQL 执行失败,准备自动修复 ({exec_attempt + 1}/{MAX_SQL_EXEC_RETRIES})")
messages.append({"role": "assistant", "content": f"```json\n{{\"sql\": \"{sql_query}\"}}\n```"})
messages.append({
"role": "user",
"content": f"The generated SQL failed to execute. Database error:\n{str(e)}\n\nPlease fix the SQL query to resolve this error and provide the corrected version following the exact same JSON format."
})
continue
else:
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed after {MAX_SQL_EXEC_RETRIES} retries: {e}")
# 7. Generate Chart
if request.generate_chart and formatted_results:
try:
chart_started = time.perf_counter()
await emit_progress("正在生成可视化方案")
timeout_stage = "chart_generation"
@@ -506,16 +517,8 @@ Language: Chinese (Simplified)
generate_chart(formatted_results, request.query),
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
)
if not chart_response or not chart_response.chart_spec:
# Do not fallback automatically if the LLM explicitly decided not to or failed.
# Just pass whatever it returned (or lack thereof)
pass
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except asyncio.TimeoutError:
if timeout_stage == "chart_generation":
except asyncio.TimeoutError:
fallback_chart = ChartGenerationResponse(
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
chart_type="",
@@ -523,6 +526,8 @@ Language: Chinese (Simplified)
chart_spec=None,
)
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
except Exception as e:
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
except Exception as e:
pass # Ignore chart generation errors, return data only
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
+4 -2
View File
@@ -34,6 +34,8 @@ from nanobot.config.schema import Config
from app.api.skills import load_skills
from app.services.llm_cache import get_llm_configs
from app.core.streaming_provider import StreamingLiteLLMProvider
class NanobotIntegration:
def __init__(self):
self.agent: AgentLoop | None = None
@@ -156,7 +158,7 @@ class NanobotIntegration:
spec = find_by_name(provider_name)
# Skip API key check for now to allow initialization without full config
return LiteLLMProvider(
return StreamingLiteLLMProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
@@ -211,7 +213,7 @@ class NanobotIntegration:
cached = self._model_agent_cache.get(model_id)
if cached:
return cached
provider = LiteLLMProvider(
provider = StreamingLiteLLMProvider(
api_key=target_config.get("api_key"),
api_base=target_config.get("api_base"),
default_model=target_config.get("model"),
+76
View File
@@ -0,0 +1,76 @@
import contextvars
import json
from typing import Any, Dict, List, Optional
from loguru import logger
from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.providers.base import LLMResponse
from litellm import acompletion, stream_chunk_builder
streaming_queue_var = contextvars.ContextVar("streaming_queue", default=None)
class StreamingLiteLLMProvider(LiteLLMProvider):
async def chat(
self,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
model: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 4000,
reasoning_effort: Optional[str] = None,
request_timeout: Optional[int] = None,
num_retries: Optional[int] = None,
) -> LLMResponse:
original_model = model or self.default_model
model_name = self._resolve_model(original_model)
kwargs: Dict[str, Any] = {
"model": model_name,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True, # 强制开启流式
}
if self.api_key and self.api_key != "no-key":
kwargs["api_key"] = self.api_key
if self.api_base:
kwargs["api_base"] = self.api_base
if self.extra_headers:
kwargs["extra_headers"] = self.extra_headers
if tools:
kwargs["tools"] = tools
if request_timeout is not None:
kwargs["timeout"] = request_timeout
if num_retries is not None:
kwargs["num_retries"] = max(0, int(num_retries))
if reasoning_effort and self._supports_reasoning_effort(model_name):
kwargs["reasoning_effort"] = reasoning_effort
try:
response_stream = await acompletion(**kwargs)
chunks = []
queue = streaming_queue_var.get()
async for chunk in response_stream:
chunks.append(chunk)
if queue is not None:
# 提取普通内容或 think 内容
delta = chunk.choices[0].delta if chunk.choices else None
if delta:
content = getattr(delta, "content", None)
reasoning_content = getattr(delta, "reasoning_content", None)
if content:
await queue.put({"type": "delta", "content": content})
if reasoning_content:
await queue.put({"type": "progress", "content": reasoning_content, "is_reasoning": True})
# 还原为完整的 response 对象供 nanobot 处理
full_response = stream_chunk_builder(chunks, messages=messages)
return self._parse_response(full_response)
except Exception as e:
logger.error("StreamingLiteLLMProvider failed: {}", e)
raise
+11 -5
View File
@@ -170,6 +170,8 @@ async def nanobot_chat(request: ChatRequest):
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
from app.core.streaming_provider import streaming_queue_var
@app.post("/nanobot/chat/stream")
async def nanobot_chat_stream(request: ChatRequest):
async def event_generator():
@@ -184,6 +186,8 @@ async def nanobot_chat_stream(request: ChatRequest):
yield f"data: {json.dumps({'type': 'routing', 'selected': 'agent', 'reason': 'auto_routed_by_agent'}, ensure_ascii=False)}\n\n"
progress_queue: asyncio.Queue[str] = asyncio.Queue()
# 设置 streaming_queue_var 为当前请求的 progress_queue
streaming_queue_var.set(progress_queue)
async def _on_progress(content: str, **kwargs: Any) -> None:
if content:
@@ -237,7 +241,10 @@ async def nanobot_chat_stream(request: ChatRequest):
break
try:
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
if isinstance(progress, dict):
yield f"data: {json.dumps(progress, ensure_ascii=False)}\n\n"
else:
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
continue
@@ -266,10 +273,9 @@ async def nanobot_chat_stream(request: ChatRequest):
session.messages[-1]["viz"] = viz_payload
nanobot_service.agent.sessions.save(session)
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
yield f"data: {json.dumps({'type': 'delta', 'content': chunk}, ensure_ascii=False)}\n\n"
# Since true streaming is enabled via StreamingLiteLLMProvider,
# we no longer need to chunk and yield `text` here.
# Just yield the final text to signal completion and update final state.
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
except asyncio.CancelledError:
+30 -9
View File
@@ -24,6 +24,7 @@ interface Message {
viz?: MessageViz;
progressLogs?: string[];
routeInfo?: string;
reasoningContent?: string;
}
interface MessageViz {
@@ -526,18 +527,26 @@ export function ChatInterface() {
progressLogs: ["请求已提交,准备路由..."],
}]);
const pushProgressLog = (text: string) => {
if (!text.trim()) return;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => {
if (msg.id !== assistantId) return msg;
const pushProgressLog = (text: string, isReasoningToken: boolean = false) => {
if (!text.trim() && !isReasoningToken) return;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => {
if (msg.id !== assistantId) return msg;
if (isReasoningToken) {
// 对于流式推理内容,拼接而不是创建新条目
const currentReasoning = msg.reasoningContent || "";
return { ...msg, reasoningContent: currentReasoning + text };
} else {
// 对于普通的阶段性日志,保留最近的 8 条
const current = msg.progressLogs || [];
if (current[current.length - 1] === text) return msg;
const next = [...current, text].slice(-8);
return { ...msg, progressLogs: next };
})
);
};
}
})
);
};
const token = localStorage.getItem("token");
const effectiveModelId = selectedModelId || currentModel?.id || "";
@@ -627,6 +636,7 @@ export function ChatInterface() {
const payload = JSON.parse(payloadText) as {
type: string;
content?: string;
is_reasoning?: boolean;
sql?: string;
result?: unknown;
error?: string;
@@ -652,7 +662,9 @@ export function ChatInterface() {
}
if (payload.type === "progress" && payload.content) {
pushProgressLog(payload.content);
// 如果 progress 内容带有空格或者换行,并且不是典型的系统提示词,很可能这是 reasoning_content
// 为了安全起见,我们在后端应该加上 is_reasoning 标记,这里我们通过启发式或者统一拼接
pushProgressLog(payload.content, payload.is_reasoning || false);
}
if (payload.type === "final" && payload.content) {
@@ -968,6 +980,15 @@ export function ChatInterface() {
>
{msg.role === "assistant" ? (
<>
{msg.reasoningContent && (
<div className="mb-3 rounded-xl border border-zinc-200 bg-zinc-50/50 p-3 text-sm text-zinc-600 font-mono whitespace-pre-wrap leading-relaxed shadow-inner max-h-[300px] overflow-y-auto">
<div className="flex items-center gap-2 mb-2 text-xs font-semibold text-zinc-500 uppercase tracking-wider">
<Settings className={`h-3.5 w-3.5 ${msg.awaitingFirstToken ? 'animate-spin' : ''}`} />
</div>
{msg.reasoningContent}
</div>
)}
{msg.progressLogs && msg.progressLogs.length > 0 ? (
<div className="mb-2 rounded-xl border border-zinc-100 bg-zinc-50/70 px-3 py-2">
<div className="flex items-center gap-2 text-zinc-500 text-xs">