Update 2026-05-13 16:43:53

This commit is contained in:
yi
2026-05-13 16:43:53 +08:00
parent 6af5c584f4
commit afd7c5fe85
490 changed files with 850 additions and 922 deletions
+872
View File
@@ -0,0 +1,872 @@
import asyncio
import hashlib
import hmac
import json
import sys
import time
from collections.abc import Generator
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
import httpx
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
from app.core.security import CurrentUser, get_current_user
from app.database import Base, get_db
from app.models.a2a import A2ARemoteAgent, A2ATask, A2ATaskState
from app.models.project import Project
from app.models.user import User
from app.schemas.a2a import A2AMessageRole, A2APartType, AgentSkillOutputMode, AgentSkillInputMode
from app.services.a2a_service import a2a_runtime, SharedSecretAuth
from main import app
def _seed(db: Session) -> tuple[int, str, int, str, int]:
owner = User(username="a2a_owner", email="a2a_owner@example.com", hashed_password="x", is_admin=False)
other = User(username="a2a_other", email="a2a_other@example.com", hashed_password="x", is_admin=False)
db.add(owner)
db.add(other)
db.commit()
db.refresh(owner)
db.refresh(other)
project = Project(name="a2a_project", description="a2a", owner_id=owner.id)
db.add(project)
db.commit()
db.refresh(project)
return owner.id, owner.username, other.id, other.username, project.id
def _make_message_payload(project_id: int, text: str, session_id: str = "test-session", route_mode: str = "local_first", idempotency_key: Optional[str] = None) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"message": {
"messageId": f"msg-{int(time.time()*1000)}",
"role": "user",
"parts": [
{
"part_type": "data",
"data": {
"project_id": project_id,
"route_mode": route_mode,
"session_id": session_id,
**( {"idempotency_key": idempotency_key} if idempotency_key else {} )
},
"mediaType": "application/json",
},
{
"part_type": "text",
"text": text,
}
],
}
}
return payload
class TestPartSerialization:
def test_part_text_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.TEXT,
text="Hello world",
mediaType="text/plain",
)
data = part.model_dump()
assert data["part_type"] == "text"
assert data["text"] == "Hello world"
assert data["mediaType"] == "text/plain"
def test_part_data_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.DATA,
data={"project_id": 123, "route_mode": "local"},
mediaType="application/json",
)
data = part.model_dump()
assert data["part_type"] == "data"
assert data["data"]["project_id"] == 123
def test_part_url_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.URL,
url="https://example.com/file.pdf",
mediaType="application/pdf",
filename="file.pdf",
)
data = part.model_dump()
assert data["part_type"] == "url"
assert data["url"] == "https://example.com/file.pdf"
assert data["filename"] == "file.pdf"
def test_part_raw_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.RAW,
raw="\x00\x01\x02\x03",
mediaType="application/octet-stream",
)
data = part.model_dump()
assert data["part_type"] == "raw"
assert data["raw"] == "\x00\x01\x02\x03"
class TestStateMachine:
def test_state_transitions_submit_to_complete(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "test state machine"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
assert task.state == A2ATaskState.SUBMITTED
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.WORKING)
assert task.state == A2ATaskState.WORKING
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.COMPLETED)
assert task.state == A2ATaskState.COMPLETED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_cancel(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
assert cancel_resp.status_code == 200
assert cancel_resp.json()["state"] == "CANCELED"
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_failed(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "fail test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.FAILED, error_message='{"message": "test error"}')
assert task.state == A2ATaskState.FAILED
assert task.error_message is not None
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_rejected(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "reject test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.REJECTED)
assert task.state == A2ATaskState.REJECTED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_input_required(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "input required test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.INPUT_REQUIRED)
assert task.state == A2ATaskState.INPUT_REQUIRED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_auth_required(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "auth required test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.AUTH_REQUIRED)
assert task.state == A2ATaskState.AUTH_REQUIRED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestA2APathNormalization:
def test_message_send_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "path test"))
assert resp.status_code == 200
assert "task" in resp.json()
assert "id" in resp.json()["task"]
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_message_stream_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:stream", json=_make_message_payload(project_id, "stream path test"))
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_tasks_cancel_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel path test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
assert cancel_resp.status_code == 200
assert cancel_resp.json()["task_id"] == task_id
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_agent_card_public_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
db.close()
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
app.dependency_overrides[get_db] = override_get_db
try:
client = TestClient(app)
resp = client.get("/api/v1/.well-known/agent-card.json")
assert resp.status_code == 200
data = resp.json()
assert "name" in data
assert "protocol_version" in data
assert "endpoints" in data
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestVersionNegotiation:
def test_version_not_supported_error(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post(
"/api/v1/message:send",
json=_make_message_payload(project_id, "version test"),
headers={"A2A-Version": "2.0"}
)
assert resp.status_code == 400
detail = json.loads(resp.json()["detail"])
assert detail["code"] == -32009
assert "not supported" in detail["message"].lower()
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_version_response_header(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post(
"/api/v1/message:send",
json=_make_message_payload(project_id, "version header test"),
headers={"A2A-Version": "1.0"}
)
assert resp.status_code == 200
assert resp.headers.get("A2A-Version") == "1.0"
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestWebhookStreamResponse:
def test_webhook_payload_format(self):
from app.schemas.a2a import StreamResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent, TaskMessageEvent, A2ATaskStatusSchema, A2ATaskState, A2AArtifactSchema
status_event = TaskStatusUpdateEvent(
taskId="task-123",
contextId="ctx-456",
status=A2ATaskStatusSchema(
state=A2ATaskState.SUBMITTED,
timestamp=datetime.utcnow(),
),
metadata={},
)
status_dump = status_event.model_dump()
assert "taskId" in status_dump
assert status_dump["taskId"] == "task-123"
assert status_dump["status"]["state"] == "SUBMITTED"
artifact_event = TaskArtifactUpdateEvent(
taskId="task-123",
contextId="ctx-456",
artifact=A2AArtifactSchema(
artifactId="art-789",
parts=[],
),
append=False,
lastChunk=True,
)
artifact_dump = artifact_event.model_dump()
assert "taskId" in artifact_dump
assert artifact_dump["artifact"]["artifactId"] == "art-789"
def test_stream_response_task_field(self):
from app.schemas.a2a import StreamResponse, StreamResponseTask, A2ATaskState
resp = StreamResponse(
task=StreamResponseTask(
id="task-123",
contextId="ctx-456",
state=A2ATaskState.WORKING,
artifacts=[],
)
)
data = resp.model_dump()
assert "task" in data
assert data["task"]["id"] == "task-123"
class TestSSEFIFO:
def test_sse_event_order(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
with client.stream("POST", "/api/v1/message:stream", json=_make_message_payload(project_id, "fifo test")) as resp:
assert resp.status_code == 200
chunks = []
for line in resp.iter_lines():
if line.startswith("data: "):
chunks.append(json.loads(line[6:]))
event_types = [c.get("type") for c in chunks if "type" in c]
status_idx = next((i for i, t in enumerate(event_types) if t == "TaskStatusUpdateEvent"), -1)
assert status_idx >= 0
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestAuthSchemes:
def test_shared_secret_auth(self):
secret = "test-secret-key-12345"
timestamp = int(time.time())
body = b'{"test":"data"}'
sig, _ = SharedSecretAuth.generate_signature(secret, body, timestamp)
assert sig.startswith("sha256=")
assert SharedSecretAuth.verify_signature(secret, body, sig, timestamp) is True
def test_auth_scheme_none(self):
from app.schemas.a2a import SecuritySchemeHttpAuth
scheme = SecuritySchemeHttpAuth(scheme="bearer", description="Bearer auth")
assert scheme.scheme == "bearer"
class TestExceptionPaths:
def test_auth_failure_marks_agent_unhealthy(self, monkeypatch):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, _, _, _, project_id = _seed(db)
agent = A2ARemoteAgent(
project_id=project_id,
name="auth-fail-agent",
base_url="https://remote.example.com",
auth_scheme="bearer",
auth_token="bad-token",
created_by=owner_id,
)
db.add(agent)
db.commit()
db.refresh(agent)
a2a_runtime._circuit_state.pop(agent.id, None)
class _FailResp:
status_code = 401
@staticmethod
def json():
return {"detail": "unauthorized"}
class _Client401:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url, headers=None):
return _FailResp()
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _Client401)
with pytest.raises(RuntimeError):
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
db.refresh(agent)
assert agent.healthy is False
assert agent.failure_count == 1
Base.metadata.drop_all(bind=engine)
db.close()
engine.dispose()
def test_remote_unavailable_opens_circuit(self, monkeypatch):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, _, _, _, project_id = _seed(db)
agent = A2ARemoteAgent(
project_id=project_id,
name="offline-agent",
base_url="https://offline.example.com",
auth_scheme="none",
created_by=owner_id,
)
db.add(agent)
db.commit()
db.refresh(agent)
a2a_runtime._circuit_state.pop(agent.id, None)
class _ClientDown:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url, headers=None):
raise httpx.ConnectError("network down")
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _ClientDown)
for _ in range(3):
with pytest.raises(Exception):
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
db.refresh(agent)
assert agent.healthy is False
assert agent.failure_count == 3
assert agent.circuit_open_until is not None
Base.metadata.drop_all(bind=engine)
db.close()
engine.dispose()
def test_idempotency_key_deduplicates_task(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
idempotency_key = f"idem-key-{int(time.time())}"
payload1 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
resp1 = client.post("/api/v1/message:send", json=payload1)
assert resp1.status_code == 200
payload2 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
payload2["message"]["messageId"] = f"msg-{int(time.time()*1000) + 1}"
resp2 = client.post("/api/v1/message:send", json=payload2)
assert resp2.status_code == 200
assert resp1.json()["task"]["id"] == resp2.json()["task"]["id"]
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_tenant_isolation(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "isolation test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
get_resp = client.get(f"/api/v1/tasks/{task_id}")
assert get_resp.status_code == 404
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestMetricsAdminOnly:
def test_metrics_admin_only(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, _ = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
denied = client.get("/api/v1/a2a/metrics")
assert denied.status_code == 403
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
ok = client.get("/api/v1/a2a/metrics")
assert ok.status_code == 200
assert "counters" in ok.json()
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
@@ -0,0 +1,132 @@
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from app.core.data_root import get_data_root
from main import app
def _backend_data_root() -> Path:
return get_data_root()
def test_download_artifact_within_whitelist() -> None:
uploads_dir = _backend_data_root() / "uploads"
uploads_dir.mkdir(parents=True, exist_ok=True)
sample = uploads_dir / "task2-download.csv"
sample.write_text("id,name\n1,a\n", encoding="utf-8")
client = TestClient(app)
response = client.get("/nanobot/artifacts/download", params={"target": "local://task2-download.csv"})
assert response.status_code == 200
assert response.headers["content-type"].startswith("application/octet-stream")
assert response.headers["content-disposition"].startswith("attachment;")
assert response.content == sample.read_bytes()
def test_download_artifact_rejects_outside_paths() -> None:
client = TestClient(app)
response = client.get("/nanobot/artifacts/download", params={"target": "/etc/hosts"})
assert response.status_code == 403
assert response.json()["detail"] == "非法路径访问"
def test_preview_artifact_returns_unsupported_for_binary() -> None:
uploads_dir = _backend_data_root() / "uploads"
uploads_dir.mkdir(parents=True, exist_ok=True)
sample = uploads_dir / "task2-unsupported.bin"
sample.write_bytes(b"\x00\x01\x02")
client = TestClient(app)
response = client.get("/nanobot/artifacts/preview", params={"target": f"local://{sample.name}"})
assert response.status_code == 415
assert response.json()["detail"] == "当前文件类型不支持预览,请使用下载"
download = client.get("/nanobot/artifacts/download", params={"target": f"local://{sample.name}"})
assert download.status_code == 200
assert download.content == sample.read_bytes()
def test_preview_html_supports_directory_resources() -> None:
web_dir = _backend_data_root() / "workspace" / "task2-web"
web_dir.mkdir(parents=True, exist_ok=True)
html_file = web_dir / "index.html"
css_file = web_dir / "styles.css"
html_file.write_text("<html><head><link rel='stylesheet' href='styles.css'></head><body>ok</body></html>", encoding="utf-8")
css_file.write_text("body{color:#333;}", encoding="utf-8")
client = TestClient(app)
preview = client.get(
"/nanobot/artifacts/preview",
params={"target": str(html_file)},
follow_redirects=False,
)
assert preview.status_code == 307
location = preview.headers["location"]
assert location.startswith("/nanobot/artifacts/web/")
html_response = client.get(location)
assert html_response.status_code == 200
assert "text/html" in html_response.headers["content-type"]
assert "styles.css" in html_response.text
css_response = client.get(location.replace("index.html", "styles.css"))
assert css_response.status_code == 200
assert "text/css" in css_response.headers["content-type"]
assert "color:#333" in css_response.text
@pytest.mark.parametrize(
("filename", "payload", "expected_mime"),
[
("task4-image.png", b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR", "image/png"),
("task4-preview.pdf", b"%PDF-1.4\n1 0 obj\n<<>>\nendobj\n", "application/pdf"),
(
"task4-preview.pptx",
b"PK\x03\x04\x14\x00\x00\x00\x08\x00\x00\x00!\x00",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
),
],
)
def test_preview_and_download_supported_types(filename: str, payload: bytes, expected_mime: str) -> None:
uploads_dir = _backend_data_root() / "uploads"
uploads_dir.mkdir(parents=True, exist_ok=True)
sample = uploads_dir / filename
sample.write_bytes(payload)
client = TestClient(app)
preview = client.get("/nanobot/artifacts/preview", params={"target": f"local://{filename}"})
assert preview.status_code == 200
assert preview.headers["content-type"].startswith(expected_mime)
download = client.get("/nanobot/artifacts/download", params={"target": f"local://{filename}"})
assert download.status_code == 200
assert download.content == sample.read_bytes()
def test_web_preview_missing_resource_returns_error_and_download_still_works() -> None:
web_dir = _backend_data_root() / "workspace" / "task4-web-missing"
web_dir.mkdir(parents=True, exist_ok=True)
html_file = web_dir / "index.html"
html_file.write_text("<html><head><script src='missing.js'></script></head><body>ok</body></html>", encoding="utf-8")
client = TestClient(app)
preview = client.get(
"/nanobot/artifacts/preview",
params={"target": str(html_file)},
follow_redirects=False,
)
assert preview.status_code == 307
location = preview.headers["location"]
missing = client.get(location.replace("index.html", "missing.js"))
assert missing.status_code == 404
assert missing.json()["detail"] == "Web 资源不存在"
download = client.get("/nanobot/artifacts/download", params={"target": str(html_file)})
assert download.status_code == 200
assert download.content == html_file.read_bytes()
+56
View File
@@ -0,0 +1,56 @@
from pathlib import Path
from app.core.artifacts import extract_artifacts
from app.core.data_root import get_data_root
def _backend_data_root() -> Path:
return get_data_root()
def test_extract_artifacts_from_local_and_tool_paths() -> None:
data_root = _backend_data_root()
uploads_dir = data_root / "uploads"
workspace_dir = data_root / "workspace" / "reports"
uploads_dir.mkdir(parents=True, exist_ok=True)
workspace_dir.mkdir(parents=True, exist_ok=True)
upload_file = uploads_dir / "task1-sample.csv"
upload_file.write_text("a,b\n1,2\n", encoding="utf-8")
report_file = workspace_dir / "task1-report.html"
report_file.write_text("<html><body>ok</body></html>", encoding="utf-8")
content = "请下载 local://task1-sample.csv"
session_messages = [
{"role": "user", "content": "生成报告"},
{"role": "tool", "content": f"输出文件:{report_file}"},
]
artifacts = extract_artifacts(content, session_messages)
by_name = {item["name"]: item for item in artifacts}
assert "task1-sample.csv" in by_name
assert "task1-report.html" in by_name
assert by_name["task1-sample.csv"]["download_url"].startswith("/nanobot/artifacts/download?target=")
assert by_name["task1-sample.csv"]["previewable"] is True
assert by_name["task1-report.html"]["previewable"] is True
assert by_name["task1-report.html"]["preview_url"].startswith("/nanobot/artifacts/preview?target=")
def test_extract_artifacts_deduplicate_and_skip_missing() -> None:
data_root = _backend_data_root()
workspace_dir = data_root / "workspace"
workspace_dir.mkdir(parents=True, exist_ok=True)
pdf_file = workspace_dir / "task1-dedup.pdf"
pdf_file.write_bytes(b"%PDF-1.4 test")
missing_file = workspace_dir / "task1-missing.pdf"
content = f"{pdf_file} and {pdf_file} and {missing_file}"
artifacts = extract_artifacts(content, [])
assert len(artifacts) == 1
item = artifacts[0]
assert item["name"] == "task1-dedup.pdf"
assert item["mime_type"] == "application/pdf"
assert item["previewable"] is True
+145
View File
@@ -0,0 +1,145 @@
import asyncio
import sys
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
import main
def test_nanobot_chat_syncs_project_id(monkeypatch) -> None:
calls: list[dict[str, object]] = []
process_kwargs: list[dict[str, object]] = []
def fake_update_alias_meta(**kwargs):
calls.append(kwargs)
return kwargs
async def fake_process_message(*args, **kwargs):
process_kwargs.append(kwargs)
return "ok"
monkeypatch.setattr(main.session_alias_store, "update_alias_meta", fake_update_alias_meta)
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
monkeypatch.setattr(main.nanobot_service, "agent", None)
request = main.ChatRequest(message="hello", session_id="api:test-1", project_id=101)
response = asyncio.run(main.nanobot_chat(request))
assert response["response"] == "ok"
assert calls == [{"session_key": "api:test-1", "project_id": 101}]
assert process_kwargs and process_kwargs[0]["project_id"] == 101
def test_nanobot_chat_without_project_id_does_not_sync(monkeypatch) -> None:
calls: list[dict[str, object]] = []
process_kwargs: list[dict[str, object]] = []
def fake_update_alias_meta(**kwargs):
calls.append(kwargs)
return kwargs
async def fake_process_message(*args, **kwargs):
process_kwargs.append(kwargs)
return "ok"
monkeypatch.setattr(main.session_alias_store, "update_alias_meta", fake_update_alias_meta)
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
monkeypatch.setattr(main.nanobot_service, "agent", None)
request = main.ChatRequest(message="hello", session_id="api:test-2")
response = asyncio.run(main.nanobot_chat(request))
assert response["response"] == "ok"
assert calls == []
assert process_kwargs and process_kwargs[0]["project_id"] is None
def test_nanobot_chat_stream_syncs_project_id(monkeypatch) -> None:
calls: list[dict[str, object]] = []
process_kwargs: list[dict[str, object]] = []
def fake_update_alias_meta(**kwargs):
calls.append(kwargs)
return kwargs
async def fake_process_message(*args, **kwargs):
process_kwargs.append(kwargs)
on_stream = kwargs.get("on_stream")
if on_stream:
await on_stream("stream-token")
return "stream-complete"
async def collect_stream_chunks(response) -> list[str]:
chunks: list[str] = []
async for chunk in response.body_iterator:
if isinstance(chunk, bytes):
chunks.append(chunk.decode("utf-8"))
else:
chunks.append(chunk)
return chunks
monkeypatch.setattr(main.session_alias_store, "update_alias_meta", fake_update_alias_meta)
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
monkeypatch.setattr(main.nanobot_service, "agent", None)
request = main.ChatRequest(message="hello", session_id="api:test-3", project_id=202)
response = asyncio.run(main.nanobot_chat_stream(request))
chunks = asyncio.run(collect_stream_chunks(response))
content = "".join(chunks)
assert "stream-token" in content
assert "stream-complete" in content
assert calls == [{"session_key": "api:test-3", "project_id": 202}]
assert process_kwargs and process_kwargs[0]["project_id"] == 202
def test_nanobot_chat_stream_emits_reasoning_flags_and_final_reasoning(monkeypatch) -> None:
async def fake_process_message(*args, **kwargs):
on_progress = kwargs.get("on_progress")
on_stream = kwargs.get("on_stream")
if on_progress:
await on_progress("模型正在拆解问题", is_reasoning=True)
await on_progress("开始执行工具", tool_hint=True)
if on_stream:
await on_stream("answer-token")
return "final-answer"
class _DummySession:
def __init__(self):
self.metadata = {}
self.messages = [
{"role": "assistant", "content": "final-answer", "reasoning_content": "完整思考过程"}
]
class _DummySessions:
def get_or_create(self, _key):
return _DummySession()
class _DummyAgent:
def __init__(self):
self.sessions = _DummySessions()
async def collect_stream_chunks(response) -> list[str]:
chunks: list[str] = []
async for chunk in response.body_iterator:
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk)
return chunks
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
request = main.ChatRequest(message="hello", session_id="api:test-4", project_id=303)
response = asyncio.run(main.nanobot_chat_stream(request))
chunks = asyncio.run(collect_stream_chunks(response))
content = "".join(chunks)
assert '"type": "progress", "content": "模型正在拆解问题", "is_reasoning": true' in content
assert '"type": "progress", "content": "开始执行工具", "tool_hint": true' in content
assert '"type": "final", "content": "final-answer", "reasoning_content": "完整思考过程"' in content
+28
View File
@@ -0,0 +1,28 @@
from pathlib import Path
from app.core import data_root
def test_data_root_prefers_env(monkeypatch, tmp_path: Path) -> None:
custom = tmp_path / "custom-data-root"
monkeypatch.setenv("DATA_ROOT", str(custom))
assert data_root.get_data_root() == custom.resolve()
def test_data_root_falls_back_to_legacy(monkeypatch, tmp_path: Path) -> None:
monkeypatch.delenv("DATA_ROOT", raising=False)
legacy = tmp_path / "legacy-data"
default = tmp_path / "default-data"
legacy.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(data_root, "LEGACY_DATA_ROOT", legacy)
monkeypatch.setattr(data_root, "DEFAULT_DATA_ROOT", default)
assert data_root.get_data_root() == legacy
def test_ensure_data_layout_creates_children(monkeypatch, tmp_path: Path) -> None:
monkeypatch.setenv("DATA_ROOT", str(tmp_path / "dr"))
data_root.ensure_data_layout()
root = data_root.get_data_root()
assert (root / "workspace").exists()
assert (root / "uploads").exists()
assert (root / "data").exists()
+102
View File
@@ -0,0 +1,102 @@
import json
import sys
from importlib import import_module
from pathlib import Path
from typer.testing import CliRunner
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
app = import_module("app.cli").app
runner = CliRunner()
class _FakeProcess:
def __init__(self, pid: int = 9527, exit_code: int | None = None) -> None:
self.pid = pid
self._exit_code = exit_code
def poll(self):
return self._exit_code
def test_start_command_writes_state(monkeypatch, tmp_path) -> None:
pid_file = tmp_path / "run" / "state.json"
log_file = tmp_path / "run" / "service.log"
monkeypatch.setattr("app.cli.subprocess.Popen", lambda *args, **kwargs: _FakeProcess())
monkeypatch.setattr("app.cli._wait_for_server_ready", lambda *_args, **_kwargs: True)
result = runner.invoke(
app,
[
"start",
"--host",
"127.0.0.1",
"--port",
"18999",
"--pid-file",
str(pid_file),
"--log-file",
str(log_file),
],
)
assert result.exit_code == 0
assert "已启动" in result.stdout
assert pid_file.exists()
state = json.loads(pid_file.read_text(encoding="utf-8"))
assert state["pid"] == 9527
assert state["host"] == "127.0.0.1"
assert state["port"] == 18999
def test_status_command_reports_running(monkeypatch, tmp_path) -> None:
pid_file = tmp_path / "run" / "state.json"
pid_file.parent.mkdir(parents=True, exist_ok=True)
pid_file.write_text(
json.dumps({"pid": 9527, "host": "127.0.0.1", "port": 18080}, ensure_ascii=False),
encoding="utf-8",
)
monkeypatch.setattr("app.cli._is_process_running", lambda pid: pid == 9527)
result = runner.invoke(app, ["status", "--pid-file", str(pid_file)])
assert result.exit_code == 0
assert "running" in result.stdout
assert "127.0.0.1:18080" in result.stdout
def test_stop_command_cleans_state(monkeypatch, tmp_path) -> None:
pid_file = tmp_path / "run" / "state.json"
pid_file.parent.mkdir(parents=True, exist_ok=True)
pid_file.write_text(json.dumps({"pid": 9527}, ensure_ascii=False), encoding="utf-8")
monkeypatch.setattr("app.cli._is_process_running", lambda pid: pid == 9527)
monkeypatch.setattr("app.cli._stop_pid", lambda pid, timeout: pid == 9527)
result = runner.invoke(app, ["stop", "--pid-file", str(pid_file)])
assert result.exit_code == 0
assert "已停止" in result.stdout
assert not pid_file.exists()
def test_status_command_cleans_stale_state(monkeypatch, tmp_path) -> None:
pid_file = tmp_path / "run" / "state.json"
pid_file.parent.mkdir(parents=True, exist_ok=True)
pid_file.write_text(json.dumps({"pid": 9527}, ensure_ascii=False), encoding="utf-8")
monkeypatch.setattr("app.cli._is_process_running", lambda _pid: False)
result = runner.invoke(app, ["status", "--pid-file", str(pid_file)])
assert result.exit_code == 0
assert "stopped" in result.stdout
assert not pid_file.exists()
@@ -0,0 +1,376 @@
import asyncio
import sys
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
from fastapi.testclient import TestClient
import main
from app.context import current_knowledge_base_id
from app.schemas.knowledge import KnowledgeSearchResponse
from app.tools.knowledge_base import KnowledgeBaseRetrieveTool
def test_knowledge_base_crud_and_search_routes(monkeypatch, tmp_path) -> None:
async def fake_start():
return None
async def fake_stop():
return None
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
client = TestClient(main.app)
create_resp = client.post(
"/api/v1/knowledge-bases",
json={"name": "产品手册", "description": "用于问答", "top_k": 2, "chunk_size": 256, "chunk_overlap": 20},
)
assert create_resp.status_code == 200
kb = create_resp.json()
kb_id = kb["id"]
list_resp = client.get("/api/v1/knowledge-bases")
assert list_resp.status_code == 200
assert any(item["id"] == kb_id for item in list_resp.json())
doc_resp = client.post(
f"/api/v1/knowledge-bases/{kb_id}/documents",
json={"title": "退款规则", "content": "苹果手机支持7天无理由退款", "metadata": {"lang": "zh"}},
)
assert doc_resp.status_code == 200
doc_id = doc_resp.json()["id"]
reindex_resp = client.post(f"/api/v1/knowledge-bases/{kb_id}/reindex")
assert reindex_resp.status_code == 200
search_resp = client.post(f"/api/v1/knowledge-bases/{kb_id}/search", json={"query": "苹果退款", "top_k": 2})
assert search_resp.status_code == 200
parsed = KnowledgeSearchResponse(**search_resp.json())
assert parsed.hits
assert "苹果" in parsed.answer
update_resp = client.put(f"/api/v1/knowledge-bases/{kb_id}", json={"name": "售后知识库"})
assert update_resp.status_code == 200
assert update_resp.json()["name"] == "售后知识库"
delete_doc_resp = client.delete(f"/api/v1/knowledge-bases/{kb_id}/documents/{doc_id}")
assert delete_doc_resp.status_code == 200
delete_kb_resp = client.delete(f"/api/v1/knowledge-bases/{kb_id}")
assert delete_kb_resp.status_code == 200
def test_knowledge_global_config_mask_and_validation(monkeypatch, tmp_path) -> None:
async def fake_start():
return None
async def fake_stop():
return None
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
client = TestClient(main.app)
initial_resp = client.get("/api/v1/knowledge-bases/global-config")
assert initial_resp.status_code == 200
assert initial_resp.json() == {
"api_base": None,
"api_key": None,
"api_key_masked": None,
"has_api_key": False,
"default_embedding_model": None,
}
update_resp = client.put(
"/api/v1/knowledge-bases/global-config",
json={"api_base": "https://kb.example.com/", "api_key": "sk-knowledge-secret", "default_embedding_model": "text-embedding-3-small"},
)
assert update_resp.status_code == 200
body = update_resp.json()
assert body["api_base"] == "https://kb.example.com"
assert body["api_key"] is None
assert body["has_api_key"] is True
assert body["api_key_masked"] == "sk-k***********cret"
assert body["default_embedding_model"] == "text-embedding-3-small"
get_resp = client.get("/api/v1/knowledge-bases/global-config")
assert get_resp.status_code == 200
assert get_resp.json()["api_key"] is None
assert get_resp.json()["api_key_masked"] == "sk-k***********cret"
assert get_resp.json()["default_embedding_model"] == "text-embedding-3-small"
invalid_resp = client.put("/api/v1/knowledge-bases/global-config", json={"api_base": "ftp://kb.example.com"})
assert invalid_resp.status_code == 422
def test_chat_request_syncs_knowledge_base_metadata(monkeypatch) -> None:
captured_kb_ids: list[str | None] = []
captured_messages: list[str] = []
class _DummySession:
def __init__(self):
self.metadata = {}
self.messages = []
self.updated_at = None
class _DummySessions:
def __init__(self):
self._sessions: dict[str, _DummySession] = {}
def get_or_create(self, key: str):
if key not in self._sessions:
self._sessions[key] = _DummySession()
return self._sessions[key]
def save(self, _session):
return None
class _DummyAgent:
def __init__(self):
self.sessions = _DummySessions()
async def fake_process_message(*args, **kwargs):
captured_kb_ids.append(current_knowledge_base_id.get())
if args and isinstance(args[0], str):
captured_messages.append(args[0])
return "ok"
def fake_search(*, kb_id: str, query: str, top_k=None):
assert kb_id == "kb-123"
assert query == "请回答售后规则"
return {
"answer": "命中结果",
"hits": [
{"doc_id": "d1", "title": "退款规则", "chunk": "7天无理由退款", "score": 0.9, "metadata": {}},
{"doc_id": "d2", "title": "售后电话", "chunk": "客服电话 400-1234", "score": 0.7, "metadata": {}},
],
}
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
monkeypatch.setattr("main.knowledge_index_service.search", fake_search)
request = main.ChatRequest(message="请回答售后规则", session_id="api:kb-1", knowledge_base_id="kb-123")
response = asyncio.run(main.nanobot_chat(request))
assert response["response"] == "ok"
assert "kb_citations" in response
assert len(response["kb_citations"]) == 2
assert response["kb_citations"][0]["title"] == "退款规则"
assert captured_kb_ids == ["kb-123"]
assert captured_messages and "7天无理由退款" in captured_messages[0]
session = main.nanobot_service.agent.sessions.get_or_create("api:kb-1")
assert session.metadata["selected_knowledge_base_id"] == "kb-123"
def test_knowledge_tool_uses_session_context(monkeypatch) -> None:
tool = KnowledgeBaseRetrieveTool()
token = current_knowledge_base_id.set("kb-session")
called: list[dict] = []
def fake_search(*, kb_id: str, query: str, top_k=None):
called.append({"kb_id": kb_id, "query": query, "top_k": top_k})
return {
"answer": "命中结果",
"hits": [{"doc_id": "d1", "title": "t1", "chunk": "命中结果", "score": 1.0, "metadata": {}}],
}
monkeypatch.setattr("app.tools.knowledge_base.knowledge_index_service.search", fake_search)
try:
output = asyncio.run(tool.execute(query="售后政策"))
finally:
current_knowledge_base_id.reset(token)
assert called and called[0]["kb_id"] == "kb-session"
assert "命中结果" in output
def test_update_session_context_file_supports_knowledge_base(monkeypatch) -> None:
class _DummySession:
def __init__(self):
self.metadata = {}
self.updated_at = None
class _DummySessions:
def __init__(self):
self.session = _DummySession()
def get_or_create(self, _key: str):
return self.session
def save(self, _session):
return None
class _DummyAgent:
def __init__(self):
self.sessions = _DummySessions()
monkeypatch.setattr(main.nanobot_service, "agent", _DummyAgent())
payload = main.SessionFileContextUpdateRequest(selected_knowledge_base_id="kb-ctx")
response = main.update_session_context_file("api:ctx", payload)
assert response["status"] == "success"
assert response["metadata"]["selected_knowledge_base_id"] == "kb-ctx"
def test_knowledge_global_connection_test_route(monkeypatch, tmp_path) -> None:
async def fake_start():
return None
async def fake_stop():
return None
class _DummyEmbeddingData:
embedding = [0.1, 0.2, 0.3]
class _DummyEmbeddingResp:
data = [_DummyEmbeddingData()]
class _DummyEmbeddingsAPI:
@staticmethod
def create(model: str, input: str):
assert model == "text-embedding-3-small"
assert input == "connection test"
return _DummyEmbeddingResp()
class _DummyOpenAI:
def __init__(self, api_key: str, base_url: str):
assert api_key == "sk-knowledge-secret"
assert base_url == "https://kb.example.com"
self.embeddings = _DummyEmbeddingsAPI()
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
monkeypatch.setattr("app.api.knowledge.OpenAI", _DummyOpenAI)
client = TestClient(main.app)
save_resp = client.put(
"/api/v1/knowledge-bases/global-config",
json={
"api_base": "https://kb.example.com",
"api_key": "sk-knowledge-secret",
"default_embedding_model": "text-embedding-3-small",
},
)
assert save_resp.status_code == 200
test_resp = client.post(
"/api/v1/knowledge-bases/global-config/test-connection",
json={"model_name": "text-embedding-3-small"},
)
assert test_resp.status_code == 200
body = test_resp.json()
assert body["success"] is True
assert body["model_name"] == "text-embedding-3-small"
assert body["embedding_dimension"] == 3
assert body["resolved_api_base"] == "https://kb.example.com"
assert body["available_models"] == []
def test_knowledge_global_connection_test_route_requires_model_name(monkeypatch, tmp_path) -> None:
async def fake_start():
return None
async def fake_stop():
return None
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
client = TestClient(main.app)
resp = client.post(
"/api/v1/knowledge-bases/global-config/test-connection",
json={
"api_base": "https://api.siliconflow.cn/v1/embeddings",
"api_key": "ark-key",
},
)
assert resp.status_code == 400
assert "测试连接必须显式填写向量模型名称" in resp.json()["detail"]
def test_knowledge_global_connection_test_route_returns_remote_error(monkeypatch, tmp_path) -> None:
async def fake_start():
return None
async def fake_stop():
return None
class _DummyEmbeddingsAPI:
@staticmethod
def create(model: str, input: str):
assert model == "BAAI/bge-large-zh-v1.5"
assert input == "connection test"
raise RuntimeError("Not Found")
class _DummyOpenAI:
def __init__(self, api_key: str, base_url: str):
assert api_key == "sf-key"
assert base_url == "https://api.siliconflow.cn/v1"
self.embeddings = _DummyEmbeddingsAPI()
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
monkeypatch.setattr("app.api.knowledge.OpenAI", _DummyOpenAI)
client = TestClient(main.app)
resp = client.post(
"/api/v1/knowledge-bases/global-config/test-connection",
json={
"api_base": "https://api.siliconflow.cn/v1/embeddings",
"api_key": "sf-key",
"model_name": "BAAI/bge-large-zh-v1.5",
},
)
assert resp.status_code == 400
assert "Embedding调用失败" in resp.json()["detail"]
def test_knowledge_document_upload_route(monkeypatch, tmp_path) -> None:
async def fake_start():
return None
async def fake_stop():
return None
monkeypatch.setenv("DATA_ROOT", str(tmp_path))
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
client = TestClient(main.app)
create_resp = client.post(
"/api/v1/knowledge-bases",
json={"name": "上传测试库", "description": "用于上传", "top_k": 2, "chunk_size": 256, "chunk_overlap": 20},
)
assert create_resp.status_code == 200
kb_id = create_resp.json()["id"]
files = [
("files", ("doc1.txt", b"hello knowledge", "text/plain")),
("files", ("doc2.md", b"# title\ncontent", "text/markdown")),
]
upload_resp = client.post(
f"/api/v1/knowledge-bases/{kb_id}/documents/upload",
files=files,
data={"metadata": "{\"source\":\"batch\"}"},
)
assert upload_resp.status_code == 200
body = upload_resp.json()
assert body["status"] == "success"
assert body["count"] == 2
assert len(body["documents"]) == 2
list_resp = client.get(f"/api/v1/knowledge-bases/{kb_id}/documents")
assert list_resp.status_code == 200
docs = list_resp.json()
assert len(docs) == 2
@@ -0,0 +1,110 @@
import asyncio
from types import SimpleNamespace
from app.core.nanobot import NanobotIntegration
from app.context import current_session_id
class _DummySessions:
def __init__(self) -> None:
self.saved = []
self._session = SimpleNamespace(messages=[])
def get_or_create(self, _session_id: str):
return self._session
def save(self, session) -> None:
self.saved.append(session)
class _DummyAgent:
def __init__(self) -> None:
self.sessions = _DummySessions()
self.provider = SimpleNamespace(default_model="demo-model")
self.model = "demo-model"
async def process_direct(self, *_args, **_kwargs):
return "ok"
def test_process_message_project_id_fallback_from_session_alias(monkeypatch) -> None:
service = NanobotIntegration()
base_agent = _DummyAgent()
custom_agent = _DummyAgent()
service.agent = base_agent
service._started = True
captured: dict[str, object] = {}
async def fake_get_or_create_model_agent(model_id, target_config, project_id):
captured["project_id"] = project_id
return custom_agent
monkeypatch.setattr(service, "_get_or_create_model_agent", fake_get_or_create_model_agent)
monkeypatch.setattr("app.core.nanobot.get_llm_configs", lambda: [])
monkeypatch.setattr("app.core.nanobot.get_active_llm_config", lambda: None)
monkeypatch.setattr(
"app.core.session_alias_store.session_alias_store.get_alias_meta",
lambda _session_id: {"project_id": 77},
)
response = asyncio.run(service.process_message("hello", session_id="api:s1"))
assert response == "ok"
assert captured["project_id"] == 77
def test_process_message_project_id_prefers_request_value(monkeypatch) -> None:
service = NanobotIntegration()
base_agent = _DummyAgent()
custom_agent = _DummyAgent()
service.agent = base_agent
service._started = True
captured: dict[str, object] = {}
async def fake_get_or_create_model_agent(model_id, target_config, project_id):
captured["project_id"] = project_id
return custom_agent
monkeypatch.setattr(service, "_get_or_create_model_agent", fake_get_or_create_model_agent)
monkeypatch.setattr("app.core.nanobot.get_llm_configs", lambda: [])
monkeypatch.setattr("app.core.nanobot.get_active_llm_config", lambda: None)
monkeypatch.setattr(
"app.core.session_alias_store.session_alias_store.get_alias_meta",
lambda _session_id: {"project_id": 88},
)
response = asyncio.run(service.process_message("hello", session_id="api:s2", project_id=9))
assert response == "ok"
assert captured["project_id"] == 9
def test_register_custom_tools_always_contains_subagent_tools() -> None:
service = NanobotIntegration()
names: list[str] = []
class _ToolRegistry:
def register(self, tool) -> None:
names.append(tool.name)
fake_agent = SimpleNamespace(tools=_ToolRegistry())
service._register_custom_tools(fake_agent, project_id=None)
assert "list_subagents" in names
assert "invoke_subagent" in names
def test_subagent_tool_resolves_project_from_session_alias(monkeypatch) -> None:
from app.tools.subagent import _resolve_project_id
token = current_session_id.set("api:subagent-test")
try:
monkeypatch.setattr(
"app.tools.subagent.session_alias_store.get_alias_meta",
lambda _session_id: {"project_id": 66},
)
assert _resolve_project_id(None) == 66
finally:
current_session_id.reset(token)
+12
View File
@@ -0,0 +1,12 @@
import asyncio
import json
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest
async def main():
req = NL2SQLRequest(query="列出所有表", source="postgres", generate_chart=False)
res = await process_nl2sql(req)
print("SQL:", res.sql)
print("Error:", res.error)
print("Result:", res.result)
asyncio.run(main())
@@ -0,0 +1,68 @@
import pytest
import json
from unittest.mock import AsyncMock, MagicMock, patch
from app.agent.nl2sql import process_nl2sql, NL2SQLRequest, NL2SQLResponse
@pytest.mark.asyncio
async def test_nl2sql_optimized_flow():
# Mock parameters
query = "Show me the top 5 sales"
source = "ds:1"
# Mock connector and schema
mock_connector = MagicMock()
mock_connector.get_schema.return_value = {
"sales": {"columns": [{"name": "id", "type": "INT"}, {"name": "amount", "type": "DECIMAL"}]},
"users": {"columns": [{"name": "id", "type": "INT"}, {"name": "name", "type": "TEXT"}]},
"logs": {"columns": [{"name": "id", "type": "INT"}, {"name": "event", "type": "TEXT"}]},
"products": {"columns": [{"name": "id", "type": "INT"}]},
"categories": {"columns": [{"name": "id", "type": "INT"}]},
"inventory": {"columns": [{"name": "id", "type": "INT"}]}
}
mock_connector.test_connection.return_value = True
mock_connector.execute_query.return_value = [{"id": 1, "amount": 100}]
# Mock LLM provider
mock_provider = AsyncMock()
# First response for Table Selector
mock_resp_tables = MagicMock()
mock_resp_tables.content = '["sales"]'
mock_resp_tables.finish_reason = "stop"
# Second response for SQL Generation
mock_resp_sql = MagicMock()
mock_resp_sql.content = '{"reasoning": "Plan...", "sql": "SELECT * FROM sales LIMIT 5"}'
mock_resp_sql.finish_reason = "stop"
mock_provider.chat.side_effect = [mock_resp_tables, mock_resp_sql]
# Patch dependencies
with patch("app.agent.nl2sql.get_active_llm_config", return_value={"model": "gpt-4"}), \
patch("app.agent.nl2sql.build_llm_provider", return_value=mock_provider), \
patch("app.agent.nl2sql.get_connector", return_value=mock_connector), \
patch("app.agent.nl2sql.SessionLocal"), \
patch("app.agent.nl2sql.DataSource"), \
patch("app.agent.nl2sql.postgres_connector", mock_connector), \
patch("app.agent.nl2sql._check_connection_with_cache", return_value=True):
request = NL2SQLRequest(query=query, source=source)
response = await process_nl2sql(request)
print(f"DEBUG: Response SQL: '{response.sql}'")
print(f"DEBUG: Response Error: '{response.error}'")
assert response.sql == "SELECT * FROM sales LIMIT 5"
assert len(response.result) == 1
assert response.error is None
# Verify provider was called twice
assert mock_provider.chat.call_count == 2
# Verify first call was for table selection
args, kwargs = mock_provider.chat.call_args_list[0]
assert "TABLE_SELECTOR_SYSTEM_PROMPT" in str(args) or "Identifying relevant tables" in str(args) or any("system" in m["role"] for m in kwargs["messages"])
if __name__ == "__main__":
import asyncio
asyncio.run(test_nl2sql_optimized_flow())
@@ -0,0 +1,62 @@
import sys
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
from app.core.llm_provider import build_llm_provider
from app.core.nanobot import NanobotIntegration
from app.core.patched_openai_compat_provider import PatchedOpenAICompatProvider
def test_build_llm_provider_uses_max_completion_tokens_for_gpt5() -> None:
provider = build_llm_provider(
model="gpt-5.4-nano",
provider="openai",
api_key="test-key",
api_base="https://example.com/v1",
)
assert isinstance(provider, PatchedOpenAICompatProvider)
kwargs = provider._build_kwargs(
messages=[{"role": "user", "content": "hello"}],
tools=None,
model="gpt-5.4-nano",
max_tokens=5,
temperature=0,
reasoning_effort=None,
tool_choice=None,
)
assert kwargs["max_completion_tokens"] == 5
assert "max_tokens" not in kwargs
def test_nanobot_provider_keeps_max_tokens_for_legacy_models() -> None:
integration = NanobotIntegration()
provider = integration._build_provider(
model="gpt-4o-mini",
provider_name="openai",
api_key="test-key",
api_base="https://example.com/v1",
extra_headers=None,
)
assert isinstance(provider, PatchedOpenAICompatProvider)
kwargs = provider._build_kwargs(
messages=[{"role": "user", "content": "hello"}],
tools=None,
model="gpt-4o-mini",
max_tokens=5,
temperature=0,
reasoning_effort=None,
tool_choice=None,
)
assert kwargs["max_tokens"] == 5
assert "max_completion_tokens" not in kwargs
@@ -0,0 +1,102 @@
import sys
from collections.abc import Generator
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.core.security import CurrentUser, get_current_user
from app.database import Base, get_db
from app.models.project import Project
from app.models.subagent import Subagent
from app.models.user import User
from main import app
def _seed_subagent(db: Session) -> tuple[User, Project, Subagent]:
user = User(
username="task3-owner",
email="task3-owner@example.com",
hashed_password="test",
is_admin=False,
)
db.add(user)
db.commit()
db.refresh(user)
project = Project(
name="task3-project",
description="task3",
owner_id=user.id,
)
db.add(project)
db.commit()
db.refresh(project)
subagent = Subagent(
project_id=project.id,
name="task3-subagent",
description="task3",
instructions="do task3",
model="gpt",
)
db.add(subagent)
db.commit()
db.refresh(subagent)
return user, project, subagent
def test_subagent_detail_route_is_global_and_project_scoped_route_is_invalid() -> None:
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
user, project, subagent = _seed_subagent(db)
user_id = user.id
username = user.username
project_id = project.id
subagent_id = subagent.id
db.close()
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return CurrentUser(id=user_id, username=username, is_admin=False)
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
response = client.get(f"/api/v1/subagents/{subagent_id}")
assert response.status_code == 200
body = response.json()
assert body["id"] == subagent_id
assert body["project_id"] == project_id
legacy_path_response = client.get(f"/api/v1/projects/{project_id}/subagents/{subagent_id}")
assert legacy_path_response.status_code == 404
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
@@ -0,0 +1,136 @@
import asyncio
import json
import sys
from collections.abc import Generator
from pathlib import Path
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
from app.context import current_session_id
from app.core.security import CurrentUser, get_current_user
from app.database import Base, get_db
from app.models.project import Project
from app.models.user import User
from app.tools.subagent import InvokeSubagentTool, ListSubagentsTool
from main import app
def _seed_owner_and_project(db: Session) -> tuple[User, Project]:
user = User(
username="task4-owner",
email="task4-owner@example.com",
hashed_password="test",
is_admin=False,
)
db.add(user)
db.commit()
db.refresh(user)
project = Project(
name="task4-project",
description="task4",
owner_id=user.id,
)
db.add(project)
db.commit()
db.refresh(project)
return user, project
def test_create_subagent_then_list_and_invoke_success(monkeypatch) -> None:
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
user, project = _seed_owner_and_project(db)
user_id = user.id
username = user.username
project_id = project.id
db.close()
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return CurrentUser(id=user_id, username=username, is_admin=False)
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
token = current_session_id.set("api:task4-regression")
captured: dict[str, object] = {}
async def fake_process_message(message, session_id, project_id, model_id):
captured["message"] = message
captured["session_id"] = session_id
captured["project_id"] = project_id
captured["model_id"] = model_id
return "invoke-ok"
try:
monkeypatch.setattr("app.tools.subagent.SessionLocal", testing_session_local)
monkeypatch.setattr(
"app.tools.subagent.session_alias_store.get_alias_meta",
lambda _session_id: {"project_id": project_id},
)
monkeypatch.setattr("app.tools.subagent.get_llm_configs", lambda: [])
monkeypatch.setattr("app.tools.subagent.get_active_llm_config", lambda: None)
monkeypatch.setattr("app.tools.subagent.nanobot_service.process_message", fake_process_message)
client = TestClient(app)
create_response = client.post(
f"/api/v1/projects/{project_id}/subagents",
json={
"name": "task4-subagent",
"description": "task4-desc",
"instructions": "focus on regression",
"model": "gpt-x",
},
)
assert create_response.status_code == 200
created = create_response.json()
assert created["project_id"] == project_id
assert created["name"] == "task4-subagent"
listed = asyncio.run(ListSubagentsTool().execute())
listed_payload = json.loads(listed)
assert len(listed_payload) == 1
assert listed_payload[0]["name"] == "task4-subagent"
assert listed_payload[0]["description"] == "task4-desc"
invoke_result = asyncio.run(
InvokeSubagentTool().execute(
subagent_name="task4-subagent",
task="run regression task",
)
)
assert "completed the task" in invoke_result
assert "invoke-ok" in invoke_result
assert captured["project_id"] == project_id
assert captured["session_id"] == f"api:task4-regression:subagent:{created['id']}"
assert "focus on regression" in str(captured["message"])
finally:
current_session_id.reset(token)
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
@@ -0,0 +1,115 @@
import asyncio
import os
import sys
from pathlib import Path
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
import main
from app.trace.attributes import build_chat_trace_attributes, sanitize_attributes
from app.trace.service import TraceService
def test_trace_service_initialize_without_keys(monkeypatch) -> None:
monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False)
monkeypatch.delenv("LANGFUSE_SECRET_KEY", raising=False)
service = TraceService()
assert service.initialize() is False
assert service.enabled is False
assert service.initialized is True
def test_trace_attribute_helpers() -> None:
attrs = sanitize_attributes(
{
"session_id": "api:test",
"project_id": 1,
"skip_none": None,
"obj": {"a": 1},
}
)
assert attrs["session_id"] == "api:test"
assert attrs["project_id"] == 1
assert "skip_none" not in attrs
assert attrs["obj"] == "{'a': 1}"
chat_attrs = build_chat_trace_attributes(
session_id="api:test",
project_id=9,
model_id="model-a",
route_mode="auto",
source="postgres",
knowledge_base_id=None,
)
assert chat_attrs["component"] == "chat_stream"
assert chat_attrs["session_id"] == "api:test"
assert chat_attrs["project_id"] == 9
def test_nanobot_chat_stream_uses_trace_span(monkeypatch) -> None:
calls: list[tuple[str, dict]] = []
updates: list[dict] = []
trace_updates: list[dict] = []
class _Span:
def set_attributes(self, attributes):
updates.append(attributes)
def update(self, **kwargs):
updates.append(kwargs)
def update_trace(self, **kwargs):
trace_updates.append(kwargs)
def record_error(self, _exc, *, stage: str = "unknown"):
updates.append({"stage": stage})
class _SpanCtx:
def __init__(self, name: str, attributes: dict):
self._name = name
self._attributes = attributes
def __enter__(self):
calls.append((self._name, self._attributes))
return _Span()
def __exit__(self, exc_type, exc, tb):
return False
def fake_start_span(name: str, *, attributes=None, input_payload=None):
payload = dict(attributes or {})
if input_payload is not None:
payload["input_payload"] = input_payload
return _SpanCtx(name, payload)
async def fake_process_message(*args, **kwargs):
on_stream = kwargs.get("on_stream")
if on_stream:
await on_stream("token")
return "ok"
async def collect_stream_chunks(response) -> list[str]:
chunks: list[str] = []
async for chunk in response.body_iterator:
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else chunk)
return chunks
monkeypatch.setattr(main.trace_service, "start_span", fake_start_span)
monkeypatch.setattr(main.nanobot_service, "process_message", fake_process_message)
monkeypatch.setattr(main.nanobot_service, "agent", None)
request = main.ChatRequest(message="hello", session_id="api:trace-test", project_id=7)
response = asyncio.run(main.nanobot_chat_stream(request))
chunks = asyncio.run(collect_stream_chunks(response))
content = "".join(chunks)
assert "token" in content
assert "ok" in content
assert calls
assert calls[0][0] == "chat.stream"
assert trace_updates and trace_updates[0]["session_id"] == "api:trace-test"
@@ -0,0 +1,78 @@
from pathlib import Path
import sys
from fastapi.staticfiles import StaticFiles
from fastapi.testclient import TestClient
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
import main
def _prepare_webui(monkeypatch, tmp_path: Path) -> None:
webui_dir = tmp_path / "webui"
assets_dir = webui_dir / "assets"
assets_dir.mkdir(parents=True, exist_ok=True)
(webui_dir / "index.html").write_text("<html><body>dataclaw-webui</body></html>", encoding="utf-8")
(assets_dir / "app.js").write_text("window.__TASK2__=true;", encoding="utf-8")
monkeypatch.setattr(main, "_WEBUI_DIR", webui_dir)
monkeypatch.setattr(main, "_WEBUI_INDEX", webui_dir / "index.html")
monkeypatch.setattr(main, "_WEBUI_STATIC", StaticFiles(directory=str(webui_dir), html=False))
def _prepare_lifecycle(monkeypatch) -> None:
async def fake_start():
return None
async def fake_stop():
return None
monkeypatch.setattr(main.nanobot_service, "start", fake_start)
monkeypatch.setattr(main.nanobot_service, "stop", fake_stop)
def test_webui_static_assets_served_from_backend(monkeypatch, tmp_path) -> None:
_prepare_webui(monkeypatch, tmp_path)
_prepare_lifecycle(monkeypatch)
client = TestClient(main.app)
index_resp = client.get("/")
assert index_resp.status_code == 200
assert "dataclaw-webui" in index_resp.text
asset_resp = client.get("/assets/app.js")
assert asset_resp.status_code == 200
assert "window.__TASK2__=true;" in asset_resp.text
def test_spa_route_fallback_to_index_html(monkeypatch, tmp_path) -> None:
_prepare_webui(monkeypatch, tmp_path)
_prepare_lifecycle(monkeypatch)
client = TestClient(main.app)
spa_resp = client.get("/settings/users")
assert spa_resp.status_code == 200
assert "dataclaw-webui" in spa_resp.text
missing_asset_resp = client.get("/assets/missing.js")
assert missing_asset_resp.status_code == 404
def test_backend_accessible_without_frontend_dev_server(monkeypatch, tmp_path) -> None:
_prepare_webui(monkeypatch, tmp_path)
_prepare_lifecycle(monkeypatch)
client = TestClient(main.app)
ui_resp = client.get("/")
assert ui_resp.status_code == 200
assert "dataclaw-webui" in ui_resp.text
api_resp = client.get("/nanobot/status")
assert api_resp.status_code == 200
assert api_resp.json()["status"] in {"running", "stopped"}