From 50352a3653ce9282184ee5783e890283ce63b14b Mon Sep 17 00:00:00 2001 From: qixinbo Date: Fri, 20 Mar 2026 16:54:21 +0800 Subject: [PATCH] feat: add streaming output --- backend/app/agent/nl2sql.py | 269 +++++++++++----------- backend/app/core/nanobot.py | 6 +- backend/app/core/streaming_provider.py | 76 ++++++ backend/main.py | 16 +- frontend/src/components/ChatInterface.tsx | 39 +++- 5 files changed, 258 insertions(+), 148 deletions(-) create mode 100644 backend/app/core/streaming_provider.py diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index 8a3907a..3f7aa04 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -375,130 +375,141 @@ Language: Chinese (Simplified) {"role": "user", "content": user_prompt} ] - # 5. Call LLM - try: - llm_started = time.perf_counter() - await emit_progress("正在生成 SQL") - response = None - last_error = "" - - for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1): - try: - response = await asyncio.wait_for( - provider.chat( - messages=messages, - max_tokens=NL2SQL_MAX_TOKENS, - temperature=NL2SQL_TEMPERATURE, - reasoning_effort=NL2SQL_REASONING_EFFORT, - request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS, - num_retries=0, - ), - timeout=NL2SQL_LLM_TIMEOUT_SECONDS, - ) - except asyncio.TimeoutError: - last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s" - if attempt < NL2SQL_LLM_RETRY_COUNT: - await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") - continue - return NL2SQLResponse(sql="", result=[], error=last_error) - except Exception as e: - last_error = f"LLM generation failed: {e}" - if attempt < NL2SQL_LLM_RETRY_COUNT: - await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") - continue - return NL2SQLResponse(sql="", result=[], error=last_error) - - if response.finish_reason == "error": - last_error = response.content or "LLM Error" - if attempt < NL2SQL_LLM_RETRY_COUNT: - await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") - continue - return NL2SQLResponse(sql="", result=[], error=last_error) - break - - if response is None: - return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed") - - content = (response.content or "").strip() - if not content: - return NL2SQLResponse(sql="", result=[], error="LLM returned empty response") - - # Clean up code blocks - if "```json" in content: - content = content.split("```json")[1].split("```")[0] - elif "```" in content: - content = content.split("```")[1].split("```")[0] - - content = content.strip() + # 5. Call LLM & 6. Execute SQL (with Self-Correction Loop) + MAX_SQL_EXEC_RETRIES = int(os.getenv("NL2SQL_MAX_EXEC_RETRIES", "2")) + sql_query = "" + formatted_results = [] + chart_response = None + timeout_stage = "llm_generation" + for exec_attempt in range(MAX_SQL_EXEC_RETRIES + 1): try: - result_json = json.loads(content) - sql_query = result_json.get("sql", "").strip() - except json.JSONDecodeError: - # Fallback if LLM doesn't return valid JSON despite instructions - sql_query = content - await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)") - - except Exception as e: - return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}") - - # 6. Execute SQL - try: - timeout_stage = "sql_execution" - sql_exec_started = time.perf_counter() - await emit_progress("正在执行 SQL 查询") - if request.source == "upload": - if upload_df is None: - upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url) - upload_df = upload_payload["df"] - timeout_stage = "sql_execution" - formatted_results = await asyncio.wait_for( - asyncio.to_thread(_execute_upload_sql, sql_query, upload_df), - timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS, - ) - else: - timeout_stage = "sql_execution" - results = await asyncio.wait_for( - asyncio.to_thread(connector.execute_query, sql_query), - timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS, - ) - - # Format results - formatted_results = [] - if isinstance(results, list): - if results and isinstance(results[0], dict): - formatted_results = results - elif results and isinstance(results[0], (list, tuple)): - # Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case) - # If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types)) - # But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types) - # Wait, client.execute(with_column_types=True) returns (data, columns_with_types) - # Let's check what connector.execute_query returns. - # PostgresConnector returns list of dicts. - # ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs. - # Let's handle the ClickHouse case explicitly if possible or make it generic. - - # If results is list of tuples/lists, we need headers. - # Postgres returns list of dicts, so we are good. - # ClickHouse: if modified to return client.execute(..., with_column_types=True), - # it returns `(result_rows, column_types_list)`. - # So `results` here would be a tuple, not a list. - formatted_results = [list(row) for row in results] - else: - formatted_results = results - elif isinstance(results, tuple) and len(results) == 2: - # Likely ClickHouse (rows, columns) - rows, cols = results - col_names = [c[0] for c in cols] - formatted_results = [dict(zip(col_names, row)) for row in rows] + llm_started = time.perf_counter() + if exec_attempt == 0: + await emit_progress("正在生成 SQL") else: - # Unknown format, try to return as is or empty - formatted_results = [] - await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)") + await emit_progress(f"正在尝试修复 SQL ({exec_attempt}/{MAX_SQL_EXEC_RETRIES})") + + response = None + last_error = "" - # 7. Generate Chart - chart_response = None - if request.generate_chart and formatted_results: + for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1): + try: + response = await asyncio.wait_for( + provider.chat( + messages=messages, + max_tokens=NL2SQL_MAX_TOKENS, + temperature=NL2SQL_TEMPERATURE, + reasoning_effort=NL2SQL_REASONING_EFFORT, + request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS, + num_retries=0, + ), + timeout=NL2SQL_LLM_TIMEOUT_SECONDS, + ) + except asyncio.TimeoutError: + last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s" + if attempt < NL2SQL_LLM_RETRY_COUNT: + await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") + continue + return NL2SQLResponse(sql=sql_query, result=[], error=last_error) + except Exception as e: + last_error = f"LLM generation failed: {e}" + if attempt < NL2SQL_LLM_RETRY_COUNT: + await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") + continue + return NL2SQLResponse(sql=sql_query, result=[], error=last_error) + + if response.finish_reason == "error": + last_error = response.content or "LLM Error" + if attempt < NL2SQL_LLM_RETRY_COUNT: + await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})") + continue + return NL2SQLResponse(sql=sql_query, result=[], error=last_error) + break + + if response is None: + return NL2SQLResponse(sql=sql_query, result=[], error=last_error or "LLM generation failed") + + content = (response.content or "").strip() + if not content: + return NL2SQLResponse(sql=sql_query, result=[], error="LLM returned empty response") + + # Clean up code blocks + if "```json" in content: + content = content.split("```json")[1].split("```")[0] + elif "```" in content: + content = content.split("```")[1].split("```")[0] + + content = content.strip() + + try: + result_json = json.loads(content) + sql_query = result_json.get("sql", "").strip() + except json.JSONDecodeError: + # Fallback if LLM doesn't return valid JSON despite instructions + sql_query = content + + await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)") + + except Exception as e: + return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}") + + # 6. Execute SQL + try: + timeout_stage = "sql_execution" + sql_exec_started = time.perf_counter() + await emit_progress("正在执行 SQL 查询") + + if request.source == "upload": + if upload_df is None: + upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url) + upload_df = upload_payload["df"] + formatted_results = await asyncio.wait_for( + asyncio.to_thread(_execute_upload_sql, sql_query, upload_df), + timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS, + ) + else: + results = await asyncio.wait_for( + asyncio.to_thread(connector.execute_query, sql_query), + timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS, + ) + + # Format results + formatted_results = [] + if isinstance(results, list): + if results and isinstance(results[0], dict): + formatted_results = results + elif results and isinstance(results[0], (list, tuple)): + formatted_results = [list(row) for row in results] + else: + formatted_results = results + elif isinstance(results, tuple) and len(results) == 2: + rows, cols = results + col_names = [c[0] for c in cols] + formatted_results = [dict(zip(col_names, row)) for row in rows] + else: + formatted_results = [] + + await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)") + break # Execution succeeded, break the retry loop + + except asyncio.TimeoutError: + return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s") + except Exception as e: + if exec_attempt < MAX_SQL_EXEC_RETRIES: + await emit_progress(f"SQL 执行失败,准备自动修复 ({exec_attempt + 1}/{MAX_SQL_EXEC_RETRIES})") + messages.append({"role": "assistant", "content": f"```json\n{{\"sql\": \"{sql_query}\"}}\n```"}) + messages.append({ + "role": "user", + "content": f"The generated SQL failed to execute. Database error:\n{str(e)}\n\nPlease fix the SQL query to resolve this error and provide the corrected version following the exact same JSON format." + }) + continue + else: + return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed after {MAX_SQL_EXEC_RETRIES} retries: {e}") + + # 7. Generate Chart + if request.generate_chart and formatted_results: + try: chart_started = time.perf_counter() await emit_progress("正在生成可视化方案") timeout_stage = "chart_generation" @@ -506,16 +517,8 @@ Language: Chinese (Simplified) generate_chart(formatted_results, request.query), timeout=NL2SQL_CHART_TIMEOUT_SECONDS, ) - if not chart_response or not chart_response.chart_spec: - # Do not fallback automatically if the LLM explicitly decided not to or failed. - # Just pass whatever it returned (or lack thereof) - pass await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)") - await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s") - - return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) - except asyncio.TimeoutError: - if timeout_stage == "chart_generation": + except asyncio.TimeoutError: fallback_chart = ChartGenerationResponse( reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s", chart_type="", @@ -523,6 +526,8 @@ Language: Chinese (Simplified) chart_spec=None, ) return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart) - return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s") - except Exception as e: - return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}") + except Exception as e: + pass # Ignore chart generation errors, return data only + + await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s") + return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) diff --git a/backend/app/core/nanobot.py b/backend/app/core/nanobot.py index c494a47..659a6a8 100644 --- a/backend/app/core/nanobot.py +++ b/backend/app/core/nanobot.py @@ -34,6 +34,8 @@ from nanobot.config.schema import Config from app.api.skills import load_skills from app.services.llm_cache import get_llm_configs +from app.core.streaming_provider import StreamingLiteLLMProvider + class NanobotIntegration: def __init__(self): self.agent: AgentLoop | None = None @@ -156,7 +158,7 @@ class NanobotIntegration: spec = find_by_name(provider_name) # Skip API key check for now to allow initialization without full config - return LiteLLMProvider( + return StreamingLiteLLMProvider( api_key=p.api_key if p else None, api_base=config.get_api_base(model), default_model=model, @@ -211,7 +213,7 @@ class NanobotIntegration: cached = self._model_agent_cache.get(model_id) if cached: return cached - provider = LiteLLMProvider( + provider = StreamingLiteLLMProvider( api_key=target_config.get("api_key"), api_base=target_config.get("api_base"), default_model=target_config.get("model"), diff --git a/backend/app/core/streaming_provider.py b/backend/app/core/streaming_provider.py new file mode 100644 index 0000000..00faf9f --- /dev/null +++ b/backend/app/core/streaming_provider.py @@ -0,0 +1,76 @@ +import contextvars +import json +from typing import Any, Dict, List, Optional +from loguru import logger +from nanobot.providers.litellm_provider import LiteLLMProvider +from nanobot.providers.base import LLMResponse +from litellm import acompletion, stream_chunk_builder + +streaming_queue_var = contextvars.ContextVar("streaming_queue", default=None) + +class StreamingLiteLLMProvider(LiteLLMProvider): + async def chat( + self, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + model: Optional[str] = None, + temperature: float = 0.7, + max_tokens: int = 4000, + reasoning_effort: Optional[str] = None, + request_timeout: Optional[int] = None, + num_retries: Optional[int] = None, + ) -> LLMResponse: + original_model = model or self.default_model + model_name = self._resolve_model(original_model) + + kwargs: Dict[str, Any] = { + "model": model_name, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "stream": True, # 强制开启流式 + } + + if self.api_key and self.api_key != "no-key": + kwargs["api_key"] = self.api_key + if self.api_base: + kwargs["api_base"] = self.api_base + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + if tools: + kwargs["tools"] = tools + if request_timeout is not None: + kwargs["timeout"] = request_timeout + if num_retries is not None: + kwargs["num_retries"] = max(0, int(num_retries)) + + if reasoning_effort and self._supports_reasoning_effort(model_name): + kwargs["reasoning_effort"] = reasoning_effort + + try: + response_stream = await acompletion(**kwargs) + chunks = [] + queue = streaming_queue_var.get() + + async for chunk in response_stream: + chunks.append(chunk) + + if queue is not None: + # 提取普通内容或 think 内容 + delta = chunk.choices[0].delta if chunk.choices else None + if delta: + content = getattr(delta, "content", None) + reasoning_content = getattr(delta, "reasoning_content", None) + + if content: + await queue.put({"type": "delta", "content": content}) + if reasoning_content: + await queue.put({"type": "progress", "content": reasoning_content, "is_reasoning": True}) + + # 还原为完整的 response 对象供 nanobot 处理 + full_response = stream_chunk_builder(chunks, messages=messages) + return self._parse_response(full_response) + + except Exception as e: + logger.error("StreamingLiteLLMProvider failed: {}", e) + raise diff --git a/backend/main.py b/backend/main.py index 339ffe8..4ce7673 100644 --- a/backend/main.py +++ b/backend/main.py @@ -170,6 +170,8 @@ async def nanobot_chat(request: ChatRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) +from app.core.streaming_provider import streaming_queue_var + @app.post("/nanobot/chat/stream") async def nanobot_chat_stream(request: ChatRequest): async def event_generator(): @@ -184,6 +186,8 @@ async def nanobot_chat_stream(request: ChatRequest): yield f"data: {json.dumps({'type': 'routing', 'selected': 'agent', 'reason': 'auto_routed_by_agent'}, ensure_ascii=False)}\n\n" progress_queue: asyncio.Queue[str] = asyncio.Queue() + # 设置 streaming_queue_var 为当前请求的 progress_queue + streaming_queue_var.set(progress_queue) async def _on_progress(content: str, **kwargs: Any) -> None: if content: @@ -237,7 +241,10 @@ async def nanobot_chat_stream(request: ChatRequest): break try: progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2) - yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n" + if isinstance(progress, dict): + yield f"data: {json.dumps(progress, ensure_ascii=False)}\n\n" + else: + yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n" except asyncio.TimeoutError: continue @@ -266,10 +273,9 @@ async def nanobot_chat_stream(request: ChatRequest): session.messages[-1]["viz"] = viz_payload nanobot_service.agent.sessions.save(session) - for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE): - chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE] - yield f"data: {json.dumps({'type': 'delta', 'content': chunk}, ensure_ascii=False)}\n\n" - + # Since true streaming is enabled via StreamingLiteLLMProvider, + # we no longer need to chunk and yield `text` here. + # Just yield the final text to signal completion and update final state. yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" except asyncio.CancelledError: diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index ce8df4e..1871245 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -24,6 +24,7 @@ interface Message { viz?: MessageViz; progressLogs?: string[]; routeInfo?: string; + reasoningContent?: string; } interface MessageViz { @@ -526,18 +527,26 @@ export function ChatInterface() { progressLogs: ["请求已提交,准备路由..."], }]); - const pushProgressLog = (text: string) => { - if (!text.trim()) return; - setMessagesForSession(targetSessionKey, (prev) => - prev.map((msg) => { - if (msg.id !== assistantId) return msg; + const pushProgressLog = (text: string, isReasoningToken: boolean = false) => { + if (!text.trim() && !isReasoningToken) return; + setMessagesForSession(targetSessionKey, (prev) => + prev.map((msg) => { + if (msg.id !== assistantId) return msg; + + if (isReasoningToken) { + // 对于流式推理内容,拼接而不是创建新条目 + const currentReasoning = msg.reasoningContent || ""; + return { ...msg, reasoningContent: currentReasoning + text }; + } else { + // 对于普通的阶段性日志,保留最近的 8 条 const current = msg.progressLogs || []; if (current[current.length - 1] === text) return msg; const next = [...current, text].slice(-8); return { ...msg, progressLogs: next }; - }) - ); - }; + } + }) + ); + }; const token = localStorage.getItem("token"); const effectiveModelId = selectedModelId || currentModel?.id || ""; @@ -627,6 +636,7 @@ export function ChatInterface() { const payload = JSON.parse(payloadText) as { type: string; content?: string; + is_reasoning?: boolean; sql?: string; result?: unknown; error?: string; @@ -652,7 +662,9 @@ export function ChatInterface() { } if (payload.type === "progress" && payload.content) { - pushProgressLog(payload.content); + // 如果 progress 内容带有空格或者换行,并且不是典型的系统提示词,很可能这是 reasoning_content + // 为了安全起见,我们在后端应该加上 is_reasoning 标记,这里我们通过启发式或者统一拼接 + pushProgressLog(payload.content, payload.is_reasoning || false); } if (payload.type === "final" && payload.content) { @@ -968,6 +980,15 @@ export function ChatInterface() { > {msg.role === "assistant" ? ( <> + {msg.reasoningContent && ( +
+
+ + 思考过程 +
+ {msg.reasoningContent} +
+ )} {msg.progressLogs && msg.progressLogs.length > 0 ? (