From a0e72111cd210634687c4a33d3a41147c3df8c58 Mon Sep 17 00:00:00 2001 From: qixinbo Date: Wed, 18 Mar 2026 21:58:11 +0800 Subject: [PATCH] refactor: convert to nl2sql skills --- backend/app/agent/nl2sql.py | 67 +++++--- backend/app/context.py | 21 +++ backend/app/core/nanobot.py | 9 ++ backend/app/tools/nl2sql.py | 117 ++++++++++++++ backend/app/tools/visualization.py | 80 ++++++++++ backend/main.py | 236 +++++++++-------------------- 6 files changed, 350 insertions(+), 180 deletions(-) create mode 100644 backend/app/context.py create mode 100644 backend/app/tools/nl2sql.py create mode 100644 backend/app/tools/visualization.py diff --git a/backend/app/agent/nl2sql.py b/backend/app/agent/nl2sql.py index c620f94..5eaae3a 100644 --- a/backend/app/agent/nl2sql.py +++ b/backend/app/agent/nl2sql.py @@ -191,14 +191,24 @@ def _set_cached_schema(source: str, connector: Any, schema: Dict[str, List[Dict[ with _cache_lock: _schema_cache[key] = {"schema": schema, "expires_at": time.time() + SCHEMA_CACHE_TTL_SECONDS} -def _check_connection_with_cache(source: str, connector: Any) -> bool: +async def _check_connection_with_cache(source: str, connector: Any) -> bool: cache_key = _build_schema_cache_key(source, connector) now = time.time() with _cache_lock: cached = _connection_cache.get(cache_key) if cached and now < cached["expires_at"]: return bool(cached["ok"]) - ok = connector.test_connection() + + # Run synchronous test_connection in a separate thread to avoid blocking event loop + try: + ok = await asyncio.wait_for( + asyncio.to_thread(connector.test_connection), + timeout=10.0 + ) + except Exception as e: + print(f"Connection test failed or timed out: {e}") + ok = False + with _cache_lock: _connection_cache[cache_key] = {"ok": ok, "expires_at": now + CONNECTION_CACHE_TTL_SECONDS} return ok @@ -224,7 +234,7 @@ async def process_nl2sql( elif request.source == "upload": try: upload_started = time.perf_counter() - upload_payload = _get_upload_payload(request.file_url) + upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url) upload_df = upload_payload["df"] schema = upload_payload["schema"] await emit_progress(f"上传文件加载完成 ({time.perf_counter() - upload_started:.2f}s)") @@ -234,14 +244,21 @@ async def process_nl2sql( try: ds_started = time.perf_counter() ds_id = int(request.source.split(":")[1]) - db = SessionLocal() - try: - ds = db.query(DataSource).filter(DataSource.id == ds_id).first() - if not ds: - return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}") - connector = get_connector(ds) - finally: - db.close() + + def _get_ds_connector(): + db = SessionLocal() + try: + ds = db.query(DataSource).filter(DataSource.id == ds_id).first() + if not ds: + return None + return get_connector(ds) + finally: + db.close() + + connector = await asyncio.to_thread(_get_ds_connector) + if not connector: + return NL2SQLResponse(sql="", result=[], error=f"Data source not found: {request.source}") + await emit_progress(f"数据源配置读取完成 ({time.perf_counter() - ds_started:.2f}s)") except ValueError: return NL2SQLResponse(sql="", result=[], error=f"Invalid data source ID: {request.source}") @@ -258,20 +275,35 @@ async def process_nl2sql( await emit_progress(f"命中 Schema 缓存,已加载 {len(schema)} 张表") else: conn_started = time.perf_counter() - if not _check_connection_with_cache(request.source, connector): + if not await _check_connection_with_cache(request.source, connector): return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") await emit_progress(f"连接检测完成 ({time.perf_counter() - conn_started:.2f}s)") schema_started = time.perf_counter() - schema = connector.get_schema() + try: + schema = await asyncio.wait_for( + asyncio.to_thread(connector.get_schema), + timeout=30.0 + ) + except Exception as e: + return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema: {e}") + _set_cached_schema(request.source, connector, schema) await emit_progress(f"Schema 拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - schema_started:.2f}s)") if connector and not schema: retry_started = time.perf_counter() # Double check in case schema was empty but connection is ok (e.g. empty db) - if not _check_connection_with_cache(request.source, connector): + if not await _check_connection_with_cache(request.source, connector): return NL2SQLResponse(sql="", result=[], error=f"Failed to connect to {request.source}") - schema = connector.get_schema() + + try: + schema = await asyncio.wait_for( + asyncio.to_thread(connector.get_schema), + timeout=30.0 + ) + except Exception as e: + return NL2SQLResponse(sql="", result=[], error=f"Failed to fetch schema on retry: {e}") + _set_cached_schema(request.source, connector, schema) await emit_progress(f"Schema 二次拉取完成,共 {len(schema)} 张表 ({time.perf_counter() - retry_started:.2f}s)") @@ -282,7 +314,7 @@ async def process_nl2sql( if request.source.startswith("ds:"): try: ds_id = int(request.source.split(":")[1]) - mdl = MDLService.get_mdl(ds_id) + mdl = await asyncio.to_thread(MDLService.get_mdl, ds_id) if mdl: mdl_lines = ["\n### SEMANTIC MODEL (WrenMDL) ###"] @@ -392,7 +424,8 @@ Language: Chinese (Simplified) await emit_progress("正在执行 SQL 查询") if request.source == "upload": if upload_df is None: - upload_df = _get_upload_payload(request.file_url)["df"] + upload_payload = await asyncio.to_thread(_get_upload_payload, request.file_url) + upload_df = upload_payload["df"] timeout_stage = "sql_execution" formatted_results = await asyncio.wait_for( asyncio.to_thread(_execute_upload_sql, sql_query, upload_df), diff --git a/backend/app/context.py b/backend/app/context.py new file mode 100644 index 0000000..1a44367 --- /dev/null +++ b/backend/app/context.py @@ -0,0 +1,21 @@ +from contextvars import ContextVar +from typing import Any, Callable, Awaitable, Dict, Optional + +# The current session ID processing the request +current_session_id: ContextVar[str] = ContextVar("current_session_id", default="") + +# A callback to send progress updates to the frontend during tool execution +current_progress_callback: ContextVar[Optional[Callable[[str], Awaitable[None]]]] = ContextVar("current_progress_callback", default=None) + +# A payload dictionary to store visualization results generated by tools +# This will be picked up by the stream handler and sent to the frontend +current_viz_data: ContextVar[Optional[Dict[str, Any]]] = ContextVar("current_viz_data", default=None) + +# Store the last queried data so the Visualization Tool can access it +current_data: ContextVar[Optional[list]] = ContextVar("current_data", default=None) + +# The data source requested by the user or bound to the session +current_data_source: ContextVar[str] = ContextVar("current_data_source", default="postgres") + +# Any file URL attached to the request +current_file_url: ContextVar[Optional[str]] = ContextVar("current_file_url", default=None) diff --git a/backend/app/core/nanobot.py b/backend/app/core/nanobot.py index 185500b..df092b3 100644 --- a/backend/app/core/nanobot.py +++ b/backend/app/core/nanobot.py @@ -85,6 +85,14 @@ class NanobotIntegration: channels_config=self.config.channels, ) + self._register_custom_tools(self.agent) + + def _register_custom_tools(self, agent: AgentLoop): + from app.tools.nl2sql import NL2SQLTool + from app.tools.visualization import VisualizationTool + agent.tools.register(NL2SQLTool()) + agent.tools.register(VisualizationTool()) + def _make_provider(self, config: Config): # Logic adapted from nanobot/cli/commands.py model = config.agents.defaults.model @@ -195,6 +203,7 @@ class NanobotIntegration: provider_name=target_config.get("provider"), ) agent = self._build_agent_for_provider(provider) + self._register_custom_tools(agent) self._model_agent_cache[model_id] = agent return agent diff --git a/backend/app/tools/nl2sql.py b/backend/app/tools/nl2sql.py new file mode 100644 index 0000000..c5cd8e5 --- /dev/null +++ b/backend/app/tools/nl2sql.py @@ -0,0 +1,117 @@ +import json +import logging +from typing import Any, Dict + +from nanobot.agent.tools.base import Tool +from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse +from app.context import current_progress_callback, current_viz_data, current_data_source, current_file_url, current_data +from fastapi.encoders import jsonable_encoder + +logger = logging.getLogger(__name__) + +def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict: + chart = nl2sql_result.chart + payload = { + "sql": nl2sql_result.sql, + "result": nl2sql_result.result, + "chart": chart.model_dump() if chart else None, + "error": nl2sql_result.error, + } + return jsonable_encoder(payload) + +class NL2SQLTool(Tool): + """ + Tool for translating natural language queries into SQL, executing them, + and optionally generating visualizations. + """ + + @property + def name(self) -> str: + return "nl2sql" + + @property + def description(self) -> str: + return ( + "Query the connected database or data source using natural language. " + "Use this tool when the user asks to query, analyze, aggregate, or fetch data from the database. " + "Set generate_chart=True if the user also wants to visualize or plot the data." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The natural language query describing what data to fetch or analyze.", + }, + "generate_chart": { + "type": "boolean", + "description": "Whether to automatically generate a visualization chart for the result. Default is False.", + } + }, + "required": ["query"], + } + + async def execute(self, **kwargs: Any) -> str: + query = kwargs.get("query", "") + generate_chart = kwargs.get("generate_chart", False) + + # Get context + source = current_data_source.get() + file_url = current_file_url.get() + on_progress = current_progress_callback.get() + + request = NL2SQLRequest( + query=query, + source=source, + file_url=file_url, + generate_chart=generate_chart, + ) + + try: + # Call the core logic + result = await process_nl2sql(request, on_progress=on_progress) + + if result.error: + return f"Error executing query: {result.error}" + + # Save the result data to context for potential later use by VisualizationTool + if result.result: + current_data.set(result.result) + + # Save visualization payload to context so the chat stream can pick it up + viz_payload = _build_sql_chart_viz(result) + existing_viz = current_viz_data.get() + if isinstance(existing_viz, dict): + existing_viz.clear() + existing_viz.update(viz_payload) + current_viz_data.set(existing_viz) + else: + current_viz_data.set(viz_payload) + + # Build a summary string for the Agent to read + row_count = len(result.result) if result.result else 0 + + summary_parts = [f"Successfully executed SQL query."] + summary_parts.append(f"SQL: {result.sql}") + summary_parts.append(f"Rows returned: {row_count}") + + if generate_chart: + if result.chart and result.chart.can_visualize: + summary_parts.append("Chart was successfully generated.") + if result.chart.reasoning: + summary_parts.append(f"Chart Reasoning: {result.chart.reasoning}") + else: + summary_parts.append("Requested a chart, but the data was not suitable for visualization.") + + summary_parts.append("\nSample data (first 5 rows):") + sample = result.result[:5] if result.result else [] + summary_parts.append(json.dumps(jsonable_encoder(sample), ensure_ascii=False)) + + return "\n".join(summary_parts) + + except Exception as e: + logger.error(f"NL2SQL Tool error: {e}", exc_info=True) + return f"An unexpected error occurred during NL2SQL execution: {str(e)}" diff --git a/backend/app/tools/visualization.py b/backend/app/tools/visualization.py new file mode 100644 index 0000000..4e071de --- /dev/null +++ b/backend/app/tools/visualization.py @@ -0,0 +1,80 @@ +import logging +from typing import Any + +from nanobot.agent.tools.base import Tool +from app.agent.chart import generate_chart +from app.context import current_data, current_viz_data, current_progress_callback +from fastapi.encoders import jsonable_encoder + +logger = logging.getLogger(__name__) + +class VisualizationTool(Tool): + """ + Tool for generating a visualization (chart) from existing data. + """ + + @property + def name(self) -> str: + return "visualization" + + @property + def description(self) -> str: + return ( + "Generate a chart or visualization based on the most recently queried data. " + "Use this tool when the user asks to plot, visualize, or create a chart from data that has already been retrieved. " + "Note: This tool relies on the data from the last executed SQL query. If no query has been executed yet, you must use the nl2sql tool first." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The user's request describing how they want the data visualized (e.g., 'plot sales by month as a bar chart').", + } + }, + "required": ["query"], + } + + async def execute(self, **kwargs: Any) -> str: + query = kwargs.get("query", "") + data = current_data.get() + on_progress = current_progress_callback.get() + + if not data: + return "Error: No data available to visualize. Please query the data first using the nl2sql tool." + + try: + if on_progress: + await on_progress("正在分析数据特征并生成可视化方案...") + + chart_response = await generate_chart(data, query) + + if chart_response.can_visualize: + # Build the viz payload (similar to NL2SQL, but without the SQL part) + # We reuse the previous viz_data if it exists (to keep SQL), or create a new one + existing_viz = current_viz_data.get() or {} + + viz_payload = { + "sql": existing_viz.get("sql", ""), + "result": data, + "chart": chart_response.model_dump(), + "error": None, + } + encoded_viz = jsonable_encoder(viz_payload) + if isinstance(existing_viz, dict): + existing_viz.clear() + existing_viz.update(encoded_viz) + current_viz_data.set(existing_viz) + else: + current_viz_data.set(encoded_viz) + + return f"Successfully generated a {chart_response.chart_type} chart.\nReasoning: {chart_response.reasoning}" + else: + return f"Could not generate a chart: {chart_response.reasoning}" + + except Exception as e: + logger.error(f"Visualization Tool error: {e}", exc_info=True) + return f"An error occurred while generating the visualization: {str(e)}" diff --git a/backend/main.py b/backend/main.py index d838734..4ff3e5a 100644 --- a/backend/main.py +++ b/backend/main.py @@ -14,7 +14,7 @@ from app.connectors.postgres import postgres_connector from app.connectors.clickhouse import clickhouse_connector from app.core.nanobot import nanobot_service from app.core.session_alias_store import session_alias_store -from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse +from app.context import current_session_id, current_progress_callback, current_viz_data, current_data_source, current_file_url from app.database import engine, Base # Import all models to ensure they are registered from app.models.user import User @@ -44,21 +44,6 @@ app.include_router(semantic.router, prefix="/api/v1") STREAM_DELTA_CHUNK_SIZE = 48 -SQL_INTENT_DENY_PATTERNS = [ - re.compile(r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", re.IGNORECASE), - re.compile(r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", re.IGNORECASE), - re.compile(r"(写|生成).*(python|脚本|代码)", re.IGNORECASE), -] - -SQL_INTENT_POSITIVE_PATTERNS = [ - re.compile(r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", re.IGNORECASE), - re.compile(r"(按|按.*维度|按.*分组|统计|汇总|分组|排序|筛选|过滤|环比|同比|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", re.IGNORECASE), - re.compile(r"(数据库|数据源|数据表|表|字段|列|行|记录).*(查询|检索|列出|统计|分析|对比|查看)", re.IGNORECASE), - re.compile(r"(查询|检索|列出|统计|分析|对比|查看).*(数据库|数据源|数据表|表|字段|列|行|记录)", re.IGNORECASE), -] - -VISUAL_INTENT_PATTERN = re.compile(r"(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|dashboard|chart|plot|visuali[sz]e)", re.IGNORECASE) - @app.on_event("startup") async def startup_event(): # Initialize nanobot in background @@ -110,56 +95,15 @@ def _session_context_for_routing(session_id: str) -> Dict[str, Any]: session = nanobot_service.agent.sessions.get_or_create(session_id) return session.metadata or {} - -def _looks_like_sql_intent(message: str) -> bool: - text = (message or "").strip().lower() - if not text: - return False - for pattern in SQL_INTENT_DENY_PATTERNS: - if pattern.search(text): - return False - for pattern in SQL_INTENT_POSITIVE_PATTERNS: - if pattern.search(text): - return True - return False - - -def _looks_like_visual_intent(message: str) -> bool: - return bool(VISUAL_INTENT_PATTERN.search((message or "").strip().lower())) - - -def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]: - # Determine the effective data source from session context or request +def _resolve_effective_source(request: ChatRequest) -> str: session_ctx = _session_context_for_routing(request.session_id) session_source = (session_ctx.get("selected_data_source") or "").strip().lower() request_source = (request.source or "").strip().lower() - # Priority: Session bound source > Request source > "postgres" effective_source = request_source if session_source.startswith("ds:") or session_source == "upload": effective_source = session_source - - if request.route_mode == "sql": - return True, "route_mode=sql", effective_source - if request.route_mode == "chat": - return False, "route_mode=chat", effective_source - if request.prefer_sql_chart: - return True, "prefer_sql_chart=true", effective_source - - has_sql_intent = _looks_like_sql_intent(request.message) - if not has_sql_intent: - return False, "message_non_sql_intent", effective_source - - # If we have intent, check if we have a valid source context - if effective_source.startswith("ds:") or effective_source == "upload": - return True, "message_sql_intent_with_datasource", effective_source - - # Even if just "postgres" (default), if intent is strong, we might allow it? - # But usually we want a bound source. - # Let's keep existing logic: if intent is strong, return True. - # But effectively, if source is "postgres", it might fail later if no tables are there. - return True, "message_sql_intent", effective_source - + return effective_source class SessionAliasUpdateRequest(BaseModel): title: Optional[str] = None @@ -175,71 +119,42 @@ class SessionFileContextUpdateRequest(BaseModel): active_data_file: Optional[Dict[str, Any]] = None selected_data_source: Optional[str] = None - -def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str: - chart = nl2sql_result.chart - can_visualize = bool(chart and chart.can_visualize and chart.chart_spec) - text = ( - f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。" - f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}" - ) - if chart and chart.reasoning: - return f"{text}\n\n可视化说明:{chart.reasoning}" - return text - - -def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict: - chart = nl2sql_result.chart - payload = { - "sql": nl2sql_result.sql, - "result": nl2sql_result.result, - "chart": chart.model_dump() if chart else None, - "error": nl2sql_result.error, - } - return jsonable_encoder(payload) - - -def _persist_session_turn( - session_id: str, - user_message: str, - assistant_message: str, - assistant_extra: Optional[dict] = None, -) -> None: - if not nanobot_service.agent: - return - session = nanobot_service.agent.sessions.get_or_create(session_id) - session.add_message("user", user_message) - session.add_message("assistant", assistant_message, **(assistant_extra or {})) - nanobot_service.agent.sessions.save(session) - @app.post("/nanobot/chat") async def nanobot_chat(request: ChatRequest): try: - use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request) - if use_nl2sql: - nl2sql_result = await process_nl2sql( - NL2SQLRequest( - query=request.message, - source=resolved_source, - file_url=request.file_url, - generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message), - ) - ) - text = _build_sql_chart_text(nl2sql_result) - viz_payload = _build_sql_chart_viz(nl2sql_result) - _persist_session_turn(request.session_id, request.message, text, {"viz": viz_payload}) - return { - "response": text, - "viz": viz_payload, - "routing": {"selected": "sql", "reason": route_reason}, - } + resolved_source = _resolve_effective_source(request) + current_data_source.set(resolved_source) + current_file_url.set(request.file_url) + current_session_id.set(request.session_id) + current_viz_data.set({}) + + # Inject instructions if explicitly routed + message = request.message + if request.route_mode == "sql" or request.prefer_sql_chart: + message = f"[System: User explicitly requested data analysis. Please use the nl2sql tool to answer the following query.]\n{message}" + elif request.route_mode == "chat": + message = f"[System: User explicitly requested normal chat. Do NOT use the nl2sql tool.]\n{message}" + response = await nanobot_service.process_message( - request.message, + message, session_id=request.session_id, skill_ids=request.skill_ids, model_id=request.model_id, ) - return {"response": response, "routing": {"selected": "chat", "reason": route_reason}} + + viz_payload = current_viz_data.get() + if viz_payload and nanobot_service.agent: + # Update the last assistant message with viz data + session = nanobot_service.agent.sessions.get_or_create(request.session_id) + if session.messages and session.messages[-1].get("role") == "assistant": + session.messages[-1]["viz"] = viz_payload + nanobot_service.agent.sessions.save(session) + + return { + "response": response, + "viz": viz_payload, + "routing": {"selected": "agent", "reason": "auto_routed_by_agent"}, + } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -248,69 +163,49 @@ async def nanobot_chat_stream(request: ChatRequest): async def event_generator(): current_task = None try: - use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request) - yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n" - if use_nl2sql: - yield f"data: {json.dumps({'type': 'progress', 'content': '已识别为数据分析请求,正在连接数据源'}, ensure_ascii=False)}\n\n" - sql_progress_queue: asyncio.Queue[str] = asyncio.Queue() + resolved_source = _resolve_effective_source(request) + current_data_source.set(resolved_source) + current_file_url.set(request.file_url) + current_session_id.set(request.session_id) + current_viz_data.set({}) - async def _on_sql_progress(content: str) -> None: - if content: - await sql_progress_queue.put(content) - - current_task = asyncio.create_task( - process_nl2sql( - NL2SQLRequest( - query=request.message, - source=resolved_source, - file_url=request.file_url, - generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message), - ), - on_progress=_on_sql_progress, - ) - ) - while True: - if current_task.done() and sql_progress_queue.empty(): - break - try: - progress = await asyncio.wait_for(sql_progress_queue.get(), timeout=0.2) - yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n" - except asyncio.TimeoutError: - continue - nl2sql_result = await current_task - if nl2sql_result.error: - yield f"data: {json.dumps({'type': 'progress', 'content': f'出错:{nl2sql_result.error},正在整理结果'}, ensure_ascii=False)}\n\n" - else: - yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n" - persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) - viz_payload = { - "type": "viz", - **persisted_viz_payload, - } - yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n" - text = _build_sql_chart_text(nl2sql_result) - _persist_session_turn(request.session_id, request.message, text, {"viz": persisted_viz_payload}) - yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n" - yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" - return + yield f"data: {json.dumps({'type': 'routing', 'selected': 'agent', 'reason': 'auto_routed_by_agent'}, ensure_ascii=False)}\n\n" + progress_queue: asyncio.Queue[str] = asyncio.Queue() - async def _on_progress(content: str, **_: Any) -> None: + async def _on_progress(content: str, **kwargs: Any) -> None: if content: await progress_queue.put(content) + current_progress_callback.set(_on_progress) + + # Inject instructions if explicitly routed + message = request.message + if request.route_mode == "sql" or request.prefer_sql_chart: + message = f"[System: User explicitly requested data analysis. Please use the nl2sql tool to answer the following query.]\n{message}" + elif request.route_mode == "chat": + message = f"[System: User explicitly requested normal chat. Do NOT use the nl2sql tool.]\n{message}" + current_task = asyncio.create_task( nanobot_service.process_message( - request.message, + message, session_id=request.session_id, skill_ids=request.skill_ids, model_id=request.model_id, on_progress=_on_progress, ) ) - yield f"data: {json.dumps({'type': 'progress', 'content': '已发送给模型,正在分析问题'}, ensure_ascii=False)}\n\n" + text = "" + viz_sent = False + while True: + # Check for viz payload during processing + viz_payload = current_viz_data.get() + if viz_payload and not viz_sent: + yield f"data: {json.dumps({'type': 'viz', **viz_payload}, ensure_ascii=False)}\n\n" + viz_sent = True + if current_task.done() and progress_queue.empty(): break try: @@ -318,11 +213,26 @@ async def nanobot_chat_stream(request: ChatRequest): yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n" except asyncio.TimeoutError: continue + response = await current_task text = response or "" + + # Check again for viz payload after task completes if not sent yet + viz_payload = current_viz_data.get() + if viz_payload and not viz_sent: + yield f"data: {json.dumps({'type': 'viz', **viz_payload}, ensure_ascii=False)}\n\n" + + # Persist viz payload to session + if viz_payload and nanobot_service.agent: + session = nanobot_service.agent.sessions.get_or_create(request.session_id) + if session.messages and session.messages[-1].get("role") == "assistant": + session.messages[-1]["viz"] = viz_payload + nanobot_service.agent.sessions.save(session) + for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE): chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE] yield f"data: {json.dumps({'type': 'delta', 'content': chunk}, ensure_ascii=False)}\n\n" + yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n" yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" except asyncio.CancelledError: