feat: knowledge base first OK
This commit is contained in:
@@ -0,0 +1,257 @@
|
||||
from typing import List, Optional
|
||||
import io
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import UploadFile, File, Form
|
||||
from openai import OpenAI
|
||||
import pandas as pd
|
||||
|
||||
from app.schemas.knowledge import (
|
||||
KnowledgeBase,
|
||||
KnowledgeBaseCreate,
|
||||
KnowledgeConnectionTestRequest,
|
||||
KnowledgeConnectionTestResponse,
|
||||
KnowledgeGlobalConfig,
|
||||
KnowledgeGlobalConfigUpdate,
|
||||
KnowledgeBaseUpdate,
|
||||
KnowledgeDocument,
|
||||
KnowledgeDocumentCreate,
|
||||
KnowledgeDocumentUpdate,
|
||||
KnowledgeSearchRequest,
|
||||
KnowledgeSearchResponse,
|
||||
)
|
||||
from app.services.knowledge_base_store import knowledge_base_store
|
||||
from app.services.knowledge_global_config_store import knowledge_global_config_store
|
||||
from app.services.knowledge_index import knowledge_index_service
|
||||
from app.services.openai_compat import normalize_openai_base_url
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _mask_api_key(value: Optional[str]) -> Optional[str]:
|
||||
if not value:
|
||||
return None
|
||||
if len(value) <= 8:
|
||||
return "*" * len(value)
|
||||
return f"{value[:4]}{'*' * (len(value) - 8)}{value[-4:]}"
|
||||
|
||||
|
||||
def _extract_upload_text(filename: str, content: bytes) -> str:
|
||||
lower = filename.lower()
|
||||
if lower.endswith((".txt", ".md", ".markdown", ".json", ".yaml", ".yml", ".log", ".xml", ".html", ".htm")):
|
||||
try:
|
||||
return content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
return content.decode("utf-8", errors="ignore")
|
||||
if lower.endswith(".csv"):
|
||||
df = pd.read_csv(io.BytesIO(content))
|
||||
return df.to_csv(index=False)
|
||||
if lower.endswith((".xls", ".xlsx")):
|
||||
df = pd.read_excel(io.BytesIO(content))
|
||||
return df.to_csv(index=False)
|
||||
raise ValueError("Unsupported file type")
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
|
||||
def get_knowledge_global_config():
|
||||
config = knowledge_global_config_store.get()
|
||||
raw_api_key = config.get("api_key")
|
||||
return {
|
||||
"api_base": config.get("api_base"),
|
||||
"api_key": None,
|
||||
"api_key_masked": _mask_api_key(raw_api_key),
|
||||
"has_api_key": bool(raw_api_key),
|
||||
"default_embedding_model": config.get("default_embedding_model"),
|
||||
}
|
||||
|
||||
|
||||
@router.put("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
|
||||
def update_knowledge_global_config(payload: KnowledgeGlobalConfigUpdate):
|
||||
updated = knowledge_global_config_store.update(payload.model_dump(exclude_unset=True))
|
||||
raw_api_key = updated.get("api_key")
|
||||
return {
|
||||
"api_base": updated.get("api_base"),
|
||||
"api_key": None,
|
||||
"api_key_masked": _mask_api_key(raw_api_key),
|
||||
"has_api_key": bool(raw_api_key),
|
||||
"default_embedding_model": updated.get("default_embedding_model"),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/global-config/test-connection", response_model=KnowledgeConnectionTestResponse)
|
||||
def test_knowledge_global_connection(payload: KnowledgeConnectionTestRequest):
|
||||
saved = knowledge_global_config_store.get()
|
||||
api_base = normalize_openai_base_url(payload.api_base or saved.get("api_base") or "")
|
||||
api_key = payload.api_key or saved.get("api_key")
|
||||
model_name = (payload.model_name or "").strip()
|
||||
|
||||
if not api_base:
|
||||
raise HTTPException(status_code=400, detail="API Base 未配置")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="API Key 未配置")
|
||||
if not model_name:
|
||||
raise HTTPException(status_code=400, detail="测试连接必须显式填写向量模型名称")
|
||||
|
||||
if not api_base:
|
||||
raise HTTPException(status_code=400, detail="API Base 未配置")
|
||||
try:
|
||||
client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
)
|
||||
embedding_resp = client.embeddings.create(
|
||||
model=model_name,
|
||||
input="connection test",
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=400, detail=f"Embedding调用失败: {exc}")
|
||||
|
||||
dimension = None
|
||||
if getattr(embedding_resp, "data", None):
|
||||
first = embedding_resp.data[0]
|
||||
vector = getattr(first, "embedding", None)
|
||||
if isinstance(vector, list):
|
||||
dimension = len(vector)
|
||||
return {
|
||||
"success": True,
|
||||
"message": "连接成功,Embedding调用正常",
|
||||
"model_name": model_name,
|
||||
"embedding_dimension": dimension,
|
||||
"resolved_api_base": api_base,
|
||||
"available_models": [],
|
||||
}
|
||||
|
||||
|
||||
@router.get("/knowledge-bases", response_model=List[KnowledgeBase])
|
||||
def list_knowledge_bases(project_id: Optional[int] = None):
|
||||
return knowledge_base_store.list(project_id=project_id)
|
||||
|
||||
|
||||
@router.post("/knowledge-bases", response_model=KnowledgeBase)
|
||||
def create_knowledge_base(payload: KnowledgeBaseCreate):
|
||||
return knowledge_base_store.create(payload.model_dump())
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
|
||||
def get_knowledge_base(kb_id: str):
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return kb
|
||||
|
||||
|
||||
@router.put("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
|
||||
def update_knowledge_base(kb_id: str, payload: KnowledgeBaseUpdate):
|
||||
kb = knowledge_base_store.update(kb_id, payload.model_dump(exclude_unset=True))
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return kb
|
||||
|
||||
|
||||
@router.delete("/knowledge-bases/{kb_id}")
|
||||
def delete_knowledge_base(kb_id: str):
|
||||
deleted = knowledge_base_store.delete(kb_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.get("/knowledge-bases/{kb_id}/documents", response_model=List[KnowledgeDocument])
|
||||
def list_knowledge_documents(kb_id: str):
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return kb.get("documents", [])
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/documents", response_model=KnowledgeDocument)
|
||||
def create_knowledge_document(kb_id: str, payload: KnowledgeDocumentCreate):
|
||||
doc = knowledge_base_store.create_document(kb_id=kb_id, payload=payload.model_dump())
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
return doc
|
||||
|
||||
|
||||
@router.put("/knowledge-bases/{kb_id}/documents/{doc_id}", response_model=KnowledgeDocument)
|
||||
def update_knowledge_document(kb_id: str, doc_id: str, payload: KnowledgeDocumentUpdate):
|
||||
doc = knowledge_base_store.update_document(
|
||||
kb_id=kb_id,
|
||||
doc_id=doc_id,
|
||||
payload=payload.model_dump(exclude_unset=True),
|
||||
)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Knowledge document not found")
|
||||
return doc
|
||||
|
||||
|
||||
@router.delete("/knowledge-bases/{kb_id}/documents/{doc_id}")
|
||||
def delete_knowledge_document(kb_id: str, doc_id: str):
|
||||
deleted = knowledge_base_store.delete_document(kb_id=kb_id, doc_id=doc_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Knowledge document not found")
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/reindex")
|
||||
def reindex_knowledge_base(kb_id: str):
|
||||
try:
|
||||
return knowledge_index_service.reindex(kb_id)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/search", response_model=KnowledgeSearchResponse)
|
||||
def search_knowledge_base(kb_id: str, payload: KnowledgeSearchRequest):
|
||||
try:
|
||||
result = knowledge_index_service.search(
|
||||
kb_id=kb_id,
|
||||
query=payload.query,
|
||||
top_k=payload.top_k,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=404, detail=str(exc))
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/knowledge-bases/{kb_id}/documents/upload")
|
||||
async def upload_knowledge_documents(
|
||||
kb_id: str,
|
||||
files: List[UploadFile] = File(...),
|
||||
metadata: Optional[str] = Form(default=None),
|
||||
):
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
||||
metadata_payload: dict[str, Any] = {}
|
||||
if metadata:
|
||||
try:
|
||||
parsed_metadata = json.loads(metadata)
|
||||
if isinstance(parsed_metadata, dict):
|
||||
metadata_payload = parsed_metadata
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=400, detail="metadata 必须是合法 JSON 对象")
|
||||
|
||||
created: List[dict[str, Any]] = []
|
||||
for file in files:
|
||||
filename = file.filename or "untitled"
|
||||
content = await file.read()
|
||||
if not content:
|
||||
continue
|
||||
if len(content) > 5 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail=f"文件过大: {filename}")
|
||||
try:
|
||||
text = _extract_upload_text(filename, content)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {filename}")
|
||||
doc = knowledge_base_store.create_document(
|
||||
kb_id=kb_id,
|
||||
payload={
|
||||
"title": filename,
|
||||
"content": text,
|
||||
"metadata": {**metadata_payload, "source": "upload", "filename": filename},
|
||||
},
|
||||
)
|
||||
if doc:
|
||||
created.append(doc)
|
||||
return {"status": "success", "count": len(created), "documents": created}
|
||||
@@ -19,3 +19,5 @@ current_data_source: ContextVar[str] = ContextVar("current_data_source", default
|
||||
|
||||
# Any file URL attached to the request
|
||||
current_file_url: ContextVar[Optional[str]] = ContextVar("current_file_url", default=None)
|
||||
|
||||
current_knowledge_base_id: ContextVar[Optional[str]] = ContextVar("current_knowledge_base_id", default=None)
|
||||
|
||||
@@ -204,10 +204,12 @@ class NanobotIntegration:
|
||||
from app.tools.nl2sql import NL2SQLTool
|
||||
from app.tools.visualization import VisualizationTool
|
||||
from app.tools.get_schema import GetDatabaseSchemaTool
|
||||
from app.tools.knowledge_base import KnowledgeBaseRetrieveTool
|
||||
from app.tools.subagent import ListSubagentsTool, InvokeSubagentTool
|
||||
agent.tools.register(NL2SQLTool())
|
||||
agent.tools.register(VisualizationTool())
|
||||
agent.tools.register(GetDatabaseSchemaTool())
|
||||
agent.tools.register(KnowledgeBaseRetrieveTool())
|
||||
agent.tools.register(ListSubagentsTool(project_id=project_id))
|
||||
agent.tools.register(InvokeSubagentTool(project_id=project_id))
|
||||
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class KnowledgeDocumentBase(BaseModel):
|
||||
title: str = Field(..., min_length=1, max_length=200)
|
||||
content: str = Field(..., min_length=1)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class KnowledgeDocumentCreate(KnowledgeDocumentBase):
|
||||
pass
|
||||
|
||||
|
||||
class KnowledgeDocumentUpdate(BaseModel):
|
||||
title: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
content: Optional[str] = Field(None, min_length=1)
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class KnowledgeDocument(KnowledgeDocumentBase):
|
||||
id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class KnowledgeBaseConfigBase(BaseModel):
|
||||
name: str = Field(..., min_length=1, max_length=120)
|
||||
description: Optional[str] = None
|
||||
project_id: Optional[int] = None
|
||||
embedding_model: Optional[str] = None
|
||||
chunk_size: int = Field(default=512, ge=64, le=4096)
|
||||
chunk_overlap: int = Field(default=50, ge=0, le=512)
|
||||
top_k: int = Field(default=3, ge=1, le=20)
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class KnowledgeBaseCreate(KnowledgeBaseConfigBase):
|
||||
pass
|
||||
|
||||
|
||||
class KnowledgeBaseUpdate(BaseModel):
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=120)
|
||||
description: Optional[str] = None
|
||||
project_id: Optional[int] = None
|
||||
embedding_model: Optional[str] = None
|
||||
chunk_size: Optional[int] = Field(None, ge=64, le=4096)
|
||||
chunk_overlap: Optional[int] = Field(None, ge=0, le=512)
|
||||
top_k: Optional[int] = Field(None, ge=1, le=20)
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class KnowledgeBase(KnowledgeBaseConfigBase):
|
||||
id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
documents: List[KnowledgeDocument] = Field(default_factory=list)
|
||||
|
||||
|
||||
class KnowledgeSearchRequest(BaseModel):
|
||||
query: str = Field(..., min_length=1)
|
||||
top_k: Optional[int] = Field(default=None, ge=1, le=20)
|
||||
|
||||
|
||||
class KnowledgeSearchHit(BaseModel):
|
||||
doc_id: str
|
||||
title: str
|
||||
chunk: str
|
||||
score: float
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class KnowledgeSearchResponse(BaseModel):
|
||||
answer: str
|
||||
hits: List[KnowledgeSearchHit] = Field(default_factory=list)
|
||||
|
||||
|
||||
class KnowledgeGlobalConfigUpdate(BaseModel):
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
default_embedding_model: Optional[str] = None
|
||||
|
||||
@field_validator("api_base")
|
||||
@classmethod
|
||||
def validate_api_base(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if not (normalized.startswith("http://") or normalized.startswith("https://")):
|
||||
raise ValueError("api_base must start with http:// or https://")
|
||||
return normalized.rstrip("/")
|
||||
|
||||
@field_validator("api_key")
|
||||
@classmethod
|
||||
def validate_api_key(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if len(normalized) > 512:
|
||||
raise ValueError("api_key is too long")
|
||||
return normalized
|
||||
|
||||
@field_validator("default_embedding_model")
|
||||
@classmethod
|
||||
def validate_default_embedding_model(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if len(normalized) > 200:
|
||||
raise ValueError("default_embedding_model is too long")
|
||||
return normalized
|
||||
|
||||
|
||||
class KnowledgeGlobalConfig(BaseModel):
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
api_key_masked: Optional[str] = None
|
||||
has_api_key: bool = False
|
||||
default_embedding_model: Optional[str] = None
|
||||
|
||||
|
||||
class KnowledgeConnectionTestRequest(BaseModel):
|
||||
api_base: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
model_name: Optional[str] = None
|
||||
|
||||
@field_validator("api_base")
|
||||
@classmethod
|
||||
def validate_test_api_base(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if not (normalized.startswith("http://") or normalized.startswith("https://")):
|
||||
raise ValueError("api_base must start with http:// or https://")
|
||||
return normalized.rstrip("/")
|
||||
|
||||
@field_validator("api_key", "model_name")
|
||||
@classmethod
|
||||
def normalize_test_value(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
normalized = value.strip()
|
||||
return normalized or None
|
||||
|
||||
|
||||
class KnowledgeConnectionTestResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
model_name: Optional[str] = None
|
||||
embedding_dimension: Optional[int] = None
|
||||
resolved_api_base: Optional[str] = None
|
||||
available_models: List[str] = Field(default_factory=list)
|
||||
@@ -0,0 +1,188 @@
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
|
||||
def _utcnow_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
class KnowledgeBaseStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def _file_path() -> Path:
|
||||
return get_data_root() / "knowledge_bases.json"
|
||||
|
||||
def _read(self) -> List[Dict[str, Any]]:
|
||||
file_path = self._file_path()
|
||||
if not file_path.exists():
|
||||
return []
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return []
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
return data
|
||||
|
||||
def _write(self, data: List[Dict[str, Any]]) -> None:
|
||||
file_path = self._file_path()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_documents(item: Dict[str, Any]) -> None:
|
||||
docs = item.get("documents")
|
||||
if not isinstance(docs, list):
|
||||
item["documents"] = []
|
||||
return
|
||||
normalized: List[Dict[str, Any]] = []
|
||||
for doc in docs:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
if not doc.get("id"):
|
||||
doc["id"] = str(uuid.uuid4())
|
||||
now = _utcnow_iso()
|
||||
doc.setdefault("created_at", now)
|
||||
doc.setdefault("updated_at", now)
|
||||
doc.setdefault("metadata", {})
|
||||
normalized.append(doc)
|
||||
item["documents"] = normalized
|
||||
|
||||
def list(self, project_id: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for item in data:
|
||||
self._normalize_documents(item)
|
||||
if project_id is None:
|
||||
return data
|
||||
return [item for item in data if item.get("project_id") == project_id]
|
||||
|
||||
def get(self, kb_id: str) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
for item in self._read():
|
||||
if item.get("id") == kb_id:
|
||||
self._normalize_documents(item)
|
||||
return item
|
||||
return None
|
||||
|
||||
def create(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
now = _utcnow_iso()
|
||||
item = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"name": payload["name"],
|
||||
"description": payload.get("description"),
|
||||
"project_id": payload.get("project_id"),
|
||||
"embedding_model": payload.get("embedding_model"),
|
||||
"chunk_size": payload.get("chunk_size", 512),
|
||||
"chunk_overlap": payload.get("chunk_overlap", 50),
|
||||
"top_k": payload.get("top_k", 3),
|
||||
"is_active": payload.get("is_active", True),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"documents": [],
|
||||
}
|
||||
data.append(item)
|
||||
self._write(data)
|
||||
return item
|
||||
|
||||
def update(self, kb_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for idx, item in enumerate(data):
|
||||
if item.get("id") != kb_id:
|
||||
continue
|
||||
for key, value in payload.items():
|
||||
item[key] = value
|
||||
item["updated_at"] = _utcnow_iso()
|
||||
self._normalize_documents(item)
|
||||
data[idx] = item
|
||||
self._write(data)
|
||||
return item
|
||||
return None
|
||||
|
||||
def delete(self, kb_id: str) -> bool:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
filtered = [item for item in data if item.get("id") != kb_id]
|
||||
if len(filtered) == len(data):
|
||||
return False
|
||||
self._write(filtered)
|
||||
return True
|
||||
|
||||
def create_document(self, kb_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for idx, item in enumerate(data):
|
||||
if item.get("id") != kb_id:
|
||||
continue
|
||||
now = _utcnow_iso()
|
||||
doc = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"title": payload["title"],
|
||||
"content": payload["content"],
|
||||
"metadata": payload.get("metadata", {}),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
self._normalize_documents(item)
|
||||
item["documents"].append(doc)
|
||||
item["updated_at"] = now
|
||||
data[idx] = item
|
||||
self._write(data)
|
||||
return doc
|
||||
return None
|
||||
|
||||
def update_document(self, kb_id: str, doc_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for kb_idx, item in enumerate(data):
|
||||
if item.get("id") != kb_id:
|
||||
continue
|
||||
self._normalize_documents(item)
|
||||
docs = item["documents"]
|
||||
for doc_idx, doc in enumerate(docs):
|
||||
if doc.get("id") != doc_id:
|
||||
continue
|
||||
for key, value in payload.items():
|
||||
doc[key] = value
|
||||
doc["updated_at"] = _utcnow_iso()
|
||||
docs[doc_idx] = doc
|
||||
item["updated_at"] = _utcnow_iso()
|
||||
data[kb_idx] = item
|
||||
self._write(data)
|
||||
return doc
|
||||
return None
|
||||
return None
|
||||
|
||||
def delete_document(self, kb_id: str, doc_id: str) -> bool:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
for kb_idx, item in enumerate(data):
|
||||
if item.get("id") != kb_id:
|
||||
continue
|
||||
self._normalize_documents(item)
|
||||
docs = item["documents"]
|
||||
filtered = [doc for doc in docs if doc.get("id") != doc_id]
|
||||
if len(filtered) == len(docs):
|
||||
return False
|
||||
item["documents"] = filtered
|
||||
item["updated_at"] = _utcnow_iso()
|
||||
data[kb_idx] = item
|
||||
self._write(data)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
knowledge_base_store = KnowledgeBaseStore()
|
||||
@@ -0,0 +1,58 @@
|
||||
import json
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
|
||||
class KnowledgeGlobalConfigStore:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def _file_path() -> Path:
|
||||
return get_data_root() / "knowledge_global_config.json"
|
||||
|
||||
def _read(self) -> Dict[str, Any]:
|
||||
file_path = self._file_path()
|
||||
if not file_path.exists():
|
||||
return {}
|
||||
try:
|
||||
with file_path.open("r", encoding="utf-8") as file_obj:
|
||||
data = json.load(file_obj)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
return data
|
||||
|
||||
def _write(self, data: Dict[str, Any]) -> None:
|
||||
file_path = self._file_path()
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with file_path.open("w", encoding="utf-8") as file_obj:
|
||||
json.dump(data, file_obj, indent=2, ensure_ascii=False)
|
||||
|
||||
def get(self) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
data = self._read()
|
||||
return {
|
||||
"api_base": data.get("api_base"),
|
||||
"api_key": data.get("api_key"),
|
||||
"default_embedding_model": data.get("default_embedding_model"),
|
||||
}
|
||||
|
||||
def update(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
with self._lock:
|
||||
current = self.get()
|
||||
if "api_base" in payload:
|
||||
current["api_base"] = payload.get("api_base")
|
||||
if "api_key" in payload:
|
||||
current["api_key"] = payload.get("api_key")
|
||||
if "default_embedding_model" in payload:
|
||||
current["default_embedding_model"] = payload.get("default_embedding_model")
|
||||
self._write(current)
|
||||
return current
|
||||
|
||||
|
||||
knowledge_global_config_store = KnowledgeGlobalConfigStore()
|
||||
@@ -0,0 +1,250 @@
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from app.services.knowledge_base_store import knowledge_base_store
|
||||
from app.services.knowledge_global_config_store import knowledge_global_config_store
|
||||
from app.services.openai_compat import normalize_openai_base_url
|
||||
|
||||
try:
|
||||
from llama_index.core import Document, VectorStoreIndex
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
LLAMAINDEX_AVAILABLE = True
|
||||
except Exception:
|
||||
Document = Any
|
||||
VectorStoreIndex = Any
|
||||
SentenceSplitter = Any
|
||||
LLAMAINDEX_AVAILABLE = False
|
||||
|
||||
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
return re.findall(r"[a-zA-Z0-9]+|[\u4e00-\u9fff]", (text or "").lower())
|
||||
|
||||
|
||||
def _normalize_embedding_api_base(api_base: str) -> str:
|
||||
return normalize_openai_base_url(api_base)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchHit:
|
||||
doc_id: str
|
||||
title: str
|
||||
chunk: str
|
||||
score: float
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class KnowledgeIndexService:
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.RLock()
|
||||
self._cache: Dict[str, Tuple[str, Any, List[Dict[str, Any]]]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _signature(kb: Dict[str, Any]) -> str:
|
||||
doc_parts = []
|
||||
for doc in kb.get("documents", []):
|
||||
doc_parts.append(f"{doc.get('id')}:{doc.get('updated_at')}:{len(doc.get('content', ''))}")
|
||||
return "|".join(
|
||||
[
|
||||
str(kb.get("updated_at")),
|
||||
str(kb.get("chunk_size")),
|
||||
str(kb.get("chunk_overlap")),
|
||||
*doc_parts,
|
||||
]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fallback_chunks(kb: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
chunks: List[Dict[str, Any]] = []
|
||||
chunk_size = int(kb.get("chunk_size") or 512)
|
||||
overlap = int(kb.get("chunk_overlap") or 50)
|
||||
step = max(1, chunk_size - overlap)
|
||||
for doc in kb.get("documents", []):
|
||||
text = doc.get("content") or ""
|
||||
if not text:
|
||||
continue
|
||||
if len(text) <= chunk_size:
|
||||
chunks.append(
|
||||
{
|
||||
"doc_id": doc.get("id", ""),
|
||||
"title": doc.get("title", ""),
|
||||
"chunk": text,
|
||||
"metadata": doc.get("metadata") or {},
|
||||
}
|
||||
)
|
||||
continue
|
||||
for start in range(0, len(text), step):
|
||||
piece = text[start : start + chunk_size]
|
||||
if not piece:
|
||||
continue
|
||||
chunks.append(
|
||||
{
|
||||
"doc_id": doc.get("id", ""),
|
||||
"title": doc.get("title", ""),
|
||||
"chunk": piece,
|
||||
"metadata": doc.get("metadata") or {},
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
|
||||
def _build_index(self, kb: Dict[str, Any]) -> Tuple[Any, List[Dict[str, Any]]]:
|
||||
fallback_chunks = self._fallback_chunks(kb)
|
||||
if not LLAMAINDEX_AVAILABLE:
|
||||
return None, fallback_chunks
|
||||
chunk_size = int(kb.get("chunk_size") or 512)
|
||||
overlap = int(kb.get("chunk_overlap") or 50)
|
||||
splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
|
||||
docs = [
|
||||
Document(
|
||||
text=(doc.get("content") or ""),
|
||||
metadata={
|
||||
"doc_id": doc.get("id", ""),
|
||||
"title": doc.get("title", ""),
|
||||
**(doc.get("metadata") or {}),
|
||||
},
|
||||
)
|
||||
for doc in kb.get("documents", [])
|
||||
if (doc.get("content") or "").strip()
|
||||
]
|
||||
if not docs:
|
||||
return None, fallback_chunks
|
||||
embed_model = self._build_embed_model(kb)
|
||||
if embed_model is not None:
|
||||
index = VectorStoreIndex.from_documents(
|
||||
docs,
|
||||
transformations=[splitter],
|
||||
embed_model=embed_model,
|
||||
)
|
||||
else:
|
||||
index = VectorStoreIndex.from_documents(docs, transformations=[splitter])
|
||||
return index, fallback_chunks
|
||||
|
||||
@staticmethod
|
||||
def _build_embed_model(kb: Dict[str, Any]) -> Any:
|
||||
global_config = knowledge_global_config_store.get()
|
||||
api_base = global_config.get("api_base")
|
||||
api_key = global_config.get("api_key")
|
||||
model_name = kb.get("embedding_model") or global_config.get("default_embedding_model")
|
||||
if not api_base or not api_key or not model_name:
|
||||
return None
|
||||
api_base = _normalize_embedding_api_base(api_base)
|
||||
try:
|
||||
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
|
||||
|
||||
return OpenAILikeEmbedding(
|
||||
model_name=model_name,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
embed_batch_size=10,
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
return OpenAIEmbedding(
|
||||
model_name=model_name,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
embed_batch_size=10,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def reindex(self, kb_id: str) -> Dict[str, Any]:
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise ValueError("Knowledge base not found")
|
||||
with self._lock:
|
||||
signature = self._signature(kb)
|
||||
index, fallback_chunks = self._build_index(kb)
|
||||
self._cache[kb_id] = (signature, index, fallback_chunks)
|
||||
return {
|
||||
"kb_id": kb_id,
|
||||
"status": "ok",
|
||||
"documents": len(kb.get("documents", [])),
|
||||
"engine": "llamaindex" if LLAMAINDEX_AVAILABLE and index is not None else "fallback",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _fallback_search(query: str, chunks: List[Dict[str, Any]], top_k: int) -> List[SearchHit]:
|
||||
q_tokens = _tokenize(query)
|
||||
if not q_tokens:
|
||||
return []
|
||||
q_set = set(q_tokens)
|
||||
scored: List[SearchHit] = []
|
||||
for chunk_item in chunks:
|
||||
c_tokens = _tokenize(chunk_item.get("chunk", ""))
|
||||
if not c_tokens:
|
||||
continue
|
||||
overlap = sum(1 for t in c_tokens if t in q_set)
|
||||
if overlap == 0:
|
||||
continue
|
||||
score = overlap / math.sqrt(len(c_tokens))
|
||||
scored.append(
|
||||
SearchHit(
|
||||
doc_id=chunk_item.get("doc_id", ""),
|
||||
title=chunk_item.get("title", ""),
|
||||
chunk=chunk_item.get("chunk", ""),
|
||||
score=float(score),
|
||||
metadata=chunk_item.get("metadata") or {},
|
||||
)
|
||||
)
|
||||
scored.sort(key=lambda x: x.score, reverse=True)
|
||||
return scored[:top_k]
|
||||
|
||||
def search(self, kb_id: str, query: str, top_k: int | None = None) -> Dict[str, Any]:
|
||||
kb = knowledge_base_store.get(kb_id)
|
||||
if not kb:
|
||||
raise ValueError("Knowledge base not found")
|
||||
if not kb.get("documents"):
|
||||
return {"answer": "", "hits": []}
|
||||
effective_top_k = int(top_k or kb.get("top_k") or 3)
|
||||
with self._lock:
|
||||
signature = self._signature(kb)
|
||||
cached = self._cache.get(kb_id)
|
||||
if not cached or cached[0] != signature:
|
||||
index, fallback_chunks = self._build_index(kb)
|
||||
cached = (signature, index, fallback_chunks)
|
||||
self._cache[kb_id] = cached
|
||||
_, index, fallback_chunks = cached
|
||||
if index is None:
|
||||
hits = self._fallback_search(query=query, chunks=fallback_chunks, top_k=effective_top_k)
|
||||
answer = "\n\n".join(hit.chunk for hit in hits)
|
||||
return {
|
||||
"answer": answer,
|
||||
"hits": [hit.__dict__ for hit in hits],
|
||||
}
|
||||
retriever = index.as_retriever(similarity_top_k=effective_top_k)
|
||||
response_nodes = retriever.retrieve(query)
|
||||
hits: List[Dict[str, Any]] = []
|
||||
for node_with_score in response_nodes:
|
||||
node = getattr(node_with_score, "node", None)
|
||||
metadata = getattr(node, "metadata", {}) if node is not None else {}
|
||||
chunk_text = ""
|
||||
if node is not None and hasattr(node, "get_content"):
|
||||
chunk_text = node.get_content()
|
||||
elif node is not None:
|
||||
chunk_text = str(getattr(node, "text", ""))
|
||||
hits.append(
|
||||
{
|
||||
"doc_id": metadata.get("doc_id", ""),
|
||||
"title": metadata.get("title", ""),
|
||||
"chunk": chunk_text,
|
||||
"score": float(getattr(node_with_score, "score", 0.0) or 0.0),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
if not hits:
|
||||
fallback_hits = self._fallback_search(query=query, chunks=fallback_chunks, top_k=effective_top_k)
|
||||
return {
|
||||
"answer": "\n\n".join(hit.chunk for hit in fallback_hits),
|
||||
"hits": [hit.__dict__ for hit in fallback_hits],
|
||||
}
|
||||
answer = "\n\n".join(item.get("chunk", "") for item in hits if item.get("chunk"))
|
||||
return {"answer": answer, "hits": hits}
|
||||
|
||||
|
||||
knowledge_index_service = KnowledgeIndexService()
|
||||
@@ -0,0 +1,5 @@
|
||||
def normalize_openai_base_url(api_base: str) -> str:
|
||||
normalized = (api_base or "").strip().rstrip("/")
|
||||
if normalized.lower().endswith("/embeddings"):
|
||||
normalized = normalized[: -len("/embeddings")]
|
||||
return normalized
|
||||
@@ -0,0 +1,59 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool
|
||||
|
||||
from app.context import current_knowledge_base_id
|
||||
from app.services.knowledge_index import knowledge_index_service
|
||||
|
||||
|
||||
class KnowledgeBaseRetrieveTool(Tool):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "knowledge_retrieve"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Retrieve relevant context from the selected knowledge base to answer user questions."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "User question or retrieval query.",
|
||||
},
|
||||
"knowledge_base_id": {
|
||||
"type": "string",
|
||||
"description": "Optional knowledge base id, defaults to current session setting.",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of returned chunks.",
|
||||
"minimum": 1,
|
||||
"maximum": 20,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
query = (kwargs.get("query") or "").strip()
|
||||
if not query:
|
||||
return "Query is required."
|
||||
kb_id = (kwargs.get("knowledge_base_id") or current_knowledge_base_id.get() or "").strip()
|
||||
if not kb_id:
|
||||
return "No knowledge base is selected in this session."
|
||||
top_k = kwargs.get("top_k")
|
||||
try:
|
||||
result = knowledge_index_service.search(kb_id=kb_id, query=query, top_k=top_k)
|
||||
except ValueError as exc:
|
||||
return str(exc)
|
||||
payload = {
|
||||
"knowledge_base_id": kb_id,
|
||||
"answer": result.get("answer", ""),
|
||||
"hits": result.get("hits", []),
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False)
|
||||
Reference in New Issue
Block a user