From 68a44f38373fc5e46d4b5ab02723100be2377f35 Mon Sep 17 00:00:00 2001 From: qixinbo Date: Tue, 17 Mar 2026 17:49:34 +0800 Subject: [PATCH] fix: source binded bug --- backend/main.py | 55 +++++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/backend/main.py b/backend/main.py index bfa117b..dfa293b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -108,10 +108,12 @@ def _looks_like_sql_intent(message: str) -> bool: if re.search(pattern, text, re.IGNORECASE): return False positive_patterns = [ - r"\b(select|from|where|group by|order by|having|join|union|limit)\b", + r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", - r"(多少|几条|多少条|有多少|查询|检索|按.*(天|周|月|年))", - r"(chart|plot|visuali[sz]e|dashboard)", + r"(多少|几条|多少条|有多少|查询|检索|列出|列表|清单|显示|展示|查看|分析|对比|情况|数据|信息|记录)", + r"(chart|plot|visuali[sz]e|dashboard|画图|图表|可视化)", + r"\b(list|show|get|find|search|analyze|compare)\b", + r"\b(how many|what|which|who|when|where)\b", ] for pattern in positive_patterns: if re.search(pattern, text, re.IGNORECASE): @@ -119,24 +121,37 @@ def _looks_like_sql_intent(message: str) -> bool: return False -def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str]: +def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]: + # Determine the effective data source from session context or request + session_ctx = _session_context_for_routing(request.session_id) + session_source = (session_ctx.get("selected_data_source") or "").strip().lower() + request_source = (request.source or "").strip().lower() + + # Priority: Session bound source > Request source > "postgres" + effective_source = request_source + if session_source.startswith("ds:") or session_source == "upload": + effective_source = session_source + if request.route_mode == "sql": - return True, "route_mode=sql" + return True, "route_mode=sql", effective_source if request.route_mode == "chat": - return False, "route_mode=chat" + return False, "route_mode=chat", effective_source if request.prefer_sql_chart: - return True, "prefer_sql_chart=true" + return True, "prefer_sql_chart=true", effective_source + has_sql_intent = _looks_like_sql_intent(request.message) if not has_sql_intent: - return False, "message_non_sql_intent" - session_ctx = _session_context_for_routing(request.session_id) - selected_data_source = (session_ctx.get("selected_data_source") or "").strip().lower() - if selected_data_source.startswith("ds:") or selected_data_source == "upload": - return True, "message_sql_intent_with_session_datasource" - source = (request.source or "").strip().lower() - if source == "upload" or source.startswith("ds:"): - return True, "message_sql_intent_with_request_datasource" - return True, "message_sql_intent" + return False, "message_non_sql_intent", effective_source + + # If we have intent, check if we have a valid source context + if effective_source.startswith("ds:") or effective_source == "upload": + return True, "message_sql_intent_with_datasource", effective_source + + # Even if just "postgres" (default), if intent is strong, we might allow it? + # But usually we want a bound source. + # Let's keep existing logic: if intent is strong, return True. + # But effectively, if source is "postgres", it might fail later if no tables are there. + return True, "message_sql_intent", effective_source class SessionAliasUpdateRequest(BaseModel): @@ -193,10 +208,10 @@ def _persist_session_turn( @app.post("/nanobot/chat") async def nanobot_chat(request: ChatRequest): try: - use_nl2sql, route_reason = _should_use_nl2sql(request) + use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request) if use_nl2sql: nl2sql_result = await process_nl2sql( - NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) + NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url) ) text = _build_sql_chart_text(nl2sql_result) viz_payload = _build_sql_chart_viz(nl2sql_result) @@ -220,11 +235,11 @@ async def nanobot_chat(request: ChatRequest): async def nanobot_chat_stream(request: ChatRequest): async def event_generator(): try: - use_nl2sql, route_reason = _should_use_nl2sql(request) + use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request) yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n" if use_nl2sql: nl2sql_result = await process_nl2sql( - NL2SQLRequest(query=request.message, source=request.source, file_url=request.file_url) + NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url) ) persisted_viz_payload = _build_sql_chart_viz(nl2sql_result) viz_payload = {