feat: knowledge base first OK

This commit is contained in:
qixinbo
2026-03-29 00:20:53 +08:00
parent bd7776d1b7
commit 92e8c40826
17 changed files with 3357 additions and 10 deletions
+257
View File
@@ -0,0 +1,257 @@
from typing import List, Optional
import io
import json
from fastapi import APIRouter, HTTPException
from fastapi import UploadFile, File, Form
from openai import OpenAI
import pandas as pd
from app.schemas.knowledge import (
KnowledgeBase,
KnowledgeBaseCreate,
KnowledgeConnectionTestRequest,
KnowledgeConnectionTestResponse,
KnowledgeGlobalConfig,
KnowledgeGlobalConfigUpdate,
KnowledgeBaseUpdate,
KnowledgeDocument,
KnowledgeDocumentCreate,
KnowledgeDocumentUpdate,
KnowledgeSearchRequest,
KnowledgeSearchResponse,
)
from app.services.knowledge_base_store import knowledge_base_store
from app.services.knowledge_global_config_store import knowledge_global_config_store
from app.services.knowledge_index import knowledge_index_service
from app.services.openai_compat import normalize_openai_base_url
router = APIRouter()
def _mask_api_key(value: Optional[str]) -> Optional[str]:
if not value:
return None
if len(value) <= 8:
return "*" * len(value)
return f"{value[:4]}{'*' * (len(value) - 8)}{value[-4:]}"
def _extract_upload_text(filename: str, content: bytes) -> str:
lower = filename.lower()
if lower.endswith((".txt", ".md", ".markdown", ".json", ".yaml", ".yml", ".log", ".xml", ".html", ".htm")):
try:
return content.decode("utf-8")
except UnicodeDecodeError:
return content.decode("utf-8", errors="ignore")
if lower.endswith(".csv"):
df = pd.read_csv(io.BytesIO(content))
return df.to_csv(index=False)
if lower.endswith((".xls", ".xlsx")):
df = pd.read_excel(io.BytesIO(content))
return df.to_csv(index=False)
raise ValueError("Unsupported file type")
@router.get("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
def get_knowledge_global_config():
config = knowledge_global_config_store.get()
raw_api_key = config.get("api_key")
return {
"api_base": config.get("api_base"),
"api_key": None,
"api_key_masked": _mask_api_key(raw_api_key),
"has_api_key": bool(raw_api_key),
"default_embedding_model": config.get("default_embedding_model"),
}
@router.put("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
def update_knowledge_global_config(payload: KnowledgeGlobalConfigUpdate):
updated = knowledge_global_config_store.update(payload.model_dump(exclude_unset=True))
raw_api_key = updated.get("api_key")
return {
"api_base": updated.get("api_base"),
"api_key": None,
"api_key_masked": _mask_api_key(raw_api_key),
"has_api_key": bool(raw_api_key),
"default_embedding_model": updated.get("default_embedding_model"),
}
@router.post("/knowledge-bases/global-config/test-connection", response_model=KnowledgeConnectionTestResponse)
def test_knowledge_global_connection(payload: KnowledgeConnectionTestRequest):
saved = knowledge_global_config_store.get()
api_base = normalize_openai_base_url(payload.api_base or saved.get("api_base") or "")
api_key = payload.api_key or saved.get("api_key")
model_name = (payload.model_name or "").strip()
if not api_base:
raise HTTPException(status_code=400, detail="API Base 未配置")
if not api_key:
raise HTTPException(status_code=400, detail="API Key 未配置")
if not model_name:
raise HTTPException(status_code=400, detail="测试连接必须显式填写向量模型名称")
if not api_base:
raise HTTPException(status_code=400, detail="API Base 未配置")
try:
client = OpenAI(
api_key=api_key,
base_url=api_base,
)
embedding_resp = client.embeddings.create(
model=model_name,
input="connection test",
)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Embedding调用失败: {exc}")
dimension = None
if getattr(embedding_resp, "data", None):
first = embedding_resp.data[0]
vector = getattr(first, "embedding", None)
if isinstance(vector, list):
dimension = len(vector)
return {
"success": True,
"message": "连接成功,Embedding调用正常",
"model_name": model_name,
"embedding_dimension": dimension,
"resolved_api_base": api_base,
"available_models": [],
}
@router.get("/knowledge-bases", response_model=List[KnowledgeBase])
def list_knowledge_bases(project_id: Optional[int] = None):
return knowledge_base_store.list(project_id=project_id)
@router.post("/knowledge-bases", response_model=KnowledgeBase)
def create_knowledge_base(payload: KnowledgeBaseCreate):
return knowledge_base_store.create(payload.model_dump())
@router.get("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
def get_knowledge_base(kb_id: str):
kb = knowledge_base_store.get(kb_id)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb
@router.put("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
def update_knowledge_base(kb_id: str, payload: KnowledgeBaseUpdate):
kb = knowledge_base_store.update(kb_id, payload.model_dump(exclude_unset=True))
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb
@router.delete("/knowledge-bases/{kb_id}")
def delete_knowledge_base(kb_id: str):
deleted = knowledge_base_store.delete(kb_id)
if not deleted:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return {"status": "success"}
@router.get("/knowledge-bases/{kb_id}/documents", response_model=List[KnowledgeDocument])
def list_knowledge_documents(kb_id: str):
kb = knowledge_base_store.get(kb_id)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return kb.get("documents", [])
@router.post("/knowledge-bases/{kb_id}/documents", response_model=KnowledgeDocument)
def create_knowledge_document(kb_id: str, payload: KnowledgeDocumentCreate):
doc = knowledge_base_store.create_document(kb_id=kb_id, payload=payload.model_dump())
if not doc:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return doc
@router.put("/knowledge-bases/{kb_id}/documents/{doc_id}", response_model=KnowledgeDocument)
def update_knowledge_document(kb_id: str, doc_id: str, payload: KnowledgeDocumentUpdate):
doc = knowledge_base_store.update_document(
kb_id=kb_id,
doc_id=doc_id,
payload=payload.model_dump(exclude_unset=True),
)
if not doc:
raise HTTPException(status_code=404, detail="Knowledge document not found")
return doc
@router.delete("/knowledge-bases/{kb_id}/documents/{doc_id}")
def delete_knowledge_document(kb_id: str, doc_id: str):
deleted = knowledge_base_store.delete_document(kb_id=kb_id, doc_id=doc_id)
if not deleted:
raise HTTPException(status_code=404, detail="Knowledge document not found")
return {"status": "success"}
@router.post("/knowledge-bases/{kb_id}/reindex")
def reindex_knowledge_base(kb_id: str):
try:
return knowledge_index_service.reindex(kb_id)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
@router.post("/knowledge-bases/{kb_id}/search", response_model=KnowledgeSearchResponse)
def search_knowledge_base(kb_id: str, payload: KnowledgeSearchRequest):
try:
result = knowledge_index_service.search(
kb_id=kb_id,
query=payload.query,
top_k=payload.top_k,
)
except ValueError as exc:
raise HTTPException(status_code=404, detail=str(exc))
return result
@router.post("/knowledge-bases/{kb_id}/documents/upload")
async def upload_knowledge_documents(
kb_id: str,
files: List[UploadFile] = File(...),
metadata: Optional[str] = Form(default=None),
):
kb = knowledge_base_store.get(kb_id)
if not kb:
raise HTTPException(status_code=404, detail="Knowledge base not found")
metadata_payload: dict[str, Any] = {}
if metadata:
try:
parsed_metadata = json.loads(metadata)
if isinstance(parsed_metadata, dict):
metadata_payload = parsed_metadata
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="metadata 必须是合法 JSON 对象")
created: List[dict[str, Any]] = []
for file in files:
filename = file.filename or "untitled"
content = await file.read()
if not content:
continue
if len(content) > 5 * 1024 * 1024:
raise HTTPException(status_code=400, detail=f"文件过大: {filename}")
try:
text = _extract_upload_text(filename, content)
except Exception:
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {filename}")
doc = knowledge_base_store.create_document(
kb_id=kb_id,
payload={
"title": filename,
"content": text,
"metadata": {**metadata_payload, "source": "upload", "filename": filename},
},
)
if doc:
created.append(doc)
return {"status": "success", "count": len(created), "documents": created}
+2
View File
@@ -19,3 +19,5 @@ current_data_source: ContextVar[str] = ContextVar("current_data_source", default
# Any file URL attached to the request
current_file_url: ContextVar[Optional[str]] = ContextVar("current_file_url", default=None)
current_knowledge_base_id: ContextVar[Optional[str]] = ContextVar("current_knowledge_base_id", default=None)
+2
View File
@@ -204,10 +204,12 @@ class NanobotIntegration:
from app.tools.nl2sql import NL2SQLTool
from app.tools.visualization import VisualizationTool
from app.tools.get_schema import GetDatabaseSchemaTool
from app.tools.knowledge_base import KnowledgeBaseRetrieveTool
from app.tools.subagent import ListSubagentsTool, InvokeSubagentTool
agent.tools.register(NL2SQLTool())
agent.tools.register(VisualizationTool())
agent.tools.register(GetDatabaseSchemaTool())
agent.tools.register(KnowledgeBaseRetrieveTool())
agent.tools.register(ListSubagentsTool(project_id=project_id))
agent.tools.register(InvokeSubagentTool(project_id=project_id))
+162
View File
@@ -0,0 +1,162 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
class KnowledgeDocumentBase(BaseModel):
title: str = Field(..., min_length=1, max_length=200)
content: str = Field(..., min_length=1)
metadata: Dict[str, Any] = Field(default_factory=dict)
class KnowledgeDocumentCreate(KnowledgeDocumentBase):
pass
class KnowledgeDocumentUpdate(BaseModel):
title: Optional[str] = Field(None, min_length=1, max_length=200)
content: Optional[str] = Field(None, min_length=1)
metadata: Optional[Dict[str, Any]] = None
class KnowledgeDocument(KnowledgeDocumentBase):
id: str
created_at: datetime
updated_at: datetime
class KnowledgeBaseConfigBase(BaseModel):
name: str = Field(..., min_length=1, max_length=120)
description: Optional[str] = None
project_id: Optional[int] = None
embedding_model: Optional[str] = None
chunk_size: int = Field(default=512, ge=64, le=4096)
chunk_overlap: int = Field(default=50, ge=0, le=512)
top_k: int = Field(default=3, ge=1, le=20)
is_active: bool = True
class KnowledgeBaseCreate(KnowledgeBaseConfigBase):
pass
class KnowledgeBaseUpdate(BaseModel):
name: Optional[str] = Field(None, min_length=1, max_length=120)
description: Optional[str] = None
project_id: Optional[int] = None
embedding_model: Optional[str] = None
chunk_size: Optional[int] = Field(None, ge=64, le=4096)
chunk_overlap: Optional[int] = Field(None, ge=0, le=512)
top_k: Optional[int] = Field(None, ge=1, le=20)
is_active: Optional[bool] = None
class KnowledgeBase(KnowledgeBaseConfigBase):
id: str
created_at: datetime
updated_at: datetime
documents: List[KnowledgeDocument] = Field(default_factory=list)
class KnowledgeSearchRequest(BaseModel):
query: str = Field(..., min_length=1)
top_k: Optional[int] = Field(default=None, ge=1, le=20)
class KnowledgeSearchHit(BaseModel):
doc_id: str
title: str
chunk: str
score: float
metadata: Dict[str, Any] = Field(default_factory=dict)
class KnowledgeSearchResponse(BaseModel):
answer: str
hits: List[KnowledgeSearchHit] = Field(default_factory=list)
class KnowledgeGlobalConfigUpdate(BaseModel):
api_base: Optional[str] = None
api_key: Optional[str] = None
default_embedding_model: Optional[str] = None
@field_validator("api_base")
@classmethod
def validate_api_base(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if not (normalized.startswith("http://") or normalized.startswith("https://")):
raise ValueError("api_base must start with http:// or https://")
return normalized.rstrip("/")
@field_validator("api_key")
@classmethod
def validate_api_key(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if len(normalized) > 512:
raise ValueError("api_key is too long")
return normalized
@field_validator("default_embedding_model")
@classmethod
def validate_default_embedding_model(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if len(normalized) > 200:
raise ValueError("default_embedding_model is too long")
return normalized
class KnowledgeGlobalConfig(BaseModel):
api_base: Optional[str] = None
api_key: Optional[str] = None
api_key_masked: Optional[str] = None
has_api_key: bool = False
default_embedding_model: Optional[str] = None
class KnowledgeConnectionTestRequest(BaseModel):
api_base: Optional[str] = None
api_key: Optional[str] = None
model_name: Optional[str] = None
@field_validator("api_base")
@classmethod
def validate_test_api_base(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
if not (normalized.startswith("http://") or normalized.startswith("https://")):
raise ValueError("api_base must start with http:// or https://")
return normalized.rstrip("/")
@field_validator("api_key", "model_name")
@classmethod
def normalize_test_value(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return None
normalized = value.strip()
return normalized or None
class KnowledgeConnectionTestResponse(BaseModel):
success: bool
message: str
model_name: Optional[str] = None
embedding_dimension: Optional[int] = None
resolved_api_base: Optional[str] = None
available_models: List[str] = Field(default_factory=list)
@@ -0,0 +1,188 @@
import json
import threading
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from app.core.data_root import get_data_root
def _utcnow_iso() -> str:
return datetime.now(timezone.utc).isoformat()
class KnowledgeBaseStore:
def __init__(self) -> None:
self._lock = threading.RLock()
@staticmethod
def _file_path() -> Path:
return get_data_root() / "knowledge_bases.json"
def _read(self) -> List[Dict[str, Any]]:
file_path = self._file_path()
if not file_path.exists():
return []
try:
with file_path.open("r", encoding="utf-8") as f:
data = json.load(f)
except (json.JSONDecodeError, OSError):
return []
if not isinstance(data, list):
return []
return data
def _write(self, data: List[Dict[str, Any]]) -> None:
file_path = self._file_path()
file_path.parent.mkdir(parents=True, exist_ok=True)
with file_path.open("w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
@staticmethod
def _normalize_documents(item: Dict[str, Any]) -> None:
docs = item.get("documents")
if not isinstance(docs, list):
item["documents"] = []
return
normalized: List[Dict[str, Any]] = []
for doc in docs:
if not isinstance(doc, dict):
continue
if not doc.get("id"):
doc["id"] = str(uuid.uuid4())
now = _utcnow_iso()
doc.setdefault("created_at", now)
doc.setdefault("updated_at", now)
doc.setdefault("metadata", {})
normalized.append(doc)
item["documents"] = normalized
def list(self, project_id: Optional[int] = None) -> List[Dict[str, Any]]:
with self._lock:
data = self._read()
for item in data:
self._normalize_documents(item)
if project_id is None:
return data
return [item for item in data if item.get("project_id") == project_id]
def get(self, kb_id: str) -> Optional[Dict[str, Any]]:
with self._lock:
for item in self._read():
if item.get("id") == kb_id:
self._normalize_documents(item)
return item
return None
def create(self, payload: Dict[str, Any]) -> Dict[str, Any]:
with self._lock:
data = self._read()
now = _utcnow_iso()
item = {
"id": str(uuid.uuid4()),
"name": payload["name"],
"description": payload.get("description"),
"project_id": payload.get("project_id"),
"embedding_model": payload.get("embedding_model"),
"chunk_size": payload.get("chunk_size", 512),
"chunk_overlap": payload.get("chunk_overlap", 50),
"top_k": payload.get("top_k", 3),
"is_active": payload.get("is_active", True),
"created_at": now,
"updated_at": now,
"documents": [],
}
data.append(item)
self._write(data)
return item
def update(self, kb_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
with self._lock:
data = self._read()
for idx, item in enumerate(data):
if item.get("id") != kb_id:
continue
for key, value in payload.items():
item[key] = value
item["updated_at"] = _utcnow_iso()
self._normalize_documents(item)
data[idx] = item
self._write(data)
return item
return None
def delete(self, kb_id: str) -> bool:
with self._lock:
data = self._read()
filtered = [item for item in data if item.get("id") != kb_id]
if len(filtered) == len(data):
return False
self._write(filtered)
return True
def create_document(self, kb_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
with self._lock:
data = self._read()
for idx, item in enumerate(data):
if item.get("id") != kb_id:
continue
now = _utcnow_iso()
doc = {
"id": str(uuid.uuid4()),
"title": payload["title"],
"content": payload["content"],
"metadata": payload.get("metadata", {}),
"created_at": now,
"updated_at": now,
}
self._normalize_documents(item)
item["documents"].append(doc)
item["updated_at"] = now
data[idx] = item
self._write(data)
return doc
return None
def update_document(self, kb_id: str, doc_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
with self._lock:
data = self._read()
for kb_idx, item in enumerate(data):
if item.get("id") != kb_id:
continue
self._normalize_documents(item)
docs = item["documents"]
for doc_idx, doc in enumerate(docs):
if doc.get("id") != doc_id:
continue
for key, value in payload.items():
doc[key] = value
doc["updated_at"] = _utcnow_iso()
docs[doc_idx] = doc
item["updated_at"] = _utcnow_iso()
data[kb_idx] = item
self._write(data)
return doc
return None
return None
def delete_document(self, kb_id: str, doc_id: str) -> bool:
with self._lock:
data = self._read()
for kb_idx, item in enumerate(data):
if item.get("id") != kb_id:
continue
self._normalize_documents(item)
docs = item["documents"]
filtered = [doc for doc in docs if doc.get("id") != doc_id]
if len(filtered) == len(docs):
return False
item["documents"] = filtered
item["updated_at"] = _utcnow_iso()
data[kb_idx] = item
self._write(data)
return True
return False
knowledge_base_store = KnowledgeBaseStore()
@@ -0,0 +1,58 @@
import json
import threading
from pathlib import Path
from typing import Any, Dict
from app.core.data_root import get_data_root
class KnowledgeGlobalConfigStore:
def __init__(self) -> None:
self._lock = threading.RLock()
@staticmethod
def _file_path() -> Path:
return get_data_root() / "knowledge_global_config.json"
def _read(self) -> Dict[str, Any]:
file_path = self._file_path()
if not file_path.exists():
return {}
try:
with file_path.open("r", encoding="utf-8") as file_obj:
data = json.load(file_obj)
except (OSError, json.JSONDecodeError):
return {}
if not isinstance(data, dict):
return {}
return data
def _write(self, data: Dict[str, Any]) -> None:
file_path = self._file_path()
file_path.parent.mkdir(parents=True, exist_ok=True)
with file_path.open("w", encoding="utf-8") as file_obj:
json.dump(data, file_obj, indent=2, ensure_ascii=False)
def get(self) -> Dict[str, Any]:
with self._lock:
data = self._read()
return {
"api_base": data.get("api_base"),
"api_key": data.get("api_key"),
"default_embedding_model": data.get("default_embedding_model"),
}
def update(self, payload: Dict[str, Any]) -> Dict[str, Any]:
with self._lock:
current = self.get()
if "api_base" in payload:
current["api_base"] = payload.get("api_base")
if "api_key" in payload:
current["api_key"] = payload.get("api_key")
if "default_embedding_model" in payload:
current["default_embedding_model"] = payload.get("default_embedding_model")
self._write(current)
return current
knowledge_global_config_store = KnowledgeGlobalConfigStore()
+250
View File
@@ -0,0 +1,250 @@
import math
import re
import threading
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
from app.services.knowledge_base_store import knowledge_base_store
from app.services.knowledge_global_config_store import knowledge_global_config_store
from app.services.openai_compat import normalize_openai_base_url
try:
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
LLAMAINDEX_AVAILABLE = True
except Exception:
Document = Any
VectorStoreIndex = Any
SentenceSplitter = Any
LLAMAINDEX_AVAILABLE = False
def _tokenize(text: str) -> List[str]:
return re.findall(r"[a-zA-Z0-9]+|[\u4e00-\u9fff]", (text or "").lower())
def _normalize_embedding_api_base(api_base: str) -> str:
return normalize_openai_base_url(api_base)
@dataclass
class SearchHit:
doc_id: str
title: str
chunk: str
score: float
metadata: Dict[str, Any]
class KnowledgeIndexService:
def __init__(self) -> None:
self._lock = threading.RLock()
self._cache: Dict[str, Tuple[str, Any, List[Dict[str, Any]]]] = {}
@staticmethod
def _signature(kb: Dict[str, Any]) -> str:
doc_parts = []
for doc in kb.get("documents", []):
doc_parts.append(f"{doc.get('id')}:{doc.get('updated_at')}:{len(doc.get('content', ''))}")
return "|".join(
[
str(kb.get("updated_at")),
str(kb.get("chunk_size")),
str(kb.get("chunk_overlap")),
*doc_parts,
]
)
@staticmethod
def _fallback_chunks(kb: Dict[str, Any]) -> List[Dict[str, Any]]:
chunks: List[Dict[str, Any]] = []
chunk_size = int(kb.get("chunk_size") or 512)
overlap = int(kb.get("chunk_overlap") or 50)
step = max(1, chunk_size - overlap)
for doc in kb.get("documents", []):
text = doc.get("content") or ""
if not text:
continue
if len(text) <= chunk_size:
chunks.append(
{
"doc_id": doc.get("id", ""),
"title": doc.get("title", ""),
"chunk": text,
"metadata": doc.get("metadata") or {},
}
)
continue
for start in range(0, len(text), step):
piece = text[start : start + chunk_size]
if not piece:
continue
chunks.append(
{
"doc_id": doc.get("id", ""),
"title": doc.get("title", ""),
"chunk": piece,
"metadata": doc.get("metadata") or {},
}
)
return chunks
def _build_index(self, kb: Dict[str, Any]) -> Tuple[Any, List[Dict[str, Any]]]:
fallback_chunks = self._fallback_chunks(kb)
if not LLAMAINDEX_AVAILABLE:
return None, fallback_chunks
chunk_size = int(kb.get("chunk_size") or 512)
overlap = int(kb.get("chunk_overlap") or 50)
splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
docs = [
Document(
text=(doc.get("content") or ""),
metadata={
"doc_id": doc.get("id", ""),
"title": doc.get("title", ""),
**(doc.get("metadata") or {}),
},
)
for doc in kb.get("documents", [])
if (doc.get("content") or "").strip()
]
if not docs:
return None, fallback_chunks
embed_model = self._build_embed_model(kb)
if embed_model is not None:
index = VectorStoreIndex.from_documents(
docs,
transformations=[splitter],
embed_model=embed_model,
)
else:
index = VectorStoreIndex.from_documents(docs, transformations=[splitter])
return index, fallback_chunks
@staticmethod
def _build_embed_model(kb: Dict[str, Any]) -> Any:
global_config = knowledge_global_config_store.get()
api_base = global_config.get("api_base")
api_key = global_config.get("api_key")
model_name = kb.get("embedding_model") or global_config.get("default_embedding_model")
if not api_base or not api_key or not model_name:
return None
api_base = _normalize_embedding_api_base(api_base)
try:
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
return OpenAILikeEmbedding(
model_name=model_name,
api_base=api_base,
api_key=api_key,
embed_batch_size=10,
)
except Exception:
try:
from llama_index.embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding(
model_name=model_name,
api_base=api_base,
api_key=api_key,
embed_batch_size=10,
)
except Exception:
return None
def reindex(self, kb_id: str) -> Dict[str, Any]:
kb = knowledge_base_store.get(kb_id)
if not kb:
raise ValueError("Knowledge base not found")
with self._lock:
signature = self._signature(kb)
index, fallback_chunks = self._build_index(kb)
self._cache[kb_id] = (signature, index, fallback_chunks)
return {
"kb_id": kb_id,
"status": "ok",
"documents": len(kb.get("documents", [])),
"engine": "llamaindex" if LLAMAINDEX_AVAILABLE and index is not None else "fallback",
}
@staticmethod
def _fallback_search(query: str, chunks: List[Dict[str, Any]], top_k: int) -> List[SearchHit]:
q_tokens = _tokenize(query)
if not q_tokens:
return []
q_set = set(q_tokens)
scored: List[SearchHit] = []
for chunk_item in chunks:
c_tokens = _tokenize(chunk_item.get("chunk", ""))
if not c_tokens:
continue
overlap = sum(1 for t in c_tokens if t in q_set)
if overlap == 0:
continue
score = overlap / math.sqrt(len(c_tokens))
scored.append(
SearchHit(
doc_id=chunk_item.get("doc_id", ""),
title=chunk_item.get("title", ""),
chunk=chunk_item.get("chunk", ""),
score=float(score),
metadata=chunk_item.get("metadata") or {},
)
)
scored.sort(key=lambda x: x.score, reverse=True)
return scored[:top_k]
def search(self, kb_id: str, query: str, top_k: int | None = None) -> Dict[str, Any]:
kb = knowledge_base_store.get(kb_id)
if not kb:
raise ValueError("Knowledge base not found")
if not kb.get("documents"):
return {"answer": "", "hits": []}
effective_top_k = int(top_k or kb.get("top_k") or 3)
with self._lock:
signature = self._signature(kb)
cached = self._cache.get(kb_id)
if not cached or cached[0] != signature:
index, fallback_chunks = self._build_index(kb)
cached = (signature, index, fallback_chunks)
self._cache[kb_id] = cached
_, index, fallback_chunks = cached
if index is None:
hits = self._fallback_search(query=query, chunks=fallback_chunks, top_k=effective_top_k)
answer = "\n\n".join(hit.chunk for hit in hits)
return {
"answer": answer,
"hits": [hit.__dict__ for hit in hits],
}
retriever = index.as_retriever(similarity_top_k=effective_top_k)
response_nodes = retriever.retrieve(query)
hits: List[Dict[str, Any]] = []
for node_with_score in response_nodes:
node = getattr(node_with_score, "node", None)
metadata = getattr(node, "metadata", {}) if node is not None else {}
chunk_text = ""
if node is not None and hasattr(node, "get_content"):
chunk_text = node.get_content()
elif node is not None:
chunk_text = str(getattr(node, "text", ""))
hits.append(
{
"doc_id": metadata.get("doc_id", ""),
"title": metadata.get("title", ""),
"chunk": chunk_text,
"score": float(getattr(node_with_score, "score", 0.0) or 0.0),
"metadata": metadata,
}
)
if not hits:
fallback_hits = self._fallback_search(query=query, chunks=fallback_chunks, top_k=effective_top_k)
return {
"answer": "\n\n".join(hit.chunk for hit in fallback_hits),
"hits": [hit.__dict__ for hit in fallback_hits],
}
answer = "\n\n".join(item.get("chunk", "") for item in hits if item.get("chunk"))
return {"answer": answer, "hits": hits}
knowledge_index_service = KnowledgeIndexService()
+5
View File
@@ -0,0 +1,5 @@
def normalize_openai_base_url(api_base: str) -> str:
normalized = (api_base or "").strip().rstrip("/")
if normalized.lower().endswith("/embeddings"):
normalized = normalized[: -len("/embeddings")]
return normalized
+59
View File
@@ -0,0 +1,59 @@
import json
from typing import Any
from nanobot.agent.tools.base import Tool
from app.context import current_knowledge_base_id
from app.services.knowledge_index import knowledge_index_service
class KnowledgeBaseRetrieveTool(Tool):
@property
def name(self) -> str:
return "knowledge_retrieve"
@property
def description(self) -> str:
return "Retrieve relevant context from the selected knowledge base to answer user questions."
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "User question or retrieval query.",
},
"knowledge_base_id": {
"type": "string",
"description": "Optional knowledge base id, defaults to current session setting.",
},
"top_k": {
"type": "integer",
"description": "Maximum number of returned chunks.",
"minimum": 1,
"maximum": 20,
},
},
"required": ["query"],
}
async def execute(self, **kwargs: Any) -> str:
query = (kwargs.get("query") or "").strip()
if not query:
return "Query is required."
kb_id = (kwargs.get("knowledge_base_id") or current_knowledge_base_id.get() or "").strip()
if not kb_id:
return "No knowledge base is selected in this session."
top_k = kwargs.get("top_k")
try:
result = knowledge_index_service.search(kb_id=kb_id, query=query, top_k=top_k)
except ValueError as exc:
return str(exc)
payload = {
"knowledge_base_id": kb_id,
"answer": result.get("answer", ""),
"hits": result.get("hits", []),
}
return json.dumps(payload, ensure_ascii=False)