feat: knowledge base first OK

This commit is contained in:
qixinbo
2026-03-29 00:20:53 +08:00
parent bd7776d1b7
commit 92e8c40826
17 changed files with 3357 additions and 10 deletions
@@ -0,0 +1,188 @@
import json
import threading
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from app.core.data_root import get_data_root
def _utcnow_iso() -> str:
return datetime.now(timezone.utc).isoformat()
class KnowledgeBaseStore:
def __init__(self) -> None:
self._lock = threading.RLock()
@staticmethod
def _file_path() -> Path:
return get_data_root() / "knowledge_bases.json"
def _read(self) -> List[Dict[str, Any]]:
file_path = self._file_path()
if not file_path.exists():
return []
try:
with file_path.open("r", encoding="utf-8") as f:
data = json.load(f)
except (json.JSONDecodeError, OSError):
return []
if not isinstance(data, list):
return []
return data
def _write(self, data: List[Dict[str, Any]]) -> None:
file_path = self._file_path()
file_path.parent.mkdir(parents=True, exist_ok=True)
with file_path.open("w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
@staticmethod
def _normalize_documents(item: Dict[str, Any]) -> None:
docs = item.get("documents")
if not isinstance(docs, list):
item["documents"] = []
return
normalized: List[Dict[str, Any]] = []
for doc in docs:
if not isinstance(doc, dict):
continue
if not doc.get("id"):
doc["id"] = str(uuid.uuid4())
now = _utcnow_iso()
doc.setdefault("created_at", now)
doc.setdefault("updated_at", now)
doc.setdefault("metadata", {})
normalized.append(doc)
item["documents"] = normalized
def list(self, project_id: Optional[int] = None) -> List[Dict[str, Any]]:
with self._lock:
data = self._read()
for item in data:
self._normalize_documents(item)
if project_id is None:
return data
return [item for item in data if item.get("project_id") == project_id]
def get(self, kb_id: str) -> Optional[Dict[str, Any]]:
with self._lock:
for item in self._read():
if item.get("id") == kb_id:
self._normalize_documents(item)
return item
return None
def create(self, payload: Dict[str, Any]) -> Dict[str, Any]:
with self._lock:
data = self._read()
now = _utcnow_iso()
item = {
"id": str(uuid.uuid4()),
"name": payload["name"],
"description": payload.get("description"),
"project_id": payload.get("project_id"),
"embedding_model": payload.get("embedding_model"),
"chunk_size": payload.get("chunk_size", 512),
"chunk_overlap": payload.get("chunk_overlap", 50),
"top_k": payload.get("top_k", 3),
"is_active": payload.get("is_active", True),
"created_at": now,
"updated_at": now,
"documents": [],
}
data.append(item)
self._write(data)
return item
def update(self, kb_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
with self._lock:
data = self._read()
for idx, item in enumerate(data):
if item.get("id") != kb_id:
continue
for key, value in payload.items():
item[key] = value
item["updated_at"] = _utcnow_iso()
self._normalize_documents(item)
data[idx] = item
self._write(data)
return item
return None
def delete(self, kb_id: str) -> bool:
with self._lock:
data = self._read()
filtered = [item for item in data if item.get("id") != kb_id]
if len(filtered) == len(data):
return False
self._write(filtered)
return True
def create_document(self, kb_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
with self._lock:
data = self._read()
for idx, item in enumerate(data):
if item.get("id") != kb_id:
continue
now = _utcnow_iso()
doc = {
"id": str(uuid.uuid4()),
"title": payload["title"],
"content": payload["content"],
"metadata": payload.get("metadata", {}),
"created_at": now,
"updated_at": now,
}
self._normalize_documents(item)
item["documents"].append(doc)
item["updated_at"] = now
data[idx] = item
self._write(data)
return doc
return None
def update_document(self, kb_id: str, doc_id: str, payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
with self._lock:
data = self._read()
for kb_idx, item in enumerate(data):
if item.get("id") != kb_id:
continue
self._normalize_documents(item)
docs = item["documents"]
for doc_idx, doc in enumerate(docs):
if doc.get("id") != doc_id:
continue
for key, value in payload.items():
doc[key] = value
doc["updated_at"] = _utcnow_iso()
docs[doc_idx] = doc
item["updated_at"] = _utcnow_iso()
data[kb_idx] = item
self._write(data)
return doc
return None
return None
def delete_document(self, kb_id: str, doc_id: str) -> bool:
with self._lock:
data = self._read()
for kb_idx, item in enumerate(data):
if item.get("id") != kb_id:
continue
self._normalize_documents(item)
docs = item["documents"]
filtered = [doc for doc in docs if doc.get("id") != doc_id]
if len(filtered) == len(docs):
return False
item["documents"] = filtered
item["updated_at"] = _utcnow_iso()
data[kb_idx] = item
self._write(data)
return True
return False
knowledge_base_store = KnowledgeBaseStore()
@@ -0,0 +1,58 @@
import json
import threading
from pathlib import Path
from typing import Any, Dict
from app.core.data_root import get_data_root
class KnowledgeGlobalConfigStore:
def __init__(self) -> None:
self._lock = threading.RLock()
@staticmethod
def _file_path() -> Path:
return get_data_root() / "knowledge_global_config.json"
def _read(self) -> Dict[str, Any]:
file_path = self._file_path()
if not file_path.exists():
return {}
try:
with file_path.open("r", encoding="utf-8") as file_obj:
data = json.load(file_obj)
except (OSError, json.JSONDecodeError):
return {}
if not isinstance(data, dict):
return {}
return data
def _write(self, data: Dict[str, Any]) -> None:
file_path = self._file_path()
file_path.parent.mkdir(parents=True, exist_ok=True)
with file_path.open("w", encoding="utf-8") as file_obj:
json.dump(data, file_obj, indent=2, ensure_ascii=False)
def get(self) -> Dict[str, Any]:
with self._lock:
data = self._read()
return {
"api_base": data.get("api_base"),
"api_key": data.get("api_key"),
"default_embedding_model": data.get("default_embedding_model"),
}
def update(self, payload: Dict[str, Any]) -> Dict[str, Any]:
with self._lock:
current = self.get()
if "api_base" in payload:
current["api_base"] = payload.get("api_base")
if "api_key" in payload:
current["api_key"] = payload.get("api_key")
if "default_embedding_model" in payload:
current["default_embedding_model"] = payload.get("default_embedding_model")
self._write(current)
return current
knowledge_global_config_store = KnowledgeGlobalConfigStore()
+250
View File
@@ -0,0 +1,250 @@
import math
import re
import threading
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
from app.services.knowledge_base_store import knowledge_base_store
from app.services.knowledge_global_config_store import knowledge_global_config_store
from app.services.openai_compat import normalize_openai_base_url
try:
from llama_index.core import Document, VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
LLAMAINDEX_AVAILABLE = True
except Exception:
Document = Any
VectorStoreIndex = Any
SentenceSplitter = Any
LLAMAINDEX_AVAILABLE = False
def _tokenize(text: str) -> List[str]:
return re.findall(r"[a-zA-Z0-9]+|[\u4e00-\u9fff]", (text or "").lower())
def _normalize_embedding_api_base(api_base: str) -> str:
return normalize_openai_base_url(api_base)
@dataclass
class SearchHit:
doc_id: str
title: str
chunk: str
score: float
metadata: Dict[str, Any]
class KnowledgeIndexService:
def __init__(self) -> None:
self._lock = threading.RLock()
self._cache: Dict[str, Tuple[str, Any, List[Dict[str, Any]]]] = {}
@staticmethod
def _signature(kb: Dict[str, Any]) -> str:
doc_parts = []
for doc in kb.get("documents", []):
doc_parts.append(f"{doc.get('id')}:{doc.get('updated_at')}:{len(doc.get('content', ''))}")
return "|".join(
[
str(kb.get("updated_at")),
str(kb.get("chunk_size")),
str(kb.get("chunk_overlap")),
*doc_parts,
]
)
@staticmethod
def _fallback_chunks(kb: Dict[str, Any]) -> List[Dict[str, Any]]:
chunks: List[Dict[str, Any]] = []
chunk_size = int(kb.get("chunk_size") or 512)
overlap = int(kb.get("chunk_overlap") or 50)
step = max(1, chunk_size - overlap)
for doc in kb.get("documents", []):
text = doc.get("content") or ""
if not text:
continue
if len(text) <= chunk_size:
chunks.append(
{
"doc_id": doc.get("id", ""),
"title": doc.get("title", ""),
"chunk": text,
"metadata": doc.get("metadata") or {},
}
)
continue
for start in range(0, len(text), step):
piece = text[start : start + chunk_size]
if not piece:
continue
chunks.append(
{
"doc_id": doc.get("id", ""),
"title": doc.get("title", ""),
"chunk": piece,
"metadata": doc.get("metadata") or {},
}
)
return chunks
def _build_index(self, kb: Dict[str, Any]) -> Tuple[Any, List[Dict[str, Any]]]:
fallback_chunks = self._fallback_chunks(kb)
if not LLAMAINDEX_AVAILABLE:
return None, fallback_chunks
chunk_size = int(kb.get("chunk_size") or 512)
overlap = int(kb.get("chunk_overlap") or 50)
splitter = SentenceSplitter(chunk_size=chunk_size, chunk_overlap=overlap)
docs = [
Document(
text=(doc.get("content") or ""),
metadata={
"doc_id": doc.get("id", ""),
"title": doc.get("title", ""),
**(doc.get("metadata") or {}),
},
)
for doc in kb.get("documents", [])
if (doc.get("content") or "").strip()
]
if not docs:
return None, fallback_chunks
embed_model = self._build_embed_model(kb)
if embed_model is not None:
index = VectorStoreIndex.from_documents(
docs,
transformations=[splitter],
embed_model=embed_model,
)
else:
index = VectorStoreIndex.from_documents(docs, transformations=[splitter])
return index, fallback_chunks
@staticmethod
def _build_embed_model(kb: Dict[str, Any]) -> Any:
global_config = knowledge_global_config_store.get()
api_base = global_config.get("api_base")
api_key = global_config.get("api_key")
model_name = kb.get("embedding_model") or global_config.get("default_embedding_model")
if not api_base or not api_key or not model_name:
return None
api_base = _normalize_embedding_api_base(api_base)
try:
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
return OpenAILikeEmbedding(
model_name=model_name,
api_base=api_base,
api_key=api_key,
embed_batch_size=10,
)
except Exception:
try:
from llama_index.embeddings.openai import OpenAIEmbedding
return OpenAIEmbedding(
model_name=model_name,
api_base=api_base,
api_key=api_key,
embed_batch_size=10,
)
except Exception:
return None
def reindex(self, kb_id: str) -> Dict[str, Any]:
kb = knowledge_base_store.get(kb_id)
if not kb:
raise ValueError("Knowledge base not found")
with self._lock:
signature = self._signature(kb)
index, fallback_chunks = self._build_index(kb)
self._cache[kb_id] = (signature, index, fallback_chunks)
return {
"kb_id": kb_id,
"status": "ok",
"documents": len(kb.get("documents", [])),
"engine": "llamaindex" if LLAMAINDEX_AVAILABLE and index is not None else "fallback",
}
@staticmethod
def _fallback_search(query: str, chunks: List[Dict[str, Any]], top_k: int) -> List[SearchHit]:
q_tokens = _tokenize(query)
if not q_tokens:
return []
q_set = set(q_tokens)
scored: List[SearchHit] = []
for chunk_item in chunks:
c_tokens = _tokenize(chunk_item.get("chunk", ""))
if not c_tokens:
continue
overlap = sum(1 for t in c_tokens if t in q_set)
if overlap == 0:
continue
score = overlap / math.sqrt(len(c_tokens))
scored.append(
SearchHit(
doc_id=chunk_item.get("doc_id", ""),
title=chunk_item.get("title", ""),
chunk=chunk_item.get("chunk", ""),
score=float(score),
metadata=chunk_item.get("metadata") or {},
)
)
scored.sort(key=lambda x: x.score, reverse=True)
return scored[:top_k]
def search(self, kb_id: str, query: str, top_k: int | None = None) -> Dict[str, Any]:
kb = knowledge_base_store.get(kb_id)
if not kb:
raise ValueError("Knowledge base not found")
if not kb.get("documents"):
return {"answer": "", "hits": []}
effective_top_k = int(top_k or kb.get("top_k") or 3)
with self._lock:
signature = self._signature(kb)
cached = self._cache.get(kb_id)
if not cached or cached[0] != signature:
index, fallback_chunks = self._build_index(kb)
cached = (signature, index, fallback_chunks)
self._cache[kb_id] = cached
_, index, fallback_chunks = cached
if index is None:
hits = self._fallback_search(query=query, chunks=fallback_chunks, top_k=effective_top_k)
answer = "\n\n".join(hit.chunk for hit in hits)
return {
"answer": answer,
"hits": [hit.__dict__ for hit in hits],
}
retriever = index.as_retriever(similarity_top_k=effective_top_k)
response_nodes = retriever.retrieve(query)
hits: List[Dict[str, Any]] = []
for node_with_score in response_nodes:
node = getattr(node_with_score, "node", None)
metadata = getattr(node, "metadata", {}) if node is not None else {}
chunk_text = ""
if node is not None and hasattr(node, "get_content"):
chunk_text = node.get_content()
elif node is not None:
chunk_text = str(getattr(node, "text", ""))
hits.append(
{
"doc_id": metadata.get("doc_id", ""),
"title": metadata.get("title", ""),
"chunk": chunk_text,
"score": float(getattr(node_with_score, "score", 0.0) or 0.0),
"metadata": metadata,
}
)
if not hits:
fallback_hits = self._fallback_search(query=query, chunks=fallback_chunks, top_k=effective_top_k)
return {
"answer": "\n\n".join(hit.chunk for hit in fallback_hits),
"hits": [hit.__dict__ for hit in fallback_hits],
}
answer = "\n\n".join(item.get("chunk", "") for item in hits if item.get("chunk"))
return {"answer": answer, "hits": hits}
knowledge_index_service = KnowledgeIndexService()
+5
View File
@@ -0,0 +1,5 @@
def normalize_openai_base_url(api_base: str) -> str:
normalized = (api_base or "").strip().rstrip("/")
if normalized.lower().endswith("/embeddings"):
normalized = normalized[: -len("/embeddings")]
return normalized