feat: add streaming output
This commit is contained in:
+137
-132
@@ -375,130 +375,141 @@ Language: Chinese (Simplified)
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 5. Call LLM
|
||||
try:
|
||||
llm_started = time.perf_counter()
|
||||
await emit_progress("正在生成 SQL")
|
||||
response = None
|
||||
last_error = ""
|
||||
|
||||
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
|
||||
num_retries=0,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
except Exception as e:
|
||||
last_error = f"LLM generation failed: {e}"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
|
||||
if response.finish_reason == "error":
|
||||
last_error = response.content or "LLM Error"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
return NL2SQLResponse(sql="", result=[], error=last_error or "LLM generation failed")
|
||||
|
||||
content = (response.content or "").strip()
|
||||
if not content:
|
||||
return NL2SQLResponse(sql="", result=[], error="LLM returned empty response")
|
||||
|
||||
# Clean up code blocks
|
||||
if "```json" in content:
|
||||
content = content.split("```json")[1].split("```")[0]
|
||||
elif "```" in content:
|
||||
content = content.split("```")[1].split("```")[0]
|
||||
|
||||
content = content.strip()
|
||||
# 5. Call LLM & 6. Execute SQL (with Self-Correction Loop)
|
||||
MAX_SQL_EXEC_RETRIES = int(os.getenv("NL2SQL_MAX_EXEC_RETRIES", "2"))
|
||||
sql_query = ""
|
||||
formatted_results = []
|
||||
chart_response = None
|
||||
timeout_stage = "llm_generation"
|
||||
|
||||
for exec_attempt in range(MAX_SQL_EXEC_RETRIES + 1):
|
||||
try:
|
||||
result_json = json.loads(content)
|
||||
sql_query = result_json.get("sql", "").strip()
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql="", result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
timeout_stage = "sql_execution"
|
||||
sql_exec_started = time.perf_counter()
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
timeout_stage = "sql_execution"
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
timeout_stage = "sql_execution"
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.execute_query, sql_query),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
# Handle tuple/list results (like ClickHouse withColumnTypes=False, or just in case)
|
||||
# If we have column info (ClickHouse withColumnTypes=True returns (result_rows, column_types))
|
||||
# But execute_query wrapper in ClickHouseConnector now returns (data, columns_with_types)
|
||||
# Wait, client.execute(with_column_types=True) returns (data, columns_with_types)
|
||||
# Let's check what connector.execute_query returns.
|
||||
# PostgresConnector returns list of dicts.
|
||||
# ClickHouseConnector (modified) returns (data, columns_with_types) OR just data if wrapper logic differs.
|
||||
# Let's handle the ClickHouse case explicitly if possible or make it generic.
|
||||
|
||||
# If results is list of tuples/lists, we need headers.
|
||||
# Postgres returns list of dicts, so we are good.
|
||||
# ClickHouse: if modified to return client.execute(..., with_column_types=True),
|
||||
# it returns `(result_rows, column_types_list)`.
|
||||
# So `results` here would be a tuple, not a list.
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
# Likely ClickHouse (rows, columns)
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
llm_started = time.perf_counter()
|
||||
if exec_attempt == 0:
|
||||
await emit_progress("正在生成 SQL")
|
||||
else:
|
||||
# Unknown format, try to return as is or empty
|
||||
formatted_results = []
|
||||
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
|
||||
await emit_progress(f"正在尝试修复 SQL ({exec_attempt}/{MAX_SQL_EXEC_RETRIES})")
|
||||
|
||||
response = None
|
||||
last_error = ""
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if request.generate_chart and formatted_results:
|
||||
for attempt in range(NL2SQL_LLM_RETRY_COUNT + 1):
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
provider.chat(
|
||||
messages=messages,
|
||||
max_tokens=NL2SQL_MAX_TOKENS,
|
||||
temperature=NL2SQL_TEMPERATURE,
|
||||
reasoning_effort=NL2SQL_REASONING_EFFORT,
|
||||
request_timeout=NL2SQL_LLM_REQUEST_TIMEOUT_SECONDS,
|
||||
num_retries=0,
|
||||
),
|
||||
timeout=NL2SQL_LLM_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"LLM generation timeout after {NL2SQL_LLM_TIMEOUT_SECONDS}s"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成超时,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
|
||||
except Exception as e:
|
||||
last_error = f"LLM generation failed: {e}"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"SQL 生成失败,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
|
||||
|
||||
if response.finish_reason == "error":
|
||||
last_error = response.content or "LLM Error"
|
||||
if attempt < NL2SQL_LLM_RETRY_COUNT:
|
||||
await emit_progress(f"模型返回错误,正在重试 ({attempt + 1}/{NL2SQL_LLM_RETRY_COUNT})")
|
||||
continue
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=last_error)
|
||||
break
|
||||
|
||||
if response is None:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=last_error or "LLM generation failed")
|
||||
|
||||
content = (response.content or "").strip()
|
||||
if not content:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error="LLM returned empty response")
|
||||
|
||||
# Clean up code blocks
|
||||
if "```json" in content:
|
||||
content = content.split("```json")[1].split("```")[0]
|
||||
elif "```" in content:
|
||||
content = content.split("```")[1].split("```")[0]
|
||||
|
||||
content = content.strip()
|
||||
|
||||
try:
|
||||
result_json = json.loads(content)
|
||||
sql_query = result_json.get("sql", "").strip()
|
||||
except json.JSONDecodeError:
|
||||
# Fallback if LLM doesn't return valid JSON despite instructions
|
||||
sql_query = content
|
||||
|
||||
await emit_progress(f"SQL 生成完成 ({time.perf_counter() - llm_started:.2f}s)")
|
||||
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"LLM generation failed: {e}")
|
||||
|
||||
# 6. Execute SQL
|
||||
try:
|
||||
timeout_stage = "sql_execution"
|
||||
sql_exec_started = time.perf_counter()
|
||||
await emit_progress("正在执行 SQL 查询")
|
||||
|
||||
if request.source == "upload":
|
||||
if upload_df is None:
|
||||
upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url)
|
||||
upload_df = upload_payload["df"]
|
||||
formatted_results = await asyncio.wait_for(
|
||||
asyncio.to_thread(_execute_upload_sql, sql_query, upload_df),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
else:
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.to_thread(connector.execute_query, sql_query),
|
||||
timeout=NL2SQL_SQL_EXEC_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
# Format results
|
||||
formatted_results = []
|
||||
if isinstance(results, list):
|
||||
if results and isinstance(results[0], dict):
|
||||
formatted_results = results
|
||||
elif results and isinstance(results[0], (list, tuple)):
|
||||
formatted_results = [list(row) for row in results]
|
||||
else:
|
||||
formatted_results = results
|
||||
elif isinstance(results, tuple) and len(results) == 2:
|
||||
rows, cols = results
|
||||
col_names = [c[0] for c in cols]
|
||||
formatted_results = [dict(zip(col_names, row)) for row in rows]
|
||||
else:
|
||||
formatted_results = []
|
||||
|
||||
await emit_progress(f"SQL 执行完成,返回 {len(formatted_results)} 行 ({time.perf_counter() - sql_exec_started:.2f}s)")
|
||||
break # Execution succeeded, break the retry loop
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
if exec_attempt < MAX_SQL_EXEC_RETRIES:
|
||||
await emit_progress(f"SQL 执行失败,准备自动修复 ({exec_attempt + 1}/{MAX_SQL_EXEC_RETRIES})")
|
||||
messages.append({"role": "assistant", "content": f"```json\n{{\"sql\": \"{sql_query}\"}}\n```"})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"The generated SQL failed to execute. Database error:\n{str(e)}\n\nPlease fix the SQL query to resolve this error and provide the corrected version following the exact same JSON format."
|
||||
})
|
||||
continue
|
||||
else:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed after {MAX_SQL_EXEC_RETRIES} retries: {e}")
|
||||
|
||||
# 7. Generate Chart
|
||||
if request.generate_chart and formatted_results:
|
||||
try:
|
||||
chart_started = time.perf_counter()
|
||||
await emit_progress("正在生成可视化方案")
|
||||
timeout_stage = "chart_generation"
|
||||
@@ -506,16 +517,8 @@ Language: Chinese (Simplified)
|
||||
generate_chart(formatted_results, request.query),
|
||||
timeout=NL2SQL_CHART_TIMEOUT_SECONDS,
|
||||
)
|
||||
if not chart_response or not chart_response.chart_spec:
|
||||
# Do not fallback automatically if the LLM explicitly decided not to or failed.
|
||||
# Just pass whatever it returned (or lack thereof)
|
||||
pass
|
||||
await emit_progress(f"可视化方案生成完成 ({time.perf_counter() - chart_started:.2f}s)")
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except asyncio.TimeoutError:
|
||||
if timeout_stage == "chart_generation":
|
||||
except asyncio.TimeoutError:
|
||||
fallback_chart = ChartGenerationResponse(
|
||||
reasoning=f"Chart generation timeout after {NL2SQL_CHART_TIMEOUT_SECONDS}s",
|
||||
chart_type="",
|
||||
@@ -523,6 +526,8 @@ Language: Chinese (Simplified)
|
||||
chart_spec=None,
|
||||
)
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=fallback_chart)
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution timeout after {NL2SQL_SQL_EXEC_TIMEOUT_SECONDS}s")
|
||||
except Exception as e:
|
||||
return NL2SQLResponse(sql=sql_query, result=[], error=f"SQL execution failed: {e}")
|
||||
except Exception as e:
|
||||
pass # Ignore chart generation errors, return data only
|
||||
|
||||
await emit_progress(f"NL2SQL 总耗时 {time.perf_counter() - total_started:.2f}s")
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
|
||||
@@ -34,6 +34,8 @@ from nanobot.config.schema import Config
|
||||
from app.api.skills import load_skills
|
||||
from app.services.llm_cache import get_llm_configs
|
||||
|
||||
from app.core.streaming_provider import StreamingLiteLLMProvider
|
||||
|
||||
class NanobotIntegration:
|
||||
def __init__(self):
|
||||
self.agent: AgentLoop | None = None
|
||||
@@ -156,7 +158,7 @@ class NanobotIntegration:
|
||||
spec = find_by_name(provider_name)
|
||||
# Skip API key check for now to allow initialization without full config
|
||||
|
||||
return LiteLLMProvider(
|
||||
return StreamingLiteLLMProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
@@ -211,7 +213,7 @@ class NanobotIntegration:
|
||||
cached = self._model_agent_cache.get(model_id)
|
||||
if cached:
|
||||
return cached
|
||||
provider = LiteLLMProvider(
|
||||
provider = StreamingLiteLLMProvider(
|
||||
api_key=target_config.get("api_key"),
|
||||
api_base=target_config.get("api_base"),
|
||||
default_model=target_config.get("model"),
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import contextvars
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from loguru import logger
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.base import LLMResponse
|
||||
from litellm import acompletion, stream_chunk_builder
|
||||
|
||||
streaming_queue_var = contextvars.ContextVar("streaming_queue", default=None)
|
||||
|
||||
class StreamingLiteLLMProvider(LiteLLMProvider):
|
||||
async def chat(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4000,
|
||||
reasoning_effort: Optional[str] = None,
|
||||
request_timeout: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
) -> LLMResponse:
|
||||
original_model = model or self.default_model
|
||||
model_name = self._resolve_model(original_model)
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True, # 强制开启流式
|
||||
}
|
||||
|
||||
if self.api_key and self.api_key != "no-key":
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
if request_timeout is not None:
|
||||
kwargs["timeout"] = request_timeout
|
||||
if num_retries is not None:
|
||||
kwargs["num_retries"] = max(0, int(num_retries))
|
||||
|
||||
if reasoning_effort and self._supports_reasoning_effort(model_name):
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
|
||||
try:
|
||||
response_stream = await acompletion(**kwargs)
|
||||
chunks = []
|
||||
queue = streaming_queue_var.get()
|
||||
|
||||
async for chunk in response_stream:
|
||||
chunks.append(chunk)
|
||||
|
||||
if queue is not None:
|
||||
# 提取普通内容或 think 内容
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
if delta:
|
||||
content = getattr(delta, "content", None)
|
||||
reasoning_content = getattr(delta, "reasoning_content", None)
|
||||
|
||||
if content:
|
||||
await queue.put({"type": "delta", "content": content})
|
||||
if reasoning_content:
|
||||
await queue.put({"type": "progress", "content": reasoning_content, "is_reasoning": True})
|
||||
|
||||
# 还原为完整的 response 对象供 nanobot 处理
|
||||
full_response = stream_chunk_builder(chunks, messages=messages)
|
||||
return self._parse_response(full_response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("StreamingLiteLLMProvider failed: {}", e)
|
||||
raise
|
||||
+11
-5
@@ -170,6 +170,8 @@ async def nanobot_chat(request: ChatRequest):
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
from app.core.streaming_provider import streaming_queue_var
|
||||
|
||||
@app.post("/nanobot/chat/stream")
|
||||
async def nanobot_chat_stream(request: ChatRequest):
|
||||
async def event_generator():
|
||||
@@ -184,6 +186,8 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
yield f"data: {json.dumps({'type': 'routing', 'selected': 'agent', 'reason': 'auto_routed_by_agent'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
# 设置 streaming_queue_var 为当前请求的 progress_queue
|
||||
streaming_queue_var.set(progress_queue)
|
||||
|
||||
async def _on_progress(content: str, **kwargs: Any) -> None:
|
||||
if content:
|
||||
@@ -237,7 +241,10 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
break
|
||||
try:
|
||||
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
if isinstance(progress, dict):
|
||||
yield f"data: {json.dumps(progress, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
@@ -266,10 +273,9 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
session.messages[-1]["viz"] = viz_payload
|
||||
nanobot_service.agent.sessions.save(session)
|
||||
|
||||
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
|
||||
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
|
||||
yield f"data: {json.dumps({'type': 'delta', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Since true streaming is enabled via StreamingLiteLLMProvider,
|
||||
# we no longer need to chunk and yield `text` here.
|
||||
# Just yield the final text to signal completion and update final state.
|
||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
|
||||
Reference in New Issue
Block a user