feat: support a2a mode
This commit is contained in:
@@ -0,0 +1,891 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.nanobot import nanobot_service
|
||||
from app.core.security import CurrentUser, get_current_user
|
||||
from app.database import SessionLocal, get_db
|
||||
from app.models.a2a import (
|
||||
A2AAuditLog,
|
||||
A2AProjectConfig,
|
||||
A2ARemoteAgent,
|
||||
A2ATask,
|
||||
A2ATaskEvent,
|
||||
A2ATaskWebhook,
|
||||
A2AWebhookDelivery,
|
||||
)
|
||||
from app.models.project import Project
|
||||
from app.services.a2a_service import _json_dumps, _json_loads, a2a_runtime
|
||||
from app.trace import build_error_attributes, trace_service
|
||||
|
||||
router = APIRouter(prefix="/a2a", tags=["a2a"])
|
||||
|
||||
SUPPORTED_PROTOCOL_VERSION = "1.0"
|
||||
SUPPORTED_CAPABILITIES = ["streaming", "push", "task_management", "subscribe"]
|
||||
SUPPORTED_AUTH = ["bearer", "shared_secret", "none"]
|
||||
|
||||
|
||||
def _mask_error(message: str) -> str:
|
||||
if not message:
|
||||
return "internal_error"
|
||||
return "request_failed"
|
||||
|
||||
|
||||
class AgentCardResponse(BaseModel):
|
||||
name: str
|
||||
protocol_version: str
|
||||
capabilities: List[str]
|
||||
endpoints: Dict[str, str]
|
||||
auth: List[str]
|
||||
|
||||
|
||||
class RemoteAgentCreate(BaseModel):
|
||||
project_id: int
|
||||
name: str = Field(min_length=1, max_length=120)
|
||||
base_url: str = Field(min_length=1, max_length=500)
|
||||
auth_scheme: Literal["none", "bearer"] = "none"
|
||||
auth_token: Optional[str] = None
|
||||
|
||||
|
||||
class RemoteAgentUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
auth_scheme: Optional[Literal["none", "bearer"]] = None
|
||||
auth_token: Optional[str] = None
|
||||
|
||||
|
||||
class RemoteAgentView(BaseModel):
|
||||
id: int
|
||||
project_id: int
|
||||
name: str
|
||||
base_url: str
|
||||
auth_scheme: str
|
||||
protocol_version: Optional[str] = None
|
||||
capabilities: List[str] = []
|
||||
healthy: bool
|
||||
failure_count: int
|
||||
circuit_open_until: Optional[datetime] = None
|
||||
card_fetched_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
project_id: int
|
||||
message: str = Field(min_length=1)
|
||||
session_id: str = "api:a2a"
|
||||
remote_agent_id: Optional[int] = None
|
||||
route_mode: Literal["auto", "local", "a2a", "a2a_first", "local_first", "mcp_first"] = "auto"
|
||||
fallback_chain: Optional[List[Literal["a2a", "local", "mcp"]]] = None
|
||||
idempotency_key: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class TaskView(BaseModel):
|
||||
id: str
|
||||
project_id: int
|
||||
source: str
|
||||
state: str
|
||||
remote_agent_id: Optional[int] = None
|
||||
input_text: str
|
||||
output_text: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
compatibility_mode: bool
|
||||
metadata: Dict[str, Any]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
finished_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class CancelTaskResponse(BaseModel):
|
||||
task_id: str
|
||||
state: str
|
||||
|
||||
|
||||
class TaskWebhookCreate(BaseModel):
|
||||
target_url: str = Field(min_length=1, max_length=500)
|
||||
secret: Optional[str] = None
|
||||
auth_header: Optional[str] = None
|
||||
|
||||
|
||||
class TaskWebhookView(BaseModel):
|
||||
id: int
|
||||
task_id: str
|
||||
target_url: str
|
||||
enabled: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class RolloutConfigView(BaseModel):
|
||||
project_id: int
|
||||
canary_enabled: bool
|
||||
canary_percent: int
|
||||
rollback_to_local: bool
|
||||
compatibility_mode: bool
|
||||
dual_event_write: bool
|
||||
route_mode_default: str
|
||||
fallback_chain: List[str]
|
||||
alert_thresholds: Dict[str, Any]
|
||||
|
||||
|
||||
class RolloutConfigUpdate(BaseModel):
|
||||
canary_enabled: Optional[bool] = None
|
||||
canary_percent: Optional[int] = Field(default=None, ge=0, le=100)
|
||||
rollback_to_local: Optional[bool] = None
|
||||
compatibility_mode: Optional[bool] = None
|
||||
dual_event_write: Optional[bool] = None
|
||||
route_mode_default: Optional[str] = None
|
||||
fallback_chain: Optional[List[str]] = None
|
||||
alert_thresholds: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def _ensure_project_access(db: Session, project_id: int, user: CurrentUser) -> Project:
|
||||
project = db.query(Project).filter(Project.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
if not user.is_admin and project.owner_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
return project
|
||||
|
||||
|
||||
def _ensure_task_access(db: Session, task_id: str, user: CurrentUser) -> A2ATask:
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
if not user.is_admin and task.tenant_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
return task
|
||||
|
||||
|
||||
def _ensure_agent_access(db: Session, agent_id: int, user: CurrentUser) -> A2ARemoteAgent:
|
||||
agent = db.query(A2ARemoteAgent).filter(A2ARemoteAgent.id == agent_id).first()
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Remote agent not found")
|
||||
project = _ensure_project_access(db, agent.project_id, user)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Remote agent not found")
|
||||
return agent
|
||||
|
||||
|
||||
def _task_to_view(task: A2ATask) -> TaskView:
|
||||
return TaskView(
|
||||
id=task.id,
|
||||
project_id=task.project_id,
|
||||
source=task.source,
|
||||
state=task.state,
|
||||
remote_agent_id=task.remote_agent_id,
|
||||
input_text=task.input_text,
|
||||
output_text=task.output_text,
|
||||
error_message=task.error_message,
|
||||
compatibility_mode=task.compatibility_mode,
|
||||
metadata=_json_loads(task.metadata_json, {}),
|
||||
created_at=task.created_at,
|
||||
updated_at=task.updated_at,
|
||||
finished_at=task.finished_at,
|
||||
)
|
||||
|
||||
|
||||
def _agent_to_view(agent: A2ARemoteAgent) -> RemoteAgentView:
|
||||
return RemoteAgentView(
|
||||
id=agent.id,
|
||||
project_id=agent.project_id,
|
||||
name=agent.name,
|
||||
base_url=agent.base_url,
|
||||
auth_scheme=agent.auth_scheme,
|
||||
protocol_version=agent.protocol_version,
|
||||
capabilities=_json_loads(agent.capabilities_json, []),
|
||||
healthy=bool(agent.healthy),
|
||||
failure_count=int(agent.failure_count or 0),
|
||||
circuit_open_until=agent.circuit_open_until,
|
||||
card_fetched_at=agent.card_fetched_at,
|
||||
)
|
||||
|
||||
|
||||
def _build_status_event(task: A2ATask, *, compatibility_mode: bool, dual_event_write: bool) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"type": "TaskStatusUpdateEvent",
|
||||
"task_id": task.id,
|
||||
"task_status": task.state,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"source": task.source,
|
||||
}
|
||||
if compatibility_mode or dual_event_write:
|
||||
payload.update(
|
||||
{
|
||||
"event": "task_status",
|
||||
"status": task.state,
|
||||
"taskId": task.id,
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def _build_artifact_event(task_id: str, content: str, *, compatibility_mode: bool, dual_event_write: bool) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"type": "TaskArtifactUpdateEvent",
|
||||
"task_id": task_id,
|
||||
"artifact": {"content": content},
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
if compatibility_mode or dual_event_write:
|
||||
payload.update(
|
||||
{
|
||||
"event": "task_output",
|
||||
"taskId": task_id,
|
||||
"output": content,
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
async def _delegate_to_remote(task: A2ATask, agent: A2ARemoteAgent, message: str) -> Tuple[str, Dict[str, Any]]:
|
||||
headers: Dict[str, str] = {}
|
||||
if agent.auth_scheme == "bearer" and agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {agent.auth_token}"
|
||||
payload = {
|
||||
"project_id": task.project_id,
|
||||
"message": message,
|
||||
"session_id": f"a2a-delegate:{task.id}",
|
||||
"idempotency_key": task.idempotency_key,
|
||||
"route_mode": "local_first",
|
||||
"metadata": {"delegated_by": "dataclaw", "task_id": task.id},
|
||||
}
|
||||
url = f"{agent.base_url.rstrip('/')}/api/v1/a2a/messages/send"
|
||||
async with httpx.AsyncClient(timeout=25.0, verify=True) as client:
|
||||
resp = await client.post(url, json=payload, headers=headers)
|
||||
if resp.status_code >= 400:
|
||||
raise RuntimeError(f"remote_http_{resp.status_code}")
|
||||
body = resp.json()
|
||||
content = ""
|
||||
if isinstance(body, dict):
|
||||
task_obj = body.get("task") or {}
|
||||
content = str(task_obj.get("output_text") or body.get("message") or "")
|
||||
return content, body
|
||||
|
||||
|
||||
async def _run_task(task_id: str, request: SendMessageRequest, tenant_id: int) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
if not task:
|
||||
return
|
||||
config = a2a_runtime.get_project_config(db, task.project_id, tenant_id)
|
||||
if task.state in {"CANCELED", "REJECTED"}:
|
||||
return
|
||||
with trace_service.start_span("a2a.task.execute", attributes={"task_id": task.id, "project_id": task.project_id, "source": task.source}) as span:
|
||||
start_ts = datetime.utcnow().timestamp()
|
||||
try:
|
||||
task = a2a_runtime.transition_task(db, task, to_state="WORKING")
|
||||
status_event = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
|
||||
status_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", status_event)
|
||||
await a2a_runtime.publish(task.id, status_event)
|
||||
await a2a_runtime.notify_webhooks(db, task, status_row)
|
||||
|
||||
if task.source == "a2a" and task.remote_agent_id:
|
||||
agent = db.query(A2ARemoteAgent).filter(A2ARemoteAgent.id == task.remote_agent_id).first()
|
||||
if not agent:
|
||||
raise RuntimeError("remote_agent_missing")
|
||||
response_text, metadata = await _delegate_to_remote(task, agent, request.message)
|
||||
else:
|
||||
response_text = await nanobot_service.process_message(
|
||||
request.message,
|
||||
session_id=f"a2a-task:{task.id}",
|
||||
project_id=task.project_id,
|
||||
)
|
||||
metadata = {"executor": "local"}
|
||||
artifact_event_payload = _build_artifact_event(task.id, response_text or "", compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
|
||||
artifact_event = a2a_runtime.append_event(db, task, "TaskArtifactUpdateEvent", artifact_event_payload)
|
||||
await a2a_runtime.publish(task.id, artifact_event_payload)
|
||||
await a2a_runtime.notify_webhooks(db, task, artifact_event)
|
||||
task = a2a_runtime.transition_task(
|
||||
db,
|
||||
task,
|
||||
to_state="COMPLETED",
|
||||
output_text=response_text or "",
|
||||
metadata=metadata,
|
||||
)
|
||||
done_event = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
|
||||
done_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", done_event)
|
||||
await a2a_runtime.publish(task.id, done_event)
|
||||
await a2a_runtime.notify_webhooks(db, task, done_row)
|
||||
elapsed = (datetime.utcnow().timestamp() - start_ts) * 1000
|
||||
await a2a_runtime.metrics.observe_latency("a2a.execute", elapsed)
|
||||
except Exception as exc:
|
||||
span.set_attributes(build_error_attributes(exc, stage="a2a_task_execute"))
|
||||
await a2a_runtime.metrics.incr("a2a.requests.error")
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task.id).first()
|
||||
if task and task.state not in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
|
||||
task = a2a_runtime.transition_task(db, task, to_state="FAILED", error_message=_json_dumps({"message": _mask_error(str(exc))}))
|
||||
fail_event = _build_status_event(task, compatibility_mode=task.compatibility_mode, dual_event_write=True)
|
||||
fail_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", fail_event)
|
||||
await a2a_runtime.publish(task.id, fail_event)
|
||||
await a2a_runtime.notify_webhooks(db, task, fail_row)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/agent-card", response_model=AgentCardResponse)
|
||||
def get_agent_card() -> AgentCardResponse:
|
||||
return AgentCardResponse(
|
||||
name="DataClaw A2A Gateway",
|
||||
protocol_version=SUPPORTED_PROTOCOL_VERSION,
|
||||
capabilities=SUPPORTED_CAPABILITIES,
|
||||
endpoints={
|
||||
"send_message": "/api/v1/a2a/messages/send",
|
||||
"send_streaming_message": "/api/v1/a2a/messages/stream",
|
||||
"get_task": "/api/v1/a2a/tasks/{task_id}",
|
||||
"list_tasks": "/api/v1/a2a/tasks",
|
||||
"cancel_task": "/api/v1/a2a/tasks/{task_id}/cancel",
|
||||
"subscribe_task": "/api/v1/a2a/tasks/{task_id}/subscribe",
|
||||
},
|
||||
auth=SUPPORTED_AUTH,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/remote-agents", response_model=List[RemoteAgentView])
|
||||
def list_remote_agents(
|
||||
project_id: Optional[int] = Query(default=None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> List[RemoteAgentView]:
|
||||
query = db.query(A2ARemoteAgent)
|
||||
if project_id is not None:
|
||||
_ensure_project_access(db, project_id, current_user)
|
||||
query = query.filter(A2ARemoteAgent.project_id == project_id)
|
||||
if not current_user.is_admin:
|
||||
owned_ids = [p.id for p in db.query(Project).filter(Project.owner_id == current_user.id).all()]
|
||||
if not owned_ids:
|
||||
return []
|
||||
query = query.filter(A2ARemoteAgent.project_id.in_(owned_ids))
|
||||
return [_agent_to_view(item) for item in query.order_by(A2ARemoteAgent.id.desc()).all()]
|
||||
|
||||
|
||||
@router.post("/remote-agents", response_model=RemoteAgentView, status_code=status.HTTP_201_CREATED)
|
||||
async def create_remote_agent(
|
||||
payload: RemoteAgentCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> RemoteAgentView:
|
||||
_ensure_project_access(db, payload.project_id, current_user)
|
||||
item = A2ARemoteAgent(
|
||||
project_id=payload.project_id,
|
||||
name=payload.name.strip(),
|
||||
base_url=payload.base_url.strip().rstrip("/"),
|
||||
auth_scheme=payload.auth_scheme,
|
||||
auth_token=payload.auth_token,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
try:
|
||||
await a2a_runtime.fetch_agent_card(db, item)
|
||||
except Exception:
|
||||
pass
|
||||
a2a_runtime.record_audit(
|
||||
db,
|
||||
actor_user_id=current_user.id,
|
||||
action="create_remote_agent",
|
||||
target_type="remote_agent",
|
||||
target_id=str(item.id),
|
||||
result="ok",
|
||||
project_id=item.project_id,
|
||||
)
|
||||
return _agent_to_view(item)
|
||||
|
||||
|
||||
@router.put("/remote-agents/{agent_id}", response_model=RemoteAgentView)
|
||||
async def update_remote_agent(
|
||||
agent_id: int,
|
||||
payload: RemoteAgentUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> RemoteAgentView:
|
||||
item = _ensure_agent_access(db, agent_id, current_user)
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(item, key, value)
|
||||
if item.base_url:
|
||||
item.base_url = item.base_url.rstrip("/")
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
try:
|
||||
await a2a_runtime.fetch_agent_card(db, item)
|
||||
except Exception:
|
||||
pass
|
||||
a2a_runtime.record_audit(
|
||||
db,
|
||||
actor_user_id=current_user.id,
|
||||
action="update_remote_agent",
|
||||
target_type="remote_agent",
|
||||
target_id=str(item.id),
|
||||
result="ok",
|
||||
project_id=item.project_id,
|
||||
)
|
||||
return _agent_to_view(item)
|
||||
|
||||
|
||||
@router.delete("/remote-agents/{agent_id}")
|
||||
def delete_remote_agent(
|
||||
agent_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> Dict[str, str]:
|
||||
item = _ensure_agent_access(db, agent_id, current_user)
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
a2a_runtime.record_audit(
|
||||
db,
|
||||
actor_user_id=current_user.id,
|
||||
action="delete_remote_agent",
|
||||
target_type="remote_agent",
|
||||
target_id=str(agent_id),
|
||||
result="ok",
|
||||
project_id=item.project_id,
|
||||
)
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/remote-agents/{agent_id}/refresh-card", response_model=RemoteAgentView)
|
||||
async def refresh_remote_agent_card(
|
||||
agent_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> RemoteAgentView:
|
||||
item = _ensure_agent_access(db, agent_id, current_user)
|
||||
try:
|
||||
card = await a2a_runtime.fetch_agent_card(db, item)
|
||||
except Exception as exc:
|
||||
a2a_runtime.record_audit(
|
||||
db,
|
||||
actor_user_id=current_user.id,
|
||||
action="refresh_remote_agent_card",
|
||||
target_type="remote_agent",
|
||||
target_id=str(agent_id),
|
||||
result="failed",
|
||||
project_id=item.project_id,
|
||||
detail={"error": str(exc)},
|
||||
)
|
||||
raise HTTPException(status_code=502, detail="Remote card fetch failed")
|
||||
version = str(card.get("protocol_version") or "")
|
||||
if version and version.split(".")[0] != SUPPORTED_PROTOCOL_VERSION.split(".")[0]:
|
||||
raise HTTPException(status_code=400, detail="Protocol version incompatible")
|
||||
return _agent_to_view(item)
|
||||
|
||||
|
||||
@router.post("/remote-agents/{agent_id}/health-check")
|
||||
async def health_check_remote_agent(
|
||||
agent_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
item = _ensure_agent_access(db, agent_id, current_user)
|
||||
try:
|
||||
await a2a_runtime.fetch_agent_card(db, item, timeout_s=5.0)
|
||||
return {"healthy": True, "failure_count": item.failure_count}
|
||||
except Exception:
|
||||
return {"healthy": False, "failure_count": item.failure_count}
|
||||
|
||||
|
||||
@router.post("/messages/send")
|
||||
async def send_message(
|
||||
request: SendMessageRequest,
|
||||
x_a2a_token: Optional[str] = Header(default=None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
_ensure_project_access(db, request.project_id, current_user)
|
||||
config = a2a_runtime.get_project_config(db, request.project_id, current_user.id)
|
||||
route = a2a_runtime.resolve_route(
|
||||
project_config=config,
|
||||
session_id=request.session_id,
|
||||
requested_mode=request.route_mode,
|
||||
requested_fallback=request.fallback_chain,
|
||||
)
|
||||
selected_source = "local"
|
||||
remote_agent_id = None
|
||||
if route.selected == "a2a" and request.remote_agent_id:
|
||||
agent = _ensure_agent_access(db, request.remote_agent_id, current_user)
|
||||
if not agent.healthy and config.rollback_to_local:
|
||||
selected_source = "local"
|
||||
else:
|
||||
selected_source = "a2a"
|
||||
remote_agent_id = agent.id
|
||||
task = a2a_runtime.create_task(
|
||||
db,
|
||||
project_id=request.project_id,
|
||||
tenant_id=current_user.id,
|
||||
source=selected_source,
|
||||
input_text=request.message,
|
||||
idempotency_key=request.idempotency_key,
|
||||
remote_agent_id=remote_agent_id,
|
||||
compatibility_mode=config.compatibility_mode,
|
||||
metadata={"route": route.model_dump() if hasattr(route, "model_dump") else route.__dict__, "token_present": bool(x_a2a_token), "request_metadata": request.metadata or {}},
|
||||
)
|
||||
event_payload = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
|
||||
event_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", event_payload)
|
||||
await a2a_runtime.publish(task.id, event_payload)
|
||||
await a2a_runtime.notify_webhooks(db, task, event_row)
|
||||
asyncio.create_task(_run_task(task.id, request, current_user.id))
|
||||
await a2a_runtime.metrics.incr("a2a.requests.total")
|
||||
a2a_runtime.record_audit(
|
||||
db,
|
||||
actor_user_id=current_user.id,
|
||||
action="send_message",
|
||||
target_type="task",
|
||||
target_id=task.id,
|
||||
result="accepted",
|
||||
project_id=task.project_id,
|
||||
task_id=task.id,
|
||||
)
|
||||
return {"task": _task_to_view(task).model_dump(), "routing": route.__dict__}
|
||||
|
||||
|
||||
@router.post("/messages/stream")
|
||||
async def send_streaming_message(
|
||||
request: SendMessageRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
response = await send_message(request=request, x_a2a_token=None, db=db, current_user=current_user)
|
||||
task_id = response["task"]["id"]
|
||||
|
||||
async def event_generator():
|
||||
history = (
|
||||
db.query(A2ATaskEvent)
|
||||
.filter(A2ATaskEvent.task_id == task_id)
|
||||
.order_by(A2ATaskEvent.id.asc())
|
||||
.all()
|
||||
)
|
||||
for item in history:
|
||||
payload = _json_loads(item.payload_json, {})
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
async for payload in a2a_runtime.subscribe(task_id):
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
if payload.get("task_status") in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
|
||||
break
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=TaskView)
|
||||
def get_task(
|
||||
task_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> TaskView:
|
||||
task = _ensure_task_access(db, task_id, current_user)
|
||||
return _task_to_view(task)
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=List[TaskView])
|
||||
def list_tasks(
|
||||
project_id: Optional[int] = Query(default=None),
|
||||
state: Optional[str] = Query(default=None),
|
||||
skip: int = Query(default=0, ge=0),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> List[TaskView]:
|
||||
query = db.query(A2ATask)
|
||||
if not current_user.is_admin:
|
||||
query = query.filter(A2ATask.tenant_id == current_user.id)
|
||||
if project_id is not None:
|
||||
_ensure_project_access(db, project_id, current_user)
|
||||
query = query.filter(A2ATask.project_id == project_id)
|
||||
if state:
|
||||
query = query.filter(A2ATask.state == state)
|
||||
tasks = query.order_by(A2ATask.created_at.desc()).offset(skip).limit(limit).all()
|
||||
return [_task_to_view(item) for item in tasks]
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/cancel", response_model=CancelTaskResponse)
|
||||
async def cancel_task(
|
||||
task_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CancelTaskResponse:
|
||||
task = _ensure_task_access(db, task_id, current_user)
|
||||
if task.state in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
|
||||
return CancelTaskResponse(task_id=task.id, state=task.state)
|
||||
try:
|
||||
task = a2a_runtime.transition_task(db, task, to_state="CANCELED")
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=409, detail="Task state transition conflict")
|
||||
config = a2a_runtime.get_project_config(db, task.project_id, current_user.id)
|
||||
payload = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
|
||||
row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", payload)
|
||||
await a2a_runtime.publish(task.id, payload)
|
||||
await a2a_runtime.notify_webhooks(db, task, row)
|
||||
a2a_runtime.record_audit(
|
||||
db,
|
||||
actor_user_id=current_user.id,
|
||||
action="cancel_task",
|
||||
target_type="task",
|
||||
target_id=task.id,
|
||||
result="ok",
|
||||
project_id=task.project_id,
|
||||
task_id=task.id,
|
||||
)
|
||||
return CancelTaskResponse(task_id=task.id, state=task.state)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}/subscribe")
|
||||
async def subscribe_task(
|
||||
task_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> StreamingResponse:
|
||||
task = _ensure_task_access(db, task_id, current_user)
|
||||
initial_events = (
|
||||
db.query(A2ATaskEvent)
|
||||
.filter(A2ATaskEvent.task_id == task.id)
|
||||
.order_by(A2ATaskEvent.id.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
async def event_generator():
|
||||
for event in initial_events:
|
||||
payload = _json_loads(event.payload_json, {})
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
if task.state in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
return
|
||||
async for payload in a2a_runtime.subscribe(task.id):
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
if payload.get("task_status") in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
|
||||
break
|
||||
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}/webhooks", response_model=List[TaskWebhookView])
|
||||
def list_task_webhooks(
|
||||
task_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> List[TaskWebhookView]:
|
||||
task = _ensure_task_access(db, task_id, current_user)
|
||||
items = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.task_id == task.id).order_by(A2ATaskWebhook.id.desc()).all()
|
||||
return [
|
||||
TaskWebhookView(
|
||||
id=item.id,
|
||||
task_id=item.task_id,
|
||||
target_url=item.target_url,
|
||||
enabled=item.enabled,
|
||||
created_at=item.created_at,
|
||||
updated_at=item.updated_at,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/webhooks", response_model=TaskWebhookView, status_code=status.HTTP_201_CREATED)
|
||||
def create_task_webhook(
|
||||
task_id: str,
|
||||
payload: TaskWebhookCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> TaskWebhookView:
|
||||
task = _ensure_task_access(db, task_id, current_user)
|
||||
item = A2ATaskWebhook(
|
||||
task_id=task.id,
|
||||
target_url=payload.target_url.strip(),
|
||||
secret=payload.secret,
|
||||
auth_header=payload.auth_header,
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
a2a_runtime.record_audit(
|
||||
db,
|
||||
actor_user_id=current_user.id,
|
||||
action="create_task_webhook",
|
||||
target_type="task_webhook",
|
||||
target_id=str(item.id),
|
||||
result="ok",
|
||||
project_id=task.project_id,
|
||||
task_id=task.id,
|
||||
)
|
||||
return TaskWebhookView(
|
||||
id=item.id,
|
||||
task_id=item.task_id,
|
||||
target_url=item.target_url,
|
||||
enabled=item.enabled,
|
||||
created_at=item.created_at,
|
||||
updated_at=item.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/tasks/{task_id}/webhooks/{webhook_id}")
|
||||
def delete_task_webhook(
|
||||
task_id: str,
|
||||
webhook_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> Dict[str, str]:
|
||||
task = _ensure_task_access(db, task_id, current_user)
|
||||
item = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.id == webhook_id, A2ATaskWebhook.task_id == task.id).first()
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Webhook not found")
|
||||
db.delete(item)
|
||||
db.commit()
|
||||
return {"status": "success"}
|
||||
|
||||
|
||||
@router.post("/webhook-deliveries/{delivery_id}/replay")
|
||||
async def replay_delivery(
|
||||
delivery_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
delivery = db.query(A2AWebhookDelivery).filter(A2AWebhookDelivery.id == delivery_id).first()
|
||||
if not delivery:
|
||||
raise HTTPException(status_code=404, detail="Delivery not found")
|
||||
task = _ensure_task_access(db, delivery.task_id, current_user)
|
||||
webhook = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.id == delivery.webhook_id).first()
|
||||
event = db.query(A2ATaskEvent).filter(A2ATaskEvent.id == delivery.event_id).first()
|
||||
if not webhook or not event:
|
||||
raise HTTPException(status_code=404, detail="Delivery dependencies not found")
|
||||
await a2a_runtime._deliver_once(db, webhook, event, delivery)
|
||||
return {"status": delivery.status, "attempt": delivery.attempt, "dead_letter": delivery.dead_letter, "task_id": task.id}
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def get_metrics(current_user: CurrentUser = Depends(get_current_user)) -> Dict[str, Any]:
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Admin permission required")
|
||||
return await a2a_runtime.metrics.snapshot()
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/rollout", response_model=RolloutConfigView)
|
||||
def get_rollout_config(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> RolloutConfigView:
|
||||
_ensure_project_access(db, project_id, current_user)
|
||||
item = a2a_runtime.get_project_config(db, project_id, current_user.id)
|
||||
return RolloutConfigView(
|
||||
project_id=item.project_id,
|
||||
canary_enabled=item.canary_enabled,
|
||||
canary_percent=item.canary_percent,
|
||||
rollback_to_local=item.rollback_to_local,
|
||||
compatibility_mode=item.compatibility_mode,
|
||||
dual_event_write=item.dual_event_write,
|
||||
route_mode_default=item.route_mode_default,
|
||||
fallback_chain=_json_loads(item.fallback_chain_json, ["local"]),
|
||||
alert_thresholds=_json_loads(item.alert_thresholds_json, {}),
|
||||
)
|
||||
|
||||
|
||||
@router.put("/projects/{project_id}/rollout", response_model=RolloutConfigView)
|
||||
def update_rollout_config(
|
||||
project_id: int,
|
||||
payload: RolloutConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> RolloutConfigView:
|
||||
_ensure_project_access(db, project_id, current_user)
|
||||
item = a2a_runtime.get_project_config(db, project_id, current_user.id)
|
||||
data = payload.model_dump(exclude_unset=True)
|
||||
for key, value in data.items():
|
||||
if key == "fallback_chain":
|
||||
item.fallback_chain_json = _json_dumps(value)
|
||||
continue
|
||||
if key == "alert_thresholds":
|
||||
item.alert_thresholds_json = _json_dumps(value)
|
||||
continue
|
||||
setattr(item, key, value)
|
||||
item.updated_by = current_user.id
|
||||
db.add(item)
|
||||
db.commit()
|
||||
db.refresh(item)
|
||||
a2a_runtime.record_audit(
|
||||
db,
|
||||
actor_user_id=current_user.id,
|
||||
action="update_rollout_config",
|
||||
target_type="project_rollout",
|
||||
target_id=str(project_id),
|
||||
result="ok",
|
||||
project_id=project_id,
|
||||
)
|
||||
return RolloutConfigView(
|
||||
project_id=item.project_id,
|
||||
canary_enabled=item.canary_enabled,
|
||||
canary_percent=item.canary_percent,
|
||||
rollback_to_local=item.rollback_to_local,
|
||||
compatibility_mode=item.compatibility_mode,
|
||||
dual_event_write=item.dual_event_write,
|
||||
route_mode_default=item.route_mode_default,
|
||||
fallback_chain=_json_loads(item.fallback_chain_json, ["local"]),
|
||||
alert_thresholds=_json_loads(item.alert_thresholds_json, {}),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/alerts")
|
||||
def get_alert_panel(
|
||||
project_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
_ensure_project_access(db, project_id, current_user)
|
||||
config = a2a_runtime.get_project_config(db, project_id, current_user.id)
|
||||
thresholds = _json_loads(config.alert_thresholds_json, {})
|
||||
defaults = {"error_rate": 0.05, "p95_ms": 3000, "retry_rate": 0.2, "circuit_open_rate": 0.05}
|
||||
merged = {**defaults, **thresholds}
|
||||
return {
|
||||
"project_id": project_id,
|
||||
"thresholds": merged,
|
||||
"panel": {"metrics_endpoint": "/api/v1/a2a/metrics", "task_list_endpoint": "/api/v1/a2a/tasks"},
|
||||
}
|
||||
|
||||
|
||||
@router.get("/audit-logs")
|
||||
def list_audit_logs(
|
||||
project_id: Optional[int] = Query(default=None),
|
||||
skip: int = Query(default=0, ge=0),
|
||||
limit: int = Query(default=100, ge=1, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> List[Dict[str, Any]]:
|
||||
query = db.query(A2AAuditLog)
|
||||
if project_id is not None:
|
||||
_ensure_project_access(db, project_id, current_user)
|
||||
query = query.filter(A2AAuditLog.project_id == project_id)
|
||||
elif not current_user.is_admin:
|
||||
query = query.filter(A2AAuditLog.actor_user_id == current_user.id)
|
||||
rows = query.order_by(A2AAuditLog.created_at.desc()).offset(skip).limit(limit).all()
|
||||
return [
|
||||
{
|
||||
"id": row.id,
|
||||
"actor_user_id": row.actor_user_id,
|
||||
"action": row.action,
|
||||
"target_type": row.target_type,
|
||||
"target_id": row.target_id,
|
||||
"project_id": row.project_id,
|
||||
"task_id": row.task_id,
|
||||
"result": row.result,
|
||||
"detail": _json_loads(row.detail_json, {}),
|
||||
"created_at": row.created_at,
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
@@ -0,0 +1,134 @@
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class A2ARemoteAgent(Base):
|
||||
__tablename__ = "a2a_remote_agents"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
base_url = Column(String, nullable=False)
|
||||
auth_scheme = Column(String, nullable=False, default="none")
|
||||
auth_token = Column(String, nullable=True)
|
||||
protocol_version = Column(String, nullable=True)
|
||||
capabilities_json = Column(Text, nullable=False, default="[]")
|
||||
card_json = Column(Text, nullable=True)
|
||||
card_fetched_at = Column(DateTime, nullable=True)
|
||||
healthy = Column(Boolean, nullable=False, default=False)
|
||||
failure_count = Column(Integer, nullable=False, default=0)
|
||||
circuit_open_until = Column(DateTime, nullable=True)
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
project = relationship("Project")
|
||||
|
||||
|
||||
class A2ATask(Base):
|
||||
__tablename__ = "a2a_tasks"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
|
||||
tenant_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
source = Column(String, nullable=False, default="local")
|
||||
remote_agent_id = Column(Integer, ForeignKey("a2a_remote_agents.id"), nullable=True, index=True)
|
||||
idempotency_key = Column(String, nullable=True, index=True)
|
||||
state = Column(String, nullable=False, index=True, default="SUBMITTED")
|
||||
input_text = Column(Text, nullable=False, default="")
|
||||
output_text = Column(Text, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
compatibility_mode = Column(Boolean, nullable=False, default=True)
|
||||
metadata_json = Column(Text, nullable=False, default="{}")
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
finished_at = Column(DateTime, nullable=True)
|
||||
|
||||
project = relationship("Project")
|
||||
remote_agent = relationship("A2ARemoteAgent")
|
||||
|
||||
|
||||
class A2ATaskEvent(Base):
|
||||
__tablename__ = "a2a_task_events"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
event_type = Column(String, nullable=False)
|
||||
payload_json = Column(Text, nullable=False, default="{}")
|
||||
created_at = Column(DateTime, default=func.now(), index=True)
|
||||
|
||||
task = relationship("A2ATask")
|
||||
|
||||
|
||||
class A2ATaskWebhook(Base):
|
||||
__tablename__ = "a2a_task_webhooks"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
target_url = Column(String, nullable=False)
|
||||
secret = Column(String, nullable=True)
|
||||
auth_header = Column(String, nullable=True)
|
||||
enabled = Column(Boolean, nullable=False, default=True)
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
task = relationship("A2ATask")
|
||||
|
||||
|
||||
class A2AWebhookDelivery(Base):
|
||||
__tablename__ = "a2a_webhook_deliveries"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
webhook_id = Column(Integer, ForeignKey("a2a_task_webhooks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
event_id = Column(Integer, ForeignKey("a2a_task_events.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
attempt = Column(Integer, nullable=False, default=0)
|
||||
status = Column(String, nullable=False, default="PENDING")
|
||||
response_code = Column(Integer, nullable=True)
|
||||
response_body = Column(Text, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
next_retry_at = Column(DateTime, nullable=True)
|
||||
delivered_at = Column(DateTime, nullable=True)
|
||||
dead_letter = Column(Boolean, nullable=False, default=False, index=True)
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
task = relationship("A2ATask")
|
||||
webhook = relationship("A2ATaskWebhook")
|
||||
event = relationship("A2ATaskEvent")
|
||||
|
||||
|
||||
class A2AProjectConfig(Base):
|
||||
__tablename__ = "a2a_project_configs"
|
||||
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), primary_key=True)
|
||||
canary_enabled = Column(Boolean, nullable=False, default=False)
|
||||
canary_percent = Column(Integer, nullable=False, default=0)
|
||||
rollback_to_local = Column(Boolean, nullable=False, default=True)
|
||||
compatibility_mode = Column(Boolean, nullable=False, default=True)
|
||||
dual_event_write = Column(Boolean, nullable=False, default=True)
|
||||
route_mode_default = Column(String, nullable=False, default="local_first")
|
||||
fallback_chain_json = Column(Text, nullable=False, default='["local"]')
|
||||
alert_thresholds_json = Column(Text, nullable=False, default="{}")
|
||||
updated_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
project = relationship("Project")
|
||||
|
||||
|
||||
class A2AAuditLog(Base):
|
||||
__tablename__ = "a2a_audit_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
actor_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
action = Column(String, nullable=False)
|
||||
target_type = Column(String, nullable=False)
|
||||
target_id = Column(String, nullable=False)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=True, index=True)
|
||||
task_id = Column(String, nullable=True, index=True)
|
||||
result = Column(String, nullable=False)
|
||||
detail_json = Column(Text, nullable=False, default="{}")
|
||||
created_at = Column(DateTime, default=func.now(), index=True)
|
||||
@@ -0,0 +1,384 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.a2a import (
|
||||
A2AAuditLog,
|
||||
A2AProjectConfig,
|
||||
A2ARemoteAgent,
|
||||
A2ATask,
|
||||
A2ATaskEvent,
|
||||
A2ATaskWebhook,
|
||||
A2AWebhookDelivery,
|
||||
)
|
||||
from app.models.project import Project
|
||||
from app.trace import build_error_attributes, trace_service
|
||||
|
||||
_STATE_TRANSITIONS = {
|
||||
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
|
||||
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
|
||||
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
|
||||
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
|
||||
"REJECTED": set(),
|
||||
"FAILED": set(),
|
||||
"COMPLETED": set(),
|
||||
"CANCELED": set(),
|
||||
}
|
||||
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
|
||||
|
||||
|
||||
def _json_loads(raw: Optional[str], default: Any) -> Any:
|
||||
if not raw:
|
||||
return default
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
|
||||
def _json_dumps(raw: Any) -> str:
|
||||
return json.dumps(raw, ensure_ascii=False)
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _mask_error(message: str) -> str:
|
||||
if not message:
|
||||
return "internal_error"
|
||||
return "request_failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AResolvedRoute:
|
||||
selected: str
|
||||
fallback_chain: List[str]
|
||||
canary_hit: bool
|
||||
reason: str
|
||||
|
||||
|
||||
class A2AMetrics:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self._counters: Dict[str, int] = defaultdict(int)
|
||||
self._latency_ms: Dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=2000))
|
||||
|
||||
async def incr(self, key: str, value: int = 1) -> None:
|
||||
async with self._lock:
|
||||
self._counters[key] += value
|
||||
|
||||
async def observe_latency(self, key: str, elapsed_ms: float) -> None:
|
||||
async with self._lock:
|
||||
self._latency_ms[key].append(float(elapsed_ms))
|
||||
|
||||
async def snapshot(self) -> Dict[str, Any]:
|
||||
async with self._lock:
|
||||
counters = dict(self._counters)
|
||||
p95 = {}
|
||||
for key, values in self._latency_ms.items():
|
||||
series = sorted(values)
|
||||
if not series:
|
||||
p95[f"{key}.p95_ms"] = 0.0
|
||||
continue
|
||||
idx = int(0.95 * (len(series) - 1))
|
||||
p95[f"{key}.p95_ms"] = round(series[idx], 2)
|
||||
total = counters.get("a2a.requests.total", 0)
|
||||
errors = counters.get("a2a.requests.error", 0)
|
||||
retries = counters.get("a2a.requests.retry", 0)
|
||||
breakers = counters.get("a2a.circuit.open", 0)
|
||||
return {
|
||||
"counters": counters,
|
||||
"derived": {
|
||||
"error_rate": round(errors / total, 4) if total else 0.0,
|
||||
"retry_rate": round(retries / total, 4) if total else 0.0,
|
||||
"circuit_open_rate": round(breakers / total, 4) if total else 0.0,
|
||||
},
|
||||
"latency": p95,
|
||||
}
|
||||
|
||||
|
||||
class A2ARuntime:
|
||||
def __init__(self) -> None:
|
||||
self._subscribers: Dict[str, List[asyncio.Queue[Dict[str, Any]]]] = defaultdict(list)
|
||||
self.metrics = A2AMetrics()
|
||||
self.protocol_version = "1.0"
|
||||
self._circuit_state: Dict[int, datetime] = {}
|
||||
|
||||
async def publish(self, task_id: str, event: Dict[str, Any]) -> None:
|
||||
queues = list(self._subscribers.get(task_id, []))
|
||||
for queue in queues:
|
||||
await queue.put(event)
|
||||
|
||||
async def subscribe(self, task_id: str) -> AsyncIterator[Dict[str, Any]]:
|
||||
queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=200)
|
||||
self._subscribers[task_id].append(queue)
|
||||
try:
|
||||
while True:
|
||||
payload = await queue.get()
|
||||
yield payload
|
||||
finally:
|
||||
self._subscribers[task_id] = [q for q in self._subscribers.get(task_id, []) if q is not queue]
|
||||
if not self._subscribers[task_id]:
|
||||
self._subscribers.pop(task_id, None)
|
||||
|
||||
def get_project_config(self, db: Session, project_id: int, user_id: int) -> A2AProjectConfig:
|
||||
item = db.query(A2AProjectConfig).filter(A2AProjectConfig.project_id == project_id).first()
|
||||
if item:
|
||||
return item
|
||||
config = A2AProjectConfig(project_id=project_id, updated_by=user_id)
|
||||
db.add(config)
|
||||
db.commit()
|
||||
db.refresh(config)
|
||||
return config
|
||||
|
||||
def resolve_route(self, *, project_config: A2AProjectConfig, session_id: str, requested_mode: str, requested_fallback: Optional[List[str]]) -> A2AResolvedRoute:
|
||||
selected = requested_mode or project_config.route_mode_default or "local_first"
|
||||
fallback = requested_fallback or _json_loads(project_config.fallback_chain_json, ["local"])
|
||||
fallback_chain = [item for item in fallback if item in {"a2a", "local", "mcp"}]
|
||||
if not fallback_chain:
|
||||
fallback_chain = ["local"]
|
||||
canary_hit = False
|
||||
if project_config.canary_enabled and project_config.canary_percent > 0:
|
||||
digest = hashlib.sha256(f"{project_config.project_id}:{session_id}".encode()).hexdigest()
|
||||
bucket = int(digest[:8], 16) % 100
|
||||
canary_hit = bucket < project_config.canary_percent
|
||||
if selected in {"a2a_first", "a2a"} and not canary_hit:
|
||||
return A2AResolvedRoute(
|
||||
selected="local",
|
||||
fallback_chain=fallback_chain,
|
||||
canary_hit=False,
|
||||
reason="canary_not_hit_fallback_local",
|
||||
)
|
||||
if selected in {"a2a_first", "a2a"}:
|
||||
return A2AResolvedRoute(selected="a2a", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="a2a_selected")
|
||||
if selected in {"mcp_first", "mcp"}:
|
||||
return A2AResolvedRoute(selected="mcp", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="mcp_selected")
|
||||
return A2AResolvedRoute(selected="local", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="local_selected")
|
||||
|
||||
def can_transition(self, from_state: str, to_state: str) -> bool:
|
||||
if from_state == to_state:
|
||||
return True
|
||||
return to_state in _STATE_TRANSITIONS.get(from_state, set())
|
||||
|
||||
def create_task(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
project_id: int,
|
||||
tenant_id: int,
|
||||
source: str,
|
||||
input_text: str,
|
||||
idempotency_key: Optional[str],
|
||||
remote_agent_id: Optional[int],
|
||||
compatibility_mode: bool,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> A2ATask:
|
||||
if idempotency_key:
|
||||
existing = (
|
||||
db.query(A2ATask)
|
||||
.filter(
|
||||
A2ATask.project_id == project_id,
|
||||
A2ATask.tenant_id == tenant_id,
|
||||
A2ATask.idempotency_key == idempotency_key,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
task = A2ATask(
|
||||
id=f"task_{uuid.uuid4().hex}",
|
||||
project_id=project_id,
|
||||
tenant_id=tenant_id,
|
||||
source=source,
|
||||
remote_agent_id=remote_agent_id,
|
||||
state="SUBMITTED",
|
||||
input_text=input_text,
|
||||
idempotency_key=idempotency_key,
|
||||
compatibility_mode=compatibility_mode,
|
||||
metadata_json=_json_dumps(metadata or {}),
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
def append_event(self, db: Session, task: A2ATask, event_type: str, payload: Dict[str, Any]) -> A2ATaskEvent:
|
||||
event = A2ATaskEvent(task_id=task.id, event_type=event_type, payload_json=_json_dumps(payload))
|
||||
db.add(event)
|
||||
db.commit()
|
||||
db.refresh(event)
|
||||
return event
|
||||
|
||||
def transition_task(
|
||||
self,
|
||||
db: Session,
|
||||
task: A2ATask,
|
||||
*,
|
||||
to_state: str,
|
||||
output_text: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> A2ATask:
|
||||
if not self.can_transition(task.state, to_state):
|
||||
raise ValueError(f"Invalid task transition: {task.state} -> {to_state}")
|
||||
task.state = to_state
|
||||
if output_text is not None:
|
||||
task.output_text = output_text
|
||||
if error_message is not None:
|
||||
task.error_message = error_message
|
||||
if metadata is not None:
|
||||
task.metadata_json = _json_dumps(metadata)
|
||||
if to_state in _TERMINAL_STATES:
|
||||
task.finished_at = _utc_now()
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return task
|
||||
|
||||
def record_audit(
|
||||
self,
|
||||
db: Session,
|
||||
*,
|
||||
actor_user_id: int,
|
||||
action: str,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
result: str,
|
||||
project_id: Optional[int] = None,
|
||||
task_id: Optional[str] = None,
|
||||
detail: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
audit = A2AAuditLog(
|
||||
actor_user_id=actor_user_id,
|
||||
action=action,
|
||||
target_type=target_type,
|
||||
target_id=target_id,
|
||||
result=result,
|
||||
project_id=project_id,
|
||||
task_id=task_id,
|
||||
detail_json=_json_dumps(detail or {}),
|
||||
)
|
||||
db.add(audit)
|
||||
db.commit()
|
||||
|
||||
async def fetch_agent_card(self, db: Session, agent: A2ARemoteAgent, *, timeout_s: float = 10.0) -> Dict[str, Any]:
|
||||
if agent.id in self._circuit_state and self._circuit_state[agent.id] > _utc_now():
|
||||
raise RuntimeError("circuit_open")
|
||||
started = time.perf_counter()
|
||||
await self.metrics.incr("a2a.requests.total")
|
||||
headers = {}
|
||||
if agent.auth_scheme == "bearer" and agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {agent.auth_token}"
|
||||
url = f"{agent.base_url.rstrip('/')}/api/v1/a2a/agent-card"
|
||||
with trace_service.start_span("a2a.card.fetch", attributes={"agent_id": agent.id, "url": url}) as span:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout_s, verify=True) as client:
|
||||
resp = await client.get(url, headers=headers)
|
||||
if resp.status_code >= 400:
|
||||
raise RuntimeError(f"http_{resp.status_code}")
|
||||
payload = resp.json()
|
||||
elapsed_ms = (time.perf_counter() - started) * 1000
|
||||
await self.metrics.observe_latency("a2a.card.fetch", elapsed_ms)
|
||||
agent.card_json = _json_dumps(payload)
|
||||
agent.protocol_version = str(payload.get("protocol_version") or "")
|
||||
agent.capabilities_json = _json_dumps(payload.get("capabilities") or [])
|
||||
agent.card_fetched_at = _utc_now()
|
||||
agent.healthy = True
|
||||
agent.failure_count = 0
|
||||
agent.circuit_open_until = None
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
return payload
|
||||
except Exception as exc:
|
||||
span.set_attributes(build_error_attributes(exc, stage="a2a_card_fetch"))
|
||||
await self.metrics.incr("a2a.requests.error")
|
||||
if attempt < 2:
|
||||
await self.metrics.incr("a2a.requests.retry")
|
||||
await asyncio.sleep(0.2 * (2 ** attempt))
|
||||
continue
|
||||
agent.failure_count = (agent.failure_count or 0) + 1
|
||||
if agent.failure_count >= 3:
|
||||
reopen_at = _utc_now() + timedelta(seconds=90)
|
||||
agent.circuit_open_until = reopen_at
|
||||
self._circuit_state[agent.id] = reopen_at
|
||||
await self.metrics.incr("a2a.circuit.open")
|
||||
agent.healthy = False
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
raise
|
||||
|
||||
async def notify_webhooks(self, db: Session, task: A2ATask, event: A2ATaskEvent) -> None:
|
||||
webhooks = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.task_id == task.id, A2ATaskWebhook.enabled == True).all()
|
||||
if not webhooks:
|
||||
return
|
||||
for hook in webhooks:
|
||||
delivery = A2AWebhookDelivery(task_id=task.id, webhook_id=hook.id, event_id=event.id, attempt=0, status="PENDING")
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
db.refresh(delivery)
|
||||
await self._deliver_once(db, hook, event, delivery)
|
||||
|
||||
async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None:
|
||||
event_payload = _json_loads(event.payload_json, {})
|
||||
request_payload = {
|
||||
"task_id": event.task_id,
|
||||
"event_type": event.event_type,
|
||||
"event_id": event.id,
|
||||
"payload": event_payload,
|
||||
}
|
||||
body = _json_dumps(request_payload).encode("utf-8")
|
||||
for attempt in range(1, 5):
|
||||
delivery.attempt = attempt
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)}
|
||||
if hook.secret:
|
||||
digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest()
|
||||
headers["X-A2A-Signature"] = f"sha256={digest}"
|
||||
if hook.auth_header:
|
||||
headers["Authorization"] = hook.auth_header
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0, verify=True) as client:
|
||||
resp = await client.post(hook.target_url, content=body, headers=headers)
|
||||
delivery.response_code = resp.status_code
|
||||
delivery.response_body = (resp.text or "")[:1000]
|
||||
if 200 <= resp.status_code < 300:
|
||||
delivery.status = "DELIVERED"
|
||||
delivery.dead_letter = False
|
||||
delivery.delivered_at = _utc_now()
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
return
|
||||
raise RuntimeError(f"http_{resp.status_code}")
|
||||
except Exception as exc:
|
||||
delivery.error_message = str(exc)[:500]
|
||||
if attempt < 4:
|
||||
delivery.status = "RETRYING"
|
||||
delivery.next_retry_at = _utc_now() + timedelta(seconds=2 ** attempt)
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
continue
|
||||
delivery.status = "FAILED"
|
||||
delivery.dead_letter = True
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
return
|
||||
|
||||
|
||||
a2a_runtime = A2ARuntime()
|
||||
+14
-2
@@ -23,7 +23,7 @@ import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from app.api import upload, llm, skills, users, datasources, projects, semantic, mcp, subagents, knowledge, embedding_models, web_search
|
||||
from app.api import upload, llm, skills, users, datasources, projects, semantic, mcp, subagents, knowledge, embedding_models, web_search, a2a
|
||||
from app.connectors.postgres import postgres_connector
|
||||
from app.connectors.clickhouse import clickhouse_connector
|
||||
from app.core.artifacts import extract_artifacts
|
||||
@@ -52,6 +52,15 @@ from app.models.user import User, EmailVerification
|
||||
from app.models.project import Project
|
||||
from app.models.datasource import DataSource
|
||||
from app.models.subagent import Subagent
|
||||
from app.models.a2a import (
|
||||
A2ARemoteAgent,
|
||||
A2ATask,
|
||||
A2ATaskEvent,
|
||||
A2ATaskWebhook,
|
||||
A2AWebhookDelivery,
|
||||
A2AProjectConfig,
|
||||
A2AAuditLog,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -86,6 +95,7 @@ app.include_router(subagents.router, prefix="/api/v1")
|
||||
app.include_router(knowledge.router, prefix="/api/v1")
|
||||
app.include_router(embedding_models.router, prefix="/api/v1")
|
||||
app.include_router(web_search.router, prefix="/api/v1")
|
||||
app.include_router(a2a.router, prefix="/api/v1")
|
||||
|
||||
STREAM_DELTA_CHUNK_SIZE = 48
|
||||
PREVIEWABLE_TEXT_EXTENSIONS = {
|
||||
@@ -292,7 +302,9 @@ class ChatRequest(BaseModel):
|
||||
source: str = "postgres"
|
||||
prefer_sql_chart: bool = False
|
||||
file_url: Optional[str] = None
|
||||
route_mode: Literal["auto", "chat", "sql"] = "auto"
|
||||
route_mode: Literal["auto", "chat", "sql", "a2a", "a2a_first", "local_first", "mcp_first"] = "auto"
|
||||
route_fallback_chain: Optional[List[Literal["a2a", "local", "mcp"]]] = None
|
||||
a2a_agent_id: Optional[int] = None
|
||||
knowledge_base_id: Optional[str] = None
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,319 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
BACKEND_ROOT = Path(__file__).resolve().parents[1]
|
||||
REPO_ROOT = BACKEND_ROOT.parent
|
||||
NANOBOT_ROOT = REPO_ROOT / "nanobot"
|
||||
if str(BACKEND_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
if str(NANOBOT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(NANOBOT_ROOT))
|
||||
|
||||
from app.core.security import CurrentUser, get_current_user
|
||||
from app.database import Base, get_db
|
||||
from app.models.a2a import A2ARemoteAgent
|
||||
from app.models.project import Project
|
||||
from app.models.user import User
|
||||
from app.services.a2a_service import a2a_runtime
|
||||
from main import app
|
||||
|
||||
|
||||
def _seed(db: Session) -> tuple[int, str, int, str, int]:
|
||||
owner = User(username="a2a_owner", email="a2a_owner@example.com", hashed_password="x", is_admin=False)
|
||||
other = User(username="a2a_other", email="a2a_other@example.com", hashed_password="x", is_admin=False)
|
||||
db.add(owner)
|
||||
db.add(other)
|
||||
db.commit()
|
||||
db.refresh(owner)
|
||||
db.refresh(other)
|
||||
project = Project(name="a2a_project", description="a2a", owner_id=owner.id)
|
||||
db.add(project)
|
||||
db.commit()
|
||||
db.refresh(project)
|
||||
return owner.id, owner.username, other.id, other.username, project.id
|
||||
|
||||
|
||||
def test_a2a_send_list_cancel_and_rollout() -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
send_resp = client.post(
|
||||
"/api/v1/a2a/messages/send",
|
||||
json={
|
||||
"project_id": project_id,
|
||||
"message": "hello a2a",
|
||||
"session_id": "test-a2a-session",
|
||||
"route_mode": "local_first",
|
||||
},
|
||||
)
|
||||
assert send_resp.status_code == 200
|
||||
task_id = send_resp.json()["task"]["id"]
|
||||
|
||||
get_resp = client.get(f"/api/v1/a2a/tasks/{task_id}")
|
||||
assert get_resp.status_code == 200
|
||||
assert get_resp.json()["project_id"] == project_id
|
||||
|
||||
list_resp = client.get("/api/v1/a2a/tasks", params={"project_id": project_id})
|
||||
assert list_resp.status_code == 200
|
||||
assert any(item["id"] == task_id for item in list_resp.json())
|
||||
|
||||
cancel_resp = client.post(f"/api/v1/a2a/tasks/{task_id}/cancel")
|
||||
assert cancel_resp.status_code == 200
|
||||
assert cancel_resp.json()["state"] in {"CANCELED", "COMPLETED", "FAILED"}
|
||||
|
||||
rollout_resp = client.put(
|
||||
f"/api/v1/a2a/projects/{project_id}/rollout",
|
||||
json={"canary_enabled": True, "canary_percent": 30, "route_mode_default": "a2a_first", "fallback_chain": ["a2a", "local"]},
|
||||
)
|
||||
assert rollout_resp.status_code == 200
|
||||
assert rollout_resp.json()["canary_enabled"] is True
|
||||
assert rollout_resp.json()["canary_percent"] == 30
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_a2a_task_tenant_isolation() -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
send_resp = client.post(
|
||||
"/api/v1/a2a/messages/send",
|
||||
json={"project_id": project_id, "message": "tenant isolation", "session_id": "tenant-isolation", "route_mode": "local"},
|
||||
)
|
||||
assert send_resp.status_code == 200
|
||||
task_id = send_resp.json()["task"]["id"]
|
||||
|
||||
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
|
||||
forbidden_resp = client.get(f"/api/v1/a2a/tasks/{task_id}")
|
||||
assert forbidden_resp.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_a2a_metrics_admin_only() -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, _ = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
denied = client.get("/api/v1/a2a/metrics")
|
||||
assert denied.status_code == 403
|
||||
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
|
||||
ok = client.get("/api/v1/a2a/metrics")
|
||||
assert ok.status_code == 200
|
||||
assert "counters" in ok.json()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_a2a_send_idempotency_key_deduplicates_task() -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
payload = {
|
||||
"project_id": project_id,
|
||||
"message": "dedupe-task",
|
||||
"session_id": "idempotency-session",
|
||||
"route_mode": "local_first",
|
||||
"idempotency_key": "same-key-1",
|
||||
}
|
||||
first_resp = client.post("/api/v1/a2a/messages/send", json=payload)
|
||||
second_resp = client.post("/api/v1/a2a/messages/send", json=payload)
|
||||
assert first_resp.status_code == 200
|
||||
assert second_resp.status_code == 200
|
||||
assert first_resp.json()["task"]["id"] == second_resp.json()["task"]["id"]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_a2a_fetch_agent_card_auth_failure_marks_agent_unhealthy(monkeypatch) -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, _, _, _, project_id = _seed(db)
|
||||
agent = A2ARemoteAgent(
|
||||
project_id=project_id,
|
||||
name="auth-fail-agent",
|
||||
base_url="https://remote.example.com",
|
||||
auth_scheme="bearer",
|
||||
auth_token="bad-token",
|
||||
created_by=owner_id,
|
||||
)
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
a2a_runtime._circuit_state.pop(agent.id, None)
|
||||
|
||||
class _FailResp:
|
||||
status_code = 401
|
||||
|
||||
@staticmethod
|
||||
def json():
|
||||
return {"detail": "unauthorized"}
|
||||
|
||||
class _Client401:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url, headers=None):
|
||||
return _FailResp()
|
||||
|
||||
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _Client401)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
|
||||
db.refresh(agent)
|
||||
assert agent.healthy is False
|
||||
assert agent.failure_count == 1
|
||||
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
db.close()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_a2a_fetch_agent_card_remote_unavailable_opens_circuit(monkeypatch) -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, _, _, _, project_id = _seed(db)
|
||||
agent = A2ARemoteAgent(
|
||||
project_id=project_id,
|
||||
name="offline-agent",
|
||||
base_url="https://offline.example.com",
|
||||
auth_scheme="none",
|
||||
created_by=owner_id,
|
||||
)
|
||||
db.add(agent)
|
||||
db.commit()
|
||||
db.refresh(agent)
|
||||
a2a_runtime._circuit_state.pop(agent.id, None)
|
||||
|
||||
class _ClientDown:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get(self, url, headers=None):
|
||||
raise httpx.ConnectError("network down")
|
||||
|
||||
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _ClientDown)
|
||||
|
||||
for _ in range(3):
|
||||
with pytest.raises(Exception):
|
||||
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
|
||||
db.refresh(agent)
|
||||
assert agent.healthy is False
|
||||
assert agent.failure_count == 3
|
||||
assert agent.circuit_open_until is not None
|
||||
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
db.close()
|
||||
engine.dispose()
|
||||
Reference in New Issue
Block a user