speed optim
This commit is contained in:
@@ -9,8 +9,8 @@ if str(PROJECT_ROOT) not in sys.path:
|
||||
sys.path.append(str(PROJECT_ROOT))
|
||||
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.services.llm_cache import get_active_llm_config
|
||||
|
||||
CHART_INSTRUCTIONS = """
|
||||
### INSTRUCTIONS ###
|
||||
@@ -133,9 +133,7 @@ CHART_EXAMPLES = """
|
||||
"""
|
||||
|
||||
async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerationResponse:
|
||||
# 1. Initialize Provider
|
||||
llm_configs = load_llm_config()
|
||||
active_config = next((c for c in llm_configs if c.get("is_active")), None)
|
||||
active_config = get_active_llm_config()
|
||||
|
||||
if not active_config:
|
||||
return ChartGenerationResponse(
|
||||
@@ -178,7 +176,7 @@ async def generate_chart(data: List[Dict[str, Any]], query: str) -> ChartGenerat
|
||||
columns = list(data[0].keys())
|
||||
|
||||
# 3. Construct Prompt
|
||||
schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), indent=2)
|
||||
schema_json = json.dumps(ChartGenerationResponse.model_json_schema(), ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
system_prompt = f"""You are a data analyst great at visualizing data using vega-lite! Given the user's question, sample data and sample column values, you need to generate vega-lite schema in JSON and provide suitable chart type.
|
||||
Besides, you need to give a concise and easy-to-understand reasoning to describe why you provide such vega-lite schema based on the question, sample data and sample column values.
|
||||
@@ -201,7 +199,7 @@ Please provide your chain of thought reasoning, chart type and the vega-lite sch
|
||||
user_prompt = f"""
|
||||
### INPUT ###
|
||||
Question: {query}
|
||||
Sample Data: {json.dumps(sample_data, indent=2, default=str)}
|
||||
Sample Data: {json.dumps(sample_data, ensure_ascii=False, separators=(",", ":"), default=str)}
|
||||
Sample Column Values: {columns}
|
||||
Language: Chinese (Simplified)
|
||||
|
||||
|
||||
@@ -18,13 +18,13 @@ from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.connectors.factory import get_connector
|
||||
from app.api.llm import _load_data as load_llm_config
|
||||
from app.schemas.chart import ChartGenerationResponse
|
||||
from app.agent.chart import generate_chart
|
||||
from app.database import SessionLocal
|
||||
from app.models.datasource import DataSource
|
||||
from app.core.files import resolve_upload_file_path
|
||||
from app.services.mdl import MDLService
|
||||
from app.services.llm_cache import get_active_llm_config
|
||||
|
||||
SCHEMA_CACHE_TTL_SECONDS = 300
|
||||
CONNECTION_CACHE_TTL_SECONDS = 30
|
||||
@@ -41,6 +41,7 @@ class NL2SQLRequest(BaseModel):
|
||||
source: str = Field(..., description="Data source to query (postgres, clickhouse, upload, ds:{id})")
|
||||
file_url: Optional[str] = Field(None, description="Uploaded file URL when source is upload")
|
||||
session_id: Optional[str] = Field(None, description="Conversation session identifier")
|
||||
generate_chart: bool = Field(False, description="Whether to generate chart specification")
|
||||
|
||||
class NL2SQLResponse(BaseModel):
|
||||
sql: str
|
||||
@@ -246,7 +247,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
schema = connector.get_schema()
|
||||
_set_cached_schema(request.source, connector, schema)
|
||||
|
||||
schema_str = json.dumps(schema, indent=2)
|
||||
schema_str = json.dumps(schema, ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
# Try to load MDL context
|
||||
mdl_context = ""
|
||||
@@ -280,8 +281,7 @@ async def process_nl2sql(request: NL2SQLRequest) -> NL2SQLResponse:
|
||||
print(f"Failed to load MDL: {e}")
|
||||
|
||||
# 2. Get the active LLM config
|
||||
llm_configs = load_llm_config()
|
||||
active_config = next((c for c in llm_configs if c.get("is_active")), None)
|
||||
active_config = get_active_llm_config()
|
||||
|
||||
if not active_config:
|
||||
return NL2SQLResponse(sql="", result=[], error="No active LLM configuration found")
|
||||
@@ -383,10 +383,8 @@ Let's think step by step.
|
||||
|
||||
# 7. Generate Chart
|
||||
chart_response = None
|
||||
if formatted_results:
|
||||
# Only try to generate chart if we have results
|
||||
# Convert to list of dicts if possible, or pass as is
|
||||
chart_response = await generate_chart(formatted_results, request.query)
|
||||
if request.generate_chart and formatted_results:
|
||||
chart_response = await generate_chart(formatted_results, request.query)
|
||||
|
||||
return NL2SQLResponse(sql=sql_query, result=formatted_results, chart=chart_response)
|
||||
except Exception as e:
|
||||
|
||||
+61
-67
@@ -2,7 +2,7 @@ import asyncio
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Callable, Awaitable, Any
|
||||
from typing import List, Callable, Awaitable, Any, Dict
|
||||
|
||||
# Add project root to sys.path to allow importing nanobot
|
||||
# Assuming backend/app/core/nanobot.py -> backend/app/core -> backend/app -> backend -> root
|
||||
@@ -31,6 +31,7 @@ from nanobot.config.schema import Config
|
||||
# or just import here if we are confident.
|
||||
# Given the structure, importing here should be fine as long as skills.py doesn't import nanobot.py.
|
||||
from app.api.skills import load_skills
|
||||
from app.services.llm_cache import get_llm_configs
|
||||
|
||||
class NanobotIntegration:
|
||||
def __init__(self):
|
||||
@@ -38,6 +39,9 @@ class NanobotIntegration:
|
||||
self.bus: MessageBus | None = None
|
||||
self.cron: CronService | None = None
|
||||
self.config: Config | None = None
|
||||
self._started = False
|
||||
self._model_agent_cache: Dict[str, AgentLoop] = {}
|
||||
self._model_agent_lock = asyncio.Lock()
|
||||
|
||||
def initialize(self):
|
||||
# Set workspace path to backend/data/workspace
|
||||
@@ -137,18 +141,62 @@ class NanobotIntegration:
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
if not self.agent:
|
||||
self.initialize()
|
||||
# Start the agent loop in background
|
||||
asyncio.create_task(self.agent.run())
|
||||
asyncio.create_task(self.cron.start())
|
||||
self._started = True
|
||||
|
||||
async def stop(self):
|
||||
if self.agent:
|
||||
self.agent.stop()
|
||||
await self.agent.close_mcp()
|
||||
for agent in self._model_agent_cache.values():
|
||||
agent.stop()
|
||||
await agent.close_mcp()
|
||||
self._model_agent_cache.clear()
|
||||
if self.cron:
|
||||
self.cron.stop()
|
||||
self._started = False
|
||||
|
||||
def _build_agent_for_provider(self, provider: Any) -> AgentLoop:
|
||||
return AgentLoop(
|
||||
bus=self.bus,
|
||||
provider=provider,
|
||||
workspace=self.config.workspace_path,
|
||||
model=provider.default_model,
|
||||
temperature=self.config.agents.defaults.temperature,
|
||||
max_tokens=self.config.agents.defaults.max_tokens,
|
||||
max_iterations=self.config.agents.defaults.max_tool_iterations,
|
||||
memory_window=self.config.agents.defaults.memory_window,
|
||||
reasoning_effort=self.config.agents.defaults.reasoning_effort,
|
||||
brave_api_key=self.config.tools.web.search.api_key or None,
|
||||
web_proxy=self.config.tools.web.proxy or None,
|
||||
exec_config=self.config.tools.exec,
|
||||
cron_service=self.cron,
|
||||
restrict_to_workspace=self.config.tools.restrict_to_workspace,
|
||||
session_manager=self.agent.sessions,
|
||||
mcp_servers=self.config.tools.mcp_servers,
|
||||
channels_config=self.config.channels,
|
||||
)
|
||||
|
||||
async def _get_or_create_model_agent(self, model_id: str, target_config: Dict[str, Any]) -> AgentLoop:
|
||||
async with self._model_agent_lock:
|
||||
cached = self._model_agent_cache.get(model_id)
|
||||
if cached:
|
||||
return cached
|
||||
provider = LiteLLMProvider(
|
||||
api_key=target_config.get("api_key"),
|
||||
api_base=target_config.get("api_base"),
|
||||
default_model=target_config.get("model"),
|
||||
extra_headers=target_config.get("extra_headers"),
|
||||
provider_name=target_config.get("provider"),
|
||||
)
|
||||
agent = self._build_agent_for_provider(provider)
|
||||
self._model_agent_cache[model_id] = agent
|
||||
return agent
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
@@ -160,6 +208,7 @@ class NanobotIntegration:
|
||||
):
|
||||
if not self.agent:
|
||||
self.initialize()
|
||||
if not self._started:
|
||||
await self.start()
|
||||
|
||||
# Handle dynamic model switching
|
||||
@@ -181,79 +230,24 @@ class NanobotIntegration:
|
||||
# BUT `process_direct` is relatively isolated.
|
||||
#
|
||||
# Let's try to fetch the config first.
|
||||
current_provider = self.agent.provider
|
||||
temp_provider = None
|
||||
|
||||
if model_id:
|
||||
from app.api.llm import _load_data
|
||||
llm_configs = _load_data()
|
||||
target_config = next((item for item in llm_configs if item["id"] == model_id), None)
|
||||
|
||||
if target_config:
|
||||
# Map our DB config to Nanobot Provider
|
||||
# We reuse LiteLLMProvider for most cases as it is generic
|
||||
|
||||
# Construct kwargs for LiteLLMProvider
|
||||
provider_name = target_config["provider"]
|
||||
model_name = target_config["model"]
|
||||
|
||||
# Handle special case where provider might need to be part of model name for LiteLLM if not standard
|
||||
# But LiteLLMProvider handles `provider_name` arg.
|
||||
|
||||
temp_provider = LiteLLMProvider(
|
||||
api_key=target_config.get("api_key"),
|
||||
api_base=target_config.get("api_base"),
|
||||
default_model=model_name,
|
||||
extra_headers=target_config.get("extra_headers"),
|
||||
provider_name=provider_name
|
||||
)
|
||||
|
||||
# If we created a temp provider, we need to use it.
|
||||
# Since AgentLoop binds the provider, we might need to swap it temporarily or create a new AgentLoop.
|
||||
# Swapping is risky for concurrency.
|
||||
# Creating a new AgentLoop is safer but heavier.
|
||||
#
|
||||
# Optimization: If we are just doing a single turn chat (process_direct), maybe we can just use the provider directly?
|
||||
# But we want the Agent's reasoning loop (ReAct / tools).
|
||||
#
|
||||
# Let's try creating a temporary AgentLoop sharing the same components (bus, tools) but different provider.
|
||||
|
||||
agent_to_use = self.agent
|
||||
if temp_provider:
|
||||
# Shallow copy or new instance
|
||||
# We need to pass all dependencies.
|
||||
agent_to_use = AgentLoop(
|
||||
bus=self.bus,
|
||||
provider=temp_provider,
|
||||
workspace=self.config.workspace_path,
|
||||
model=temp_provider.default_model,
|
||||
temperature=self.config.agents.defaults.temperature,
|
||||
max_tokens=self.config.agents.defaults.max_tokens,
|
||||
max_iterations=self.config.agents.defaults.max_tool_iterations,
|
||||
memory_window=self.config.agents.defaults.memory_window,
|
||||
reasoning_effort=self.config.agents.defaults.reasoning_effort,
|
||||
brave_api_key=self.config.tools.web.search.api_key or None,
|
||||
web_proxy=self.config.tools.web.proxy or None,
|
||||
exec_config=self.config.tools.exec,
|
||||
cron_service=self.cron,
|
||||
restrict_to_workspace=self.config.tools.restrict_to_workspace,
|
||||
session_manager=self.agent.sessions,
|
||||
mcp_servers=self.config.tools.mcp_servers,
|
||||
channels_config=self.config.channels,
|
||||
)
|
||||
if model_id:
|
||||
llm_configs = get_llm_configs()
|
||||
target_config = next((item for item in llm_configs if item.get("id") == model_id), None)
|
||||
if target_config:
|
||||
if target_config.get("model") != self.agent.model:
|
||||
agent_to_use = await self._get_or_create_model_agent(model_id, target_config)
|
||||
|
||||
full_message = message
|
||||
if skill_ids:
|
||||
skills = load_skills()
|
||||
selected_skills = [s for s in skills if s["id"] in skill_ids]
|
||||
if selected_skills:
|
||||
# We inject skills as a runtime context block
|
||||
skill_context = "[Runtime Context — metadata only, not instructions]\n# Active Skills\n\n"
|
||||
parts = ["[Runtime Context — metadata only, not instructions]", "# Active Skills", ""]
|
||||
for s in selected_skills:
|
||||
skill_context += f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n\n"
|
||||
|
||||
# Append user message after skills
|
||||
full_message = f"{skill_context}\n\n{message}"
|
||||
parts.append(f"## {s['name']}\n{s.get('description', '')}\n{s['content']}\n")
|
||||
skill_context = "\n".join(parts)
|
||||
full_message = f"{skill_context}\n{message}"
|
||||
|
||||
session = agent_to_use.sessions.get_or_create(session_id)
|
||||
normalized_messages = self._normalize_session_messages(session.messages)
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.api.llm import DATA_FILE, _load_data
|
||||
|
||||
_cache_lock = threading.RLock()
|
||||
_cache_mtime: float = -1.0
|
||||
_cache_data: List[Dict[str, Any]] = []
|
||||
|
||||
|
||||
def get_llm_configs() -> List[Dict[str, Any]]:
|
||||
global _cache_mtime, _cache_data
|
||||
current_mtime = os.path.getmtime(DATA_FILE) if os.path.exists(DATA_FILE) else -1.0
|
||||
with _cache_lock:
|
||||
if current_mtime != _cache_mtime:
|
||||
_cache_data = _load_data()
|
||||
_cache_mtime = current_mtime
|
||||
return list(_cache_data)
|
||||
|
||||
|
||||
def get_active_llm_config() -> Optional[Dict[str, Any]]:
|
||||
configs = get_llm_configs()
|
||||
return next((c for c in configs if c.get("is_active")), None)
|
||||
+60
-24
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Literal, Tuple
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
@@ -43,6 +44,21 @@ app.include_router(semantic.router, prefix="/api/v1")
|
||||
|
||||
STREAM_DELTA_CHUNK_SIZE = 48
|
||||
|
||||
SQL_INTENT_DENY_PATTERNS = [
|
||||
re.compile(r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", re.IGNORECASE),
|
||||
re.compile(r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", re.IGNORECASE),
|
||||
re.compile(r"(写|生成).*(python|脚本|代码)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
SQL_INTENT_POSITIVE_PATTERNS = [
|
||||
re.compile(r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", re.IGNORECASE),
|
||||
re.compile(r"(按|按.*维度|按.*分组|统计|汇总|分组|排序|筛选|过滤|环比|同比|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", re.IGNORECASE),
|
||||
re.compile(r"(数据库|数据源|数据表|表|字段|列|行|记录).*(查询|检索|列出|统计|分析|对比|查看)", re.IGNORECASE),
|
||||
re.compile(r"(查询|检索|列出|统计|分析|对比|查看).*(数据库|数据源|数据表|表|字段|列|行|记录)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
VISUAL_INTENT_PATTERN = re.compile(r"(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|dashboard|chart|plot|visuali[sz]e)", re.IGNORECASE)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# Initialize nanobot in background
|
||||
@@ -99,28 +115,19 @@ def _looks_like_sql_intent(message: str) -> bool:
|
||||
text = (message or "").strip().lower()
|
||||
if not text:
|
||||
return False
|
||||
deny_patterns = [
|
||||
r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)",
|
||||
r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b",
|
||||
r"(写|生成).*(python|脚本|代码)",
|
||||
]
|
||||
for pattern in deny_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
for pattern in SQL_INTENT_DENY_PATTERNS:
|
||||
if pattern.search(text):
|
||||
return False
|
||||
positive_patterns = [
|
||||
r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b",
|
||||
r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)",
|
||||
r"(多少|几条|多少条|有多少|查询|检索|列出|列表|清单|显示|展示|查看|分析|对比|情况|数据|信息|记录)",
|
||||
r"(chart|plot|visuali[sz]e|dashboard|画图|图表|可视化)",
|
||||
r"\b(list|show|get|find|search|analyze|compare)\b",
|
||||
r"\b(how many|what|which|who|when|where)\b",
|
||||
]
|
||||
for pattern in positive_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
for pattern in SQL_INTENT_POSITIVE_PATTERNS:
|
||||
if pattern.search(text):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _looks_like_visual_intent(message: str) -> bool:
|
||||
return bool(VISUAL_INTENT_PATTERN.search((message or "").strip().lower()))
|
||||
|
||||
|
||||
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]:
|
||||
# Determine the effective data source from session context or request
|
||||
session_ctx = _session_context_for_routing(request.session_id)
|
||||
@@ -211,7 +218,12 @@ async def nanobot_chat(request: ChatRequest):
|
||||
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
||||
if use_nl2sql:
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url)
|
||||
NL2SQLRequest(
|
||||
query=request.message,
|
||||
source=resolved_source,
|
||||
file_url=request.file_url,
|
||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
||||
)
|
||||
)
|
||||
text = _build_sql_chart_text(nl2sql_result)
|
||||
viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||
@@ -239,7 +251,12 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
|
||||
if use_nl2sql:
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url)
|
||||
NL2SQLRequest(
|
||||
query=request.message,
|
||||
source=resolved_source,
|
||||
file_url=request.file_url,
|
||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
||||
)
|
||||
)
|
||||
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||
viz_payload = {
|
||||
@@ -252,12 +269,31 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
response = await nanobot_service.process_message(
|
||||
request.message,
|
||||
session_id=request.session_id,
|
||||
skill_ids=request.skill_ids,
|
||||
model_id=request.model_id,
|
||||
progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def _on_progress(content: str, **_: Any) -> None:
|
||||
if content:
|
||||
await progress_queue.put(content)
|
||||
|
||||
task = asyncio.create_task(
|
||||
nanobot_service.process_message(
|
||||
request.message,
|
||||
session_id=request.session_id,
|
||||
skill_ids=request.skill_ids,
|
||||
model_id=request.model_id,
|
||||
on_progress=_on_progress,
|
||||
)
|
||||
)
|
||||
text = ""
|
||||
while True:
|
||||
if task.done() and progress_queue.empty():
|
||||
break
|
||||
try:
|
||||
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
|
||||
yield f"data: {json.dumps({'type': 'delta', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
response = await task
|
||||
text = response or ""
|
||||
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
|
||||
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
|
||||
|
||||
@@ -447,6 +447,35 @@ export function ChatInterface() {
|
||||
let buffer = "";
|
||||
let streamedText = "";
|
||||
let streamedViz: MessageViz | null = null;
|
||||
let hasFinalPayload = false;
|
||||
let hasDonePayload = false;
|
||||
let rafPending = false;
|
||||
let renderedText = "";
|
||||
|
||||
const flushAssistant = (force = false) => {
|
||||
if (streamedText === renderedText) return;
|
||||
if (force) {
|
||||
renderedText = streamedText;
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
|
||||
)
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (rafPending) return;
|
||||
rafPending = true;
|
||||
requestAnimationFrame(() => {
|
||||
rafPending = false;
|
||||
if (streamedText === renderedText) return;
|
||||
renderedText = streamedText;
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
|
||||
)
|
||||
);
|
||||
});
|
||||
};
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
@@ -473,15 +502,13 @@ export function ChatInterface() {
|
||||
|
||||
if (payload.type === "delta" && payload.content) {
|
||||
streamedText = `${streamedText}${payload.content}`;
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: streamedText, awaitingFirstToken: false } : msg
|
||||
)
|
||||
);
|
||||
flushAssistant(false);
|
||||
}
|
||||
|
||||
if (payload.type === "final" && payload.content) {
|
||||
hasFinalPayload = true;
|
||||
streamedText = payload.content;
|
||||
flushAssistant(true);
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: payload.content || "", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
|
||||
@@ -489,6 +516,10 @@ export function ChatInterface() {
|
||||
);
|
||||
}
|
||||
|
||||
if (payload.type === "done") {
|
||||
hasDonePayload = true;
|
||||
}
|
||||
|
||||
if (payload.type === "error") {
|
||||
throw new Error(payload.content || "流式响应错误");
|
||||
}
|
||||
@@ -504,31 +535,13 @@ export function ChatInterface() {
|
||||
}
|
||||
}
|
||||
|
||||
if (!streamedText) {
|
||||
const fallback = await api.post<{
|
||||
response: string;
|
||||
viz?: {
|
||||
sql?: string;
|
||||
result?: unknown;
|
||||
error?: string | null;
|
||||
chart?: { chart_spec?: ChartSpec | null; reasoning?: string; can_visualize?: boolean; chart_type?: string } | null;
|
||||
};
|
||||
}>("/nanobot/chat", {
|
||||
message: messagePayload,
|
||||
session_id: targetSessionKey,
|
||||
model_id: effectiveModelId,
|
||||
skill_ids: selectedSkillIds,
|
||||
source,
|
||||
prefer_sql_chart: preferSqlChart,
|
||||
file_url: fileUrl,
|
||||
route_mode: "auto",
|
||||
}, { signal: controller.signal });
|
||||
const fallbackViz = fallback.viz ? buildMessageViz(fallback.viz) : undefined;
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: fallback.response || "暂无回复", awaitingFirstToken: false, viz: fallbackViz } : msg
|
||||
)
|
||||
);
|
||||
flushAssistant(true);
|
||||
if (!streamedText && (hasFinalPayload || hasDonePayload)) {
|
||||
setMessagesForSession(targetSessionKey, (prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === assistantId ? { ...msg, content: "暂无回复", awaitingFirstToken: false, viz: streamedViz ?? msg.viz } : msg
|
||||
)
|
||||
);
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (error?.name === "AbortError" || String(error?.message || "").toLowerCase().includes("aborted")) {
|
||||
|
||||
Reference in New Issue
Block a user