377 lines
13 KiB
Python
377 lines
13 KiB
Python
|
|
import asyncio
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||
|
|
REPO_ROOT = BACKEND_ROOT.parent
|
||
|
|
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||
|
|
if str(BACKEND_ROOT) not in sys.path:
|
||
|
|
sys.path.insert(0, str(BACKEND_ROOT))
|
||
|
|
if str(NANOBOT_ROOT) not in sys.path:
|
||
|
|
sys.path.insert(0, str(NANOBOT_ROOT))
|
||
|
|
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
|
||
|
|
import main
|
||
|
|
from app.context import current_knowledge_base_id
|
||
|
|
from app.schemas.knowledge import KnowledgeSearchResponse
|
||
|
|
from app.tools.knowledge_base import KnowledgeBaseRetrieveTool
|
||
|
|
|
||
|
|
|
||
|
|
def test_knowledge_base_crud_and_search_routes(monkeypatch, tmp_path) -> None:
|
||
|
|
async def fake_start():
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def fake_stop():
|
||
|
|
return None
|
||
|
|
|
||
|
|
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||
|
|
|
||
|
|
client = TestClient(main.app)
|
||
|
|
|
||
|
|
create_resp = client.post(
|
||
|
|
"/api/v1/knowledge-bases",
|
||
|
|
json={"name": "产品手册", "description": "用于问答", "top_k": 2, "chunk_size": 256, "chunk_overlap": 20},
|
||
|
|
)
|
||
|
|
assert create_resp.status_code == 200
|
||
|
|
kb = create_resp.json()
|
||
|
|
kb_id = kb["id"]
|
||
|
|
|
||
|
|
list_resp = client.get("/api/v1/knowledge-bases")
|
||
|
|
assert list_resp.status_code == 200
|
||
|
|
assert any(item["id"] == kb_id for item in list_resp.json())
|
||
|
|
|
||
|
|
doc_resp = client.post(
|
||
|
|
f"/api/v1/knowledge-bases/{kb_id}/documents",
|
||
|
|
json={"title": "退款规则", "content": "苹果手机支持7天无理由退款", "metadata": {"lang": "zh"}},
|
||
|
|
)
|
||
|
|
assert doc_resp.status_code == 200
|
||
|
|
doc_id = doc_resp.json()["id"]
|
||
|
|
|
||
|
|
reindex_resp = client.post(f"/api/v1/knowledge-bases/{kb_id}/reindex")
|
||
|
|
assert reindex_resp.status_code == 200
|
||
|
|
|
||
|
|
search_resp = client.post(f"/api/v1/knowledge-bases/{kb_id}/search", json={"query": "苹果退款", "top_k": 2})
|
||
|
|
assert search_resp.status_code == 200
|
||
|
|
parsed = KnowledgeSearchResponse(**search_resp.json())
|
||
|
|
assert parsed.hits
|
||
|
|
assert "苹果" in parsed.answer
|
||
|
|
|
||
|
|
update_resp = client.put(f"/api/v1/knowledge-bases/{kb_id}", json={"name": "售后知识库"})
|
||
|
|
assert update_resp.status_code == 200
|
||
|
|
assert update_resp.json()["name"] == "售后知识库"
|
||
|
|
|
||
|
|
delete_doc_resp = client.delete(f"/api/v1/knowledge-bases/{kb_id}/documents/{doc_id}")
|
||
|
|
assert delete_doc_resp.status_code == 200
|
||
|
|
delete_kb_resp = client.delete(f"/api/v1/knowledge-bases/{kb_id}")
|
||
|
|
assert delete_kb_resp.status_code == 200
|
||
|
|
|
||
|
|
|
||
|
|
def test_knowledge_global_config_mask_and_validation(monkeypatch, tmp_path) -> None:
|
||
|
|
async def fake_start():
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def fake_stop():
|
||
|
|
return None
|
||
|
|
|
||
|
|
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||
|
|
client = TestClient(main.app)
|
||
|
|
|
||
|
|
initial_resp = client.get("/api/v1/knowledge-bases/global-config")
|
||
|
|
assert initial_resp.status_code == 200
|
||
|
|
assert initial_resp.json() == {
|
||
|
|
"api_base": None,
|
||
|
|
"api_key": None,
|
||
|
|
"api_key_masked": None,
|
||
|
|
"has_api_key": False,
|
||
|
|
"default_embedding_model": None,
|
||
|
|
}
|
||
|
|
|
||
|
|
update_resp = client.put(
|
||
|
|
"/api/v1/knowledge-bases/global-config",
|
||
|
|
json={"api_base": "https://kb.example.com/", "api_key": "sk-knowledge-secret", "default_embedding_model": "text-embedding-3-small"},
|
||
|
|
)
|
||
|
|
assert update_resp.status_code == 200
|
||
|
|
body = update_resp.json()
|
||
|
|
assert body["api_base"] == "https://kb.example.com"
|
||
|
|
assert body["api_key"] is None
|
||
|
|
assert body["has_api_key"] is True
|
||
|
|
assert body["api_key_masked"] == "sk-k***********cret"
|
||
|
|
assert body["default_embedding_model"] == "text-embedding-3-small"
|
||
|
|
|
||
|
|
get_resp = client.get("/api/v1/knowledge-bases/global-config")
|
||
|
|
assert get_resp.status_code == 200
|
||
|
|
assert get_resp.json()["api_key"] is None
|
||
|
|
assert get_resp.json()["api_key_masked"] == "sk-k***********cret"
|
||
|
|
assert get_resp.json()["default_embedding_model"] == "text-embedding-3-small"
|
||
|
|
|
||
|
|
invalid_resp = client.put("/api/v1/knowledge-bases/global-config", json={"api_base": "ftp://kb.example.com"})
|
||
|
|
assert invalid_resp.status_code == 422
|
||
|
|
|
||
|
|
|
||
|
|
def test_chat_request_syncs_knowledge_base_metadata(monkeypatch) -> None:
|
||
|
|
captured_kb_ids: list[str | None] = []
|
||
|
|
captured_messages: list[str] = []
|
||
|
|
|
||
|
|
class _DummySession:
|
||
|
|
def __init__(self):
|
||
|
|
self.metadata = {}
|
||
|
|
self.messages = []
|
||
|
|
self.updated_at = None
|
||
|
|
|
||
|
|
class _DummySessions:
|
||
|
|
def __init__(self):
|
||
|
|
self._sessions: dict[str, _DummySession] = {}
|
||
|
|
|
||
|
|
def get_or_create(self, key: str):
|
||
|
|
if key not in self._sessions:
|
||
|
|
self._sessions[key] = _DummySession()
|
||
|
|
return self._sessions[key]
|
||
|
|
|
||
|
|
def save(self, _session):
|
||
|
|
return None
|
||
|
|
|
||
|
|
class _DummyAgent:
|
||
|
|
def __init__(self):
|
||
|
|
self.sessions = _DummySessions()
|
||
|
|
|
||
|
|
async def fake_process_message(*args, **kwargs):
|
||
|
|
captured_kb_ids.append(current_knowledge_base_id.get())
|
||
|
|
if args and isinstance(args[0], str):
|
||
|
|
captured_messages.append(args[0])
|
||
|
|
return "ok"
|
||
|
|
|
||
|
|
def fake_search(*, kb_id: str, query: str, top_k=None):
|
||
|
|
assert kb_id == "kb-123"
|
||
|
|
assert query == "请回答售后规则"
|
||
|
|
return {
|
||
|
|
"answer": "命中结果",
|
||
|
|
"hits": [
|
||
|
|
{"doc_id": "d1", "title": "退款规则", "chunk": "7天无理由退款", "score": 0.9, "metadata": {}},
|
||
|
|
{"doc_id": "d2", "title": "售后电话", "chunk": "客服电话 400-1234", "score": 0.7, "metadata": {}},
|
||
|
|
],
|
||
|
|
}
|
||
|
|
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
|
||
|
|
monkeypatch.setattr("main.knowledge_index_service.search", fake_search)
|
||
|
|
|
||
|
|
request = main.ChatRequest(message="请回答售后规则", session_id="api:kb-1", knowledge_base_id="kb-123")
|
||
|
|
response = asyncio.run(main.nanobot_chat(request))
|
||
|
|
assert response["response"] == "ok"
|
||
|
|
assert "kb_citations" in response
|
||
|
|
assert len(response["kb_citations"]) == 2
|
||
|
|
assert response["kb_citations"][0]["title"] == "退款规则"
|
||
|
|
assert captured_kb_ids == ["kb-123"]
|
||
|
|
assert captured_messages and "7天无理由退款" in captured_messages[0]
|
||
|
|
session = main.nanobot_service.agent.sessions.get_or_create("api:kb-1")
|
||
|
|
assert session.metadata["selected_knowledge_base_id"] == "kb-123"
|
||
|
|
|
||
|
|
|
||
|
|
def test_knowledge_tool_uses_session_context(monkeypatch) -> None:
|
||
|
|
tool = KnowledgeBaseRetrieveTool()
|
||
|
|
token = current_knowledge_base_id.set("kb-session")
|
||
|
|
called: list[dict] = []
|
||
|
|
|
||
|
|
def fake_search(*, kb_id: str, query: str, top_k=None):
|
||
|
|
called.append({"kb_id": kb_id, "query": query, "top_k": top_k})
|
||
|
|
return {
|
||
|
|
"answer": "命中结果",
|
||
|
|
"hits": [{"doc_id": "d1", "title": "t1", "chunk": "命中结果", "score": 1.0, "metadata": {}}],
|
||
|
|
}
|
||
|
|
|
||
|
|
monkeypatch.setattr("app.tools.knowledge_base.knowledge_index_service.search", fake_search)
|
||
|
|
try:
|
||
|
|
output = asyncio.run(tool.execute(query="售后政策"))
|
||
|
|
finally:
|
||
|
|
current_knowledge_base_id.reset(token)
|
||
|
|
assert called and called[0]["kb_id"] == "kb-session"
|
||
|
|
assert "命中结果" in output
|
||
|
|
|
||
|
|
|
||
|
|
def test_update_session_context_file_supports_knowledge_base(monkeypatch) -> None:
|
||
|
|
class _DummySession:
|
||
|
|
def __init__(self):
|
||
|
|
self.metadata = {}
|
||
|
|
self.updated_at = None
|
||
|
|
|
||
|
|
class _DummySessions:
|
||
|
|
def __init__(self):
|
||
|
|
self.session = _DummySession()
|
||
|
|
|
||
|
|
def get_or_create(self, _key: str):
|
||
|
|
return self.session
|
||
|
|
|
||
|
|
def save(self, _session):
|
||
|
|
return None
|
||
|
|
|
||
|
|
class _DummyAgent:
|
||
|
|
def __init__(self):
|
||
|
|
self.sessions = _DummySessions()
|
||
|
|
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
|
||
|
|
payload = main.SessionFileContextUpdateRequest(selected_knowledge_base_id="kb-ctx")
|
||
|
|
response = main.update_session_context_file("api:ctx", payload)
|
||
|
|
assert response["status"] == "success"
|
||
|
|
assert response["metadata"]["selected_knowledge_base_id"] == "kb-ctx"
|
||
|
|
|
||
|
|
|
||
|
|
def test_knowledge_global_connection_test_route(monkeypatch, tmp_path) -> None:
|
||
|
|
async def fake_start():
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def fake_stop():
|
||
|
|
return None
|
||
|
|
|
||
|
|
class _DummyEmbeddingData:
|
||
|
|
embedding = [0.1, 0.2, 0.3]
|
||
|
|
|
||
|
|
class _DummyEmbeddingResp:
|
||
|
|
data = [_DummyEmbeddingData()]
|
||
|
|
|
||
|
|
class _DummyEmbeddingsAPI:
|
||
|
|
@staticmethod
|
||
|
|
def create(model: str, input: str):
|
||
|
|
assert model == "text-embedding-3-small"
|
||
|
|
assert input == "connection test"
|
||
|
|
return _DummyEmbeddingResp()
|
||
|
|
|
||
|
|
class _DummyOpenAI:
|
||
|
|
def __init__(self, api_key: str, base_url: str):
|
||
|
|
assert api_key == "sk-knowledge-secret"
|
||
|
|
assert base_url == "https://kb.example.com"
|
||
|
|
self.embeddings = _DummyEmbeddingsAPI()
|
||
|
|
|
||
|
|
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||
|
|
monkeypatch.setattr("app.api.knowledge.OpenAI", _DummyOpenAI)
|
||
|
|
|
||
|
|
client = TestClient(main.app)
|
||
|
|
save_resp = client.put(
|
||
|
|
"/api/v1/knowledge-bases/global-config",
|
||
|
|
json={
|
||
|
|
"api_base": "https://kb.example.com",
|
||
|
|
"api_key": "sk-knowledge-secret",
|
||
|
|
"default_embedding_model": "text-embedding-3-small",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert save_resp.status_code == 200
|
||
|
|
|
||
|
|
test_resp = client.post(
|
||
|
|
"/api/v1/knowledge-bases/global-config/test-connection",
|
||
|
|
json={"model_name": "text-embedding-3-small"},
|
||
|
|
)
|
||
|
|
assert test_resp.status_code == 200
|
||
|
|
body = test_resp.json()
|
||
|
|
assert body["success"] is True
|
||
|
|
assert body["model_name"] == "text-embedding-3-small"
|
||
|
|
assert body["embedding_dimension"] == 3
|
||
|
|
assert body["resolved_api_base"] == "https://kb.example.com"
|
||
|
|
assert body["available_models"] == []
|
||
|
|
|
||
|
|
|
||
|
|
def test_knowledge_global_connection_test_route_requires_model_name(monkeypatch, tmp_path) -> None:
|
||
|
|
async def fake_start():
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def fake_stop():
|
||
|
|
return None
|
||
|
|
|
||
|
|
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||
|
|
|
||
|
|
client = TestClient(main.app)
|
||
|
|
resp = client.post(
|
||
|
|
"/api/v1/knowledge-bases/global-config/test-connection",
|
||
|
|
json={
|
||
|
|
"api_base": "https://api.siliconflow.cn/v1/embeddings",
|
||
|
|
"api_key": "ark-key",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
assert "测试连接必须显式填写向量模型名称" in resp.json()["detail"]
|
||
|
|
|
||
|
|
|
||
|
|
def test_knowledge_global_connection_test_route_returns_remote_error(monkeypatch, tmp_path) -> None:
|
||
|
|
async def fake_start():
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def fake_stop():
|
||
|
|
return None
|
||
|
|
|
||
|
|
class _DummyEmbeddingsAPI:
|
||
|
|
@staticmethod
|
||
|
|
def create(model: str, input: str):
|
||
|
|
assert model == "BAAI/bge-large-zh-v1.5"
|
||
|
|
assert input == "connection test"
|
||
|
|
raise RuntimeError("Not Found")
|
||
|
|
|
||
|
|
class _DummyOpenAI:
|
||
|
|
def __init__(self, api_key: str, base_url: str):
|
||
|
|
assert api_key == "sf-key"
|
||
|
|
assert base_url == "https://api.siliconflow.cn/v1"
|
||
|
|
self.embeddings = _DummyEmbeddingsAPI()
|
||
|
|
|
||
|
|
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||
|
|
monkeypatch.setattr("app.api.knowledge.OpenAI", _DummyOpenAI)
|
||
|
|
|
||
|
|
client = TestClient(main.app)
|
||
|
|
resp = client.post(
|
||
|
|
"/api/v1/knowledge-bases/global-config/test-connection",
|
||
|
|
json={
|
||
|
|
"api_base": "https://api.siliconflow.cn/v1/embeddings",
|
||
|
|
"api_key": "sf-key",
|
||
|
|
"model_name": "BAAI/bge-large-zh-v1.5",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
assert resp.status_code == 400
|
||
|
|
assert "Embedding调用失败" in resp.json()["detail"]
|
||
|
|
|
||
|
|
|
||
|
|
def test_knowledge_document_upload_route(monkeypatch, tmp_path) -> None:
|
||
|
|
async def fake_start():
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def fake_stop():
|
||
|
|
return None
|
||
|
|
|
||
|
|
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||
|
|
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||
|
|
|
||
|
|
client = TestClient(main.app)
|
||
|
|
create_resp = client.post(
|
||
|
|
"/api/v1/knowledge-bases",
|
||
|
|
json={"name": "上传测试库", "description": "用于上传", "top_k": 2, "chunk_size": 256, "chunk_overlap": 20},
|
||
|
|
)
|
||
|
|
assert create_resp.status_code == 200
|
||
|
|
kb_id = create_resp.json()["id"]
|
||
|
|
|
||
|
|
files = [
|
||
|
|
("files", ("doc1.txt", b"hello knowledge", "text/plain")),
|
||
|
|
("files", ("doc2.md", b"# title\ncontent", "text/markdown")),
|
||
|
|
]
|
||
|
|
upload_resp = client.post(
|
||
|
|
f"/api/v1/knowledge-bases/{kb_id}/documents/upload",
|
||
|
|
files=files,
|
||
|
|
data={"metadata": "{\"source\":\"batch\"}"},
|
||
|
|
)
|
||
|
|
assert upload_resp.status_code == 200
|
||
|
|
body = upload_resp.json()
|
||
|
|
assert body["status"] == "success"
|
||
|
|
assert body["count"] == 2
|
||
|
|
assert len(body["documents"]) == 2
|
||
|
|
|
||
|
|
list_resp = client.get(f"/api/v1/knowledge-bases/{kb_id}/documents")
|
||
|
|
assert list_resp.status_code == 200
|
||
|
|
docs = list_resp.json()
|
||
|
|
assert len(docs) == 2
|