Files
DataClaw/backend/main.py
T

448 lines
18 KiB
Python
Raw Normal View History

2026-03-17 20:40:56 +08:00
import asyncio
2026-03-17 11:38:02 +08:00
from typing import Any, Dict, List, Optional, Literal, Tuple
2026-03-14 23:15:41 +08:00
from fastapi import FastAPI, HTTPException
2026-03-16 23:16:33 +08:00
from fastapi.encoders import jsonable_encoder
2026-03-14 22:00:36 +08:00
from fastapi.responses import StreamingResponse
2026-03-14 15:44:48 +08:00
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
2026-03-14 22:00:36 +08:00
import json
2026-03-17 11:38:02 +08:00
import re
2026-03-15 18:25:38 +08:00
from datetime import datetime
2026-03-14 15:44:48 +08:00
2026-03-16 22:18:23 +08:00
from app.api import upload, llm, skills, users, datasources, projects, semantic
2026-03-14 15:44:48 +08:00
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.core.nanobot import nanobot_service
2026-03-14 23:15:41 +08:00
from app.core.session_alias_store import session_alias_store
2026-03-14 15:44:48 +08:00
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
2026-03-15 19:36:02 +08:00
from app.database import engine, Base
# Import all models to ensure they are registered
from app.models.user import User
2026-03-16 16:12:35 +08:00
from app.models.project import Project
2026-03-15 19:36:02 +08:00
from app.models.datasource import DataSource
2026-03-14 15:44:48 +08:00
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://localhost:5174", "*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
2026-03-15 19:36:02 +08:00
# Initialize database tables
Base.metadata.create_all(bind=engine)
2026-03-14 15:44:48 +08:00
app.include_router(upload.router, prefix="/api/v1")
app.include_router(llm.router, prefix="/api/v1")
app.include_router(skills.router, prefix="/api/v1")
2026-03-14 19:20:37 +08:00
app.include_router(users.router, prefix="/api/v1")
2026-03-16 16:12:35 +08:00
app.include_router(projects.router, prefix="/api/v1")
2026-03-15 19:36:02 +08:00
app.include_router(datasources.router, prefix="/api/v1")
2026-03-16 22:18:23 +08:00
app.include_router(semantic.router, prefix="/api/v1")
2026-03-14 15:44:48 +08:00
2026-03-17 16:43:55 +08:00
STREAM_DELTA_CHUNK_SIZE = 48
2026-03-17 20:40:56 +08:00
SQL_INTENT_DENY_PATTERNS = [
re.compile(r"\b(sql|query)\b.*(解释|说明|改写|优化|翻译)", re.IGNORECASE),
re.compile(r"(解释|说明|改写|优化|翻译).*\b(sql|query)\b", re.IGNORECASE),
re.compile(r"(写|生成).*(python|脚本|代码)", re.IGNORECASE),
]
SQL_INTENT_POSITIVE_PATTERNS = [
re.compile(r"\b(select|from|where|group by|order by|having|join|union|limit|count|sum|avg|max|min)\b", re.IGNORECASE),
re.compile(r"(按|按.*维度|按.*分组|统计|汇总|分组|排序|筛选|过滤|环比|同比|top\s*\d+|前\d+|占比|均值|平均|最大|最小|总数|总量|明细)", re.IGNORECASE),
re.compile(r"(数据库|数据源|数据表|表|字段|列|行|记录).*(查询|检索|列出|统计|分析|对比|查看)", re.IGNORECASE),
re.compile(r"(查询|检索|列出|统计|分析|对比|查看).*(数据库|数据源|数据表|表|字段|列|行|记录)", re.IGNORECASE),
]
VISUAL_INTENT_PATTERN = re.compile(r"(图表|可视化|画图|作图|柱状图|折线图|饼图|趋势|分布|dashboard|chart|plot|visuali[sz]e)", re.IGNORECASE)
2026-03-14 15:44:48 +08:00
@app.on_event("startup")
async def startup_event():
# Initialize nanobot in background
try:
await nanobot_service.start()
except Exception as e:
print(f"Nanobot startup failed: {e}")
@app.on_event("shutdown")
async def shutdown_event():
await nanobot_service.stop()
@app.get("/")
def read_root():
return {"Hello": "DataClaw Backend"}
@app.get("/connect/postgres")
def test_postgres():
if postgres_connector.test_connection():
return {"status": "success", "message": "Connected to PostgreSQL"}
raise HTTPException(status_code=500, detail="Failed to connect to PostgreSQL")
@app.get("/connect/clickhouse")
def test_clickhouse():
if clickhouse_connector.test_connection():
return {"status": "success", "message": "Connected to ClickHouse"}
raise HTTPException(status_code=500, detail="Failed to connect to ClickHouse")
@app.get("/nanobot/status")
def nanobot_status():
if nanobot_service.agent:
return {"status": "running", "model": nanobot_service.agent.model}
return {"status": "stopped"}
class ChatRequest(BaseModel):
message: str
2026-03-14 22:25:01 +08:00
session_id: str = "api:default"
2026-03-14 15:44:48 +08:00
skill_ids: Optional[List[str]] = None
2026-03-14 22:00:36 +08:00
model_id: Optional[str] = None
2026-03-15 10:49:37 +08:00
source: str = "postgres"
prefer_sql_chart: bool = False
file_url: Optional[str] = None
2026-03-17 11:38:02 +08:00
route_mode: Literal["auto", "chat", "sql"] = "auto"
2026-03-14 15:44:48 +08:00
2026-03-14 23:15:41 +08:00
2026-03-17 11:38:02 +08:00
def _session_context_for_routing(session_id: str) -> Dict[str, Any]:
if not nanobot_service.agent:
return {}
session = nanobot_service.agent.sessions.get_or_create(session_id)
return session.metadata or {}
def _looks_like_sql_intent(message: str) -> bool:
text = (message or "").strip().lower()
if not text:
return False
2026-03-17 20:40:56 +08:00
for pattern in SQL_INTENT_DENY_PATTERNS:
if pattern.search(text):
2026-03-17 11:38:02 +08:00
return False
2026-03-17 20:40:56 +08:00
for pattern in SQL_INTENT_POSITIVE_PATTERNS:
if pattern.search(text):
2026-03-17 11:38:02 +08:00
return True
return False
2026-03-17 20:40:56 +08:00
def _looks_like_visual_intent(message: str) -> bool:
return bool(VISUAL_INTENT_PATTERN.search((message or "").strip().lower()))
2026-03-17 17:49:34 +08:00
def _should_use_nl2sql(request: ChatRequest) -> Tuple[bool, str, str]:
# Determine the effective data source from session context or request
session_ctx = _session_context_for_routing(request.session_id)
session_source = (session_ctx.get("selected_data_source") or "").strip().lower()
request_source = (request.source or "").strip().lower()
# Priority: Session bound source > Request source > "postgres"
effective_source = request_source
if session_source.startswith("ds:") or session_source == "upload":
effective_source = session_source
2026-03-17 11:38:02 +08:00
if request.route_mode == "sql":
2026-03-17 17:49:34 +08:00
return True, "route_mode=sql", effective_source
2026-03-17 11:38:02 +08:00
if request.route_mode == "chat":
2026-03-17 17:49:34 +08:00
return False, "route_mode=chat", effective_source
2026-03-16 23:16:33 +08:00
if request.prefer_sql_chart:
2026-03-17 17:49:34 +08:00
return True, "prefer_sql_chart=true", effective_source
2026-03-17 11:38:02 +08:00
has_sql_intent = _looks_like_sql_intent(request.message)
if not has_sql_intent:
2026-03-17 17:49:34 +08:00
return False, "message_non_sql_intent", effective_source
# If we have intent, check if we have a valid source context
if effective_source.startswith("ds:") or effective_source == "upload":
return True, "message_sql_intent_with_datasource", effective_source
# Even if just "postgres" (default), if intent is strong, we might allow it?
# But usually we want a bound source.
# Let's keep existing logic: if intent is strong, return True.
# But effectively, if source is "postgres", it might fail later if no tables are there.
return True, "message_sql_intent", effective_source
2026-03-16 23:16:33 +08:00
2026-03-14 23:15:41 +08:00
class SessionAliasUpdateRequest(BaseModel):
title: Optional[str] = None
pinned: Optional[bool] = None
archived: Optional[bool] = None
2026-03-15 17:57:09 +08:00
2026-03-15 20:55:42 +08:00
class BatchDeleteRequest(BaseModel):
session_ids: List[str]
2026-03-15 18:25:38 +08:00
class SessionFileContextUpdateRequest(BaseModel):
active_data_file: Optional[Dict[str, Any]] = None
2026-03-16 23:16:33 +08:00
selected_data_source: Optional[str] = None
2026-03-15 18:25:38 +08:00
2026-03-15 17:57:09 +08:00
def _build_sql_chart_text(nl2sql_result: NL2SQLResponse) -> str:
chart = nl2sql_result.chart
can_visualize = bool(chart and chart.can_visualize and chart.chart_spec)
text = (
f"已为你生成 SQL 并查询到 {len(nl2sql_result.result)} 行数据。"
f"{'可视化面板已同步更新图表。' if can_visualize else '本次结果不适合图表展示。'}"
)
if chart and chart.reasoning:
return f"{text}\n\n可视化说明:{chart.reasoning}"
return text
def _build_sql_chart_viz(nl2sql_result: NL2SQLResponse) -> dict:
chart = nl2sql_result.chart
2026-03-16 23:16:33 +08:00
payload = {
2026-03-15 17:57:09 +08:00
"sql": nl2sql_result.sql,
"result": nl2sql_result.result,
"chart": chart.model_dump() if chart else None,
"error": nl2sql_result.error,
}
2026-03-16 23:16:33 +08:00
return jsonable_encoder(payload)
2026-03-15 17:57:09 +08:00
def _persist_session_turn(
session_id: str,
user_message: str,
assistant_message: str,
assistant_extra: Optional[dict] = None,
) -> None:
if not nanobot_service.agent:
return
session = nanobot_service.agent.sessions.get_or_create(session_id)
session.add_message("user", user_message)
session.add_message("assistant", assistant_message, **(assistant_extra or {}))
nanobot_service.agent.sessions.save(session)
2026-03-14 15:44:48 +08:00
@app.post("/nanobot/chat")
async def nanobot_chat(request: ChatRequest):
try:
2026-03-17 17:49:34 +08:00
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
2026-03-17 11:38:02 +08:00
if use_nl2sql:
2026-03-15 10:49:37 +08:00
nl2sql_result = await process_nl2sql(
2026-03-17 20:40:56 +08:00
NL2SQLRequest(
query=request.message,
source=resolved_source,
file_url=request.file_url,
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
)
2026-03-15 10:49:37 +08:00
)
2026-03-15 17:57:09 +08:00
text = _build_sql_chart_text(nl2sql_result)
viz_payload = _build_sql_chart_viz(nl2sql_result)
_persist_session_turn(request.session_id, request.message, text, {"viz": viz_payload})
2026-03-15 10:49:37 +08:00
return {
"response": text,
2026-03-15 17:57:09 +08:00
"viz": viz_payload,
2026-03-17 11:38:02 +08:00
"routing": {"selected": "sql", "reason": route_reason},
2026-03-15 10:49:37 +08:00
}
2026-03-14 23:15:41 +08:00
response = await nanobot_service.process_message(
request.message,
session_id=request.session_id,
skill_ids=request.skill_ids,
model_id=request.model_id,
)
2026-03-17 11:38:02 +08:00
return {"response": response, "routing": {"selected": "chat", "reason": route_reason}}
2026-03-14 15:44:48 +08:00
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
2026-03-14 22:00:36 +08:00
@app.post("/nanobot/chat/stream")
async def nanobot_chat_stream(request: ChatRequest):
async def event_generator():
try:
2026-03-17 17:49:34 +08:00
use_nl2sql, route_reason, resolved_source = _should_use_nl2sql(request)
2026-03-17 11:38:02 +08:00
yield f"data: {json.dumps({'type': 'routing', 'selected': 'sql' if use_nl2sql else 'chat', 'reason': route_reason}, ensure_ascii=False)}\n\n"
if use_nl2sql:
2026-03-17 21:32:01 +08:00
yield f"data: {json.dumps({'type': 'progress', 'content': '已识别为数据分析请求,正在连接数据源'}, ensure_ascii=False)}\n\n"
sql_progress_queue: asyncio.Queue[str] = asyncio.Queue()
async def _on_sql_progress(content: str) -> None:
if content:
await sql_progress_queue.put(content)
sql_task = asyncio.create_task(
process_nl2sql(
NL2SQLRequest(
query=request.message,
source=resolved_source,
file_url=request.file_url,
generate_chart=request.prefer_sql_chart or _looks_like_visual_intent(request.message),
),
on_progress=_on_sql_progress,
2026-03-17 20:40:56 +08:00
)
2026-03-15 10:49:37 +08:00
)
2026-03-17 21:32:01 +08:00
while True:
if sql_task.done() and sql_progress_queue.empty():
break
try:
progress = await asyncio.wait_for(sql_progress_queue.get(), timeout=0.2)
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
except asyncio.TimeoutError:
continue
nl2sql_result = await sql_task
if nl2sql_result.error:
yield f"data: {json.dumps({'type': 'progress', 'content': '数据查询阶段返回错误,正在整理结果'}, ensure_ascii=False)}\n\n"
else:
yield f"data: {json.dumps({'type': 'progress', 'content': 'SQL 已执行完成,正在整理回答'}, ensure_ascii=False)}\n\n"
2026-03-15 17:57:09 +08:00
persisted_viz_payload = _build_sql_chart_viz(nl2sql_result)
2026-03-15 10:49:37 +08:00
viz_payload = {
"type": "viz",
2026-03-15 17:57:09 +08:00
**persisted_viz_payload,
2026-03-15 10:49:37 +08:00
}
yield f"data: {json.dumps(viz_payload, ensure_ascii=False)}\n\n"
2026-03-15 17:57:09 +08:00
text = _build_sql_chart_text(nl2sql_result)
_persist_session_turn(request.session_id, request.message, text, {"viz": persisted_viz_payload})
2026-03-15 10:49:37 +08:00
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return
2026-03-17 20:40:56 +08:00
progress_queue: asyncio.Queue[str] = asyncio.Queue()
async def _on_progress(content: str, **_: Any) -> None:
if content:
await progress_queue.put(content)
task = asyncio.create_task(
nanobot_service.process_message(
request.message,
session_id=request.session_id,
skill_ids=request.skill_ids,
model_id=request.model_id,
on_progress=_on_progress,
)
2026-03-14 22:07:40 +08:00
)
2026-03-17 21:32:01 +08:00
yield f"data: {json.dumps({'type': 'progress', 'content': '已发送给模型,正在分析问题'}, ensure_ascii=False)}\n\n"
2026-03-17 20:40:56 +08:00
text = ""
while True:
if task.done() and progress_queue.empty():
break
try:
progress = await asyncio.wait_for(progress_queue.get(), timeout=0.2)
2026-03-17 21:32:01 +08:00
yield f"data: {json.dumps({'type': 'progress', 'content': progress}, ensure_ascii=False)}\n\n"
2026-03-17 20:40:56 +08:00
except asyncio.TimeoutError:
continue
response = await task
2026-03-14 22:07:40 +08:00
text = response or ""
2026-03-17 16:43:55 +08:00
for idx in range(0, len(text), STREAM_DELTA_CHUNK_SIZE):
chunk = text[idx: idx + STREAM_DELTA_CHUNK_SIZE]
yield f"data: {json.dumps({'type': 'delta', 'content': chunk}, ensure_ascii=False)}\n\n"
2026-03-14 22:07:40 +08:00
yield f"data: {json.dumps({'type': 'final', 'content': text}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
except Exception as e:
yield f"data: {json.dumps({'type': 'error', 'content': str(e)}, ensure_ascii=False)}\n\n"
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
2026-03-14 22:00:36 +08:00
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
2026-03-14 22:25:01 +08:00
@app.get("/nanobot/sessions")
def get_sessions():
if not nanobot_service.agent:
2026-03-14 23:15:41 +08:00
return session_alias_store.list_cached_sessions()
2026-03-14 22:25:01 +08:00
sessions = nanobot_service.agent.sessions.list_sessions()
2026-03-14 23:15:41 +08:00
return session_alias_store.sync_and_list(sessions)
2026-03-14 22:25:01 +08:00
@app.get("/nanobot/sessions/{session_id}")
def get_session(session_id: str):
if not nanobot_service.agent:
raise HTTPException(status_code=400, detail="Nanobot not running")
session = nanobot_service.agent.sessions.get_or_create(session_id)
2026-03-14 23:15:41 +08:00
alias = session_alias_store.get_alias(session_id)
2026-03-14 22:25:01 +08:00
return {
"key": session.key,
"created_at": session.created_at,
"updated_at": session.updated_at,
"metadata": session.metadata,
2026-03-14 23:15:41 +08:00
"alias": alias,
2026-03-14 22:25:01 +08:00
"messages": session.messages
}
2026-03-15 17:05:16 +08:00
@app.post("/nanobot/sessions/{session_id}/ensure")
def ensure_session(session_id: str):
if not nanobot_service.agent:
raise HTTPException(status_code=400, detail="Nanobot not running")
session = nanobot_service.agent.sessions.get_or_create(session_id)
nanobot_service.agent.sessions.save(session)
alias = session_alias_store.get_alias(session_id)
return {
"key": session.key,
"created_at": session.created_at,
"updated_at": session.updated_at,
"metadata": session.metadata,
"alias": alias,
}
2026-03-14 22:25:01 +08:00
@app.delete("/nanobot/sessions/{session_id}")
def delete_session(session_id: str):
if not nanobot_service.agent:
raise HTTPException(status_code=400, detail="Nanobot not running")
# Try to remove from cache and delete file
session = nanobot_service.agent.sessions.get_or_create(session_id)
if session:
nanobot_service.agent.sessions.invalidate(session_id)
path = nanobot_service.agent.sessions._get_session_path(session_id)
if path.exists():
path.unlink()
2026-03-14 23:15:41 +08:00
session_alias_store.delete_session(session_id)
2026-03-14 22:25:01 +08:00
return {"status": "success"}
raise HTTPException(status_code=404, detail="Session not found")
2026-03-15 20:55:42 +08:00
@app.post("/nanobot/sessions/batch-delete")
def batch_delete_sessions(request: BatchDeleteRequest):
if not nanobot_service.agent:
raise HTTPException(status_code=400, detail="Nanobot not running")
deleted_ids = []
for session_id in request.session_ids:
try:
# Try to remove from cache and delete file
session = nanobot_service.agent.sessions.get_or_create(session_id)
if session:
nanobot_service.agent.sessions.invalidate(session_id)
path = nanobot_service.agent.sessions._get_session_path(session_id)
if path.exists():
path.unlink()
session_alias_store.delete_session(session_id)
deleted_ids.append(session_id)
except Exception as e:
print(f"Failed to delete session {session_id}: {e}")
return {"status": "success", "deleted_count": len(deleted_ids), "deleted_ids": deleted_ids}
2026-03-14 22:25:01 +08:00
@app.put("/nanobot/sessions/{session_id}")
2026-03-14 23:15:41 +08:00
def update_session(session_id: str, payload: SessionAliasUpdateRequest):
updated = session_alias_store.update_alias_meta(
session_key=session_id,
alias=payload.title,
pinned=payload.pinned,
archived=payload.archived,
)
return {"status": "success", **updated}
2026-03-14 22:25:01 +08:00
2026-03-15 18:25:38 +08:00
@app.put("/nanobot/sessions/{session_id}/context-file")
def update_session_context_file(session_id: str, payload: SessionFileContextUpdateRequest):
if not nanobot_service.agent:
raise HTTPException(status_code=400, detail="Nanobot not running")
session = nanobot_service.agent.sessions.get_or_create(session_id)
2026-03-16 23:16:33 +08:00
updated_fields = payload.model_fields_set
if "active_data_file" in updated_fields:
if payload.active_data_file is None:
session.metadata.pop("active_data_file", None)
else:
session.metadata["active_data_file"] = payload.active_data_file
if "selected_data_source" in updated_fields:
if payload.selected_data_source:
session.metadata["selected_data_source"] = payload.selected_data_source
else:
session.metadata.pop("selected_data_source", None)
2026-03-15 18:25:38 +08:00
session.updated_at = datetime.now()
nanobot_service.agent.sessions.save(session)
return {"status": "success", "metadata": session.metadata}