Update 2026-05-13 16:43:53
This commit is contained in:
@@ -0,0 +1,872 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
from app.core.security import CurrentUser, get_current_user
|
||||
from app.database import Base, get_db
|
||||
from app.models.a2a import A2ARemoteAgent, A2ATask, A2ATaskState
|
||||
from app.models.project import Project
|
||||
from app.models.user import User
|
||||
from app.schemas.a2a import A2AMessageRole, A2APartType, AgentSkillOutputMode, AgentSkillInputMode
|
||||
from app.services.a2a_service import a2a_runtime, SharedSecretAuth
|
||||
from main import app
|
||||
|
||||
|
||||
def _seed(db: Session) -> tuple[int, str, int, str, int]:
|
||||
owner = User(username="a2a_owner", email="a2a_owner@example.com", hashed_password="x", is_admin=False)
|
||||
other = User(username="a2a_other", email="a2a_other@example.com", hashed_password="x", is_admin=False)
|
||||
db.add(owner)
|
||||
db.add(other)
|
||||
db.commit()
|
||||
db.refresh(owner)
|
||||
db.refresh(other)
|
||||
project = Project(name="a2a_project", description="a2a", owner_id=owner.id)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return owner.id, owner.username, other.id, other.username, project.id
|
||||
|
||||
|
||||
def _make_message_payload(project_id: int, text: str, session_id: str = "test-session", route_mode: str = "local_first", idempotency_key: Optional[str] = None) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"message": {
|
||||
"messageId": f"msg-{int(time.time()*1000)}",
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"part_type": "data",
|
||||
"data": {
|
||||
"project_id": project_id,
|
||||
"route_mode": route_mode,
|
||||
"session_id": session_id,
|
||||
**( {"idempotency_key": idempotency_key} if idempotency_key else {} )
|
||||
},
|
||||
"mediaType": "application/json",
|
||||
},
|
||||
{
|
||||
"part_type": "text",
|
||||
"text": text,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
class TestPartSerialization:
|
||||
def test_part_text_serialization(self):
|
||||
from app.schemas.a2a import A2APartCreateSchema
|
||||
part = A2APartCreateSchema(
|
||||
part_type=A2APartType.TEXT,
|
||||
text="Hello world",
|
||||
mediaType="text/plain",
|
||||
)
|
||||
data = part.model_dump()
|
||||
assert data["part_type"] == "text"
|
||||
assert data["text"] == "Hello world"
|
||||
assert data["mediaType"] == "text/plain"
|
||||
|
||||
def test_part_data_serialization(self):
|
||||
from app.schemas.a2a import A2APartCreateSchema
|
||||
part = A2APartCreateSchema(
|
||||
part_type=A2APartType.DATA,
|
||||
data={"project_id": 123, "route_mode": "local"},
|
||||
mediaType="application/json",
|
||||
)
|
||||
data = part.model_dump()
|
||||
assert data["part_type"] == "data"
|
||||
assert data["data"]["project_id"] == 123
|
||||
|
||||
def test_part_url_serialization(self):
|
||||
from app.schemas.a2a import A2APartCreateSchema
|
||||
part = A2APartCreateSchema(
|
||||
part_type=A2APartType.URL,
|
||||
url="https://example.com/file.pdf",
|
||||
mediaType="application/pdf",
|
||||
filename="file.pdf",
|
||||
)
|
||||
data = part.model_dump()
|
||||
assert data["part_type"] == "url"
|
||||
assert data["url"] == "https://example.com/file.pdf"
|
||||
assert data["filename"] == "file.pdf"
|
||||
|
||||
def test_part_raw_serialization(self):
|
||||
from app.schemas.a2a import A2APartCreateSchema
|
||||
part = A2APartCreateSchema(
|
||||
part_type=A2APartType.RAW,
|
||||
raw="\x00\x01\x02\x03",
|
||||
mediaType="application/octet-stream",
|
||||
)
|
||||
data = part.model_dump()
|
||||
assert data["part_type"] == "raw"
|
||||
assert data["raw"] == "\x00\x01\x02\x03"
|
||||
|
||||
|
||||
class TestStateMachine:
|
||||
def test_state_transitions_submit_to_complete(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "test state machine"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
assert task.state == A2ATaskState.SUBMITTED
|
||||
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.WORKING)
|
||||
assert task.state == A2ATaskState.WORKING
|
||||
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.COMPLETED)
|
||||
assert task.state == A2ATaskState.COMPLETED
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_cancel(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
|
||||
assert cancel_resp.status_code == 200
|
||||
assert cancel_resp.json()["state"] == "CANCELED"
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_failed(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "fail test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.FAILED, error_message='{"message": "test error"}')
|
||||
assert task.state == A2ATaskState.FAILED
|
||||
assert task.error_message is not None
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_rejected(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "reject test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.REJECTED)
|
||||
assert task.state == A2ATaskState.REJECTED
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_input_required(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "input required test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.INPUT_REQUIRED)
|
||||
assert task.state == A2ATaskState.INPUT_REQUIRED
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_auth_required(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "auth required test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.AUTH_REQUIRED)
|
||||
assert task.state == A2ATaskState.AUTH_REQUIRED
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestA2APathNormalization:
|
||||
def test_message_send_path(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "path test"))
|
||||
assert resp.status_code == 200
|
||||
assert "task" in resp.json()
|
||||
assert "id" in resp.json()["task"]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_message_stream_path(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:stream", json=_make_message_payload(project_id, "stream path test"))
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_tasks_cancel_path(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel path test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
|
||||
assert cancel_resp.status_code == 200
|
||||
assert cancel_resp.json()["task_id"] == task_id
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_agent_card_public_path(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
db.close()
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/.well-known/agent-card.json")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "name" in data
|
||||
assert "protocol_version" in data
|
||||
assert "endpoints" in data
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestVersionNegotiation:
|
||||
def test_version_not_supported_error(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post(
|
||||
"/api/v1/message:send",
|
||||
json=_make_message_payload(project_id, "version test"),
|
||||
headers={"A2A-Version": "2.0"}
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
detail = json.loads(resp.json()["detail"])
|
||||
assert detail["code"] == -32009
|
||||
assert "not supported" in detail["message"].lower()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_version_response_header(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post(
|
||||
"/api/v1/message:send",
|
||||
json=_make_message_payload(project_id, "version header test"),
|
||||
headers={"A2A-Version": "1.0"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers.get("A2A-Version") == "1.0"
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestWebhookStreamResponse:
|
||||
def test_webhook_payload_format(self):
|
||||
from app.schemas.a2a import StreamResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent, TaskMessageEvent, A2ATaskStatusSchema, A2ATaskState, A2AArtifactSchema
|
||||
|
||||
status_event = TaskStatusUpdateEvent(
|
||||
taskId="task-123",
|
||||
contextId="ctx-456",
|
||||
status=A2ATaskStatusSchema(
|
||||
state=A2ATaskState.SUBMITTED,
|
||||
timestamp=datetime.utcnow(),
|
||||
),
|
||||
metadata={},
|
||||
)
|
||||
status_dump = status_event.model_dump()
|
||||
assert "taskId" in status_dump
|
||||
assert status_dump["taskId"] == "task-123"
|
||||
assert status_dump["status"]["state"] == "SUBMITTED"
|
||||
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
taskId="task-123",
|
||||
contextId="ctx-456",
|
||||
artifact=A2AArtifactSchema(
|
||||
artifactId="art-789",
|
||||
parts=[],
|
||||
),
|
||||
append=False,
|
||||
lastChunk=True,
|
||||
)
|
||||
artifact_dump = artifact_event.model_dump()
|
||||
assert "taskId" in artifact_dump
|
||||
assert artifact_dump["artifact"]["artifactId"] == "art-789"
|
||||
|
||||
def test_stream_response_task_field(self):
|
||||
from app.schemas.a2a import StreamResponse, StreamResponseTask, A2ATaskState
|
||||
resp = StreamResponse(
|
||||
task=StreamResponseTask(
|
||||
id="task-123",
|
||||
contextId="ctx-456",
|
||||
state=A2ATaskState.WORKING,
|
||||
artifacts=[],
|
||||
)
|
||||
)
|
||||
data = resp.model_dump()
|
||||
assert "task" in data
|
||||
assert data["task"]["id"] == "task-123"
|
||||
|
||||
|
||||
class TestSSEFIFO:
|
||||
def test_sse_event_order(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
with client.stream("POST", "/api/v1/message:stream", json=_make_message_payload(project_id, "fifo test")) as resp:
|
||||
assert resp.status_code == 200
|
||||
chunks = []
|
||||
for line in resp.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
chunks.append(json.loads(line[6:]))
|
||||
|
||||
event_types = [c.get("type") for c in chunks if "type" in c]
|
||||
status_idx = next((i for i, t in enumerate(event_types) if t == "TaskStatusUpdateEvent"), -1)
|
||||
assert status_idx >= 0
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestAuthSchemes:
|
||||
def test_shared_secret_auth(self):
|
||||
secret = "test-secret-key-12345"
|
||||
timestamp = int(time.time())
|
||||
body = b'{"test":"data"}'
|
||||
sig, _ = SharedSecretAuth.generate_signature(secret, body, timestamp)
|
||||
assert sig.startswith("sha256=")
|
||||
|
||||
assert SharedSecretAuth.verify_signature(secret, body, sig, timestamp) is True
|
||||
|
||||
def test_auth_scheme_none(self):
|
||||
from app.schemas.a2a import SecuritySchemeHttpAuth
|
||||
scheme = SecuritySchemeHttpAuth(scheme="bearer", description="Bearer auth")
|
||||
assert scheme.scheme == "bearer"
|
||||
|
||||
|
||||
class TestExceptionPaths:
|
||||
def test_auth_failure_marks_agent_unhealthy(self, monkeypatch):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, _, _, _, project_id = _seed(db)
|
||||
agent = A2ARemoteAgent(
|
||||
project_id=project_id,
|
||||
name="auth-fail-agent",
|
||||
base_url="https://remote.example.com",
|
||||
auth_scheme="bearer",
|
||||
auth_token="bad-token",
|
||||
created_by=owner_id,
|
||||
)
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
a2a_runtime._circuit_state.pop(agent.id, None)
|
||||
|
||||
class _FailResp:
|
||||
status_code = 401
|
||||
|
||||
@staticmethod
|
||||
def json():
|
||||
return {"detail": "unauthorized"}
|
||||
|
||||
class _Client401:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url, headers=None):
|
||||
return _FailResp()
|
||||
|
||||
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _Client401)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
|
||||
db.refresh(agent)
|
||||
assert agent.healthy is False
|
||||
assert agent.failure_count == 1
|
||||
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
db.close()
|
||||
engine.dispose()
|
||||
|
||||
def test_remote_unavailable_opens_circuit(self, monkeypatch):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, _, _, _, project_id = _seed(db)
|
||||
agent = A2ARemoteAgent(
|
||||
project_id=project_id,
|
||||
name="offline-agent",
|
||||
base_url="https://offline.example.com",
|
||||
auth_scheme="none",
|
||||
created_by=owner_id,
|
||||
)
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
a2a_runtime._circuit_state.pop(agent.id, None)
|
||||
|
||||
class _ClientDown:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url, headers=None):
|
||||
raise httpx.ConnectError("network down")
|
||||
|
||||
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _ClientDown)
|
||||
|
||||
for _ in range(3):
|
||||
with pytest.raises(Exception):
|
||||
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
|
||||
db.refresh(agent)
|
||||
assert agent.healthy is False
|
||||
assert agent.failure_count == 3
|
||||
assert agent.circuit_open_until is not None
|
||||
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
db.close()
|
||||
engine.dispose()
|
||||
|
||||
def test_idempotency_key_deduplicates_task(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
idempotency_key = f"idem-key-{int(time.time())}"
|
||||
|
||||
payload1 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
|
||||
resp1 = client.post("/api/v1/message:send", json=payload1)
|
||||
assert resp1.status_code == 200
|
||||
|
||||
payload2 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
|
||||
payload2["message"]["messageId"] = f"msg-{int(time.time()*1000) + 1}"
|
||||
resp2 = client.post("/api/v1/message:send", json=payload2)
|
||||
|
||||
assert resp2.status_code == 200
|
||||
assert resp1.json()["task"]["id"] == resp2.json()["task"]["id"]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_tenant_isolation(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "isolation test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
|
||||
get_resp = client.get(f"/api/v1/tasks/{task_id}")
|
||||
assert get_resp.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestMetricsAdminOnly:
|
||||
def test_metrics_admin_only(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, _ = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
denied = client.get("/api/v1/a2a/metrics")
|
||||
assert denied.status_code == 403
|
||||
|
||||
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
|
||||
ok = client.get("/api/v1/a2a/metrics")
|
||||
assert ok.status_code == 200
|
||||
assert "counters" in ok.json()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
@@ -0,0 +1,132 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.core.data_root import get_data_root
|
||||
from main import app
|
||||
|
||||
|
||||
def _backend_data_root() -> Path:
|
||||
return get_data_root()
|
||||
|
||||
|
||||
def test_download_artifact_within_whitelist() -> None:
|
||||
uploads_dir = _backend_data_root() / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
sample = uploads_dir / "task2-download.csv"
|
||||
sample.write_text("id,name\n1,a\n", encoding="utf-8")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/nanobot/artifacts/download", params={"target": "local://task2-download.csv"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"].startswith("application/octet-stream")
|
||||
assert response.headers["content-disposition"].startswith("attachment;")
|
||||
assert response.content == sample.read_bytes()
|
||||
|
||||
|
||||
def test_download_artifact_rejects_outside_paths() -> None:
|
||||
client = TestClient(app)
|
||||
response = client.get("/nanobot/artifacts/download", params={"target": "/etc/hosts"})
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "非法路径访问"
|
||||
|
||||
|
||||
def test_preview_artifact_returns_unsupported_for_binary() -> None:
|
||||
uploads_dir = _backend_data_root() / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
sample = uploads_dir / "task2-unsupported.bin"
|
||||
sample.write_bytes(b"\x00\x01\x02")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/nanobot/artifacts/preview", params={"target": f"local://{sample.name}"})
|
||||
|
||||
assert response.status_code == 415
|
||||
assert response.json()["detail"] == "当前文件类型不支持预览,请使用下载"
|
||||
download = client.get("/nanobot/artifacts/download", params={"target": f"local://{sample.name}"})
|
||||
assert download.status_code == 200
|
||||
assert download.content == sample.read_bytes()
|
||||
|
||||
|
||||
def test_preview_html_supports_directory_resources() -> None:
|
||||
web_dir = _backend_data_root() / "workspace" / "task2-web"
|
||||
web_dir.mkdir(parents=True, exist_ok=True)
|
||||
html_file = web_dir / "index.html"
|
||||
css_file = web_dir / "styles.css"
|
||||
html_file.write_text("<html><head><link rel='stylesheet' href='styles.css'></head><body>ok</body></html>", encoding="utf-8")
|
||||
css_file.write_text("body{color:#333;}", encoding="utf-8")
|
||||
|
||||
client = TestClient(app)
|
||||
preview = client.get(
|
||||
"/nanobot/artifacts/preview",
|
||||
params={"target": str(html_file)},
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert preview.status_code == 307
|
||||
location = preview.headers["location"]
|
||||
assert location.startswith("/nanobot/artifacts/web/")
|
||||
|
||||
html_response = client.get(location)
|
||||
assert html_response.status_code == 200
|
||||
assert "text/html" in html_response.headers["content-type"]
|
||||
assert "styles.css" in html_response.text
|
||||
|
||||
css_response = client.get(location.replace("index.html", "styles.css"))
|
||||
assert css_response.status_code == 200
|
||||
assert "text/css" in css_response.headers["content-type"]
|
||||
assert "color:#333" in css_response.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("filename", "payload", "expected_mime"),
|
||||
[
|
||||
("task4-image.png", b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR", "image/png"),
|
||||
("task4-preview.pdf", b"%PDF-1.4\n1 0 obj\n<<>>\nendobj\n", "application/pdf"),
|
||||
(
|
||||
"task4-preview.pptx",
|
||||
b"PK\x03\x04\x14\x00\x00\x00\x08\x00\x00\x00!\x00",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_preview_and_download_supported_types(filename: str, payload: bytes, expected_mime: str) -> None:
|
||||
uploads_dir = _backend_data_root() / "uploads"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
sample = uploads_dir / filename
|
||||
sample.write_bytes(payload)
|
||||
|
||||
client = TestClient(app)
|
||||
preview = client.get("/nanobot/artifacts/preview", params={"target": f"local://{filename}"})
|
||||
assert preview.status_code == 200
|
||||
assert preview.headers["content-type"].startswith(expected_mime)
|
||||
|
||||
download = client.get("/nanobot/artifacts/download", params={"target": f"local://{filename}"})
|
||||
assert download.status_code == 200
|
||||
assert download.content == sample.read_bytes()
|
||||
|
||||
|
||||
def test_web_preview_missing_resource_returns_error_and_download_still_works() -> None:
|
||||
web_dir = _backend_data_root() / "workspace" / "task4-web-missing"
|
||||
web_dir.mkdir(parents=True, exist_ok=True)
|
||||
html_file = web_dir / "index.html"
|
||||
html_file.write_text("<html><head><script src='missing.js'></script></head><body>ok</body></html>", encoding="utf-8")
|
||||
|
||||
client = TestClient(app)
|
||||
preview = client.get(
|
||||
"/nanobot/artifacts/preview",
|
||||
params={"target": str(html_file)},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert preview.status_code == 307
|
||||
location = preview.headers["location"]
|
||||
|
||||
missing = client.get(location.replace("index.html", "missing.js"))
|
||||
assert missing.status_code == 404
|
||||
assert missing.json()["detail"] == "Web 资源不存在"
|
||||
|
||||
download = client.get("/nanobot/artifacts/download", params={"target": str(html_file)})
|
||||
assert download.status_code == 200
|
||||
assert download.content == html_file.read_bytes()
|
||||
@@ -0,0 +1,56 @@
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.artifacts import extract_artifacts
|
||||
from app.core.data_root import get_data_root
|
||||
|
||||
|
||||
def _backend_data_root() -> Path:
|
||||
return get_data_root()
|
||||
|
||||
|
||||
def test_extract_artifacts_from_local_and_tool_paths() -> None:
|
||||
data_root = _backend_data_root()
|
||||
uploads_dir = data_root / "uploads"
|
||||
workspace_dir = data_root / "workspace" / "reports"
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
workspace_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
upload_file = uploads_dir / "task1-sample.csv"
|
||||
upload_file.write_text("a,b\n1,2\n", encoding="utf-8")
|
||||
report_file = workspace_dir / "task1-report.html"
|
||||
report_file.write_text("<html><body>ok</body></html>", encoding="utf-8")
|
||||
|
||||
content = "请下载 local://task1-sample.csv"
|
||||
session_messages = [
|
||||
{"role": "user", "content": "生成报告"},
|
||||
{"role": "tool", "content": f"输出文件:{report_file}"},
|
||||
]
|
||||
|
||||
artifacts = extract_artifacts(content, session_messages)
|
||||
|
||||
by_name = {item["name"]: item for item in artifacts}
|
||||
assert "task1-sample.csv" in by_name
|
||||
assert "task1-report.html" in by_name
|
||||
assert by_name["task1-sample.csv"]["download_url"].startswith("/nanobot/artifacts/download?target=")
|
||||
assert by_name["task1-sample.csv"]["previewable"] is True
|
||||
assert by_name["task1-report.html"]["previewable"] is True
|
||||
assert by_name["task1-report.html"]["preview_url"].startswith("/nanobot/artifacts/preview?target=")
|
||||
|
||||
|
||||
def test_extract_artifacts_deduplicate_and_skip_missing() -> None:
|
||||
data_root = _backend_data_root()
|
||||
workspace_dir = data_root / "workspace"
|
||||
workspace_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pdf_file = workspace_dir / "task1-dedup.pdf"
|
||||
pdf_file.write_bytes(b"%PDF-1.4 test")
|
||||
missing_file = workspace_dir / "task1-missing.pdf"
|
||||
|
||||
content = f"{pdf_file} and {pdf_file} and {missing_file}"
|
||||
artifacts = extract_artifacts(content, [])
|
||||
|
||||
assert len(artifacts) == 1
|
||||
item = artifacts[0]
|
||||
assert item["name"] == "task1-dedup.pdf"
|
||||
assert item["mime_type"] == "application/pdf"
|
||||
assert item["previewable"] is True
|
||||
@@ -0,0 +1,145 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
import main
|
||||
|
||||
|
||||
def test_nanobot_chat_syncs_project_id(monkeypatch) -> None:
|
||||
calls: list[dict[str, object]] = []
|
||||
process_kwargs: list[dict[str, object]] = []
|
||||
|
||||
def fake_update_alias_meta(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return kwargs
|
||||
|
||||
async def fake_process_message(*args, **kwargs):
|
||||
process_kwargs.append(kwargs)
|
||||
return "ok"
|
||||
|
||||
monkeypatch.setattr(main.session_alias_store, "update_alias_meta", fake_update_alias_meta)
|
||||
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
|
||||
monkeypatch.setattr(main.nanobot_service, "agent", None)
|
||||
|
||||
request = main.ChatRequest(message="hello", session_id="api:test-1", project_id=101)
|
||||
response = asyncio.run(main.nanobot_chat(request))
|
||||
|
||||
assert response["response"] == "ok"
|
||||
assert calls == [{"session_key": "api:test-1", "project_id": 101}]
|
||||
assert process_kwargs and process_kwargs[0]["project_id"] == 101
|
||||
|
||||
|
||||
def test_nanobot_chat_without_project_id_does_not_sync(monkeypatch) -> None:
|
||||
calls: list[dict[str, object]] = []
|
||||
process_kwargs: list[dict[str, object]] = []
|
||||
|
||||
def fake_update_alias_meta(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return kwargs
|
||||
|
||||
async def fake_process_message(*args, **kwargs):
|
||||
process_kwargs.append(kwargs)
|
||||
return "ok"
|
||||
|
||||
monkeypatch.setattr(main.session_alias_store, "update_alias_meta", fake_update_alias_meta)
|
||||
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
|
||||
monkeypatch.setattr(main.nanobot_service, "agent", None)
|
||||
|
||||
request = main.ChatRequest(message="hello", session_id="api:test-2")
|
||||
response = asyncio.run(main.nanobot_chat(request))
|
||||
|
||||
assert response["response"] == "ok"
|
||||
assert calls == []
|
||||
assert process_kwargs and process_kwargs[0]["project_id"] is None
|
||||
|
||||
|
||||
def test_nanobot_chat_stream_syncs_project_id(monkeypatch) -> None:
|
||||
calls: list[dict[str, object]] = []
|
||||
process_kwargs: list[dict[str, object]] = []
|
||||
|
||||
def fake_update_alias_meta(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return kwargs
|
||||
|
||||
async def fake_process_message(*args, **kwargs):
|
||||
process_kwargs.append(kwargs)
|
||||
on_stream = kwargs.get("on_stream")
|
||||
if on_stream:
|
||||
await on_stream("stream-token")
|
||||
return "stream-complete"
|
||||
|
||||
async def collect_stream_chunks(response) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
async for chunk in response.body_iterator:
|
||||
if isinstance(chunk, bytes):
|
||||
chunks.append(chunk.decode("utf-8"))
|
||||
else:
|
||||
chunks.append(chunk)
|
||||
return chunks
|
||||
|
||||
monkeypatch.setattr(main.session_alias_store, "update_alias_meta", fake_update_alias_meta)
|
||||
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
|
||||
monkeypatch.setattr(main.nanobot_service, "agent", None)
|
||||
|
||||
request = main.ChatRequest(message="hello", session_id="api:test-3", project_id=202)
|
||||
response = asyncio.run(main.nanobot_chat_stream(request))
|
||||
chunks = asyncio.run(collect_stream_chunks(response))
|
||||
content = "".join(chunks)
|
||||
|
||||
assert "stream-token" in content
|
||||
assert "stream-complete" in content
|
||||
assert calls == [{"session_key": "api:test-3", "project_id": 202}]
|
||||
assert process_kwargs and process_kwargs[0]["project_id"] == 202
|
||||
|
||||
|
||||
def test_nanobot_chat_stream_emits_reasoning_flags_and_final_reasoning(monkeypatch) -> None:
|
||||
async def fake_process_message(*args, **kwargs):
|
||||
on_progress = kwargs.get("on_progress")
|
||||
on_stream = kwargs.get("on_stream")
|
||||
if on_progress:
|
||||
await on_progress("模型正在拆解问题", is_reasoning=True)
|
||||
await on_progress("开始执行工具", tool_hint=True)
|
||||
if on_stream:
|
||||
await on_stream("answer-token")
|
||||
return "final-answer"
|
||||
|
||||
class _DummySession:
|
||||
def __init__(self):
|
||||
self.metadata = {}
|
||||
self.messages = [
|
||||
{"role": "assistant", "content": "final-answer", "reasoning_content": "完整思考过程"}
|
||||
]
|
||||
|
||||
class _DummySessions:
|
||||
def get_or_create(self, _key):
|
||||
return _DummySession()
|
||||
|
||||
class _DummyAgent:
|
||||
def __init__(self):
|
||||
self.sessions = _DummySessions()
|
||||
|
||||
async def collect_stream_chunks(response) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk)
|
||||
return chunks
|
||||
|
||||
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
|
||||
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
|
||||
|
||||
request = main.ChatRequest(message="hello", session_id="api:test-4", project_id=303)
|
||||
response = asyncio.run(main.nanobot_chat_stream(request))
|
||||
chunks = asyncio.run(collect_stream_chunks(response))
|
||||
content = "".join(chunks)
|
||||
|
||||
assert '"type": "progress", "content": "模型正在拆解问题", "is_reasoning": true' in content
|
||||
assert '"type": "progress", "content": "开始执行工具", "tool_hint": true' in content
|
||||
assert '"type": "final", "content": "final-answer", "reasoning_content": "完整思考过程"' in content
|
||||
@@ -0,0 +1,28 @@
|
||||
from pathlib import Path
|
||||
|
||||
from app.core import data_root
|
||||
|
||||
|
||||
def test_data_root_prefers_env(monkeypatch, tmp_path: Path) -> None:
|
||||
custom = tmp_path / "custom-data-root"
|
||||
monkeypatch.setenv("DATA_ROOT", str(custom))
|
||||
assert data_root.get_data_root() == custom.resolve()
|
||||
|
||||
|
||||
def test_data_root_falls_back_to_legacy(monkeypatch, tmp_path: Path) -> None:
|
||||
monkeypatch.delenv("DATA_ROOT", raising=False)
|
||||
legacy = tmp_path / "legacy-data"
|
||||
default = tmp_path / "default-data"
|
||||
legacy.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setattr(data_root, "LEGACY_DATA_ROOT", legacy)
|
||||
monkeypatch.setattr(data_root, "DEFAULT_DATA_ROOT", default)
|
||||
assert data_root.get_data_root() == legacy
|
||||
|
||||
|
||||
def test_ensure_data_layout_creates_children(monkeypatch, tmp_path: Path) -> None:
|
||||
monkeypatch.setenv("DATA_ROOT", str(tmp_path / "dr"))
|
||||
data_root.ensure_data_layout()
|
||||
root = data_root.get_data_root()
|
||||
assert (root / "workspace").exists()
|
||||
assert (root / "uploads").exists()
|
||||
assert (root / "data").exists()
|
||||
@@ -0,0 +1,102 @@
|
||||
import json
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from typer.testing import CliRunner
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
app = import_module("app.cli").app
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class _FakeProcess:
|
||||
def __init__(self, pid: int = 9527, exit_code: int | None = None) -> None:
|
||||
self.pid = pid
|
||||
self._exit_code = exit_code
|
||||
|
||||
def poll(self):
|
||||
return self._exit_code
|
||||
|
||||
|
||||
def test_start_command_writes_state(monkeypatch, tmp_path) -> None:
|
||||
pid_file = tmp_path / "run" / "state.json"
|
||||
log_file = tmp_path / "run" / "service.log"
|
||||
|
||||
monkeypatch.setattr("app.cli.subprocess.Popen", lambda *args, **kwargs: _FakeProcess())
|
||||
monkeypatch.setattr("app.cli._wait_for_server_ready", lambda *_args, **_kwargs: True)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"start",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"18999",
|
||||
"--pid-file",
|
||||
str(pid_file),
|
||||
"--log-file",
|
||||
str(log_file),
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "已启动" in result.stdout
|
||||
assert pid_file.exists()
|
||||
state = json.loads(pid_file.read_text(encoding="utf-8"))
|
||||
assert state["pid"] == 9527
|
||||
assert state["host"] == "127.0.0.1"
|
||||
assert state["port"] == 18999
|
||||
|
||||
|
||||
def test_status_command_reports_running(monkeypatch, tmp_path) -> None:
|
||||
pid_file = tmp_path / "run" / "state.json"
|
||||
pid_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
pid_file.write_text(
|
||||
json.dumps({"pid": 9527, "host": "127.0.0.1", "port": 18080}, ensure_ascii=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("app.cli._is_process_running", lambda pid: pid == 9527)
|
||||
result = runner.invoke(app, ["status", "--pid-file", str(pid_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "running" in result.stdout
|
||||
assert "127.0.0.1:18080" in result.stdout
|
||||
|
||||
|
||||
def test_stop_command_cleans_state(monkeypatch, tmp_path) -> None:
|
||||
pid_file = tmp_path / "run" / "state.json"
|
||||
pid_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
pid_file.write_text(json.dumps({"pid": 9527}, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("app.cli._is_process_running", lambda pid: pid == 9527)
|
||||
monkeypatch.setattr("app.cli._stop_pid", lambda pid, timeout: pid == 9527)
|
||||
|
||||
result = runner.invoke(app, ["stop", "--pid-file", str(pid_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "已停止" in result.stdout
|
||||
assert not pid_file.exists()
|
||||
|
||||
|
||||
def test_status_command_cleans_stale_state(monkeypatch, tmp_path) -> None:
|
||||
pid_file = tmp_path / "run" / "state.json"
|
||||
pid_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
pid_file.write_text(json.dumps({"pid": 9527}, ensure_ascii=False), encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("app.cli._is_process_running", lambda _pid: False)
|
||||
result = runner.invoke(app, ["status", "--pid-file", str(pid_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "stopped" in result.stdout
|
||||
assert not pid_file.exists()
|
||||
@@ -0,0 +1,376 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import main
|
||||
from app.context import current_knowledge_base_id
|
||||
from app.schemas.knowledge import KnowledgeSearchResponse
|
||||
from app.tools.knowledge_base import KnowledgeBaseRetrieveTool
|
||||
|
||||
|
||||
def test_knowledge_base_crud_and_search_routes(monkeypatch, tmp_path) -> None:
|
||||
async def fake_start():
|
||||
return None
|
||||
|
||||
async def fake_stop():
|
||||
return None
|
||||
|
||||
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||||
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||||
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||||
|
||||
client = TestClient(main.app)
|
||||
|
||||
create_resp = client.post(
|
||||
"/api/v1/knowledge-bases",
|
||||
json={"name": "产品手册", "description": "用于问答", "top_k": 2, "chunk_size": 256, "chunk_overlap": 20},
|
||||
)
|
||||
assert create_resp.status_code == 200
|
||||
kb = create_resp.json()
|
||||
kb_id = kb["id"]
|
||||
|
||||
list_resp = client.get("/api/v1/knowledge-bases")
|
||||
assert list_resp.status_code == 200
|
||||
assert any(item["id"] == kb_id for item in list_resp.json())
|
||||
|
||||
doc_resp = client.post(
|
||||
f"/api/v1/knowledge-bases/{kb_id}/documents",
|
||||
json={"title": "退款规则", "content": "苹果手机支持7天无理由退款", "metadata": {"lang": "zh"}},
|
||||
)
|
||||
assert doc_resp.status_code == 200
|
||||
doc_id = doc_resp.json()["id"]
|
||||
|
||||
reindex_resp = client.post(f"/api/v1/knowledge-bases/{kb_id}/reindex")
|
||||
assert reindex_resp.status_code == 200
|
||||
|
||||
search_resp = client.post(f"/api/v1/knowledge-bases/{kb_id}/search", json={"query": "苹果退款", "top_k": 2})
|
||||
assert search_resp.status_code == 200
|
||||
parsed = KnowledgeSearchResponse(**search_resp.json())
|
||||
assert parsed.hits
|
||||
assert "苹果" in parsed.answer
|
||||
|
||||
update_resp = client.put(f"/api/v1/knowledge-bases/{kb_id}", json={"name": "售后知识库"})
|
||||
assert update_resp.status_code == 200
|
||||
assert update_resp.json()["name"] == "售后知识库"
|
||||
|
||||
delete_doc_resp = client.delete(f"/api/v1/knowledge-bases/{kb_id}/documents/{doc_id}")
|
||||
assert delete_doc_resp.status_code == 200
|
||||
delete_kb_resp = client.delete(f"/api/v1/knowledge-bases/{kb_id}")
|
||||
assert delete_kb_resp.status_code == 200
|
||||
|
||||
|
||||
def test_knowledge_global_config_mask_and_validation(monkeypatch, tmp_path) -> None:
|
||||
async def fake_start():
|
||||
return None
|
||||
|
||||
async def fake_stop():
|
||||
return None
|
||||
|
||||
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||||
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||||
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||||
client = TestClient(main.app)
|
||||
|
||||
initial_resp = client.get("/api/v1/knowledge-bases/global-config")
|
||||
assert initial_resp.status_code == 200
|
||||
assert initial_resp.json() == {
|
||||
"api_base": None,
|
||||
"api_key": None,
|
||||
"api_key_masked": None,
|
||||
"has_api_key": False,
|
||||
"default_embedding_model": None,
|
||||
}
|
||||
|
||||
update_resp = client.put(
|
||||
"/api/v1/knowledge-bases/global-config",
|
||||
json={"api_base": "https://kb.example.com/", "api_key": "sk-knowledge-secret", "default_embedding_model": "text-embedding-3-small"},
|
||||
)
|
||||
assert update_resp.status_code == 200
|
||||
body = update_resp.json()
|
||||
assert body["api_base"] == "https://kb.example.com"
|
||||
assert body["api_key"] is None
|
||||
assert body["has_api_key"] is True
|
||||
assert body["api_key_masked"] == "sk-k***********cret"
|
||||
assert body["default_embedding_model"] == "text-embedding-3-small"
|
||||
|
||||
get_resp = client.get("/api/v1/knowledge-bases/global-config")
|
||||
assert get_resp.status_code == 200
|
||||
assert get_resp.json()["api_key"] is None
|
||||
assert get_resp.json()["api_key_masked"] == "sk-k***********cret"
|
||||
assert get_resp.json()["default_embedding_model"] == "text-embedding-3-small"
|
||||
|
||||
invalid_resp = client.put("/api/v1/knowledge-bases/global-config", json={"api_base": "ftp://kb.example.com"})
|
||||
assert invalid_resp.status_code == 422
|
||||
|
||||
|
||||
def test_chat_request_syncs_knowledge_base_metadata(monkeypatch) -> None:
|
||||
captured_kb_ids: list[str | None] = []
|
||||
captured_messages: list[str] = []
|
||||
|
||||
class _DummySession:
|
||||
def __init__(self):
|
||||
self.metadata = {}
|
||||
self.messages = []
|
||||
self.updated_at = None
|
||||
|
||||
class _DummySessions:
|
||||
def __init__(self):
|
||||
self._sessions: dict[str, _DummySession] = {}
|
||||
|
||||
def get_or_create(self, key: str):
|
||||
if key not in self._sessions:
|
||||
self._sessions[key] = _DummySession()
|
||||
return self._sessions[key]
|
||||
|
||||
def save(self, _session):
|
||||
return None
|
||||
|
||||
class _DummyAgent:
|
||||
def __init__(self):
|
||||
self.sessions = _DummySessions()
|
||||
|
||||
async def fake_process_message(*args, **kwargs):
|
||||
captured_kb_ids.append(current_knowledge_base_id.get())
|
||||
if args and isinstance(args[0], str):
|
||||
captured_messages.append(args[0])
|
||||
return "ok"
|
||||
|
||||
def fake_search(*, kb_id: str, query: str, top_k=None):
|
||||
assert kb_id == "kb-123"
|
||||
assert query == "请回答售后规则"
|
||||
return {
|
||||
"answer": "命中结果",
|
||||
"hits": [
|
||||
{"doc_id": "d1", "title": "退款规则", "chunk": "7天无理由退款", "score": 0.9, "metadata": {}},
|
||||
{"doc_id": "d2", "title": "售后电话", "chunk": "客服电话 400-1234", "score": 0.7, "metadata": {}},
|
||||
],
|
||||
}
|
||||
|
||||
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
|
||||
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
|
||||
monkeypatch.setattr("main.knowledge_index_service.search", fake_search)
|
||||
|
||||
request = main.ChatRequest(message="请回答售后规则", session_id="api:kb-1", knowledge_base_id="kb-123")
|
||||
response = asyncio.run(main.nanobot_chat(request))
|
||||
assert response["response"] == "ok"
|
||||
assert "kb_citations" in response
|
||||
assert len(response["kb_citations"]) == 2
|
||||
assert response["kb_citations"][0]["title"] == "退款规则"
|
||||
assert captured_kb_ids == ["kb-123"]
|
||||
assert captured_messages and "7天无理由退款" in captured_messages[0]
|
||||
session = main.nanobot_service.agent.sessions.get_or_create("api:kb-1")
|
||||
assert session.metadata["selected_knowledge_base_id"] == "kb-123"
|
||||
|
||||
|
||||
def test_knowledge_tool_uses_session_context(monkeypatch) -> None:
|
||||
tool = KnowledgeBaseRetrieveTool()
|
||||
token = current_knowledge_base_id.set("kb-session")
|
||||
called: list[dict] = []
|
||||
|
||||
def fake_search(*, kb_id: str, query: str, top_k=None):
|
||||
called.append({"kb_id": kb_id, "query": query, "top_k": top_k})
|
||||
return {
|
||||
"answer": "命中结果",
|
||||
"hits": [{"doc_id": "d1", "title": "t1", "chunk": "命中结果", "score": 1.0, "metadata": {}}],
|
||||
}
|
||||
|
||||
monkeypatch.setattr("app.tools.knowledge_base.knowledge_index_service.search", fake_search)
|
||||
try:
|
||||
output = asyncio.run(tool.execute(query="售后政策"))
|
||||
finally:
|
||||
current_knowledge_base_id.reset(token)
|
||||
assert called and called[0]["kb_id"] == "kb-session"
|
||||
assert "命中结果" in output
|
||||
|
||||
|
||||
def test_update_session_context_file_supports_knowledge_base(monkeypatch) -> None:
|
||||
class _DummySession:
|
||||
def __init__(self):
|
||||
self.metadata = {}
|
||||
self.updated_at = None
|
||||
|
||||
class _DummySessions:
|
||||
def __init__(self):
|
||||
self.session = _DummySession()
|
||||
|
||||
def get_or_create(self, _key: str):
|
||||
return self.session
|
||||
|
||||
def save(self, _session):
|
||||
return None
|
||||
|
||||
class _DummyAgent:
|
||||
def __init__(self):
|
||||
self.sessions = _DummySessions()
|
||||
|
||||
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
|
||||
payload = main.SessionFileContextUpdateRequest(selected_knowledge_base_id="kb-ctx")
|
||||
response = main.update_session_context_file("api:ctx", payload)
|
||||
assert response["status"] == "success"
|
||||
assert response["metadata"]["selected_knowledge_base_id"] == "kb-ctx"
|
||||
|
||||
|
||||
def test_knowledge_global_connection_test_route(monkeypatch, tmp_path) -> None:
|
||||
async def fake_start():
|
||||
return None
|
||||
|
||||
async def fake_stop():
|
||||
return None
|
||||
|
||||
class _DummyEmbeddingData:
|
||||
embedding = [0.1, 0.2, 0.3]
|
||||
|
||||
class _DummyEmbeddingResp:
|
||||
data = [_DummyEmbeddingData()]
|
||||
|
||||
class _DummyEmbeddingsAPI:
|
||||
@staticmethod
|
||||
def create(model: str, input: str):
|
||||
assert model == "text-embedding-3-small"
|
||||
assert input == "connection test"
|
||||
return _DummyEmbeddingResp()
|
||||
|
||||
class _DummyOpenAI:
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
assert api_key == "sk-knowledge-secret"
|
||||
assert base_url == "https://kb.example.com"
|
||||
self.embeddings = _DummyEmbeddingsAPI()
|
||||
|
||||
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||||
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||||
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||||
monkeypatch.setattr("app.api.knowledge.OpenAI", _DummyOpenAI)
|
||||
|
||||
client = TestClient(main.app)
|
||||
save_resp = client.put(
|
||||
"/api/v1/knowledge-bases/global-config",
|
||||
json={
|
||||
"api_base": "https://kb.example.com",
|
||||
"api_key": "sk-knowledge-secret",
|
||||
"default_embedding_model": "text-embedding-3-small",
|
||||
},
|
||||
)
|
||||
assert save_resp.status_code == 200
|
||||
|
||||
test_resp = client.post(
|
||||
"/api/v1/knowledge-bases/global-config/test-connection",
|
||||
json={"model_name": "text-embedding-3-small"},
|
||||
)
|
||||
assert test_resp.status_code == 200
|
||||
body = test_resp.json()
|
||||
assert body["success"] is True
|
||||
assert body["model_name"] == "text-embedding-3-small"
|
||||
assert body["embedding_dimension"] == 3
|
||||
assert body["resolved_api_base"] == "https://kb.example.com"
|
||||
assert body["available_models"] == []
|
||||
|
||||
|
||||
def test_knowledge_global_connection_test_route_requires_model_name(monkeypatch, tmp_path) -> None:
|
||||
async def fake_start():
|
||||
return None
|
||||
|
||||
async def fake_stop():
|
||||
return None
|
||||
|
||||
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||||
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||||
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||||
|
||||
client = TestClient(main.app)
|
||||
resp = client.post(
|
||||
"/api/v1/knowledge-bases/global-config/test-connection",
|
||||
json={
|
||||
"api_base": "https://api.siliconflow.cn/v1/embeddings",
|
||||
"api_key": "ark-key",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "测试连接必须显式填写向量模型名称" in resp.json()["detail"]
|
||||
|
||||
|
||||
def test_knowledge_global_connection_test_route_returns_remote_error(monkeypatch, tmp_path) -> None:
|
||||
async def fake_start():
|
||||
return None
|
||||
|
||||
async def fake_stop():
|
||||
return None
|
||||
|
||||
class _DummyEmbeddingsAPI:
|
||||
@staticmethod
|
||||
def create(model: str, input: str):
|
||||
assert model == "BAAI/bge-large-zh-v1.5"
|
||||
assert input == "connection test"
|
||||
raise RuntimeError("Not Found")
|
||||
|
||||
class _DummyOpenAI:
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
assert api_key == "sf-key"
|
||||
assert base_url == "https://api.siliconflow.cn/v1"
|
||||
self.embeddings = _DummyEmbeddingsAPI()
|
||||
|
||||
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||||
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||||
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||||
monkeypatch.setattr("app.api.knowledge.OpenAI", _DummyOpenAI)
|
||||
|
||||
client = TestClient(main.app)
|
||||
resp = client.post(
|
||||
"/api/v1/knowledge-bases/global-config/test-connection",
|
||||
json={
|
||||
"api_base": "https://api.siliconflow.cn/v1/embeddings",
|
||||
"api_key": "sf-key",
|
||||
"model_name": "BAAI/bge-large-zh-v1.5",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert "Embedding调用失败" in resp.json()["detail"]
|
||||
|
||||
|
||||
def test_knowledge_document_upload_route(monkeypatch, tmp_path) -> None:
|
||||
async def fake_start():
|
||||
return None
|
||||
|
||||
async def fake_stop():
|
||||
return None
|
||||
|
||||
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
|
||||
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||||
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||||
|
||||
client = TestClient(main.app)
|
||||
create_resp = client.post(
|
||||
"/api/v1/knowledge-bases",
|
||||
json={"name": "上传测试库", "description": "用于上传", "top_k": 2, "chunk_size": 256, "chunk_overlap": 20},
|
||||
)
|
||||
assert create_resp.status_code == 200
|
||||
kb_id = create_resp.json()["id"]
|
||||
|
||||
files = [
|
||||
("files", ("doc1.txt", b"hello knowledge", "text/plain")),
|
||||
("files", ("doc2.md", b"# title\ncontent", "text/markdown")),
|
||||
]
|
||||
upload_resp = client.post(
|
||||
f"/api/v1/knowledge-bases/{kb_id}/documents/upload",
|
||||
files=files,
|
||||
data={"metadata": "{\"source\":\"batch\"}"},
|
||||
)
|
||||
assert upload_resp.status_code == 200
|
||||
body = upload_resp.json()
|
||||
assert body["status"] == "success"
|
||||
assert body["count"] == 2
|
||||
assert len(body["documents"]) == 2
|
||||
|
||||
list_resp = client.get(f"/api/v1/knowledge-bases/{kb_id}/documents")
|
||||
assert list_resp.status_code == 200
|
||||
docs = list_resp.json()
|
||||
assert len(docs) == 2
|
||||
@@ -0,0 +1,110 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.core.nanobot import NanobotIntegration
|
||||
from app.context import current_session_id
|
||||
|
||||
|
||||
class _DummySessions:
|
||||
def __init__(self) -> None:
|
||||
self.saved = []
|
||||
self._session = SimpleNamespace(messages=[])
|
||||
|
||||
def get_or_create(self, _session_id: str):
|
||||
return self._session
|
||||
|
||||
def save(self, session) -> None:
|
||||
self.saved.append(session)
|
||||
|
||||
|
||||
class _DummyAgent:
|
||||
def __init__(self) -> None:
|
||||
self.sessions = _DummySessions()
|
||||
self.provider = SimpleNamespace(default_model="demo-model")
|
||||
self.model = "demo-model"
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return "ok"
|
||||
|
||||
|
||||
def test_process_message_project_id_fallback_from_session_alias(monkeypatch) -> None:
|
||||
service = NanobotIntegration()
|
||||
base_agent = _DummyAgent()
|
||||
custom_agent = _DummyAgent()
|
||||
service.agent = base_agent
|
||||
service._started = True
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def fake_get_or_create_model_agent(model_id, target_config, project_id):
|
||||
captured["project_id"] = project_id
|
||||
return custom_agent
|
||||
|
||||
monkeypatch.setattr(service, "_get_or_create_model_agent", fake_get_or_create_model_agent)
|
||||
monkeypatch.setattr("app.core.nanobot.get_llm_configs", lambda: [])
|
||||
monkeypatch.setattr("app.core.nanobot.get_active_llm_config", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"app.core.session_alias_store.session_alias_store.get_alias_meta",
|
||||
lambda _session_id: {"project_id": 77},
|
||||
)
|
||||
|
||||
response = asyncio.run(service.process_message("hello", session_id="api:s1"))
|
||||
|
||||
assert response == "ok"
|
||||
assert captured["project_id"] == 77
|
||||
|
||||
|
||||
def test_process_message_project_id_prefers_request_value(monkeypatch) -> None:
|
||||
service = NanobotIntegration()
|
||||
base_agent = _DummyAgent()
|
||||
custom_agent = _DummyAgent()
|
||||
service.agent = base_agent
|
||||
service._started = True
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def fake_get_or_create_model_agent(model_id, target_config, project_id):
|
||||
captured["project_id"] = project_id
|
||||
return custom_agent
|
||||
|
||||
monkeypatch.setattr(service, "_get_or_create_model_agent", fake_get_or_create_model_agent)
|
||||
monkeypatch.setattr("app.core.nanobot.get_llm_configs", lambda: [])
|
||||
monkeypatch.setattr("app.core.nanobot.get_active_llm_config", lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"app.core.session_alias_store.session_alias_store.get_alias_meta",
|
||||
lambda _session_id: {"project_id": 88},
|
||||
)
|
||||
|
||||
response = asyncio.run(service.process_message("hello", session_id="api:s2", project_id=9))
|
||||
|
||||
assert response == "ok"
|
||||
assert captured["project_id"] == 9
|
||||
|
||||
|
||||
def test_register_custom_tools_always_contains_subagent_tools() -> None:
|
||||
service = NanobotIntegration()
|
||||
names: list[str] = []
|
||||
|
||||
class _ToolRegistry:
|
||||
def register(self, tool) -> None:
|
||||
names.append(tool.name)
|
||||
|
||||
fake_agent = SimpleNamespace(tools=_ToolRegistry())
|
||||
service._register_custom_tools(fake_agent, project_id=None)
|
||||
|
||||
assert "list_subagents" in names
|
||||
assert "invoke_subagent" in names
|
||||
|
||||
|
||||
def test_subagent_tool_resolves_project_from_session_alias(monkeypatch) -> None:
|
||||
from app.tools.subagent import _resolve_project_id
|
||||
|
||||
token = current_session_id.set("api:subagent-test")
|
||||
try:
|
||||
monkeypatch.setattr(
|
||||
"app.tools.subagent.session_alias_store.get_alias_meta",
|
||||
lambda _session_id: {"project_id": 66},
|
||||
)
|
||||
assert _resolve_project_id(None) == 66
|
||||
finally:
|
||||
current_session_id.reset(token)
|
||||
@@ -0,0 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest
|
||||
|
||||
async def main():
|
||||
req = NL2SQLRequest(query="列出所有表", source="postgres", generate_chart=False)
|
||||
res = await process_nl2sql(req)
|
||||
print("SQL:", res.sql)
|
||||
print("Error:", res.error)
|
||||
print("Result:", res.result)
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nl2sql_optimized_flow():
|
||||
# Mock parameters
|
||||
query = "Show me the top 5 sales"
|
||||
source = "ds:1"
|
||||
|
||||
# Mock connector and schema
|
||||
mock_connector = MagicMock()
|
||||
mock_connector.get_schema.return_value = {
|
||||
"sales": {"columns": [{"name": "id", "type": "INT"}, {"name": "amount", "type": "DECIMAL"}]},
|
||||
"users": {"columns": [{"name": "id", "type": "INT"}, {"name": "name", "type": "TEXT"}]},
|
||||
"logs": {"columns": [{"name": "id", "type": "INT"}, {"name": "event", "type": "TEXT"}]},
|
||||
"products": {"columns": [{"name": "id", "type": "INT"}]},
|
||||
"categories": {"columns": [{"name": "id", "type": "INT"}]},
|
||||
"inventory": {"columns": [{"name": "id", "type": "INT"}]}
|
||||
}
|
||||
mock_connector.test_connection.return_value = True
|
||||
mock_connector.execute_query.return_value = [{"id": 1, "amount": 100}]
|
||||
|
||||
# Mock LLM provider
|
||||
mock_provider = AsyncMock()
|
||||
|
||||
# First response for Table Selector
|
||||
mock_resp_tables = MagicMock()
|
||||
mock_resp_tables.content = '["sales"]'
|
||||
mock_resp_tables.finish_reason = "stop"
|
||||
|
||||
# Second response for SQL Generation
|
||||
mock_resp_sql = MagicMock()
|
||||
mock_resp_sql.content = '{"reasoning": "Plan...", "sql": "SELECT * FROM sales LIMIT 5"}'
|
||||
mock_resp_sql.finish_reason = "stop"
|
||||
|
||||
mock_provider.chat.side_effect = [mock_resp_tables, mock_resp_sql]
|
||||
|
||||
# Patch dependencies
|
||||
with patch("app.agent.nl2sql.get_active_llm_config", return_value={"model": "gpt-4"}), \
|
||||
patch("app.agent.nl2sql.build_llm_provider", return_value=mock_provider), \
|
||||
patch("app.agent.nl2sql.get_connector", return_value=mock_connector), \
|
||||
patch("app.agent.nl2sql.SessionLocal"), \
|
||||
patch("app.agent.nl2sql.DataSource"), \
|
||||
patch("app.agent.nl2sql.postgres_connector", mock_connector), \
|
||||
patch("app.agent.nl2sql._check_connection_with_cache", return_value=True):
|
||||
|
||||
request = NL2SQLRequest(query=query, source=source)
|
||||
response = await process_nl2sql(request)
|
||||
|
||||
print(f"DEBUG: Response SQL: '{response.sql}'")
|
||||
print(f"DEBUG: Response Error: '{response.error}'")
|
||||
|
||||
assert response.sql == "SELECT * FROM sales LIMIT 5"
|
||||
assert len(response.result) == 1
|
||||
assert response.error is None
|
||||
|
||||
# Verify provider was called twice
|
||||
assert mock_provider.chat.call_count == 2
|
||||
|
||||
# Verify first call was for table selection
|
||||
args, kwargs = mock_provider.chat.call_args_list[0]
|
||||
assert "TABLE_SELECTOR_SYSTEM_PROMPT" in str(args) or "Identifying relevant tables" in str(args) or any("system" in m["role"] for m in kwargs["messages"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(test_nl2sql_optimized_flow())
|
||||
@@ -0,0 +1,62 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
from app.core.llm_provider import build_llm_provider
|
||||
from app.core.nanobot import NanobotIntegration
|
||||
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
|
||||
|
||||
|
||||
def test_build_llm_provider_uses_max_completion_tokens_for_gpt5() -> None:
|
||||
provider = build_llm_provider(
|
||||
model="gpt-5.4-nano",
|
||||
provider="openai",
|
||||
api_key="test-key",
|
||||
api_base="https://example.com/v1",
|
||||
)
|
||||
|
||||
assert isinstance(provider, PatchedOpenAICompatProvider)
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
model="gpt-5.4-nano",
|
||||
max_tokens=5,
|
||||
temperature=0,
|
||||
reasoning_effort=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert kwargs["max_completion_tokens"] == 5
|
||||
assert "max_tokens" not in kwargs
|
||||
|
||||
|
||||
def test_nanobot_provider_keeps_max_tokens_for_legacy_models() -> None:
|
||||
integration = NanobotIntegration()
|
||||
provider = integration._build_provider(
|
||||
model="gpt-4o-mini",
|
||||
provider_name="openai",
|
||||
api_key="test-key",
|
||||
api_base="https://example.com/v1",
|
||||
extra_headers=None,
|
||||
)
|
||||
|
||||
assert isinstance(provider, PatchedOpenAICompatProvider)
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
model="gpt-4o-mini",
|
||||
max_tokens=5,
|
||||
temperature=0,
|
||||
reasoning_effort=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert kwargs["max_tokens"] == 5
|
||||
assert "max_completion_tokens" not in kwargs
|
||||
@@ -0,0 +1,102 @@
|
||||
import sys
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.core.security import CurrentUser, get_current_user
|
||||
from app.database import Base, get_db
|
||||
from app.models.project import Project
|
||||
from app.models.subagent import Subagent
|
||||
from app.models.user import User
|
||||
from main import app
|
||||
|
||||
|
||||
def _seed_subagent(db: Session) -> tuple[User, Project, Subagent]:
|
||||
user = User(
|
||||
username="task3-owner",
|
||||
email="task3-owner@example.com",
|
||||
hashed_password="test",
|
||||
is_admin=False,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
project = Project(
|
||||
name="task3-project",
|
||||
description="task3",
|
||||
owner_id=user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
|
||||
subagent = Subagent(
|
||||
project_id=project.id,
|
||||
name="task3-subagent",
|
||||
description="task3",
|
||||
instructions="do task3",
|
||||
model="gpt",
|
||||
)
|
||||
db.add(subagent)
|
||||
db.commit()
|
||||
db.refresh(subagent)
|
||||
return user, project, subagent
|
||||
|
||||
|
||||
def test_subagent_detail_route_is_global_and_project_scoped_route_is_invalid() -> None:
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
db = testing_session_local()
|
||||
user, project, subagent = _seed_subagent(db)
|
||||
user_id = user.id
|
||||
username = user.username
|
||||
project_id = project.id
|
||||
subagent_id = subagent.id
|
||||
db.close()
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return CurrentUser(id=user_id, username=username, is_admin=False)
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
|
||||
try:
|
||||
client = TestClient(app)
|
||||
response = client.get(f"/api/v1/subagents/{subagent_id}")
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["id"] == subagent_id
|
||||
assert body["project_id"] == project_id
|
||||
|
||||
legacy_path_response = client.get(f"/api/v1/projects/{project_id}/subagents/{subagent_id}")
|
||||
assert legacy_path_response.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
@@ -0,0 +1,136 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
from app.context import current_session_id
|
||||
from app.core.security import CurrentUser, get_current_user
|
||||
from app.database import Base, get_db
|
||||
from app.models.project import Project
|
||||
from app.models.user import User
|
||||
from app.tools.subagent import InvokeSubagentTool, ListSubagentsTool
|
||||
from main import app
|
||||
|
||||
|
||||
def _seed_owner_and_project(db: Session) -> tuple[User, Project]:
|
||||
user = User(
|
||||
username="task4-owner",
|
||||
email="task4-owner@example.com",
|
||||
hashed_password="test",
|
||||
is_admin=False,
|
||||
)
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
|
||||
project = Project(
|
||||
name="task4-project",
|
||||
description="task4",
|
||||
owner_id=user.id,
|
||||
)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return user, project
|
||||
|
||||
|
||||
def test_create_subagent_then_list_and_invoke_success(monkeypatch) -> None:
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
db = testing_session_local()
|
||||
user, project = _seed_owner_and_project(db)
|
||||
user_id = user.id
|
||||
username = user.username
|
||||
project_id = project.id
|
||||
db.close()
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return CurrentUser(id=user_id, username=username, is_admin=False)
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
|
||||
token = current_session_id.set("api:task4-regression")
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def fake_process_message(message, session_id, project_id, model_id):
|
||||
captured["message"] = message
|
||||
captured["session_id"] = session_id
|
||||
captured["project_id"] = project_id
|
||||
captured["model_id"] = model_id
|
||||
return "invoke-ok"
|
||||
|
||||
try:
|
||||
monkeypatch.setattr("app.tools.subagent.SessionLocal", testing_session_local)
|
||||
monkeypatch.setattr(
|
||||
"app.tools.subagent.session_alias_store.get_alias_meta",
|
||||
lambda _session_id: {"project_id": project_id},
|
||||
)
|
||||
monkeypatch.setattr("app.tools.subagent.get_llm_configs", lambda: [])
|
||||
monkeypatch.setattr("app.tools.subagent.get_active_llm_config", lambda: None)
|
||||
monkeypatch.setattr("app.tools.subagent.nanobot_service.process_message", fake_process_message)
|
||||
|
||||
client = TestClient(app)
|
||||
create_response = client.post(
|
||||
f"/api/v1/projects/{project_id}/subagents",
|
||||
json={
|
||||
"name": "task4-subagent",
|
||||
"description": "task4-desc",
|
||||
"instructions": "focus on regression",
|
||||
"model": "gpt-x",
|
||||
},
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
created = create_response.json()
|
||||
assert created["project_id"] == project_id
|
||||
assert created["name"] == "task4-subagent"
|
||||
|
||||
listed = asyncio.run(ListSubagentsTool().execute())
|
||||
listed_payload = json.loads(listed)
|
||||
assert len(listed_payload) == 1
|
||||
assert listed_payload[0]["name"] == "task4-subagent"
|
||||
assert listed_payload[0]["description"] == "task4-desc"
|
||||
|
||||
invoke_result = asyncio.run(
|
||||
InvokeSubagentTool().execute(
|
||||
subagent_name="task4-subagent",
|
||||
task="run regression task",
|
||||
)
|
||||
)
|
||||
assert "completed the task" in invoke_result
|
||||
assert "invoke-ok" in invoke_result
|
||||
assert captured["project_id"] == project_id
|
||||
assert captured["session_id"] == f"api:task4-regression:subagent:{created['id']}"
|
||||
assert "focus on regression" in str(captured["message"])
|
||||
finally:
|
||||
current_session_id.reset(token)
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
@@ -0,0 +1,115 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
import main
|
||||
from app.trace.attributes import build_chat_trace_attributes, sanitize_attributes
|
||||
from app.trace.service import TraceService
|
||||
|
||||
|
||||
def test_trace_service_initialize_without_keys(monkeypatch) -> None:
|
||||
monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False)
|
||||
monkeypatch.delenv("LANGFUSE_SECRET_KEY", raising=False)
|
||||
service = TraceService()
|
||||
assert service.initialize() is False
|
||||
assert service.enabled is False
|
||||
assert service.initialized is True
|
||||
|
||||
|
||||
def test_trace_attribute_helpers() -> None:
|
||||
attrs = sanitize_attributes(
|
||||
{
|
||||
"session_id": "api:test",
|
||||
"project_id": 1,
|
||||
"skip_none": None,
|
||||
"obj": {"a": 1},
|
||||
}
|
||||
)
|
||||
assert attrs["session_id"] == "api:test"
|
||||
assert attrs["project_id"] == 1
|
||||
assert "skip_none" not in attrs
|
||||
assert attrs["obj"] == "{'a': 1}"
|
||||
chat_attrs = build_chat_trace_attributes(
|
||||
session_id="api:test",
|
||||
project_id=9,
|
||||
model_id="model-a",
|
||||
route_mode="auto",
|
||||
source="postgres",
|
||||
knowledge_base_id=None,
|
||||
)
|
||||
assert chat_attrs["component"] == "chat_stream"
|
||||
assert chat_attrs["session_id"] == "api:test"
|
||||
assert chat_attrs["project_id"] == 9
|
||||
|
||||
|
||||
def test_nanobot_chat_stream_uses_trace_span(monkeypatch) -> None:
|
||||
calls: list[tuple[str, dict]] = []
|
||||
updates: list[dict] = []
|
||||
trace_updates: list[dict] = []
|
||||
|
||||
class _Span:
|
||||
def set_attributes(self, attributes):
|
||||
updates.append(attributes)
|
||||
|
||||
def update(self, **kwargs):
|
||||
updates.append(kwargs)
|
||||
|
||||
def update_trace(self, **kwargs):
|
||||
trace_updates.append(kwargs)
|
||||
|
||||
def record_error(self, _exc, *, stage: str = "unknown"):
|
||||
updates.append({"stage": stage})
|
||||
|
||||
class _SpanCtx:
|
||||
def __init__(self, name: str, attributes: dict):
|
||||
self._name = name
|
||||
self._attributes = attributes
|
||||
|
||||
def __enter__(self):
|
||||
calls.append((self._name, self._attributes))
|
||||
return _Span()
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def fake_start_span(name: str, *, attributes=None, input_payload=None):
|
||||
payload = dict(attributes or {})
|
||||
if input_payload is not None:
|
||||
payload["input_payload"] = input_payload
|
||||
return _SpanCtx(name, payload)
|
||||
|
||||
async def fake_process_message(*args, **kwargs):
|
||||
on_stream = kwargs.get("on_stream")
|
||||
if on_stream:
|
||||
await on_stream("token")
|
||||
return "ok"
|
||||
|
||||
async def collect_stream_chunks(response) -> list[str]:
|
||||
chunks: list[str] = []
|
||||
async for chunk in response.body_iterator:
|
||||
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk)
|
||||
return chunks
|
||||
|
||||
monkeypatch.setattr(main.trace_service, "start_span", fake_start_span)
|
||||
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
|
||||
monkeypatch.setattr(main.nanobot_service, "agent", None)
|
||||
|
||||
request = main.ChatRequest(message="hello", session_id="api:trace-test", project_id=7)
|
||||
response = asyncio.run(main.nanobot_chat_stream(request))
|
||||
chunks = asyncio.run(collect_stream_chunks(response))
|
||||
content = "".join(chunks)
|
||||
|
||||
assert "token" in content
|
||||
assert "ok" in content
|
||||
assert calls
|
||||
assert calls[0][0] == "chat.stream"
|
||||
assert trace_updates and trace_updates[0]["session_id"] == "api:trace-test"
|
||||
@@ -0,0 +1,78 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
import main
|
||||
|
||||
|
||||
def _prepare_webui(monkeypatch, tmp_path: Path) -> None:
|
||||
webui_dir = tmp_path / "webui"
|
||||
assets_dir = webui_dir / "assets"
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
(webui_dir / "index.html").write_text("<html><body>dataclaw-webui</body></html>", encoding="utf-8")
|
||||
(assets_dir / "app.js").write_text("window.__TASK2__=true;", encoding="utf-8")
|
||||
monkeypatch.setattr(main, "_WEBUI_DIR", webui_dir)
|
||||
monkeypatch.setattr(main, "_WEBUI_INDEX", webui_dir / "index.html")
|
||||
monkeypatch.setattr(main, "_WEBUI_STATIC", StaticFiles(directory=str(webui_dir), html=False))
|
||||
|
||||
|
||||
def _prepare_lifecycle(monkeypatch) -> None:
|
||||
async def fake_start():
|
||||
return None
|
||||
|
||||
async def fake_stop():
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
|
||||
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
|
||||
|
||||
|
||||
def test_webui_static_assets_served_from_backend(monkeypatch, tmp_path) -> None:
|
||||
_prepare_webui(monkeypatch, tmp_path)
|
||||
_prepare_lifecycle(monkeypatch)
|
||||
client = TestClient(main.app)
|
||||
|
||||
index_resp = client.get("/")
|
||||
assert index_resp.status_code == 200
|
||||
assert "dataclaw-webui" in index_resp.text
|
||||
|
||||
asset_resp = client.get("/assets/app.js")
|
||||
assert asset_resp.status_code == 200
|
||||
assert "window.__TASK2__=true;" in asset_resp.text
|
||||
|
||||
|
||||
def test_spa_route_fallback_to_index_html(monkeypatch, tmp_path) -> None:
|
||||
_prepare_webui(monkeypatch, tmp_path)
|
||||
_prepare_lifecycle(monkeypatch)
|
||||
client = TestClient(main.app)
|
||||
|
||||
spa_resp = client.get("/settings/users")
|
||||
assert spa_resp.status_code == 200
|
||||
assert "dataclaw-webui" in spa_resp.text
|
||||
|
||||
missing_asset_resp = client.get("/assets/missing.js")
|
||||
assert missing_asset_resp.status_code == 404
|
||||
|
||||
|
||||
def test_backend_accessible_without_frontend_dev_server(monkeypatch, tmp_path) -> None:
|
||||
_prepare_webui(monkeypatch, tmp_path)
|
||||
_prepare_lifecycle(monkeypatch)
|
||||
client = TestClient(main.app)
|
||||
|
||||
ui_resp = client.get("/")
|
||||
assert ui_resp.status_code == 200
|
||||
assert "dataclaw-webui" in ui_resp.text
|
||||
|
||||
api_resp = client.get("/nanobot/status")
|
||||
assert api_resp.status_code == 200
|
||||
assert api_resp.json()["status"] in {"running", "stopped"}
|
||||
Reference in New Issue
Block a user