Files

873 lines
34 KiB
Python
Raw Permalink Normal View History

2026-04-01 11:21:55 +08:00
import asyncio
2026-04-04 07:24:09 +08:00
import hashlib
import hmac
import json
2026-04-01 11:21:55 +08:00
import sys
2026-04-04 07:24:09 +08:00
import time
2026-04-01 11:21:55 +08:00
from collections.abc import Generator
2026-04-04 07:24:09 +08:00
from datetime import datetime
2026-04-01 11:21:55 +08:00
from pathlib import Path
2026-04-04 07:24:09 +08:00
from typing import Any, Dict, Optional
2026-04-01 11:21:55 +08:00
import httpx
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
from app.core.security import CurrentUser, get_current_user
from app.database import Base, get_db
2026-04-04 07:24:09 +08:00
from app.models.a2a import A2ARemoteAgent, A2ATask, A2ATaskState
2026-04-01 11:21:55 +08:00
from app.models.project import Project
from app.models.user import User
2026-04-04 07:24:09 +08:00
from app.schemas.a2a import A2AMessageRole, A2APartType, AgentSkillOutputMode, AgentSkillInputMode
from app.services.a2a_service import a2a_runtime, SharedSecretAuth
2026-04-01 11:21:55 +08:00
from main import app
def _seed(db: Session) -> tuple[int, str, int, str, int]:
owner = User(username="a2a_owner", email="a2a_owner@example.com", hashed_password="x", is_admin=False)
other = User(username="a2a_other", email="a2a_other@example.com", hashed_password="x", is_admin=False)
db.add(owner)
db.add(other)
db.commit()
db.refresh(owner)
db.refresh(other)
project = Project(name="a2a_project", description="a2a", owner_id=owner.id)
db.add(project)
db.commit()
db.refresh(project)
return owner.id, owner.username, other.id, other.username, project.id
2026-04-04 07:24:09 +08:00
def _make_message_payload(project_id: int, text: str, session_id: str = "test-session", route_mode: str = "local_first", idempotency_key: Optional[str] = None) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"message": {
"messageId": f"msg-{int(time.time()*1000)}",
"role": "user",
"parts": [
{
"part_type": "data",
"data": {
"project_id": project_id,
"route_mode": route_mode,
"session_id": session_id,
**( {"idempotency_key": idempotency_key} if idempotency_key else {} )
},
"mediaType": "application/json",
},
{
"part_type": "text",
"text": text,
}
],
}
}
return payload
class TestPartSerialization:
def test_part_text_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.TEXT,
text="Hello world",
mediaType="text/plain",
)
data = part.model_dump()
assert data["part_type"] == "text"
assert data["text"] == "Hello world"
assert data["mediaType"] == "text/plain"
def test_part_data_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.DATA,
data={"project_id": 123, "route_mode": "local"},
mediaType="application/json",
)
data = part.model_dump()
assert data["part_type"] == "data"
assert data["data"]["project_id"] == 123
def test_part_url_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.URL,
url="https://example.com/file.pdf",
mediaType="application/pdf",
filename="file.pdf",
)
data = part.model_dump()
assert data["part_type"] == "url"
assert data["url"] == "https://example.com/file.pdf"
assert data["filename"] == "file.pdf"
def test_part_raw_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.RAW,
raw="\x00\x01\x02\x03",
mediaType="application/octet-stream",
)
data = part.model_dump()
assert data["part_type"] == "raw"
assert data["raw"] == "\x00\x01\x02\x03"
class TestStateMachine:
def test_state_transitions_submit_to_complete(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "test state machine"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
assert task.state == A2ATaskState.SUBMITTED
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.WORKING)
assert task.state == A2ATaskState.WORKING
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.COMPLETED)
assert task.state == A2ATaskState.COMPLETED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_cancel(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
assert cancel_resp.status_code == 200
assert cancel_resp.json()["state"] == "CANCELED"
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_failed(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "fail test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.FAILED, error_message='{"message": "test error"}')
assert task.state == A2ATaskState.FAILED
assert task.error_message is not None
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_rejected(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "reject test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.REJECTED)
assert task.state == A2ATaskState.REJECTED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_input_required(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "input required test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.INPUT_REQUIRED)
assert task.state == A2ATaskState.INPUT_REQUIRED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_auth_required(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "auth required test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.AUTH_REQUIRED)
assert task.state == A2ATaskState.AUTH_REQUIRED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestA2APathNormalization:
def test_message_send_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "path test"))
assert resp.status_code == 200
assert "task" in resp.json()
assert "id" in resp.json()["task"]
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_message_stream_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:stream", json=_make_message_payload(project_id, "stream path test"))
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_tasks_cancel_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel path test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
assert cancel_resp.status_code == 200
assert cancel_resp.json()["task_id"] == task_id
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_agent_card_public_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
db.close()
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
app.dependency_overrides[get_db] = override_get_db
try:
client = TestClient(app)
resp = client.get("/api/v1/.well-known/agent-card.json")
assert resp.status_code == 200
data = resp.json()
assert "name" in data
assert "protocol_version" in data
assert "endpoints" in data
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestVersionNegotiation:
def test_version_not_supported_error(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post(
"/api/v1/message:send",
json=_make_message_payload(project_id, "version test"),
headers={"A2A-Version": "2.0"}
)
assert resp.status_code == 400
detail = json.loads(resp.json()["detail"])
assert detail["code"] == -32009
assert "not supported" in detail["message"].lower()
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_version_response_header(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
2026-04-01 11:21:55 +08:00
try:
2026-04-04 07:24:09 +08:00
client = TestClient(app)
resp = client.post(
"/api/v1/message:send",
json=_make_message_payload(project_id, "version header test"),
headers={"A2A-Version": "1.0"}
)
assert resp.status_code == 200
assert resp.headers.get("A2A-Version") == "1.0"
2026-04-01 11:21:55 +08:00
finally:
2026-04-04 07:24:09 +08:00
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestWebhookStreamResponse:
def test_webhook_payload_format(self):
from app.schemas.a2a import StreamResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent, TaskMessageEvent, A2ATaskStatusSchema, A2ATaskState, A2AArtifactSchema
status_event = TaskStatusUpdateEvent(
taskId="task-123",
contextId="ctx-456",
status=A2ATaskStatusSchema(
state=A2ATaskState.SUBMITTED,
timestamp=datetime.utcnow(),
),
metadata={},
2026-04-01 11:21:55 +08:00
)
2026-04-04 07:24:09 +08:00
status_dump = status_event.model_dump()
assert "taskId" in status_dump
assert status_dump["taskId"] == "task-123"
assert status_dump["status"]["state"] == "SUBMITTED"
artifact_event = TaskArtifactUpdateEvent(
taskId="task-123",
contextId="ctx-456",
artifact=A2AArtifactSchema(
artifactId="art-789",
parts=[],
),
append=False,
lastChunk=True,
)
artifact_dump = artifact_event.model_dump()
assert "taskId" in artifact_dump
assert artifact_dump["artifact"]["artifactId"] == "art-789"
def test_stream_response_task_field(self):
from app.schemas.a2a import StreamResponse, StreamResponseTask, A2ATaskState
resp = StreamResponse(
task=StreamResponseTask(
id="task-123",
contextId="ctx-456",
state=A2ATaskState.WORKING,
artifacts=[],
)
)
data = resp.model_dump()
assert "task" in data
assert data["task"]["id"] == "task-123"
class TestSSEFIFO:
def test_sse_event_order(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
with client.stream("POST", "/api/v1/message:stream", json=_make_message_payload(project_id, "fifo test")) as resp:
assert resp.status_code == 200
chunks = []
for line in resp.iter_lines():
if line.startswith("data: "):
chunks.append(json.loads(line[6:]))
event_types = [c.get("type") for c in chunks if "type" in c]
status_idx = next((i for i, t in enumerate(event_types) if t == "TaskStatusUpdateEvent"), -1)
assert status_idx >= 0
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestAuthSchemes:
def test_shared_secret_auth(self):
secret = "test-secret-key-12345"
timestamp = int(time.time())
body = b'{"test":"data"}'
sig, _ = SharedSecretAuth.generate_signature(secret, body, timestamp)
assert sig.startswith("sha256=")
assert SharedSecretAuth.verify_signature(secret, body, sig, timestamp) is True
def test_auth_scheme_none(self):
from app.schemas.a2a import SecuritySchemeHttpAuth
scheme = SecuritySchemeHttpAuth(scheme="bearer", description="Bearer auth")
assert scheme.scheme == "bearer"
class TestExceptionPaths:
def test_auth_failure_marks_agent_unhealthy(self, monkeypatch):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, _, _, _, project_id = _seed(db)
agent = A2ARemoteAgent(
project_id=project_id,
name="auth-fail-agent",
base_url="https://remote.example.com",
auth_scheme="bearer",
auth_token="bad-token",
created_by=owner_id,
)
db.add(agent)
db.commit()
db.refresh(agent)
a2a_runtime._circuit_state.pop(agent.id, None)
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
class _FailResp:
status_code = 401
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
@staticmethod
def json():
return {"detail": "unauthorized"}
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
class _Client401:
def __init__(self, *args, **kwargs):
pass
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
async def __aenter__(self):
return self
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
async def __aexit__(self, exc_type, exc, tb):
return False
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
async def get(self, url, headers=None):
return _FailResp()
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _Client401)
with pytest.raises(RuntimeError):
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
db.refresh(agent)
assert agent.healthy is False
assert agent.failure_count == 1
2026-04-01 11:21:55 +08:00
Base.metadata.drop_all(bind=engine)
2026-04-04 07:24:09 +08:00
db.close()
2026-04-01 11:21:55 +08:00
engine.dispose()
2026-04-04 07:24:09 +08:00
def test_remote_unavailable_opens_circuit(self, monkeypatch):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, _, _, _, project_id = _seed(db)
agent = A2ARemoteAgent(
project_id=project_id,
name="offline-agent",
base_url="https://offline.example.com",
auth_scheme="none",
created_by=owner_id,
)
db.add(agent)
db.commit()
db.refresh(agent)
a2a_runtime._circuit_state.pop(agent.id, None)
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
class _ClientDown:
def __init__(self, *args, **kwargs):
pass
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
async def __aenter__(self):
return self
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
async def __aexit__(self, exc_type, exc, tb):
return False
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
async def get(self, url, headers=None):
raise httpx.ConnectError("network down")
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _ClientDown)
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
for _ in range(3):
with pytest.raises(Exception):
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
db.refresh(agent)
assert agent.healthy is False
assert agent.failure_count == 3
assert agent.circuit_open_until is not None
2026-04-01 11:21:55 +08:00
Base.metadata.drop_all(bind=engine)
2026-04-04 07:24:09 +08:00
db.close()
2026-04-01 11:21:55 +08:00
engine.dispose()
2026-04-04 07:24:09 +08:00
def test_idempotency_key_deduplicates_task(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
def override_current_user() -> CurrentUser:
return state["user"]
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
idempotency_key = f"idem-key-{int(time.time())}"
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
payload1 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
resp1 = client.post("/api/v1/message:send", json=payload1)
assert resp1.status_code == 200
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
payload2 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
payload2["message"]["messageId"] = f"msg-{int(time.time()*1000) + 1}"
resp2 = client.post("/api/v1/message:send", json=payload2)
2026-04-01 11:21:55 +08:00
2026-04-04 07:24:09 +08:00
assert resp2.status_code == 200
assert resp1.json()["task"]["id"] == resp2.json()["task"]["id"]
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_tenant_isolation(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "isolation test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
get_resp = client.get(f"/api/v1/tasks/{task_id}")
assert get_resp.status_code == 404
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestMetricsAdminOnly:
def test_metrics_admin_only(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, _ = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
denied = client.get("/api/v1/a2a/metrics")
assert denied.status_code == 403
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
ok = client.get("/api/v1/a2a/metrics")
assert ok.status_code == 200
assert "counters" in ok.json()
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()