fix: source binded bug
This commit is contained in:
+35
-20
@@ -108,10 +108,12 @@ def _looks_like_sql_intent(message: str) -> bool:
|
|||||||
if re.search(pattern, text, re.IGNORECASE):
|
if re.search(pattern, text, re.IGNORECASE):
|
||||||
return False
|
return False
|
||||||
positive_patterns = [
|
positive_patterns = [
|
||||||
r"\b(select|from|where|group by|order by|having|join|union|limit)\b",
|
r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b",
|
||||||
r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)",
|
r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)",
|
||||||
r"(多少|几条|多少条|有多少|查询|检索|按.*(天|周|月|年))",
|
r"(多少|几条|多少条|有多少|查询|检索|列出|列表|清单|显示|展示|查看|分析|对比|情况|数据|信息|记录)",
|
||||||
r"(chart|plot|visuali[sz]e|dashboard)",
|
r"(chart|plot|visuali[sz]e|dashboard|画图|图表|可视化)",
|
||||||
|
r"\b(list|show|get|find|search|analyze|compare)\b",
|
||||||
|
r"\b(how many|what|which|who|when|where)\b",
|
||||||
]
|
]
|
||||||
for pattern in positive_patterns:
|
for pattern in positive_patterns:
|
||||||
if re.search(pattern, text, re.IGNORECASE):
|
if re.search(pattern, text, re.IGNORECASE):
|
||||||
@@ -119,24 +121,37 @@ def _looks_like_sql_intent(message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str]:
|
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]:
|
||||||
|
# Determine the effective data source from session context or request
|
||||||
|
session_ctx = _session_context_for_routing(request.session_id)
|
||||||
|
session_source = (session_ctx.get("selected_data_source") or "").strip().lower()
|
||||||
|
request_source = (request.source or "").strip().lower()
|
||||||
|
|
||||||
|
# Priority: Session bound source > Request source > "postgres"
|
||||||
|
effective_source = request_source
|
||||||
|
if session_source.startswith("ds:") or session_source == "upload":
|
||||||
|
effective_source = session_source
|
||||||
|
|
||||||
if request.route_mode == "sql":
|
if request.route_mode == "sql":
|
||||||
return True, "route_mode=sql"
|
return True, "route_mode=sql", effective_source
|
||||||
if request.route_mode == "chat":
|
if request.route_mode == "chat":
|
||||||
return False, "route_mode=chat"
|
return False, "route_mode=chat", effective_source
|
||||||
if request.prefer_sql_chart:
|
if request.prefer_sql_chart:
|
||||||
return True, "prefer_sql_chart=true"
|
return True, "prefer_sql_chart=true", effective_source
|
||||||
|
|
||||||
has_sql_intent = _looks_like_sql_intent(request.message)
|
has_sql_intent = _looks_like_sql_intent(request.message)
|
||||||
if not has_sql_intent:
|
if not has_sql_intent:
|
||||||
return False, "message_non_sql_intent"
|
return False, "message_non_sql_intent", effective_source
|
||||||
session_ctx = _session_context_for_routing(request.session_id)
|
|
||||||
selected_data_source = (session_ctx.get("selected_data_source") or "").strip().lower()
|
# If we have intent, check if we have a valid source context
|
||||||
if selected_data_source.startswith("ds:") or selected_data_source == "upload":
|
if effective_source.startswith("ds:") or effective_source == "upload":
|
||||||
return True, "message_sql_intent_with_session_datasource"
|
return True, "message_sql_intent_with_datasource", effective_source
|
||||||
source = (request.source or "").strip().lower()
|
|
||||||
if source == "upload" or source.startswith("ds:"):
|
# Even if just "postgres" (default), if intent is strong, we might allow it?
|
||||||
return True, "message_sql_intent_with_request_datasource"
|
# But usually we want a bound source.
|
||||||
return True, "message_sql_intent"
|
# Let's keep existing logic: if intent is strong, return True.
|
||||||
|
# But effectively, if source is "postgres", it might fail later if no tables are there.
|
||||||
|
return True, "message_sql_intent", effective_source
|
||||||
|
|
||||||
|
|
||||||
class SessionAliasUpdateRequest(BaseModel):
|
class SessionAliasUpdateRequest(BaseModel):
|
||||||
@@ -193,10 +208,10 @@ def _persist_session_turn(
|
|||||||
@app.post("/nanobot/chat")
|
@app.post("/nanobot/chat")
|
||||||
async def nanobot_chat(request: ChatRequest):
|
async def nanobot_chat(request: ChatRequest):
|
||||||
try:
|
try:
|
||||||
use_nl2sql, route_reason = _should_use_nl2sql(request)
|
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
||||||
if use_nl2sql:
|
if use_nl2sql:
|
||||||
nl2sql_result = await process_nl2sql(
|
nl2sql_result = await process_nl2sql(
|
||||||
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url)
|
||||||
)
|
)
|
||||||
text = _build_sql_chart_text(nl2sql_result)
|
text = _build_sql_chart_text(nl2sql_result)
|
||||||
viz_payload = _build_sql_chart_viz(nl2sql_result)
|
viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||||
@@ -220,11 +235,11 @@ async def nanobot_chat(request: ChatRequest):
|
|||||||
async def nanobot_chat_stream(request: ChatRequest):
|
async def nanobot_chat_stream(request: ChatRequest):
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
try:
|
try:
|
||||||
use_nl2sql, route_reason = _should_use_nl2sql(request)
|
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
||||||
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
|
||||||
if use_nl2sql:
|
if use_nl2sql:
|
||||||
nl2sql_result = await process_nl2sql(
|
nl2sql_result = await process_nl2sql(
|
||||||
NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url)
|
NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url)
|
||||||
)
|
)
|
||||||
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||||
viz_payload = {
|
viz_payload = {
|
||||||
|
|||||||
Reference in New Issue
Block a user