refactor: convert to nl2sql skills
This commit is contained in:
+73
-163
@@ -14,7 +14,7 @@ from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.core.nanobot import nanobot_service
|
||||
from app.core.session_alias_store import session_alias_store
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
from app.context import current_session_id, current_progress_callback, current_viz_data, current_data_source, current_file_url
|
||||
from app.database import engine, Base
|
||||
# Import all models to ensure they are registered
|
||||
from app.models.user import User
|
||||
@@ -44,21 +44,6 @@ app.include_router(semantic.router, prefix="/api/v1")
|
||||
|
||||
STREAM_DELTA_CHUNK_SIZE = 48
|
||||
|
||||
SQL_INTENT_DENY_PATTERNS = [
|
||||
re.compile(r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", re.IGNORECASE),
|
||||
re.compile(r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", re.IGNORECASE),
|
||||
re.compile(r"(写|生成).*(python|脚本|代码)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
SQL_INTENT_POSITIVE_PATTERNS = [
|
||||
re.compile(r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", re.IGNORECASE),
|
||||
re.compile(r"(按|按.*维度|按.*分组|统计|汇总|分组|排序|筛选|过滤|环比|同比|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", re.IGNORECASE),
|
||||
re.compile(r"(数据库|数据源|数据表|表|字段|列|行|记录).*(查询|检索|列出|统计|分析|对比|查看)", re.IGNORECASE),
|
||||
re.compile(r"(查询|检索|列出|统计|分析|对比|查看).*(数据库|数据源|数据表|表|字段|列|行|记录)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
VISUAL_INTENT_PATTERN = re.compile(r"(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|dashboard|chart|plot|visuali[sz]e)", re.IGNORECASE)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# Initialize nanobot in background
|
||||
@@ -110,56 +95,15 @@ def _session_context_for_routing(session_id: str) -> Dict[str, Any]:
|
||||
session = nanobot_service.agent.sessions.get_or_create(session_id)
|
||||
return session.metadata or {}
|
||||
|
||||
|
||||
def _looks_like_sql_intent(message: str) -> bool:
|
||||
text = (message or "").strip().lower()
|
||||
if not text:
|
||||
return False
|
||||
for pattern in SQL_INTENT_DENY_PATTERNS:
|
||||
if pattern.search(text):
|
||||
return False
|
||||
for pattern in SQL_INTENT_POSITIVE_PATTERNS:
|
||||
if pattern.search(text):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _looks_like_visual_intent(message: str) -> bool:
|
||||
return bool(VISUAL_INTENT_PATTERN.search((message or "").strip().lower()))
|
||||
|
||||
|
||||
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]:
|
||||
# Determine the effective data source from session context or request
|
||||
def _resolve_effective_source(request: ChatRequest) -> str:
|
||||
session_ctx = _session_context_for_routing(request.session_id)
|
||||
session_source = (session_ctx.get("selected_data_source") or "").strip().lower()
|
||||
request_source = (request.source or "").strip().lower()
|
||||
|
||||
# Priority: Session bound source > Request source > "postgres"
|
||||
effective_source = request_source
|
||||
if session_source.startswith("ds:") or session_source == "upload":
|
||||
effective_source = session_source
|
||||
|
||||
if request.route_mode == "sql":
|
||||
return True, "route_mode=sql", effective_source
|
||||
if request.route_mode == "chat":
|
||||
return False, "route_mode=chat", effective_source
|
||||
if request.prefer_sql_chart:
|
||||
return True, "prefer_sql_chart=true", effective_source
|
||||
|
||||
has_sql_intent = _looks_like_sql_intent(request.message)
|
||||
if not has_sql_intent:
|
||||
return False, "message_non_sql_intent", effective_source
|
||||
|
||||
# If we have intent, check if we have a valid source context
|
||||
if effective_source.startswith("ds:") or effective_source == "upload":
|
||||
return True, "message_sql_intent_with_datasource", effective_source
|
||||
|
||||
# Even if just "postgres" (default), if intent is strong, we might allow it?
|
||||
# But usually we want a bound source.
|
||||
# Let's keep existing logic: if intent is strong, return True.
|
||||
# But effectively, if source is "postgres", it might fail later if no tables are there.
|
||||
return True, "message_sql_intent", effective_source
|
||||
|
||||
return effective_source
|
||||
|
||||
class SessionAliasUpdateRequest(BaseModel):
|
||||
title: Optional[str] = None
|
||||
@@ -175,71 +119,42 @@ class SessionFileContextUpdateRequest(BaseModel):
|
||||
active_data_file: Optional[Dict[str, Any]] = None
|
||||
selected_data_source: Optional[str] = None
|
||||
|
||||
|
||||
def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
|
||||
chart = nl2sql_result.chart
|
||||
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
|
||||
text = (
|
||||
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
|
||||
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
|
||||
)
|
||||
if chart and chart.reasoning:
|
||||
return f"{text}\n\n可视化说明:{chart.reasoning}"
|
||||
return text
|
||||
|
||||
|
||||
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
|
||||
chart = nl2sql_result.chart
|
||||
payload = {
|
||||
"sql": nl2sql_result.sql,
|
||||
"result": nl2sql_result.result,
|
||||
"chart": chart.model_dump() if chart else None,
|
||||
"error": nl2sql_result.error,
|
||||
}
|
||||
return jsonable_encoder(payload)
|
||||
|
||||
|
||||
def _persist_session_turn(
|
||||
session_id: str,
|
||||
user_message: str,
|
||||
assistant_message: str,
|
||||
assistant_extra: Optional[dict] = None,
|
||||
) -> None:
|
||||
if not nanobot_service.agent:
|
||||
return
|
||||
session = nanobot_service.agent.sessions.get_or_create(session_id)
|
||||
session.add_message("user", user_message)
|
||||
session.add_message("assistant", assistant_message, **(assistant_extra or {}))
|
||||
nanobot_service.agent.sessions.save(session)
|
||||
|
||||
@app.post("/nanobot/chat")
|
||||
async def nanobot_chat(request: ChatRequest):
|
||||
try:
|
||||
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
||||
if use_nl2sql:
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(
|
||||
query=request.message,
|
||||
source=resolved_source,
|
||||
file_url=request.file_url,
|
||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
||||
)
|
||||
)
|
||||
text = _build_sql_chart_text(nl2sql_result)
|
||||
viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||
_persist_session_turn(request.session_id, request.message, text, {"viz": viz_payload})
|
||||
return {
|
||||
"response": text,
|
||||
"viz": viz_payload,
|
||||
"routing": {"selected": "sql", "reason": route_reason},
|
||||
}
|
||||
resolved_source = _resolve_effective_source(request)
|
||||
current_data_source.set(resolved_source)
|
||||
current_file_url.set(request.file_url)
|
||||
current_session_id.set(request.session_id)
|
||||
current_viz_data.set({})
|
||||
|
||||
# Inject instructions if explicitly routed
|
||||
message = request.message
|
||||
if request.route_mode == "sql" or request.prefer_sql_chart:
|
||||
message = f"[System: User explicitly requested data analysis. Please use the nl2sql tool to answer the following query.]\n{message}"
|
||||
elif request.route_mode == "chat":
|
||||
message = f"[System: User explicitly requested normal chat. Do NOT use the nl2sql tool.]\n{message}"
|
||||
|
||||
response = await nanobot_service.process_message(
|
||||
request.message,
|
||||
message,
|
||||
session_id=request.session_id,
|
||||
skill_ids=request.skill_ids,
|
||||
model_id=request.model_id,
|
||||
)
|
||||
return {"response": response, "routing": {"selected": "chat", "reason": route_reason}}
|
||||
|
||||
viz_payload = current_viz_data.get()
|
||||
if viz_payload and nanobot_service.agent:
|
||||
# Update the last assistant message with viz data
|
||||
session = nanobot_service.agent.sessions.get_or_create(request.session_id)
|
||||
if session.messages and session.messages[-1].get("role") == "assistant":
|
||||
session.messages[-1]["viz"] = viz_payload
|
||||
nanobot_service.agent.sessions.save(session)
|
||||
|
||||
return {
|
||||
"response": response,
|
||||
"viz": viz_payload,
|
||||
"routing": {"selected": "agent", "reason": "auto_routed_by_agent"},
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -248,69 +163,49 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
async def event_generator():
|
||||
current_task = None
|
||||
try:
|
||||
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
||||
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
|
||||
if use_nl2sql:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': '已识别为数据分析请求,正在连接数据源'}, ensure_ascii=False)}\n\n"
|
||||
sql_progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
resolved_source = _resolve_effective_source(request)
|
||||
current_data_source.set(resolved_source)
|
||||
current_file_url.set(request.file_url)
|
||||
current_session_id.set(request.session_id)
|
||||
current_viz_data.set({})
|
||||
|
||||
async def _on_sql_progress(content: str) -> None:
|
||||
if content:
|
||||
await sql_progress_queue.put(content)
|
||||
|
||||
current_task = asyncio.create_task(
|
||||
process_nl2sql(
|
||||
NL2SQLRequest(
|
||||
query=request.message,
|
||||
source=resolved_source,
|
||||
file_url=request.file_url,
|
||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
||||
),
|
||||
on_progress=_on_sql_progress,
|
||||
)
|
||||
)
|
||||
while True:
|
||||
if current_task.done() and sql_progress_queue.empty():
|
||||
break
|
||||
try:
|
||||
progress = await asyncio.wait_for(sql_progress_queue.get(), timeout=0.2)
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
nl2sql_result = await current_task
|
||||
if nl2sql_result.error:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': f'出错:{nl2sql_result.error},正在整理结果'}, ensure_ascii=False)}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
|
||||
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||
viz_payload = {
|
||||
"type": "viz",
|
||||
**persisted_viz_payload,
|
||||
}
|
||||
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
|
||||
text = _build_sql_chart_text(nl2sql_result)
|
||||
_persist_session_turn(request.session_id, request.message, text, {"viz": persisted_viz_payload})
|
||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
yield f"data: {json.dumps({'type': 'routing', 'selected': 'agent', 'reason': 'auto_routed_by_agent'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def _on_progress(content: str, **_: Any) -> None:
|
||||
async def _on_progress(content: str, **kwargs: Any) -> None:
|
||||
if content:
|
||||
await progress_queue.put(content)
|
||||
|
||||
current_progress_callback.set(_on_progress)
|
||||
|
||||
# Inject instructions if explicitly routed
|
||||
message = request.message
|
||||
if request.route_mode == "sql" or request.prefer_sql_chart:
|
||||
message = f"[System: User explicitly requested data analysis. Please use the nl2sql tool to answer the following query.]\n{message}"
|
||||
elif request.route_mode == "chat":
|
||||
message = f"[System: User explicitly requested normal chat. Do NOT use the nl2sql tool.]\n{message}"
|
||||
|
||||
current_task = asyncio.create_task(
|
||||
nanobot_service.process_message(
|
||||
request.message,
|
||||
message,
|
||||
session_id=request.session_id,
|
||||
skill_ids=request.skill_ids,
|
||||
model_id=request.model_id,
|
||||
on_progress=_on_progress,
|
||||
)
|
||||
)
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': '已发送给模型,正在分析问题'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
text = ""
|
||||
viz_sent = False
|
||||
|
||||
while True:
|
||||
# Check for viz payload during processing
|
||||
viz_payload = current_viz_data.get()
|
||||
if viz_payload and not viz_sent:
|
||||
yield f"data: {json.dumps({'type': 'viz', **viz_payload}, ensure_ascii=False)}\n\n"
|
||||
viz_sent = True
|
||||
|
||||
if current_task.done() and progress_queue.empty():
|
||||
break
|
||||
try:
|
||||
@@ -318,11 +213,26 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
response = await current_task
|
||||
text = response or ""
|
||||
|
||||
# Check again for viz payload after task completes if not sent yet
|
||||
viz_payload = current_viz_data.get()
|
||||
if viz_payload and not viz_sent:
|
||||
yield f"data: {json.dumps({'type': 'viz', **viz_payload}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# Persist viz payload to session
|
||||
if viz_payload and nanobot_service.agent:
|
||||
session = nanobot_service.agent.sessions.get_or_create(request.session_id)
|
||||
if session.messages and session.messages[-1].get("role") == "assistant":
|
||||
session.messages[-1]["viz"] = viz_payload
|
||||
nanobot_service.agent.sessions.save(session)
|
||||
|
||||
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
|
||||
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
|
||||
yield f"data: {json.dumps({'type': 'delta', 'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
|
||||
Reference in New Issue
Block a user