refactor: a2a
This commit is contained in:
+1113
-151
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,34 @@
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, JSON, Enum as SQLEnum, func
|
||||
from sqlalchemy.orm import relationship
|
||||
import enum
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class A2ATaskState(str, enum.Enum):
|
||||
SUBMITTED = "SUBMITTED"
|
||||
WORKING = "WORKING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
INPUT_REQUIRED = "INPUT_REQUIRED"
|
||||
AUTH_REQUIRED = "AUTH_REQUIRED"
|
||||
REJECTED = "REJECTED"
|
||||
|
||||
|
||||
class A2APartType(str, enum.Enum):
|
||||
TEXT = "text"
|
||||
RAW = "raw"
|
||||
URL = "url"
|
||||
DATA = "data"
|
||||
|
||||
|
||||
class A2AMessageRole(str, enum.Enum):
|
||||
USER = "user"
|
||||
AGENT = "agent"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class A2ARemoteAgent(Base):
|
||||
__tablename__ = "a2a_remote_agents"
|
||||
|
||||
@@ -13,6 +38,17 @@ class A2ARemoteAgent(Base):
|
||||
base_url = Column(String, nullable=False)
|
||||
auth_scheme = Column(String, nullable=False, default="none")
|
||||
auth_token = Column(String, nullable=True)
|
||||
shared_secret = Column(String, nullable=True)
|
||||
mtls_ca_cert = Column(Text, nullable=True)
|
||||
mtls_client_cert = Column(Text, nullable=True)
|
||||
mtls_client_key = Column(Text, nullable=True)
|
||||
oauth2_client_id = Column(String, nullable=True)
|
||||
oauth2_client_secret = Column(String, nullable=True)
|
||||
oauth2_token_url = Column(String, nullable=True)
|
||||
oauth2_scopes = Column(String, nullable=True)
|
||||
oidc_issuer_url = Column(String, nullable=True)
|
||||
oidc_client_id = Column(String, nullable=True)
|
||||
oidc_client_secret = Column(String, nullable=True)
|
||||
protocol_version = Column(String, nullable=True)
|
||||
capabilities_json = Column(Text, nullable=False, default="[]")
|
||||
card_json = Column(Text, nullable=True)
|
||||
@@ -27,27 +63,84 @@ class A2ARemoteAgent(Base):
|
||||
project = relationship("Project")
|
||||
|
||||
|
||||
class A2APart(Base):
|
||||
__tablename__ = "a2a_parts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
message_id = Column(Integer, ForeignKey("a2a_messages.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
artifact_id = Column(Integer, ForeignKey("a2a_artifacts.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
part_type = Column(SQLEnum(A2APartType), nullable=False)
|
||||
text_content = Column(Text, nullable=True)
|
||||
raw_content = Column(Text, nullable=True)
|
||||
url_content = Column(String, nullable=True)
|
||||
data_content = Column(Text, nullable=True)
|
||||
media_type = Column(String, nullable=True)
|
||||
filename = Column(String, nullable=True)
|
||||
metadata_json = Column(Text, nullable=False, default="{}")
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
|
||||
message = relationship("A2AMessage", back_populates="parts", foreign_keys=[message_id])
|
||||
artifact = relationship("A2AArtifact", back_populates="parts", foreign_keys=[artifact_id])
|
||||
|
||||
|
||||
class A2AMessage(Base):
|
||||
__tablename__ = "a2a_messages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
message_id = Column(String, nullable=False, unique=True, index=True)
|
||||
context_id = Column(String, nullable=True, index=True)
|
||||
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
role = Column(SQLEnum(A2AMessageRole), nullable=False)
|
||||
extensions_json = Column(Text, nullable=False, default="{}")
|
||||
reference_task_ids_json = Column(Text, nullable=False, default="[]")
|
||||
created_at = Column(DateTime, default=func.now(), index=True)
|
||||
|
||||
task = relationship("A2ATask", back_populates="messages", foreign_keys=[task_id])
|
||||
parts = relationship("A2APart", back_populates="message", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class A2AArtifact(Base):
|
||||
__tablename__ = "a2a_artifacts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
artifact_id = Column(String, nullable=False, unique=True, index=True)
|
||||
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
name = Column(String, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
metadata_json = Column(Text, nullable=False, default="{}")
|
||||
extensions_json = Column(Text, nullable=False, default="{}")
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
task = relationship("A2ATask", back_populates="artifacts")
|
||||
parts = relationship("A2APart", back_populates="artifact", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class A2ATask(Base):
|
||||
__tablename__ = "a2a_tasks"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
|
||||
tenant_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
context_id = Column(String, nullable=True, index=True)
|
||||
source = Column(String, nullable=False, default="local")
|
||||
remote_agent_id = Column(Integer, ForeignKey("a2a_remote_agents.id"), nullable=True, index=True)
|
||||
idempotency_key = Column(String, nullable=True, index=True)
|
||||
state = Column(String, nullable=False, index=True, default="SUBMITTED")
|
||||
state = Column(SQLEnum(A2ATaskState), nullable=False, index=True, default=A2ATaskState.SUBMITTED)
|
||||
input_text = Column(Text, nullable=False, default="")
|
||||
output_text = Column(Text, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
compatibility_mode = Column(Boolean, nullable=False, default=True)
|
||||
metadata_json = Column(Text, nullable=False, default="{}")
|
||||
history_length = Column(Integer, nullable=False, default=0)
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
finished_at = Column(DateTime, nullable=True)
|
||||
|
||||
project = relationship("Project")
|
||||
remote_agent = relationship("A2ARemoteAgent")
|
||||
messages = relationship("A2AMessage", back_populates="task", cascade="all, delete-orphan", foreign_keys=[A2AMessage.task_id])
|
||||
artifacts = relationship("A2AArtifact", back_populates="task", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class A2ATaskEvent(Base):
|
||||
|
||||
@@ -0,0 +1,361 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing import Optional, List, Dict, Any, Literal, Union
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class A2ATaskState(str, Enum):
|
||||
SUBMITTED = "SUBMITTED"
|
||||
WORKING = "WORKING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
INPUT_REQUIRED = "INPUT_REQUIRED"
|
||||
AUTH_REQUIRED = "AUTH_REQUIRED"
|
||||
REJECTED = "REJECTED"
|
||||
|
||||
|
||||
class A2APartType(str, Enum):
|
||||
TEXT = "text"
|
||||
RAW = "raw"
|
||||
URL = "url"
|
||||
DATA = "data"
|
||||
|
||||
|
||||
class A2AMessageRole(str, Enum):
|
||||
USER = "user"
|
||||
AGENT = "agent"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class A2APartSchema(BaseModel):
|
||||
part_type: A2APartType
|
||||
text: Optional[str] = None
|
||||
raw: Optional[bytes] = None
|
||||
url: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
mediaType: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class A2APartCreateSchema(BaseModel):
|
||||
part_type: A2APartType
|
||||
text: Optional[str] = None
|
||||
raw: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
mediaType: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class A2AMessageSchema(BaseModel):
|
||||
messageId: str
|
||||
contextId: Optional[str] = None
|
||||
taskId: Optional[str] = None
|
||||
role: A2AMessageRole
|
||||
parts: List[A2APartSchema] = Field(default_factory=list)
|
||||
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
referenceTaskIds: Optional[List[str]] = Field(default_factory=list)
|
||||
createdAt: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class A2AMessageCreateSchema(BaseModel):
|
||||
messageId: str
|
||||
contextId: Optional[str] = None
|
||||
taskId: Optional[str] = None
|
||||
role: A2AMessageRole
|
||||
parts: List[A2APartCreateSchema] = Field(default_factory=list)
|
||||
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
referenceTaskIds: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class A2AArtifactSchema(BaseModel):
|
||||
artifactId: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
parts: List[A2APartSchema] = Field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
createdAt: Optional[datetime] = None
|
||||
updatedAt: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class A2AArtifactCreateSchema(BaseModel):
|
||||
artifactId: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
parts: List[A2APartCreateSchema] = Field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class A2ATaskStatusSchema(BaseModel):
|
||||
state: A2ATaskState
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class A2ATaskSchema(BaseModel):
|
||||
id: str
|
||||
contextId: Optional[str] = None
|
||||
projectId: int
|
||||
tenantId: int
|
||||
source: str
|
||||
remoteAgentId: Optional[int] = None
|
||||
idempotencyKey: Optional[str] = None
|
||||
state: A2ATaskState
|
||||
inputText: str
|
||||
outputText: Optional[str] = None
|
||||
errorMessage: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
historyLength: int = 0
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
finishedAt: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class A2ATaskWithMessagesSchema(A2ATaskSchema):
|
||||
messages: List[A2AMessageSchema] = Field(default_factory=list)
|
||||
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
|
||||
|
||||
|
||||
class A2ATaskWithHistorySchema(BaseModel):
|
||||
id: str
|
||||
contextId: Optional[str] = None
|
||||
projectId: int
|
||||
tenantId: int
|
||||
state: A2ATaskState
|
||||
history: List[A2AMessageSchema] = Field(default_factory=list)
|
||||
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
finishedAt: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
taskId: str
|
||||
contextId: Optional[str] = None
|
||||
status: A2ATaskStatusSchema
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
taskId: str
|
||||
contextId: Optional[str] = None
|
||||
artifact: A2AArtifactSchema
|
||||
append: bool = False
|
||||
lastChunk: bool = True
|
||||
|
||||
|
||||
class TaskMessageEvent(BaseModel):
|
||||
message: A2AMessageSchema
|
||||
|
||||
|
||||
class StreamResponseTask(BaseModel):
|
||||
id: str
|
||||
contextId: Optional[str] = None
|
||||
state: A2ATaskState
|
||||
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
task: Optional[StreamResponseTask] = None
|
||||
message: Optional[A2AMessageSchema] = None
|
||||
statusUpdate: Optional[TaskStatusUpdateEvent] = None
|
||||
artifactUpdate: Optional[TaskArtifactUpdateEvent] = None
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
message: A2AMessageCreateSchema
|
||||
taskId: Optional[str] = None
|
||||
contextId: Optional[str] = None
|
||||
|
||||
|
||||
class SendStreamingMessageRequest(BaseModel):
|
||||
message: A2AMessageCreateSchema
|
||||
taskId: Optional[str] = None
|
||||
contextId: Optional[str] = None
|
||||
|
||||
|
||||
class GetTaskRequest(BaseModel):
|
||||
historyLength: Optional[int] = None
|
||||
|
||||
|
||||
class TaskListRequest(BaseModel):
|
||||
contextId: Optional[str] = None
|
||||
status: Optional[A2ATaskState] = None
|
||||
pageSize: int = 20
|
||||
pageToken: Optional[str] = None
|
||||
|
||||
|
||||
class CancelTaskRequest(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class PushNotificationConfigCreate(BaseModel):
|
||||
targetUrl: str
|
||||
secret: Optional[str] = None
|
||||
authHeader: Optional[str] = None
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
id: int
|
||||
taskId: str
|
||||
targetUrl: str
|
||||
secret: Optional[str] = None
|
||||
authHeader: Optional[str] = None
|
||||
enabled: bool
|
||||
createdBy: int
|
||||
createdAt: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class VersionNotSupportedError(BaseModel):
|
||||
code: int = -32009
|
||||
message: str = "Version not supported"
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AgentSkillInputMode(str, Enum):
|
||||
TEXT = "text"
|
||||
DATA = "data"
|
||||
RAW = "raw"
|
||||
URL = "url"
|
||||
|
||||
|
||||
class AgentSkillOutputMode(str, Enum):
|
||||
TEXT = "text"
|
||||
DATA = "data"
|
||||
ARTIFACT = "artifact"
|
||||
STREAM = "stream"
|
||||
|
||||
|
||||
class AgentSkillSecurityRequirement(BaseModel):
|
||||
scheme: str
|
||||
scopes: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AgentSkillExample(BaseModel):
|
||||
input: Dict[str, Any]
|
||||
output: Dict[str, Any]
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
examples: List[AgentSkillExample] = Field(default_factory=list)
|
||||
inputModes: List[AgentSkillInputMode] = Field(default_factory=list)
|
||||
outputModes: List[AgentSkillOutputMode] = Field(default_factory=list)
|
||||
securityRequirements: List[AgentSkillSecurityRequirement] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
organization: str
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class AgentSupportedInterface(BaseModel):
|
||||
url: str
|
||||
protocolBinding: str
|
||||
protocolVersion: str
|
||||
tenant: Optional[str] = None
|
||||
|
||||
|
||||
class SecuritySchemeApiKey(BaseModel):
|
||||
type: Literal["apiKey"] = "apiKey"
|
||||
name: str
|
||||
in_: str = Field(alias="in")
|
||||
description: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class SecuritySchemeHttpAuth(BaseModel):
|
||||
type: Literal["http"] = "http"
|
||||
scheme: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuth2AuthorizationCodeFlow(BaseModel):
|
||||
authorizationUrl: str
|
||||
tokenUrl: str
|
||||
scopes: Dict[str, str] = Field(default_factory=dict)
|
||||
refreshUrl: Optional[str] = None
|
||||
|
||||
|
||||
class OAuth2ClientCredentialsFlow(BaseModel):
|
||||
tokenUrl: str
|
||||
scopes: Dict[str, str] = Field(default_factory=dict)
|
||||
refreshUrl: Optional[str] = None
|
||||
|
||||
|
||||
class OAuth2DeviceCodeFlow(BaseModel):
|
||||
authorizationUrl: str
|
||||
tokenUrl: str
|
||||
scopes: Dict[str, str] = Field(default_factory=dict)
|
||||
deviceAuthorizationUrl: Optional[str] = None
|
||||
|
||||
|
||||
class OAuth2Flows(BaseModel):
|
||||
authorizationCode: Optional[OAuth2AuthorizationCodeFlow] = None
|
||||
clientCredentials: Optional[OAuth2ClientCredentialsFlow] = None
|
||||
deviceCode: Optional[OAuth2DeviceCodeFlow] = None
|
||||
implicit: Optional[Dict[str, Any]] = None
|
||||
password: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class SecuritySchemeOAuth2(BaseModel):
|
||||
type: Literal["oauth2"] = "oauth2"
|
||||
flows: OAuth2Flows
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class SecuritySchemeOpenIdConnect(BaseModel):
|
||||
type: Literal["openIdConnect"] = "openIdConnect"
|
||||
openIdConnectUrl: str
|
||||
description: Optional[str] = None
|
||||
scopes: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SecuritySchemeMtls(BaseModel):
|
||||
type: Literal["mutualTLS"] = "mutualTLS"
|
||||
description: Optional[str] = None
|
||||
caCerts: Optional[List[str]] = None
|
||||
clientCert: Optional[str] = None
|
||||
clientKey: Optional[str] = None
|
||||
|
||||
|
||||
class AgentCardPublicSchema(BaseModel):
|
||||
name: str
|
||||
protocol_version: str = "1.0"
|
||||
capabilities: List[str]
|
||||
endpoints: Dict[str, str]
|
||||
auth: List[str]
|
||||
skills: List[AgentSkill] = Field(default_factory=list)
|
||||
provider: Optional[AgentProvider] = None
|
||||
supportedInterfaces: List[AgentSupportedInterface] = Field(default_factory=list)
|
||||
defaultInputModes: List[str] = Field(default_factory=list)
|
||||
defaultOutputModes: List[str] = Field(default_factory=list)
|
||||
iconUrl: Optional[str] = None
|
||||
documentationUrl: Optional[str] = None
|
||||
|
||||
|
||||
class AgentCardExtendedSchema(AgentCardPublicSchema):
|
||||
securitySchemes: Optional[Dict[str, Union[SecuritySchemeApiKey, SecuritySchemeHttpAuth, SecuritySchemeOAuth2, SecuritySchemeOpenIdConnect, SecuritySchemeMtls]]] = None
|
||||
security: List[Dict[str, List[str]]] = Field(default_factory=list)
|
||||
signatures: List[str] = Field(default_factory=list)
|
||||
tenantId: Optional[int] = None
|
||||
isAdmin: Optional[bool] = None
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
@@ -24,20 +25,15 @@ from app.models.a2a import (
|
||||
A2AWebhookDelivery,
|
||||
)
|
||||
from app.models.project import Project
|
||||
from app.schemas.a2a import (
|
||||
A2AArtifactSchema,
|
||||
A2APartSchema,
|
||||
A2ATaskStatusSchema,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from app.trace import build_error_attributes, trace_service
|
||||
|
||||
_STATE_TRANSITIONS = {
|
||||
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
|
||||
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
|
||||
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
|
||||
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
|
||||
"REJECTED": set(),
|
||||
"FAILED": set(),
|
||||
"COMPLETED": set(),
|
||||
"CANCELED": set(),
|
||||
}
|
||||
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
|
||||
|
||||
|
||||
def _json_loads(raw: Optional[str], default: Any) -> Any:
|
||||
if not raw:
|
||||
@@ -62,6 +58,292 @@ def _mask_error(message: str) -> str:
|
||||
return "request_failed"
|
||||
|
||||
|
||||
class SharedSecretAuth:
|
||||
@staticmethod
|
||||
def generate_signature(secret: str, payload: bytes, timestamp: Optional[int] = None) -> Tuple[str, int]:
|
||||
if timestamp is None:
|
||||
timestamp = int(time.time())
|
||||
message = f"{timestamp}".encode() + payload
|
||||
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
|
||||
return f"sha256={signature}", timestamp
|
||||
|
||||
@staticmethod
|
||||
def verify_signature(secret: str, payload: bytes, signature: str, timestamp: int, max_age_seconds: int = 300) -> bool:
|
||||
if abs(time.time() - timestamp) > max_age_seconds:
|
||||
return False
|
||||
expected_sig, _ = SharedSecretAuth.generate_signature(secret, payload, timestamp)
|
||||
return hmac.compare_digest(signature, expected_sig)
|
||||
|
||||
@staticmethod
|
||||
def sign_request(secret: str, method: str, path: str, body: Optional[bytes] = None) -> Dict[str, str]:
|
||||
timestamp = int(time.time())
|
||||
payload = body or b""
|
||||
message = f"{timestamp}.{method.upper()}.{path}".encode() + payload
|
||||
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
|
||||
return {
|
||||
"X-A2A-Signature": f"sha256={signature}",
|
||||
"X-A2A-Timestamp": str(timestamp),
|
||||
}
|
||||
|
||||
|
||||
class MtlsConfig:
|
||||
def __init__(
|
||||
self,
|
||||
ca_cert: Optional[str] = None,
|
||||
client_cert: Optional[str] = None,
|
||||
client_key: Optional[str] = None,
|
||||
):
|
||||
self.ca_cert = ca_cert
|
||||
self.client_cert = client_cert
|
||||
self.client_key = client_key
|
||||
|
||||
def create_ssl_context(self) -> Optional[ssl.SSLContext]:
|
||||
if not self.client_cert or not self.client_key:
|
||||
return None
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_cert_chain(self.client_cert, self.client_key)
|
||||
if self.ca_cert:
|
||||
ctx.load_verify_locations(self.ca_cert)
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
else:
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
return ctx
|
||||
|
||||
|
||||
class OAuth2TokenStore:
|
||||
def __init__(self):
|
||||
self._tokens: Dict[str, Tuple[str, datetime]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_token(self, key: str) -> Optional[str]:
|
||||
async with self._lock:
|
||||
if key in self._tokens:
|
||||
token, expires_at = self._tokens[key]
|
||||
if expires_at > _utc_now() + timedelta(minutes=1):
|
||||
return token
|
||||
return None
|
||||
|
||||
async def set_token(self, key: str, token: str, expires_in: int = 3600) -> None:
|
||||
async with self._lock:
|
||||
self._tokens[key] = (token, _utc_now() + timedelta(seconds=expires_in))
|
||||
|
||||
|
||||
class OAuth2Auth:
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_url: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_url = token_url
|
||||
self.scopes = scopes or []
|
||||
self._token_store = OAuth2TokenStore()
|
||||
|
||||
def _get_cache_key(self) -> str:
|
||||
return f"{self.client_id}:{self.token_url}:{':'.join(self.scopes)}"
|
||||
|
||||
async def get_access_token(self, grant_type: str = "client_credentials") -> str:
|
||||
cache_key = self._get_cache_key()
|
||||
cached = await self._token_store.get_token(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"grant_type": grant_type,
|
||||
}
|
||||
if self.scopes:
|
||||
data["scope"] = " ".join(self.scopes)
|
||||
resp = await client.post(self.token_url, data=data)
|
||||
resp.raise_for_status()
|
||||
token_data = resp.json()
|
||||
token = token_data["access_token"]
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
await self._token_store.set_token(cache_key, token, expires_in)
|
||||
return token
|
||||
|
||||
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
|
||||
token = await self.get_access_token()
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
class OIDCAuth:
|
||||
def __init__(
|
||||
self,
|
||||
issuer_url: str,
|
||||
client_id: str,
|
||||
client_secret: Optional[str] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
):
|
||||
self.issuer_url = issuer_url.rstrip("/")
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.scopes = scopes or ["openid", "profile"]
|
||||
self._oauth2: Optional[OAuth2Auth] = None
|
||||
self._discovery_cache: Optional[Dict[str, Any]] = None
|
||||
|
||||
async def _get_discovery(self) -> Dict[str, Any]:
|
||||
if self._discovery_cache:
|
||||
return self._discovery_cache
|
||||
discovery_url = f"{self.issuer_url}/.well-known/openid-configuration"
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(discovery_url)
|
||||
resp.raise_for_status()
|
||||
self._discovery_cache = resp.json()
|
||||
return self._discovery_cache
|
||||
|
||||
async def get_access_token(self) -> str:
|
||||
discovery = await self._get_discovery()
|
||||
token_url = discovery.get("token_endpoint")
|
||||
if not token_url:
|
||||
raise RuntimeError("OIDC discovery missing token_endpoint")
|
||||
if not self._oauth2:
|
||||
self._oauth2 = OAuth2Auth(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret or "",
|
||||
token_url=token_url,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
return await self._oauth2.get_access_token()
|
||||
|
||||
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
|
||||
token = await self.get_access_token()
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
class RemoteAgentSecuritySelector:
|
||||
def __init__(self, agent: A2ARemoteAgent):
|
||||
self.agent = agent
|
||||
self._card_security_schemes: Optional[Dict[str, Any]] = None
|
||||
|
||||
def load_security_from_card(self) -> None:
|
||||
card = _json_loads(self.agent.card_json, {})
|
||||
if card:
|
||||
self._card_security_schemes = card.get("securitySchemes", {})
|
||||
|
||||
def get_preferred_auth_scheme(self) -> str:
|
||||
card = _json_loads(self.agent.card_json, {})
|
||||
security_reqs = card.get("security", [])
|
||||
if security_reqs:
|
||||
first_req = security_reqs[0]
|
||||
if isinstance(first_req, dict):
|
||||
for scheme_name in first_req.keys():
|
||||
return scheme_name
|
||||
return self.agent.auth_scheme or "bearer"
|
||||
|
||||
def get_auth_headers(self, user_token: Optional[str] = None) -> Dict[str, str]:
|
||||
headers: Dict[str, str] = {}
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred == "bearer" or self.agent.auth_scheme == "bearer":
|
||||
if self.agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
|
||||
elif user_token:
|
||||
headers["Authorization"] = f"Bearer {user_token}"
|
||||
|
||||
elif preferred == "shared_secret" or self.agent.auth_scheme == "shared_secret":
|
||||
pass
|
||||
|
||||
elif preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
|
||||
pass
|
||||
|
||||
elif preferred == "openIdConnect":
|
||||
pass
|
||||
|
||||
elif preferred == "mutualTLS":
|
||||
pass
|
||||
|
||||
return headers
|
||||
|
||||
def get_mtls_context(self) -> Optional[ssl.SSLContext]:
|
||||
if self.agent.auth_scheme == "mutualTLS" or self.get_preferred_auth_scheme() == "mutualTLS":
|
||||
if self.agent.mtls_client_cert and self.agent.mtls_client_key:
|
||||
config = MtlsConfig(
|
||||
ca_cert=self.agent.mtls_ca_cert,
|
||||
client_cert=self.agent.mtls_client_cert,
|
||||
client_key=self.agent.mtls_client_key,
|
||||
)
|
||||
return config.create_ssl_context()
|
||||
return None
|
||||
|
||||
def create_signed_request_headers(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
body: Optional[bytes] = None,
|
||||
) -> Dict[str, str]:
|
||||
headers: Dict[str, str] = {}
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred == "shared_secret" and self.agent.shared_secret:
|
||||
sig_headers = SharedSecretAuth.sign_request(
|
||||
self.agent.shared_secret,
|
||||
method,
|
||||
path,
|
||||
body,
|
||||
)
|
||||
headers.update(sig_headers)
|
||||
|
||||
elif self.agent.auth_scheme == "bearer" and self.agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def get_oauth2_auth(self) -> Optional[OAuth2Auth]:
|
||||
if self.agent.oauth2_client_id and self.agent.oauth2_token_url:
|
||||
scopes = self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else []
|
||||
return OAuth2Auth(
|
||||
client_id=self.agent.oauth2_client_id,
|
||||
client_secret=self.agent.oauth2_client_secret or "",
|
||||
token_url=self.agent.oauth2_token_url,
|
||||
scopes=scopes,
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_oidc_auth(self) -> Optional[OIDCAuth]:
|
||||
if self.agent.oidc_issuer_url:
|
||||
return OIDCAuth(
|
||||
issuer_url=self.agent.oidc_issuer_url,
|
||||
client_id=self.agent.oidc_client_id or "",
|
||||
client_secret=self.agent.oidc_client_secret,
|
||||
scopes=self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else [],
|
||||
)
|
||||
return None
|
||||
|
||||
async def authorize_request(self, method: str, url: str, user_token: Optional[str] = None) -> Dict[str, str]:
|
||||
headers = self.get_auth_headers(user_token)
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
|
||||
oauth2_auth = await self.get_oauth2_auth()
|
||||
if oauth2_auth:
|
||||
headers.update(await oauth2_auth.authorize_request(method, url))
|
||||
|
||||
elif preferred == "openIdConnect":
|
||||
oidc_auth = await self.get_oidc_auth()
|
||||
if oidc_auth:
|
||||
headers.update(await oidc_auth.authorize_request(method, url))
|
||||
|
||||
return headers
|
||||
|
||||
_STATE_TRANSITIONS = {
|
||||
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
|
||||
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
|
||||
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
|
||||
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
|
||||
"REJECTED": set(),
|
||||
"FAILED": set(),
|
||||
"COMPLETED": set(),
|
||||
"CANCELED": set(),
|
||||
}
|
||||
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AResolvedRoute:
|
||||
selected: str
|
||||
@@ -185,6 +467,7 @@ class A2ARuntime:
|
||||
remote_agent_id: Optional[int],
|
||||
compatibility_mode: bool,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
context_id: Optional[str] = None,
|
||||
) -> A2ATask:
|
||||
if idempotency_key:
|
||||
existing = (
|
||||
@@ -209,6 +492,7 @@ class A2ARuntime:
|
||||
idempotency_key=idempotency_key,
|
||||
compatibility_mode=compatibility_mode,
|
||||
metadata_json=_json_dumps(metadata or {}),
|
||||
context_id=context_id,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
@@ -335,23 +619,21 @@ class A2ARuntime:
|
||||
|
||||
async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None:
|
||||
event_payload = _json_loads(event.payload_json, {})
|
||||
request_payload = {
|
||||
"task_id": event.task_id,
|
||||
"event_type": event.event_type,
|
||||
"event_id": event.id,
|
||||
"payload": event_payload,
|
||||
}
|
||||
body = _json_dumps(request_payload).encode("utf-8")
|
||||
stream_response_payload = self._build_stream_response_payload(event, event_payload)
|
||||
body = _json_dumps(stream_response_payload).encode("utf-8")
|
||||
|
||||
for attempt in range(1, 5):
|
||||
delivery.attempt = attempt
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
|
||||
headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)}
|
||||
if hook.secret:
|
||||
digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest()
|
||||
headers["X-A2A-Signature"] = f"sha256={digest}"
|
||||
if hook.auth_header:
|
||||
headers["Authorization"] = hook.auth_header
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0, verify=True) as client:
|
||||
resp = await client.post(hook.target_url, content=body, headers=headers)
|
||||
@@ -368,11 +650,12 @@ class A2ARuntime:
|
||||
except Exception as exc:
|
||||
delivery.error_message = str(exc)[:500]
|
||||
if attempt < 4:
|
||||
backoff_seconds = 2 ** attempt
|
||||
delivery.status = "RETRYING"
|
||||
delivery.next_retry_at = _utc_now() + timedelta(seconds=2 ** attempt)
|
||||
delivery.next_retry_at = _utc_now() + timedelta(seconds=backoff_seconds)
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
await asyncio.sleep(backoff_seconds)
|
||||
continue
|
||||
delivery.status = "FAILED"
|
||||
delivery.dead_letter = True
|
||||
@@ -380,5 +663,42 @@ class A2ARuntime:
|
||||
db.commit()
|
||||
return
|
||||
|
||||
def _build_stream_response_payload(self, event: A2ATaskEvent, event_payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = event.event_type
|
||||
task_id = event_payload.get("task_id", event.task_id)
|
||||
|
||||
if event_type == "TaskStatusUpdateEvent":
|
||||
status_state = event_payload.get("task_status", "WORKING")
|
||||
status_timestamp = event_payload.get("timestamp", _utc_now().isoformat())
|
||||
status_schema = A2ATaskStatusSchema(
|
||||
state=status_state,
|
||||
timestamp=datetime.fromisoformat(status_timestamp) if isinstance(status_timestamp, str) else status_timestamp,
|
||||
)
|
||||
return {
|
||||
"statusUpdate": TaskStatusUpdateEvent(
|
||||
taskId=task_id,
|
||||
contextId=event_payload.get("context_id"),
|
||||
status=status_schema,
|
||||
metadata=event_payload.get("metadata", {}),
|
||||
).model_dump()
|
||||
}
|
||||
elif event_type == "TaskArtifactUpdateEvent":
|
||||
artifact_content = event_payload.get("artifact", {}).get("content", "")
|
||||
artifact_schema = A2AArtifactSchema(
|
||||
artifactId=f"artifact-{event.id}",
|
||||
parts=[A2APartSchema(part_type="text", text=artifact_content)],
|
||||
)
|
||||
return {
|
||||
"artifactUpdate": TaskArtifactUpdateEvent(
|
||||
taskId=task_id,
|
||||
contextId=event_payload.get("context_id"),
|
||||
artifact=artifact_schema,
|
||||
append=False,
|
||||
lastChunk=True,
|
||||
).model_dump()
|
||||
}
|
||||
else:
|
||||
return {"message": event_payload}
|
||||
|
||||
|
||||
a2a_runtime = A2ARuntime()
|
||||
|
||||
+809
-256
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user