From 4bbecabc20ae165fe5a4eb952ebc96259a56e19b Mon Sep 17 00:00:00 2001 From: qixinbo Date: Tue, 17 Mar 2026 11:38:02 +0800 Subject: [PATCH] feat: routing --- backend/main.py | 64 ++++++++++++++++++++--- frontend/src/components/ChatInterface.tsx | 10 ++-- 2 files changed, 63 insertions(+), 11 deletions(-) diff --git a/backend/main.py b/backend/main.py index 46a2c8a..aa07ea7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Literal, Tuple from fastapi import FastAPI, HTTPException from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse @@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import asyncio import json +import re from datetime import datetime from app.api import upload, llm, skills, users, datasources, projects, semantic @@ -83,13 +84,58 @@ class ChatRequest(BaseModel): source: str = "postgres" prefer_sql_chart: bool = False file_url: Optional[str] = None + route_mode: Literal["auto", "chat", "sql"] = "auto" -def _should_use_nl2sql(request: ChatRequest) -> bool: +def _session_context_for_routing(session_id: str) -> Dict[str, Any]: + if not nanobot_service.agent: + return {} + session = nanobot_service.agent.sessions.get_or_create(session_id) + return session.metadata or {} + + +def _looks_like_sql_intent(message: str) -> bool: + text = (message or "").strip().lower() + if not text: + return False + deny_patterns = [ + r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", + r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", + r"(写|生成).*(python|脚本|代码)", + ] + for pattern in deny_patterns: + if re.search(pattern, text, re.IGNORECASE): + return False + positive_patterns = [ + r"\b(select|from|where|group by|order by|having|join|union|limit)\b", + r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", + r"(多少|几条|多少条|有多少|查询|检索|按.*(天|周|月|年))", + r"(chart|plot|visuali[sz]e|dashboard)", + ] + for pattern in positive_patterns: + if re.search(pattern, text, re.IGNORECASE): + return True + return False + + +def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str]: + if request.route_mode == "sql": + return True, "route_mode=sql" + if request.route_mode == "chat": + return False, "route_mode=chat" if request.prefer_sql_chart: - return True + return True, "prefer_sql_chart=true" + has_sql_intent = _looks_like_sql_intent(request.message) + if not has_sql_intent: + return False, "message_non_sql_intent" + session_ctx = _session_context_for_routing(request.session_id) + selected_data_source = (session_ctx.get("selected_data_source") or "").strip().lower() + if selected_data_source.startswith("ds:") or selected_data_source == "upload": + return True, "message_sql_intent_with_session_datasource" source = (request.source or "").strip().lower() - return source == "upload" or source.startswith("ds:") + if source == "upload" or source.startswith("ds:"): + return True, "message_sql_intent_with_request_datasource" + return True, "message_sql_intent" class SessionAliasUpdateRequest(BaseModel): @@ -146,7 +192,8 @@ def _persist_session_turn( @app.post("/nanobot/chat") async def nanobot_chat(request: ChatRequest): try: - if _should_use_nl2sql(request): + use_nl2sql, route_reason = _should_use_nl2sql(request) + if use_nl2sql: nl2sql_result = await process_nl2sql( NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) ) @@ -156,6 +203,7 @@ async def nanobot_chat(request: ChatRequest): return { "response": text, "viz": viz_payload, + "routing": {"selected": "sql", "reason": route_reason}, } response = await nanobot_service.process_message( request.message, @@ -163,7 +211,7 @@ async def nanobot_chat(request: ChatRequest): skill_ids=request.skill_ids, model_id=request.model_id, ) - return {"response": response} + return {"response": response, "routing": {"selected": "chat", "reason": route_reason}} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -171,7 +219,9 @@ async def nanobot_chat(request: ChatRequest): async def nanobot_chat_stream(request: ChatRequest): async def event_generator(): try: - if _should_use_nl2sql(request): + use_nl2sql, route_reason = _should_use_nl2sql(request) + yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n" + if use_nl2sql: nl2sql_result = await process_nl2sql( NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) ) diff --git a/frontend/src/components/ChatInterface.tsx b/frontend/src/components/ChatInterface.tsx index af6c3f0..2ceb58c 100644 --- a/frontend/src/components/ChatInterface.tsx +++ b/frontend/src/components/ChatInterface.tsx @@ -359,7 +359,7 @@ export function ChatInterface() { }; const handleSend = async () => { - if (!input.trim() || isLoading || !selectedDataSource) return; + if (!input.trim() || isLoading) return; const newMessage: Message = { id: Date.now().toString(), role: 'user', content: input }; setMessages(prev => [...prev, newMessage]); @@ -388,7 +388,7 @@ export function ChatInterface() { const token = localStorage.getItem("token"); const effectiveModelId = selectedModelId || currentModel?.id || ""; - let source = selectedDataSource; + let source = selectedDataSource || "postgres"; const useUploadSource = Boolean(currentAttachedFile?.url?.startsWith("local://")); if (useUploadSource) { @@ -411,6 +411,7 @@ export function ChatInterface() { source, prefer_sql_chart: preferSqlChart, file_url: fileUrl, + route_mode: "auto", }), signal: controller.signal, }); @@ -499,6 +500,7 @@ export function ChatInterface() { source, prefer_sql_chart: preferSqlChart, file_url: fileUrl, + route_mode: "auto", }, { signal: controller.signal }); const fallbackViz = fallback.viz ? buildMessageViz(fallback.viz) : undefined; setMessages((prev) => @@ -725,7 +727,7 @@ export function ChatInterface() {