refactor: a2a

This commit is contained in:
qixinbo
2026-04-04 07:24:09 +08:00
parent 02677763cb
commit bdc61fe651
7 changed files with 3345 additions and 496 deletions
+1091 -129
View File
File diff suppressed because it is too large Load Diff
+95 -2
View File
@@ -1,9 +1,34 @@
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, func
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, JSON, Enum as SQLEnum, func
from sqlalchemy.orm import relationship
import enum
from app.database import Base
class A2ATaskState(str, enum.Enum):
SUBMITTED = "SUBMITTED"
WORKING = "WORKING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
CANCELED = "CANCELED"
INPUT_REQUIRED = "INPUT_REQUIRED"
AUTH_REQUIRED = "AUTH_REQUIRED"
REJECTED = "REJECTED"
class A2APartType(str, enum.Enum):
TEXT = "text"
RAW = "raw"
URL = "url"
DATA = "data"
class A2AMessageRole(str, enum.Enum):
USER = "user"
AGENT = "agent"
SYSTEM = "system"
class A2ARemoteAgent(Base):
__tablename__ = "a2a_remote_agents"
@@ -13,6 +38,17 @@ class A2ARemoteAgent(Base):
base_url = Column(String, nullable=False)
auth_scheme = Column(String, nullable=False, default="none")
auth_token = Column(String, nullable=True)
shared_secret = Column(String, nullable=True)
mtls_ca_cert = Column(Text, nullable=True)
mtls_client_cert = Column(Text, nullable=True)
mtls_client_key = Column(Text, nullable=True)
oauth2_client_id = Column(String, nullable=True)
oauth2_client_secret = Column(String, nullable=True)
oauth2_token_url = Column(String, nullable=True)
oauth2_scopes = Column(String, nullable=True)
oidc_issuer_url = Column(String, nullable=True)
oidc_client_id = Column(String, nullable=True)
oidc_client_secret = Column(String, nullable=True)
protocol_version = Column(String, nullable=True)
capabilities_json = Column(Text, nullable=False, default="[]")
card_json = Column(Text, nullable=True)
@@ -27,27 +63,84 @@ class A2ARemoteAgent(Base):
project = relationship("Project")
class A2APart(Base):
__tablename__ = "a2a_parts"
id = Column(Integer, primary_key=True, index=True)
message_id = Column(Integer, ForeignKey("a2a_messages.id", ondelete="CASCADE"), nullable=True, index=True)
artifact_id = Column(Integer, ForeignKey("a2a_artifacts.id", ondelete="CASCADE"), nullable=True, index=True)
part_type = Column(SQLEnum(A2APartType), nullable=False)
text_content = Column(Text, nullable=True)
raw_content = Column(Text, nullable=True)
url_content = Column(String, nullable=True)
data_content = Column(Text, nullable=True)
media_type = Column(String, nullable=True)
filename = Column(String, nullable=True)
metadata_json = Column(Text, nullable=False, default="{}")
created_at = Column(DateTime, default=func.now())
message = relationship("A2AMessage", back_populates="parts", foreign_keys=[message_id])
artifact = relationship("A2AArtifact", back_populates="parts", foreign_keys=[artifact_id])
class A2AMessage(Base):
__tablename__ = "a2a_messages"
id = Column(Integer, primary_key=True, index=True)
message_id = Column(String, nullable=False, unique=True, index=True)
context_id = Column(String, nullable=True, index=True)
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=True, index=True)
role = Column(SQLEnum(A2AMessageRole), nullable=False)
extensions_json = Column(Text, nullable=False, default="{}")
reference_task_ids_json = Column(Text, nullable=False, default="[]")
created_at = Column(DateTime, default=func.now(), index=True)
task = relationship("A2ATask", back_populates="messages", foreign_keys=[task_id])
parts = relationship("A2APart", back_populates="message", cascade="all, delete-orphan")
class A2AArtifact(Base):
__tablename__ = "a2a_artifacts"
id = Column(Integer, primary_key=True, index=True)
artifact_id = Column(String, nullable=False, unique=True, index=True)
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
name = Column(String, nullable=True)
description = Column(Text, nullable=True)
metadata_json = Column(Text, nullable=False, default="{}")
extensions_json = Column(Text, nullable=False, default="{}")
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
task = relationship("A2ATask", back_populates="artifacts")
parts = relationship("A2APart", back_populates="artifact", cascade="all, delete-orphan")
class A2ATask(Base):
__tablename__ = "a2a_tasks"
id = Column(String, primary_key=True, index=True)
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
tenant_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
context_id = Column(String, nullable=True, index=True)
source = Column(String, nullable=False, default="local")
remote_agent_id = Column(Integer, ForeignKey("a2a_remote_agents.id"), nullable=True, index=True)
idempotency_key = Column(String, nullable=True, index=True)
state = Column(String, nullable=False, index=True, default="SUBMITTED")
state = Column(SQLEnum(A2ATaskState), nullable=False, index=True, default=A2ATaskState.SUBMITTED)
input_text = Column(Text, nullable=False, default="")
output_text = Column(Text, nullable=True)
error_message = Column(Text, nullable=True)
compatibility_mode = Column(Boolean, nullable=False, default=True)
metadata_json = Column(Text, nullable=False, default="{}")
history_length = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
finished_at = Column(DateTime, nullable=True)
project = relationship("Project")
remote_agent = relationship("A2ARemoteAgent")
messages = relationship("A2AMessage", back_populates="task", cascade="all, delete-orphan", foreign_keys=[A2AMessage.task_id])
artifacts = relationship("A2AArtifact", back_populates="task", cascade="all, delete-orphan")
class A2ATaskEvent(Base):
+361
View File
@@ -0,0 +1,361 @@
from pydantic import BaseModel, ConfigDict, Field
from typing import Optional, List, Dict, Any, Literal, Union
from datetime import datetime
from enum import Enum
class A2ATaskState(str, Enum):
SUBMITTED = "SUBMITTED"
WORKING = "WORKING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
CANCELED = "CANCELED"
INPUT_REQUIRED = "INPUT_REQUIRED"
AUTH_REQUIRED = "AUTH_REQUIRED"
REJECTED = "REJECTED"
class A2APartType(str, Enum):
TEXT = "text"
RAW = "raw"
URL = "url"
DATA = "data"
class A2AMessageRole(str, Enum):
USER = "user"
AGENT = "agent"
SYSTEM = "system"
class A2APartSchema(BaseModel):
part_type: A2APartType
text: Optional[str] = None
raw: Optional[bytes] = None
url: Optional[str] = None
data: Optional[Any] = None
mediaType: Optional[str] = None
filename: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
class A2APartCreateSchema(BaseModel):
part_type: A2APartType
text: Optional[str] = None
raw: Optional[str] = None
url: Optional[str] = None
data: Optional[Any] = None
mediaType: Optional[str] = None
filename: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
class A2AMessageSchema(BaseModel):
messageId: str
contextId: Optional[str] = None
taskId: Optional[str] = None
role: A2AMessageRole
parts: List[A2APartSchema] = Field(default_factory=list)
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
referenceTaskIds: Optional[List[str]] = Field(default_factory=list)
createdAt: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class A2AMessageCreateSchema(BaseModel):
messageId: str
contextId: Optional[str] = None
taskId: Optional[str] = None
role: A2AMessageRole
parts: List[A2APartCreateSchema] = Field(default_factory=list)
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
referenceTaskIds: Optional[List[str]] = Field(default_factory=list)
class A2AArtifactSchema(BaseModel):
artifactId: str
name: Optional[str] = None
description: Optional[str] = None
parts: List[A2APartSchema] = Field(default_factory=list)
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
createdAt: Optional[datetime] = None
updatedAt: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class A2AArtifactCreateSchema(BaseModel):
artifactId: str
name: Optional[str] = None
description: Optional[str] = None
parts: List[A2APartCreateSchema] = Field(default_factory=list)
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
class A2ATaskStatusSchema(BaseModel):
state: A2ATaskState
timestamp: datetime
class A2ATaskSchema(BaseModel):
id: str
contextId: Optional[str] = None
projectId: int
tenantId: int
source: str
remoteAgentId: Optional[int] = None
idempotencyKey: Optional[str] = None
state: A2ATaskState
inputText: str
outputText: Optional[str] = None
errorMessage: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
historyLength: int = 0
createdAt: datetime
updatedAt: datetime
finishedAt: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class A2ATaskWithMessagesSchema(A2ATaskSchema):
messages: List[A2AMessageSchema] = Field(default_factory=list)
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
class A2ATaskWithHistorySchema(BaseModel):
id: str
contextId: Optional[str] = None
projectId: int
tenantId: int
state: A2ATaskState
history: List[A2AMessageSchema] = Field(default_factory=list)
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
createdAt: datetime
updatedAt: datetime
finishedAt: Optional[datetime] = None
model_config = ConfigDict(from_attributes=True)
class TaskStatusUpdateEvent(BaseModel):
taskId: str
contextId: Optional[str] = None
status: A2ATaskStatusSchema
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
class TaskArtifactUpdateEvent(BaseModel):
taskId: str
contextId: Optional[str] = None
artifact: A2AArtifactSchema
append: bool = False
lastChunk: bool = True
class TaskMessageEvent(BaseModel):
message: A2AMessageSchema
class StreamResponseTask(BaseModel):
id: str
contextId: Optional[str] = None
state: A2ATaskState
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
class StreamResponse(BaseModel):
task: Optional[StreamResponseTask] = None
message: Optional[A2AMessageSchema] = None
statusUpdate: Optional[TaskStatusUpdateEvent] = None
artifactUpdate: Optional[TaskArtifactUpdateEvent] = None
class SendMessageRequest(BaseModel):
message: A2AMessageCreateSchema
taskId: Optional[str] = None
contextId: Optional[str] = None
class SendStreamingMessageRequest(BaseModel):
message: A2AMessageCreateSchema
taskId: Optional[str] = None
contextId: Optional[str] = None
class GetTaskRequest(BaseModel):
historyLength: Optional[int] = None
class TaskListRequest(BaseModel):
contextId: Optional[str] = None
status: Optional[A2ATaskState] = None
pageSize: int = 20
pageToken: Optional[str] = None
class CancelTaskRequest(BaseModel):
pass
class PushNotificationConfigCreate(BaseModel):
targetUrl: str
secret: Optional[str] = None
authHeader: Optional[str] = None
enabled: bool = True
class PushNotificationConfig(BaseModel):
id: int
taskId: str
targetUrl: str
secret: Optional[str] = None
authHeader: Optional[str] = None
enabled: bool
createdBy: int
createdAt: datetime
model_config = ConfigDict(from_attributes=True)
class VersionNotSupportedError(BaseModel):
code: int = -32009
message: str = "Version not supported"
data: Optional[Dict[str, Any]] = None
class AgentSkillInputMode(str, Enum):
TEXT = "text"
DATA = "data"
RAW = "raw"
URL = "url"
class AgentSkillOutputMode(str, Enum):
TEXT = "text"
DATA = "data"
ARTIFACT = "artifact"
STREAM = "stream"
class AgentSkillSecurityRequirement(BaseModel):
scheme: str
scopes: Optional[List[str]] = None
class AgentSkillExample(BaseModel):
input: Dict[str, Any]
output: Dict[str, Any]
class AgentSkill(BaseModel):
id: str
name: str
description: Optional[str] = None
tags: List[str] = Field(default_factory=list)
examples: List[AgentSkillExample] = Field(default_factory=list)
inputModes: List[AgentSkillInputMode] = Field(default_factory=list)
outputModes: List[AgentSkillOutputMode] = Field(default_factory=list)
securityRequirements: List[AgentSkillSecurityRequirement] = Field(default_factory=list)
class AgentProvider(BaseModel):
organization: str
url: Optional[str] = None
class AgentSupportedInterface(BaseModel):
url: str
protocolBinding: str
protocolVersion: str
tenant: Optional[str] = None
class SecuritySchemeApiKey(BaseModel):
type: Literal["apiKey"] = "apiKey"
name: str
in_: str = Field(alias="in")
description: Optional[str] = None
model_config = ConfigDict(populate_by_name=True)
class SecuritySchemeHttpAuth(BaseModel):
type: Literal["http"] = "http"
scheme: str
description: Optional[str] = None
class OAuth2AuthorizationCodeFlow(BaseModel):
authorizationUrl: str
tokenUrl: str
scopes: Dict[str, str] = Field(default_factory=dict)
refreshUrl: Optional[str] = None
class OAuth2ClientCredentialsFlow(BaseModel):
tokenUrl: str
scopes: Dict[str, str] = Field(default_factory=dict)
refreshUrl: Optional[str] = None
class OAuth2DeviceCodeFlow(BaseModel):
authorizationUrl: str
tokenUrl: str
scopes: Dict[str, str] = Field(default_factory=dict)
deviceAuthorizationUrl: Optional[str] = None
class OAuth2Flows(BaseModel):
authorizationCode: Optional[OAuth2AuthorizationCodeFlow] = None
clientCredentials: Optional[OAuth2ClientCredentialsFlow] = None
deviceCode: Optional[OAuth2DeviceCodeFlow] = None
implicit: Optional[Dict[str, Any]] = None
password: Optional[Dict[str, Any]] = None
class SecuritySchemeOAuth2(BaseModel):
type: Literal["oauth2"] = "oauth2"
flows: OAuth2Flows
description: Optional[str] = None
class SecuritySchemeOpenIdConnect(BaseModel):
type: Literal["openIdConnect"] = "openIdConnect"
openIdConnectUrl: str
description: Optional[str] = None
scopes: Dict[str, str] = Field(default_factory=dict)
class SecuritySchemeMtls(BaseModel):
type: Literal["mutualTLS"] = "mutualTLS"
description: Optional[str] = None
caCerts: Optional[List[str]] = None
clientCert: Optional[str] = None
clientKey: Optional[str] = None
class AgentCardPublicSchema(BaseModel):
name: str
protocol_version: str = "1.0"
capabilities: List[str]
endpoints: Dict[str, str]
auth: List[str]
skills: List[AgentSkill] = Field(default_factory=list)
provider: Optional[AgentProvider] = None
supportedInterfaces: List[AgentSupportedInterface] = Field(default_factory=list)
defaultInputModes: List[str] = Field(default_factory=list)
defaultOutputModes: List[str] = Field(default_factory=list)
iconUrl: Optional[str] = None
documentationUrl: Optional[str] = None
class AgentCardExtendedSchema(AgentCardPublicSchema):
securitySchemes: Optional[Dict[str, Union[SecuritySchemeApiKey, SecuritySchemeHttpAuth, SecuritySchemeOAuth2, SecuritySchemeOpenIdConnect, SecuritySchemeMtls]]] = None
security: List[Dict[str, List[str]]] = Field(default_factory=list)
signatures: List[str] = Field(default_factory=list)
tenantId: Optional[int] = None
isAdmin: Optional[bool] = None
+341 -21
View File
@@ -4,6 +4,7 @@ import asyncio
import hashlib
import hmac
import json
import ssl
import time
import uuid
from collections import defaultdict, deque
@@ -24,20 +25,15 @@ from app.models.a2a import (
A2AWebhookDelivery,
)
from app.models.project import Project
from app.schemas.a2a import (
A2AArtifactSchema,
A2APartSchema,
A2ATaskStatusSchema,
TaskArtifactUpdateEvent,
TaskStatusUpdateEvent,
)
from app.trace import build_error_attributes, trace_service
_STATE_TRANSITIONS = {
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
"REJECTED": set(),
"FAILED": set(),
"COMPLETED": set(),
"CANCELED": set(),
}
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
def _json_loads(raw: Optional[str], default: Any) -> Any:
if not raw:
@@ -62,6 +58,292 @@ def _mask_error(message: str) -> str:
return "request_failed"
class SharedSecretAuth:
@staticmethod
def generate_signature(secret: str, payload: bytes, timestamp: Optional[int] = None) -> Tuple[str, int]:
if timestamp is None:
timestamp = int(time.time())
message = f"{timestamp}".encode() + payload
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
return f"sha256={signature}", timestamp
@staticmethod
def verify_signature(secret: str, payload: bytes, signature: str, timestamp: int, max_age_seconds: int = 300) -> bool:
if abs(time.time() - timestamp) > max_age_seconds:
return False
expected_sig, _ = SharedSecretAuth.generate_signature(secret, payload, timestamp)
return hmac.compare_digest(signature, expected_sig)
@staticmethod
def sign_request(secret: str, method: str, path: str, body: Optional[bytes] = None) -> Dict[str, str]:
timestamp = int(time.time())
payload = body or b""
message = f"{timestamp}.{method.upper()}.{path}".encode() + payload
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
return {
"X-A2A-Signature": f"sha256={signature}",
"X-A2A-Timestamp": str(timestamp),
}
class MtlsConfig:
def __init__(
self,
ca_cert: Optional[str] = None,
client_cert: Optional[str] = None,
client_key: Optional[str] = None,
):
self.ca_cert = ca_cert
self.client_cert = client_cert
self.client_key = client_key
def create_ssl_context(self) -> Optional[ssl.SSLContext]:
if not self.client_cert or not self.client_key:
return None
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ctx.load_cert_chain(self.client_cert, self.client_key)
if self.ca_cert:
ctx.load_verify_locations(self.ca_cert)
ctx.verify_mode = ssl.CERT_REQUIRED
else:
ctx.verify_mode = ssl.CERT_NONE
return ctx
class OAuth2TokenStore:
def __init__(self):
self._tokens: Dict[str, Tuple[str, datetime]] = {}
self._lock = asyncio.Lock()
async def get_token(self, key: str) -> Optional[str]:
async with self._lock:
if key in self._tokens:
token, expires_at = self._tokens[key]
if expires_at > _utc_now() + timedelta(minutes=1):
return token
return None
async def set_token(self, key: str, token: str, expires_in: int = 3600) -> None:
async with self._lock:
self._tokens[key] = (token, _utc_now() + timedelta(seconds=expires_in))
class OAuth2Auth:
def __init__(
self,
client_id: str,
client_secret: str,
token_url: str,
scopes: Optional[List[str]] = None,
):
self.client_id = client_id
self.client_secret = client_secret
self.token_url = token_url
self.scopes = scopes or []
self._token_store = OAuth2TokenStore()
def _get_cache_key(self) -> str:
return f"{self.client_id}:{self.token_url}:{':'.join(self.scopes)}"
async def get_access_token(self, grant_type: str = "client_credentials") -> str:
cache_key = self._get_cache_key()
cached = await self._token_store.get_token(cache_key)
if cached:
return cached
async with httpx.AsyncClient() as client:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": grant_type,
}
if self.scopes:
data["scope"] = " ".join(self.scopes)
resp = await client.post(self.token_url, data=data)
resp.raise_for_status()
token_data = resp.json()
token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600)
await self._token_store.set_token(cache_key, token, expires_in)
return token
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
token = await self.get_access_token()
return {"Authorization": f"Bearer {token}"}
class OIDCAuth:
def __init__(
self,
issuer_url: str,
client_id: str,
client_secret: Optional[str] = None,
scopes: Optional[List[str]] = None,
):
self.issuer_url = issuer_url.rstrip("/")
self.client_id = client_id
self.client_secret = client_secret
self.scopes = scopes or ["openid", "profile"]
self._oauth2: Optional[OAuth2Auth] = None
self._discovery_cache: Optional[Dict[str, Any]] = None
async def _get_discovery(self) -> Dict[str, Any]:
if self._discovery_cache:
return self._discovery_cache
discovery_url = f"{self.issuer_url}/.well-known/openid-configuration"
async with httpx.AsyncClient() as client:
resp = await client.get(discovery_url)
resp.raise_for_status()
self._discovery_cache = resp.json()
return self._discovery_cache
async def get_access_token(self) -> str:
discovery = await self._get_discovery()
token_url = discovery.get("token_endpoint")
if not token_url:
raise RuntimeError("OIDC discovery missing token_endpoint")
if not self._oauth2:
self._oauth2 = OAuth2Auth(
client_id=self.client_id,
client_secret=self.client_secret or "",
token_url=token_url,
scopes=self.scopes,
)
return await self._oauth2.get_access_token()
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
token = await self.get_access_token()
return {"Authorization": f"Bearer {token}"}
class RemoteAgentSecuritySelector:
def __init__(self, agent: A2ARemoteAgent):
self.agent = agent
self._card_security_schemes: Optional[Dict[str, Any]] = None
def load_security_from_card(self) -> None:
card = _json_loads(self.agent.card_json, {})
if card:
self._card_security_schemes = card.get("securitySchemes", {})
def get_preferred_auth_scheme(self) -> str:
card = _json_loads(self.agent.card_json, {})
security_reqs = card.get("security", [])
if security_reqs:
first_req = security_reqs[0]
if isinstance(first_req, dict):
for scheme_name in first_req.keys():
return scheme_name
return self.agent.auth_scheme or "bearer"
def get_auth_headers(self, user_token: Optional[str] = None) -> Dict[str, str]:
headers: Dict[str, str] = {}
preferred = self.get_preferred_auth_scheme()
if preferred == "bearer" or self.agent.auth_scheme == "bearer":
if self.agent.auth_token:
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
elif user_token:
headers["Authorization"] = f"Bearer {user_token}"
elif preferred == "shared_secret" or self.agent.auth_scheme == "shared_secret":
pass
elif preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
pass
elif preferred == "openIdConnect":
pass
elif preferred == "mutualTLS":
pass
return headers
def get_mtls_context(self) -> Optional[ssl.SSLContext]:
if self.agent.auth_scheme == "mutualTLS" or self.get_preferred_auth_scheme() == "mutualTLS":
if self.agent.mtls_client_cert and self.agent.mtls_client_key:
config = MtlsConfig(
ca_cert=self.agent.mtls_ca_cert,
client_cert=self.agent.mtls_client_cert,
client_key=self.agent.mtls_client_key,
)
return config.create_ssl_context()
return None
def create_signed_request_headers(
self,
method: str,
path: str,
body: Optional[bytes] = None,
) -> Dict[str, str]:
headers: Dict[str, str] = {}
preferred = self.get_preferred_auth_scheme()
if preferred == "shared_secret" and self.agent.shared_secret:
sig_headers = SharedSecretAuth.sign_request(
self.agent.shared_secret,
method,
path,
body,
)
headers.update(sig_headers)
elif self.agent.auth_scheme == "bearer" and self.agent.auth_token:
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
return headers
async def get_oauth2_auth(self) -> Optional[OAuth2Auth]:
if self.agent.oauth2_client_id and self.agent.oauth2_token_url:
scopes = self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else []
return OAuth2Auth(
client_id=self.agent.oauth2_client_id,
client_secret=self.agent.oauth2_client_secret or "",
token_url=self.agent.oauth2_token_url,
scopes=scopes,
)
return None
async def get_oidc_auth(self) -> Optional[OIDCAuth]:
if self.agent.oidc_issuer_url:
return OIDCAuth(
issuer_url=self.agent.oidc_issuer_url,
client_id=self.agent.oidc_client_id or "",
client_secret=self.agent.oidc_client_secret,
scopes=self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else [],
)
return None
async def authorize_request(self, method: str, url: str, user_token: Optional[str] = None) -> Dict[str, str]:
headers = self.get_auth_headers(user_token)
preferred = self.get_preferred_auth_scheme()
if preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
oauth2_auth = await self.get_oauth2_auth()
if oauth2_auth:
headers.update(await oauth2_auth.authorize_request(method, url))
elif preferred == "openIdConnect":
oidc_auth = await self.get_oidc_auth()
if oidc_auth:
headers.update(await oidc_auth.authorize_request(method, url))
return headers
_STATE_TRANSITIONS = {
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
"REJECTED": set(),
"FAILED": set(),
"COMPLETED": set(),
"CANCELED": set(),
}
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
@dataclass
class A2AResolvedRoute:
selected: str
@@ -185,6 +467,7 @@ class A2ARuntime:
remote_agent_id: Optional[int],
compatibility_mode: bool,
metadata: Optional[Dict[str, Any]] = None,
context_id: Optional[str] = None,
) -> A2ATask:
if idempotency_key:
existing = (
@@ -209,6 +492,7 @@ class A2ARuntime:
idempotency_key=idempotency_key,
compatibility_mode=compatibility_mode,
metadata_json=_json_dumps(metadata or {}),
context_id=context_id,
)
db.add(task)
db.commit()
@@ -335,23 +619,21 @@ class A2ARuntime:
async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None:
event_payload = _json_loads(event.payload_json, {})
request_payload = {
"task_id": event.task_id,
"event_type": event.event_type,
"event_id": event.id,
"payload": event_payload,
}
body = _json_dumps(request_payload).encode("utf-8")
stream_response_payload = self._build_stream_response_payload(event, event_payload)
body = _json_dumps(stream_response_payload).encode("utf-8")
for attempt in range(1, 5):
delivery.attempt = attempt
db.add(delivery)
db.commit()
headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)}
if hook.secret:
digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest()
headers["X-A2A-Signature"] = f"sha256={digest}"
if hook.auth_header:
headers["Authorization"] = hook.auth_header
try:
async with httpx.AsyncClient(timeout=8.0, verify=True) as client:
resp = await client.post(hook.target_url, content=body, headers=headers)
@@ -368,11 +650,12 @@ class A2ARuntime:
except Exception as exc:
delivery.error_message = str(exc)[:500]
if attempt < 4:
backoff_seconds = 2 ** attempt
delivery.status = "RETRYING"
delivery.next_retry_at = _utc_now() + timedelta(seconds=2 ** attempt)
delivery.next_retry_at = _utc_now() + timedelta(seconds=backoff_seconds)
db.add(delivery)
db.commit()
await asyncio.sleep(2 ** attempt)
await asyncio.sleep(backoff_seconds)
continue
delivery.status = "FAILED"
delivery.dead_letter = True
@@ -380,5 +663,42 @@ class A2ARuntime:
db.commit()
return
def _build_stream_response_payload(self, event: A2ATaskEvent, event_payload: Dict[str, Any]) -> Dict[str, Any]:
event_type = event.event_type
task_id = event_payload.get("task_id", event.task_id)
if event_type == "TaskStatusUpdateEvent":
status_state = event_payload.get("task_status", "WORKING")
status_timestamp = event_payload.get("timestamp", _utc_now().isoformat())
status_schema = A2ATaskStatusSchema(
state=status_state,
timestamp=datetime.fromisoformat(status_timestamp) if isinstance(status_timestamp, str) else status_timestamp,
)
return {
"statusUpdate": TaskStatusUpdateEvent(
taskId=task_id,
contextId=event_payload.get("context_id"),
status=status_schema,
metadata=event_payload.get("metadata", {}),
).model_dump()
}
elif event_type == "TaskArtifactUpdateEvent":
artifact_content = event_payload.get("artifact", {}).get("content", "")
artifact_schema = A2AArtifactSchema(
artifactId=f"artifact-{event.id}",
parts=[A2APartSchema(part_type="text", text=artifact_content)],
)
return {
"artifactUpdate": TaskArtifactUpdateEvent(
taskId=task_id,
contextId=event_payload.get("context_id"),
artifact=artifact_schema,
append=False,
lastChunk=True,
).model_dump()
}
else:
return {"message": event_payload}
a2a_runtime = A2ARuntime()
+698 -145
View File
@@ -1,7 +1,13 @@
import asyncio
import hashlib
import hmac
import json
import sys
import time
from collections.abc import Generator
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional
import httpx
import pytest
@@ -20,10 +26,11 @@ if str(NANOBOT_ROOT) not in sys.path:
from app.core.security import CurrentUser, get_current_user
from app.database import Base, get_db
from app.models.a2a import A2ARemoteAgent
from app.models.a2a import A2ARemoteAgent, A2ATask, A2ATaskState
from app.models.project import Project
from app.models.user import User
from app.services.a2a_service import a2a_runtime
from app.schemas.a2a import A2AMessageRole, A2APartType, AgentSkillOutputMode, AgentSkillInputMode
from app.services.a2a_service import a2a_runtime, SharedSecretAuth
from main import app
@@ -42,144 +49,83 @@ def _seed(db: Session) -> tuple[int, str, int, str, int]:
return owner.id, owner.username, other.id, other.username, project.id
def test_a2a_send_list_cancel_and_rollout() -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
send_resp = client.post(
"/api/v1/a2a/messages/send",
json={
def _make_message_payload(project_id: int, text: str, session_id: str = "test-session", route_mode: str = "local_first", idempotency_key: Optional[str] = None) -> Dict[str, Any]:
payload: Dict[str, Any] = {
"message": {
"messageId": f"msg-{int(time.time()*1000)}",
"role": "user",
"parts": [
{
"part_type": "data",
"data": {
"project_id": project_id,
"message": "hello a2a",
"session_id": "test-a2a-session",
"route_mode": "local_first",
"route_mode": route_mode,
"session_id": session_id,
**( {"idempotency_key": idempotency_key} if idempotency_key else {} )
},
"mediaType": "application/json",
},
{
"part_type": "text",
"text": text,
}
],
}
}
return payload
class TestPartSerialization:
def test_part_text_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.TEXT,
text="Hello world",
mediaType="text/plain",
)
assert send_resp.status_code == 200
task_id = send_resp.json()["task"]["id"]
data = part.model_dump()
assert data["part_type"] == "text"
assert data["text"] == "Hello world"
assert data["mediaType"] == "text/plain"
get_resp = client.get(f"/api/v1/a2a/tasks/{task_id}")
assert get_resp.status_code == 200
assert get_resp.json()["project_id"] == project_id
list_resp = client.get("/api/v1/a2a/tasks", params={"project_id": project_id})
assert list_resp.status_code == 200
assert any(item["id"] == task_id for item in list_resp.json())
cancel_resp = client.post(f"/api/v1/a2a/tasks/{task_id}/cancel")
assert cancel_resp.status_code == 200
assert cancel_resp.json()["state"] in {"CANCELED", "COMPLETED", "FAILED"}
rollout_resp = client.put(
f"/api/v1/a2a/projects/{project_id}/rollout",
json={"canary_enabled": True, "canary_percent": 30, "route_mode_default": "a2a_first", "fallback_chain": ["a2a", "local"]},
def test_part_data_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.DATA,
data={"project_id": 123, "route_mode": "local"},
mediaType="application/json",
)
assert rollout_resp.status_code == 200
assert rollout_resp.json()["canary_enabled"] is True
assert rollout_resp.json()["canary_percent"] == 30
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
data = part.model_dump()
assert data["part_type"] == "data"
assert data["data"]["project_id"] == 123
def test_a2a_task_tenant_isolation() -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
send_resp = client.post(
"/api/v1/a2a/messages/send",
json={"project_id": project_id, "message": "tenant isolation", "session_id": "tenant-isolation", "route_mode": "local"},
def test_part_url_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.URL,
url="https://example.com/file.pdf",
mediaType="application/pdf",
filename="file.pdf",
)
assert send_resp.status_code == 200
task_id = send_resp.json()["task"]["id"]
data = part.model_dump()
assert data["part_type"] == "url"
assert data["url"] == "https://example.com/file.pdf"
assert data["filename"] == "file.pdf"
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
forbidden_resp = client.get(f"/api/v1/a2a/tasks/{task_id}")
assert forbidden_resp.status_code == 404
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_part_raw_serialization(self):
from app.schemas.a2a import A2APartCreateSchema
part = A2APartCreateSchema(
part_type=A2APartType.RAW,
raw="\x00\x01\x02\x03",
mediaType="application/octet-stream",
)
data = part.model_dump()
assert data["part_type"] == "raw"
assert data["raw"] == "\x00\x01\x02\x03"
def test_a2a_metrics_admin_only() -> None:
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, _ = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
denied = client.get("/api/v1/a2a/metrics")
assert denied.status_code == 403
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
ok = client.get("/api/v1/a2a/metrics")
assert ok.status_code == 200
assert "counters" in ok.json()
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_a2a_send_idempotency_key_deduplicates_task() -> None:
class TestStateMachine:
def test_state_transitions_submit_to_complete(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
@@ -203,25 +149,518 @@ def test_a2a_send_idempotency_key_deduplicates_task() -> None:
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
payload = {
"project_id": project_id,
"message": "dedupe-task",
"session_id": "idempotency-session",
"route_mode": "local_first",
"idempotency_key": "same-key-1",
}
first_resp = client.post("/api/v1/a2a/messages/send", json=payload)
second_resp = client.post("/api/v1/a2a/messages/send", json=payload)
assert first_resp.status_code == 200
assert second_resp.status_code == 200
assert first_resp.json()["task"]["id"] == second_resp.json()["task"]["id"]
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "test state machine"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
assert task.state == A2ATaskState.SUBMITTED
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.WORKING)
assert task.state == A2ATaskState.WORKING
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.COMPLETED)
assert task.state == A2ATaskState.COMPLETED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_cancel(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
assert cancel_resp.status_code == 200
assert cancel_resp.json()["state"] == "CANCELED"
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_failed(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "fail test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.FAILED, error_message='{"message": "test error"}')
assert task.state == A2ATaskState.FAILED
assert task.error_message is not None
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_rejected(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "reject test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.REJECTED)
assert task.state == A2ATaskState.REJECTED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_input_required(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "input required test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.INPUT_REQUIRED)
assert task.state == A2ATaskState.INPUT_REQUIRED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_state_auth_required(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "auth required test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.AUTH_REQUIRED)
assert task.state == A2ATaskState.AUTH_REQUIRED
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_a2a_fetch_agent_card_auth_failure_marks_agent_unhealthy(monkeypatch) -> None:
class TestA2APathNormalization:
def test_message_send_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "path test"))
assert resp.status_code == 200
assert "task" in resp.json()
assert "id" in resp.json()["task"]
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_message_stream_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:stream", json=_make_message_payload(project_id, "stream path test"))
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("text/event-stream")
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_tasks_cancel_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel path test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
assert cancel_resp.status_code == 200
assert cancel_resp.json()["task_id"] == task_id
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_agent_card_public_path(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
db.close()
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
app.dependency_overrides[get_db] = override_get_db
try:
client = TestClient(app)
resp = client.get("/api/v1/.well-known/agent-card.json")
assert resp.status_code == 200
data = resp.json()
assert "name" in data
assert "protocol_version" in data
assert "endpoints" in data
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestVersionNegotiation:
def test_version_not_supported_error(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post(
"/api/v1/message:send",
json=_make_message_payload(project_id, "version test"),
headers={"A2A-Version": "2.0"}
)
assert resp.status_code == 400
detail = json.loads(resp.json()["detail"])
assert detail["code"] == -32009
assert "not supported" in detail["message"].lower()
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_version_response_header(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post(
"/api/v1/message:send",
json=_make_message_payload(project_id, "version header test"),
headers={"A2A-Version": "1.0"}
)
assert resp.status_code == 200
assert resp.headers.get("A2A-Version") == "1.0"
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestWebhookStreamResponse:
def test_webhook_payload_format(self):
from app.schemas.a2a import StreamResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent, TaskMessageEvent, A2ATaskStatusSchema, A2ATaskState, A2AArtifactSchema
status_event = TaskStatusUpdateEvent(
taskId="task-123",
contextId="ctx-456",
status=A2ATaskStatusSchema(
state=A2ATaskState.SUBMITTED,
timestamp=datetime.utcnow(),
),
metadata={},
)
status_dump = status_event.model_dump()
assert "taskId" in status_dump
assert status_dump["taskId"] == "task-123"
assert status_dump["status"]["state"] == "SUBMITTED"
artifact_event = TaskArtifactUpdateEvent(
taskId="task-123",
contextId="ctx-456",
artifact=A2AArtifactSchema(
artifactId="art-789",
parts=[],
),
append=False,
lastChunk=True,
)
artifact_dump = artifact_event.model_dump()
assert "taskId" in artifact_dump
assert artifact_dump["artifact"]["artifactId"] == "art-789"
def test_stream_response_task_field(self):
from app.schemas.a2a import StreamResponse, StreamResponseTask, A2ATaskState
resp = StreamResponse(
task=StreamResponseTask(
id="task-123",
contextId="ctx-456",
state=A2ATaskState.WORKING,
artifacts=[],
)
)
data = resp.model_dump()
assert "task" in data
assert data["task"]["id"] == "task-123"
class TestSSEFIFO:
def test_sse_event_order(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
with client.stream("POST", "/api/v1/message:stream", json=_make_message_payload(project_id, "fifo test")) as resp:
assert resp.status_code == 200
chunks = []
for line in resp.iter_lines():
if line.startswith("data: "):
chunks.append(json.loads(line[6:]))
event_types = [c.get("type") for c in chunks if "type" in c]
status_idx = next((i for i, t in enumerate(event_types) if t == "TaskStatusUpdateEvent"), -1)
assert status_idx >= 0
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestAuthSchemes:
def test_shared_secret_auth(self):
secret = "test-secret-key-12345"
timestamp = int(time.time())
body = b'{"test":"data"}'
sig, _ = SharedSecretAuth.generate_signature(secret, body, timestamp)
assert sig.startswith("sha256=")
assert SharedSecretAuth.verify_signature(secret, body, sig, timestamp) is True
def test_auth_scheme_none(self):
from app.schemas.a2a import SecuritySchemeHttpAuth
scheme = SecuritySchemeHttpAuth(scheme="bearer", description="Bearer auth")
assert scheme.scheme == "bearer"
class TestExceptionPaths:
def test_auth_failure_marks_agent_unhealthy(self, monkeypatch):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
@@ -272,8 +711,7 @@ def test_a2a_fetch_agent_card_auth_failure_marks_agent_unhealthy(monkeypatch) ->
db.close()
engine.dispose()
def test_a2a_fetch_agent_card_remote_unavailable_opens_circuit(monkeypatch) -> None:
def test_remote_unavailable_opens_circuit(self, monkeypatch):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
@@ -317,3 +755,118 @@ def test_a2a_fetch_agent_card_remote_unavailable_opens_circuit(monkeypatch) -> N
Base.metadata.drop_all(bind=engine)
db.close()
engine.dispose()
def test_idempotency_key_deduplicates_task(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
idempotency_key = f"idem-key-{int(time.time())}"
payload1 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
resp1 = client.post("/api/v1/message:send", json=payload1)
assert resp1.status_code == 200
payload2 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
payload2["message"]["messageId"] = f"msg-{int(time.time()*1000) + 1}"
resp2 = client.post("/api/v1/message:send", json=payload2)
assert resp2.status_code == 200
assert resp1.json()["task"]["id"] == resp2.json()["task"]["id"]
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
def test_tenant_isolation(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "isolation test"))
assert resp.status_code == 200
task_id = resp.json()["task"]["id"]
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
get_resp = client.get(f"/api/v1/tasks/{task_id}")
assert get_resp.status_code == 404
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
class TestMetricsAdminOnly:
def test_metrics_admin_only(self):
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base.metadata.create_all(bind=engine)
db = testing_session_local()
owner_id, owner_username, _, _, _ = _seed(db)
db.close()
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
def override_get_db() -> Generator[Session, None, None]:
override_db = testing_session_local()
try:
yield override_db
finally:
override_db.close()
def override_current_user() -> CurrentUser:
return state["user"]
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_current_user
try:
client = TestClient(app)
denied = client.get("/api/v1/a2a/metrics")
assert denied.status_code == 403
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
ok = client.get("/api/v1/a2a/metrics")
assert ok.status_code == 200
assert "counters" in ok.json()
finally:
app.dependency_overrides.clear()
Base.metadata.drop_all(bind=engine)
engine.dispose()
+258 -28
View File
@@ -1,5 +1,100 @@
import { api } from "@/lib/api";
export interface A2APartText {
kind: "text";
text: string;
}
export interface A2APartUrl {
kind: "url";
url: string;
}
export interface A2APartFile {
kind: "file";
data: string;
mediaType?: string;
filename?: string;
}
export type A2APart = A2APartText | A2APartUrl | A2APartFile;
export interface A2AMessage {
messageId?: string;
contextId?: string;
taskId?: string;
role: "user" | "agent" | "system";
parts: A2APart[];
extensions?: Record<string, unknown>[];
referenceTaskIds?: string[];
}
export interface A2AArtifact {
artifactId?: string;
name?: string;
description?: string;
parts: A2APart[];
metadata?: Record<string, unknown>;
extensions?: Record<string, unknown>[];
}
export interface A2ATask {
id: string;
project_id?: number;
context_id?: string;
source: string;
state: string;
remote_agent_id?: number | null;
input_text: string;
input_parts?: A2APart[];
output_text?: string | null;
output_parts?: A2APart[];
error_message?: string | null;
compatibility_mode: boolean;
metadata: Record<string, unknown>;
artifacts?: A2AArtifact[];
history?: A2AMessage[];
history_length?: number;
created_at: string;
updated_at: string;
finished_at?: string | null;
}
export interface A2AAgentCard {
id?: string;
name: string;
description?: string;
url?: string;
provider?: {
organization?: string;
url?: string;
};
skills?: Array<{
id?: string;
name: string;
description?: string;
tags?: string[];
examples?: string[];
inputModes?: string[];
outputModes?: string[];
securityRequirements?: Array<Record<string, unknown>>;
}>;
supportedInterfaces?: Array<{
type: string;
url?: string;
protocolBinding?: string;
protocolVersion?: string;
tenant?: string;
}>;
defaultInputModes?: string[];
defaultOutputModes?: string[];
securitySchemes?: Record<string, unknown>;
security?: Array<Record<string, unknown>>;
signatures?: string[];
iconUrl?: string;
documentationUrl?: string;
}
export interface A2ARemoteAgent {
id: number;
project_id: number;
@@ -12,27 +107,12 @@ export interface A2ARemoteAgent {
failure_count: number;
circuit_open_until?: string | null;
card_fetched_at?: string | null;
}
export interface A2ATask {
id: string;
project_id: number;
source: string;
state: string;
remote_agent_id?: number | null;
input_text: string;
output_text?: string | null;
error_message?: string | null;
compatibility_mode: boolean;
metadata: Record<string, unknown>;
created_at: string;
updated_at: string;
finished_at?: string | null;
agent_card?: A2AAgentCard;
}
export interface A2ASendMessagePayload {
project_id: number;
message: string;
message: A2AMessage;
session_id?: string;
remote_agent_id?: number;
route_mode?: "auto" | "local" | "a2a" | "a2a_first" | "local_first" | "mcp_first";
@@ -55,11 +135,13 @@ export interface A2ASubscribeEvent {
type?: string;
event?: string;
task_id?: string;
context_id?: string;
task_status?: string;
status?: string;
artifact?: {
content?: string;
};
artifact?: A2AArtifact;
append?: boolean;
last_chunk?: boolean;
message?: A2AMessage;
output?: string;
source?: string;
timestamp?: string;
@@ -122,29 +204,42 @@ export const a2aApi = {
healthCheckRemoteAgent(agentId: number) {
return api.post<{ healthy: boolean; failure_count: number }>(`/api/v1/a2a/remote-agents/${agentId}/health-check`, {});
},
listTasks(projectId: number, state?: string) {
listTasks(projectId: number, state?: string, contextId?: string) {
const params = new URLSearchParams({ project_id: String(projectId), limit: "100" });
if (state && state !== "all") {
params.set("state", state);
}
if (contextId) {
params.set("context_id", contextId);
}
return api.get<A2ATask[]>(`/api/v1/a2a/tasks?${params.toString()}`);
},
getTask(taskId: string) {
return api.get<A2ATask>(`/api/v1/a2a/tasks/${taskId}`);
getTask(taskId: string, historyLength?: number) {
const params = new URLSearchParams();
if (historyLength !== undefined) {
params.set("historyLength", String(historyLength));
}
const queryString = params.toString();
return api.get<A2ATask>(`/api/v1/a2a/tasks/${taskId}${queryString ? `?${queryString}` : ""}`);
},
cancelTask(taskId: string) {
return api.post<{ task_id: string; state: string }>(`/api/v1/a2a/tasks/${taskId}/cancel`, {});
return api.post<{ task_id: string; state: string }>(`/api/v1/a2a/tasks/${taskId}:cancel`, {});
},
sendMessage(payload: A2ASendMessagePayload) {
return api.post<A2ASendMessageResponse>("/api/v1/a2a/messages/send", payload);
return api.post<A2ASendMessageResponse>("/api/v1/a2a/message:send", payload);
},
async subscribeTask(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): Promise<void> {
const response = await fetch(`/api/v1/a2a/tasks/${taskId}/subscribe`, {
streamMessage(payload: A2ASendMessagePayload) {
return api.post<A2ASendMessageResponse>("/api/v1/a2a/message:stream", payload);
},
subscribeTask(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): () => void {
const controller = new AbortController();
void (async () => {
const response = await fetch(`/api/v1/a2a/tasks/${taskId}:subscribe`, {
method: "GET",
headers: {
...getAuthHeaders(),
},
signal,
signal: signal || controller.signal,
});
if (!response.ok || !response.body) {
throw new Error(`Subscribe failed: ${response.status}`);
@@ -171,5 +266,140 @@ export const a2aApi = {
onEvent(event);
}
}
})();
return () => controller.abort();
},
subscribeTaskSSE(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): () => void {
const controller = new AbortController();
void (async () => {
const response = await fetch(`/api/v1/a2a/tasks/${taskId}/subscribe`, {
method: "GET",
headers: {
...getAuthHeaders(),
},
signal: signal || controller.signal,
});
if (!response.ok || !response.body) {
throw new Error(`Subscribe failed: ${response.status}`);
}
const reader = response.body.getReader();
const decoder = new TextDecoder("utf-8");
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const splitIndex = buffer.lastIndexOf("\n\n");
if (splitIndex === -1) continue;
const complete = buffer.slice(0, splitIndex);
buffer = buffer.slice(splitIndex + 2);
const events = parseSseEvents(complete);
for (const event of events) {
onEvent(event);
}
}
if (buffer.trim()) {
const events = parseSseEvents(buffer);
for (const event of events) {
onEvent(event);
}
}
})();
return () => controller.abort();
},
};
export function renderPart(part: A2APart): string {
switch (part.kind) {
case "text":
return part.text;
case "url":
return `[URL: ${part.url}]`;
case "file":
if (part.mediaType?.startsWith("image/")) {
return `[Image: ${part.filename || "image"}]`;
}
if (part.mediaType?.includes("json")) {
try {
const decoded = atob(part.data);
return `[JSON File: ${part.filename || "data.json"}]\n${decoded}`;
} catch {
return `[Binary File: ${part.filename || "data"}]`;
}
}
return `[File: ${part.filename || "file"}]`;
default:
return "[Unknown Part]";
}
}
export function renderParts(parts: A2APart[]): string {
return parts.map(renderPart).join("\n");
}
export function extractTextFromParts(parts: A2APart[]): string {
return parts
.filter((p): p is A2APartText => p.kind === "text")
.map((p) => p.text)
.join("\n");
}
export function getArtifactPreview(artifact: A2AArtifact): { type: "text" | "image" | "html" | "json" | "unknown"; content: string } {
if (!artifact.parts || artifact.parts.length === 0) {
return { type: "unknown", content: "" };
}
const firstPart = artifact.parts[0];
if (firstPart.kind === "text") {
return { type: "text", content: firstPart.text };
}
if (firstPart.kind === "url") {
const url = firstPart.url.toLowerCase();
if (url.endsWith(".png") || url.endsWith(".jpg") || url.endsWith(".jpeg") || url.endsWith(".gif") || url.endsWith(".webp")) {
return { type: "image", content: firstPart.url };
}
if (url.endsWith(".html") || url.endsWith(".htm")) {
return { type: "html", content: firstPart.url };
}
return { type: "unknown", content: firstPart.url };
}
if (firstPart.kind === "file") {
const mediaType = firstPart.mediaType || "";
if (mediaType.startsWith("image/")) {
return { type: "image", content: `data:${mediaType};base64,${firstPart.data}` };
}
if (mediaType.includes("html")) {
try {
const decoded = atob(firstPart.data);
return { type: "html", content: decoded };
} catch {
return { type: "unknown", content: "[HTML content]" };
}
}
if (mediaType.includes("json")) {
try {
const decoded = atob(firstPart.data);
return { type: "json", content: decoded };
} catch {
return { type: "unknown", content: "[JSON content]" };
}
}
return { type: "unknown", content: `[File: ${firstPart.filename || "file"}]` };
}
return { type: "unknown", content: "" };
}
export function groupTasksByContextId(tasks: A2ATask[]): Map<string, A2ATask[]> {
const grouped = new Map<string, A2ATask[]>();
for (const task of tasks) {
const contextId = task.context_id || "no-context";
const existing = grouped.get(contextId) || [];
existing.push(task);
grouped.set(contextId, existing);
}
return grouped;
}
+340 -10
View File
@@ -9,7 +9,7 @@ import { Textarea } from "@/components/ui/textarea";
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table";
import { api } from "@/lib/api";
import { a2aApi, type A2ARemoteAgent, type A2ATask } from "@/api/a2a";
import { a2aApi, type A2ARemoteAgent, type A2ATask, type A2AArtifact, renderPart, renderParts, getArtifactPreview, groupTasksByContextId } from "@/api/a2a";
import { useProjectStore } from "@/store/projectStore";
import { useMcpHealthStore } from "@/store/mcpHealthStore";
import { useRef } from 'react';
@@ -118,6 +118,11 @@ export function Skills() {
auth_token: '',
});
const [isA2aRefreshingHealth, setIsA2aRefreshingHealth] = useState(false);
const [selectedA2aAgent, setSelectedA2aAgent] = useState<A2ARemoteAgent | null>(null);
const [selectedTask, setSelectedTask] = useState<A2ATask | null>(null);
const [taskArtifactPreview, setTaskArtifactPreview] = useState<{ type: string; content: string } | null>(null);
const [contextIdFilter, setContextIdFilter] = useState<string>('all');
const [groupedByContextId, setGroupedByContextId] = useState<Map<string, A2ATask[]>>(new Map());
const { currentProject } = useProjectStore();
const { hasMcpError, refresh: refreshMcpHealth } = useMcpHealthStore();
@@ -245,6 +250,7 @@ export function Skills() {
]);
setA2aAgents(agents || []);
setA2aTasks(tasks || []);
setGroupedByContextId(groupTasksByContextId(tasks || []));
} catch (error) {
console.error("Failed to fetch A2A data", error);
} finally {
@@ -252,6 +258,25 @@ export function Skills() {
}
};
const handlePreviewArtifact = (artifact: A2AArtifact) => {
const preview = getArtifactPreview(artifact);
setTaskArtifactPreview(preview);
};
const handleTaskClick = (task: A2ATask) => {
setSelectedTask(task);
if (task.artifacts && task.artifacts.length > 0) {
const preview = getArtifactPreview(task.artifacts[0]);
setTaskArtifactPreview(preview);
} else {
setTaskArtifactPreview(null);
}
};
const handleAgentCardClick = (agent: A2ARemoteAgent) => {
setSelectedA2aAgent(agent);
};
const handleRefreshA2aHealth = async () => {
if (!currentProject || a2aAgents.length === 0) return;
setIsA2aRefreshingHealth(true);
@@ -602,6 +627,20 @@ export function Skills() {
<RefreshCw className="h-4 w-4" />
{t('refresh')}
</Button>
<Select value={contextIdFilter} onValueChange={(val) => { if (val) setContextIdFilter(val); }}>
<SelectTrigger className="w-[180px] h-9">
<SelectValue placeholder={t('filterByContext')} />
</SelectTrigger>
<SelectContent>
<SelectItem value="all">{t('allContexts')}</SelectItem>
{Array.from(groupedByContextId.keys()).filter(k => k !== 'no-context').map(contextId => (
<SelectItem key={contextId} value={contextId}>{contextId.slice(0, 16)}...</SelectItem>
))}
{groupedByContextId.has('no-context') && (
<SelectItem value="no-context">{t('noContext')}</SelectItem>
)}
</SelectContent>
</Select>
<Button
className="h-9 bg-[#ff4d29] hover:bg-[#ff4d29]/90 text-white gap-2 rounded-md px-3"
onClick={handleOpenCreateA2a}
@@ -856,7 +895,7 @@ export function Skills() {
</TableRow>
) : (
a2aAgents.map((agent) => (
<TableRow key={agent.id} className="group hover:bg-muted/50/50 transition-colors border-border">
<TableRow key={agent.id} className="group hover:bg-muted/50/50 transition-colors border-border cursor-pointer" onClick={() => handleAgentCardClick(agent)}>
<TableCell className="py-4 px-4 text-sm font-medium">{agent.name}</TableCell>
<TableCell className="py-4 px-4 text-sm text-muted-foreground truncate" title={agent.base_url}>{agent.base_url}</TableCell>
<TableCell className="py-4 px-4 text-sm text-muted-foreground">{agent.protocol_version || '-'}</TableCell>
@@ -868,7 +907,7 @@ export function Skills() {
<span className="opacity-70">#{agent.failure_count}</span>
</div>
</TableCell>
<TableCell className="py-4 px-4 text-right">
<TableCell className="py-4 px-4 text-right" onClick={(e) => e.stopPropagation()}>
<div className="flex items-center justify-end gap-1">
<Button variant="ghost" size="icon" className="h-8 w-8 text-muted-foreground hover:text-indigo-600 hover:bg-indigo-50 rounded-md transition-all shrink-0" onClick={() => void handleRefreshA2aCard(agent.id)}>
<RefreshCw className="h-4 w-4" />
@@ -895,10 +934,11 @@ export function Skills() {
<Table className="table-fixed w-full">
<TableHeader className="bg-muted/50/50">
<TableRow className="hover:bg-transparent">
<TableHead className="w-[18%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('taskId')}</TableHead>
<TableHead className="w-[12%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('taskSource')}</TableHead>
<TableHead className="w-[12%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('status')}</TableHead>
<TableHead className="w-[38%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('content')}</TableHead>
<TableHead className="w-[16%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('taskId')}</TableHead>
<TableHead className="w-[12%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('contextId')}</TableHead>
<TableHead className="w-[10%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('taskSource')}</TableHead>
<TableHead className="w-[10%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('status')}</TableHead>
<TableHead className="w-[32%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('content')}</TableHead>
<TableHead className="w-[20%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('time')}</TableHead>
</TableRow>
</TableHeader>
@@ -915,14 +955,20 @@ export function Skills() {
</TableRow>
) : (
a2aTasks.map((task) => (
<TableRow key={task.id} className="group hover:bg-muted/50/50 transition-colors border-border">
<TableRow key={task.id} className="group hover:bg-muted/50/50 transition-colors border-border cursor-pointer" onClick={() => handleTaskClick(task)}>
<TableCell className="py-4 px-4 text-xs font-mono truncate" title={task.id}>{task.id}</TableCell>
<TableCell className="py-4 px-4 text-xs font-mono truncate text-muted-foreground" title={task.context_id || ''}>{task.context_id ? task.context_id.slice(0, 12) + '...' : '-'}</TableCell>
<TableCell className="py-4 px-4 text-sm text-muted-foreground">{task.source}</TableCell>
<TableCell className="py-4 px-4 text-sm">{task.state}</TableCell>
<TableCell className="py-4 px-4 text-xs text-muted-foreground">
<div className="line-clamp-2" title={task.error_message || task.output_text || task.input_text}>
{task.error_message || task.output_text || task.input_text}
<div className="line-clamp-2" title={task.error_message || task.output_text || task.input_text || (task.input_parts ? renderParts(task.input_parts) : '')}>
{task.error_message || task.output_text || task.input_text || (task.input_parts ? renderParts(task.input_parts) : '')}
</div>
{task.artifacts && task.artifacts.length > 0 && (
<div className="mt-1 text-[10px] text-indigo-600">
{task.artifacts.length} artifact(s)
</div>
)}
</TableCell>
<TableCell className="py-4 px-4 text-xs text-muted-foreground">{task.updated_at}</TableCell>
</TableRow>
@@ -1203,6 +1249,290 @@ export function Skills() {
</DialogFooter>
</DialogContent>
</Dialog>
<Dialog open={!!selectedA2aAgent} onOpenChange={(open) => { if (!open) setSelectedA2aAgent(null); }}>
<DialogContent className="sm:max-w-[700px] max-h-[90vh] flex flex-col rounded-2xl p-0 overflow-hidden">
<DialogHeader className="p-6 pb-2">
<DialogTitle className="text-xl font-bold text-foreground">{selectedA2aAgent?.name} - Agent Card</DialogTitle>
</DialogHeader>
<div className="flex-1 overflow-y-auto px-6 py-2">
{selectedA2aAgent?.agent_card ? (
<div className="grid gap-4">
{selectedA2aAgent.agent_card.description && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('description')}</Label>
<p className="text-sm text-foreground">{selectedA2aAgent.agent_card.description}</p>
</div>
)}
{selectedA2aAgent.agent_card.url && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">URL</Label>
<a href={selectedA2aAgent.agent_card.url} target="_blank" rel="noopener noreferrer" className="text-sm text-indigo-600 hover:underline">{selectedA2aAgent.agent_card.url}</a>
</div>
)}
{selectedA2aAgent.agent_card.provider && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('provider')}</Label>
<div className="text-sm text-foreground">
{selectedA2aAgent.agent_card.provider.organization && <span>{selectedA2aAgent.agent_card.provider.organization}</span>}
{selectedA2aAgent.agent_card.provider.url && <span> - <a href={selectedA2aAgent.agent_card.provider.url} target="_blank" rel="noopener noreferrer" className="text-indigo-600 hover:underline">{selectedA2aAgent.agent_card.provider.url}</a></span>}
</div>
</div>
)}
{selectedA2aAgent.agent_card.skills && selectedA2aAgent.agent_card.skills.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('skills')}</Label>
<div className="space-y-2">
{selectedA2aAgent.agent_card.skills.map((skill, idx) => (
<div key={idx} className="p-3 bg-muted/50 rounded-lg">
<div className="font-medium text-sm">{skill.name}</div>
{skill.description && <p className="text-xs text-muted-foreground mt-1">{skill.description}</p>}
{skill.tags && skill.tags.length > 0 && (
<div className="flex flex-wrap gap-1 mt-2">
{skill.tags.map((tag, i) => (
<span key={i} className="px-2 py-0.5 bg-indigo-100 text-indigo-700 text-[10px] rounded-full">{tag}</span>
))}
</div>
)}
{skill.inputModes && skill.inputModes.length > 0 && (
<div className="mt-1 text-xs text-muted-foreground">Input: {skill.inputModes.join(', ')}</div>
)}
{skill.outputModes && skill.outputModes.length > 0 && (
<div className="mt-1 text-xs text-muted-foreground">Output: {skill.outputModes.join(', ')}</div>
)}
</div>
))}
</div>
</div>
)}
{selectedA2aAgent.agent_card.supportedInterfaces && selectedA2aAgent.agent_card.supportedInterfaces.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('supportedInterfaces')}</Label>
<div className="space-y-1">
{selectedA2aAgent.agent_card.supportedInterfaces.map((iface, idx) => (
<div key={idx} className="p-2 bg-muted/50 rounded text-xs">
<span className="font-medium">{iface.type}</span>
{iface.url && <span className="text-muted-foreground ml-2">{iface.url}</span>}
{iface.protocolVersion && <span className="text-muted-foreground ml-2">v{iface.protocolVersion}</span>}
</div>
))}
</div>
</div>
)}
{selectedA2aAgent.agent_card.defaultInputModes && selectedA2aAgent.agent_card.defaultInputModes.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('defaultInputModes')}</Label>
<div className="flex flex-wrap gap-1">
{selectedA2aAgent.agent_card.defaultInputModes.map((mode, idx) => (
<span key={idx} className="px-2 py-0.5 bg-muted rounded text-xs">{mode}</span>
))}
</div>
</div>
)}
{selectedA2aAgent.agent_card.defaultOutputModes && selectedA2aAgent.agent_card.defaultOutputModes.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('defaultOutputModes')}</Label>
<div className="flex flex-wrap gap-1">
{selectedA2aAgent.agent_card.defaultOutputModes.map((mode, idx) => (
<span key={idx} className="px-2 py-0.5 bg-muted rounded text-xs">{mode}</span>
))}
</div>
</div>
)}
{selectedA2aAgent.agent_card.iconUrl && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('iconUrl')}</Label>
<img src={selectedA2aAgent.agent_card.iconUrl} alt="Agent Icon" className="h-16 w-16 rounded-lg object-contain" />
</div>
)}
{selectedA2aAgent.agent_card.documentationUrl && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('documentationUrl')}</Label>
<a href={selectedA2aAgent.agent_card.documentationUrl} target="_blank" rel="noopener noreferrer" className="text-sm text-indigo-600 hover:underline">{selectedA2aAgent.agent_card.documentationUrl}</a>
</div>
)}
{selectedA2aAgent.agent_card.security && selectedA2aAgent.agent_card.security.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('security')}</Label>
<div className="text-xs text-muted-foreground font-mono bg-muted/50 p-2 rounded">
{JSON.stringify(selectedA2aAgent.agent_card.security, null, 2)}
</div>
</div>
)}
</div>
) : (
<div className="text-center py-8 text-muted-foreground">
<p>{t('noAgentCardAvailable')}</p>
<p className="text-xs mt-2">{t('tryRefreshingCard')}</p>
</div>
)}
</div>
<DialogFooter className="p-6 pt-2">
<Button variant="outline" onClick={() => setSelectedA2aAgent(null)}>{t('close')}</Button>
{selectedA2aAgent && (
<Button onClick={() => void handleRefreshA2aCard(selectedA2aAgent.id)} className="gap-2">
<RefreshCw className="h-4 w-4" />
{t('refreshCard')}
</Button>
)}
</DialogFooter>
</DialogContent>
</Dialog>
<Dialog open={!!selectedTask} onOpenChange={(open) => { if (!open) { setSelectedTask(null); setTaskArtifactPreview(null); } }}>
<DialogContent className="sm:max-w-[900px] max-h-[90vh] flex flex-col rounded-2xl p-0 overflow-hidden">
<DialogHeader className="p-6 pb-2">
<DialogTitle className="text-xl font-bold text-foreground">{t('taskDetails')}</DialogTitle>
</DialogHeader>
<div className="flex-1 overflow-y-auto px-6 py-2">
{selectedTask && (
<div className="grid gap-4">
<div className="grid grid-cols-2 gap-4">
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">Task ID</Label>
<p className="text-sm font-mono break-all">{selectedTask.id}</p>
</div>
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">Context ID</Label>
<p className="text-sm font-mono break-all">{selectedTask.context_id || '-'}</p>
</div>
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">State</Label>
<p className="text-sm">{selectedTask.state}</p>
</div>
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">Source</Label>
<p className="text-sm">{selectedTask.source}</p>
</div>
</div>
{selectedTask.input_parts && selectedTask.input_parts.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('inputParts')}</Label>
<div className="bg-muted/50 rounded-lg p-3 space-y-2">
{selectedTask.input_parts.map((part, idx) => (
<div key={idx} className="text-sm">
{part.kind === 'text' && <p className="whitespace-pre-wrap">{part.text}</p>}
{part.kind === 'url' && <a href={part.url} target="_blank" rel="noopener noreferrer" className="text-indigo-600 hover:underline">{part.url}</a>}
{part.kind === 'file' && (
<div className="flex items-center gap-2">
<span className="text-muted-foreground">[{part.filename || 'file'}]</span>
{part.mediaType?.startsWith('image/') && <span className="text-xs text-muted-foreground">({part.mediaType})</span>}
</div>
)}
</div>
))}
</div>
</div>
)}
{selectedTask.output_parts && selectedTask.output_parts.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('outputParts')}</Label>
<div className="bg-muted/50 rounded-lg p-3 space-y-2">
{selectedTask.output_parts.map((part, idx) => (
<div key={idx} className="text-sm">
{part.kind === 'text' && <p className="whitespace-pre-wrap">{part.text}</p>}
{part.kind === 'url' && <a href={part.url} target="_blank" rel="noopener noreferrer" className="text-indigo-600 hover:underline">{part.url}</a>}
{part.kind === 'file' && (
<div className="flex items-center gap-2">
<span className="text-muted-foreground">[{part.filename || 'file'}]</span>
{part.mediaType?.startsWith('image/') && <span className="text-xs text-muted-foreground">({part.mediaType})</span>}
</div>
)}
</div>
))}
</div>
</div>
)}
{selectedTask.artifacts && selectedTask.artifacts.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('artifacts')} ({selectedTask.artifacts.length})</Label>
<div className="space-y-2">
{selectedTask.artifacts.map((artifact, idx) => (
<div key={idx} className="border border-border rounded-lg p-3">
<div className="flex items-center justify-between mb-2">
<div className="font-medium text-sm">
{artifact.name || `Artifact ${idx + 1}`}
</div>
<Button variant="ghost" size="sm" className="h-7 text-xs gap-1" onClick={() => handlePreviewArtifact(artifact)}>
<Eye className="h-3 w-3" />
{t('preview')}
</Button>
</div>
{artifact.description && (
<p className="text-xs text-muted-foreground mb-2">{artifact.description}</p>
)}
<div className="text-xs text-muted-foreground">
{artifact.parts.length} part(s)
</div>
</div>
))}
</div>
</div>
)}
{taskArtifactPreview && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('artifactPreview')}</Label>
<div className="border border-border rounded-lg p-3 bg-muted/50">
{taskArtifactPreview.type === 'text' && (
<pre className="text-xs whitespace-pre-wrap break-all max-h-[300px] overflow-auto">{taskArtifactPreview.content}</pre>
)}
{taskArtifactPreview.type === 'image' && (
<img src={taskArtifactPreview.content} alt="Artifact Preview" className="max-w-full max-h-[300px] rounded-lg object-contain" />
)}
{taskArtifactPreview.type === 'html' && (
<div className="text-xs text-muted-foreground italic">[HTML Preview - rendered separately]</div>
)}
{taskArtifactPreview.type === 'json' && (
<pre className="text-xs whitespace-pre-wrap break-all max-h-[300px] overflow-auto">{taskArtifactPreview.content}</pre>
)}
{taskArtifactPreview.type === 'unknown' && (
<p className="text-xs text-muted-foreground">{taskArtifactPreview.content}</p>
)}
</div>
</div>
)}
{selectedTask.error_message && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('errorMessage')}</Label>
<div className="bg-rose-50 border border-rose-100 rounded-lg p-3 text-sm text-rose-700">
{selectedTask.error_message}
</div>
</div>
)}
{selectedTask.history && selectedTask.history.length > 0 && (
<div className="grid gap-1.5">
<Label className="text-muted-foreground font-medium text-sm">{t('messageHistory')} ({selectedTask.history.length})</Label>
<div className="space-y-2 max-h-[300px] overflow-auto">
{selectedTask.history.map((msg, idx) => (
<div key={idx} className={`p-3 rounded-lg ${msg.role === 'user' ? 'bg-blue-50 border border-blue-100' : 'bg-green-50 border border-green-100'}`}>
<div className="flex items-center justify-between mb-1">
<span className="text-xs font-medium">{msg.role}</span>
{msg.messageId && <span className="text-[10px] text-muted-foreground font-mono">{msg.messageId.slice(0, 8)}...</span>}
</div>
<div className="text-xs">
{msg.parts.map((part, pIdx) => (
<div key={pIdx}>{renderPart(part)}</div>
))}
</div>
</div>
))}
</div>
</div>
)}
</div>
)}
</div>
<DialogFooter className="p-6 pt-2">
<Button variant="outline" onClick={() => { setSelectedTask(null); setTaskArtifactPreview(null); }}>{t('close')}</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
);
}