from __future__ import annotations import asyncio import hashlib import hmac import json import time import uuid from collections import defaultdict, deque from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple import httpx from sqlalchemy.orm import Session from app.models.a2a import ( A2AAuditLog, A2AProjectConfig, A2ARemoteAgent, A2ATask, A2ATaskEvent, A2ATaskWebhook, A2AWebhookDelivery, ) from app.models.project import Project from app.trace import build_error_attributes, trace_service _STATE_TRANSITIONS = { "SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"}, "WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"}, "INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"}, "AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"}, "REJECTED": set(), "FAILED": set(), "COMPLETED": set(), "CANCELED": set(), } _TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"} def _json_loads(raw: Optional[str], default: Any) -> Any: if not raw: return default try: return json.loads(raw) except Exception: return default def _json_dumps(raw: Any) -> str: return json.dumps(raw, ensure_ascii=False) def _utc_now() -> datetime: return datetime.now(timezone.utc) def _mask_error(message: str) -> str: if not message: return "internal_error" return "request_failed" @dataclass class A2AResolvedRoute: selected: str fallback_chain: List[str] canary_hit: bool reason: str class A2AMetrics: def __init__(self) -> None: self._lock = asyncio.Lock() self._counters: Dict[str, int] = defaultdict(int) self._latency_ms: Dict[str, deque[float]] = defaultdict(lambda: deque(maxlen=2000)) async def incr(self, key: str, value: int = 1) -> None: async with self._lock: self._counters[key] += value async def observe_latency(self, key: str, elapsed_ms: float) -> None: async with self._lock: self._latency_ms[key].append(float(elapsed_ms)) async def snapshot(self) -> Dict[str, Any]: async with self._lock: counters = dict(self._counters) p95 = {} for key, values in self._latency_ms.items(): series = sorted(values) if not series: p95[f"{key}.p95_ms"] = 0.0 continue idx = int(0.95 * (len(series) - 1)) p95[f"{key}.p95_ms"] = round(series[idx], 2) total = counters.get("a2a.requests.total", 0) errors = counters.get("a2a.requests.error", 0) retries = counters.get("a2a.requests.retry", 0) breakers = counters.get("a2a.circuit.open", 0) return { "counters": counters, "derived": { "error_rate": round(errors / total, 4) if total else 0.0, "retry_rate": round(retries / total, 4) if total else 0.0, "circuit_open_rate": round(breakers / total, 4) if total else 0.0, }, "latency": p95, } class A2ARuntime: def __init__(self) -> None: self._subscribers: Dict[str, List[asyncio.Queue[Dict[str, Any]]]] = defaultdict(list) self.metrics = A2AMetrics() self.protocol_version = "1.0" self._circuit_state: Dict[int, datetime] = {} async def publish(self, task_id: str, event: Dict[str, Any]) -> None: queues = list(self._subscribers.get(task_id, [])) for queue in queues: await queue.put(event) async def subscribe(self, task_id: str) -> AsyncIterator[Dict[str, Any]]: queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue(maxsize=200) self._subscribers[task_id].append(queue) try: while True: payload = await queue.get() yield payload finally: self._subscribers[task_id] = [q for q in self._subscribers.get(task_id, []) if q is not queue] if not self._subscribers[task_id]: self._subscribers.pop(task_id, None) def get_project_config(self, db: Session, project_id: int, user_id: int) -> A2AProjectConfig: item = db.query(A2AProjectConfig).filter(A2AProjectConfig.project_id == project_id).first() if item: return item config = A2AProjectConfig(project_id=project_id, updated_by=user_id) db.add(config) db.commit() db.refresh(config) return config def resolve_route(self, *, project_config: A2AProjectConfig, session_id: str, requested_mode: str, requested_fallback: Optional[List[str]]) -> A2AResolvedRoute: selected = requested_mode or project_config.route_mode_default or "local_first" fallback = requested_fallback or _json_loads(project_config.fallback_chain_json, ["local"]) fallback_chain = [item for item in fallback if item in {"a2a", "local", "mcp"}] if not fallback_chain: fallback_chain = ["local"] canary_hit = False if project_config.canary_enabled and project_config.canary_percent > 0: digest = hashlib.sha256(f"{project_config.project_id}:{session_id}".encode()).hexdigest() bucket = int(digest[:8], 16) % 100 canary_hit = bucket < project_config.canary_percent if selected in {"a2a_first", "a2a"} and not canary_hit: return A2AResolvedRoute( selected="local", fallback_chain=fallback_chain, canary_hit=False, reason="canary_not_hit_fallback_local", ) if selected in {"a2a_first", "a2a"}: return A2AResolvedRoute(selected="a2a", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="a2a_selected") if selected in {"mcp_first", "mcp"}: return A2AResolvedRoute(selected="mcp", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="mcp_selected") return A2AResolvedRoute(selected="local", fallback_chain=fallback_chain, canary_hit=canary_hit, reason="local_selected") def can_transition(self, from_state: str, to_state: str) -> bool: if from_state == to_state: return True return to_state in _STATE_TRANSITIONS.get(from_state, set()) def create_task( self, db: Session, *, project_id: int, tenant_id: int, source: str, input_text: str, idempotency_key: Optional[str], remote_agent_id: Optional[int], compatibility_mode: bool, metadata: Optional[Dict[str, Any]] = None, ) -> A2ATask: if idempotency_key: existing = ( db.query(A2ATask) .filter( A2ATask.project_id == project_id, A2ATask.tenant_id == tenant_id, A2ATask.idempotency_key == idempotency_key, ) .first() ) if existing: return existing task = A2ATask( id=f"task_{uuid.uuid4().hex}", project_id=project_id, tenant_id=tenant_id, source=source, remote_agent_id=remote_agent_id, state="SUBMITTED", input_text=input_text, idempotency_key=idempotency_key, compatibility_mode=compatibility_mode, metadata_json=_json_dumps(metadata or {}), ) db.add(task) db.commit() db.refresh(task) return task def append_event(self, db: Session, task: A2ATask, event_type: str, payload: Dict[str, Any]) -> A2ATaskEvent: event = A2ATaskEvent(task_id=task.id, event_type=event_type, payload_json=_json_dumps(payload)) db.add(event) db.commit() db.refresh(event) return event def transition_task( self, db: Session, task: A2ATask, *, to_state: str, output_text: Optional[str] = None, error_message: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, ) -> A2ATask: if not self.can_transition(task.state, to_state): raise ValueError(f"Invalid task transition: {task.state} -> {to_state}") task.state = to_state if output_text is not None: task.output_text = output_text if error_message is not None: task.error_message = error_message if metadata is not None: task.metadata_json = _json_dumps(metadata) if to_state in _TERMINAL_STATES: task.finished_at = _utc_now() db.add(task) db.commit() db.refresh(task) return task def record_audit( self, db: Session, *, actor_user_id: int, action: str, target_type: str, target_id: str, result: str, project_id: Optional[int] = None, task_id: Optional[str] = None, detail: Optional[Dict[str, Any]] = None, ) -> None: audit = A2AAuditLog( actor_user_id=actor_user_id, action=action, target_type=target_type, target_id=target_id, result=result, project_id=project_id, task_id=task_id, detail_json=_json_dumps(detail or {}), ) db.add(audit) db.commit() async def fetch_agent_card(self, db: Session, agent: A2ARemoteAgent, *, timeout_s: float = 10.0) -> Dict[str, Any]: if agent.id in self._circuit_state and self._circuit_state[agent.id] > _utc_now(): raise RuntimeError("circuit_open") started = time.perf_counter() await self.metrics.incr("a2a.requests.total") headers = {} if agent.auth_scheme == "bearer" and agent.auth_token: headers["Authorization"] = f"Bearer {agent.auth_token}" url = f"{agent.base_url.rstrip('/')}/api/v1/a2a/agent-card" with trace_service.start_span("a2a.card.fetch", attributes={"agent_id": agent.id, "url": url}) as span: for attempt in range(3): try: async with httpx.AsyncClient(timeout=timeout_s, verify=True) as client: resp = await client.get(url, headers=headers) if resp.status_code >= 400: raise RuntimeError(f"http_{resp.status_code}") payload = resp.json() elapsed_ms = (time.perf_counter() - started) * 1000 await self.metrics.observe_latency("a2a.card.fetch", elapsed_ms) agent.card_json = _json_dumps(payload) agent.protocol_version = str(payload.get("protocol_version") or "") agent.capabilities_json = _json_dumps(payload.get("capabilities") or []) agent.card_fetched_at = _utc_now() agent.healthy = True agent.failure_count = 0 agent.circuit_open_until = None db.add(agent) db.commit() db.refresh(agent) return payload except Exception as exc: span.set_attributes(build_error_attributes(exc, stage="a2a_card_fetch")) await self.metrics.incr("a2a.requests.error") if attempt < 2: await self.metrics.incr("a2a.requests.retry") await asyncio.sleep(0.2 * (2 ** attempt)) continue agent.failure_count = (agent.failure_count or 0) + 1 if agent.failure_count >= 3: reopen_at = _utc_now() + timedelta(seconds=90) agent.circuit_open_until = reopen_at self._circuit_state[agent.id] = reopen_at await self.metrics.incr("a2a.circuit.open") agent.healthy = False db.add(agent) db.commit() raise async def notify_webhooks(self, db: Session, task: A2ATask, event: A2ATaskEvent) -> None: webhooks = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.task_id == task.id, A2ATaskWebhook.enabled == True).all() if not webhooks: return for hook in webhooks: delivery = A2AWebhookDelivery(task_id=task.id, webhook_id=hook.id, event_id=event.id, attempt=0, status="PENDING") db.add(delivery) db.commit() db.refresh(delivery) await self._deliver_once(db, hook, event, delivery) async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None: event_payload = _json_loads(event.payload_json, {}) request_payload = { "task_id": event.task_id, "event_type": event.event_type, "event_id": event.id, "payload": event_payload, } body = _json_dumps(request_payload).encode("utf-8") for attempt in range(1, 5): delivery.attempt = attempt db.add(delivery) db.commit() headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)} if hook.secret: digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest() headers["X-A2A-Signature"] = f"sha256={digest}" if hook.auth_header: headers["Authorization"] = hook.auth_header try: async with httpx.AsyncClient(timeout=8.0, verify=True) as client: resp = await client.post(hook.target_url, content=body, headers=headers) delivery.response_code = resp.status_code delivery.response_body = (resp.text or "")[:1000] if 200 <= resp.status_code < 300: delivery.status = "DELIVERED" delivery.dead_letter = False delivery.delivered_at = _utc_now() db.add(delivery) db.commit() return raise RuntimeError(f"http_{resp.status_code}") except Exception as exc: delivery.error_message = str(exc)[:500] if attempt < 4: delivery.status = "RETRYING" delivery.next_retry_at = _utc_now() + timedelta(seconds=2 ** attempt) db.add(delivery) db.commit() await asyncio.sleep(2 ** attempt) continue delivery.status = "FAILED" delivery.dead_letter = True db.add(delivery) db.commit() return a2a_runtime = A2ARuntime()