feat: routing

This commit is contained in:
qixinbo
2026-03-17 11:38:02 +08:00
parent b9d1c182cf
commit 4bbecabc20
2 changed files with 63 additions and 11 deletions
+57 -7
View File
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Literal, Tuple
from fastapi import FastAPI, HTTPException
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse
@@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import asyncio
import json
import re
from datetime import datetime
from app.api import upload, llm, skills, users, datasources, projects, semantic
@@ -83,13 +84,58 @@ class ChatRequest(BaseModel):
source: str = "postgres"
prefer_sql_chart: bool = False
file_url: Optional[str] = None
route_mode: Literal["auto", "chat", "sql"] = "auto"
def _should_use_nl2sql(request: ChatRequest) -> bool:
def _session_context_for_routing(session_id: str) -> Dict[str, Any]:
if not nanobot_service.agent:
return {}
session = nanobot_service.agent.sessions.get_or_create(session_id)
return session.metadata or {}
def _looks_like_sql_intent(message: str) -> bool:
text = (message or "").strip().lower()
if not text:
return False
deny_patterns = [
r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)",
r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b",
r"(写|生成).*(python|脚本|代码)",
]
for pattern in deny_patterns:
if re.search(pattern, text, re.IGNORECASE):
return False
positive_patterns = [
r"\b(select|from|where|group by|order by|having|join|union|limit)\b",
r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)",
r"(多少|几条|多少条|有多少|查询|检索|按.*(天|周|月|年))",
r"(chart|plot|visuali[sz]e|dashboard)",
]
for pattern in positive_patterns:
if re.search(pattern, text, re.IGNORECASE):
return True
return False
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str]:
if request.route_mode == "sql":
return True, "route_mode=sql"
if request.route_mode == "chat":
return False, "route_mode=chat"
if request.prefer_sql_chart:
return True
return True, "prefer_sql_chart=true"
has_sql_intent = _looks_like_sql_intent(request.message)
if not has_sql_intent:
return False, "message_non_sql_intent"
session_ctx = _session_context_for_routing(request.session_id)
selected_data_source = (session_ctx.get("selected_data_source") or "").strip().lower()
if selected_data_source.startswith("ds:") or selected_data_source == "upload":
return True, "message_sql_intent_with_session_datasource"
source = (request.source or "").strip().lower()
return source == "upload" or source.startswith("ds:")
if source == "upload" or source.startswith("ds:"):
return True, "message_sql_intent_with_request_datasource"
return True, "message_sql_intent"
class SessionAliasUpdateRequest(BaseModel):
@@ -146,7 +192,8 @@ def _persist_session_turn(
@app.post("/nanobot/chat")
async def nanobot_chat(request: ChatRequest):
try:
if _should_use_nl2sql(request):
use_nl2sql, route_reason = _should_use_nl2sql(request)
if use_nl2sql:
nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
)
@@ -156,6 +203,7 @@ async def nanobot_chat(request: ChatRequest):
return {
"response": text,
"viz": viz_payload,
"routing": {"selected": "sql", "reason": route_reason},
}
response = await nanobot_service.process_message(
request.message,
@@ -163,7 +211,7 @@ async def nanobot_chat(request: ChatRequest):
skill_ids=request.skill_ids,
model_id=request.model_id,
)
return {"response": response}
return {"response": response, "routing": {"selected": "chat", "reason": route_reason}}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -171,7 +219,9 @@ async def nanobot_chat(request: ChatRequest):
async def nanobot_chat_stream(request: ChatRequest):
async def event_generator():
try:
if _should_use_nl2sql(request):
use_nl2sql, route_reason = _should_use_nl2sql(request)
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
if use_nl2sql:
nl2sql_result = await process_nl2sql(
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
)
+6 -4
View File
@@ -359,7 +359,7 @@ export function ChatInterface() {
};
const handleSend = async () => {
if (!input.trim() || isLoading || !selectedDataSource) return;
if (!input.trim() || isLoading) return;
const newMessage: Message = { id: Date.now().toString(), role: 'user', content: input };
setMessages(prev => [...prev, newMessage]);
@@ -388,7 +388,7 @@ export function ChatInterface() {
const token = localStorage.getItem("token");
const effectiveModelId = selectedModelId || currentModel?.id || "";
let source = selectedDataSource;
let source = selectedDataSource || "postgres";
const useUploadSource = Boolean(currentAttachedFile?.url?.startsWith("local://"));
if (useUploadSource) {
@@ -411,6 +411,7 @@ export function ChatInterface() {
source,
prefer_sql_chart: preferSqlChart,
file_url: fileUrl,
route_mode: "auto",
}),
signal: controller.signal,
});
@@ -499,6 +500,7 @@ export function ChatInterface() {
source,
prefer_sql_chart: preferSqlChart,
file_url: fileUrl,
route_mode: "auto",
}, { signal: controller.signal });
const fallbackViz = fallback.viz ? buildMessageViz(fallback.viz) : undefined;
setMessages((prev) =>
@@ -725,7 +727,7 @@ export function ChatInterface() {
<div className="flex items-center gap-1">
<button
onClick={handleSend}
disabled={isLoading || !selectedDataSource || !input.trim()}
disabled={isLoading || !input.trim()}
className={cn(
"flex items-center justify-center h-10 w-10 rounded-full transition-all duration-200",
(input.trim() || attachedFile || activeDataFile) && !isLoading
@@ -931,7 +933,7 @@ export function ChatInterface() {
<div className="flex items-center gap-1">
<button
onClick={isLoading ? handleForceStop : handleSend}
disabled={isLoading ? false : !selectedDataSource || !input.trim()}
disabled={isLoading ? false : !input.trim()}
className={cn(
"flex items-center justify-center h-10 w-10 rounded-full transition-all duration-200",
(input.trim() || isLoading)