diff --git a/backend/app/api/a2a.py b/backend/app/api/a2a.py index faa8bbd..556326c 100644 --- a/backend/app/api/a2a.py +++ b/backend/app/api/a2a.py @@ -4,8 +4,8 @@ from datetime import datetime from typing import Any, Dict, List, Literal, Optional, Tuple import httpx -from fastapi import APIRouter, Depends, Header, HTTPException, Query, status -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, status +from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -14,51 +14,278 @@ from app.core.security import CurrentUser, get_current_user from app.database import SessionLocal, get_db from app.models.a2a import ( A2AAuditLog, + A2AArtifact, + A2AMessage, + A2APart, A2AProjectConfig, A2ARemoteAgent, A2ATask, A2ATaskEvent, A2ATaskWebhook, A2AWebhookDelivery, + A2ATaskState, ) from app.models.project import Project -from app.services.a2a_service import _json_dumps, _json_loads, a2a_runtime +from app.schemas.a2a import ( + A2AMessageCreateSchema, + A2AMessageSchema, + A2AMessageRole, + A2APartSchema, + A2ATaskSchema, + A2ATaskWithHistorySchema, + A2ATaskWithMessagesSchema, + A2AArtifactSchema, + A2ATaskStatusSchema, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + TaskMessageEvent, + StreamResponse, + StreamResponseTask, + SendMessageRequest, + SendStreamingMessageRequest, + GetTaskRequest, + TaskListRequest, + CancelTaskRequest, + PushNotificationConfigCreate, + PushNotificationConfig, + A2ATaskState as SchemaTaskState, + AgentCardPublicSchema, + AgentCardExtendedSchema, + AgentSkill, + AgentProvider, + AgentSupportedInterface, + SecuritySchemeApiKey, + SecuritySchemeHttpAuth, + SecuritySchemeOAuth2, + SecuritySchemeOpenIdConnect, + SecuritySchemeMtls, + OAuth2Flows, + VersionNotSupportedError, +) +from app.services.a2a_service import _json_dumps, _json_loads, a2a_runtime, RemoteAgentSecuritySelector from app.trace import build_error_attributes, trace_service -router = APIRouter(prefix="/a2a", tags=["a2a"]) +router = APIRouter() +A2A_API_PREFIX = "/a2a" SUPPORTED_PROTOCOL_VERSION = "1.0" SUPPORTED_CAPABILITIES = ["streaming", "push", "task_management", "subscribe"] SUPPORTED_AUTH = ["bearer", "shared_secret", "none"] +async def verify_a2a_version( + response: Response, + a2a_version: Optional[str] = Header(default=None, alias="A2A-Version"), +) -> None: + if a2a_version is not None and a2a_version != SUPPORTED_PROTOCOL_VERSION: + error = VersionNotSupportedError( + code=-32009, + message=f"Protocol version '{a2a_version}' not supported. Supported version is '{SUPPORTED_PROTOCOL_VERSION}'.", + data={"supportedVersion": SUPPORTED_PROTOCOL_VERSION}, + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=json.dumps(error.model_dump(), ensure_ascii=False), + ) + response.headers["A2A-Version"] = SUPPORTED_PROTOCOL_VERSION + + +async def verify_shared_secret( + request_data: bytes, + x_a2a_signature: Optional[str] = Header(default=None, alias="X-A2A-Signature"), + x_a2a_timestamp: Optional[str] = Header(default=None, alias="X-A2A-Timestamp"), + shared_secret: Optional[str] = None, +) -> bool: + if not x_a2a_signature or not x_a2a_timestamp: + return False + if not shared_secret: + return False + try: + from app.services.a2a_service import SharedSecretAuth + timestamp = int(x_a2a_timestamp) + return SharedSecretAuth.verify_signature(shared_secret, request_data, x_a2a_signature, timestamp) + except (ValueError, TypeError): + return False + + +def get_user_bearer_token(current_user: CurrentUser) -> str: + from app.core.security import create_access_token + return create_access_token({"sub": str(current_user.id), "is_admin": current_user.is_admin}) + + +class A2AStreamingResponse(StreamingResponse): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.headers["A2A-Version"] = SUPPORTED_PROTOCOL_VERSION + + def _mask_error(message: str) -> str: if not message: return "internal_error" return "request_failed" -class AgentCardResponse(BaseModel): - name: str - protocol_version: str - capabilities: List[str] - endpoints: Dict[str, str] - auth: List[str] +def _json_serialize(obj: Any) -> Any: + if isinstance(obj, datetime): + return obj.isoformat() + if isinstance(obj, enum.Enum): + return obj.value + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +AGENT_SKILLS = [ + AgentSkill( + id="dataclaw-data-analysis", + name="Data Analysis", + description="Analyze datasets, generate insights, and produce visualizations", + tags=["data", "analysis", "analytics", "visualization"], + examples=[], + inputModes=["text", "data"], + outputModes=["text", "artifact", "stream"], + securityRequirements=[], + ), + AgentSkill( + id="dataclaw-nl2sql", + name="Natural Language to SQL", + description="Convert natural language queries into SQL statements", + tags=["nl2sql", "sql", "query", "database"], + examples=[], + inputModes=["text"], + outputModes=["text", "data"], + securityRequirements=[], + ), + AgentSkill( + id="dataclaw-artifact-preview", + name="Artifact Preview & Download", + description="Generate and serve previews for data artifacts", + tags=["artifact", "preview", "download", "export"], + examples=[], + inputModes=["text", "data"], + outputModes=["artifact", "stream"], + securityRequirements=[], + ), +] + +AGENT_PROVIDER = AgentProvider( + organization="DataClaw", + url="https://dataclaw.io", +) + +AGENT_SUPPORTED_INTERFACES = [ + AgentSupportedInterface( + url="/message:send", + protocolBinding="http", + protocolVersion="1.0", + ), + AgentSupportedInterface( + url="/message:stream", + protocolBinding="http", + protocolVersion="1.0", + ), +] + +AGENT_SECURITY_SCHEMES = { + "bearer": SecuritySchemeHttpAuth(scheme="bearer", description="JWT Bearer token authentication"), + "apiKey": SecuritySchemeApiKey(name="X-API-Key", in_="header", description="API key authentication"), + "oauth2": SecuritySchemeOAuth2( + flows=OAuth2Flows( + authorizationCode={ + "authorizationUrl": "/oauth/authorize", + "tokenUrl": "/oauth/token", + "scopes": {"read": "Read access", "write": "Write access"}, + }, + clientCredentials={"tokenUrl": "/oauth/token", "scopes": {"read": "Read access", "write": "Write access"}}, + deviceCode={"authorizationUrl": "/oauth/device", "tokenUrl": "/oauth/token", "scopes": {"read": "Read access", "write": "Write access"}}, + ), + description="OAuth2 authentication", + ), + "openIdConnect": SecuritySchemeOpenIdConnect( + openIdConnectUrl="/.well-known/openid-configuration", + description="OpenID Connect authentication", + scopes={"openid": "OpenID scope", "profile": "Profile scope"}, + ), + "mutualTLS": SecuritySchemeMtls( + description="Mutual TLS authentication", + caCerts=[], + clientCert=None, + clientKey=None, + ), + "shared_secret": SecuritySchemeHttpAuth(scheme="hmac-sha256", description="HMAC-SHA256 shared secret authentication"), +} + + +def _build_public_agent_card() -> AgentCardPublicSchema: + return AgentCardPublicSchema( + name="DataClaw A2A Gateway", + protocol_version=SUPPORTED_PROTOCOL_VERSION, + capabilities=SUPPORTED_CAPABILITIES, + endpoints={ + "sendMessage": "/message:send", + "sendStreamingMessage": "/message:stream", + "getTask": "/tasks/{task_id}", + "listTasks": "/tasks", + "cancelTask": "/tasks/{task_id}:cancel", + "subscribeTask": "/tasks/{task_id}:subscribe", + "pushNotificationConfig": "/tasks/{task_id}/pushNotificationConfigs", + }, + auth=SUPPORTED_AUTH, + skills=AGENT_SKILLS, + provider=AGENT_PROVIDER, + supportedInterfaces=AGENT_SUPPORTED_INTERFACES, + defaultInputModes=["text", "data"], + defaultOutputModes=["text", "artifact", "stream"], + iconUrl="https://dataclaw.io/icon.png", + documentationUrl="https://docs.dataclaw.io/a2a", + ) + + +def _build_extended_agent_card(current_user: CurrentUser) -> AgentCardExtendedSchema: + public_card = _build_public_agent_card() + return AgentCardExtendedSchema( + **public_card.model_dump(), + securitySchemes=AGENT_SECURITY_SCHEMES, + security=[{"bearer": []}, {"apiKey": []}], + signatures=[], + tenantId=current_user.id, + isAdmin=current_user.is_admin, + ) class RemoteAgentCreate(BaseModel): project_id: int name: str = Field(min_length=1, max_length=120) base_url: str = Field(min_length=1, max_length=500) - auth_scheme: Literal["none", "bearer"] = "none" + auth_scheme: Literal["none", "bearer", "shared_secret", "oauth2", "openIdConnect", "mutualTLS"] = "none" auth_token: Optional[str] = None + shared_secret: Optional[str] = None + mtls_ca_cert: Optional[str] = None + mtls_client_cert: Optional[str] = None + mtls_client_key: Optional[str] = None + oauth2_client_id: Optional[str] = None + oauth2_client_secret: Optional[str] = None + oauth2_token_url: Optional[str] = None + oauth2_scopes: Optional[str] = None + oidc_issuer_url: Optional[str] = None + oidc_client_id: Optional[str] = None + oidc_client_secret: Optional[str] = None class RemoteAgentUpdate(BaseModel): name: Optional[str] = None base_url: Optional[str] = None - auth_scheme: Optional[Literal["none", "bearer"]] = None + auth_scheme: Optional[Literal["none", "bearer", "shared_secret", "oauth2", "openIdConnect", "mutualTLS"]] = None auth_token: Optional[str] = None + shared_secret: Optional[str] = None + mtls_ca_cert: Optional[str] = None + mtls_client_cert: Optional[str] = None + mtls_client_key: Optional[str] = None + oauth2_client_id: Optional[str] = None + oauth2_client_secret: Optional[str] = None + oauth2_token_url: Optional[str] = None + oauth2_scopes: Optional[str] = None + oidc_issuer_url: Optional[str] = None + oidc_client_id: Optional[str] = None + oidc_client_secret: Optional[str] = None class RemoteAgentView(BaseModel): @@ -73,17 +300,10 @@ class RemoteAgentView(BaseModel): failure_count: int circuit_open_until: Optional[datetime] = None card_fetched_at: Optional[datetime] = None - - -class SendMessageRequest(BaseModel): - project_id: int - message: str = Field(min_length=1) - session_id: str = "api:a2a" - remote_agent_id: Optional[int] = None - route_mode: Literal["auto", "local", "a2a", "a2a_first", "local_first", "mcp_first"] = "auto" - fallback_chain: Optional[List[Literal["a2a", "local", "mcp"]]] = None - idempotency_key: Optional[str] = None - metadata: Optional[Dict[str, Any]] = None + shared_secret_configured: bool = False + mtls_configured: bool = False + oauth2_configured: bool = False + oidc_configured: bool = False class TaskView(BaseModel): @@ -191,6 +411,180 @@ def _task_to_view(task: A2ATask) -> TaskView: ) +def _task_to_schema(task: A2ATask) -> A2ATaskSchema: + return A2ATaskSchema( + id=task.id, + contextId=task.context_id, + projectId=task.project_id, + tenantId=task.tenant_id, + source=task.source, + remoteAgentId=task.remote_agent_id, + idempotencyKey=task.idempotency_key, + state=SchemaTaskState(task.state.value), + inputText=task.input_text, + outputText=task.output_text, + errorMessage=task.error_message, + metadata=_json_loads(task.metadata_json, {}), + historyLength=task.history_length or 0, + createdAt=task.created_at, + updatedAt=task.updated_at, + finishedAt=task.finished_at, + ) + + +def _task_to_with_history(task: A2ATask, history_length: Optional[int] = None) -> A2ATaskWithHistorySchema: + query = db.query(A2AMessage).filter(A2AMessage.task_id == task.id) + if history_length is not None and history_length > 0: + query = query.order_by(A2AMessage.id.desc()).limit(history_length) + messages = query.all() + messages = list(reversed(messages)) + else: + messages = query.order_by(A2AMessage.id.asc()).all() + + message_schemas = [] + for msg in messages: + parts = db.query(A2APart).filter(A2APart.message_id == msg.id).all() + part_schemas = [] + for p in parts: + part_schemas.append(A2APartSchema( + part_type=p.part_type, + text=p.text_content, + raw=p.raw_content, + url=p.url_content, + data=p.data_content, + mediaType=p.media_type, + filename=p.filename, + metadata=_json_loads(p.metadata_json, {}), + )) + message_schemas.append(A2AMessageSchema( + messageId=msg.message_id, + contextId=msg.context_id, + taskId=msg.task_id, + role=msg.role, + parts=part_schemas, + extensions=_json_loads(msg.extensions_json, {}), + referenceTaskIds=_json_loads(msg.reference_task_ids_json, []), + createdAt=msg.created_at, + )) + + artifacts = db.query(A2AArtifact).filter(A2AArtifact.task_id == task.id).all() + artifact_schemas = [] + for art in artifacts: + parts = db.query(A2APart).filter(A2APart.artifact_id == art.id).all() + part_schemas = [] + for p in parts: + part_schemas.append(A2APartSchema( + part_type=p.part_type, + text=p.text_content, + raw=p.raw_content, + url=p.url_content, + data=p.data_content, + mediaType=p.media_type, + filename=p.filename, + metadata=_json_loads(p.metadata_json, {}), + )) + artifact_schemas.append(A2AArtifactSchema( + artifactId=art.artifact_id, + name=art.name, + description=art.description, + parts=part_schemas, + metadata=_json_loads(art.metadata_json, {}), + extensions=_json_loads(art.extensions_json, {}), + createdAt=art.created_at, + updatedAt=art.updated_at, + )) + + return A2ATaskWithHistorySchema( + id=task.id, + contextId=task.context_id, + projectId=task.project_id, + tenantId=task.tenant_id, + state=SchemaTaskState(task.state.value), + history=message_schemas, + artifacts=artifact_schemas, + createdAt=task.created_at, + updatedAt=task.updated_at, + finishedAt=task.finished_at, + ) + + +def _task_to_with_messages(task: A2ATask) -> A2ATaskWithMessagesSchema: + messages = db.query(A2AMessage).filter(A2AMessage.task_id == task.id).order_by(A2AMessage.id.asc()).all() + message_schemas = [] + for msg in messages: + parts = db.query(A2APart).filter(A2APart.message_id == msg.id).all() + part_schemas = [] + for p in parts: + part_schemas.append(A2APartSchema( + part_type=p.part_type, + text=p.text_content, + raw=p.raw_content, + url=p.url_content, + data=p.data_content, + mediaType=p.media_type, + filename=p.filename, + metadata=_json_loads(p.metadata_json, {}), + )) + message_schemas.append(A2AMessageSchema( + messageId=msg.message_id, + contextId=msg.context_id, + taskId=msg.task_id, + role=msg.role, + parts=part_schemas, + extensions=_json_loads(msg.extensions_json, {}), + referenceTaskIds=_json_loads(msg.reference_task_ids_json, []), + createdAt=msg.created_at, + )) + + artifacts = db.query(A2AArtifact).filter(A2AArtifact.task_id == task.id).all() + artifact_schemas = [] + for art in artifacts: + parts = db.query(A2APart).filter(A2APart.artifact_id == art.id).all() + part_schemas = [] + for p in parts: + part_schemas.append(A2APartSchema( + part_type=p.part_type, + text=p.text_content, + raw=p.raw_content, + url=p.url_content, + data=p.data_content, + mediaType=p.media_type, + filename=p.filename, + metadata=_json_loads(p.metadata_json, {}), + )) + artifact_schemas.append(A2AArtifactSchema( + artifactId=art.artifact_id, + name=art.name, + description=art.description, + parts=part_schemas, + metadata=_json_loads(art.metadata_json, {}), + extensions=_json_loads(art.extensions_json, {}), + createdAt=art.created_at, + updatedAt=art.updated_at, + )) + + return A2ATaskWithMessagesSchema( + id=task.id, + contextId=task.context_id, + projectId=task.project_id, + tenantId=task.tenant_id, + source=task.source, + remoteAgentId=task.remote_agent_id, + idempotencyKey=task.idempotency_key, + state=SchemaTaskState(task.state.value), + inputText=task.input_text, + outputText=task.output_text, + errorMessage=task.error_message, + metadata=_json_loads(task.metadata_json, {}), + historyLength=task.history_length or 0, + createdAt=task.created_at, + updatedAt=task.updated_at, + finishedAt=task.finished_at, + messages=message_schemas, + artifacts=artifact_schemas, + ) + + def _agent_to_view(agent: A2ARemoteAgent) -> RemoteAgentView: return RemoteAgentView( id=agent.id, @@ -204,6 +598,10 @@ def _agent_to_view(agent: A2ARemoteAgent) -> RemoteAgentView: failure_count=int(agent.failure_count or 0), circuit_open_until=agent.circuit_open_until, card_fetched_at=agent.card_fetched_at, + shared_secret_configured=bool(agent.shared_secret), + mtls_configured=bool(agent.mtls_client_cert and agent.mtls_client_key), + oauth2_configured=bool(agent.oauth2_client_id and agent.oauth2_token_url), + oidc_configured=bool(agent.oidc_issuer_url), ) @@ -244,10 +642,33 @@ def _build_artifact_event(task_id: str, content: str, *, compatibility_mode: boo return payload +def _part_to_model(part_schema: A2APartSchema) -> Dict[str, Any]: + return { + "text": part_schema.text, + "raw": part_schema.raw, + "url": part_schema.url, + "data": part_schema.data, + "mediaType": part_schema.mediaType, + "filename": part_schema.filename, + "metadata": part_schema.metadata or {}, + } + + +def _message_to_task_input(message: A2AMessageCreateSchema) -> str: + text_parts = [] + for part in message.parts: + if part.part_type.value == "text" and part.text: + text_parts.append(part.text) + elif part.part_type.value == "data" and part.data: + text_parts.append(str(part.data)) + return "\n".join(text_parts) if text_parts else "" + + async def _delegate_to_remote(task: A2ATask, agent: A2ARemoteAgent, message: str) -> Tuple[str, Dict[str, Any]]: - headers: Dict[str, str] = {} - if agent.auth_scheme == "bearer" and agent.auth_token: - headers["Authorization"] = f"Bearer {agent.auth_token}" + security_selector = RemoteAgentSecuritySelector(agent) + security_selector.load_security_from_card() + preferred_scheme = security_selector.get_preferred_auth_scheme() + payload = { "project_id": task.project_id, "message": message, @@ -256,9 +677,28 @@ async def _delegate_to_remote(task: A2ATask, agent: A2ARemoteAgent, message: str "route_mode": "local_first", "metadata": {"delegated_by": "dataclaw", "task_id": task.id}, } + body_bytes = json.dumps(payload).encode("utf-8") url = f"{agent.base_url.rstrip('/')}/api/v1/a2a/messages/send" - async with httpx.AsyncClient(timeout=25.0, verify=True) as client: - resp = await client.post(url, json=payload, headers=headers) + path = "/api/v1/a2a/messages/send" + + headers: Dict[str, str] = {"Content-Type": "application/json"} + + if preferred_scheme == "shared_secret" and agent.shared_secret: + sig_headers = security_selector.create_signed_request_headers("POST", path, body_bytes) + headers.update(sig_headers) + else: + auth_headers = await security_selector.authorize_request("POST", url) + headers.update(auth_headers) + + mtls_context = security_selector.get_mtls_context() + + if mtls_context: + async with httpx.AsyncClient(timeout=25.0, ssl=mtls_context) as client: + resp = await client.post(url, content=body_bytes, headers=headers) + else: + async with httpx.AsyncClient(timeout=25.0, verify=True) as client: + resp = await client.post(url, content=body_bytes, headers=headers) + if resp.status_code >= 400: raise RuntimeError(f"remote_http_{resp.status_code}") body = resp.json() @@ -276,25 +716,28 @@ async def _run_task(task_id: str, request: SendMessageRequest, tenant_id: int) - if not task: return config = a2a_runtime.get_project_config(db, task.project_id, tenant_id) - if task.state in {"CANCELED", "REJECTED"}: + if task.state in {A2ATaskState.CANCELED, A2ATaskState.REJECTED}: return with trace_service.start_span("a2a.task.execute", attributes={"task_id": task.id, "project_id": task.project_id, "source": task.source}) as span: start_ts = datetime.utcnow().timestamp() try: - task = a2a_runtime.transition_task(db, task, to_state="WORKING") + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.WORKING) status_event = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write) status_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", status_event) await a2a_runtime.publish(task.id, status_event) await a2a_runtime.notify_webhooks(db, task, status_row) + input_message = db.query(A2AMessage).filter(A2AMessage.task_id == task.id).order_by(A2AMessage.id.asc()).first() + message_text = _message_to_task_input(request.message) if input_message is None else task.input_text + if task.source == "a2a" and task.remote_agent_id: agent = db.query(A2ARemoteAgent).filter(A2ARemoteAgent.id == task.remote_agent_id).first() if not agent: raise RuntimeError("remote_agent_missing") - response_text, metadata = await _delegate_to_remote(task, agent, request.message) + response_text, metadata = await _delegate_to_remote(task, agent, message_text) else: response_text = await nanobot_service.process_message( - request.message, + message_text, session_id=f"a2a-task:{task.id}", project_id=task.project_id, ) @@ -306,7 +749,7 @@ async def _run_task(task_id: str, request: SendMessageRequest, tenant_id: int) - task = a2a_runtime.transition_task( db, task, - to_state="COMPLETED", + to_state=A2ATaskState.COMPLETED, output_text=response_text or "", metadata=metadata, ) @@ -320,8 +763,8 @@ async def _run_task(task_id: str, request: SendMessageRequest, tenant_id: int) - span.set_attributes(build_error_attributes(exc, stage="a2a_task_execute")) await a2a_runtime.metrics.incr("a2a.requests.error") task = db.query(A2ATask).filter(A2ATask.id == task.id).first() - if task and task.state not in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}: - task = a2a_runtime.transition_task(db, task, to_state="FAILED", error_message=_json_dumps({"message": _mask_error(str(exc))})) + if task and task.state not in {A2ATaskState.COMPLETED, A2ATaskState.FAILED, A2ATaskState.CANCELED, A2ATaskState.REJECTED}: + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.FAILED, error_message=_json_dumps({"message": _mask_error(str(exc))})) fail_event = _build_status_event(task, compatibility_mode=task.compatibility_mode, dual_event_write=True) fail_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", fail_event) await a2a_runtime.publish(task.id, fail_event) @@ -330,25 +773,17 @@ async def _run_task(task_id: str, request: SendMessageRequest, tenant_id: int) - db.close() -@router.get("/agent-card", response_model=AgentCardResponse) -def get_agent_card() -> AgentCardResponse: - return AgentCardResponse( - name="DataClaw A2A Gateway", - protocol_version=SUPPORTED_PROTOCOL_VERSION, - capabilities=SUPPORTED_CAPABILITIES, - endpoints={ - "send_message": "/api/v1/a2a/messages/send", - "send_streaming_message": "/api/v1/a2a/messages/stream", - "get_task": "/api/v1/a2a/tasks/{task_id}", - "list_tasks": "/api/v1/a2a/tasks", - "cancel_task": "/api/v1/a2a/tasks/{task_id}/cancel", - "subscribe_task": "/api/v1/a2a/tasks/{task_id}/subscribe", - }, - auth=SUPPORTED_AUTH, - ) +@router.get("/.well-known/agent-card.json", response_model=AgentCardPublicSchema) +def get_agent_card_public() -> AgentCardPublicSchema: + return _build_public_agent_card() -@router.get("/remote-agents", response_model=List[RemoteAgentView]) +@router.get(f"{A2A_API_PREFIX}/agent-card", response_model=AgentCardPublicSchema) +def get_agent_card() -> AgentCardPublicSchema: + return _build_public_agent_card() + + +@router.get(f"{A2A_API_PREFIX}/remote-agents", response_model=List[RemoteAgentView]) def list_remote_agents( project_id: Optional[int] = Query(default=None), db: Session = Depends(get_db), @@ -366,7 +801,7 @@ def list_remote_agents( return [_agent_to_view(item) for item in query.order_by(A2ARemoteAgent.id.desc()).all()] -@router.post("/remote-agents", response_model=RemoteAgentView, status_code=status.HTTP_201_CREATED) +@router.post(f"{A2A_API_PREFIX}/remote-agents", response_model=RemoteAgentView, status_code=status.HTTP_201_CREATED) async def create_remote_agent( payload: RemoteAgentCreate, db: Session = Depends(get_db), @@ -379,6 +814,17 @@ async def create_remote_agent( base_url=payload.base_url.strip().rstrip("/"), auth_scheme=payload.auth_scheme, auth_token=payload.auth_token, + shared_secret=payload.shared_secret, + mtls_ca_cert=payload.mtls_ca_cert, + mtls_client_cert=payload.mtls_client_cert, + mtls_client_key=payload.mtls_client_key, + oauth2_client_id=payload.oauth2_client_id, + oauth2_client_secret=payload.oauth2_client_secret, + oauth2_token_url=payload.oauth2_token_url, + oauth2_scopes=payload.oauth2_scopes, + oidc_issuer_url=payload.oidc_issuer_url, + oidc_client_id=payload.oidc_client_id, + oidc_client_secret=payload.oidc_client_secret, created_by=current_user.id, ) db.add(item) @@ -400,7 +846,7 @@ async def create_remote_agent( return _agent_to_view(item) -@router.put("/remote-agents/{agent_id}", response_model=RemoteAgentView) +@router.put(f"{A2A_API_PREFIX}/remote-agents/{{agent_id}}", response_model=RemoteAgentView) async def update_remote_agent( agent_id: int, payload: RemoteAgentUpdate, @@ -432,7 +878,7 @@ async def update_remote_agent( return _agent_to_view(item) -@router.delete("/remote-agents/{agent_id}") +@router.delete(f"{A2A_API_PREFIX}/remote-agents/{{agent_id}}") def delete_remote_agent( agent_id: int, db: Session = Depends(get_db), @@ -453,7 +899,7 @@ def delete_remote_agent( return {"status": "success"} -@router.post("/remote-agents/{agent_id}/refresh-card", response_model=RemoteAgentView) +@router.post(f"{A2A_API_PREFIX}/remote-agents/{{agent_id}}/refresh-card", response_model=RemoteAgentView) async def refresh_remote_agent_card( agent_id: int, db: Session = Depends(get_db), @@ -480,7 +926,7 @@ async def refresh_remote_agent_card( return _agent_to_view(item) -@router.post("/remote-agents/{agent_id}/health-check") +@router.post(f"{A2A_API_PREFIX}/remote-agents/{{agent_id}}/health-check") async def health_check_remote_agent( agent_id: int, db: Session = Depends(get_db), @@ -494,41 +940,125 @@ async def health_check_remote_agent( return {"healthy": False, "failure_count": item.failure_count} -@router.post("/messages/send") +@router.post("/message:send") async def send_message( request: SendMessageRequest, + response: Response, x_a2a_token: Optional[str] = Header(default=None), db: Session = Depends(get_db), current_user: CurrentUser = Depends(get_current_user), -) -> Dict[str, Any]: - _ensure_project_access(db, request.project_id, current_user) - config = a2a_runtime.get_project_config(db, request.project_id, current_user.id) - route = a2a_runtime.resolve_route( - project_config=config, - session_id=request.session_id, - requested_mode=request.route_mode, - requested_fallback=request.fallback_chain, - ) - selected_source = "local" + _version_check: None = Depends(verify_a2a_version), +) -> StreamResponse: + message = request.message + project_id = message.parts[0].data.get("project_id") if message.parts and message.parts[0].data else None + if not project_id: + raise HTTPException(status_code=400, detail="project_id required in message part data") + + _ensure_project_access(db, project_id, current_user) + config = a2a_runtime.get_project_config(db, project_id, current_user.id) + + message_id_str = message.messageId + context_id = request.contextId or message.contextId + task_id = request.taskId or message.taskId + + existing_task = None + if task_id: + existing_task = db.query(A2ATask).filter(A2ATask.id == task_id).first() + if existing_task and existing_task.tenant_id != current_user.id and not current_user.is_admin: + raise HTTPException(status_code=404, detail="Task not found") + + input_text = _message_to_task_input(message) + + if existing_task: + msg_record = A2AMessage( + message_id=message_id_str, + context_id=context_id, + task_id=existing_task.id, + role=message.role, + extensions_json=_json_dumps(message.extensions or {}), + reference_task_ids_json=_json_dumps(message.referenceTaskIds or []), + ) + db.add(msg_record) + for idx, part in enumerate(message.parts): + part_record = A2APart( + message_id=msg_record.id, + part_type=part.part_type, + text_content=part.text, + raw_content=part.raw, + url_content=part.url, + data_content=str(part.data) if part.data else None, + media_type=part.mediaType, + filename=part.filename, + metadata_json=_json_dumps(part.metadata or {}), + ) + db.add(part_record) + db.commit() + asyncio.create_task(_run_task(existing_task.id, request, current_user.id)) + return StreamResponse( + task=StreamResponseTask( + id=existing_task.id, + contextId=existing_task.context_id, + state=SchemaTaskState(existing_task.state.value), + artifacts=[], + ) + ) + + route_selected = "local" remote_agent_id = None - if route.selected == "a2a" and request.remote_agent_id: - agent = _ensure_agent_access(db, request.remote_agent_id, current_user) - if not agent.healthy and config.rollback_to_local: - selected_source = "local" - else: - selected_source = "a2a" - remote_agent_id = agent.id + agent = None + if message.parts and message.parts[0].data: + route_mode = message.parts[0].data.get("route_mode", "auto") + remote_agent_id_param = message.parts[0].data.get("remote_agent_id") + if route_mode == "a2a" and remote_agent_id_param: + agent = _ensure_agent_access(db, remote_agent_id_param, current_user) + if not agent.healthy and config.rollback_to_local: + route_selected = "local" + else: + route_selected = "a2a" + remote_agent_id = agent.id + + idempotency_key = message.parts[0].data.get("idempotency_key") if message.parts and message.parts[0].data else None + metadata = message.parts[0].data.get("metadata", {}) if message.parts and message.parts[0].data else {} + task = a2a_runtime.create_task( db, - project_id=request.project_id, + project_id=project_id, tenant_id=current_user.id, - source=selected_source, - input_text=request.message, - idempotency_key=request.idempotency_key, + source=route_selected, + input_text=input_text, + idempotency_key=idempotency_key, remote_agent_id=remote_agent_id, compatibility_mode=config.compatibility_mode, - metadata={"route": route.model_dump() if hasattr(route, "model_dump") else route.__dict__, "token_present": bool(x_a2a_token), "request_metadata": request.metadata or {}}, + metadata={"route_selected": route_selected, "token_present": bool(x_a2a_token), "request_metadata": metadata}, + context_id=context_id, ) + + msg_record = A2AMessage( + message_id=message_id_str, + context_id=context_id, + task_id=task.id, + role=message.role, + extensions_json=_json_dumps(message.extensions or {}), + reference_task_ids_json=_json_dumps(message.referenceTaskIds or []), + ) + db.add(msg_record) + for idx, part in enumerate(message.parts): + part_record = A2APart( + message_id=msg_record.id, + part_type=part.part_type, + text_content=part.text, + raw_content=part.raw, + url_content=part.url, + data_content=str(part.data) if part.data else None, + media_type=part.mediaType, + filename=part.filename, + metadata_json=_json_dumps(part.metadata or {}), + ) + db.add(part_record) + + task.context_id = context_id + db.commit() + event_payload = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write) event_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", event_payload) await a2a_runtime.publish(task.id, event_payload) @@ -545,83 +1075,327 @@ async def send_message( project_id=task.project_id, task_id=task.id, ) - return {"task": _task_to_view(task).model_dump(), "routing": route.__dict__} + + task_record = db.query(A2ATask).filter(A2ATask.id == task.id).first() + return StreamResponse( + task=StreamResponseTask( + id=task_record.id, + contextId=task_record.context_id, + state=SchemaTaskState(task_record.state.value), + artifacts=[], + ) + ) -@router.post("/messages/stream") +@router.post("/message:stream") async def send_streaming_message( - request: SendMessageRequest, + request: SendStreamingMessageRequest, + response: Response, + x_a2a_token: Optional[str] = Header(default=None), db: Session = Depends(get_db), current_user: CurrentUser = Depends(get_current_user), + _version_check: None = Depends(verify_a2a_version), ) -> StreamingResponse: - response = await send_message(request=request, x_a2a_token=None, db=db, current_user=current_user) - task_id = response["task"]["id"] + message = request.message + project_id = message.parts[0].data.get("project_id") if message.parts and message.parts[0].data else None + if not project_id: + raise HTTPException(status_code=400, detail="project_id required in message part data") + + _ensure_project_access(db, project_id, current_user) + config = a2a_runtime.get_project_config(db, project_id, current_user.id) + + message_id_str = message.messageId + context_id = request.contextId or message.contextId + task_id = request.taskId or message.taskId + + existing_task = None + if task_id: + existing_task = db.query(A2ATask).filter(A2ATask.id == task_id).first() + + input_text = _message_to_task_input(message) + + task_context_id = None + if existing_task: + msg_record = A2AMessage( + message_id=message_id_str, + context_id=context_id, + task_id=existing_task.id, + role=message.role, + extensions_json=_json_dumps(message.extensions or {}), + reference_task_ids_json=_json_dumps(message.referenceTaskIds or []), + ) + db.add(msg_record) + for idx, part in enumerate(message.parts): + part_record = A2APart( + message_id=msg_record.id, + part_type=part.part_type, + text_content=part.text, + raw_content=part.raw, + url_content=part.url, + data_content=str(part.data) if part.data else None, + media_type=part.mediaType, + filename=part.filename, + metadata_json=_json_dumps(part.metadata or {}), + ) + db.add(part_record) + db.commit() + task_context_id = existing_task.context_id + asyncio.create_task(_run_task(existing_task.id, request, current_user.id)) + task_id = existing_task.id + else: + route_selected = "local" + remote_agent_id = None + if message.parts and message.parts[0].data: + route_mode = message.parts[0].data.get("route_mode", "auto") + remote_agent_id_param = message.parts[0].data.get("remote_agent_id") + if route_mode == "a2a" and remote_agent_id_param: + agent = _ensure_agent_access(db, remote_agent_id_param, current_user) + if not agent.healthy and config.rollback_to_local: + route_selected = "local" + else: + route_selected = "a2a" + remote_agent_id = agent.id + + idempotency_key = message.parts[0].data.get("idempotency_key") if message.parts and message.parts[0].data else None + metadata = message.parts[0].data.get("metadata", {}) if message.parts and message.parts[0].data else {} + + task = a2a_runtime.create_task( + db, + project_id=project_id, + tenant_id=current_user.id, + source=route_selected, + input_text=input_text, + idempotency_key=idempotency_key, + remote_agent_id=remote_agent_id, + compatibility_mode=config.compatibility_mode, + metadata={"route_selected": route_selected, "token_present": bool(x_a2a_token), "request_metadata": metadata}, + context_id=context_id, + ) + + msg_record = A2AMessage( + message_id=message_id_str, + context_id=context_id, + task_id=task.id, + role=message.role, + extensions_json=_json_dumps(message.extensions or {}), + reference_task_ids_json=_json_dumps(message.referenceTaskIds or []), + ) + db.add(msg_record) + for idx, part in enumerate(message.parts): + part_record = A2APart( + message_id=msg_record.id, + part_type=part.part_type, + text_content=part.text, + raw_content=part.raw, + url_content=part.url, + data_content=str(part.data) if part.data else None, + media_type=part.mediaType, + filename=part.filename, + metadata_json=_json_dumps(part.metadata or {}), + ) + db.add(part_record) + + task.context_id = context_id + db.commit() + task_context_id = task.context_id + + event_payload = _build_status_event(task, compatibility_mode=config.compatibility_mode, dual_event_write=config.dual_event_write) + event_row = a2a_runtime.append_event(db, task, "TaskStatusUpdateEvent", event_payload) + await a2a_runtime.publish(task.id, event_payload) + await a2a_runtime.notify_webhooks(db, task, event_row) + asyncio.create_task(_run_task(task.id, request, current_user.id)) + task_id = task.id + + async def _collect_events_to_queue(task_id: str, queue: asyncio.Queue, context_id: Optional[str]) -> None: + try: + history = ( + db.query(A2ATaskEvent) + .filter(A2ATaskEvent.task_id == task_id) + .order_by(A2ATaskEvent.id.asc()) + .all() + ) + for item in history: + payload = _json_loads(item.payload_json, {}) + if payload.get("type") == "TaskStatusUpdateEvent": + task_obj = db.query(A2ATask).filter(A2ATask.id == task_id).first() + event = TaskStatusUpdateEvent( + taskId=task_id, + contextId=task_obj.context_id if task_obj else context_id, + status=A2ATaskStatusSchema( + state=SchemaTaskState(payload.get("task_status", "WORKING")), + timestamp=datetime.utcnow(), + ), + metadata=payload.get("metadata", {}), + ) + await queue.put(("TaskStatusUpdateEvent", event.model_dump(mode='json'))) + elif payload.get("type") == "TaskArtifactUpdateEvent": + task_obj = db.query(A2ATask).filter(A2ATask.id == task_id).first() + content = payload.get("artifact", {}).get("content", "") + event = TaskArtifactUpdateEvent( + taskId=task_id, + contextId=task_obj.context_id if task_obj else context_id, + artifact=A2AArtifactSchema( + artifactId=f"artifact-{item.id}", + parts=[A2APartSchema(part_type="text", text=content)], + ), + append=False, + lastChunk=True, + ) + await queue.put(("TaskArtifactUpdateEvent", event.model_dump(mode='json'))) + elif payload.get("type") == "Message": + msg_event = TaskMessageEvent( + message=A2AMessageSchema( + messageId=payload.get("messageId", ""), + contextId=payload.get("contextId", context_id), + taskId=task_id, + role=A2AMessageRole(payload.get("role", "agent")), + parts=[A2APartSchema(part_type="text", text=payload.get("content", ""))], + ) + ) + await queue.put(("Message", msg_event.model_dump(mode='json'))) + else: + await queue.put(("raw", payload)) + + async for payload in a2a_runtime.subscribe(task_id): + if payload.get("type") == "TaskStatusUpdateEvent": + task_obj = db.query(A2ATask).filter(A2ATask.id == task_id).first() + event = TaskStatusUpdateEvent( + taskId=task_id, + contextId=task_obj.context_id if task_obj else context_id, + status=A2ATaskStatusSchema( + state=SchemaTaskState(payload.get("task_status", "WORKING")), + timestamp=datetime.utcnow(), + ), + metadata=payload.get("metadata", {}), + ) + await queue.put(("TaskStatusUpdateEvent", event.model_dump(mode='json'))) + elif payload.get("type") == "TaskArtifactUpdateEvent": + task_obj = db.query(A2ATask).filter(A2ATask.id == task_id).first() + content = payload.get("artifact", {}).get("content", "") + event = TaskArtifactUpdateEvent( + taskId=task_id, + contextId=task_obj.context_id if task_obj else context_id, + artifact=A2AArtifactSchema( + artifactId=f"artifact-stream-{datetime.utcnow().timestamp()}", + parts=[A2APartSchema(part_type="text", text=content)], + ), + append=False, + lastChunk=True, + ) + await queue.put(("TaskArtifactUpdateEvent", event.model_dump(mode='json'))) + elif payload.get("type") == "Message": + msg_event = TaskMessageEvent( + message=A2AMessageSchema( + messageId=payload.get("messageId", ""), + contextId=payload.get("contextId", context_id), + taskId=task_id, + role=A2AMessageRole(payload.get("role", "agent")), + parts=[A2APartSchema(part_type="text", text=payload.get("content", ""))], + ) + ) + await queue.put(("Message", msg_event.model_dump(mode='json'))) + else: + await queue.put(("raw", payload)) + + if payload.get("task_status") in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}: + await queue.put(("terminal", None)) + break + except Exception: + await queue.put(("error", None)) + finally: + await queue.put(("close", None)) async def event_generator(): - history = ( - db.query(A2ATaskEvent) - .filter(A2ATaskEvent.task_id == task_id) - .order_by(A2ATaskEvent.id.asc()) - .all() - ) - for item in history: - payload = _json_loads(item.payload_json, {}) - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - async for payload in a2a_runtime.subscribe(task_id): - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - if payload.get("task_status") in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}: - break - yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + queue: asyncio.Queue = asyncio.Queue(maxsize=200) + collector = asyncio.create_task(_collect_events_to_queue(task_id, queue, task_context_id)) - return StreamingResponse( + message_only = True + while True: + event_type, event_data = await queue.get() + if event_type == "close": + break + if event_type == "error": + break + if event_type == "terminal": + yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + break + if event_type in ("TaskStatusUpdateEvent", "TaskArtifactUpdateEvent"): + message_only = False + yield f"data: {json.dumps(event_data, ensure_ascii=False, default=_json_serialize)}\n\n" + if event_type == "Message": + break + + if message_only: + yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + + collector.cancel() + + return A2AStreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) -@router.get("/tasks/{task_id}", response_model=TaskView) +@router.get("/tasks/{{task_id}}") def get_task( task_id: str, + response: Response, + historyLength: Optional[int] = Query(default=None, description="Number of history messages to include"), db: Session = Depends(get_db), current_user: CurrentUser = Depends(get_current_user), -) -> TaskView: + _version_check: None = Depends(verify_a2a_version), +) -> A2ATaskWithHistorySchema: task = _ensure_task_access(db, task_id, current_user) - return _task_to_view(task) + return _task_to_with_history(task, history_length=historyLength) -@router.get("/tasks", response_model=List[TaskView]) +@router.get("/tasks") def list_tasks( - project_id: Optional[int] = Query(default=None), - state: Optional[str] = Query(default=None), - skip: int = Query(default=0, ge=0), - limit: int = Query(default=50, ge=1, le=200), + response: Response, + contextId: Optional[str] = Query(default=None, description="Filter by context ID"), + status: Optional[SchemaTaskState] = Query(default=None, description="Filter by task status"), + pageSize: int = Query(default=20, ge=1, le=100, description="Number of items per page"), + pageToken: Optional[str] = Query(default=None, description="Pagination token"), db: Session = Depends(get_db), current_user: CurrentUser = Depends(get_current_user), -) -> List[TaskView]: + _version_check: None = Depends(verify_a2a_version), +) -> Dict[str, Any]: query = db.query(A2ATask) if not current_user.is_admin: query = query.filter(A2ATask.tenant_id == current_user.id) - if project_id is not None: - _ensure_project_access(db, project_id, current_user) - query = query.filter(A2ATask.project_id == project_id) - if state: - query = query.filter(A2ATask.state == state) - tasks = query.order_by(A2ATask.created_at.desc()).offset(skip).limit(limit).all() - return [_task_to_view(item) for item in tasks] + + if contextId: + query = query.filter(A2ATask.context_id == contextId) + if status: + query = query.filter(A2ATask.state == A2ATaskState(status.value)) + + total = query.count() + tasks = query.order_by(A2ATask.created_at.desc()).offset(0).limit(pageSize).all() + + task_schemas = [_task_to_schema(item) for item in tasks] + + return { + "items": [t.model_dump(mode='json') for t in task_schemas], + "nextPageToken": str(tasks[-1].id) if tasks else None, + "contextId": contextId, + } -@router.post("/tasks/{task_id}/cancel", response_model=CancelTaskResponse) +@router.post("/tasks/{task_id}:cancel") async def cancel_task( task_id: str, + request: CancelTaskRequest, + response: Response, db: Session = Depends(get_db), current_user: CurrentUser = Depends(get_current_user), + _version_check: None = Depends(verify_a2a_version), ) -> CancelTaskResponse: task = _ensure_task_access(db, task_id, current_user) - if task.state in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}: - return CancelTaskResponse(task_id=task.id, state=task.state) + if task.state in {A2ATaskState.COMPLETED, A2ATaskState.FAILED, A2ATaskState.CANCELED, A2ATaskState.REJECTED}: + return CancelTaskResponse(task_id=task.id, state=task.state.value) try: - task = a2a_runtime.transition_task(db, task, to_state="CANCELED") + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.CANCELED) except ValueError: raise HTTPException(status_code=409, detail="Task state transition conflict") config = a2a_runtime.get_project_config(db, task.project_id, current_user.id) @@ -639,44 +1413,232 @@ async def cancel_task( project_id=task.project_id, task_id=task.id, ) - return CancelTaskResponse(task_id=task.id, state=task.state) + return CancelTaskResponse(task_id=task.id, state=task.state.value) -@router.get("/tasks/{task_id}/subscribe") +@router.get("/tasks/{task_id}:subscribe") async def subscribe_task( task_id: str, + response: Response, db: Session = Depends(get_db), current_user: CurrentUser = Depends(get_current_user), + _version_check: None = Depends(verify_a2a_version), ) -> StreamingResponse: task = _ensure_task_access(db, task_id, current_user) - initial_events = ( - db.query(A2ATaskEvent) - .filter(A2ATaskEvent.task_id == task.id) - .order_by(A2ATaskEvent.id.asc()) - .all() - ) + + async def _collect_subscribe_events_to_queue(task_id: str, queue: asyncio.Queue, context_id: Optional[str]) -> None: + try: + initial_events = ( + db.query(A2ATaskEvent) + .filter(A2ATaskEvent.task_id == task_id) + .order_by(A2ATaskEvent.id.asc()) + .all() + ) + for event in initial_events: + payload = _json_loads(event.payload_json, {}) + if payload.get("type") == "TaskStatusUpdateEvent": + evt = TaskStatusUpdateEvent( + taskId=task_id, + contextId=context_id, + status=A2ATaskStatusSchema( + state=SchemaTaskState(payload.get("task_status", "WORKING")), + timestamp=datetime.utcnow(), + ), + metadata=payload.get("metadata", {}), + ) + await queue.put(("TaskStatusUpdateEvent", evt.model_dump(mode='json'))) + elif payload.get("type") == "TaskArtifactUpdateEvent": + content = payload.get("artifact", {}).get("content", "") + evt = TaskArtifactUpdateEvent( + taskId=task_id, + contextId=context_id, + artifact=A2AArtifactSchema( + artifactId=f"artifact-{event.id}", + parts=[A2APartSchema(part_type="text", text=content)], + ), + append=False, + lastChunk=True, + ) + await queue.put(("TaskArtifactUpdateEvent", evt.model_dump(mode='json'))) + elif payload.get("type") == "Message": + msg_event = TaskMessageEvent( + message=A2AMessageSchema( + messageId=payload.get("messageId", ""), + contextId=payload.get("contextId", context_id), + taskId=task_id, + role=A2AMessageRole(payload.get("role", "agent")), + parts=[A2APartSchema(part_type="text", text=payload.get("content", ""))], + ) + ) + await queue.put(("Message", msg_event.model_dump(mode='json'))) + else: + await queue.put(("raw", payload)) + + if task.state in {A2ATaskState.COMPLETED, A2ATaskState.FAILED, A2ATaskState.CANCELED, A2ATaskState.REJECTED}: + await queue.put(("terminal", None)) + return + + async for payload in a2a_runtime.subscribe(task.id): + if payload.get("type") == "TaskStatusUpdateEvent": + evt = TaskStatusUpdateEvent( + taskId=task_id, + contextId=context_id, + status=A2ATaskStatusSchema( + state=SchemaTaskState(payload.get("task_status", "WORKING")), + timestamp=datetime.utcnow(), + ), + metadata=payload.get("metadata", {}), + ) + await queue.put(("TaskStatusUpdateEvent", evt.model_dump(mode='json'))) + elif payload.get("type") == "TaskArtifactUpdateEvent": + content = payload.get("artifact", {}).get("content", "") + evt = TaskArtifactUpdateEvent( + taskId=task_id, + contextId=context_id, + artifact=A2AArtifactSchema( + artifactId=f"artifact-stream-{datetime.utcnow().timestamp()}", + parts=[A2APartSchema(part_type="text", text=content)], + ), + append=False, + lastChunk=True, + ) + await queue.put(("TaskArtifactUpdateEvent", evt.model_dump(mode='json'))) + elif payload.get("type") == "Message": + msg_event = TaskMessageEvent( + message=A2AMessageSchema( + messageId=payload.get("messageId", ""), + contextId=payload.get("contextId", context_id), + taskId=task_id, + role=A2AMessageRole(payload.get("role", "agent")), + parts=[A2APartSchema(part_type="text", text=payload.get("content", ""))], + ) + ) + await queue.put(("Message", msg_event.model_dump(mode='json'))) + else: + await queue.put(("raw", payload)) + + if payload.get("task_status") in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}: + await queue.put(("terminal", None)) + break + except Exception: + await queue.put(("error", None)) + finally: + await queue.put(("close", None)) async def event_generator(): - for event in initial_events: - payload = _json_loads(event.payload_json, {}) - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - if task.state in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}: - yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" - return - async for payload in a2a_runtime.subscribe(task.id): - yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" - if payload.get("task_status") in {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}: - break - yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + queue: asyncio.Queue = asyncio.Queue(maxsize=200) + collector = asyncio.create_task(_collect_subscribe_events_to_queue(task_id, queue, task.context_id)) - return StreamingResponse( + while True: + event_type, event_data = await queue.get() + if event_type == "close": + break + if event_type == "error": + break + if event_type == "terminal": + yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + break + yield f"data: {json.dumps(event_data, ensure_ascii=False, default=_json_serialize)}\n\n" + if event_type == "Message": + break + + yield f"data: {json.dumps({'type': 'done'}, ensure_ascii=False)}\n\n" + collector.cancel() + + return A2AStreamingResponse( event_generator(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}, ) -@router.get("/tasks/{task_id}/webhooks", response_model=List[TaskWebhookView]) +@router.post("/tasks/{task_id}/pushNotificationConfigs", response_model=PushNotificationConfig, status_code=status.HTTP_201_CREATED) +def create_push_notification_config( + task_id: str, + payload: PushNotificationConfigCreate, + response: Response, + db: Session = Depends(get_db), + current_user: CurrentUser = Depends(get_current_user), + _version_check: None = Depends(verify_a2a_version), +) -> PushNotificationConfig: + task = _ensure_task_access(db, task_id, current_user) + item = A2ATaskWebhook( + task_id=task.id, + target_url=payload.targetUrl, + secret=payload.secret, + auth_header=payload.authHeader, + enabled=payload.enabled, + created_by=current_user.id, + ) + db.add(item) + db.commit() + db.refresh(item) + return PushNotificationConfig( + id=item.id, + taskId=item.task_id, + targetUrl=item.target_url, + secret=item.secret, + authHeader=item.auth_header, + enabled=item.enabled, + createdBy=item.created_by, + createdAt=item.created_at, + ) + + +@router.get("/tasks/{task_id}/pushNotificationConfigs", response_model=List[PushNotificationConfig]) +def list_push_notification_configs( + task_id: str, + response: Response, + db: Session = Depends(get_db), + current_user: CurrentUser = Depends(get_current_user), + _version_check: None = Depends(verify_a2a_version), +) -> List[PushNotificationConfig]: + task = _ensure_task_access(db, task_id, current_user) + items = db.query(A2ATaskWebhook).filter(A2ATaskWebhook.task_id == task.id).order_by(A2ATaskWebhook.id.desc()).all() + return [ + PushNotificationConfig( + id=item.id, + taskId=item.task_id, + targetUrl=item.target_url, + secret=item.secret, + authHeader=item.auth_header, + enabled=item.enabled, + createdBy=item.created_by, + createdAt=item.created_at, + ) + for item in items + ] + + +@router.delete("/tasks/{task_id}/pushNotificationConfigs/{config_id}") +def delete_push_notification_config( + task_id: str, + config_id: int, + response: Response, + db: Session = Depends(get_db), + current_user: CurrentUser = Depends(get_current_user), + _version_check: None = Depends(verify_a2a_version), +) -> Dict[str, str]: + task = _ensure_task_access(db, task_id, current_user) + item = db.query(A2ATaskWebhook).filter( + A2ATaskWebhook.id == config_id, + A2ATaskWebhook.task_id == task.id, + ).first() + if not item: + raise HTTPException(status_code=404, detail="Push notification config not found") + db.delete(item) + db.commit() + return {"status": "success"} + + +@router.get(f"{A2A_API_PREFIX}/extendedAgentCard", response_model=AgentCardExtendedSchema) +def get_extended_agent_card( + current_user: CurrentUser = Depends(get_current_user), +) -> AgentCardExtendedSchema: + return _build_extended_agent_card(current_user) + + +@router.get(f"{A2A_API_PREFIX}/tasks/{{task_id}}/webhooks", response_model=List[TaskWebhookView]) def list_task_webhooks( task_id: str, db: Session = Depends(get_db), @@ -697,7 +1659,7 @@ def list_task_webhooks( ] -@router.post("/tasks/{task_id}/webhooks", response_model=TaskWebhookView, status_code=status.HTTP_201_CREATED) +@router.post(f"{A2A_API_PREFIX}/tasks/{{task_id}}/webhooks", response_model=TaskWebhookView, status_code=status.HTTP_201_CREATED) def create_task_webhook( task_id: str, payload: TaskWebhookCreate, @@ -735,7 +1697,7 @@ def create_task_webhook( ) -@router.delete("/tasks/{task_id}/webhooks/{webhook_id}") +@router.delete(f"{A2A_API_PREFIX}/tasks/{{task_id}}/webhooks/{{webhook_id}}") def delete_task_webhook( task_id: str, webhook_id: int, @@ -751,7 +1713,7 @@ def delete_task_webhook( return {"status": "success"} -@router.post("/webhook-deliveries/{delivery_id}/replay") +@router.post(f"{A2A_API_PREFIX}/webhook-deliveries/{{delivery_id}}/replay") async def replay_delivery( delivery_id: int, db: Session = Depends(get_db), @@ -769,14 +1731,14 @@ async def replay_delivery( return {"status": delivery.status, "attempt": delivery.attempt, "dead_letter": delivery.dead_letter, "task_id": task.id} -@router.get("/metrics") +@router.get(f"{A2A_API_PREFIX}/metrics") async def get_metrics(current_user: CurrentUser = Depends(get_current_user)) -> Dict[str, Any]: if not current_user.is_admin: raise HTTPException(status_code=403, detail="Admin permission required") return await a2a_runtime.metrics.snapshot() -@router.get("/projects/{project_id}/rollout", response_model=RolloutConfigView) +@router.get(f"{A2A_API_PREFIX}/projects/{{project_id}}/rollout", response_model=RolloutConfigView) def get_rollout_config( project_id: int, db: Session = Depends(get_db), @@ -797,7 +1759,7 @@ def get_rollout_config( ) -@router.put("/projects/{project_id}/rollout", response_model=RolloutConfigView) +@router.put(f"{A2A_API_PREFIX}/projects/{{project_id}}/rollout", response_model=RolloutConfigView) def update_rollout_config( project_id: int, payload: RolloutConfigUpdate, @@ -841,7 +1803,7 @@ def update_rollout_config( ) -@router.get("/alerts") +@router.get(f"{A2A_API_PREFIX}/alerts") def get_alert_panel( project_id: int, db: Session = Depends(get_db), @@ -855,11 +1817,11 @@ def get_alert_panel( return { "project_id": project_id, "thresholds": merged, - "panel": {"metrics_endpoint": "/api/v1/a2a/metrics", "task_list_endpoint": "/api/v1/a2a/tasks"}, + "panel": {"metrics_endpoint": "/api/v1/a2a/metrics", "task_list_endpoint": "/tasks"}, } -@router.get("/audit-logs") +@router.get(f"{A2A_API_PREFIX}/audit-logs") def list_audit_logs( project_id: Optional[int] = Query(default=None), skip: int = Query(default=0, ge=0), @@ -888,4 +1850,4 @@ def list_audit_logs( "created_at": row.created_at, } for row in rows - ] + ] \ No newline at end of file diff --git a/backend/app/models/a2a.py b/backend/app/models/a2a.py index 84ca3ef..9ee5ce0 100644 --- a/backend/app/models/a2a.py +++ b/backend/app/models/a2a.py @@ -1,9 +1,34 @@ -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, func +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, JSON, Enum as SQLEnum, func from sqlalchemy.orm import relationship +import enum from app.database import Base +class A2ATaskState(str, enum.Enum): + SUBMITTED = "SUBMITTED" + WORKING = "WORKING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + CANCELED = "CANCELED" + INPUT_REQUIRED = "INPUT_REQUIRED" + AUTH_REQUIRED = "AUTH_REQUIRED" + REJECTED = "REJECTED" + + +class A2APartType(str, enum.Enum): + TEXT = "text" + RAW = "raw" + URL = "url" + DATA = "data" + + +class A2AMessageRole(str, enum.Enum): + USER = "user" + AGENT = "agent" + SYSTEM = "system" + + class A2ARemoteAgent(Base): __tablename__ = "a2a_remote_agents" @@ -13,6 +38,17 @@ class A2ARemoteAgent(Base): base_url = Column(String, nullable=False) auth_scheme = Column(String, nullable=False, default="none") auth_token = Column(String, nullable=True) + shared_secret = Column(String, nullable=True) + mtls_ca_cert = Column(Text, nullable=True) + mtls_client_cert = Column(Text, nullable=True) + mtls_client_key = Column(Text, nullable=True) + oauth2_client_id = Column(String, nullable=True) + oauth2_client_secret = Column(String, nullable=True) + oauth2_token_url = Column(String, nullable=True) + oauth2_scopes = Column(String, nullable=True) + oidc_issuer_url = Column(String, nullable=True) + oidc_client_id = Column(String, nullable=True) + oidc_client_secret = Column(String, nullable=True) protocol_version = Column(String, nullable=True) capabilities_json = Column(Text, nullable=False, default="[]") card_json = Column(Text, nullable=True) @@ -27,27 +63,84 @@ class A2ARemoteAgent(Base): project = relationship("Project") +class A2APart(Base): + __tablename__ = "a2a_parts" + + id = Column(Integer, primary_key=True, index=True) + message_id = Column(Integer, ForeignKey("a2a_messages.id", ondelete="CASCADE"), nullable=True, index=True) + artifact_id = Column(Integer, ForeignKey("a2a_artifacts.id", ondelete="CASCADE"), nullable=True, index=True) + part_type = Column(SQLEnum(A2APartType), nullable=False) + text_content = Column(Text, nullable=True) + raw_content = Column(Text, nullable=True) + url_content = Column(String, nullable=True) + data_content = Column(Text, nullable=True) + media_type = Column(String, nullable=True) + filename = Column(String, nullable=True) + metadata_json = Column(Text, nullable=False, default="{}") + created_at = Column(DateTime, default=func.now()) + + message = relationship("A2AMessage", back_populates="parts", foreign_keys=[message_id]) + artifact = relationship("A2AArtifact", back_populates="parts", foreign_keys=[artifact_id]) + + +class A2AMessage(Base): + __tablename__ = "a2a_messages" + + id = Column(Integer, primary_key=True, index=True) + message_id = Column(String, nullable=False, unique=True, index=True) + context_id = Column(String, nullable=True, index=True) + task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=True, index=True) + role = Column(SQLEnum(A2AMessageRole), nullable=False) + extensions_json = Column(Text, nullable=False, default="{}") + reference_task_ids_json = Column(Text, nullable=False, default="[]") + created_at = Column(DateTime, default=func.now(), index=True) + + task = relationship("A2ATask", back_populates="messages", foreign_keys=[task_id]) + parts = relationship("A2APart", back_populates="message", cascade="all, delete-orphan") + + +class A2AArtifact(Base): + __tablename__ = "a2a_artifacts" + + id = Column(Integer, primary_key=True, index=True) + artifact_id = Column(String, nullable=False, unique=True, index=True) + task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True) + name = Column(String, nullable=True) + description = Column(Text, nullable=True) + metadata_json = Column(Text, nullable=False, default="{}") + extensions_json = Column(Text, nullable=False, default="{}") + created_at = Column(DateTime, default=func.now()) + updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) + + task = relationship("A2ATask", back_populates="artifacts") + parts = relationship("A2APart", back_populates="artifact", cascade="all, delete-orphan") + + class A2ATask(Base): __tablename__ = "a2a_tasks" id = Column(String, primary_key=True, index=True) project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True) tenant_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) + context_id = Column(String, nullable=True, index=True) source = Column(String, nullable=False, default="local") remote_agent_id = Column(Integer, ForeignKey("a2a_remote_agents.id"), nullable=True, index=True) idempotency_key = Column(String, nullable=True, index=True) - state = Column(String, nullable=False, index=True, default="SUBMITTED") + state = Column(SQLEnum(A2ATaskState), nullable=False, index=True, default=A2ATaskState.SUBMITTED) input_text = Column(Text, nullable=False, default="") output_text = Column(Text, nullable=True) error_message = Column(Text, nullable=True) compatibility_mode = Column(Boolean, nullable=False, default=True) metadata_json = Column(Text, nullable=False, default="{}") + history_length = Column(Integer, nullable=False, default=0) created_at = Column(DateTime, default=func.now()) updated_at = Column(DateTime, default=func.now(), onupdate=func.now()) finished_at = Column(DateTime, nullable=True) project = relationship("Project") remote_agent = relationship("A2ARemoteAgent") + messages = relationship("A2AMessage", back_populates="task", cascade="all, delete-orphan", foreign_keys=[A2AMessage.task_id]) + artifacts = relationship("A2AArtifact", back_populates="task", cascade="all, delete-orphan") class A2ATaskEvent(Base): diff --git a/backend/app/schemas/a2a.py b/backend/app/schemas/a2a.py new file mode 100644 index 0000000..5fbd8b4 --- /dev/null +++ b/backend/app/schemas/a2a.py @@ -0,0 +1,361 @@ +from pydantic import BaseModel, ConfigDict, Field +from typing import Optional, List, Dict, Any, Literal, Union +from datetime import datetime +from enum import Enum + + +class A2ATaskState(str, Enum): + SUBMITTED = "SUBMITTED" + WORKING = "WORKING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + CANCELED = "CANCELED" + INPUT_REQUIRED = "INPUT_REQUIRED" + AUTH_REQUIRED = "AUTH_REQUIRED" + REJECTED = "REJECTED" + + +class A2APartType(str, Enum): + TEXT = "text" + RAW = "raw" + URL = "url" + DATA = "data" + + +class A2AMessageRole(str, Enum): + USER = "user" + AGENT = "agent" + SYSTEM = "system" + + +class A2APartSchema(BaseModel): + part_type: A2APartType + text: Optional[str] = None + raw: Optional[bytes] = None + url: Optional[str] = None + data: Optional[Any] = None + mediaType: Optional[str] = None + filename: Optional[str] = None + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +class A2APartCreateSchema(BaseModel): + part_type: A2APartType + text: Optional[str] = None + raw: Optional[str] = None + url: Optional[str] = None + data: Optional[Any] = None + mediaType: Optional[str] = None + filename: Optional[str] = None + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +class A2AMessageSchema(BaseModel): + messageId: str + contextId: Optional[str] = None + taskId: Optional[str] = None + role: A2AMessageRole + parts: List[A2APartSchema] = Field(default_factory=list) + extensions: Optional[Dict[str, Any]] = Field(default_factory=dict) + referenceTaskIds: Optional[List[str]] = Field(default_factory=list) + createdAt: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +class A2AMessageCreateSchema(BaseModel): + messageId: str + contextId: Optional[str] = None + taskId: Optional[str] = None + role: A2AMessageRole + parts: List[A2APartCreateSchema] = Field(default_factory=list) + extensions: Optional[Dict[str, Any]] = Field(default_factory=dict) + referenceTaskIds: Optional[List[str]] = Field(default_factory=list) + + +class A2AArtifactSchema(BaseModel): + artifactId: str + name: Optional[str] = None + description: Optional[str] = None + parts: List[A2APartSchema] = Field(default_factory=list) + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + extensions: Optional[Dict[str, Any]] = Field(default_factory=dict) + createdAt: Optional[datetime] = None + updatedAt: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +class A2AArtifactCreateSchema(BaseModel): + artifactId: str + name: Optional[str] = None + description: Optional[str] = None + parts: List[A2APartCreateSchema] = Field(default_factory=list) + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + extensions: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +class A2ATaskStatusSchema(BaseModel): + state: A2ATaskState + timestamp: datetime + + +class A2ATaskSchema(BaseModel): + id: str + contextId: Optional[str] = None + projectId: int + tenantId: int + source: str + remoteAgentId: Optional[int] = None + idempotencyKey: Optional[str] = None + state: A2ATaskState + inputText: str + outputText: Optional[str] = None + errorMessage: Optional[str] = None + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + historyLength: int = 0 + createdAt: datetime + updatedAt: datetime + finishedAt: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +class A2ATaskWithMessagesSchema(A2ATaskSchema): + messages: List[A2AMessageSchema] = Field(default_factory=list) + artifacts: List[A2AArtifactSchema] = Field(default_factory=list) + + +class A2ATaskWithHistorySchema(BaseModel): + id: str + contextId: Optional[str] = None + projectId: int + tenantId: int + state: A2ATaskState + history: List[A2AMessageSchema] = Field(default_factory=list) + artifacts: List[A2AArtifactSchema] = Field(default_factory=list) + createdAt: datetime + updatedAt: datetime + finishedAt: Optional[datetime] = None + + model_config = ConfigDict(from_attributes=True) + + +class TaskStatusUpdateEvent(BaseModel): + taskId: str + contextId: Optional[str] = None + status: A2ATaskStatusSchema + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) + + +class TaskArtifactUpdateEvent(BaseModel): + taskId: str + contextId: Optional[str] = None + artifact: A2AArtifactSchema + append: bool = False + lastChunk: bool = True + + +class TaskMessageEvent(BaseModel): + message: A2AMessageSchema + + +class StreamResponseTask(BaseModel): + id: str + contextId: Optional[str] = None + state: A2ATaskState + artifacts: List[A2AArtifactSchema] = Field(default_factory=list) + + +class StreamResponse(BaseModel): + task: Optional[StreamResponseTask] = None + message: Optional[A2AMessageSchema] = None + statusUpdate: Optional[TaskStatusUpdateEvent] = None + artifactUpdate: Optional[TaskArtifactUpdateEvent] = None + + +class SendMessageRequest(BaseModel): + message: A2AMessageCreateSchema + taskId: Optional[str] = None + contextId: Optional[str] = None + + +class SendStreamingMessageRequest(BaseModel): + message: A2AMessageCreateSchema + taskId: Optional[str] = None + contextId: Optional[str] = None + + +class GetTaskRequest(BaseModel): + historyLength: Optional[int] = None + + +class TaskListRequest(BaseModel): + contextId: Optional[str] = None + status: Optional[A2ATaskState] = None + pageSize: int = 20 + pageToken: Optional[str] = None + + +class CancelTaskRequest(BaseModel): + pass + + +class PushNotificationConfigCreate(BaseModel): + targetUrl: str + secret: Optional[str] = None + authHeader: Optional[str] = None + enabled: bool = True + + +class PushNotificationConfig(BaseModel): + id: int + taskId: str + targetUrl: str + secret: Optional[str] = None + authHeader: Optional[str] = None + enabled: bool + createdBy: int + createdAt: datetime + + model_config = ConfigDict(from_attributes=True) + + +class VersionNotSupportedError(BaseModel): + code: int = -32009 + message: str = "Version not supported" + data: Optional[Dict[str, Any]] = None + + +class AgentSkillInputMode(str, Enum): + TEXT = "text" + DATA = "data" + RAW = "raw" + URL = "url" + + +class AgentSkillOutputMode(str, Enum): + TEXT = "text" + DATA = "data" + ARTIFACT = "artifact" + STREAM = "stream" + + +class AgentSkillSecurityRequirement(BaseModel): + scheme: str + scopes: Optional[List[str]] = None + + +class AgentSkillExample(BaseModel): + input: Dict[str, Any] + output: Dict[str, Any] + + +class AgentSkill(BaseModel): + id: str + name: str + description: Optional[str] = None + tags: List[str] = Field(default_factory=list) + examples: List[AgentSkillExample] = Field(default_factory=list) + inputModes: List[AgentSkillInputMode] = Field(default_factory=list) + outputModes: List[AgentSkillOutputMode] = Field(default_factory=list) + securityRequirements: List[AgentSkillSecurityRequirement] = Field(default_factory=list) + + +class AgentProvider(BaseModel): + organization: str + url: Optional[str] = None + + +class AgentSupportedInterface(BaseModel): + url: str + protocolBinding: str + protocolVersion: str + tenant: Optional[str] = None + + +class SecuritySchemeApiKey(BaseModel): + type: Literal["apiKey"] = "apiKey" + name: str + in_: str = Field(alias="in") + description: Optional[str] = None + + model_config = ConfigDict(populate_by_name=True) + + +class SecuritySchemeHttpAuth(BaseModel): + type: Literal["http"] = "http" + scheme: str + description: Optional[str] = None + + +class OAuth2AuthorizationCodeFlow(BaseModel): + authorizationUrl: str + tokenUrl: str + scopes: Dict[str, str] = Field(default_factory=dict) + refreshUrl: Optional[str] = None + + +class OAuth2ClientCredentialsFlow(BaseModel): + tokenUrl: str + scopes: Dict[str, str] = Field(default_factory=dict) + refreshUrl: Optional[str] = None + + +class OAuth2DeviceCodeFlow(BaseModel): + authorizationUrl: str + tokenUrl: str + scopes: Dict[str, str] = Field(default_factory=dict) + deviceAuthorizationUrl: Optional[str] = None + + +class OAuth2Flows(BaseModel): + authorizationCode: Optional[OAuth2AuthorizationCodeFlow] = None + clientCredentials: Optional[OAuth2ClientCredentialsFlow] = None + deviceCode: Optional[OAuth2DeviceCodeFlow] = None + implicit: Optional[Dict[str, Any]] = None + password: Optional[Dict[str, Any]] = None + + +class SecuritySchemeOAuth2(BaseModel): + type: Literal["oauth2"] = "oauth2" + flows: OAuth2Flows + description: Optional[str] = None + + +class SecuritySchemeOpenIdConnect(BaseModel): + type: Literal["openIdConnect"] = "openIdConnect" + openIdConnectUrl: str + description: Optional[str] = None + scopes: Dict[str, str] = Field(default_factory=dict) + + +class SecuritySchemeMtls(BaseModel): + type: Literal["mutualTLS"] = "mutualTLS" + description: Optional[str] = None + caCerts: Optional[List[str]] = None + clientCert: Optional[str] = None + clientKey: Optional[str] = None + + +class AgentCardPublicSchema(BaseModel): + name: str + protocol_version: str = "1.0" + capabilities: List[str] + endpoints: Dict[str, str] + auth: List[str] + skills: List[AgentSkill] = Field(default_factory=list) + provider: Optional[AgentProvider] = None + supportedInterfaces: List[AgentSupportedInterface] = Field(default_factory=list) + defaultInputModes: List[str] = Field(default_factory=list) + defaultOutputModes: List[str] = Field(default_factory=list) + iconUrl: Optional[str] = None + documentationUrl: Optional[str] = None + + +class AgentCardExtendedSchema(AgentCardPublicSchema): + securitySchemes: Optional[Dict[str, Union[SecuritySchemeApiKey, SecuritySchemeHttpAuth, SecuritySchemeOAuth2, SecuritySchemeOpenIdConnect, SecuritySchemeMtls]]] = None + security: List[Dict[str, List[str]]] = Field(default_factory=list) + signatures: List[str] = Field(default_factory=list) + tenantId: Optional[int] = None + isAdmin: Optional[bool] = None diff --git a/backend/app/services/a2a_service.py b/backend/app/services/a2a_service.py index d5ab069..b27251b 100644 --- a/backend/app/services/a2a_service.py +++ b/backend/app/services/a2a_service.py @@ -4,6 +4,7 @@ import asyncio import hashlib import hmac import json +import ssl import time import uuid from collections import defaultdict, deque @@ -24,20 +25,15 @@ from app.models.a2a import ( A2AWebhookDelivery, ) from app.models.project import Project +from app.schemas.a2a import ( + A2AArtifactSchema, + A2APartSchema, + A2ATaskStatusSchema, + TaskArtifactUpdateEvent, + TaskStatusUpdateEvent, +) from app.trace import build_error_attributes, trace_service -_STATE_TRANSITIONS = { - "SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"}, - "WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"}, - "INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"}, - "AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"}, - "REJECTED": set(), - "FAILED": set(), - "COMPLETED": set(), - "CANCELED": set(), -} -_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"} - def _json_loads(raw: Optional[str], default: Any) -> Any: if not raw: @@ -62,6 +58,292 @@ def _mask_error(message: str) -> str: return "request_failed" +class SharedSecretAuth: + @staticmethod + def generate_signature(secret: str, payload: bytes, timestamp: Optional[int] = None) -> Tuple[str, int]: + if timestamp is None: + timestamp = int(time.time()) + message = f"{timestamp}".encode() + payload + signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest() + return f"sha256={signature}", timestamp + + @staticmethod + def verify_signature(secret: str, payload: bytes, signature: str, timestamp: int, max_age_seconds: int = 300) -> bool: + if abs(time.time() - timestamp) > max_age_seconds: + return False + expected_sig, _ = SharedSecretAuth.generate_signature(secret, payload, timestamp) + return hmac.compare_digest(signature, expected_sig) + + @staticmethod + def sign_request(secret: str, method: str, path: str, body: Optional[bytes] = None) -> Dict[str, str]: + timestamp = int(time.time()) + payload = body or b"" + message = f"{timestamp}.{method.upper()}.{path}".encode() + payload + signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest() + return { + "X-A2A-Signature": f"sha256={signature}", + "X-A2A-Timestamp": str(timestamp), + } + + +class MtlsConfig: + def __init__( + self, + ca_cert: Optional[str] = None, + client_cert: Optional[str] = None, + client_key: Optional[str] = None, + ): + self.ca_cert = ca_cert + self.client_cert = client_cert + self.client_key = client_key + + def create_ssl_context(self) -> Optional[ssl.SSLContext]: + if not self.client_cert or not self.client_key: + return None + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ctx.load_cert_chain(self.client_cert, self.client_key) + if self.ca_cert: + ctx.load_verify_locations(self.ca_cert) + ctx.verify_mode = ssl.CERT_REQUIRED + else: + ctx.verify_mode = ssl.CERT_NONE + return ctx + + +class OAuth2TokenStore: + def __init__(self): + self._tokens: Dict[str, Tuple[str, datetime]] = {} + self._lock = asyncio.Lock() + + async def get_token(self, key: str) -> Optional[str]: + async with self._lock: + if key in self._tokens: + token, expires_at = self._tokens[key] + if expires_at > _utc_now() + timedelta(minutes=1): + return token + return None + + async def set_token(self, key: str, token: str, expires_in: int = 3600) -> None: + async with self._lock: + self._tokens[key] = (token, _utc_now() + timedelta(seconds=expires_in)) + + +class OAuth2Auth: + def __init__( + self, + client_id: str, + client_secret: str, + token_url: str, + scopes: Optional[List[str]] = None, + ): + self.client_id = client_id + self.client_secret = client_secret + self.token_url = token_url + self.scopes = scopes or [] + self._token_store = OAuth2TokenStore() + + def _get_cache_key(self) -> str: + return f"{self.client_id}:{self.token_url}:{':'.join(self.scopes)}" + + async def get_access_token(self, grant_type: str = "client_credentials") -> str: + cache_key = self._get_cache_key() + cached = await self._token_store.get_token(cache_key) + if cached: + return cached + + async with httpx.AsyncClient() as client: + data = { + "client_id": self.client_id, + "client_secret": self.client_secret, + "grant_type": grant_type, + } + if self.scopes: + data["scope"] = " ".join(self.scopes) + resp = await client.post(self.token_url, data=data) + resp.raise_for_status() + token_data = resp.json() + token = token_data["access_token"] + expires_in = token_data.get("expires_in", 3600) + await self._token_store.set_token(cache_key, token, expires_in) + return token + + async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]: + token = await self.get_access_token() + return {"Authorization": f"Bearer {token}"} + + +class OIDCAuth: + def __init__( + self, + issuer_url: str, + client_id: str, + client_secret: Optional[str] = None, + scopes: Optional[List[str]] = None, + ): + self.issuer_url = issuer_url.rstrip("/") + self.client_id = client_id + self.client_secret = client_secret + self.scopes = scopes or ["openid", "profile"] + self._oauth2: Optional[OAuth2Auth] = None + self._discovery_cache: Optional[Dict[str, Any]] = None + + async def _get_discovery(self) -> Dict[str, Any]: + if self._discovery_cache: + return self._discovery_cache + discovery_url = f"{self.issuer_url}/.well-known/openid-configuration" + async with httpx.AsyncClient() as client: + resp = await client.get(discovery_url) + resp.raise_for_status() + self._discovery_cache = resp.json() + return self._discovery_cache + + async def get_access_token(self) -> str: + discovery = await self._get_discovery() + token_url = discovery.get("token_endpoint") + if not token_url: + raise RuntimeError("OIDC discovery missing token_endpoint") + if not self._oauth2: + self._oauth2 = OAuth2Auth( + client_id=self.client_id, + client_secret=self.client_secret or "", + token_url=token_url, + scopes=self.scopes, + ) + return await self._oauth2.get_access_token() + + async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]: + token = await self.get_access_token() + return {"Authorization": f"Bearer {token}"} + + +class RemoteAgentSecuritySelector: + def __init__(self, agent: A2ARemoteAgent): + self.agent = agent + self._card_security_schemes: Optional[Dict[str, Any]] = None + + def load_security_from_card(self) -> None: + card = _json_loads(self.agent.card_json, {}) + if card: + self._card_security_schemes = card.get("securitySchemes", {}) + + def get_preferred_auth_scheme(self) -> str: + card = _json_loads(self.agent.card_json, {}) + security_reqs = card.get("security", []) + if security_reqs: + first_req = security_reqs[0] + if isinstance(first_req, dict): + for scheme_name in first_req.keys(): + return scheme_name + return self.agent.auth_scheme or "bearer" + + def get_auth_headers(self, user_token: Optional[str] = None) -> Dict[str, str]: + headers: Dict[str, str] = {} + preferred = self.get_preferred_auth_scheme() + + if preferred == "bearer" or self.agent.auth_scheme == "bearer": + if self.agent.auth_token: + headers["Authorization"] = f"Bearer {self.agent.auth_token}" + elif user_token: + headers["Authorization"] = f"Bearer {user_token}" + + elif preferred == "shared_secret" or self.agent.auth_scheme == "shared_secret": + pass + + elif preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"): + pass + + elif preferred == "openIdConnect": + pass + + elif preferred == "mutualTLS": + pass + + return headers + + def get_mtls_context(self) -> Optional[ssl.SSLContext]: + if self.agent.auth_scheme == "mutualTLS" or self.get_preferred_auth_scheme() == "mutualTLS": + if self.agent.mtls_client_cert and self.agent.mtls_client_key: + config = MtlsConfig( + ca_cert=self.agent.mtls_ca_cert, + client_cert=self.agent.mtls_client_cert, + client_key=self.agent.mtls_client_key, + ) + return config.create_ssl_context() + return None + + def create_signed_request_headers( + self, + method: str, + path: str, + body: Optional[bytes] = None, + ) -> Dict[str, str]: + headers: Dict[str, str] = {} + preferred = self.get_preferred_auth_scheme() + + if preferred == "shared_secret" and self.agent.shared_secret: + sig_headers = SharedSecretAuth.sign_request( + self.agent.shared_secret, + method, + path, + body, + ) + headers.update(sig_headers) + + elif self.agent.auth_scheme == "bearer" and self.agent.auth_token: + headers["Authorization"] = f"Bearer {self.agent.auth_token}" + + return headers + + async def get_oauth2_auth(self) -> Optional[OAuth2Auth]: + if self.agent.oauth2_client_id and self.agent.oauth2_token_url: + scopes = self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else [] + return OAuth2Auth( + client_id=self.agent.oauth2_client_id, + client_secret=self.agent.oauth2_client_secret or "", + token_url=self.agent.oauth2_token_url, + scopes=scopes, + ) + return None + + async def get_oidc_auth(self) -> Optional[OIDCAuth]: + if self.agent.oidc_issuer_url: + return OIDCAuth( + issuer_url=self.agent.oidc_issuer_url, + client_id=self.agent.oidc_client_id or "", + client_secret=self.agent.oidc_client_secret, + scopes=self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else [], + ) + return None + + async def authorize_request(self, method: str, url: str, user_token: Optional[str] = None) -> Dict[str, str]: + headers = self.get_auth_headers(user_token) + preferred = self.get_preferred_auth_scheme() + + if preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"): + oauth2_auth = await self.get_oauth2_auth() + if oauth2_auth: + headers.update(await oauth2_auth.authorize_request(method, url)) + + elif preferred == "openIdConnect": + oidc_auth = await self.get_oidc_auth() + if oidc_auth: + headers.update(await oidc_auth.authorize_request(method, url)) + + return headers + +_STATE_TRANSITIONS = { + "SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"}, + "WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"}, + "INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"}, + "AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"}, + "REJECTED": set(), + "FAILED": set(), + "COMPLETED": set(), + "CANCELED": set(), +} +_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"} + + @dataclass class A2AResolvedRoute: selected: str @@ -185,6 +467,7 @@ class A2ARuntime: remote_agent_id: Optional[int], compatibility_mode: bool, metadata: Optional[Dict[str, Any]] = None, + context_id: Optional[str] = None, ) -> A2ATask: if idempotency_key: existing = ( @@ -209,6 +492,7 @@ class A2ARuntime: idempotency_key=idempotency_key, compatibility_mode=compatibility_mode, metadata_json=_json_dumps(metadata or {}), + context_id=context_id, ) db.add(task) db.commit() @@ -335,23 +619,21 @@ class A2ARuntime: async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None: event_payload = _json_loads(event.payload_json, {}) - request_payload = { - "task_id": event.task_id, - "event_type": event.event_type, - "event_id": event.id, - "payload": event_payload, - } - body = _json_dumps(request_payload).encode("utf-8") + stream_response_payload = self._build_stream_response_payload(event, event_payload) + body = _json_dumps(stream_response_payload).encode("utf-8") + for attempt in range(1, 5): delivery.attempt = attempt db.add(delivery) db.commit() + headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)} if hook.secret: digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest() headers["X-A2A-Signature"] = f"sha256={digest}" if hook.auth_header: headers["Authorization"] = hook.auth_header + try: async with httpx.AsyncClient(timeout=8.0, verify=True) as client: resp = await client.post(hook.target_url, content=body, headers=headers) @@ -368,11 +650,12 @@ class A2ARuntime: except Exception as exc: delivery.error_message = str(exc)[:500] if attempt < 4: + backoff_seconds = 2 ** attempt delivery.status = "RETRYING" - delivery.next_retry_at = _utc_now() + timedelta(seconds=2 ** attempt) + delivery.next_retry_at = _utc_now() + timedelta(seconds=backoff_seconds) db.add(delivery) db.commit() - await asyncio.sleep(2 ** attempt) + await asyncio.sleep(backoff_seconds) continue delivery.status = "FAILED" delivery.dead_letter = True @@ -380,5 +663,42 @@ class A2ARuntime: db.commit() return + def _build_stream_response_payload(self, event: A2ATaskEvent, event_payload: Dict[str, Any]) -> Dict[str, Any]: + event_type = event.event_type + task_id = event_payload.get("task_id", event.task_id) + + if event_type == "TaskStatusUpdateEvent": + status_state = event_payload.get("task_status", "WORKING") + status_timestamp = event_payload.get("timestamp", _utc_now().isoformat()) + status_schema = A2ATaskStatusSchema( + state=status_state, + timestamp=datetime.fromisoformat(status_timestamp) if isinstance(status_timestamp, str) else status_timestamp, + ) + return { + "statusUpdate": TaskStatusUpdateEvent( + taskId=task_id, + contextId=event_payload.get("context_id"), + status=status_schema, + metadata=event_payload.get("metadata", {}), + ).model_dump() + } + elif event_type == "TaskArtifactUpdateEvent": + artifact_content = event_payload.get("artifact", {}).get("content", "") + artifact_schema = A2AArtifactSchema( + artifactId=f"artifact-{event.id}", + parts=[A2APartSchema(part_type="text", text=artifact_content)], + ) + return { + "artifactUpdate": TaskArtifactUpdateEvent( + taskId=task_id, + contextId=event_payload.get("context_id"), + artifact=artifact_schema, + append=False, + lastChunk=True, + ).model_dump() + } + else: + return {"message": event_payload} + a2a_runtime = A2ARuntime() diff --git a/backend/tests/test_a2a_backend.py b/backend/tests/test_a2a_backend.py index c2e3240..78d7f70 100644 --- a/backend/tests/test_a2a_backend.py +++ b/backend/tests/test_a2a_backend.py @@ -1,7 +1,13 @@ import asyncio +import hashlib +import hmac +import json import sys +import time from collections.abc import Generator +from datetime import datetime from pathlib import Path +from typing import Any, Dict, Optional import httpx import pytest @@ -20,10 +26,11 @@ if str(NANOBOT_ROOT) not in sys.path: from app.core.security import CurrentUser, get_current_user from app.database import Base, get_db -from app.models.a2a import A2ARemoteAgent +from app.models.a2a import A2ARemoteAgent, A2ATask, A2ATaskState from app.models.project import Project from app.models.user import User -from app.services.a2a_service import a2a_runtime +from app.schemas.a2a import A2AMessageRole, A2APartType, AgentSkillOutputMode, AgentSkillInputMode +from app.services.a2a_service import a2a_runtime, SharedSecretAuth from main import app @@ -42,278 +49,824 @@ def _seed(db: Session) -> tuple[int, str, int, str, int]: return owner.id, owner.username, other.id, other.username, project.id -def test_a2a_send_list_cancel_and_rollout() -> None: - engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) - testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base.metadata.create_all(bind=engine) - db = testing_session_local() - owner_id, owner_username, _, _, project_id = _seed(db) - db.close() - - state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} - - def override_get_db() -> Generator[Session, None, None]: - override_db = testing_session_local() - try: - yield override_db - finally: - override_db.close() - - def override_current_user() -> CurrentUser: - return state["user"] - - app.dependency_overrides[get_db] = override_get_db - app.dependency_overrides[get_current_user] = override_current_user - try: - client = TestClient(app) - send_resp = client.post( - "/api/v1/a2a/messages/send", - json={ - "project_id": project_id, - "message": "hello a2a", - "session_id": "test-a2a-session", - "route_mode": "local_first", - }, - ) - assert send_resp.status_code == 200 - task_id = send_resp.json()["task"]["id"] - - get_resp = client.get(f"/api/v1/a2a/tasks/{task_id}") - assert get_resp.status_code == 200 - assert get_resp.json()["project_id"] == project_id - - list_resp = client.get("/api/v1/a2a/tasks", params={"project_id": project_id}) - assert list_resp.status_code == 200 - assert any(item["id"] == task_id for item in list_resp.json()) - - cancel_resp = client.post(f"/api/v1/a2a/tasks/{task_id}/cancel") - assert cancel_resp.status_code == 200 - assert cancel_resp.json()["state"] in {"CANCELED", "COMPLETED", "FAILED"} - - rollout_resp = client.put( - f"/api/v1/a2a/projects/{project_id}/rollout", - json={"canary_enabled": True, "canary_percent": 30, "route_mode_default": "a2a_first", "fallback_chain": ["a2a", "local"]}, - ) - assert rollout_resp.status_code == 200 - assert rollout_resp.json()["canary_enabled"] is True - assert rollout_resp.json()["canary_percent"] == 30 - finally: - app.dependency_overrides.clear() - Base.metadata.drop_all(bind=engine) - engine.dispose() - - -def test_a2a_task_tenant_isolation() -> None: - engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) - testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base.metadata.create_all(bind=engine) - db = testing_session_local() - owner_id, owner_username, other_id, other_username, project_id = _seed(db) - db.close() - - state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} - - def override_get_db() -> Generator[Session, None, None]: - override_db = testing_session_local() - try: - yield override_db - finally: - override_db.close() - - def override_current_user() -> CurrentUser: - return state["user"] - - app.dependency_overrides[get_db] = override_get_db - app.dependency_overrides[get_current_user] = override_current_user - try: - client = TestClient(app) - send_resp = client.post( - "/api/v1/a2a/messages/send", - json={"project_id": project_id, "message": "tenant isolation", "session_id": "tenant-isolation", "route_mode": "local"}, - ) - assert send_resp.status_code == 200 - task_id = send_resp.json()["task"]["id"] - - state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False) - forbidden_resp = client.get(f"/api/v1/a2a/tasks/{task_id}") - assert forbidden_resp.status_code == 404 - finally: - app.dependency_overrides.clear() - Base.metadata.drop_all(bind=engine) - engine.dispose() - - -def test_a2a_metrics_admin_only() -> None: - engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) - testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base.metadata.create_all(bind=engine) - db = testing_session_local() - owner_id, owner_username, _, _, _ = _seed(db) - db.close() - - state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} - - def override_get_db() -> Generator[Session, None, None]: - override_db = testing_session_local() - try: - yield override_db - finally: - override_db.close() - - def override_current_user() -> CurrentUser: - return state["user"] - - app.dependency_overrides[get_db] = override_get_db - app.dependency_overrides[get_current_user] = override_current_user - try: - client = TestClient(app) - denied = client.get("/api/v1/a2a/metrics") - assert denied.status_code == 403 - state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True) - ok = client.get("/api/v1/a2a/metrics") - assert ok.status_code == 200 - assert "counters" in ok.json() - finally: - app.dependency_overrides.clear() - Base.metadata.drop_all(bind=engine) - engine.dispose() - - -def test_a2a_send_idempotency_key_deduplicates_task() -> None: - engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) - testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base.metadata.create_all(bind=engine) - db = testing_session_local() - owner_id, owner_username, _, _, project_id = _seed(db) - db.close() - - state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} - - def override_get_db() -> Generator[Session, None, None]: - override_db = testing_session_local() - try: - yield override_db - finally: - override_db.close() - - def override_current_user() -> CurrentUser: - return state["user"] - - app.dependency_overrides[get_db] = override_get_db - app.dependency_overrides[get_current_user] = override_current_user - try: - client = TestClient(app) - payload = { - "project_id": project_id, - "message": "dedupe-task", - "session_id": "idempotency-session", - "route_mode": "local_first", - "idempotency_key": "same-key-1", +def _make_message_payload(project_id: int, text: str, session_id: str = "test-session", route_mode: str = "local_first", idempotency_key: Optional[str] = None) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "message": { + "messageId": f"msg-{int(time.time()*1000)}", + "role": "user", + "parts": [ + { + "part_type": "data", + "data": { + "project_id": project_id, + "route_mode": route_mode, + "session_id": session_id, + **( {"idempotency_key": idempotency_key} if idempotency_key else {} ) + }, + "mediaType": "application/json", + }, + { + "part_type": "text", + "text": text, + } + ], } - first_resp = client.post("/api/v1/a2a/messages/send", json=payload) - second_resp = client.post("/api/v1/a2a/messages/send", json=payload) - assert first_resp.status_code == 200 - assert second_resp.status_code == 200 - assert first_resp.json()["task"]["id"] == second_resp.json()["task"]["id"] - finally: - app.dependency_overrides.clear() + } + return payload + + +class TestPartSerialization: + def test_part_text_serialization(self): + from app.schemas.a2a import A2APartCreateSchema + part = A2APartCreateSchema( + part_type=A2APartType.TEXT, + text="Hello world", + mediaType="text/plain", + ) + data = part.model_dump() + assert data["part_type"] == "text" + assert data["text"] == "Hello world" + assert data["mediaType"] == "text/plain" + + def test_part_data_serialization(self): + from app.schemas.a2a import A2APartCreateSchema + part = A2APartCreateSchema( + part_type=A2APartType.DATA, + data={"project_id": 123, "route_mode": "local"}, + mediaType="application/json", + ) + data = part.model_dump() + assert data["part_type"] == "data" + assert data["data"]["project_id"] == 123 + + def test_part_url_serialization(self): + from app.schemas.a2a import A2APartCreateSchema + part = A2APartCreateSchema( + part_type=A2APartType.URL, + url="https://example.com/file.pdf", + mediaType="application/pdf", + filename="file.pdf", + ) + data = part.model_dump() + assert data["part_type"] == "url" + assert data["url"] == "https://example.com/file.pdf" + assert data["filename"] == "file.pdf" + + def test_part_raw_serialization(self): + from app.schemas.a2a import A2APartCreateSchema + part = A2APartCreateSchema( + part_type=A2APartType.RAW, + raw="\x00\x01\x02\x03", + mediaType="application/octet-stream", + ) + data = part.model_dump() + assert data["part_type"] == "raw" + assert data["raw"] == "\x00\x01\x02\x03" + + +class TestStateMachine: + def test_state_transitions_submit_to_complete(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "test state machine")) + assert resp.status_code == 200 + task_id = resp.json()["task"]["id"] + + task = db.query(A2ATask).filter(A2ATask.id == task_id).first() + assert task.state == A2ATaskState.SUBMITTED + + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.WORKING) + assert task.state == A2ATaskState.WORKING + + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.COMPLETED) + assert task.state == A2ATaskState.COMPLETED + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_state_cancel(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel test")) + assert resp.status_code == 200 + task_id = resp.json()["task"]["id"] + + cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={}) + assert cancel_resp.status_code == 200 + assert cancel_resp.json()["state"] == "CANCELED" + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_state_failed(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "fail test")) + assert resp.status_code == 200 + task_id = resp.json()["task"]["id"] + + task = db.query(A2ATask).filter(A2ATask.id == task_id).first() + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.FAILED, error_message='{"message": "test error"}') + assert task.state == A2ATaskState.FAILED + assert task.error_message is not None + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_state_rejected(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "reject test")) + assert resp.status_code == 200 + task_id = resp.json()["task"]["id"] + + task = db.query(A2ATask).filter(A2ATask.id == task_id).first() + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.REJECTED) + assert task.state == A2ATaskState.REJECTED + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_state_input_required(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "input required test")) + assert resp.status_code == 200 + task_id = resp.json()["task"]["id"] + + task = db.query(A2ATask).filter(A2ATask.id == task_id).first() + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.INPUT_REQUIRED) + assert task.state == A2ATaskState.INPUT_REQUIRED + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_state_auth_required(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "auth required test")) + assert resp.status_code == 200 + task_id = resp.json()["task"]["id"] + + task = db.query(A2ATask).filter(A2ATask.id == task_id).first() + task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.AUTH_REQUIRED) + assert task.state == A2ATaskState.AUTH_REQUIRED + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + +class TestA2APathNormalization: + def test_message_send_path(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "path test")) + assert resp.status_code == 200 + assert "task" in resp.json() + assert "id" in resp.json()["task"] + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_message_stream_path(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:stream", json=_make_message_payload(project_id, "stream path test")) + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("text/event-stream") + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_tasks_cancel_path(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel path test")) + assert resp.status_code == 200 + task_id = resp.json()["task"]["id"] + + cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={}) + assert cancel_resp.status_code == 200 + assert cancel_resp.json()["task_id"] == task_id + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_agent_card_public_path(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + db.close() + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + app.dependency_overrides[get_db] = override_get_db + try: + client = TestClient(app) + resp = client.get("/api/v1/.well-known/agent-card.json") + assert resp.status_code == 200 + data = resp.json() + assert "name" in data + assert "protocol_version" in data + assert "endpoints" in data + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + +class TestVersionNegotiation: + def test_version_not_supported_error(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post( + "/api/v1/message:send", + json=_make_message_payload(project_id, "version test"), + headers={"A2A-Version": "2.0"} + ) + assert resp.status_code == 400 + detail = json.loads(resp.json()["detail"]) + assert detail["code"] == -32009 + assert "not supported" in detail["message"].lower() + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_version_response_header(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post( + "/api/v1/message:send", + json=_make_message_payload(project_id, "version header test"), + headers={"A2A-Version": "1.0"} + ) + assert resp.status_code == 200 + assert resp.headers.get("A2A-Version") == "1.0" + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + +class TestWebhookStreamResponse: + def test_webhook_payload_format(self): + from app.schemas.a2a import StreamResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent, TaskMessageEvent, A2ATaskStatusSchema, A2ATaskState, A2AArtifactSchema + + status_event = TaskStatusUpdateEvent( + taskId="task-123", + contextId="ctx-456", + status=A2ATaskStatusSchema( + state=A2ATaskState.SUBMITTED, + timestamp=datetime.utcnow(), + ), + metadata={}, + ) + status_dump = status_event.model_dump() + assert "taskId" in status_dump + assert status_dump["taskId"] == "task-123" + assert status_dump["status"]["state"] == "SUBMITTED" + + artifact_event = TaskArtifactUpdateEvent( + taskId="task-123", + contextId="ctx-456", + artifact=A2AArtifactSchema( + artifactId="art-789", + parts=[], + ), + append=False, + lastChunk=True, + ) + artifact_dump = artifact_event.model_dump() + assert "taskId" in artifact_dump + assert artifact_dump["artifact"]["artifactId"] == "art-789" + + def test_stream_response_task_field(self): + from app.schemas.a2a import StreamResponse, StreamResponseTask, A2ATaskState + resp = StreamResponse( + task=StreamResponseTask( + id="task-123", + contextId="ctx-456", + state=A2ATaskState.WORKING, + artifacts=[], + ) + ) + data = resp.model_dump() + assert "task" in data + assert data["task"]["id"] == "task-123" + + +class TestSSEFIFO: + def test_sse_event_order(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + with client.stream("POST", "/api/v1/message:stream", json=_make_message_payload(project_id, "fifo test")) as resp: + assert resp.status_code == 200 + chunks = [] + for line in resp.iter_lines(): + if line.startswith("data: "): + chunks.append(json.loads(line[6:])) + + event_types = [c.get("type") for c in chunks if "type" in c] + status_idx = next((i for i, t in enumerate(event_types) if t == "TaskStatusUpdateEvent"), -1) + assert status_idx >= 0 + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + +class TestAuthSchemes: + def test_shared_secret_auth(self): + secret = "test-secret-key-12345" + timestamp = int(time.time()) + body = b'{"test":"data"}' + sig, _ = SharedSecretAuth.generate_signature(secret, body, timestamp) + assert sig.startswith("sha256=") + + assert SharedSecretAuth.verify_signature(secret, body, sig, timestamp) is True + + def test_auth_scheme_none(self): + from app.schemas.a2a import SecuritySchemeHttpAuth + scheme = SecuritySchemeHttpAuth(scheme="bearer", description="Bearer auth") + assert scheme.scheme == "bearer" + + +class TestExceptionPaths: + def test_auth_failure_marks_agent_unhealthy(self, monkeypatch): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, _, _, _, project_id = _seed(db) + agent = A2ARemoteAgent( + project_id=project_id, + name="auth-fail-agent", + base_url="https://remote.example.com", + auth_scheme="bearer", + auth_token="bad-token", + created_by=owner_id, + ) + db.add(agent) + db.commit() + db.refresh(agent) + a2a_runtime._circuit_state.pop(agent.id, None) + + class _FailResp: + status_code = 401 + + @staticmethod + def json(): + return {"detail": "unauthorized"} + + class _Client401: + def __init__(self, *args, **kwargs): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def get(self, url, headers=None): + return _FailResp() + + monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _Client401) + + with pytest.raises(RuntimeError): + asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01)) + db.refresh(agent) + assert agent.healthy is False + assert agent.failure_count == 1 + Base.metadata.drop_all(bind=engine) + db.close() engine.dispose() + def test_remote_unavailable_opens_circuit(self, monkeypatch): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, _, _, _, project_id = _seed(db) + agent = A2ARemoteAgent( + project_id=project_id, + name="offline-agent", + base_url="https://offline.example.com", + auth_scheme="none", + created_by=owner_id, + ) + db.add(agent) + db.commit() + db.refresh(agent) + a2a_runtime._circuit_state.pop(agent.id, None) -def test_a2a_fetch_agent_card_auth_failure_marks_agent_unhealthy(monkeypatch) -> None: - engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) - testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base.metadata.create_all(bind=engine) - db = testing_session_local() - owner_id, _, _, _, project_id = _seed(db) - agent = A2ARemoteAgent( - project_id=project_id, - name="auth-fail-agent", - base_url="https://remote.example.com", - auth_scheme="bearer", - auth_token="bad-token", - created_by=owner_id, - ) - db.add(agent) - db.commit() - db.refresh(agent) - a2a_runtime._circuit_state.pop(agent.id, None) + class _ClientDown: + def __init__(self, *args, **kwargs): + pass - class _FailResp: - status_code = 401 + async def __aenter__(self): + return self - @staticmethod - def json(): - return {"detail": "unauthorized"} + async def __aexit__(self, exc_type, exc, tb): + return False - class _Client401: - def __init__(self, *args, **kwargs): - pass + async def get(self, url, headers=None): + raise httpx.ConnectError("network down") - async def __aenter__(self): - return self + monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _ClientDown) - async def __aexit__(self, exc_type, exc, tb): - return False + for _ in range(3): + with pytest.raises(Exception): + asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01)) + db.refresh(agent) + assert agent.healthy is False + assert agent.failure_count == 3 + assert agent.circuit_open_until is not None - async def get(self, url, headers=None): - return _FailResp() + Base.metadata.drop_all(bind=engine) + db.close() + engine.dispose() - monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _Client401) + def test_idempotency_key_deduplicates_task(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, project_id = _seed(db) + db.close() - with pytest.raises(RuntimeError): - asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01)) - db.refresh(agent) - assert agent.healthy is False - assert agent.failure_count == 1 + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} - Base.metadata.drop_all(bind=engine) - db.close() - engine.dispose() + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + idempotency_key = f"idem-key-{int(time.time())}" + + payload1 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key) + resp1 = client.post("/api/v1/message:send", json=payload1) + assert resp1.status_code == 200 + + payload2 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key) + payload2["message"]["messageId"] = f"msg-{int(time.time()*1000) + 1}" + resp2 = client.post("/api/v1/message:send", json=payload2) + + assert resp2.status_code == 200 + assert resp1.json()["task"]["id"] == resp2.json()["task"]["id"] + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() + + def test_tenant_isolation(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, other_id, other_username, project_id = _seed(db) + db.close() + + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} + + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() + + def override_current_user() -> CurrentUser: + return state["user"] + + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "isolation test")) + assert resp.status_code == 200 + task_id = resp.json()["task"]["id"] + + state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False) + get_resp = client.get(f"/api/v1/tasks/{task_id}") + assert get_resp.status_code == 404 + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() -def test_a2a_fetch_agent_card_remote_unavailable_opens_circuit(monkeypatch) -> None: - engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) - testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) - Base.metadata.create_all(bind=engine) - db = testing_session_local() - owner_id, _, _, _, project_id = _seed(db) - agent = A2ARemoteAgent( - project_id=project_id, - name="offline-agent", - base_url="https://offline.example.com", - auth_scheme="none", - created_by=owner_id, - ) - db.add(agent) - db.commit() - db.refresh(agent) - a2a_runtime._circuit_state.pop(agent.id, None) +class TestMetricsAdminOnly: + def test_metrics_admin_only(self): + engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool) + testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + Base.metadata.create_all(bind=engine) + db = testing_session_local() + owner_id, owner_username, _, _, _ = _seed(db) + db.close() - class _ClientDown: - def __init__(self, *args, **kwargs): - pass + state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)} - async def __aenter__(self): - return self + def override_get_db() -> Generator[Session, None, None]: + override_db = testing_session_local() + try: + yield override_db + finally: + override_db.close() - async def __aexit__(self, exc_type, exc, tb): - return False + def override_current_user() -> CurrentUser: + return state["user"] - async def get(self, url, headers=None): - raise httpx.ConnectError("network down") + app.dependency_overrides[get_db] = override_get_db + app.dependency_overrides[get_current_user] = override_current_user + try: + client = TestClient(app) + denied = client.get("/api/v1/a2a/metrics") + assert denied.status_code == 403 - monkeypatch.setattr("app.services.a2a_service.httpx.AsyncClient", _ClientDown) - - for _ in range(3): - with pytest.raises(Exception): - asyncio.run(a2a_runtime.fetch_agent_card(db, agent, timeout_s=0.01)) - db.refresh(agent) - assert agent.healthy is False - assert agent.failure_count == 3 - assert agent.circuit_open_until is not None - - Base.metadata.drop_all(bind=engine) - db.close() - engine.dispose() + state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True) + ok = client.get("/api/v1/a2a/metrics") + assert ok.status_code == 200 + assert "counters" in ok.json() + finally: + app.dependency_overrides.clear() + Base.metadata.drop_all(bind=engine) + engine.dispose() diff --git a/frontend/src/api/a2a.ts b/frontend/src/api/a2a.ts index aca8b37..8723bf2 100644 --- a/frontend/src/api/a2a.ts +++ b/frontend/src/api/a2a.ts @@ -1,5 +1,100 @@ import { api } from "@/lib/api"; +export interface A2APartText { + kind: "text"; + text: string; +} + +export interface A2APartUrl { + kind: "url"; + url: string; +} + +export interface A2APartFile { + kind: "file"; + data: string; + mediaType?: string; + filename?: string; +} + +export type A2APart = A2APartText | A2APartUrl | A2APartFile; + +export interface A2AMessage { + messageId?: string; + contextId?: string; + taskId?: string; + role: "user" | "agent" | "system"; + parts: A2APart[]; + extensions?: Record[]; + referenceTaskIds?: string[]; +} + +export interface A2AArtifact { + artifactId?: string; + name?: string; + description?: string; + parts: A2APart[]; + metadata?: Record; + extensions?: Record[]; +} + +export interface A2ATask { + id: string; + project_id?: number; + context_id?: string; + source: string; + state: string; + remote_agent_id?: number | null; + input_text: string; + input_parts?: A2APart[]; + output_text?: string | null; + output_parts?: A2APart[]; + error_message?: string | null; + compatibility_mode: boolean; + metadata: Record; + artifacts?: A2AArtifact[]; + history?: A2AMessage[]; + history_length?: number; + created_at: string; + updated_at: string; + finished_at?: string | null; +} + +export interface A2AAgentCard { + id?: string; + name: string; + description?: string; + url?: string; + provider?: { + organization?: string; + url?: string; + }; + skills?: Array<{ + id?: string; + name: string; + description?: string; + tags?: string[]; + examples?: string[]; + inputModes?: string[]; + outputModes?: string[]; + securityRequirements?: Array>; + }>; + supportedInterfaces?: Array<{ + type: string; + url?: string; + protocolBinding?: string; + protocolVersion?: string; + tenant?: string; + }>; + defaultInputModes?: string[]; + defaultOutputModes?: string[]; + securitySchemes?: Record; + security?: Array>; + signatures?: string[]; + iconUrl?: string; + documentationUrl?: string; +} + export interface A2ARemoteAgent { id: number; project_id: number; @@ -12,27 +107,12 @@ export interface A2ARemoteAgent { failure_count: number; circuit_open_until?: string | null; card_fetched_at?: string | null; -} - -export interface A2ATask { - id: string; - project_id: number; - source: string; - state: string; - remote_agent_id?: number | null; - input_text: string; - output_text?: string | null; - error_message?: string | null; - compatibility_mode: boolean; - metadata: Record; - created_at: string; - updated_at: string; - finished_at?: string | null; + agent_card?: A2AAgentCard; } export interface A2ASendMessagePayload { project_id: number; - message: string; + message: A2AMessage; session_id?: string; remote_agent_id?: number; route_mode?: "auto" | "local" | "a2a" | "a2a_first" | "local_first" | "mcp_first"; @@ -55,11 +135,13 @@ export interface A2ASubscribeEvent { type?: string; event?: string; task_id?: string; + context_id?: string; task_status?: string; status?: string; - artifact?: { - content?: string; - }; + artifact?: A2AArtifact; + append?: boolean; + last_chunk?: boolean; + message?: A2AMessage; output?: string; source?: string; timestamp?: string; @@ -122,54 +204,202 @@ export const a2aApi = { healthCheckRemoteAgent(agentId: number) { return api.post<{ healthy: boolean; failure_count: number }>(`/api/v1/a2a/remote-agents/${agentId}/health-check`, {}); }, - listTasks(projectId: number, state?: string) { + listTasks(projectId: number, state?: string, contextId?: string) { const params = new URLSearchParams({ project_id: String(projectId), limit: "100" }); if (state && state !== "all") { params.set("state", state); } + if (contextId) { + params.set("context_id", contextId); + } return api.get(`/api/v1/a2a/tasks?${params.toString()}`); }, - getTask(taskId: string) { - return api.get(`/api/v1/a2a/tasks/${taskId}`); + getTask(taskId: string, historyLength?: number) { + const params = new URLSearchParams(); + if (historyLength !== undefined) { + params.set("historyLength", String(historyLength)); + } + const queryString = params.toString(); + return api.get(`/api/v1/a2a/tasks/${taskId}${queryString ? `?${queryString}` : ""}`); }, cancelTask(taskId: string) { - return api.post<{ task_id: string; state: string }>(`/api/v1/a2a/tasks/${taskId}/cancel`, {}); + return api.post<{ task_id: string; state: string }>(`/api/v1/a2a/tasks/${taskId}:cancel`, {}); }, sendMessage(payload: A2ASendMessagePayload) { - return api.post("/api/v1/a2a/messages/send", payload); + return api.post("/api/v1/a2a/message:send", payload); }, - async subscribeTask(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): Promise { - const response = await fetch(`/api/v1/a2a/tasks/${taskId}/subscribe`, { - method: "GET", - headers: { - ...getAuthHeaders(), - }, - signal, - }); - if (!response.ok || !response.body) { - throw new Error(`Subscribe failed: ${response.status}`); - } - const reader = response.body.getReader(); - const decoder = new TextDecoder("utf-8"); - let buffer = ""; - while (true) { - const { done, value } = await reader.read(); - if (done) break; - buffer += decoder.decode(value, { stream: true }); - const splitIndex = buffer.lastIndexOf("\n\n"); - if (splitIndex === -1) continue; - const complete = buffer.slice(0, splitIndex); - buffer = buffer.slice(splitIndex + 2); - const events = parseSseEvents(complete); - for (const event of events) { - onEvent(event); + streamMessage(payload: A2ASendMessagePayload) { + return api.post("/api/v1/a2a/message:stream", payload); + }, + subscribeTask(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): () => void { + const controller = new AbortController(); + void (async () => { + const response = await fetch(`/api/v1/a2a/tasks/${taskId}:subscribe`, { + method: "GET", + headers: { + ...getAuthHeaders(), + }, + signal: signal || controller.signal, + }); + if (!response.ok || !response.body) { + throw new Error(`Subscribe failed: ${response.status}`); } - } - if (buffer.trim()) { - const events = parseSseEvents(buffer); - for (const event of events) { - onEvent(event); + const reader = response.body.getReader(); + const decoder = new TextDecoder("utf-8"); + let buffer = ""; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + const splitIndex = buffer.lastIndexOf("\n\n"); + if (splitIndex === -1) continue; + const complete = buffer.slice(0, splitIndex); + buffer = buffer.slice(splitIndex + 2); + const events = parseSseEvents(complete); + for (const event of events) { + onEvent(event); + } } - } + if (buffer.trim()) { + const events = parseSseEvents(buffer); + for (const event of events) { + onEvent(event); + } + } + })(); + return () => controller.abort(); + }, + subscribeTaskSSE(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): () => void { + const controller = new AbortController(); + void (async () => { + const response = await fetch(`/api/v1/a2a/tasks/${taskId}/subscribe`, { + method: "GET", + headers: { + ...getAuthHeaders(), + }, + signal: signal || controller.signal, + }); + if (!response.ok || !response.body) { + throw new Error(`Subscribe failed: ${response.status}`); + } + const reader = response.body.getReader(); + const decoder = new TextDecoder("utf-8"); + let buffer = ""; + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buffer += decoder.decode(value, { stream: true }); + const splitIndex = buffer.lastIndexOf("\n\n"); + if (splitIndex === -1) continue; + const complete = buffer.slice(0, splitIndex); + buffer = buffer.slice(splitIndex + 2); + const events = parseSseEvents(complete); + for (const event of events) { + onEvent(event); + } + } + if (buffer.trim()) { + const events = parseSseEvents(buffer); + for (const event of events) { + onEvent(event); + } + } + })(); + return () => controller.abort(); }, }; + +export function renderPart(part: A2APart): string { + switch (part.kind) { + case "text": + return part.text; + case "url": + return `[URL: ${part.url}]`; + case "file": + if (part.mediaType?.startsWith("image/")) { + return `[Image: ${part.filename || "image"}]`; + } + if (part.mediaType?.includes("json")) { + try { + const decoded = atob(part.data); + return `[JSON File: ${part.filename || "data.json"}]\n${decoded}`; + } catch { + return `[Binary File: ${part.filename || "data"}]`; + } + } + return `[File: ${part.filename || "file"}]`; + default: + return "[Unknown Part]"; + } +} + +export function renderParts(parts: A2APart[]): string { + return parts.map(renderPart).join("\n"); +} + +export function extractTextFromParts(parts: A2APart[]): string { + return parts + .filter((p): p is A2APartText => p.kind === "text") + .map((p) => p.text) + .join("\n"); +} + +export function getArtifactPreview(artifact: A2AArtifact): { type: "text" | "image" | "html" | "json" | "unknown"; content: string } { + if (!artifact.parts || artifact.parts.length === 0) { + return { type: "unknown", content: "" }; + } + + const firstPart = artifact.parts[0]; + + if (firstPart.kind === "text") { + return { type: "text", content: firstPart.text }; + } + + if (firstPart.kind === "url") { + const url = firstPart.url.toLowerCase(); + if (url.endsWith(".png") || url.endsWith(".jpg") || url.endsWith(".jpeg") || url.endsWith(".gif") || url.endsWith(".webp")) { + return { type: "image", content: firstPart.url }; + } + if (url.endsWith(".html") || url.endsWith(".htm")) { + return { type: "html", content: firstPart.url }; + } + return { type: "unknown", content: firstPart.url }; + } + + if (firstPart.kind === "file") { + const mediaType = firstPart.mediaType || ""; + if (mediaType.startsWith("image/")) { + return { type: "image", content: `data:${mediaType};base64,${firstPart.data}` }; + } + if (mediaType.includes("html")) { + try { + const decoded = atob(firstPart.data); + return { type: "html", content: decoded }; + } catch { + return { type: "unknown", content: "[HTML content]" }; + } + } + if (mediaType.includes("json")) { + try { + const decoded = atob(firstPart.data); + return { type: "json", content: decoded }; + } catch { + return { type: "unknown", content: "[JSON content]" }; + } + } + return { type: "unknown", content: `[File: ${firstPart.filename || "file"}]` }; + } + + return { type: "unknown", content: "" }; +} + +export function groupTasksByContextId(tasks: A2ATask[]): Map { + const grouped = new Map(); + for (const task of tasks) { + const contextId = task.context_id || "no-context"; + const existing = grouped.get(contextId) || []; + existing.push(task); + grouped.set(contextId, existing); + } + return grouped; +} \ No newline at end of file diff --git a/frontend/src/pages/Skills.tsx b/frontend/src/pages/Skills.tsx index 13cfcc4..096d1ec 100644 --- a/frontend/src/pages/Skills.tsx +++ b/frontend/src/pages/Skills.tsx @@ -9,7 +9,7 @@ import { Textarea } from "@/components/ui/textarea"; import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select"; import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table"; import { api } from "@/lib/api"; -import { a2aApi, type A2ARemoteAgent, type A2ATask } from "@/api/a2a"; +import { a2aApi, type A2ARemoteAgent, type A2ATask, type A2AArtifact, renderPart, renderParts, getArtifactPreview, groupTasksByContextId } from "@/api/a2a"; import { useProjectStore } from "@/store/projectStore"; import { useMcpHealthStore } from "@/store/mcpHealthStore"; import { useRef } from 'react'; @@ -118,6 +118,11 @@ export function Skills() { auth_token: '', }); const [isA2aRefreshingHealth, setIsA2aRefreshingHealth] = useState(false); + const [selectedA2aAgent, setSelectedA2aAgent] = useState(null); + const [selectedTask, setSelectedTask] = useState(null); + const [taskArtifactPreview, setTaskArtifactPreview] = useState<{ type: string; content: string } | null>(null); + const [contextIdFilter, setContextIdFilter] = useState('all'); + const [groupedByContextId, setGroupedByContextId] = useState>(new Map()); const { currentProject } = useProjectStore(); const { hasMcpError, refresh: refreshMcpHealth } = useMcpHealthStore(); @@ -245,6 +250,7 @@ export function Skills() { ]); setA2aAgents(agents || []); setA2aTasks(tasks || []); + setGroupedByContextId(groupTasksByContextId(tasks || [])); } catch (error) { console.error("Failed to fetch A2A data", error); } finally { @@ -252,6 +258,25 @@ export function Skills() { } }; + const handlePreviewArtifact = (artifact: A2AArtifact) => { + const preview = getArtifactPreview(artifact); + setTaskArtifactPreview(preview); + }; + + const handleTaskClick = (task: A2ATask) => { + setSelectedTask(task); + if (task.artifacts && task.artifacts.length > 0) { + const preview = getArtifactPreview(task.artifacts[0]); + setTaskArtifactPreview(preview); + } else { + setTaskArtifactPreview(null); + } + }; + + const handleAgentCardClick = (agent: A2ARemoteAgent) => { + setSelectedA2aAgent(agent); + }; + const handleRefreshA2aHealth = async () => { if (!currentProject || a2aAgents.length === 0) return; setIsA2aRefreshingHealth(true); @@ -602,6 +627,20 @@ export function Skills() { {t('refresh')} +