refactor: a2a

This commit is contained in:
qixinbo
2026-04-04 07:24:09 +08:00
parent 02677763cb
commit bdc61fe651
7 changed files with 3345 additions and 496 deletions
+1113 -151
View File
File diff suppressed because it is too large Load Diff
+95 -2
View File
@@ -1,9 +1,34 @@
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, func
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, JSON, Enum as SQLEnum, func
from sqlalchemy.orm import relationship
import enum
from app.database import Base
class A2ATaskState(str, enum.Enum):
SUBMITTED = "SUBMITTED"
WORKING = "WORKING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
CANCELED = "CANCELED"
INPUT_REQUIRED = "INPUT_REQUIRED"
AUTH_REQUIRED = "AUTH_REQUIRED"
REJECTED = "REJECTED"
class A2APartType(str, enum.Enum):
TEXT = "text"
RAW = "raw"
URL = "url"
DATA = "data"
class A2AMessageRole(str, enum.Enum):
USER = "user"
AGENT = "agent"
SYSTEM = "system"
class A2ARemoteAgent(Base):
__tablename__ = "a2a_remote_agents"
@@ -13,6 +38,17 @@ class A2ARemoteAgent(Base):
base_url = Column(String, nullable=False)
auth_scheme = Column(String, nullable=False, default="none")
auth_token = Column(String, nullable=True)
shared_secret = Column(String, nullable=True)
mtls_ca_cert = Column(Text, nullable=True)
mtls_client_cert = Column(Text, nullable=True)
mtls_client_key = Column(Text, nullable=True)
oauth2_client_id = Column(String, nullable=True)
oauth2_client_secret = Column(String, nullable=True)
oauth2_token_url = Column(String, nullable=True)
oauth2_scopes = Column(String, nullable=True)
oidc_issuer_url = Column(String, nullable=True)
oidc_client_id = Column(String, nullable=True)
oidc_client_secret = Column(String, nullable=True)
protocol_version = Column(String, nullable=True)
capabilities_json = Column(Text, nullable=False, default="[]")
card_json = Column(Text, nullable=True)
@@ -27,27 +63,84 @@ class A2ARemoteAgent(Base):
project = relationship("Project")
class A2APart(Base):
__tablename__ = "a2a_parts"
id = Column(Integer, primary_key=True, index=True)
message_id = Column(Integer, ForeignKey("a2a_messages.id", ondelete="CASCADE"), nullable=True, index=True)
artifact_id = Column(Integer, ForeignKey("a2a_artifacts.id", ondelete="CASCADE"), nullable=True, index=True)
part_type = Column(SQLEnum(A2APartType), nullable=False)
text_content = Column(Text, nullable=True)
raw_content = Column(Text, nullable=True)
url_content = Column(String, nullable=True)
data_content = Column(Text, nullable=True)
media_type = Column(String, nullable=True)
filename = Column(String, nullable=True)
metadata_json = Column(Text, nullable=False, default="{}")
created_at = Column(DateTime, default=func.now())
message = relationship("A2AMessage", back_populates="parts", foreign_keys=[message_id])
artifact = relationship("A2AArtifact", back_populates="parts", foreign_keys=[artifact_id])
class A2AMessage(Base):
__tablename__ = "a2a_messages"
id = Column(Integer, primary_key=True, index=True)
message_id = Column(String, nullable=False, unique=True, index=True)
context_id = Column(String, nullable=True, index=True)
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=True, index=True)
role = Column(SQLEnum(A2AMessageRole), nullable=False)
extensions_json = Column(Text, nullable=False, default="{}")
reference_task_ids_json = Column(Text, nullable=False, default="[]")
created_at = Column(DateTime, default=func.now(), index=True)
task = relationship("A2ATask", back_populates="messages", foreign_keys=[task_id])
parts = relationship("A2APart", back_populates="message", cascade="all, delete-orphan")
class A2AArtifact(Base):
__tablename__ = "a2a_artifacts"
id = Column(Integer, primary_key=True, index=True)
artifact_id = Column(String, nullable=False, unique=True, index=True)
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
name = Column(String, nullable=True)
description = Column(Text, nullable=True)
metadata_json = Column(Text, nullable=False, default="{}")
extensions_json = Column(Text, nullable=False, default="{}")
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
task = relationship("A2ATask", back_populates="artifacts")
parts = relationship("A2APart", back_populates="artifact", cascade="all, delete-orphan")
class A2ATask(Base):
__tablename__ = "a2a_tasks"
id = Column(String, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
tenant_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
context_id = Column(String, nullable=True, index=True)
source = Column(String, nullable=False, default="local")
remote_agent_id = Column(Integer, ForeignKey("a2a_remote_agents.id"), nullable=True, index=True)
idempotency_key = Column(String, nullable=True, index=True)
state = Column(String, nullable=False, index=True, default="SUBMITTED")
state = Column(SQLEnum(A2ATaskState), nullable=False, index=True, default=A2ATaskState.SUBMITTED)
input_text = Column(Text, nullable=False, default="")
output_text = Column(Text, nullable=True)
error_message = Column(Text, nullable=True)
compatibility_mode = Column(Boolean, nullable=False, default=True)
metadata_json = Column(Text, nullable=False, default="{}")
history_length = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
finished_at = Column(DateTime, nullable=True)
project = relationship("Project")
remote_agent = relationship("A2ARemoteAgent")
messages = relationship("A2AMessage", back_populates="task", cascade="all, delete-orphan", foreign_keys=[A2AMessage.task_id])
artifacts = relationship("A2AArtifact", back_populates="task", cascade="all, delete-orphan")
class A2ATaskEvent(Base):
+361
View File
@@ -0,0 +1,361 @@
from pydantic import BaseModel, ConfigDict, Field
from typing import Optional, List, Dict, Any, Literal, Union
from datetime import datetime
from enum import Enum
class A2ATaskState(str, Enum):
SUBMITTED = "SUBMITTED"
WORKING = "WORKING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
CANCELED = "CANCELED"
INPUT_REQUIRED = "INPUT_REQUIRED"
AUTH_REQUIRED = "AUTH_REQUIRED"
REJECTED = "REJECTED"
class A2APartType(str, Enum):
TEXT = "text"
RAW = "raw"
URL = "url"
DATA = "data"
class A2AMessageRole(str, Enum):
USER = "user"
AGENT = "agent"
SYSTEM = "system"
class A2APartSchema(BaseModel):
part_type: A2APartType
text: Optional[str] = None
raw: Optional[bytes] = None
url: Optional[str] = None
data: Optional[Any] = None
mediaType: Optional[str] = None
filename: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
class A2APartCreateSchema(BaseModel):
part_type: A2APartType
text: Optional[str] = None
raw: Optional[str] = None
url: Optional[str] = None
data: Optional[Any] = None
mediaType: Optional[str] = None
filename: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
class A2AMessageSchema(BaseModel):
messageId: str
contextId: Optional[str] = None
taskId: Optional[str] = None
role: A2AMessageRole
parts: List[A2APartSchema] = Field(default_factory=list)
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
referenceTaskIds: Optional[List[str]] = Field(default_factory=list)
createdAt: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class A2AMessageCreateSchema(BaseModel):
messageId: str
contextId: Optional[str] = None
taskId: Optional[str] = None
role: A2AMessageRole
parts: List[A2APartCreateSchema] = Field(default_factory=list)
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
referenceTaskIds: Optional[List[str]] = Field(default_factory=list)
class A2AArtifactSchema(BaseModel):
artifactId: str
name: Optional[str] = None
description: Optional[str] = None
parts: List[A2APartSchema] = Field(default_factory=list)
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
createdAt: Optional[datetime] = None
updatedAt: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class A2AArtifactCreateSchema(BaseModel):
artifactId: str
name: Optional[str] = None
description: Optional[str] = None
parts: List[A2APartCreateSchema] = Field(default_factory=list)
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
class A2ATaskStatusSchema(BaseModel):
state: A2ATaskState
timestamp: datetime
class A2ATaskSchema(BaseModel):
id: str
contextId: Optional[str] = None
projectId: int
tenantId: int
source: str
remoteAgentId: Optional[int] = None
idempotencyKey: Optional[str] = None
state: A2ATaskState
inputText: str
outputText: Optional[str] = None
errorMessage: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
historyLength: int = 0
createdAt: datetime
updatedAt: datetime
finishedAt: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class A2ATaskWithMessagesSchema(A2ATaskSchema):
messages: List[A2AMessageSchema] = Field(default_factory=list)
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
class A2ATaskWithHistorySchema(BaseModel):
id: str
contextId: Optional[str] = None
projectId: int
tenantId: int
state: A2ATaskState
history: List[A2AMessageSchema] = Field(default_factory=list)
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
createdAt: datetime
updatedAt: datetime
finishedAt: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class TaskStatusUpdateEvent(BaseModel):
taskId: str
contextId: Optional[str] = None
status: A2ATaskStatusSchema
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
class TaskArtifactUpdateEvent(BaseModel):
taskId: str
contextId: Optional[str] = None
artifact: A2AArtifactSchema
append: bool = False
lastChunk: bool = True
class TaskMessageEvent(BaseModel):
message: A2AMessageSchema
class StreamResponseTask(BaseModel):
id: str
contextId: Optional[str] = None
state: A2ATaskState
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
class StreamResponse(BaseModel):
task: Optional[StreamResponseTask] = None
message: Optional[A2AMessageSchema] = None
statusUpdate: Optional[TaskStatusUpdateEvent] = None
artifactUpdate: Optional[TaskArtifactUpdateEvent] = None
class SendMessageRequest(BaseModel):
message: A2AMessageCreateSchema
taskId: Optional[str] = None
contextId: Optional[str] = None
class SendStreamingMessageRequest(BaseModel):
message: A2AMessageCreateSchema
taskId: Optional[str] = None
contextId: Optional[str] = None
class GetTaskRequest(BaseModel):
historyLength: Optional[int] = None
class TaskListRequest(BaseModel):
contextId: Optional[str] = None
status: Optional[A2ATaskState] = None
pageSize: int = 20
pageToken: Optional[str] = None
class CancelTaskRequest(BaseModel):
pass
class PushNotificationConfigCreate(BaseModel):
targetUrl: str
secret: Optional[str] = None
authHeader: Optional[str] = None
enabled: bool = True
class PushNotificationConfig(BaseModel):
id: int
taskId: str
targetUrl: str
secret: Optional[str] = None
authHeader: Optional[str] = None
enabled: bool
createdBy: int
createdAt: datetime
model_config = ConfigDict(from_attributes=True)
class VersionNotSupportedError(BaseModel):
code: int = -32009
message: str = "Version not supported"
data: Optional[Dict[str, Any]] = None
class AgentSkillInputMode(str, Enum):
TEXT = "text"
DATA = "data"
RAW = "raw"
URL = "url"
class AgentSkillOutputMode(str, Enum):
TEXT = "text"
DATA = "data"
ARTIFACT = "artifact"
STREAM = "stream"
class AgentSkillSecurityRequirement(BaseModel):
scheme: str
scopes: Optional[List[str]] = None
class AgentSkillExample(BaseModel):
input: Dict[str, Any]
output: Dict[str, Any]
class AgentSkill(BaseModel):
id: str
name: str
description: Optional[str] = None
tags: List[str] = Field(default_factory=list)
examples: List[AgentSkillExample] = Field(default_factory=list)
inputModes: List[AgentSkillInputMode] = Field(default_factory=list)
outputModes: List[AgentSkillOutputMode] = Field(default_factory=list)
securityRequirements: List[AgentSkillSecurityRequirement] = Field(default_factory=list)
class AgentProvider(BaseModel):
organization: str
url: Optional[str] = None
class AgentSupportedInterface(BaseModel):
url: str
protocolBinding: str
protocolVersion: str
tenant: Optional[str] = None
class SecuritySchemeApiKey(BaseModel):
type: Literal["apiKey"] = "apiKey"
name: str
in_: str = Field(alias="in")
description: Optional[str] = None
model_config = ConfigDict(populate_by_name=True)
class SecuritySchemeHttpAuth(BaseModel):
type: Literal["http"] = "http"
scheme: str
description: Optional[str] = None
class OAuth2AuthorizationCodeFlow(BaseModel):
authorizationUrl: str
tokenUrl: str
scopes: Dict[str, str] = Field(default_factory=dict)
refreshUrl: Optional[str] = None
class OAuth2ClientCredentialsFlow(BaseModel):
tokenUrl: str
scopes: Dict[str, str] = Field(default_factory=dict)
refreshUrl: Optional[str] = None
class OAuth2DeviceCodeFlow(BaseModel):
authorizationUrl: str
tokenUrl: str
scopes: Dict[str, str] = Field(default_factory=dict)
deviceAuthorizationUrl: Optional[str] = None
class OAuth2Flows(BaseModel):
authorizationCode: Optional[OAuth2AuthorizationCodeFlow] = None
clientCredentials: Optional[OAuth2ClientCredentialsFlow] = None
deviceCode: Optional[OAuth2DeviceCodeFlow] = None
implicit: Optional[Dict[str, Any]] = None
password: Optional[Dict[str, Any]] = None
class SecuritySchemeOAuth2(BaseModel):
type: Literal["oauth2"] = "oauth2"
flows: OAuth2Flows
description: Optional[str] = None
class SecuritySchemeOpenIdConnect(BaseModel):
type: Literal["openIdConnect"] = "openIdConnect"
openIdConnectUrl: str
description: Optional[str] = None
scopes: Dict[str, str] = Field(default_factory=dict)
class SecuritySchemeMtls(BaseModel):
type: Literal["mutualTLS"] = "mutualTLS"
description: Optional[str] = None
caCerts: Optional[List[str]] = None
clientCert: Optional[str] = None
clientKey: Optional[str] = None
class AgentCardPublicSchema(BaseModel):
name: str
protocol_version: str = "1.0"
capabilities: List[str]
endpoints: Dict[str, str]
auth: List[str]
skills: List[AgentSkill] = Field(default_factory=list)
provider: Optional[AgentProvider] = None
supportedInterfaces: List[AgentSupportedInterface] = Field(default_factory=list)
defaultInputModes: List[str] = Field(default_factory=list)
defaultOutputModes: List[str] = Field(default_factory=list)
iconUrl: Optional[str] = None
documentationUrl: Optional[str] = None
class AgentCardExtendedSchema(AgentCardPublicSchema):
securitySchemes: Optional[Dict[str, Union[SecuritySchemeApiKey, SecuritySchemeHttpAuth, SecuritySchemeOAuth2, SecuritySchemeOpenIdConnect, SecuritySchemeMtls]]] = None
security: List[Dict[str, List[str]]] = Field(default_factory=list)
signatures: List[str] = Field(default_factory=list)
tenantId: Optional[int] = None
isAdmin: Optional[bool] = None
+341 -21
View File
@@ -4,6 +4,7 @@ import asyncio
import hashlib
import hmac
import json
import ssl
import time
import uuid
from collections import defaultdict, deque
@@ -24,20 +25,15 @@ from app.models.a2a import (
A2AWebhookDelivery,
)
from app.models.project import Project
from app.schemas.a2a import (
A2AArtifactSchema,
A2APartSchema,
A2ATaskStatusSchema,
TaskArtifactUpdateEvent,
TaskStatusUpdateEvent,
)
from app.trace import build_error_attributes, trace_service
_STATE_TRANSITIONS = {
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
"REJECTED": set(),
"FAILED": set(),
"COMPLETED": set(),
"CANCELED": set(),
}
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
def _json_loads(raw: Optional[str], default: Any) -> Any:
if not raw:
@@ -62,6 +58,292 @@ def _mask_error(message: str) -> str:
return "request_failed"
class SharedSecretAuth:
@staticmethod
def generate_signature(secret: str, payload: bytes, timestamp: Optional[int] = None) -> Tuple[str, int]:
if timestamp is None:
timestamp = int(time.time())
message = f"{timestamp}".encode() + payload
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
return f"sha256={signature}", timestamp
@staticmethod
def verify_signature(secret: str, payload: bytes, signature: str, timestamp: int, max_age_seconds: int = 300) -> bool:
if abs(time.time() - timestamp) > max_age_seconds:
return False
expected_sig, _ = SharedSecretAuth.generate_signature(secret, payload, timestamp)
return hmac.compare_digest(signature, expected_sig)
@staticmethod
def sign_request(secret: str, method: str, path: str, body: Optional[bytes] = None) -> Dict[str, str]:
timestamp = int(time.time())
payload = body or b""
message = f"{timestamp}.{method.upper()}.{path}".encode() + payload
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
return {
"X-A2A-Signature": f"sha256={signature}",
"X-A2A-Timestamp": str(timestamp),
}
class MtlsConfig:
def __init__(
self,
ca_cert: Optional[str] = None,
client_cert: Optional[str] = None,
client_key: Optional[str] = None,
):
self.ca_cert = ca_cert
self.client_cert = client_cert
self.client_key = client_key
def create_ssl_context(self) -> Optional[ssl.SSLContext]:
if not self.client_cert or not self.client_key:
return None
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_cert_chain(self.client_cert, self.client_key)
if self.ca_cert:
ctx.load_verify_locations(self.ca_cert)
ctx.verify_mode = ssl.CERT_REQUIRED
else:
ctx.verify_mode = ssl.CERT_NONE
return ctx
class OAuth2TokenStore:
def __init__(self):
self._tokens: Dict[str, Tuple[str, datetime]] = {}
self._lock = asyncio.Lock()
async def get_token(self, key: str) -> Optional[str]:
async with self._lock:
if key in self._tokens:
token, expires_at = self._tokens[key]
if expires_at > _utc_now() + timedelta(minutes=1):
return token
return None
async def set_token(self, key: str, token: str, expires_in: int = 3600) -> None:
async with self._lock:
self._tokens[key] = (token, _utc_now() + timedelta(seconds=expires_in))
class OAuth2Auth:
def __init__(
self,
client_id: str,
client_secret: str,
token_url: str,
scopes: Optional[List[str]] = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.scopes = scopes or []
self._token_store = OAuth2TokenStore()
def _get_cache_key(self) -> str:
return f"{self.client_id}:{self.token_url}:{':'.join(self.scopes)}"
async def get_access_token(self, grant_type: str = "client_credentials") -> str:
cache_key = self._get_cache_key()
cached = await self._token_store.get_token(cache_key)
if cached:
return cached
async with httpx.AsyncClient() as client:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": grant_type,
}
if self.scopes:
data["scope"] = " ".join(self.scopes)
resp = await client.post(self.token_url, data=data)
resp.raise_for_status()
token_data = resp.json()
token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
await self._token_store.set_token(cache_key, token, expires_in)
return token
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
token = await self.get_access_token()
return {"Authorization": f"Bearer {token}"}
class OIDCAuth:
def __init__(
self,
issuer_url: str,
client_id: str,
client_secret: Optional[str] = None,
scopes: Optional[List[str]] = None,
):
self.issuer_url = issuer_url.rstrip("/")
self.client_id = client_id
self.client_secret = client_secret
self.scopes = scopes or ["openid", "profile"]
self._oauth2: Optional[OAuth2Auth] = None
self._discovery_cache: Optional[Dict[str, Any]] = None
async def _get_discovery(self) -> Dict[str, Any]:
if self._discovery_cache:
return self._discovery_cache
discovery_url = f"{self.issuer_url}/.well-known/openid-configuration"
async with httpx.AsyncClient() as client:
resp = await client.get(discovery_url)
resp.raise_for_status()
self._discovery_cache = resp.json()
return self._discovery_cache
async def get_access_token(self) -> str:
discovery = await self._get_discovery()
token_url = discovery.get("token_endpoint")
if not token_url:
raise RuntimeError("OIDC discovery missing token_endpoint")
if not self._oauth2:
self._oauth2 = OAuth2Auth(
client_id=self.client_id,
client_secret=self.client_secret or "",
token_url=token_url,
scopes=self.scopes,
)
return await self._oauth2.get_access_token()
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
token = await self.get_access_token()
return {"Authorization": f"Bearer {token}"}
class RemoteAgentSecuritySelector:
def __init__(self, agent: A2ARemoteAgent):
self.agent = agent
self._card_security_schemes: Optional[Dict[str, Any]] = None
def load_security_from_card(self) -> None:
card = _json_loads(self.agent.card_json, {})
if card:
self._card_security_schemes = card.get("securitySchemes", {})
def get_preferred_auth_scheme(self) -> str:
card = _json_loads(self.agent.card_json, {})
security_reqs = card.get("security", [])
if security_reqs:
first_req = security_reqs[0]
if isinstance(first_req, dict):
for scheme_name in first_req.keys():
return scheme_name
return self.agent.auth_scheme or "bearer"
def get_auth_headers(self, user_token: Optional[str] = None) -> Dict[str, str]:
headers: Dict[str, str] = {}
preferred = self.get_preferred_auth_scheme()
if preferred == "bearer" or self.agent.auth_scheme == "bearer":
if self.agent.auth_token:
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
elif user_token:
headers["Authorization"] = f"Bearer {user_token}"
elif preferred == "shared_secret" or self.agent.auth_scheme == "shared_secret":
pass
elif preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
pass
elif preferred == "openIdConnect":
pass
elif preferred == "mutualTLS":
pass
return headers
def get_mtls_context(self) -> Optional[ssl.SSLContext]:
if self.agent.auth_scheme == "mutualTLS" or self.get_preferred_auth_scheme() == "mutualTLS":
if self.agent.mtls_client_cert and self.agent.mtls_client_key:
config = MtlsConfig(
ca_cert=self.agent.mtls_ca_cert,
client_cert=self.agent.mtls_client_cert,
client_key=self.agent.mtls_client_key,
)
return config.create_ssl_context()
return None
def create_signed_request_headers(
self,
method: str,
path: str,
body: Optional[bytes] = None,
) -> Dict[str, str]:
headers: Dict[str, str] = {}
preferred = self.get_preferred_auth_scheme()
if preferred == "shared_secret" and self.agent.shared_secret:
sig_headers = SharedSecretAuth.sign_request(
self.agent.shared_secret,
method,
path,
body,
)
headers.update(sig_headers)
elif self.agent.auth_scheme == "bearer" and self.agent.auth_token:
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
return headers
async def get_oauth2_auth(self) -> Optional[OAuth2Auth]:
if self.agent.oauth2_client_id and self.agent.oauth2_token_url:
scopes = self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else []
return OAuth2Auth(
client_id=self.agent.oauth2_client_id,
client_secret=self.agent.oauth2_client_secret or "",
token_url=self.agent.oauth2_token_url,
scopes=scopes,
)
return None
async def get_oidc_auth(self) -> Optional[OIDCAuth]:
if self.agent.oidc_issuer_url:
return OIDCAuth(
issuer_url=self.agent.oidc_issuer_url,
client_id=self.agent.oidc_client_id or "",
client_secret=self.agent.oidc_client_secret,
scopes=self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else [],
)
return None
async def authorize_request(self, method: str, url: str, user_token: Optional[str] = None) -> Dict[str, str]:
headers = self.get_auth_headers(user_token)
preferred = self.get_preferred_auth_scheme()
if preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
oauth2_auth = await self.get_oauth2_auth()
if oauth2_auth:
headers.update(await oauth2_auth.authorize_request(method, url))
elif preferred == "openIdConnect":
oidc_auth = await self.get_oidc_auth()
if oidc_auth:
headers.update(await oidc_auth.authorize_request(method, url))
return headers
_STATE_TRANSITIONS = {
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
"REJECTED": set(),
"FAILED": set(),
"COMPLETED": set(),
"CANCELED": set(),
}
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
@dataclass
class A2AResolvedRoute:
selected: str
@@ -185,6 +467,7 @@ class A2ARuntime:
remote_agent_id: Optional[int],
compatibility_mode: bool,
metadata: Optional[Dict[str, Any]] = None,
context_id: Optional[str] = None,
) -> A2ATask:
if idempotency_key:
existing = (
@@ -209,6 +492,7 @@ class A2ARuntime:
idempotency_key=idempotency_key,
compatibility_mode=compatibility_mode,
metadata_json=_json_dumps(metadata or {}),
context_id=context_id,
)
db.add(task)
db.commit()
@@ -335,23 +619,21 @@ class A2ARuntime:
async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None:
event_payload = _json_loads(event.payload_json, {})
request_payload = {
"task_id": event.task_id,
"event_type": event.event_type,
"event_id": event.id,
"payload": event_payload,
}
body = _json_dumps(request_payload).encode("utf-8")
stream_response_payload = self._build_stream_response_payload(event, event_payload)
body = _json_dumps(stream_response_payload).encode("utf-8")
for attempt in range(1, 5):
delivery.attempt = attempt
db.add(delivery)
db.commit()
headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)}
if hook.secret:
digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest()
headers["X-A2A-Signature"] = f"sha256={digest}"
if hook.auth_header:
headers["Authorization"] = hook.auth_header
try:
async with httpx.AsyncClient(timeout=8.0, verify=True) as client:
resp = await client.post(hook.target_url, content=body, headers=headers)
@@ -368,11 +650,12 @@ class A2ARuntime:
except Exception as exc:
delivery.error_message = str(exc)[:500]
if attempt < 4:
backoff_seconds = 2 ** attempt
delivery.status = "RETRYING"
delivery.next_retry_at = _utc_now() + timedelta(seconds=2 ** attempt)
delivery.next_retry_at = _utc_now() + timedelta(seconds=backoff_seconds)
db.add(delivery)
db.commit()
await asyncio.sleep(2 ** attempt)
await asyncio.sleep(backoff_seconds)
continue
delivery.status = "FAILED"
delivery.dead_letter = True
@@ -380,5 +663,42 @@ class A2ARuntime:
db.commit()
return
def _build_stream_response_payload(self, event: A2ATaskEvent, event_payload: Dict[str, Any]) -> Dict[str, Any]:
event_type = event.event_type
task_id = event_payload.get("task_id", event.task_id)
if event_type == "TaskStatusUpdateEvent":
status_state = event_payload.get("task_status", "WORKING")
status_timestamp = event_payload.get("timestamp", _utc_now().isoformat())
status_schema = A2ATaskStatusSchema(
state=status_state,
timestamp=datetime.fromisoformat(status_timestamp) if isinstance(status_timestamp, str) else status_timestamp,
)
return {
"statusUpdate": TaskStatusUpdateEvent(
taskId=task_id,
contextId=event_payload.get("context_id"),
status=status_schema,
metadata=event_payload.get("metadata", {}),
).model_dump()
}
elif event_type == "TaskArtifactUpdateEvent":
artifact_content = event_payload.get("artifact", {}).get("content", "")
artifact_schema = A2AArtifactSchema(
artifactId=f"artifact-{event.id}",
parts=[A2APartSchema(part_type="text", text=artifact_content)],
)
return {
"artifactUpdate": TaskArtifactUpdateEvent(
taskId=task_id,
contextId=event_payload.get("context_id"),
artifact=artifact_schema,
append=False,
lastChunk=True,
).model_dump()
}
else:
return {"message": event_payload}
a2a_runtime = A2ARuntime()
File diff suppressed because it is too large Load Diff