speed optim

This commit is contained in:
qixinbo
2026-03-17 20:40:56 +08:00
parent cd764fad43
commit c51f51ff69
6 changed files with 198 additions and 135 deletions
+4 -6
View File
@@ -9,8 +9,8 @@ if str(PROJECT_ROOT) not in sys.path:
sys.path.append(str(PROJECT_ROOT)) sys.path.append(str(PROJECT_ROOT))
from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.litellm_provider import LiteLLMProvider
from app.api.llm import _load_data as load_llm_config
from app.schemas.chart import ChartGenerationResponse from app.schemas.chart import ChartGenerationResponse
from app.services.llm_cache import get_active_llm_config
CHART_INSTRUCTIONS = """ CHART_INSTRUCTIONS = """
### INSTRUCTIONS ### ### INSTRUCTIONS ###
@@ -133,9 +133,7 @@ CHART_EXAMPLES = """
""" """
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse: async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
# 1. Initialize Provider active_config = get_active_llm_config()
llm_configs = load_llm_config()
active_config = next((c for c in llm_configs if c.get("is_active")), None)
if not active_config: if not active_config:
return ChartGenerationResponse( return ChartGenerationResponse(
@@ -178,7 +176,7 @@ async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerat
columns = list(data[0].keys()) columns = list(data[0].keys())
# 3. Construct Prompt # 3. Construct Prompt
schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), indent=2) schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), ensure_ascii=False, separators=(",", ":"))
system_prompt = f"""You are a data analyst great at visualizing data using vega-lite! Given the user's question, sample data and sample column values, you need to generate vega-lite schema in JSON and provide suitable chart type. system_prompt = f"""You are a data analyst great at visualizing data using vega-lite! Given the user's question, sample data and sample column values, you need to generate vega-lite schema in JSON and provide suitable chart type.
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, sample data and sample column values. Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, sample data and sample column values.
@@ -201,7 +199,7 @@ Please provide your chain of thought reasoning, chart type and the vega-lite sch
user_prompt = f""" user_prompt = f"""
### INPUT ### ### INPUT ###
Question: {query} Question: {query}
Sample Data: {json.dumps(sample_data, indent=2, default=str)} Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)}
Sample Column Values: {columns} Sample Column Values: {columns}
Language: Chinese (Simplified) Language: Chinese (Simplified)
+6 -8
View File
@@ -18,13 +18,13 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
from app.connectors.postgres import postgres_connector from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector from app.connectors.clickhouse import clickhouse_connector
from app.connectors.factory import get_connector from app.connectors.factory import get_connector
from app.api.llm import _load_data as load_llm_config
from app.schemas.chart import ChartGenerationResponse from app.schemas.chart import ChartGenerationResponse
from app.agent.chart import generate_chart from app.agent.chart import generate_chart
from app.database import SessionLocal from app.database import SessionLocal
from app.models.datasource import DataSource from app.models.datasource import DataSource
from app.core.files import resolve_upload_file_path from app.core.files import resolve_upload_file_path
from app.services.mdl import MDLService from app.services.mdl import MDLService
from app.services.llm_cache import get_active_llm_config
SCHEMA_CACHE_TTL_SECONDS = 300 SCHEMA_CACHE_TTL_SECONDS = 300
CONNECTION_CACHE_TTL_SECONDS = 30 CONNECTION_CACHE_TTL_SECONDS = 30
@@ -41,6 +41,7 @@ class NL2SQLRequest(BaseModel):
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})") source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})")
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload") file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
session_id: Optional[str] = Field(None, description="Conversation session identifier") session_id: Optional[str] = Field(None, description="Conversation session identifier")
generate_chart: bool = Field(False, description="Whether to generate chart specification")
class NL2SQLResponse(BaseModel): class NL2SQLResponse(BaseModel):
sql: str sql: str
@@ -246,7 +247,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
schema = connector.get_schema() schema = connector.get_schema()
_set_cached_schema(request.source, connector, schema) _set_cached_schema(request.source, connector, schema)
schema_str = json.dumps(schema, indent=2) schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
# Try to load MDL context # Try to load MDL context
mdl_context = "" mdl_context = ""
@@ -280,8 +281,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
print(f"Failed to load MDL: {e}") print(f"Failed to load MDL: {e}")
# 2. Get the active LLM config # 2. Get the active LLM config
llm_configs = load_llm_config() active_config = get_active_llm_config()
active_config = next((c for c in llm_configs if c.get("is_active")), None)
if not active_config: if not active_config:
return NL2SQLResponse(sql="", result=[], error="No active LLM configuration found") return NL2SQLResponse(sql="", result=[], error="No active LLM configuration found")
@@ -383,10 +383,8 @@ Let's think step by step.
# 7. Generate Chart # 7. Generate Chart
chart_response = None chart_response = None
if formatted_results: if request.generate_chart and formatted_results:
# Only try to generate chart if we have results chart_response = await generate_chart(formatted_results, request.query)
# Convert to list of dicts if possible, or pass as is
chart_response = await generate_chart(formatted_results, request.query)
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response) return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
except Exception as e: except Exception as e:
+61 -67
View File
@@ -2,7 +2,7 @@ import asyncio
import sys import sys
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Callable, Awaitable, Any from typing import List, Callable, Awaitable, Any, Dict
# Add project root to sys.path to allow importing nanobot # Add project root to sys.path to allow importing nanobot
# Assuming backend/app/core/nanobot.py -> backend/app/core -> backend/app -> backend -> root # Assuming backend/app/core/nanobot.py -> backend/app/core -> backend/app -> backend -> root
@@ -31,6 +31,7 @@ from nanobot.config.schema import Config
# or just import here if we are confident. # or just import here if we are confident.
# Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py. # Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py.
from app.api.skills import load_skills from app.api.skills import load_skills
from app.services.llm_cache import get_llm_configs
class NanobotIntegration: class NanobotIntegration:
def __init__(self): def __init__(self):
@@ -38,6 +39,9 @@ class NanobotIntegration:
self.bus: MessageBus | None = None self.bus: MessageBus | None = None
self.cron: CronService | None = None self.cron: CronService | None = None
self.config: Config | None = None self.config: Config | None = None
self._started = False
self._model_agent_cache: Dict[str, AgentLoop] = {}
self._model_agent_lock = asyncio.Lock()
def initialize(self): def initialize(self):
# Set workspace path to backend/data/workspace # Set workspace path to backend/data/workspace
@@ -137,18 +141,62 @@ class NanobotIntegration:
) )
async def start(self): async def start(self):
if self._started:
return
if not self.agent: if not self.agent:
self.initialize() self.initialize()
# Start the agent loop in background
asyncio.create_task(self.agent.run()) asyncio.create_task(self.agent.run())
asyncio.create_task(self.cron.start()) asyncio.create_task(self.cron.start())
self._started = True
async def stop(self): async def stop(self):
if self.agent: if self.agent:
self.agent.stop() self.agent.stop()
await self.agent.close_mcp() await self.agent.close_mcp()
for agent in self._model_agent_cache.values():
agent.stop()
await agent.close_mcp()
self._model_agent_cache.clear()
if self.cron: if self.cron:
self.cron.stop() self.cron.stop()
self._started = False
def _build_agent_for_provider(self, provider: Any) -> AgentLoop:
return AgentLoop(
bus=self.bus,
provider=provider,
workspace=self.config.workspace_path,
model=provider.default_model,
temperature=self.config.agents.defaults.temperature,
max_tokens=self.config.agents.defaults.max_tokens,
max_iterations=self.config.agents.defaults.max_tool_iterations,
memory_window=self.config.agents.defaults.memory_window,
reasoning_effort=self.config.agents.defaults.reasoning_effort,
brave_api_key=self.config.tools.web.search.api_key or None,
web_proxy=self.config.tools.web.proxy or None,
exec_config=self.config.tools.exec,
cron_service=self.cron,
restrict_to_workspace=self.config.tools.restrict_to_workspace,
session_manager=self.agent.sessions,
mcp_servers=self.config.tools.mcp_servers,
channels_config=self.config.channels,
)
async def _get_or_create_model_agent(self, model_id: str, target_config: Dict[str, Any]) -> AgentLoop:
async with self._model_agent_lock:
cached = self._model_agent_cache.get(model_id)
if cached:
return cached
provider = LiteLLMProvider(
api_key=target_config.get("api_key"),
api_base=target_config.get("api_base"),
default_model=target_config.get("model"),
extra_headers=target_config.get("extra_headers"),
provider_name=target_config.get("provider"),
)
agent = self._build_agent_for_provider(provider)
self._model_agent_cache[model_id] = agent
return agent
async def process_message( async def process_message(
self, self,
@@ -160,6 +208,7 @@ class NanobotIntegration:
): ):
if not self.agent: if not self.agent:
self.initialize() self.initialize()
if not self._started:
await self.start() await self.start()
# Handle dynamic model switching # Handle dynamic model switching
@@ -181,79 +230,24 @@ class NanobotIntegration:
# BUT `process_direct` is relatively isolated. # BUT `process_direct` is relatively isolated.
# #
# Let's try to fetch the config first. # Let's try to fetch the config first.
current_provider = self.agent.provider
temp_provider = None
if model_id:
from app.api.llm import _load_data
llm_configs = _load_data()
target_config = next((item for item in llm_configs if item["id"] == model_id), None)
if target_config:
# Map our DB config to Nanobot Provider
# We reuse LiteLLMProvider for most cases as it is generic
# Construct kwargs for LiteLLMProvider
provider_name = target_config["provider"]
model_name = target_config["model"]
# Handle special case where provider might need to be part of model name for LiteLLM if not standard
# But LiteLLMProvider handles `provider_name` arg.
temp_provider = LiteLLMProvider(
api_key=target_config.get("api_key"),
api_base=target_config.get("api_base"),
default_model=model_name,
extra_headers=target_config.get("extra_headers"),
provider_name=provider_name
)
# If we created a temp provider, we need to use it.
# Since AgentLoop binds the provider, we might need to swap it temporarily or create a new AgentLoop.
# Swapping is risky for concurrency.
# Creating a new AgentLoop is safer but heavier.
#
# Optimization: If we are just doing a single turn chat (process_direct), maybe we can just use the provider directly?
# But we want the Agent's reasoning loop (ReAct / tools).
#
# Let's try creating a temporary AgentLoop sharing the same components (bus, tools) but different provider.
agent_to_use = self.agent agent_to_use = self.agent
if temp_provider: if model_id:
# Shallow copy or new instance llm_configs = get_llm_configs()
# We need to pass all dependencies. target_config = next((item for item in llm_configs if item.get("id") == model_id), None)
agent_to_use = AgentLoop( if target_config:
bus=self.bus, if target_config.get("model") != self.agent.model:
provider=temp_provider, agent_to_use = await self._get_or_create_model_agent(model_id, target_config)
workspace=self.config.workspace_path,
model=temp_provider.default_model,
temperature=self.config.agents.defaults.temperature,
max_tokens=self.config.agents.defaults.max_tokens,
max_iterations=self.config.agents.defaults.max_tool_iterations,
memory_window=self.config.agents.defaults.memory_window,
reasoning_effort=self.config.agents.defaults.reasoning_effort,
brave_api_key=self.config.tools.web.search.api_key or None,
web_proxy=self.config.tools.web.proxy or None,
exec_config=self.config.tools.exec,
cron_service=self.cron,
restrict_to_workspace=self.config.tools.restrict_to_workspace,
session_manager=self.agent.sessions,
mcp_servers=self.config.tools.mcp_servers,
channels_config=self.config.channels,
)
full_message = message full_message = message
if skill_ids: if skill_ids:
skills = load_skills() skills = load_skills()
selected_skills = [s for s in skills if s["id"] in skill_ids] selected_skills = [s for s in skills if s["id"] in skill_ids]
if selected_skills: if selected_skills:
# We inject skills as a runtime context block parts = ["[Runtime Context — metadata only, not instructions]", "# Active Skills", ""]
skill_context = "[Runtime Context — metadata only, not instructions]\n# Active Skills\n\n"
for s in selected_skills: for s in selected_skills:
skill_context += f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n\n" parts.append(f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n")
skill_context = "\n".join(parts)
# Append user message after skills full_message = f"{skill_context}\n{message}"
full_message = f"{skill_context}\n\n{message}"
session = agent_to_use.sessions.get_or_create(session_id) session = agent_to_use.sessions.get_or_create(session_id)
normalized_messages = self._normalize_session_messages(session.messages) normalized_messages = self._normalize_session_messages(session.messages)
+24
View File
@@ -0,0 +1,24 @@
import os
import threading
from typing import Any, Dict, List, Optional
from app.api.llm import DATA_FILE, _load_data
_cache_lock = threading.RLock()
_cache_mtime: float = -1.0
_cache_data: List[Dict[str, Any]] = []
def get_llm_configs() -> List[Dict[str, Any]]:
global _cache_mtime, _cache_data
current_mtime = os.path.getmtime(DATA_FILE) if os.path.exists(DATA_FILE) else -1.0
with _cache_lock:
if current_mtime != _cache_mtime:
_cache_data = _load_data()
_cache_mtime = current_mtime
return list(_cache_data)
def get_active_llm_config() -> Optional[Dict[str, Any]]:
configs = get_llm_configs()
return next((c for c in configs if c.get("is_active")), None)
+60 -24
View File
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Dict, List, Optional, Literal, Tuple from typing import Any, Dict, List, Optional, Literal, Tuple
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
@@ -43,6 +44,21 @@ app.include_router(semantic.router, prefix="/api/v1")
STREAM_DELTA_CHUNK_SIZE = 48 STREAM_DELTA_CHUNK_SIZE = 48
SQL_INTENT_DENY_PATTERNS = [
re.compile(r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", re.IGNORECASE),
re.compile(r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", re.IGNORECASE),
re.compile(r"(写|生成).*(python|脚本|代码)", re.IGNORECASE),
]
SQL_INTENT_POSITIVE_PATTERNS = [
re.compile(r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", re.IGNORECASE),
re.compile(r"(按|按.*维度|按.*分组|统计|汇总|分组|排序|筛选|过滤|环比|同比|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", re.IGNORECASE),
re.compile(r"(数据库|数据源|数据表|表|字段|列|行|记录).*(查询|检索|列出|统计|分析|对比|查看)", re.IGNORECASE),
re.compile(r"(查询|检索|列出|统计|分析|对比|查看).*(数据库|数据源|数据表|表|字段|列|行|记录)", re.IGNORECASE),
]
VISUAL_INTENT_PATTERN = re.compile(r"(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|dashboard|chart|plot|visuali[sz]e)", re.IGNORECASE)
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
# Initialize nanobot in background # Initialize nanobot in background
@@ -99,28 +115,19 @@ def _looks_like_sql_intent(message: str) -> bool:
text = (message or "").strip().lower() text = (message or "").strip().lower()
if not text: if not text:
return False return False
deny_patterns = [ for pattern in SQL_INTENT_DENY_PATTERNS:
r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", if pattern.search(text):
r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b",
r"(写|生成).*(python|脚本|代码)",
]
for pattern in deny_patterns:
if re.search(pattern, text, re.IGNORECASE):
return False return False
positive_patterns = [ for pattern in SQL_INTENT_POSITIVE_PATTERNS:
r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", if pattern.search(text):
r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)",
r"(多少|几条|多少条|有多少|查询|检索|列出|列表|清单|显示|展示|查看|分析|对比|情况|数据|信息|记录)",
r"(chart|plot|visuali[sz]e|dashboard|画图|图表|可视化)",
r"\b(list|show|get|find|search|analyze|compare)\b",
r"\b(how many|what|which|who|when|where)\b",
]
for pattern in positive_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True return True
return False return False
def _looks_like_visual_intent(message: str) -> bool:
return bool(VISUAL_INTENT_PATTERN.search((message or "").strip().lower()))
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]: def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]:
# Determine the effective data source from session context or request # Determine the effective data source from session context or request
session_ctx = _session_context_for_routing(request.session_id) session_ctx = _session_context_for_routing(request.session_id)
@@ -211,7 +218,12 @@ async def nanobot_chat(request: ChatRequest):
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request) use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
if use_nl2sql: if use_nl2sql:
nl2sql_result = await process_nl2sql( nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url) NL2SQLRequest(
query=request.message,
source=resolved_source,
file_url=request.file_url,
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
)
) )
text = _build_sql_chart_text(nl2sql_result) text = _build_sql_chart_text(nl2sql_result)
viz_payload = _build_sql_chart_viz(nl2sql_result) viz_payload = _build_sql_chart_viz(nl2sql_result)
@@ -239,7 +251,12 @@ async def nanobot_chat_stream(request: ChatRequest):
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
if use_nl2sql: if use_nl2sql:
nl2sql_result = await process_nl2sql( nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url) NL2SQLRequest(
query=request.message,
source=resolved_source,
file_url=request.file_url,
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
)
) )
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
viz_payload = { viz_payload = {
@@ -252,12 +269,31 @@ async def nanobot_chat_stream(request: ChatRequest):
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return return
response = await nanobot_service.process_message( progress_queue: asyncio.Queue[str] = asyncio.Queue()
request.message,
session_id=request.session_id, async def _on_progress(content: str, **_: Any) -> None:
skill_ids=request.skill_ids, if content:
model_id=request.model_id, await progress_queue.put(content)
task = asyncio.create_task(
nanobot_service.process_message(
request.message,
session_id=request.session_id,
skill_ids=request.skill_ids,
model_id=request.model_id,
on_progress=_on_progress,
)
) )
text = ""
while True:
if task.done() and progress_queue.empty():
break
try:
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
yield f"data: {json.dumps({'type': 'delta', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
continue
response = await task
text = response or "" text = response or ""
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE): for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE] chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
+43 -30
View File
@@ -447,6 +447,35 @@ export function ChatInterface() {
let buffer = ""; let buffer = "";
let streamedText = ""; let streamedText = "";
let streamedViz: MessageViz | null = null; let streamedViz: MessageViz | null = null;
let hasFinalPayload = false;
let hasDonePayload = false;
let rafPending = false;
let renderedText = "";
const flushAssistant = (force = false) => {
if (streamedText === renderedText) return;
if (force) {
renderedText = streamedText;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
)
);
return;
}
if (rafPending) return;
rafPending = true;
requestAnimationFrame(() => {
rafPending = false;
if (streamedText === renderedText) return;
renderedText = streamedText;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
)
);
});
};
while (true) { while (true) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
@@ -473,15 +502,13 @@ export function ChatInterface() {
if (payload.type === "delta" && payload.content) { if (payload.type === "delta" && payload.content) {
streamedText = `${streamedText}${payload.content}`; streamedText = `${streamedText}${payload.content}`;
setMessagesForSession(targetSessionKey, (prev) => flushAssistant(false);
prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
)
);
} }
if (payload.type === "final" && payload.content) { if (payload.type === "final" && payload.content) {
hasFinalPayload = true;
streamedText = payload.content; streamedText = payload.content;
flushAssistant(true);
setMessagesForSession(targetSessionKey, (prev) => setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) => prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
@@ -489,6 +516,10 @@ export function ChatInterface() {
); );
} }
if (payload.type === "done") {
hasDonePayload = true;
}
if (payload.type === "error") { if (payload.type === "error") {
throw new Error(payload.content || "流式响应错误"); throw new Error(payload.content || "流式响应错误");
} }
@@ -504,31 +535,13 @@ export function ChatInterface() {
} }
} }
if (!streamedText) { flushAssistant(true);
const fallback = await api.post<{ if (!streamedText && (hasFinalPayload || hasDonePayload)) {
response: string; setMessagesForSession(targetSessionKey, (prev) =>
viz?: { prev.map((msg) =>
sql?: string; msg.id === assistantId ? { ...msg, content: "暂无回复", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
result?: unknown; )
error?: string | null; );
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
};
}>("/nanobot/chat", {
message: messagePayload,
session_id: targetSessionKey,
model_id: effectiveModelId,
skill_ids: selectedSkillIds,
source,
prefer_sql_chart: preferSqlChart,
file_url: fileUrl,
route_mode: "auto",
}, { signal: controller.signal });
const fallbackViz = fallback.viz ? buildMessageViz(fallback.viz) : undefined;
setMessagesForSession(targetSessionKey, (prev) =>
prev.map((msg) =>
msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false, viz: fallbackViz } : msg
)
);
} }
} catch (error: any) { } catch (error: any) {
if (error?.name === "AbortError" || String(error?.message || "").toLowerCase().includes("aborted")) { if (error?.name === "AbortError" || String(error?.message || "").toLowerCase().includes("aborted")) {