2026-03-29 00:20:53 +08:00
|
|
|
from typing import List, Optional
|
|
|
|
|
import io
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
|
|
|
from fastapi import UploadFile, File, Form
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
|
|
from app.schemas.knowledge import (
|
|
|
|
|
KnowledgeBase,
|
|
|
|
|
KnowledgeBaseCreate,
|
|
|
|
|
KnowledgeConnectionTestRequest,
|
|
|
|
|
KnowledgeConnectionTestResponse,
|
|
|
|
|
KnowledgeGlobalConfig,
|
|
|
|
|
KnowledgeGlobalConfigUpdate,
|
|
|
|
|
KnowledgeBaseUpdate,
|
|
|
|
|
KnowledgeDocument,
|
|
|
|
|
KnowledgeDocumentCreate,
|
|
|
|
|
KnowledgeDocumentUpdate,
|
|
|
|
|
KnowledgeSearchRequest,
|
|
|
|
|
KnowledgeSearchResponse,
|
|
|
|
|
)
|
|
|
|
|
from app.services.knowledge_base_store import knowledge_base_store
|
|
|
|
|
from app.services.knowledge_global_config_store import knowledge_global_config_store
|
|
|
|
|
from app.services.knowledge_index import knowledge_index_service
|
|
|
|
|
from app.services.openai_compat import normalize_openai_base_url
|
|
|
|
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _mask_api_key(value: Optional[str]) -> Optional[str]:
|
|
|
|
|
if not value:
|
|
|
|
|
return None
|
|
|
|
|
if len(value) <= 8:
|
|
|
|
|
return "*" * len(value)
|
|
|
|
|
return f"{value[:4]}{'*' * (len(value) - 8)}{value[-4:]}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _extract_upload_text(filename: str, content: bytes) -> str:
|
|
|
|
|
lower = filename.lower()
|
|
|
|
|
if lower.endswith((".txt", ".md", ".markdown", ".json", ".yaml", ".yml", ".log", ".xml", ".html", ".htm")):
|
|
|
|
|
try:
|
|
|
|
|
return content.decode("utf-8")
|
|
|
|
|
except UnicodeDecodeError:
|
|
|
|
|
return content.decode("utf-8", errors="ignore")
|
|
|
|
|
if lower.endswith(".csv"):
|
|
|
|
|
df = pd.read_csv(io.BytesIO(content))
|
|
|
|
|
return df.to_csv(index=False)
|
|
|
|
|
if lower.endswith((".xls", ".xlsx")):
|
|
|
|
|
df = pd.read_excel(io.BytesIO(content))
|
|
|
|
|
return df.to_csv(index=False)
|
2026-03-29 17:24:26 +08:00
|
|
|
|
|
|
|
|
# 增加对 PDF 的文本提取支持
|
|
|
|
|
if lower.endswith(".pdf"):
|
|
|
|
|
try:
|
|
|
|
|
import PyPDF2
|
|
|
|
|
pdf_reader = PyPDF2.PdfReader(io.BytesIO(content))
|
|
|
|
|
text = []
|
|
|
|
|
for page in pdf_reader.pages:
|
|
|
|
|
page_text = page.extract_text()
|
|
|
|
|
if page_text:
|
|
|
|
|
text.append(page_text)
|
|
|
|
|
return "\n".join(text)
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ValueError("PyPDF2 is not installed. Cannot parse PDF files.")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f"Failed to parse PDF: {str(e)}")
|
|
|
|
|
|
2026-03-29 20:35:38 +08:00
|
|
|
# 增加对 Word 文档的文本提取支持
|
|
|
|
|
if lower.endswith((".doc", ".docx")):
|
|
|
|
|
try:
|
|
|
|
|
import docx
|
|
|
|
|
doc = docx.Document(io.BytesIO(content))
|
|
|
|
|
return "\n".join([para.text for para in doc.paragraphs])
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ValueError("python-docx is not installed. Cannot parse Word files.")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f"Failed to parse Word document: {str(e)}")
|
|
|
|
|
|
|
|
|
|
# 增加对 PPT 文档的文本提取支持
|
|
|
|
|
if lower.endswith((".ppt", ".pptx")):
|
|
|
|
|
try:
|
|
|
|
|
import pptx
|
|
|
|
|
prs = pptx.Presentation(io.BytesIO(content))
|
|
|
|
|
text = []
|
|
|
|
|
for slide in prs.slides:
|
|
|
|
|
for shape in slide.shapes:
|
|
|
|
|
if hasattr(shape, "text"):
|
|
|
|
|
text.append(shape.text)
|
|
|
|
|
return "\n".join(text)
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ValueError("python-pptx is not installed. Cannot parse PPT files.")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ValueError(f"Failed to parse PPT document: {str(e)}")
|
|
|
|
|
|
2026-03-29 00:20:53 +08:00
|
|
|
raise ValueError("Unsupported file type")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
|
|
|
|
|
def get_knowledge_global_config():
|
|
|
|
|
config = knowledge_global_config_store.get()
|
|
|
|
|
raw_api_key = config.get("api_key")
|
|
|
|
|
return {
|
|
|
|
|
"api_base": config.get("api_base"),
|
|
|
|
|
"api_key": None,
|
|
|
|
|
"api_key_masked": _mask_api_key(raw_api_key),
|
|
|
|
|
"has_api_key": bool(raw_api_key),
|
|
|
|
|
"default_embedding_model": config.get("default_embedding_model"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/knowledge-bases/global-config", response_model=KnowledgeGlobalConfig)
|
|
|
|
|
def update_knowledge_global_config(payload: KnowledgeGlobalConfigUpdate):
|
|
|
|
|
updated = knowledge_global_config_store.update(payload.model_dump(exclude_unset=True))
|
|
|
|
|
raw_api_key = updated.get("api_key")
|
|
|
|
|
return {
|
|
|
|
|
"api_base": updated.get("api_base"),
|
|
|
|
|
"api_key": None,
|
|
|
|
|
"api_key_masked": _mask_api_key(raw_api_key),
|
|
|
|
|
"has_api_key": bool(raw_api_key),
|
|
|
|
|
"default_embedding_model": updated.get("default_embedding_model"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/knowledge-bases/global-config/test-connection", response_model=KnowledgeConnectionTestResponse)
|
|
|
|
|
def test_knowledge_global_connection(payload: KnowledgeConnectionTestRequest):
|
|
|
|
|
saved = knowledge_global_config_store.get()
|
|
|
|
|
api_base = normalize_openai_base_url(payload.api_base or saved.get("api_base") or "")
|
|
|
|
|
api_key = payload.api_key or saved.get("api_key")
|
|
|
|
|
model_name = (payload.model_name or "").strip()
|
|
|
|
|
|
|
|
|
|
if not api_base:
|
|
|
|
|
raise HTTPException(status_code=400, detail="API Base 未配置")
|
|
|
|
|
if not api_key:
|
|
|
|
|
raise HTTPException(status_code=400, detail="API Key 未配置")
|
|
|
|
|
if not model_name:
|
|
|
|
|
raise HTTPException(status_code=400, detail="测试连接必须显式填写向量模型名称")
|
|
|
|
|
|
|
|
|
|
if not api_base:
|
|
|
|
|
raise HTTPException(status_code=400, detail="API Base 未配置")
|
|
|
|
|
try:
|
|
|
|
|
client = OpenAI(
|
|
|
|
|
api_key=api_key,
|
|
|
|
|
base_url=api_base,
|
|
|
|
|
)
|
|
|
|
|
embedding_resp = client.embeddings.create(
|
|
|
|
|
model=model_name,
|
|
|
|
|
input="connection test",
|
|
|
|
|
)
|
|
|
|
|
except Exception as exc:
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"Embedding调用失败: {exc}")
|
|
|
|
|
|
|
|
|
|
dimension = None
|
|
|
|
|
if getattr(embedding_resp, "data", None):
|
|
|
|
|
first = embedding_resp.data[0]
|
|
|
|
|
vector = getattr(first, "embedding", None)
|
|
|
|
|
if isinstance(vector, list):
|
|
|
|
|
dimension = len(vector)
|
|
|
|
|
return {
|
|
|
|
|
"success": True,
|
|
|
|
|
"message": "连接成功,Embedding调用正常",
|
|
|
|
|
"model_name": model_name,
|
|
|
|
|
"embedding_dimension": dimension,
|
|
|
|
|
"resolved_api_base": api_base,
|
|
|
|
|
"available_models": [],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/knowledge-bases", response_model=List[KnowledgeBase])
|
|
|
|
|
def list_knowledge_bases(project_id: Optional[int] = None):
|
|
|
|
|
return knowledge_base_store.list(project_id=project_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/knowledge-bases", response_model=KnowledgeBase)
|
|
|
|
|
def create_knowledge_base(payload: KnowledgeBaseCreate):
|
|
|
|
|
return knowledge_base_store.create(payload.model_dump())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
|
|
|
|
|
def get_knowledge_base(kb_id: str):
|
|
|
|
|
kb = knowledge_base_store.get(kb_id)
|
|
|
|
|
if not kb:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
|
return kb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/knowledge-bases/{kb_id}", response_model=KnowledgeBase)
|
|
|
|
|
def update_knowledge_base(kb_id: str, payload: KnowledgeBaseUpdate):
|
|
|
|
|
kb = knowledge_base_store.update(kb_id, payload.model_dump(exclude_unset=True))
|
|
|
|
|
if not kb:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
|
return kb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/knowledge-bases/{kb_id}")
|
|
|
|
|
def delete_knowledge_base(kb_id: str):
|
|
|
|
|
deleted = knowledge_base_store.delete(kb_id)
|
|
|
|
|
if not deleted:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
|
return {"status": "success"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.get("/knowledge-bases/{kb_id}/documents", response_model=List[KnowledgeDocument])
|
|
|
|
|
def list_knowledge_documents(kb_id: str):
|
|
|
|
|
kb = knowledge_base_store.get(kb_id)
|
|
|
|
|
if not kb:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
|
return kb.get("documents", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/knowledge-bases/{kb_id}/documents", response_model=KnowledgeDocument)
|
|
|
|
|
def create_knowledge_document(kb_id: str, payload: KnowledgeDocumentCreate):
|
|
|
|
|
doc = knowledge_base_store.create_document(kb_id=kb_id, payload=payload.model_dump())
|
|
|
|
|
if not doc:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
|
return doc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.put("/knowledge-bases/{kb_id}/documents/{doc_id}", response_model=KnowledgeDocument)
|
|
|
|
|
def update_knowledge_document(kb_id: str, doc_id: str, payload: KnowledgeDocumentUpdate):
|
|
|
|
|
doc = knowledge_base_store.update_document(
|
|
|
|
|
kb_id=kb_id,
|
|
|
|
|
doc_id=doc_id,
|
|
|
|
|
payload=payload.model_dump(exclude_unset=True),
|
|
|
|
|
)
|
|
|
|
|
if not doc:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Knowledge document not found")
|
|
|
|
|
return doc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.delete("/knowledge-bases/{kb_id}/documents/{doc_id}")
|
|
|
|
|
def delete_knowledge_document(kb_id: str, doc_id: str):
|
|
|
|
|
deleted = knowledge_base_store.delete_document(kb_id=kb_id, doc_id=doc_id)
|
|
|
|
|
if not deleted:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Knowledge document not found")
|
|
|
|
|
return {"status": "success"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/knowledge-bases/{kb_id}/reindex")
|
|
|
|
|
def reindex_knowledge_base(kb_id: str):
|
|
|
|
|
try:
|
|
|
|
|
return knowledge_index_service.reindex(kb_id)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=404, detail=str(exc))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/knowledge-bases/{kb_id}/search", response_model=KnowledgeSearchResponse)
|
|
|
|
|
def search_knowledge_base(kb_id: str, payload: KnowledgeSearchRequest):
|
|
|
|
|
try:
|
|
|
|
|
result = knowledge_index_service.search(
|
|
|
|
|
kb_id=kb_id,
|
|
|
|
|
query=payload.query,
|
|
|
|
|
top_k=payload.top_k,
|
|
|
|
|
)
|
|
|
|
|
except ValueError as exc:
|
|
|
|
|
raise HTTPException(status_code=404, detail=str(exc))
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/knowledge-bases/{kb_id}/documents/upload")
|
|
|
|
|
async def upload_knowledge_documents(
|
|
|
|
|
kb_id: str,
|
|
|
|
|
files: List[UploadFile] = File(...),
|
|
|
|
|
metadata: Optional[str] = Form(default=None),
|
|
|
|
|
):
|
|
|
|
|
kb = knowledge_base_store.get(kb_id)
|
|
|
|
|
if not kb:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Knowledge base not found")
|
|
|
|
|
metadata_payload: dict[str, Any] = {}
|
|
|
|
|
if metadata:
|
|
|
|
|
try:
|
|
|
|
|
parsed_metadata = json.loads(metadata)
|
|
|
|
|
if isinstance(parsed_metadata, dict):
|
|
|
|
|
metadata_payload = parsed_metadata
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
raise HTTPException(status_code=400, detail="metadata 必须是合法 JSON 对象")
|
|
|
|
|
|
|
|
|
|
created: List[dict[str, Any]] = []
|
|
|
|
|
for file in files:
|
|
|
|
|
filename = file.filename or "untitled"
|
|
|
|
|
content = await file.read()
|
|
|
|
|
if not content:
|
|
|
|
|
continue
|
2026-03-29 17:24:26 +08:00
|
|
|
# 将大小限制从 5MB 放宽到 15MB,以更好地支持带有图片的 PDF 文件
|
|
|
|
|
if len(content) > 15 * 1024 * 1024:
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"文件过大 (超过 15MB): {filename}")
|
2026-03-29 00:20:53 +08:00
|
|
|
try:
|
|
|
|
|
text = _extract_upload_text(filename, content)
|
|
|
|
|
except Exception:
|
|
|
|
|
raise HTTPException(status_code=400, detail=f"不支持的文件类型: {filename}")
|
|
|
|
|
doc = knowledge_base_store.create_document(
|
|
|
|
|
kb_id=kb_id,
|
|
|
|
|
payload={
|
|
|
|
|
"title": filename,
|
|
|
|
|
"content": text,
|
|
|
|
|
"metadata": {**metadata_payload, "source": "upload", "filename": filename},
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
if doc:
|
|
|
|
|
created.append(doc)
|
|
|
|
|
return {"status": "success", "count": len(created), "documents": created}
|