speed optim
This commit is contained in:
+60
-24
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Literal, Tuple
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
@@ -43,6 +44,21 @@ app.include_router(semantic.router, prefix="/api/v1")
|
||||
|
||||
STREAM_DELTA_CHUNK_SIZE = 48
|
||||
|
||||
SQL_INTENT_DENY_PATTERNS = [
|
||||
re.compile(r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", re.IGNORECASE),
|
||||
re.compile(r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", re.IGNORECASE),
|
||||
re.compile(r"(写|生成).*(python|脚本|代码)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
SQL_INTENT_POSITIVE_PATTERNS = [
|
||||
re.compile(r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", re.IGNORECASE),
|
||||
re.compile(r"(按|按.*维度|按.*分组|统计|汇总|分组|排序|筛选|过滤|环比|同比|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", re.IGNORECASE),
|
||||
re.compile(r"(数据库|数据源|数据表|表|字段|列|行|记录).*(查询|检索|列出|统计|分析|对比|查看)", re.IGNORECASE),
|
||||
re.compile(r"(查询|检索|列出|统计|分析|对比|查看).*(数据库|数据源|数据表|表|字段|列|行|记录)", re.IGNORECASE),
|
||||
]
|
||||
|
||||
VISUAL_INTENT_PATTERN = re.compile(r"(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|dashboard|chart|plot|visuali[sz]e)", re.IGNORECASE)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# Initialize nanobot in background
|
||||
@@ -99,28 +115,19 @@ def _looks_like_sql_intent(message: str) -> bool:
|
||||
text = (message or "").strip().lower()
|
||||
if not text:
|
||||
return False
|
||||
deny_patterns = [
|
||||
r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)",
|
||||
r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b",
|
||||
r"(写|生成).*(python|脚本|代码)",
|
||||
]
|
||||
for pattern in deny_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
for pattern in SQL_INTENT_DENY_PATTERNS:
|
||||
if pattern.search(text):
|
||||
return False
|
||||
positive_patterns = [
|
||||
r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b",
|
||||
r"(统计|汇总|分组|排序|筛选|过滤|环比|同比|趋势|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)",
|
||||
r"(多少|几条|多少条|有多少|查询|检索|列出|列表|清单|显示|展示|查看|分析|对比|情况|数据|信息|记录)",
|
||||
r"(chart|plot|visuali[sz]e|dashboard|画图|图表|可视化)",
|
||||
r"\b(list|show|get|find|search|analyze|compare)\b",
|
||||
r"\b(how many|what|which|who|when|where)\b",
|
||||
]
|
||||
for pattern in positive_patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
for pattern in SQL_INTENT_POSITIVE_PATTERNS:
|
||||
if pattern.search(text):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _looks_like_visual_intent(message: str) -> bool:
|
||||
return bool(VISUAL_INTENT_PATTERN.search((message or "").strip().lower()))
|
||||
|
||||
|
||||
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]:
|
||||
# Determine the effective data source from session context or request
|
||||
session_ctx = _session_context_for_routing(request.session_id)
|
||||
@@ -211,7 +218,12 @@ async def nanobot_chat(request: ChatRequest):
|
||||
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
|
||||
if use_nl2sql:
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url)
|
||||
NL2SQLRequest(
|
||||
query=request.message,
|
||||
source=resolved_source,
|
||||
file_url=request.file_url,
|
||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
||||
)
|
||||
)
|
||||
text = _build_sql_chart_text(nl2sql_result)
|
||||
viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||
@@ -239,7 +251,12 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
|
||||
if use_nl2sql:
|
||||
nl2sql_result = await process_nl2sql(
|
||||
NL2SQLRequest(query=request.message, source=resolved_source, file_url=request.file_url)
|
||||
NL2SQLRequest(
|
||||
query=request.message,
|
||||
source=resolved_source,
|
||||
file_url=request.file_url,
|
||||
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
|
||||
)
|
||||
)
|
||||
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
|
||||
viz_payload = {
|
||||
@@ -252,12 +269,31 @@ async def nanobot_chat_stream(request: ChatRequest):
|
||||
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
response = await nanobot_service.process_message(
|
||||
request.message,
|
||||
session_id=request.session_id,
|
||||
skill_ids=request.skill_ids,
|
||||
model_id=request.model_id,
|
||||
progress_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
|
||||
async def _on_progress(content: str, **_: Any) -> None:
|
||||
if content:
|
||||
await progress_queue.put(content)
|
||||
|
||||
task = asyncio.create_task(
|
||||
nanobot_service.process_message(
|
||||
request.message,
|
||||
session_id=request.session_id,
|
||||
skill_ids=request.skill_ids,
|
||||
model_id=request.model_id,
|
||||
on_progress=_on_progress,
|
||||
)
|
||||
)
|
||||
text = ""
|
||||
while True:
|
||||
if task.done() and progress_queue.empty():
|
||||
break
|
||||
try:
|
||||
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
|
||||
yield f"data: {json.dumps({'type': 'delta', 'content': progress}, ensure_ascii=False)}\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
response = await task
|
||||
text = response or ""
|
||||
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
|
||||
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
|
||||
|
||||
Reference in New Issue
Block a user