feat: support a2a mode

This commit is contained in:
qixinbo
2026-04-01 11:21:55 +08:00
parent 9952af198a
commit 86447049a9
12 changed files with 3092 additions and 26 deletions
+891
View File
@@ -0,0 +1,891 @@
import asyncio
import json
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple
import httpx
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.core.nanobot import nanobot_service
from app.core.security import CurrentUser, get_current_user
from app.database import SessionLocal, get_db
from app.models.a2a import (
A2AAuditLog,
A2AProjectConfig,
A2ARemoteAgent,
A2ATask,
A2ATaskEvent,
A2ATaskWebhook,
A2AWebhookDelivery,
)
from app.models.project import Project
from app.services.a2a_service import _json_dumps, _json_loads, a2a_runtime
from app.trace import build_error_attributes, trace_service
router = APIRouter(prefix="/a2a", tags=["a2a"])
SUPPORTED_PROTOCOL_VERSION = "1.0"
SUPPORTED_CAPABILITIES = ["streaming", "push", "task_management", "subscribe"]
SUPPORTED_AUTH = ["bearer", "shared_secret", "none"]
def _mask_error(message: str) -> str:
if not message:
return "internal_error"
return "request_failed"
class AgentCardResponse(BaseModel):
name: str
protocol_version: str
capabilities: List[str]
endpoints: Dict[str, str]
auth: List[str]
class RemoteAgentCreate(BaseModel):
project_id: int
name: str = Field(min_length=1, max_length=120)
base_url: str = Field(min_length=1, max_length=500)
auth_scheme: Literal["none", "bearer"] = "none"
auth_token: Optional[str] = None
class RemoteAgentUpdate(BaseModel):
name: Optional[str] = None
base_url: Optional[str] = None
auth_scheme: Optional[Literal["none", "bearer"]] = None
auth_token: Optional[str] = None
class RemoteAgentView(BaseModel):
id: int
project_id: int
name: str
base_url: str
auth_scheme: str
protocol_version: Optional[str] = None
capabilities: List[str] = []
healthy: bool
failure_count: int
circuit_open_until: Optional[datetime] = None
card_fetched_at: Optional[datetime] = None
class SendMessageRequest(BaseModel):
project_id: int
message: str = Field(min_length=1)
session_id: str = "api:a2a"
remote_agent_id: Optional[int] = None
route_mode: Literal["auto", "local", "a2a", "a2a_first", "local_first", "mcp_first"] = "auto"
fallback_chain: Optional[List[Literal["a2a", "local", "mcp"]]] = None
idempotency_key: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
class TaskView(BaseModel):
id: str
project_id: int
source: str
state: str
remote_agent_id: Optional[int] = None
input_text: str
output_text: Optional[str] = None
error_message: Optional[str] = None
compatibility_mode: bool
metadata: Dict[str, Any]
created_at: datetime
updated_at: datetime
finished_at: Optional[datetime] = None
class CancelTaskResponse(BaseModel):
task_id: str
state: str
class TaskWebhookCreate(BaseModel):
target_url: str = Field(min_length=1, max_length=500)
secret: Optional[str] = None
auth_header: Optional[str] = None
class TaskWebhookView(BaseModel):
id: int
task_id: str
target_url: str
enabled: bool
created_at: datetime
updated_at: datetime
class RolloutConfigView(BaseModel):
project_id: int
canary_enabled: bool
canary_percent: int
rollback_to_local: bool
compatibility_mode: bool
dual_event_write: bool
route_mode_default: str
fallback_chain: List[str]
alert_thresholds: Dict[str, Any]
class RolloutConfigUpdate(BaseModel):
canary_enabled: Optional[bool] = None
canary_percent: Optional[int] = Field(default=None, ge=0, le=100)
rollback_to_local: Optional[bool] = None
compatibility_mode: Optional[bool] = None
dual_event_write: Optional[bool] = None
route_mode_default: Optional[str] = None
fallback_chain: Optional[List[str]] = None
alert_thresholds: Optional[Dict[str, Any]] = None
def _ensure_project_access(db: Session, project_id: int, user: CurrentUser) -> Project:
project = db.query(Project).filter(Project.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
if not user.is_admin and project.owner_id != user.id:
raise HTTPException(status_code=404, detail="Resource not found")
return project
def _ensure_task_access(db: Session, task_id: str, user: CurrentUser) -> A2ATask:
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if not user.is_admin and task.tenant_id != user.id:
raise HTTPException(status_code=404, detail="Task not found")
return task
def _ensure_agent_access(db: Session, agent_id: int, user: CurrentUser) -> A2ARemoteAgent:
agent = db.query(A2ARemoteAgent).filter(A2ARemoteAgent.id == agent_id).first()
if not agent:
raise HTTPException(status_code=404, detail="Remote agent not found")
project = _ensure_project_access(db, agent.project_id, user)
if not project:
raise HTTPException(status_code=404, detail="Remote agent not found")
return agent
def _task_to_view(task: A2ATask) -> TaskView:
return TaskView(
id=task.id,
project_id=task.project_id,
source=task.source,
state=task.state,
remote_agent_id=task.remote_agent_id,
input_text=task.input_text,
output_text=task.output_text,
error_message=task.error_message,
compatibility_mode=task.compatibility_mode,
metadata=_json_loads(task.metadata_json, {}),
created_at=task.created_at,
updated_at=task.updated_at,
finished_at=task.finished_at,
)
def _agent_to_view(agent: A2ARemoteAgent) -> RemoteAgentView:
return RemoteAgentView(
id=agent.id,
project_id=agent.project_id,
name=agent.name,
base_url=agent.base_url,
auth_scheme=agent.auth_scheme,
protocol_version=agent.protocol_version,
capabilities=_json_loads(agent.capabilities_json, []),
healthy=bool(agent.healthy),
failure_count=int(agent.failure_count or 0),
circuit_open_until=agent.circuit_open_until,
card_fetched_at=agent.card_fetched_at,
)
def _build_status_event(task: A2ATask, *, compatibility_mode: bool, dual_event_write: bool) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"type": "TaskStatusUpdateEvent",
"task_id": task.id,
"task_status": task.state,
"timestamp": datetime.utcnow().isoformat(),
"source": task.source,
}
if compatibility_mode or dual_event_write:
payload.update(
{
"event": "task_status",
"status": task.state,
"taskId": task.id,
}
)
return payload
def _build_artifact_event(task_id: str, content: str, *, compatibility_mode: bool, dual_event_write: bool) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"type": "TaskArtifactUpdateEvent",
"task_id": task_id,
"artifact": {"content": content},
"timestamp": datetime.utcnow().isoformat(),
}
if compatibility_mode or dual_event_write:
payload.update(
{
"event": "task_output",
"taskId": task_id,
"output": content,
}
)
return payload
async def _delegate_to_remote(task: A2ATask, agent: A2ARemoteAgent, message: str) -> Tuple[str, Dict[str, Any]]:
headers: Dict[str, str] = {}
if agent.auth_scheme == "bearer" and agent.auth_token:
headers["Authorization"] = f"Bearer {agent.auth_token}"
payload = {
"project_id": task.project_id,
"message": message,
"session_id": f"a2a-delegate:{task.id}",
"idempotency_key": task.idempotency_key,
"route_mode": "local_first",
"metadata": {"delegated_by": "dataclaw", "task_id": task.id},
}
url = f"{agent.base_url.rstrip('/')}/api/v1/a2a/messages/send"
async with httpx.AsyncClient(timeout=25.0, verify=True) as client:
resp = await client.post(url, json=payload, headers=headers)
if resp.status_code >= 400:
raise RuntimeError(f"remote_http_{resp.status_code}")
body = resp.json()
content = ""
if isinstance(body, dict):
task_obj = body.get("task") or {}
content = str(task_obj.get("output_text") or body.get("message") or "")
return content, body
async def _run_task(task_id: str, request: SendMessageRequest, tenant_id: int) -> None:
db = SessionLocal()
try:
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
if not task:
return
config = a2a_runtime.get_project_config(db, task.project_id, tenant_id)
if task.state in {"CANCELED", "REJECTED"}:
return
with trace_service.start_span("a2a.task.execute", attributes={"task_id": task.id, "project_id": task.project_id, "source": task.source}) as span:
start_ts = datetime.utcnow().timestamp()
try:
task = a2a_runtime.transition_task(db, task, to_state="WORKING")
status_event = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
status_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", status_event)
await a2a_runtime.publish(task.id, status_event)
await a2a_runtime.notify_webhooks(db, task, status_row)
if task.source == "a2a" and task.remote_agent_id:
agent = db.query(A2ARemoteAgent).filter(A2ARemoteAgent.id == task.remote_agent_id).first()
if not agent:
raise RuntimeError("remote_agent_missing")
response_text, metadata = await _delegate_to_remote(task, agent, request.message)
else:
response_text = await nanobot_service.process_message(
request.message,
session_id=f"a2a-task:{task.id}",
project_id=task.project_id,
)
metadata = {"executor": "local"}
artifact_event_payload = _build_artifact_event(task.id, response_text or "", compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
artifact_event = a2a_runtime.append_event(db, task, "TaskArtifactUpdateEvent", artifact_event_payload)
await a2a_runtime.publish(task.id, artifact_event_payload)
await a2a_runtime.notify_webhooks(db, task, artifact_event)
task = a2a_runtime.transition_task(
db,
task,
to_state="COMPLETED",
output_text=response_text or "",
metadata=metadata,
)
done_event = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
done_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", done_event)
await a2a_runtime.publish(task.id, done_event)
await a2a_runtime.notify_webhooks(db, task, done_row)
elapsed = (datetime.utcnow().timestamp() - start_ts) * 1000
await a2a_runtime.metrics.observe_latency("a2a.execute", elapsed)
except Exception as exc:
span.set_attributes(build_error_attributes(exc, stage="a2a_task_execute"))
await a2a_runtime.metrics.incr("a2a.requests.error")
task = db.query(A2ATask).filter(A2ATask.id == task.id).first()
if task and task.state not in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
task = a2a_runtime.transition_task(db, task, to_state="FAILED", error_message=_json_dumps({"message": _mask_error(str(exc))}))
fail_event = _build_status_event(task, compatibility_mode=task.compatibility_mode, dual_event_write=True)
fail_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", fail_event)
await a2a_runtime.publish(task.id, fail_event)
await a2a_runtime.notify_webhooks(db, task, fail_row)
finally:
db.close()
@router.get("/agent-card", response_model=AgentCardResponse)
def get_agent_card() -> AgentCardResponse:
return AgentCardResponse(
name="DataClaw A2A Gateway",
protocol_version=SUPPORTED_PROTOCOL_VERSION,
capabilities=SUPPORTED_CAPABILITIES,
endpoints={
"send_message": "/api/v1/a2a/messages/send",
"send_streaming_message": "/api/v1/a2a/messages/stream",
"get_task": "/api/v1/a2a/tasks/{task_id}",
"list_tasks": "/api/v1/a2a/tasks",
"cancel_task": "/api/v1/a2a/tasks/{task_id}/cancel",
"subscribe_task": "/api/v1/a2a/tasks/{task_id}/subscribe",
},
auth=SUPPORTED_AUTH,
)
@router.get("/remote-agents", response_model=List[RemoteAgentView])
def list_remote_agents(
project_id: Optional[int] = Query(default=None),
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> List[RemoteAgentView]:
query = db.query(A2ARemoteAgent)
if project_id is not None:
_ensure_project_access(db, project_id, current_user)
query = query.filter(A2ARemoteAgent.project_id == project_id)
if not current_user.is_admin:
owned_ids = [p.id for p in db.query(Project).filter(Project.owner_id == current_user.id).all()]
if not owned_ids:
return []
query = query.filter(A2ARemoteAgent.project_id.in_(owned_ids))
return [_agent_to_view(item) for item in query.order_by(A2ARemoteAgent.id.desc()).all()]
@router.post("/remote-agents", response_model=RemoteAgentView, status_code=status.HTTP_201_CREATED)
async def create_remote_agent(
payload: RemoteAgentCreate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> RemoteAgentView:
_ensure_project_access(db, payload.project_id, current_user)
item = A2ARemoteAgent(
project_id=payload.project_id,
name=payload.name.strip(),
base_url=payload.base_url.strip().rstrip("/"),
auth_scheme=payload.auth_scheme,
auth_token=payload.auth_token,
created_by=current_user.id,
)
db.add(item)
db.commit()
db.refresh(item)
try:
await a2a_runtime.fetch_agent_card(db, item)
except Exception:
pass
a2a_runtime.record_audit(
db,
actor_user_id=current_user.id,
action="create_remote_agent",
target_type="remote_agent",
target_id=str(item.id),
result="ok",
project_id=item.project_id,
)
return _agent_to_view(item)
@router.put("/remote-agents/{agent_id}", response_model=RemoteAgentView)
async def update_remote_agent(
agent_id: int,
payload: RemoteAgentUpdate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> RemoteAgentView:
item = _ensure_agent_access(db, agent_id, current_user)
update_data = payload.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(item, key, value)
if item.base_url:
item.base_url = item.base_url.rstrip("/")
db.add(item)
db.commit()
db.refresh(item)
try:
await a2a_runtime.fetch_agent_card(db, item)
except Exception:
pass
a2a_runtime.record_audit(
db,
actor_user_id=current_user.id,
action="update_remote_agent",
target_type="remote_agent",
target_id=str(item.id),
result="ok",
project_id=item.project_id,
)
return _agent_to_view(item)
@router.delete("/remote-agents/{agent_id}")
def delete_remote_agent(
agent_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> Dict[str, str]:
item = _ensure_agent_access(db, agent_id, current_user)
db.delete(item)
db.commit()
a2a_runtime.record_audit(
db,
actor_user_id=current_user.id,
action="delete_remote_agent",
target_type="remote_agent",
target_id=str(agent_id),
result="ok",
project_id=item.project_id,
)
return {"status": "success"}
@router.post("/remote-agents/{agent_id}/refresh-card", response_model=RemoteAgentView)
async def refresh_remote_agent_card(
agent_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> RemoteAgentView:
item = _ensure_agent_access(db, agent_id, current_user)
try:
card = await a2a_runtime.fetch_agent_card(db, item)
except Exception as exc:
a2a_runtime.record_audit(
db,
actor_user_id=current_user.id,
action="refresh_remote_agent_card",
target_type="remote_agent",
target_id=str(agent_id),
result="failed",
project_id=item.project_id,
detail={"error": str(exc)},
)
raise HTTPException(status_code=502, detail="Remote card fetch failed")
version = str(card.get("protocol_version") or "")
if version and version.split(".")[0] != SUPPORTED_PROTOCOL_VERSION.split(".")[0]:
raise HTTPException(status_code=400, detail="Protocol version incompatible")
return _agent_to_view(item)
@router.post("/remote-agents/{agent_id}/health-check")
async def health_check_remote_agent(
agent_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> Dict[str, Any]:
item = _ensure_agent_access(db, agent_id, current_user)
try:
await a2a_runtime.fetch_agent_card(db, item, timeout_s=5.0)
return {"healthy": True, "failure_count": item.failure_count}
except Exception:
return {"healthy": False, "failure_count": item.failure_count}
@router.post("/messages/send")
async def send_message(
request: SendMessageRequest,
x_a2a_token: Optional[str] = Header(default=None),
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> Dict[str, Any]:
_ensure_project_access(db, request.project_id, current_user)
config = a2a_runtime.get_project_config(db, request.project_id, current_user.id)
route = a2a_runtime.resolve_route(
project_config=config,
session_id=request.session_id,
requested_mode=request.route_mode,
requested_fallback=request.fallback_chain,
)
selected_source = "local"
remote_agent_id = None
if route.selected == "a2a" and request.remote_agent_id:
agent = _ensure_agent_access(db, request.remote_agent_id, current_user)
if not agent.healthy and config.rollback_to_local:
selected_source = "local"
else:
selected_source = "a2a"
remote_agent_id = agent.id
task = a2a_runtime.create_task(
db,
project_id=request.project_id,
tenant_id=current_user.id,
source=selected_source,
input_text=request.message,
idempotency_key=request.idempotency_key,
remote_agent_id=remote_agent_id,
compatibility_mode=config.compatibility_mode,
metadata={"route": route.model_dump() if hasattr(route, "model_dump") else route.__dict__, "token_present": bool(x_a2a_token), "request_metadata": request.metadata or {}},
)
event_payload = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
event_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", event_payload)
await a2a_runtime.publish(task.id, event_payload)
await a2a_runtime.notify_webhooks(db, task, event_row)
asyncio.create_task(_run_task(task.id, request, current_user.id))
await a2a_runtime.metrics.incr("a2a.requests.total")
a2a_runtime.record_audit(
db,
actor_user_id=current_user.id,
action="send_message",
target_type="task",
target_id=task.id,
result="accepted",
project_id=task.project_id,
task_id=task.id,
)
return {"task": _task_to_view(task).model_dump(), "routing": route.__dict__}
@router.post("/messages/stream")
async def send_streaming_message(
request: SendMessageRequest,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> StreamingResponse:
response = await send_message(request=request, x_a2a_token=None, db=db, current_user=current_user)
task_id = response["task"]["id"]
async def event_generator():
history = (
db.query(A2ATaskEvent)
.filter(A2ATaskEvent.task_id == task_id)
.order_by(A2ATaskEvent.id.asc())
.all()
)
for item in history:
payload = _json_loads(item.payload_json, {})
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
async for payload in a2a_runtime.subscribe(task_id):
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
if payload.get("task_status") in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
break
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.get("/tasks/{task_id}", response_model=TaskView)
def get_task(
task_id: str,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> TaskView:
task = _ensure_task_access(db, task_id, current_user)
return _task_to_view(task)
@router.get("/tasks", response_model=List[TaskView])
def list_tasks(
project_id: Optional[int] = Query(default=None),
state: Optional[str] = Query(default=None),
skip: int = Query(default=0, ge=0),
limit: int = Query(default=50, ge=1, le=200),
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> List[TaskView]:
query = db.query(A2ATask)
if not current_user.is_admin:
query = query.filter(A2ATask.tenant_id == current_user.id)
if project_id is not None:
_ensure_project_access(db, project_id, current_user)
query = query.filter(A2ATask.project_id == project_id)
if state:
query = query.filter(A2ATask.state == state)
tasks = query.order_by(A2ATask.created_at.desc()).offset(skip).limit(limit).all()
return [_task_to_view(item) for item in tasks]
@router.post("/tasks/{task_id}/cancel", response_model=CancelTaskResponse)
async def cancel_task(
task_id: str,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> CancelTaskResponse:
task = _ensure_task_access(db, task_id, current_user)
if task.state in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
return CancelTaskResponse(task_id=task.id, state=task.state)
try:
task = a2a_runtime.transition_task(db, task, to_state="CANCELED")
except ValueError:
raise HTTPException(status_code=409, detail="Task state transition conflict")
config = a2a_runtime.get_project_config(db, task.project_id, current_user.id)
payload = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write)
row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", payload)
await a2a_runtime.publish(task.id, payload)
await a2a_runtime.notify_webhooks(db, task, row)
a2a_runtime.record_audit(
db,
actor_user_id=current_user.id,
action="cancel_task",
target_type="task",
target_id=task.id,
result="ok",
project_id=task.project_id,
task_id=task.id,
)
return CancelTaskResponse(task_id=task.id, state=task.state)
@router.get("/tasks/{task_id}/subscribe")
async def subscribe_task(
task_id: str,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> StreamingResponse:
task = _ensure_task_access(db, task_id, current_user)
initial_events = (
db.query(A2ATaskEvent)
.filter(A2ATaskEvent.task_id == task.id)
.order_by(A2ATaskEvent.id.asc())
.all()
)
async def event_generator():
for event in initial_events:
payload = _json_loads(event.payload_json, {})
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
if task.state in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return
async for payload in a2a_runtime.subscribe(task.id):
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
if payload.get("task_status") in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}:
break
yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
)
@router.get("/tasks/{task_id}/webhooks", response_model=List[TaskWebhookView])
def list_task_webhooks(
task_id: str,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> List[TaskWebhookView]:
task = _ensure_task_access(db, task_id, current_user)
items = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.task_id == task.id).order_by(A2ATaskWebhook.id.desc()).all()
return [
TaskWebhookView(
id=item.id,
task_id=item.task_id,
target_url=item.target_url,
enabled=item.enabled,
created_at=item.created_at,
updated_at=item.updated_at,
)
for item in items
]
@router.post("/tasks/{task_id}/webhooks", response_model=TaskWebhookView, status_code=status.HTTP_201_CREATED)
def create_task_webhook(
task_id: str,
payload: TaskWebhookCreate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> TaskWebhookView:
task = _ensure_task_access(db, task_id, current_user)
item = A2ATaskWebhook(
task_id=task.id,
target_url=payload.target_url.strip(),
secret=payload.secret,
auth_header=payload.auth_header,
created_by=current_user.id,
)
db.add(item)
db.commit()
db.refresh(item)
a2a_runtime.record_audit(
db,
actor_user_id=current_user.id,
action="create_task_webhook",
target_type="task_webhook",
target_id=str(item.id),
result="ok",
project_id=task.project_id,
task_id=task.id,
)
return TaskWebhookView(
id=item.id,
task_id=item.task_id,
target_url=item.target_url,
enabled=item.enabled,
created_at=item.created_at,
updated_at=item.updated_at,
)
@router.delete("/tasks/{task_id}/webhooks/{webhook_id}")
def delete_task_webhook(
task_id: str,
webhook_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> Dict[str, str]:
task = _ensure_task_access(db, task_id, current_user)
item = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.id == webhook_id, A2ATaskWebhook.task_id == task.id).first()
if not item:
raise HTTPException(status_code=404, detail="Webhook not found")
db.delete(item)
db.commit()
return {"status": "success"}
@router.post("/webhook-deliveries/{delivery_id}/replay")
async def replay_delivery(
delivery_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> Dict[str, Any]:
delivery = db.query(A2AWebhookDelivery).filter(A2AWebhookDelivery.id == delivery_id).first()
if not delivery:
raise HTTPException(status_code=404, detail="Delivery not found")
task = _ensure_task_access(db, delivery.task_id, current_user)
webhook = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.id == delivery.webhook_id).first()
event = db.query(A2ATaskEvent).filter(A2ATaskEvent.id == delivery.event_id).first()
if not webhook or not event:
raise HTTPException(status_code=404, detail="Delivery dependencies not found")
await a2a_runtime._deliver_once(db, webhook, event, delivery)
return {"status": delivery.status, "attempt": delivery.attempt, "dead_letter": delivery.dead_letter, "task_id": task.id}
@router.get("/metrics")
async def get_metrics(current_user: CurrentUser = Depends(get_current_user)) -> Dict[str, Any]:
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="Admin permission required")
return await a2a_runtime.metrics.snapshot()
@router.get("/projects/{project_id}/rollout", response_model=RolloutConfigView)
def get_rollout_config(
project_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> RolloutConfigView:
_ensure_project_access(db, project_id, current_user)
item = a2a_runtime.get_project_config(db, project_id, current_user.id)
return RolloutConfigView(
project_id=item.project_id,
canary_enabled=item.canary_enabled,
canary_percent=item.canary_percent,
rollback_to_local=item.rollback_to_local,
compatibility_mode=item.compatibility_mode,
dual_event_write=item.dual_event_write,
route_mode_default=item.route_mode_default,
fallback_chain=_json_loads(item.fallback_chain_json, ["local"]),
alert_thresholds=_json_loads(item.alert_thresholds_json, {}),
)
@router.put("/projects/{project_id}/rollout", response_model=RolloutConfigView)
def update_rollout_config(
project_id: int,
payload: RolloutConfigUpdate,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> RolloutConfigView:
_ensure_project_access(db, project_id, current_user)
item = a2a_runtime.get_project_config(db, project_id, current_user.id)
data = payload.model_dump(exclude_unset=True)
for key, value in data.items():
if key == "fallback_chain":
item.fallback_chain_json = _json_dumps(value)
continue
if key == "alert_thresholds":
item.alert_thresholds_json = _json_dumps(value)
continue
setattr(item, key, value)
item.updated_by = current_user.id
db.add(item)
db.commit()
db.refresh(item)
a2a_runtime.record_audit(
db,
actor_user_id=current_user.id,
action="update_rollout_config",
target_type="project_rollout",
target_id=str(project_id),
result="ok",
project_id=project_id,
)
return RolloutConfigView(
project_id=item.project_id,
canary_enabled=item.canary_enabled,
canary_percent=item.canary_percent,
rollback_to_local=item.rollback_to_local,
compatibility_mode=item.compatibility_mode,
dual_event_write=item.dual_event_write,
route_mode_default=item.route_mode_default,
fallback_chain=_json_loads(item.fallback_chain_json, ["local"]),
alert_thresholds=_json_loads(item.alert_thresholds_json, {}),
)
@router.get("/alerts")
def get_alert_panel(
project_id: int,
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> Dict[str, Any]:
_ensure_project_access(db, project_id, current_user)
config = a2a_runtime.get_project_config(db, project_id, current_user.id)
thresholds = _json_loads(config.alert_thresholds_json, {})
defaults = {"error_rate": 0.05, "p95_ms": 3000, "retry_rate": 0.2, "circuit_open_rate": 0.05}
merged = {**defaults, **thresholds}
return {
"project_id": project_id,
"thresholds": merged,
"panel": {"metrics_endpoint": "/api/v1/a2a/metrics", "task_list_endpoint": "/api/v1/a2a/tasks"},
}
@router.get("/audit-logs")
def list_audit_logs(
project_id: Optional[int] = Query(default=None),
skip: int = Query(default=0, ge=0),
limit: int = Query(default=100, ge=1, le=500),
db: Session = Depends(get_db),
current_user: CurrentUser = Depends(get_current_user),
) -> List[Dict[str, Any]]:
query = db.query(A2AAuditLog)
if project_id is not None:
_ensure_project_access(db, project_id, current_user)
query = query.filter(A2AAuditLog.project_id == project_id)
elif not current_user.is_admin:
query = query.filter(A2AAuditLog.actor_user_id == current_user.id)
rows = query.order_by(A2AAuditLog.created_at.desc()).offset(skip).limit(limit).all()
return [
{
"id": row.id,
"actor_user_id": row.actor_user_id,
"action": row.action,
"target_type": row.target_type,
"target_id": row.target_id,
"project_id": row.project_id,
"task_id": row.task_id,
"result": row.result,
"detail": _json_loads(row.detail_json, {}),
"created_at": row.created_at,
}
for row in rows
]
+134
View File
@@ -0,0 +1,134 @@
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, func
from sqlalchemy.orm import relationship
from app.database import Base
class A2ARemoteAgent(Base):
__tablename__ = "a2a_remote_agents"
id = Column(Integer, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
name = Column(String, nullable=False)
base_url = Column(String, nullable=False)
auth_scheme = Column(String, nullable=False, default="none")
auth_token = Column(String, nullable=True)
protocol_version = Column(String, nullable=True)
capabilities_json = Column(Text, nullable=False, default="[]")
card_json = Column(Text, nullable=True)
card_fetched_at = Column(DateTime, nullable=True)
healthy = Column(Boolean, nullable=False, default=False)
failure_count = Column(Integer, nullable=False, default=0)
circuit_open_until = Column(DateTime, nullable=True)
created_by = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
project = relationship("Project")
class A2ATask(Base):
__tablename__ = "a2a_tasks"
id = Column(String, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
tenant_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
source = Column(String, nullable=False, default="local")
remote_agent_id = Column(Integer, ForeignKey("a2a_remote_agents.id"), nullable=True, index=True)
idempotency_key = Column(String, nullable=True, index=True)
state = Column(String, nullable=False, index=True, default="SUBMITTED")
input_text = Column(Text, nullable=False, default="")
output_text = Column(Text, nullable=True)
error_message = Column(Text, nullable=True)
compatibility_mode = Column(Boolean, nullable=False, default=True)
metadata_json = Column(Text, nullable=False, default="{}")
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
finished_at = Column(DateTime, nullable=True)
project = relationship("Project")
remote_agent = relationship("A2ARemoteAgent")
class A2ATaskEvent(Base):
__tablename__ = "a2a_task_events"
id = Column(Integer, primary_key=True, index=True)
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
event_type = Column(String, nullable=False)
payload_json = Column(Text, nullable=False, default="{}")
created_at = Column(DateTime, default=func.now(), index=True)
task = relationship("A2ATask")
class A2ATaskWebhook(Base):
__tablename__ = "a2a_task_webhooks"
id = Column(Integer, primary_key=True, index=True)
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
target_url = Column(String, nullable=False)
secret = Column(String, nullable=True)
auth_header = Column(String, nullable=True)
enabled = Column(Boolean, nullable=False, default=True)
created_by = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
task = relationship("A2ATask")
class A2AWebhookDelivery(Base):
__tablename__ = "a2a_webhook_deliveries"
id = Column(Integer, primary_key=True, index=True)
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
webhook_id = Column(Integer, ForeignKey("a2a_task_webhooks.id", ondelete="CASCADE"), nullable=False, index=True)
event_id = Column(Integer, ForeignKey("a2a_task_events.id", ondelete="CASCADE"), nullable=False, index=True)
attempt = Column(Integer, nullable=False, default=0)
status = Column(String, nullable=False, default="PENDING")
response_code = Column(Integer, nullable=True)
response_body = Column(Text, nullable=True)
error_message = Column(Text, nullable=True)
next_retry_at = Column(DateTime, nullable=True)
delivered_at = Column(DateTime, nullable=True)
dead_letter = Column(Boolean, nullable=False, default=False, index=True)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
task = relationship("A2ATask")
webhook = relationship("A2ATaskWebhook")
event = relationship("A2ATaskEvent")
class A2AProjectConfig(Base):
__tablename__ = "a2a_project_configs"
project_id = Column(Integer, ForeignKey("projects.id"), primary_key=True)
canary_enabled = Column(Boolean, nullable=False, default=False)
canary_percent = Column(Integer, nullable=False, default=0)
rollback_to_local = Column(Boolean, nullable=False, default=True)
compatibility_mode = Column(Boolean, nullable=False, default=True)
dual_event_write = Column(Boolean, nullable=False, default=True)
route_mode_default = Column(String, nullable=False, default="local_first")
fallback_chain_json = Column(Text, nullable=False, default='["local"]')
alert_thresholds_json = Column(Text, nullable=False, default="{}")
updated_by = Column(Integer, ForeignKey("users.id"), nullable=False)
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
project = relationship("Project")
class A2AAuditLog(Base):
__tablename__ = "a2a_audit_logs"
id = Column(Integer, primary_key=True, index=True)
actor_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
action = Column(String, nullable=False)
target_type = Column(String, nullable=False)
target_id = Column(String, nullable=False)
project_id = Column(Integer, ForeignKey("projects.id"), nullable=True, index=True)
task_id = Column(String, nullable=True, index=True)
result = Column(String, nullable=False)
detail_json = Column(Text, nullable=False, default="{}")
created_at = Column(DateTime, default=func.now(), index=True)
+384
View File
@@ -0,0 +1,384 @@
from __future__ import annotations
import asyncio
import hashlib
import hmac
import json
import time
import uuid
from collections import defaultdict, deque
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple
import httpx
from sqlalchemy.orm import Session
from app.models.a2a import (
A2AAuditLog,
A2AProjectConfig,
A2ARemoteAgent,
A2ATask,
A2ATaskEvent,
A2ATaskWebhook,
A2AWebhookDelivery,
)
from app.models.project import Project
from app.trace import build_error_attributes, trace_service
_STATE_TRANSITIONS = {
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
"REJECTED": set(),
"FAILED": set(),
"COMPLETED": set(),
"CANCELED": set(),
}
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
def _json_loads(raw: Optional[str], default: Any) -> Any:
if not raw:
return default
try:
return json.loads(raw)
except Exception:
return default
def _json_dumps(raw: Any) -> str:
return json.dumps(raw, ensure_ascii=False)
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _mask_error(message: str) -> str:
if not message:
return "internal_error"
return "request_failed"
@dataclass
class A2AResolvedRoute:
selected: str
fallback_chain: List[str]
canary_hit: bool
reason: str
class A2AMetrics:
def __init__(self) -> None:
self._lock = asyncio.Lock()
self._counters: Dict[str, int] = defaultdict(int)
self._latency_ms: Dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=2000))
async def incr(self, key: str, value: int = 1) -> None:
async with self._lock:
self._counters[key] += value
async def observe_latency(self, key: str, elapsed_ms: float) -> None:
async with self._lock:
self._latency_ms[key].append(float(elapsed_ms))
async def snapshot(self) -> Dict[str, Any]:
async with self._lock:
counters = dict(self._counters)
p95 = {}
for key, values in self._latency_ms.items():
series = sorted(values)
if not series:
p95[f"{key}.p95_ms"] = 0.0
continue
idx = int(0.95 * (len(series) - 1))
p95[f"{key}.p95_ms"] = round(series[idx], 2)
total = counters.get("a2a.requests.total", 0)
errors = counters.get("a2a.requests.error", 0)
retries = counters.get("a2a.requests.retry", 0)
breakers = counters.get("a2a.circuit.open", 0)
return {
"counters": counters,
"derived": {
"error_rate": round(errors / total, 4) if total else 0.0,
"retry_rate": round(retries / total, 4) if total else 0.0,
"circuit_open_rate": round(breakers / total, 4) if total else 0.0,
},
"latency": p95,
}
class A2ARuntime:
def __init__(self) -> None:
self._subscribers: Dict[str, List[asyncio.Queue[Dict[str, Any]]]] = defaultdict(list)
self.metrics = A2AMetrics()
self.protocol_version = "1.0"
self._circuit_state: Dict[int, datetime] = {}
async def publish(self, task_id: str, event: Dict[str, Any]) -> None:
queues = list(self._subscribers.get(task_id, []))
for queue in queues:
await queue.put(event)
async def subscribe(self, task_id: str) -> AsyncIterator[Dict[str, Any]]:
queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=200)
self._subscribers[task_id].append(queue)
try:
while True:
payload = await queue.get()
yield payload
finally:
self._subscribers[task_id] = [q for q in self._subscribers.get(task_id, []) if q is not queue]
if not self._subscribers[task_id]:
self._subscribers.pop(task_id, None)
def get_project_config(self, db: Session, project_id: int, user_id: int) -> A2AProjectConfig:
item = db.query(A2AProjectConfig).filter(A2AProjectConfig.project_id == project_id).first()
if item:
return item
config = A2AProjectConfig(project_id=project_id, updated_by=user_id)
db.add(config)
db.commit()
db.refresh(config)
return config
def resolve_route(self, *, project_config: A2AProjectConfig, session_id: str, requested_mode: str, requested_fallback: Optional[List[str]]) -> A2AResolvedRoute:
selected = requested_mode or project_config.route_mode_default or "local_first"
fallback = requested_fallback or _json_loads(project_config.fallback_chain_json, ["local"])
fallback_chain = [item for item in fallback if item in {"a2a", "local", "mcp"}]
if not fallback_chain:
fallback_chain = ["local"]
canary_hit = False
if project_config.canary_enabled and project_config.canary_percent > 0:
digest = hashlib.sha256(f"{project_config.project_id}:{session_id}".encode()).hexdigest()
bucket = int(digest[:8], 16) % 100
canary_hit = bucket < project_config.canary_percent
if selected in {"a2a_first", "a2a"} and not canary_hit:
return A2AResolvedRoute(
selected="local",
fallback_chain=fallback_chain,
canary_hit=False,
reason="canary_not_hit_fallback_local",
)
if selected in {"a2a_first", "a2a"}:
return A2AResolvedRoute(selected="a2a", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="a2a_selected")
if selected in {"mcp_first", "mcp"}:
return A2AResolvedRoute(selected="mcp", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="mcp_selected")
return A2AResolvedRoute(selected="local", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="local_selected")
def can_transition(self, from_state: str, to_state: str) -> bool:
if from_state == to_state:
return True
return to_state in _STATE_TRANSITIONS.get(from_state, set())
def create_task(
self,
db: Session,
*,
project_id: int,
tenant_id: int,
source: str,
input_text: str,
idempotency_key: Optional[str],
remote_agent_id: Optional[int],
compatibility_mode: bool,
metadata: Optional[Dict[str, Any]] = None,
) -> A2ATask:
if idempotency_key:
existing = (
db.query(A2ATask)
.filter(
A2ATask.project_id == project_id,
A2ATask.tenant_id == tenant_id,
A2ATask.idempotency_key == idempotency_key,
)
.first()
)
if existing:
return existing
task = A2ATask(
id=f"task_{uuid.uuid4().hex}",
project_id=project_id,
tenant_id=tenant_id,
source=source,
remote_agent_id=remote_agent_id,
state="SUBMITTED",
input_text=input_text,
idempotency_key=idempotency_key,
compatibility_mode=compatibility_mode,
metadata_json=_json_dumps(metadata or {}),
)
db.add(task)
db.commit()
db.refresh(task)
return task
def append_event(self, db: Session, task: A2ATask, event_type: str, payload: Dict[str, Any]) -> A2ATaskEvent:
event = A2ATaskEvent(task_id=task.id, event_type=event_type, payload_json=_json_dumps(payload))
db.add(event)
db.commit()
db.refresh(event)
return event
def transition_task(
self,
db: Session,
task: A2ATask,
*,
to_state: str,
output_text: Optional[str] = None,
error_message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> A2ATask:
if not self.can_transition(task.state, to_state):
raise ValueError(f"Invalid task transition: {task.state} -> {to_state}")
task.state = to_state
if output_text is not None:
task.output_text = output_text
if error_message is not None:
task.error_message = error_message
if metadata is not None:
task.metadata_json = _json_dumps(metadata)
if to_state in _TERMINAL_STATES:
task.finished_at = _utc_now()
db.add(task)
db.commit()
db.refresh(task)
return task
def record_audit(
self,
db: Session,
*,
actor_user_id: int,
action: str,
target_type: str,
target_id: str,
result: str,
project_id: Optional[int] = None,
task_id: Optional[str] = None,
detail: Optional[Dict[str, Any]] = None,
) -> None:
audit = A2AAuditLog(
actor_user_id=actor_user_id,
action=action,
target_type=target_type,
target_id=target_id,
result=result,
project_id=project_id,
task_id=task_id,
detail_json=_json_dumps(detail or {}),
)
db.add(audit)
db.commit()
async def fetch_agent_card(self, db: Session, agent: A2ARemoteAgent, *, timeout_s: float = 10.0) -> Dict[str, Any]:
if agent.id in self._circuit_state and self._circuit_state[agent.id] > _utc_now():
raise RuntimeError("circuit_open")
started = time.perf_counter()
await self.metrics.incr("a2a.requests.total")
headers = {}
if agent.auth_scheme == "bearer" and agent.auth_token:
headers["Authorization"] = f"Bearer {agent.auth_token}"
url = f"{agent.base_url.rstrip('/')}/api/v1/a2a/agent-card"
with trace_service.start_span("a2a.card.fetch", attributes={"agent_id": agent.id, "url": url}) as span:
for attempt in range(3):
try:
async with httpx.AsyncClient(timeout=timeout_s, verify=True) as client:
resp = await client.get(url, headers=headers)
if resp.status_code >= 400:
raise RuntimeError(f"http_{resp.status_code}")
payload = resp.json()
elapsed_ms = (time.perf_counter() - started) * 1000
await self.metrics.observe_latency("a2a.card.fetch", elapsed_ms)
agent.card_json = _json_dumps(payload)
agent.protocol_version = str(payload.get("protocol_version") or "")
agent.capabilities_json = _json_dumps(payload.get("capabilities") or [])
agent.card_fetched_at = _utc_now()
agent.healthy = True
agent.failure_count = 0
agent.circuit_open_until = None
db.add(agent)
db.commit()
db.refresh(agent)
return payload
except Exception as exc:
span.set_attributes(build_error_attributes(exc, stage="a2a_card_fetch"))
await self.metrics.incr("a2a.requests.error")
if attempt < 2:
await self.metrics.incr("a2a.requests.retry")
await asyncio.sleep(0.2 * (2 ** attempt))
continue
agent.failure_count = (agent.failure_count or 0) + 1
if agent.failure_count >= 3:
reopen_at = _utc_now() + timedelta(seconds=90)
agent.circuit_open_until = reopen_at
self._circuit_state[agent.id] = reopen_at
await self.metrics.incr("a2a.circuit.open")
agent.healthy = False
db.add(agent)
db.commit()
raise
async def notify_webhooks(self, db: Session, task: A2ATask, event: A2ATaskEvent) -> None:
webhooks = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.task_id == task.id, A2ATaskWebhook.enabled == True).all()
if not webhooks:
return
for hook in webhooks:
delivery = A2AWebhookDelivery(task_id=task.id, webhook_id=hook.id, event_id=event.id, attempt=0, status="PENDING")
db.add(delivery)
db.commit()
db.refresh(delivery)
await self._deliver_once(db, hook, event, delivery)
async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None:
event_payload = _json_loads(event.payload_json, {})
request_payload = {
"task_id": event.task_id,
"event_type": event.event_type,
"event_id": event.id,
"payload": event_payload,
}
body = _json_dumps(request_payload).encode("utf-8")
for attempt in range(1, 5):
delivery.attempt = attempt
db.add(delivery)
db.commit()
headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)}
if hook.secret:
digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest()
headers["X-A2A-Signature"] = f"sha256={digest}"
if hook.auth_header:
headers["Authorization"] = hook.auth_header
try:
async with httpx.AsyncClient(timeout=8.0, verify=True) as client:
resp = await client.post(hook.target_url, content=body, headers=headers)
delivery.response_code = resp.status_code
delivery.response_body = (resp.text or "")[:1000]
if 200 <= resp.status_code < 300:
delivery.status = "DELIVERED"
delivery.dead_letter = False
delivery.delivered_at = _utc_now()
db.add(delivery)
db.commit()
return
raise RuntimeError(f"http_{resp.status_code}")
except Exception as exc:
delivery.error_message = str(exc)[:500]
if attempt < 4:
delivery.status = "RETRYING"
delivery.next_retry_at = _utc_now() + timedelta(seconds=2 ** attempt)
db.add(delivery)
db.commit()
await asyncio.sleep(2 ** attempt)
continue
delivery.status = "FAILED"
delivery.dead_letter = True
db.add(delivery)
db.commit()
return
a2a_runtime = A2ARuntime()
+14 -2
View File
@@ -23,7 +23,7 @@ import re
import os
from datetime import datetime
from app.api import upload, llm, skills, users, datasources, projects, semantic, mcp, subagents, knowledge, embedding_models, web_search
from app.api import upload, llm, skills, users, datasources, projects, semantic, mcp, subagents, knowledge, embedding_models, web_search, a2a
from app.connectors.postgres import postgres_connector
from app.connectors.clickhouse import clickhouse_connector
from app.core.artifacts import extract_artifacts
@@ -52,6 +52,15 @@ from app.models.user import User, EmailVerification
from app.models.project import Project
from app.models.datasource import DataSource
from app.models.subagent import Subagent
from app.models.a2a import (
A2ARemoteAgent,
A2ATask,
A2ATaskEvent,
A2ATaskWebhook,
A2AWebhookDelivery,
A2AProjectConfig,
A2AAuditLog,
)
app = FastAPI()
@@ -86,6 +95,7 @@ app.include_router(subagents.router, prefix="/api/v1")
app.include_router(knowledge.router, prefix="/api/v1")
app.include_router(embedding_models.router, prefix="/api/v1")
app.include_router(web_search.router, prefix="/api/v1")
app.include_router(a2a.router, prefix="/api/v1")
STREAM_DELTA_CHUNK_SIZE = 48
PREVIEWABLE_TEXT_EXTENSIONS = {
@@ -292,7 +302,9 @@ class ChatRequest(BaseModel):
source: str = "postgres"
prefer_sql_chart: bool = False
file_url: Optional[str] = None
route_mode: Literal["auto", "chat", "sql"] = "auto"
route_mode: Literal["auto", "chat", "sql", "a2a", "a2a_first", "local_first", "mcp_first"] = "auto"
route_fallback_chain: Optional[List[Literal["a2a", "local", "mcp"]]] = None
a2a_agent_id: Optional[int] = None
knowledge_base_id: Optional[str] = None
+319
View File
@@ -0,0 +1,319 @@
import asyncio
import sys
from collections.abc import Generator
from pathlib import Path
import httpx
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
BACKEND_ROOT = Path(__file__).resolve().parents[1]
REPO_ROOT = BACKEND_ROOT.parent
NANOBOT_ROOT = REPO_ROOT / "nanobot"
if str(BACKEND_ROOT) not in sys.path:
sys.path.insert(0, str(BACKEND_ROOT))
if str(NANOBOT_ROOT) not in sys.path:
sys.path.insert(0, str(NANOBOT_ROOT))
from app.core.security import CurrentUser, get_current_user
from app.database import Base, get_db
from app.models.a2a import A2ARemoteAgent
from app.models.project import Project
from app.models.user import User
from app.services.a2a_service import a2a_runtime
from main import app
def _seed(db: Session) -> tuple[int, str, int, str, int]:
owner = User(username="a2a_owner", email="a2a_owner@example.com", hashed_password="x", is_admin=False)
other = User(username="a2a_other", email="a2a_other@example.com", hashed_password="x", is_admin=False)
db.add(owner)
db.add(other)
db.commit()
db.refresh(owner)
db.refresh(other)
project = Project(name="a2a_project", description="a2a", owner_id=owner.id)
db.add(project)
db.commit()
db.refresh(project)
return owner.id, owner.username, other.id, other.username, project.id
def test_a2a_send_list_cancel_and_rollout() -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
send_resp = client.post(
"/api/v1/a2a/messages/send",
json={
"project_id": project_id,
"message": "hello a2a",
"session_id": "test-a2a-session",
"route_mode": "local_first",
},
)
assert send_resp.status_code == 200
task_id = send_resp.json()["task"]["id"]
get_resp = client.get(f"/api/v1/a2a/tasks/{task_id}")
assert get_resp.status_code == 200
assert get_resp.json()["project_id"] == project_id
list_resp = client.get("/api/v1/a2a/tasks", params={"project_id": project_id})
assert list_resp.status_code == 200
assert any(item["id"] == task_id for item in list_resp.json())
cancel_resp = client.post(f"/api/v1/a2a/tasks/{task_id}/cancel")
assert cancel_resp.status_code == 200
assert cancel_resp.json()["state"] in {"CANCELED", "COMPLETED", "FAILED"}
rollout_resp = client.put(
f"/api/v1/a2a/projects/{project_id}/rollout",
json={"canary_enabled": True, "canary_percent": 30, "route_mode_default": "a2a_first", "fallback_chain": ["a2a", "local"]},
)
assert rollout_resp.status_code == 200
assert rollout_resp.json()["canary_enabled"] is True
assert rollout_resp.json()["canary_percent"] == 30
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_a2a_task_tenant_isolation() -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
send_resp = client.post(
"/api/v1/a2a/messages/send",
json={"project_id": project_id, "message": "tenant isolation", "session_id": "tenant-isolation", "route_mode": "local"},
)
assert send_resp.status_code == 200
task_id = send_resp.json()["task"]["id"]
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
forbidden_resp = client.get(f"/api/v1/a2a/tasks/{task_id}")
assert forbidden_resp.status_code == 404
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_a2a_metrics_admin_only() -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, _ = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
denied = client.get("/api/v1/a2a/metrics")
assert denied.status_code == 403
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
ok = client.get("/api/v1/a2a/metrics")
assert ok.status_code == 200
assert "counters" in ok.json()
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_a2a_send_idempotency_key_deduplicates_task() -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
payload = {
"project_id": project_id,
"message": "dedupe-task",
"session_id": "idempotency-session",
"route_mode": "local_first",
"idempotency_key": "same-key-1",
}
first_resp = client.post("/api/v1/a2a/messages/send", json=payload)
second_resp = client.post("/api/v1/a2a/messages/send", json=payload)
assert first_resp.status_code == 200
assert second_resp.status_code == 200
assert first_resp.json()["task"]["id"] == second_resp.json()["task"]["id"]
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_a2a_fetch_agent_card_auth_failure_marks_agent_unhealthy(monkeypatch) -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, _, _, _, project_id = _seed(db)
agent = A2ARemoteAgent(
project_id=project_id,
name="auth-fail-agent",
base_url="https://remote.example.com",
auth_scheme="bearer",
auth_token="bad-token",
created_by=owner_id,
)
db.add(agent)
db.commit()
db.refresh(agent)
a2a_runtime._circuit_state.pop(agent.id, None)
class _FailResp:
status_code = 401
@staticmethod
def json():
return {"detail": "unauthorized"}
class _Client401:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url, headers=None):
return _FailResp()
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _Client401)
with pytest.raises(RuntimeError):
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
db.refresh(agent)
assert agent.healthy is False
assert agent.failure_count == 1
Base.metadata.drop_all(bind=engine)
db.close()
engine.dispose()
def test_a2a_fetch_agent_card_remote_unavailable_opens_circuit(monkeypatch) -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, _, _, _, project_id = _seed(db)
agent = A2ARemoteAgent(
project_id=project_id,
name="offline-agent",
base_url="https://offline.example.com",
auth_scheme="none",
created_by=owner_id,
)
db.add(agent)
db.commit()
db.refresh(agent)
a2a_runtime._circuit_state.pop(agent.id, None)
class _ClientDown:
def __init__(self, *args, **kwargs):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return False
async def get(self, url, headers=None):
raise httpx.ConnectError("network down")
monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _ClientDown)
for _ in range(3):
with pytest.raises(Exception):
asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01))
db.refresh(agent)
assert agent.healthy is False
assert agent.failure_count == 3
assert agent.circuit_open_until is not None
Base.metadata.drop_all(bind=engine)
db.close()
engine.dispose()