feat: routing

This commit is contained in:
qixinbo
2026-03-17 11:38:02 +08:00
parent b9d1c182cf
commit 4bbecabc20
2 changed files with 63 additions and 11 deletions
+57 -7
View File
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Literal, Tuple
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
@@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
import asyncio import asyncio
import json import json
import re
from datetime import datetime from datetime import datetime
from app.api import upload, llm, skills, users, datasources, projects, semantic from app.api import upload, llm, skills, users, datasources, projects, semantic
@@ -83,13 +84,58 @@ class ChatRequest(BaseModel):
source: str = "postgres" source: str = "postgres"
prefer_sql_chart: bool = False prefer_sql_chart: bool = False
file_url: Optional[str] = None file_url: Optional[str] = None
route_mode: Literal["auto", "chat", "sql"] = "auto"
def _should_use_nl2sql(request: ChatRequest) -> bool: def _session_context_for_routing(session_id: str) -> Dict[str, Any]:
if request.prefer_sql_chart: if not nanobot_service.agent:
return {}
session = nanobot_service.agent.sessions.get_or_create(session_id)
return session.metadata or {}
def _looks_like_sql_intent(message: str) -> bool:
text = (message or "").strip().lower()
if not text:
return False
deny_patterns = [
r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)",
r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b",
r"(写|生成).*(python|脚本|代码)",
]
for pattern in deny_patterns:
if re.search(pattern, text, re.IGNORECASE):
return False
positive_patterns = [
r"\b(select|from|where|group by|order by|having|join|union|limit)\b",
r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)",
r"(多少|几条|多少条|有多少|查询|检索|按.*(天|周|月|年))",
r"(chart|plot|visuali[sz]e|dashboard)",
]
for pattern in positive_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True return True
return False
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str]:
if request.route_mode == "sql":
return True, "route_mode=sql"
if request.route_mode == "chat":
return False, "route_mode=chat"
if request.prefer_sql_chart:
return True, "prefer_sql_chart=true"
has_sql_intent = _looks_like_sql_intent(request.message)
if not has_sql_intent:
return False, "message_non_sql_intent"
session_ctx = _session_context_for_routing(request.session_id)
selected_data_source = (session_ctx.get("selected_data_source") or "").strip().lower()
if selected_data_source.startswith("ds:") or selected_data_source == "upload":
return True, "message_sql_intent_with_session_datasource"
source = (request.source or "").strip().lower() source = (request.source or "").strip().lower()
return source == "upload" or source.startswith("ds:") if source == "upload" or source.startswith("ds:"):
return True, "message_sql_intent_with_request_datasource"
return True, "message_sql_intent"
class SessionAliasUpdateRequest(BaseModel): class SessionAliasUpdateRequest(BaseModel):
@@ -146,7 +192,8 @@ def _persist_session_turn(
@app.post("/nanobot/chat") @app.post("/nanobot/chat")
async def nanobot_chat(request: ChatRequest): async def nanobot_chat(request: ChatRequest):
try: try:
if _should_use_nl2sql(request): use_nl2sql, route_reason = _should_use_nl2sql(request)
if use_nl2sql:
nl2sql_result = await process_nl2sql( nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
) )
@@ -156,6 +203,7 @@ async def nanobot_chat(request: ChatRequest):
return { return {
"response": text, "response": text,
"viz": viz_payload, "viz": viz_payload,
"routing": {"selected": "sql", "reason": route_reason},
} }
response = await nanobot_service.process_message( response = await nanobot_service.process_message(
request.message, request.message,
@@ -163,7 +211,7 @@ async def nanobot_chat(request: ChatRequest):
skill_ids=request.skill_ids, skill_ids=request.skill_ids,
model_id=request.model_id, model_id=request.model_id,
) )
return {"response": response} return {"response": response, "routing": {"selected": "chat", "reason": route_reason}}
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@@ -171,7 +219,9 @@ async def nanobot_chat(request: ChatRequest):
async def nanobot_chat_stream(request: ChatRequest): async def nanobot_chat_stream(request: ChatRequest):
async def event_generator(): async def event_generator():
try: try:
if _should_use_nl2sql(request): use_nl2sql, route_reason = _should_use_nl2sql(request)
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
if use_nl2sql:
nl2sql_result = await process_nl2sql( nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
) )
+6 -4
View File
@@ -359,7 +359,7 @@ export function ChatInterface() {
}; };
const handleSend = async () => { const handleSend = async () => {
if (!input.trim() || isLoading || !selectedDataSource) return; if (!input.trim() || isLoading) return;
const newMessage: Message = { id: Date.now().toString(), role: 'user', content: input }; const newMessage: Message = { id: Date.now().toString(), role: 'user', content: input };
setMessages(prev => [...prev, newMessage]); setMessages(prev => [...prev, newMessage]);
@@ -388,7 +388,7 @@ export function ChatInterface() {
const token = localStorage.getItem("token"); const token = localStorage.getItem("token");
const effectiveModelId = selectedModelId || currentModel?.id || ""; const effectiveModelId = selectedModelId || currentModel?.id || "";
let source = selectedDataSource; let source = selectedDataSource || "postgres";
const useUploadSource = Boolean(currentAttachedFile?.url?.startsWith("local://")); const useUploadSource = Boolean(currentAttachedFile?.url?.startsWith("local://"));
if (useUploadSource) { if (useUploadSource) {
@@ -411,6 +411,7 @@ export function ChatInterface() {
source, source,
prefer_sql_chart: preferSqlChart, prefer_sql_chart: preferSqlChart,
file_url: fileUrl, file_url: fileUrl,
route_mode: "auto",
}), }),
signal: controller.signal, signal: controller.signal,
}); });
@@ -499,6 +500,7 @@ export function ChatInterface() {
source, source,
prefer_sql_chart: preferSqlChart, prefer_sql_chart: preferSqlChart,
file_url: fileUrl, file_url: fileUrl,
route_mode: "auto",
}, { signal: controller.signal }); }, { signal: controller.signal });
const fallbackViz = fallback.viz ? buildMessageViz(fallback.viz) : undefined; const fallbackViz = fallback.viz ? buildMessageViz(fallback.viz) : undefined;
setMessages((prev) => setMessages((prev) =>
@@ -725,7 +727,7 @@ export function ChatInterface() {
<div className="flex items-center gap-1"> <div className="flex items-center gap-1">
<button <button
onClick={handleSend} onClick={handleSend}
disabled={isLoading || !selectedDataSource || !input.trim()} disabled={isLoading || !input.trim()}
className={cn( className={cn(
"flex items-center justify-center h-10 w-10 rounded-full transition-all duration-200", "flex items-center justify-center h-10 w-10 rounded-full transition-all duration-200",
(input.trim() || attachedFile || activeDataFile) && !isLoading (input.trim() || attachedFile || activeDataFile) && !isLoading
@@ -931,7 +933,7 @@ export function ChatInterface() {
<div className="flex items-center gap-1"> <div className="flex items-center gap-1">
<button <button
onClick={isLoading ? handleForceStop : handleSend} onClick={isLoading ? handleForceStop : handleSend}
disabled={isLoading ? false : !selectedDataSource || !input.trim()} disabled={isLoading ? false : !input.trim()}
className={cn( className={cn(
"flex items-center justify-center h-10 w-10 rounded-full transition-all duration-200", "flex items-center justify-center h-10 w-10 rounded-full transition-all duration-200",
(input.trim() || isLoading) (input.trim() || isLoading)