refactor: a2a
This commit is contained in:
+1091
-129
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,34 @@
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text, JSON, Enum as SQLEnum, func
|
||||
from sqlalchemy.orm import relationship
|
||||
import enum
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class A2ATaskState(str, enum.Enum):
|
||||
SUBMITTED = "SUBMITTED"
|
||||
WORKING = "WORKING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
INPUT_REQUIRED = "INPUT_REQUIRED"
|
||||
AUTH_REQUIRED = "AUTH_REQUIRED"
|
||||
REJECTED = "REJECTED"
|
||||
|
||||
|
||||
class A2APartType(str, enum.Enum):
|
||||
TEXT = "text"
|
||||
RAW = "raw"
|
||||
URL = "url"
|
||||
DATA = "data"
|
||||
|
||||
|
||||
class A2AMessageRole(str, enum.Enum):
|
||||
USER = "user"
|
||||
AGENT = "agent"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class A2ARemoteAgent(Base):
|
||||
__tablename__ = "a2a_remote_agents"
|
||||
|
||||
@@ -13,6 +38,17 @@ class A2ARemoteAgent(Base):
|
||||
base_url = Column(String, nullable=False)
|
||||
auth_scheme = Column(String, nullable=False, default="none")
|
||||
auth_token = Column(String, nullable=True)
|
||||
shared_secret = Column(String, nullable=True)
|
||||
mtls_ca_cert = Column(Text, nullable=True)
|
||||
mtls_client_cert = Column(Text, nullable=True)
|
||||
mtls_client_key = Column(Text, nullable=True)
|
||||
oauth2_client_id = Column(String, nullable=True)
|
||||
oauth2_client_secret = Column(String, nullable=True)
|
||||
oauth2_token_url = Column(String, nullable=True)
|
||||
oauth2_scopes = Column(String, nullable=True)
|
||||
oidc_issuer_url = Column(String, nullable=True)
|
||||
oidc_client_id = Column(String, nullable=True)
|
||||
oidc_client_secret = Column(String, nullable=True)
|
||||
protocol_version = Column(String, nullable=True)
|
||||
capabilities_json = Column(Text, nullable=False, default="[]")
|
||||
card_json = Column(Text, nullable=True)
|
||||
@@ -27,27 +63,84 @@ class A2ARemoteAgent(Base):
|
||||
project = relationship("Project")
|
||||
|
||||
|
||||
class A2APart(Base):
|
||||
__tablename__ = "a2a_parts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
message_id = Column(Integer, ForeignKey("a2a_messages.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
artifact_id = Column(Integer, ForeignKey("a2a_artifacts.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
part_type = Column(SQLEnum(A2APartType), nullable=False)
|
||||
text_content = Column(Text, nullable=True)
|
||||
raw_content = Column(Text, nullable=True)
|
||||
url_content = Column(String, nullable=True)
|
||||
data_content = Column(Text, nullable=True)
|
||||
media_type = Column(String, nullable=True)
|
||||
filename = Column(String, nullable=True)
|
||||
metadata_json = Column(Text, nullable=False, default="{}")
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
|
||||
message = relationship("A2AMessage", back_populates="parts", foreign_keys=[message_id])
|
||||
artifact = relationship("A2AArtifact", back_populates="parts", foreign_keys=[artifact_id])
|
||||
|
||||
|
||||
class A2AMessage(Base):
|
||||
__tablename__ = "a2a_messages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
message_id = Column(String, nullable=False, unique=True, index=True)
|
||||
context_id = Column(String, nullable=True, index=True)
|
||||
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=True, index=True)
|
||||
role = Column(SQLEnum(A2AMessageRole), nullable=False)
|
||||
extensions_json = Column(Text, nullable=False, default="{}")
|
||||
reference_task_ids_json = Column(Text, nullable=False, default="[]")
|
||||
created_at = Column(DateTime, default=func.now(), index=True)
|
||||
|
||||
task = relationship("A2ATask", back_populates="messages", foreign_keys=[task_id])
|
||||
parts = relationship("A2APart", back_populates="message", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class A2AArtifact(Base):
|
||||
__tablename__ = "a2a_artifacts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
artifact_id = Column(String, nullable=False, unique=True, index=True)
|
||||
task_id = Column(String, ForeignKey("a2a_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
name = Column(String, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
metadata_json = Column(Text, nullable=False, default="{}")
|
||||
extensions_json = Column(Text, nullable=False, default="{}")
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
task = relationship("A2ATask", back_populates="artifacts")
|
||||
parts = relationship("A2APart", back_populates="artifact", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class A2ATask(Base):
|
||||
__tablename__ = "a2a_tasks"
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
project_id = Column(Integer, ForeignKey("projects.id"), nullable=False, index=True)
|
||||
tenant_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
context_id = Column(String, nullable=True, index=True)
|
||||
source = Column(String, nullable=False, default="local")
|
||||
remote_agent_id = Column(Integer, ForeignKey("a2a_remote_agents.id"), nullable=True, index=True)
|
||||
idempotency_key = Column(String, nullable=True, index=True)
|
||||
state = Column(String, nullable=False, index=True, default="SUBMITTED")
|
||||
state = Column(SQLEnum(A2ATaskState), nullable=False, index=True, default=A2ATaskState.SUBMITTED)
|
||||
input_text = Column(Text, nullable=False, default="")
|
||||
output_text = Column(Text, nullable=True)
|
||||
error_message = Column(Text, nullable=True)
|
||||
compatibility_mode = Column(Boolean, nullable=False, default=True)
|
||||
metadata_json = Column(Text, nullable=False, default="{}")
|
||||
history_length = Column(Integer, nullable=False, default=0)
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
finished_at = Column(DateTime, nullable=True)
|
||||
|
||||
project = relationship("Project")
|
||||
remote_agent = relationship("A2ARemoteAgent")
|
||||
messages = relationship("A2AMessage", back_populates="task", cascade="all, delete-orphan", foreign_keys=[A2AMessage.task_id])
|
||||
artifacts = relationship("A2AArtifact", back_populates="task", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class A2ATaskEvent(Base):
|
||||
|
||||
@@ -0,0 +1,361 @@
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing import Optional, List, Dict, Any, Literal, Union
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class A2ATaskState(str, Enum):
|
||||
SUBMITTED = "SUBMITTED"
|
||||
WORKING = "WORKING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
INPUT_REQUIRED = "INPUT_REQUIRED"
|
||||
AUTH_REQUIRED = "AUTH_REQUIRED"
|
||||
REJECTED = "REJECTED"
|
||||
|
||||
|
||||
class A2APartType(str, Enum):
|
||||
TEXT = "text"
|
||||
RAW = "raw"
|
||||
URL = "url"
|
||||
DATA = "data"
|
||||
|
||||
|
||||
class A2AMessageRole(str, Enum):
|
||||
USER = "user"
|
||||
AGENT = "agent"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class A2APartSchema(BaseModel):
|
||||
part_type: A2APartType
|
||||
text: Optional[str] = None
|
||||
raw: Optional[bytes] = None
|
||||
url: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
mediaType: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class A2APartCreateSchema(BaseModel):
|
||||
part_type: A2APartType
|
||||
text: Optional[str] = None
|
||||
raw: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
data: Optional[Any] = None
|
||||
mediaType: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class A2AMessageSchema(BaseModel):
|
||||
messageId: str
|
||||
contextId: Optional[str] = None
|
||||
taskId: Optional[str] = None
|
||||
role: A2AMessageRole
|
||||
parts: List[A2APartSchema] = Field(default_factory=list)
|
||||
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
referenceTaskIds: Optional[List[str]] = Field(default_factory=list)
|
||||
createdAt: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class A2AMessageCreateSchema(BaseModel):
|
||||
messageId: str
|
||||
contextId: Optional[str] = None
|
||||
taskId: Optional[str] = None
|
||||
role: A2AMessageRole
|
||||
parts: List[A2APartCreateSchema] = Field(default_factory=list)
|
||||
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
referenceTaskIds: Optional[List[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class A2AArtifactSchema(BaseModel):
|
||||
artifactId: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
parts: List[A2APartSchema] = Field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
createdAt: Optional[datetime] = None
|
||||
updatedAt: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class A2AArtifactCreateSchema(BaseModel):
|
||||
artifactId: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
parts: List[A2APartCreateSchema] = Field(default_factory=list)
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
extensions: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class A2ATaskStatusSchema(BaseModel):
|
||||
state: A2ATaskState
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class A2ATaskSchema(BaseModel):
|
||||
id: str
|
||||
contextId: Optional[str] = None
|
||||
projectId: int
|
||||
tenantId: int
|
||||
source: str
|
||||
remoteAgentId: Optional[int] = None
|
||||
idempotencyKey: Optional[str] = None
|
||||
state: A2ATaskState
|
||||
inputText: str
|
||||
outputText: Optional[str] = None
|
||||
errorMessage: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
historyLength: int = 0
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
finishedAt: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class A2ATaskWithMessagesSchema(A2ATaskSchema):
|
||||
messages: List[A2AMessageSchema] = Field(default_factory=list)
|
||||
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
|
||||
|
||||
|
||||
class A2ATaskWithHistorySchema(BaseModel):
|
||||
id: str
|
||||
contextId: Optional[str] = None
|
||||
projectId: int
|
||||
tenantId: int
|
||||
state: A2ATaskState
|
||||
history: List[A2AMessageSchema] = Field(default_factory=list)
|
||||
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
|
||||
createdAt: datetime
|
||||
updatedAt: datetime
|
||||
finishedAt: Optional[datetime] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class TaskStatusUpdateEvent(BaseModel):
|
||||
taskId: str
|
||||
contextId: Optional[str] = None
|
||||
status: A2ATaskStatusSchema
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TaskArtifactUpdateEvent(BaseModel):
|
||||
taskId: str
|
||||
contextId: Optional[str] = None
|
||||
artifact: A2AArtifactSchema
|
||||
append: bool = False
|
||||
lastChunk: bool = True
|
||||
|
||||
|
||||
class TaskMessageEvent(BaseModel):
|
||||
message: A2AMessageSchema
|
||||
|
||||
|
||||
class StreamResponseTask(BaseModel):
|
||||
id: str
|
||||
contextId: Optional[str] = None
|
||||
state: A2ATaskState
|
||||
artifacts: List[A2AArtifactSchema] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
task: Optional[StreamResponseTask] = None
|
||||
message: Optional[A2AMessageSchema] = None
|
||||
statusUpdate: Optional[TaskStatusUpdateEvent] = None
|
||||
artifactUpdate: Optional[TaskArtifactUpdateEvent] = None
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
message: A2AMessageCreateSchema
|
||||
taskId: Optional[str] = None
|
||||
contextId: Optional[str] = None
|
||||
|
||||
|
||||
class SendStreamingMessageRequest(BaseModel):
|
||||
message: A2AMessageCreateSchema
|
||||
taskId: Optional[str] = None
|
||||
contextId: Optional[str] = None
|
||||
|
||||
|
||||
class GetTaskRequest(BaseModel):
|
||||
historyLength: Optional[int] = None
|
||||
|
||||
|
||||
class TaskListRequest(BaseModel):
|
||||
contextId: Optional[str] = None
|
||||
status: Optional[A2ATaskState] = None
|
||||
pageSize: int = 20
|
||||
pageToken: Optional[str] = None
|
||||
|
||||
|
||||
class CancelTaskRequest(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class PushNotificationConfigCreate(BaseModel):
|
||||
targetUrl: str
|
||||
secret: Optional[str] = None
|
||||
authHeader: Optional[str] = None
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
class PushNotificationConfig(BaseModel):
|
||||
id: int
|
||||
taskId: str
|
||||
targetUrl: str
|
||||
secret: Optional[str] = None
|
||||
authHeader: Optional[str] = None
|
||||
enabled: bool
|
||||
createdBy: int
|
||||
createdAt: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class VersionNotSupportedError(BaseModel):
|
||||
code: int = -32009
|
||||
message: str = "Version not supported"
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class AgentSkillInputMode(str, Enum):
|
||||
TEXT = "text"
|
||||
DATA = "data"
|
||||
RAW = "raw"
|
||||
URL = "url"
|
||||
|
||||
|
||||
class AgentSkillOutputMode(str, Enum):
|
||||
TEXT = "text"
|
||||
DATA = "data"
|
||||
ARTIFACT = "artifact"
|
||||
STREAM = "stream"
|
||||
|
||||
|
||||
class AgentSkillSecurityRequirement(BaseModel):
|
||||
scheme: str
|
||||
scopes: Optional[List[str]] = None
|
||||
|
||||
|
||||
class AgentSkillExample(BaseModel):
|
||||
input: Dict[str, Any]
|
||||
output: Dict[str, Any]
|
||||
|
||||
|
||||
class AgentSkill(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
examples: List[AgentSkillExample] = Field(default_factory=list)
|
||||
inputModes: List[AgentSkillInputMode] = Field(default_factory=list)
|
||||
outputModes: List[AgentSkillOutputMode] = Field(default_factory=list)
|
||||
securityRequirements: List[AgentSkillSecurityRequirement] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentProvider(BaseModel):
|
||||
organization: str
|
||||
url: Optional[str] = None
|
||||
|
||||
|
||||
class AgentSupportedInterface(BaseModel):
|
||||
url: str
|
||||
protocolBinding: str
|
||||
protocolVersion: str
|
||||
tenant: Optional[str] = None
|
||||
|
||||
|
||||
class SecuritySchemeApiKey(BaseModel):
|
||||
type: Literal["apiKey"] = "apiKey"
|
||||
name: str
|
||||
in_: str = Field(alias="in")
|
||||
description: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
|
||||
class SecuritySchemeHttpAuth(BaseModel):
|
||||
type: Literal["http"] = "http"
|
||||
scheme: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class OAuth2AuthorizationCodeFlow(BaseModel):
|
||||
authorizationUrl: str
|
||||
tokenUrl: str
|
||||
scopes: Dict[str, str] = Field(default_factory=dict)
|
||||
refreshUrl: Optional[str] = None
|
||||
|
||||
|
||||
class OAuth2ClientCredentialsFlow(BaseModel):
|
||||
tokenUrl: str
|
||||
scopes: Dict[str, str] = Field(default_factory=dict)
|
||||
refreshUrl: Optional[str] = None
|
||||
|
||||
|
||||
class OAuth2DeviceCodeFlow(BaseModel):
|
||||
authorizationUrl: str
|
||||
tokenUrl: str
|
||||
scopes: Dict[str, str] = Field(default_factory=dict)
|
||||
deviceAuthorizationUrl: Optional[str] = None
|
||||
|
||||
|
||||
class OAuth2Flows(BaseModel):
|
||||
authorizationCode: Optional[OAuth2AuthorizationCodeFlow] = None
|
||||
clientCredentials: Optional[OAuth2ClientCredentialsFlow] = None
|
||||
deviceCode: Optional[OAuth2DeviceCodeFlow] = None
|
||||
implicit: Optional[Dict[str, Any]] = None
|
||||
password: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class SecuritySchemeOAuth2(BaseModel):
|
||||
type: Literal["oauth2"] = "oauth2"
|
||||
flows: OAuth2Flows
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class SecuritySchemeOpenIdConnect(BaseModel):
|
||||
type: Literal["openIdConnect"] = "openIdConnect"
|
||||
openIdConnectUrl: str
|
||||
description: Optional[str] = None
|
||||
scopes: Dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SecuritySchemeMtls(BaseModel):
|
||||
type: Literal["mutualTLS"] = "mutualTLS"
|
||||
description: Optional[str] = None
|
||||
caCerts: Optional[List[str]] = None
|
||||
clientCert: Optional[str] = None
|
||||
clientKey: Optional[str] = None
|
||||
|
||||
|
||||
class AgentCardPublicSchema(BaseModel):
|
||||
name: str
|
||||
protocol_version: str = "1.0"
|
||||
capabilities: List[str]
|
||||
endpoints: Dict[str, str]
|
||||
auth: List[str]
|
||||
skills: List[AgentSkill] = Field(default_factory=list)
|
||||
provider: Optional[AgentProvider] = None
|
||||
supportedInterfaces: List[AgentSupportedInterface] = Field(default_factory=list)
|
||||
defaultInputModes: List[str] = Field(default_factory=list)
|
||||
defaultOutputModes: List[str] = Field(default_factory=list)
|
||||
iconUrl: Optional[str] = None
|
||||
documentationUrl: Optional[str] = None
|
||||
|
||||
|
||||
class AgentCardExtendedSchema(AgentCardPublicSchema):
|
||||
securitySchemes: Optional[Dict[str, Union[SecuritySchemeApiKey, SecuritySchemeHttpAuth, SecuritySchemeOAuth2, SecuritySchemeOpenIdConnect, SecuritySchemeMtls]]] = None
|
||||
security: List[Dict[str, List[str]]] = Field(default_factory=list)
|
||||
signatures: List[str] = Field(default_factory=list)
|
||||
tenantId: Optional[int] = None
|
||||
isAdmin: Optional[bool] = None
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
@@ -24,20 +25,15 @@ from app.models.a2a import (
|
||||
A2AWebhookDelivery,
|
||||
)
|
||||
from app.models.project import Project
|
||||
from app.schemas.a2a import (
|
||||
A2AArtifactSchema,
|
||||
A2APartSchema,
|
||||
A2ATaskStatusSchema,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from app.trace import build_error_attributes, trace_service
|
||||
|
||||
_STATE_TRANSITIONS = {
|
||||
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
|
||||
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
|
||||
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
|
||||
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
|
||||
"REJECTED": set(),
|
||||
"FAILED": set(),
|
||||
"COMPLETED": set(),
|
||||
"CANCELED": set(),
|
||||
}
|
||||
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
|
||||
|
||||
|
||||
def _json_loads(raw: Optional[str], default: Any) -> Any:
|
||||
if not raw:
|
||||
@@ -62,6 +58,292 @@ def _mask_error(message: str) -> str:
|
||||
return "request_failed"
|
||||
|
||||
|
||||
class SharedSecretAuth:
|
||||
@staticmethod
|
||||
def generate_signature(secret: str, payload: bytes, timestamp: Optional[int] = None) -> Tuple[str, int]:
|
||||
if timestamp is None:
|
||||
timestamp = int(time.time())
|
||||
message = f"{timestamp}".encode() + payload
|
||||
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
|
||||
return f"sha256={signature}", timestamp
|
||||
|
||||
@staticmethod
|
||||
def verify_signature(secret: str, payload: bytes, signature: str, timestamp: int, max_age_seconds: int = 300) -> bool:
|
||||
if abs(time.time() - timestamp) > max_age_seconds:
|
||||
return False
|
||||
expected_sig, _ = SharedSecretAuth.generate_signature(secret, payload, timestamp)
|
||||
return hmac.compare_digest(signature, expected_sig)
|
||||
|
||||
@staticmethod
|
||||
def sign_request(secret: str, method: str, path: str, body: Optional[bytes] = None) -> Dict[str, str]:
|
||||
timestamp = int(time.time())
|
||||
payload = body or b""
|
||||
message = f"{timestamp}.{method.upper()}.{path}".encode() + payload
|
||||
signature = hmac.new(secret.encode(), message, hashlib.sha256).hexdigest()
|
||||
return {
|
||||
"X-A2A-Signature": f"sha256={signature}",
|
||||
"X-A2A-Timestamp": str(timestamp),
|
||||
}
|
||||
|
||||
|
||||
class MtlsConfig:
|
||||
def __init__(
|
||||
self,
|
||||
ca_cert: Optional[str] = None,
|
||||
client_cert: Optional[str] = None,
|
||||
client_key: Optional[str] = None,
|
||||
):
|
||||
self.ca_cert = ca_cert
|
||||
self.client_cert = client_cert
|
||||
self.client_key = client_key
|
||||
|
||||
def create_ssl_context(self) -> Optional[ssl.SSLContext]:
|
||||
if not self.client_cert or not self.client_key:
|
||||
return None
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.load_cert_chain(self.client_cert, self.client_key)
|
||||
if self.ca_cert:
|
||||
ctx.load_verify_locations(self.ca_cert)
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
else:
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
return ctx
|
||||
|
||||
|
||||
class OAuth2TokenStore:
|
||||
def __init__(self):
|
||||
self._tokens: Dict[str, Tuple[str, datetime]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_token(self, key: str) -> Optional[str]:
|
||||
async with self._lock:
|
||||
if key in self._tokens:
|
||||
token, expires_at = self._tokens[key]
|
||||
if expires_at > _utc_now() + timedelta(minutes=1):
|
||||
return token
|
||||
return None
|
||||
|
||||
async def set_token(self, key: str, token: str, expires_in: int = 3600) -> None:
|
||||
async with self._lock:
|
||||
self._tokens[key] = (token, _utc_now() + timedelta(seconds=expires_in))
|
||||
|
||||
|
||||
class OAuth2Auth:
|
||||
def __init__(
|
||||
self,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
token_url: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.token_url = token_url
|
||||
self.scopes = scopes or []
|
||||
self._token_store = OAuth2TokenStore()
|
||||
|
||||
def _get_cache_key(self) -> str:
|
||||
return f"{self.client_id}:{self.token_url}:{':'.join(self.scopes)}"
|
||||
|
||||
async def get_access_token(self, grant_type: str = "client_credentials") -> str:
|
||||
cache_key = self._get_cache_key()
|
||||
cached = await self._token_store.get_token(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"grant_type": grant_type,
|
||||
}
|
||||
if self.scopes:
|
||||
data["scope"] = " ".join(self.scopes)
|
||||
resp = await client.post(self.token_url, data=data)
|
||||
resp.raise_for_status()
|
||||
token_data = resp.json()
|
||||
token = token_data["access_token"]
|
||||
expires_in = token_data.get("expires_in", 3600)
|
||||
await self._token_store.set_token(cache_key, token, expires_in)
|
||||
return token
|
||||
|
||||
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
|
||||
token = await self.get_access_token()
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
class OIDCAuth:
|
||||
def __init__(
|
||||
self,
|
||||
issuer_url: str,
|
||||
client_id: str,
|
||||
client_secret: Optional[str] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
):
|
||||
self.issuer_url = issuer_url.rstrip("/")
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.scopes = scopes or ["openid", "profile"]
|
||||
self._oauth2: Optional[OAuth2Auth] = None
|
||||
self._discovery_cache: Optional[Dict[str, Any]] = None
|
||||
|
||||
async def _get_discovery(self) -> Dict[str, Any]:
|
||||
if self._discovery_cache:
|
||||
return self._discovery_cache
|
||||
discovery_url = f"{self.issuer_url}/.well-known/openid-configuration"
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(discovery_url)
|
||||
resp.raise_for_status()
|
||||
self._discovery_cache = resp.json()
|
||||
return self._discovery_cache
|
||||
|
||||
async def get_access_token(self) -> str:
|
||||
discovery = await self._get_discovery()
|
||||
token_url = discovery.get("token_endpoint")
|
||||
if not token_url:
|
||||
raise RuntimeError("OIDC discovery missing token_endpoint")
|
||||
if not self._oauth2:
|
||||
self._oauth2 = OAuth2Auth(
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret or "",
|
||||
token_url=token_url,
|
||||
scopes=self.scopes,
|
||||
)
|
||||
return await self._oauth2.get_access_token()
|
||||
|
||||
async def authorize_request(self, method: str, url: str, **kwargs) -> Dict[str, str]:
|
||||
token = await self.get_access_token()
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
class RemoteAgentSecuritySelector:
|
||||
def __init__(self, agent: A2ARemoteAgent):
|
||||
self.agent = agent
|
||||
self._card_security_schemes: Optional[Dict[str, Any]] = None
|
||||
|
||||
def load_security_from_card(self) -> None:
|
||||
card = _json_loads(self.agent.card_json, {})
|
||||
if card:
|
||||
self._card_security_schemes = card.get("securitySchemes", {})
|
||||
|
||||
def get_preferred_auth_scheme(self) -> str:
|
||||
card = _json_loads(self.agent.card_json, {})
|
||||
security_reqs = card.get("security", [])
|
||||
if security_reqs:
|
||||
first_req = security_reqs[0]
|
||||
if isinstance(first_req, dict):
|
||||
for scheme_name in first_req.keys():
|
||||
return scheme_name
|
||||
return self.agent.auth_scheme or "bearer"
|
||||
|
||||
def get_auth_headers(self, user_token: Optional[str] = None) -> Dict[str, str]:
|
||||
headers: Dict[str, str] = {}
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred == "bearer" or self.agent.auth_scheme == "bearer":
|
||||
if self.agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
|
||||
elif user_token:
|
||||
headers["Authorization"] = f"Bearer {user_token}"
|
||||
|
||||
elif preferred == "shared_secret" or self.agent.auth_scheme == "shared_secret":
|
||||
pass
|
||||
|
||||
elif preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
|
||||
pass
|
||||
|
||||
elif preferred == "openIdConnect":
|
||||
pass
|
||||
|
||||
elif preferred == "mutualTLS":
|
||||
pass
|
||||
|
||||
return headers
|
||||
|
||||
def get_mtls_context(self) -> Optional[ssl.SSLContext]:
|
||||
if self.agent.auth_scheme == "mutualTLS" or self.get_preferred_auth_scheme() == "mutualTLS":
|
||||
if self.agent.mtls_client_cert and self.agent.mtls_client_key:
|
||||
config = MtlsConfig(
|
||||
ca_cert=self.agent.mtls_ca_cert,
|
||||
client_cert=self.agent.mtls_client_cert,
|
||||
client_key=self.agent.mtls_client_key,
|
||||
)
|
||||
return config.create_ssl_context()
|
||||
return None
|
||||
|
||||
def create_signed_request_headers(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
body: Optional[bytes] = None,
|
||||
) -> Dict[str, str]:
|
||||
headers: Dict[str, str] = {}
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred == "shared_secret" and self.agent.shared_secret:
|
||||
sig_headers = SharedSecretAuth.sign_request(
|
||||
self.agent.shared_secret,
|
||||
method,
|
||||
path,
|
||||
body,
|
||||
)
|
||||
headers.update(sig_headers)
|
||||
|
||||
elif self.agent.auth_scheme == "bearer" and self.agent.auth_token:
|
||||
headers["Authorization"] = f"Bearer {self.agent.auth_token}"
|
||||
|
||||
return headers
|
||||
|
||||
async def get_oauth2_auth(self) -> Optional[OAuth2Auth]:
|
||||
if self.agent.oauth2_client_id and self.agent.oauth2_token_url:
|
||||
scopes = self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else []
|
||||
return OAuth2Auth(
|
||||
client_id=self.agent.oauth2_client_id,
|
||||
client_secret=self.agent.oauth2_client_secret or "",
|
||||
token_url=self.agent.oauth2_token_url,
|
||||
scopes=scopes,
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_oidc_auth(self) -> Optional[OIDCAuth]:
|
||||
if self.agent.oidc_issuer_url:
|
||||
return OIDCAuth(
|
||||
issuer_url=self.agent.oidc_issuer_url,
|
||||
client_id=self.agent.oidc_client_id or "",
|
||||
client_secret=self.agent.oidc_client_secret,
|
||||
scopes=self.agent.oauth2_scopes.split() if self.agent.oauth2_scopes else [],
|
||||
)
|
||||
return None
|
||||
|
||||
async def authorize_request(self, method: str, url: str, user_token: Optional[str] = None) -> Dict[str, str]:
|
||||
headers = self.get_auth_headers(user_token)
|
||||
preferred = self.get_preferred_auth_scheme()
|
||||
|
||||
if preferred in ("oauth2", "oauth2_authorizationcode", "oauth2_clientcredentials"):
|
||||
oauth2_auth = await self.get_oauth2_auth()
|
||||
if oauth2_auth:
|
||||
headers.update(await oauth2_auth.authorize_request(method, url))
|
||||
|
||||
elif preferred == "openIdConnect":
|
||||
oidc_auth = await self.get_oidc_auth()
|
||||
if oidc_auth:
|
||||
headers.update(await oidc_auth.authorize_request(method, url))
|
||||
|
||||
return headers
|
||||
|
||||
_STATE_TRANSITIONS = {
|
||||
"SUBMITTED": {"WORKING", "FAILED", "CANCELED", "REJECTED", "AUTH_REQUIRED", "INPUT_REQUIRED", "COMPLETED"},
|
||||
"WORKING": {"COMPLETED", "FAILED", "CANCELED", "INPUT_REQUIRED", "AUTH_REQUIRED"},
|
||||
"INPUT_REQUIRED": {"WORKING", "FAILED", "CANCELED"},
|
||||
"AUTH_REQUIRED": {"WORKING", "FAILED", "CANCELED", "REJECTED"},
|
||||
"REJECTED": set(),
|
||||
"FAILED": set(),
|
||||
"COMPLETED": set(),
|
||||
"CANCELED": set(),
|
||||
}
|
||||
_TERMINAL_STATES = {"COMPLETED", "FAILED", "CANCELED", "REJECTED"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AResolvedRoute:
|
||||
selected: str
|
||||
@@ -185,6 +467,7 @@ class A2ARuntime:
|
||||
remote_agent_id: Optional[int],
|
||||
compatibility_mode: bool,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
context_id: Optional[str] = None,
|
||||
) -> A2ATask:
|
||||
if idempotency_key:
|
||||
existing = (
|
||||
@@ -209,6 +492,7 @@ class A2ARuntime:
|
||||
idempotency_key=idempotency_key,
|
||||
compatibility_mode=compatibility_mode,
|
||||
metadata_json=_json_dumps(metadata or {}),
|
||||
context_id=context_id,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
@@ -335,23 +619,21 @@ class A2ARuntime:
|
||||
|
||||
async def _deliver_once(self, db: Session, hook: A2ATaskWebhook, event: A2ATaskEvent, delivery: A2AWebhookDelivery) -> None:
|
||||
event_payload = _json_loads(event.payload_json, {})
|
||||
request_payload = {
|
||||
"task_id": event.task_id,
|
||||
"event_type": event.event_type,
|
||||
"event_id": event.id,
|
||||
"payload": event_payload,
|
||||
}
|
||||
body = _json_dumps(request_payload).encode("utf-8")
|
||||
stream_response_payload = self._build_stream_response_payload(event, event_payload)
|
||||
body = _json_dumps(stream_response_payload).encode("utf-8")
|
||||
|
||||
for attempt in range(1, 5):
|
||||
delivery.attempt = attempt
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
|
||||
headers = {"Content-Type": "application/json", "X-A2A-Event-Id": str(event.id)}
|
||||
if hook.secret:
|
||||
digest = hmac.new(hook.secret.encode("utf-8"), body, hashlib.sha256).hexdigest()
|
||||
headers["X-A2A-Signature"] = f"sha256={digest}"
|
||||
if hook.auth_header:
|
||||
headers["Authorization"] = hook.auth_header
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0, verify=True) as client:
|
||||
resp = await client.post(hook.target_url, content=body, headers=headers)
|
||||
@@ -368,11 +650,12 @@ class A2ARuntime:
|
||||
except Exception as exc:
|
||||
delivery.error_message = str(exc)[:500]
|
||||
if attempt < 4:
|
||||
backoff_seconds = 2 ** attempt
|
||||
delivery.status = "RETRYING"
|
||||
delivery.next_retry_at = _utc_now() + timedelta(seconds=2 ** attempt)
|
||||
delivery.next_retry_at = _utc_now() + timedelta(seconds=backoff_seconds)
|
||||
db.add(delivery)
|
||||
db.commit()
|
||||
await asyncio.sleep(2 ** attempt)
|
||||
await asyncio.sleep(backoff_seconds)
|
||||
continue
|
||||
delivery.status = "FAILED"
|
||||
delivery.dead_letter = True
|
||||
@@ -380,5 +663,42 @@ class A2ARuntime:
|
||||
db.commit()
|
||||
return
|
||||
|
||||
def _build_stream_response_payload(self, event: A2ATaskEvent, event_payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
event_type = event.event_type
|
||||
task_id = event_payload.get("task_id", event.task_id)
|
||||
|
||||
if event_type == "TaskStatusUpdateEvent":
|
||||
status_state = event_payload.get("task_status", "WORKING")
|
||||
status_timestamp = event_payload.get("timestamp", _utc_now().isoformat())
|
||||
status_schema = A2ATaskStatusSchema(
|
||||
state=status_state,
|
||||
timestamp=datetime.fromisoformat(status_timestamp) if isinstance(status_timestamp, str) else status_timestamp,
|
||||
)
|
||||
return {
|
||||
"statusUpdate": TaskStatusUpdateEvent(
|
||||
taskId=task_id,
|
||||
contextId=event_payload.get("context_id"),
|
||||
status=status_schema,
|
||||
metadata=event_payload.get("metadata", {}),
|
||||
).model_dump()
|
||||
}
|
||||
elif event_type == "TaskArtifactUpdateEvent":
|
||||
artifact_content = event_payload.get("artifact", {}).get("content", "")
|
||||
artifact_schema = A2AArtifactSchema(
|
||||
artifactId=f"artifact-{event.id}",
|
||||
parts=[A2APartSchema(part_type="text", text=artifact_content)],
|
||||
)
|
||||
return {
|
||||
"artifactUpdate": TaskArtifactUpdateEvent(
|
||||
taskId=task_id,
|
||||
contextId=event_payload.get("context_id"),
|
||||
artifact=artifact_schema,
|
||||
append=False,
|
||||
lastChunk=True,
|
||||
).model_dump()
|
||||
}
|
||||
else:
|
||||
return {"message": event_payload}
|
||||
|
||||
|
||||
a2a_runtime = A2ARuntime()
|
||||
|
||||
+698
-145
@@ -1,7 +1,13 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -20,10 +26,11 @@ if str(NANOBOT_ROOT) not in sys.path:
|
||||
|
||||
from app.core.security import CurrentUser, get_current_user
|
||||
from app.database import Base, get_db
|
||||
from app.models.a2a import A2ARemoteAgent
|
||||
from app.models.a2a import A2ARemoteAgent, A2ATask, A2ATaskState
|
||||
from app.models.project import Project
|
||||
from app.models.user import User
|
||||
from app.services.a2a_service import a2a_runtime
|
||||
from app.schemas.a2a import A2AMessageRole, A2APartType, AgentSkillOutputMode, AgentSkillInputMode
|
||||
from app.services.a2a_service import a2a_runtime, SharedSecretAuth
|
||||
from main import app
|
||||
|
||||
|
||||
@@ -42,144 +49,83 @@ def _seed(db: Session) -> tuple[int, str, int, str, int]:
|
||||
return owner.id, owner.username, other.id, other.username, project.id
|
||||
|
||||
|
||||
def test_a2a_send_list_cancel_and_rollout() -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
send_resp = client.post(
|
||||
"/api/v1/a2a/messages/send",
|
||||
json={
|
||||
def _make_message_payload(project_id: int, text: str, session_id: str = "test-session", route_mode: str = "local_first", idempotency_key: Optional[str] = None) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {
|
||||
"message": {
|
||||
"messageId": f"msg-{int(time.time()*1000)}",
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"part_type": "data",
|
||||
"data": {
|
||||
"project_id": project_id,
|
||||
"message": "hello a2a",
|
||||
"session_id": "test-a2a-session",
|
||||
"route_mode": "local_first",
|
||||
"route_mode": route_mode,
|
||||
"session_id": session_id,
|
||||
**( {"idempotency_key": idempotency_key} if idempotency_key else {} )
|
||||
},
|
||||
"mediaType": "application/json",
|
||||
},
|
||||
{
|
||||
"part_type": "text",
|
||||
"text": text,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
class TestPartSerialization:
|
||||
def test_part_text_serialization(self):
|
||||
from app.schemas.a2a import A2APartCreateSchema
|
||||
part = A2APartCreateSchema(
|
||||
part_type=A2APartType.TEXT,
|
||||
text="Hello world",
|
||||
mediaType="text/plain",
|
||||
)
|
||||
assert send_resp.status_code == 200
|
||||
task_id = send_resp.json()["task"]["id"]
|
||||
data = part.model_dump()
|
||||
assert data["part_type"] == "text"
|
||||
assert data["text"] == "Hello world"
|
||||
assert data["mediaType"] == "text/plain"
|
||||
|
||||
get_resp = client.get(f"/api/v1/a2a/tasks/{task_id}")
|
||||
assert get_resp.status_code == 200
|
||||
assert get_resp.json()["project_id"] == project_id
|
||||
|
||||
list_resp = client.get("/api/v1/a2a/tasks", params={"project_id": project_id})
|
||||
assert list_resp.status_code == 200
|
||||
assert any(item["id"] == task_id for item in list_resp.json())
|
||||
|
||||
cancel_resp = client.post(f"/api/v1/a2a/tasks/{task_id}/cancel")
|
||||
assert cancel_resp.status_code == 200
|
||||
assert cancel_resp.json()["state"] in {"CANCELED", "COMPLETED", "FAILED"}
|
||||
|
||||
rollout_resp = client.put(
|
||||
f"/api/v1/a2a/projects/{project_id}/rollout",
|
||||
json={"canary_enabled": True, "canary_percent": 30, "route_mode_default": "a2a_first", "fallback_chain": ["a2a", "local"]},
|
||||
def test_part_data_serialization(self):
|
||||
from app.schemas.a2a import A2APartCreateSchema
|
||||
part = A2APartCreateSchema(
|
||||
part_type=A2APartType.DATA,
|
||||
data={"project_id": 123, "route_mode": "local"},
|
||||
mediaType="application/json",
|
||||
)
|
||||
assert rollout_resp.status_code == 200
|
||||
assert rollout_resp.json()["canary_enabled"] is True
|
||||
assert rollout_resp.json()["canary_percent"] == 30
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
data = part.model_dump()
|
||||
assert data["part_type"] == "data"
|
||||
assert data["data"]["project_id"] == 123
|
||||
|
||||
|
||||
def test_a2a_task_tenant_isolation() -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
send_resp = client.post(
|
||||
"/api/v1/a2a/messages/send",
|
||||
json={"project_id": project_id, "message": "tenant isolation", "session_id": "tenant-isolation", "route_mode": "local"},
|
||||
def test_part_url_serialization(self):
|
||||
from app.schemas.a2a import A2APartCreateSchema
|
||||
part = A2APartCreateSchema(
|
||||
part_type=A2APartType.URL,
|
||||
url="https://example.com/file.pdf",
|
||||
mediaType="application/pdf",
|
||||
filename="file.pdf",
|
||||
)
|
||||
assert send_resp.status_code == 200
|
||||
task_id = send_resp.json()["task"]["id"]
|
||||
data = part.model_dump()
|
||||
assert data["part_type"] == "url"
|
||||
assert data["url"] == "https://example.com/file.pdf"
|
||||
assert data["filename"] == "file.pdf"
|
||||
|
||||
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
|
||||
forbidden_resp = client.get(f"/api/v1/a2a/tasks/{task_id}")
|
||||
assert forbidden_resp.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
def test_part_raw_serialization(self):
|
||||
from app.schemas.a2a import A2APartCreateSchema
|
||||
part = A2APartCreateSchema(
|
||||
part_type=A2APartType.RAW,
|
||||
raw="\x00\x01\x02\x03",
|
||||
mediaType="application/octet-stream",
|
||||
)
|
||||
data = part.model_dump()
|
||||
assert data["part_type"] == "raw"
|
||||
assert data["raw"] == "\x00\x01\x02\x03"
|
||||
|
||||
|
||||
def test_a2a_metrics_admin_only() -> None:
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, _ = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
denied = client.get("/api/v1/a2a/metrics")
|
||||
assert denied.status_code == 403
|
||||
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
|
||||
ok = client.get("/api/v1/a2a/metrics")
|
||||
assert ok.status_code == 200
|
||||
assert "counters" in ok.json()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_a2a_send_idempotency_key_deduplicates_task() -> None:
|
||||
class TestStateMachine:
|
||||
def test_state_transitions_submit_to_complete(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@@ -203,25 +149,518 @@ def test_a2a_send_idempotency_key_deduplicates_task() -> None:
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
payload = {
|
||||
"project_id": project_id,
|
||||
"message": "dedupe-task",
|
||||
"session_id": "idempotency-session",
|
||||
"route_mode": "local_first",
|
||||
"idempotency_key": "same-key-1",
|
||||
}
|
||||
first_resp = client.post("/api/v1/a2a/messages/send", json=payload)
|
||||
second_resp = client.post("/api/v1/a2a/messages/send", json=payload)
|
||||
assert first_resp.status_code == 200
|
||||
assert second_resp.status_code == 200
|
||||
assert first_resp.json()["task"]["id"] == second_resp.json()["task"]["id"]
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "test state machine"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
assert task.state == A2ATaskState.SUBMITTED
|
||||
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.WORKING)
|
||||
assert task.state == A2ATaskState.WORKING
|
||||
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.COMPLETED)
|
||||
assert task.state == A2ATaskState.COMPLETED
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_cancel(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
|
||||
assert cancel_resp.status_code == 200
|
||||
assert cancel_resp.json()["state"] == "CANCELED"
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_failed(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "fail test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.FAILED, error_message='{"message": "test error"}')
|
||||
assert task.state == A2ATaskState.FAILED
|
||||
assert task.error_message is not None
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_rejected(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "reject test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.REJECTED)
|
||||
assert task.state == A2ATaskState.REJECTED
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_input_required(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "input required test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.INPUT_REQUIRED)
|
||||
assert task.state == A2ATaskState.INPUT_REQUIRED
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_state_auth_required(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "auth required test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
task = db.query(A2ATask).filter(A2ATask.id == task_id).first()
|
||||
task = a2a_runtime.transition_task(db, task, to_state=A2ATaskState.AUTH_REQUIRED)
|
||||
assert task.state == A2ATaskState.AUTH_REQUIRED
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_a2a_fetch_agent_card_auth_failure_marks_agent_unhealthy(monkeypatch) -> None:
|
||||
class TestA2APathNormalization:
|
||||
def test_message_send_path(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "path test"))
|
||||
assert resp.status_code == 200
|
||||
assert "task" in resp.json()
|
||||
assert "id" in resp.json()["task"]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_message_stream_path(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:stream", json=_make_message_payload(project_id, "stream path test"))
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("text/event-stream")
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_tasks_cancel_path(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "cancel path test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
cancel_resp = client.post(f"/api/v1/tasks/{task_id}:cancel", json={})
|
||||
assert cancel_resp.status_code == 200
|
||||
assert cancel_resp.json()["task_id"] == task_id
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_agent_card_public_path(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
db.close()
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.get("/api/v1/.well-known/agent-card.json")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "name" in data
|
||||
assert "protocol_version" in data
|
||||
assert "endpoints" in data
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestVersionNegotiation:
|
||||
def test_version_not_supported_error(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post(
|
||||
"/api/v1/message:send",
|
||||
json=_make_message_payload(project_id, "version test"),
|
||||
headers={"A2A-Version": "2.0"}
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
detail = json.loads(resp.json()["detail"])
|
||||
assert detail["code"] == -32009
|
||||
assert "not supported" in detail["message"].lower()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_version_response_header(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post(
|
||||
"/api/v1/message:send",
|
||||
json=_make_message_payload(project_id, "version header test"),
|
||||
headers={"A2A-Version": "1.0"}
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers.get("A2A-Version") == "1.0"
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestWebhookStreamResponse:
|
||||
def test_webhook_payload_format(self):
|
||||
from app.schemas.a2a import StreamResponse, TaskStatusUpdateEvent, TaskArtifactUpdateEvent, TaskMessageEvent, A2ATaskStatusSchema, A2ATaskState, A2AArtifactSchema
|
||||
|
||||
status_event = TaskStatusUpdateEvent(
|
||||
taskId="task-123",
|
||||
contextId="ctx-456",
|
||||
status=A2ATaskStatusSchema(
|
||||
state=A2ATaskState.SUBMITTED,
|
||||
timestamp=datetime.utcnow(),
|
||||
),
|
||||
metadata={},
|
||||
)
|
||||
status_dump = status_event.model_dump()
|
||||
assert "taskId" in status_dump
|
||||
assert status_dump["taskId"] == "task-123"
|
||||
assert status_dump["status"]["state"] == "SUBMITTED"
|
||||
|
||||
artifact_event = TaskArtifactUpdateEvent(
|
||||
taskId="task-123",
|
||||
contextId="ctx-456",
|
||||
artifact=A2AArtifactSchema(
|
||||
artifactId="art-789",
|
||||
parts=[],
|
||||
),
|
||||
append=False,
|
||||
lastChunk=True,
|
||||
)
|
||||
artifact_dump = artifact_event.model_dump()
|
||||
assert "taskId" in artifact_dump
|
||||
assert artifact_dump["artifact"]["artifactId"] == "art-789"
|
||||
|
||||
def test_stream_response_task_field(self):
|
||||
from app.schemas.a2a import StreamResponse, StreamResponseTask, A2ATaskState
|
||||
resp = StreamResponse(
|
||||
task=StreamResponseTask(
|
||||
id="task-123",
|
||||
contextId="ctx-456",
|
||||
state=A2ATaskState.WORKING,
|
||||
artifacts=[],
|
||||
)
|
||||
)
|
||||
data = resp.model_dump()
|
||||
assert "task" in data
|
||||
assert data["task"]["id"] == "task-123"
|
||||
|
||||
|
||||
class TestSSEFIFO:
|
||||
def test_sse_event_order(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
with client.stream("POST", "/api/v1/message:stream", json=_make_message_payload(project_id, "fifo test")) as resp:
|
||||
assert resp.status_code == 200
|
||||
chunks = []
|
||||
for line in resp.iter_lines():
|
||||
if line.startswith("data: "):
|
||||
chunks.append(json.loads(line[6:]))
|
||||
|
||||
event_types = [c.get("type") for c in chunks if "type" in c]
|
||||
status_idx = next((i for i, t in enumerate(event_types) if t == "TaskStatusUpdateEvent"), -1)
|
||||
assert status_idx >= 0
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestAuthSchemes:
|
||||
def test_shared_secret_auth(self):
|
||||
secret = "test-secret-key-12345"
|
||||
timestamp = int(time.time())
|
||||
body = b'{"test":"data"}'
|
||||
sig, _ = SharedSecretAuth.generate_signature(secret, body, timestamp)
|
||||
assert sig.startswith("sha256=")
|
||||
|
||||
assert SharedSecretAuth.verify_signature(secret, body, sig, timestamp) is True
|
||||
|
||||
def test_auth_scheme_none(self):
|
||||
from app.schemas.a2a import SecuritySchemeHttpAuth
|
||||
scheme = SecuritySchemeHttpAuth(scheme="bearer", description="Bearer auth")
|
||||
assert scheme.scheme == "bearer"
|
||||
|
||||
|
||||
class TestExceptionPaths:
|
||||
def test_auth_failure_marks_agent_unhealthy(self, monkeypatch):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@@ -272,8 +711,7 @@ def test_a2a_fetch_agent_card_auth_failure_marks_agent_unhealthy(monkeypatch) ->
|
||||
db.close()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def test_a2a_fetch_agent_card_remote_unavailable_opens_circuit(monkeypatch) -> None:
|
||||
def test_remote_unavailable_opens_circuit(self, monkeypatch):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@@ -317,3 +755,118 @@ def test_a2a_fetch_agent_card_remote_unavailable_opens_circuit(monkeypatch) -> N
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
db.close()
|
||||
engine.dispose()
|
||||
|
||||
def test_idempotency_key_deduplicates_task(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
idempotency_key = f"idem-key-{int(time.time())}"
|
||||
|
||||
payload1 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
|
||||
resp1 = client.post("/api/v1/message:send", json=payload1)
|
||||
assert resp1.status_code == 200
|
||||
|
||||
payload2 = _make_message_payload(project_id, "dedupe test", idempotency_key=idempotency_key)
|
||||
payload2["message"]["messageId"] = f"msg-{int(time.time()*1000) + 1}"
|
||||
resp2 = client.post("/api/v1/message:send", json=payload2)
|
||||
|
||||
assert resp2.status_code == 200
|
||||
assert resp1.json()["task"]["id"] == resp2.json()["task"]["id"]
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
def test_tenant_isolation(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, other_id, other_username, project_id = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
resp = client.post("/api/v1/message:send", json=_make_message_payload(project_id, "isolation test"))
|
||||
assert resp.status_code == 200
|
||||
task_id = resp.json()["task"]["id"]
|
||||
|
||||
state["user"] = CurrentUser(id=other_id, username=other_username, is_admin=False)
|
||||
get_resp = client.get(f"/api/v1/tasks/{task_id}")
|
||||
assert get_resp.status_code == 404
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
|
||||
class TestMetricsAdminOnly:
|
||||
def test_metrics_admin_only(self):
|
||||
engine = create_engine("sqlite://", connect_args={"check_same_thread": False}, poolclass=StaticPool)
|
||||
testing_session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = testing_session_local()
|
||||
owner_id, owner_username, _, _, _ = _seed(db)
|
||||
db.close()
|
||||
|
||||
state = {"user": CurrentUser(id=owner_id, username=owner_username, is_admin=False)}
|
||||
|
||||
def override_get_db() -> Generator[Session, None, None]:
|
||||
override_db = testing_session_local()
|
||||
try:
|
||||
yield override_db
|
||||
finally:
|
||||
override_db.close()
|
||||
|
||||
def override_current_user() -> CurrentUser:
|
||||
return state["user"]
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
app.dependency_overrides[get_current_user] = override_current_user
|
||||
try:
|
||||
client = TestClient(app)
|
||||
denied = client.get("/api/v1/a2a/metrics")
|
||||
assert denied.status_code == 403
|
||||
|
||||
state["user"] = CurrentUser(id=owner_id, username=owner_username, is_admin=True)
|
||||
ok = client.get("/api/v1/a2a/metrics")
|
||||
assert ok.status_code == 200
|
||||
assert "counters" in ok.json()
|
||||
finally:
|
||||
app.dependency_overrides.clear()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
engine.dispose()
|
||||
|
||||
+258
-28
@@ -1,5 +1,100 @@
|
||||
import { api } from "@/lib/api";
|
||||
|
||||
export interface A2APartText {
|
||||
kind: "text";
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface A2APartUrl {
|
||||
kind: "url";
|
||||
url: string;
|
||||
}
|
||||
|
||||
export interface A2APartFile {
|
||||
kind: "file";
|
||||
data: string;
|
||||
mediaType?: string;
|
||||
filename?: string;
|
||||
}
|
||||
|
||||
export type A2APart = A2APartText | A2APartUrl | A2APartFile;
|
||||
|
||||
export interface A2AMessage {
|
||||
messageId?: string;
|
||||
contextId?: string;
|
||||
taskId?: string;
|
||||
role: "user" | "agent" | "system";
|
||||
parts: A2APart[];
|
||||
extensions?: Record<string, unknown>[];
|
||||
referenceTaskIds?: string[];
|
||||
}
|
||||
|
||||
export interface A2AArtifact {
|
||||
artifactId?: string;
|
||||
name?: string;
|
||||
description?: string;
|
||||
parts: A2APart[];
|
||||
metadata?: Record<string, unknown>;
|
||||
extensions?: Record<string, unknown>[];
|
||||
}
|
||||
|
||||
export interface A2ATask {
|
||||
id: string;
|
||||
project_id?: number;
|
||||
context_id?: string;
|
||||
source: string;
|
||||
state: string;
|
||||
remote_agent_id?: number | null;
|
||||
input_text: string;
|
||||
input_parts?: A2APart[];
|
||||
output_text?: string | null;
|
||||
output_parts?: A2APart[];
|
||||
error_message?: string | null;
|
||||
compatibility_mode: boolean;
|
||||
metadata: Record<string, unknown>;
|
||||
artifacts?: A2AArtifact[];
|
||||
history?: A2AMessage[];
|
||||
history_length?: number;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
finished_at?: string | null;
|
||||
}
|
||||
|
||||
export interface A2AAgentCard {
|
||||
id?: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
url?: string;
|
||||
provider?: {
|
||||
organization?: string;
|
||||
url?: string;
|
||||
};
|
||||
skills?: Array<{
|
||||
id?: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
tags?: string[];
|
||||
examples?: string[];
|
||||
inputModes?: string[];
|
||||
outputModes?: string[];
|
||||
securityRequirements?: Array<Record<string, unknown>>;
|
||||
}>;
|
||||
supportedInterfaces?: Array<{
|
||||
type: string;
|
||||
url?: string;
|
||||
protocolBinding?: string;
|
||||
protocolVersion?: string;
|
||||
tenant?: string;
|
||||
}>;
|
||||
defaultInputModes?: string[];
|
||||
defaultOutputModes?: string[];
|
||||
securitySchemes?: Record<string, unknown>;
|
||||
security?: Array<Record<string, unknown>>;
|
||||
signatures?: string[];
|
||||
iconUrl?: string;
|
||||
documentationUrl?: string;
|
||||
}
|
||||
|
||||
export interface A2ARemoteAgent {
|
||||
id: number;
|
||||
project_id: number;
|
||||
@@ -12,27 +107,12 @@ export interface A2ARemoteAgent {
|
||||
failure_count: number;
|
||||
circuit_open_until?: string | null;
|
||||
card_fetched_at?: string | null;
|
||||
}
|
||||
|
||||
export interface A2ATask {
|
||||
id: string;
|
||||
project_id: number;
|
||||
source: string;
|
||||
state: string;
|
||||
remote_agent_id?: number | null;
|
||||
input_text: string;
|
||||
output_text?: string | null;
|
||||
error_message?: string | null;
|
||||
compatibility_mode: boolean;
|
||||
metadata: Record<string, unknown>;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
finished_at?: string | null;
|
||||
agent_card?: A2AAgentCard;
|
||||
}
|
||||
|
||||
export interface A2ASendMessagePayload {
|
||||
project_id: number;
|
||||
message: string;
|
||||
message: A2AMessage;
|
||||
session_id?: string;
|
||||
remote_agent_id?: number;
|
||||
route_mode?: "auto" | "local" | "a2a" | "a2a_first" | "local_first" | "mcp_first";
|
||||
@@ -55,11 +135,13 @@ export interface A2ASubscribeEvent {
|
||||
type?: string;
|
||||
event?: string;
|
||||
task_id?: string;
|
||||
context_id?: string;
|
||||
task_status?: string;
|
||||
status?: string;
|
||||
artifact?: {
|
||||
content?: string;
|
||||
};
|
||||
artifact?: A2AArtifact;
|
||||
append?: boolean;
|
||||
last_chunk?: boolean;
|
||||
message?: A2AMessage;
|
||||
output?: string;
|
||||
source?: string;
|
||||
timestamp?: string;
|
||||
@@ -122,29 +204,42 @@ export const a2aApi = {
|
||||
healthCheckRemoteAgent(agentId: number) {
|
||||
return api.post<{ healthy: boolean; failure_count: number }>(`/api/v1/a2a/remote-agents/${agentId}/health-check`, {});
|
||||
},
|
||||
listTasks(projectId: number, state?: string) {
|
||||
listTasks(projectId: number, state?: string, contextId?: string) {
|
||||
const params = new URLSearchParams({ project_id: String(projectId), limit: "100" });
|
||||
if (state && state !== "all") {
|
||||
params.set("state", state);
|
||||
}
|
||||
if (contextId) {
|
||||
params.set("context_id", contextId);
|
||||
}
|
||||
return api.get<A2ATask[]>(`/api/v1/a2a/tasks?${params.toString()}`);
|
||||
},
|
||||
getTask(taskId: string) {
|
||||
return api.get<A2ATask>(`/api/v1/a2a/tasks/${taskId}`);
|
||||
getTask(taskId: string, historyLength?: number) {
|
||||
const params = new URLSearchParams();
|
||||
if (historyLength !== undefined) {
|
||||
params.set("historyLength", String(historyLength));
|
||||
}
|
||||
const queryString = params.toString();
|
||||
return api.get<A2ATask>(`/api/v1/a2a/tasks/${taskId}${queryString ? `?${queryString}` : ""}`);
|
||||
},
|
||||
cancelTask(taskId: string) {
|
||||
return api.post<{ task_id: string; state: string }>(`/api/v1/a2a/tasks/${taskId}/cancel`, {});
|
||||
return api.post<{ task_id: string; state: string }>(`/api/v1/a2a/tasks/${taskId}:cancel`, {});
|
||||
},
|
||||
sendMessage(payload: A2ASendMessagePayload) {
|
||||
return api.post<A2ASendMessageResponse>("/api/v1/a2a/messages/send", payload);
|
||||
return api.post<A2ASendMessageResponse>("/api/v1/a2a/message:send", payload);
|
||||
},
|
||||
async subscribeTask(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): Promise<void> {
|
||||
const response = await fetch(`/api/v1/a2a/tasks/${taskId}/subscribe`, {
|
||||
streamMessage(payload: A2ASendMessagePayload) {
|
||||
return api.post<A2ASendMessageResponse>("/api/v1/a2a/message:stream", payload);
|
||||
},
|
||||
subscribeTask(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): () => void {
|
||||
const controller = new AbortController();
|
||||
void (async () => {
|
||||
const response = await fetch(`/api/v1/a2a/tasks/${taskId}:subscribe`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
...getAuthHeaders(),
|
||||
},
|
||||
signal,
|
||||
signal: signal || controller.signal,
|
||||
});
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(`Subscribe failed: ${response.status}`);
|
||||
@@ -171,5 +266,140 @@ export const a2aApi = {
|
||||
onEvent(event);
|
||||
}
|
||||
}
|
||||
})();
|
||||
return () => controller.abort();
|
||||
},
|
||||
subscribeTaskSSE(taskId: string, onEvent: SubscribeHandler, signal?: AbortSignal): () => void {
|
||||
const controller = new AbortController();
|
||||
void (async () => {
|
||||
const response = await fetch(`/api/v1/a2a/tasks/${taskId}/subscribe`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
...getAuthHeaders(),
|
||||
},
|
||||
signal: signal || controller.signal,
|
||||
});
|
||||
if (!response.ok || !response.body) {
|
||||
throw new Error(`Subscribe failed: ${response.status}`);
|
||||
}
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder("utf-8");
|
||||
let buffer = "";
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const splitIndex = buffer.lastIndexOf("\n\n");
|
||||
if (splitIndex === -1) continue;
|
||||
const complete = buffer.slice(0, splitIndex);
|
||||
buffer = buffer.slice(splitIndex + 2);
|
||||
const events = parseSseEvents(complete);
|
||||
for (const event of events) {
|
||||
onEvent(event);
|
||||
}
|
||||
}
|
||||
if (buffer.trim()) {
|
||||
const events = parseSseEvents(buffer);
|
||||
for (const event of events) {
|
||||
onEvent(event);
|
||||
}
|
||||
}
|
||||
})();
|
||||
return () => controller.abort();
|
||||
},
|
||||
};
|
||||
|
||||
export function renderPart(part: A2APart): string {
|
||||
switch (part.kind) {
|
||||
case "text":
|
||||
return part.text;
|
||||
case "url":
|
||||
return `[URL: ${part.url}]`;
|
||||
case "file":
|
||||
if (part.mediaType?.startsWith("image/")) {
|
||||
return `[Image: ${part.filename || "image"}]`;
|
||||
}
|
||||
if (part.mediaType?.includes("json")) {
|
||||
try {
|
||||
const decoded = atob(part.data);
|
||||
return `[JSON File: ${part.filename || "data.json"}]\n${decoded}`;
|
||||
} catch {
|
||||
return `[Binary File: ${part.filename || "data"}]`;
|
||||
}
|
||||
}
|
||||
return `[File: ${part.filename || "file"}]`;
|
||||
default:
|
||||
return "[Unknown Part]";
|
||||
}
|
||||
}
|
||||
|
||||
export function renderParts(parts: A2APart[]): string {
|
||||
return parts.map(renderPart).join("\n");
|
||||
}
|
||||
|
||||
export function extractTextFromParts(parts: A2APart[]): string {
|
||||
return parts
|
||||
.filter((p): p is A2APartText => p.kind === "text")
|
||||
.map((p) => p.text)
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
export function getArtifactPreview(artifact: A2AArtifact): { type: "text" | "image" | "html" | "json" | "unknown"; content: string } {
|
||||
if (!artifact.parts || artifact.parts.length === 0) {
|
||||
return { type: "unknown", content: "" };
|
||||
}
|
||||
|
||||
const firstPart = artifact.parts[0];
|
||||
|
||||
if (firstPart.kind === "text") {
|
||||
return { type: "text", content: firstPart.text };
|
||||
}
|
||||
|
||||
if (firstPart.kind === "url") {
|
||||
const url = firstPart.url.toLowerCase();
|
||||
if (url.endsWith(".png") || url.endsWith(".jpg") || url.endsWith(".jpeg") || url.endsWith(".gif") || url.endsWith(".webp")) {
|
||||
return { type: "image", content: firstPart.url };
|
||||
}
|
||||
if (url.endsWith(".html") || url.endsWith(".htm")) {
|
||||
return { type: "html", content: firstPart.url };
|
||||
}
|
||||
return { type: "unknown", content: firstPart.url };
|
||||
}
|
||||
|
||||
if (firstPart.kind === "file") {
|
||||
const mediaType = firstPart.mediaType || "";
|
||||
if (mediaType.startsWith("image/")) {
|
||||
return { type: "image", content: `data:${mediaType};base64,${firstPart.data}` };
|
||||
}
|
||||
if (mediaType.includes("html")) {
|
||||
try {
|
||||
const decoded = atob(firstPart.data);
|
||||
return { type: "html", content: decoded };
|
||||
} catch {
|
||||
return { type: "unknown", content: "[HTML content]" };
|
||||
}
|
||||
}
|
||||
if (mediaType.includes("json")) {
|
||||
try {
|
||||
const decoded = atob(firstPart.data);
|
||||
return { type: "json", content: decoded };
|
||||
} catch {
|
||||
return { type: "unknown", content: "[JSON content]" };
|
||||
}
|
||||
}
|
||||
return { type: "unknown", content: `[File: ${firstPart.filename || "file"}]` };
|
||||
}
|
||||
|
||||
return { type: "unknown", content: "" };
|
||||
}
|
||||
|
||||
export function groupTasksByContextId(tasks: A2ATask[]): Map<string, A2ATask[]> {
|
||||
const grouped = new Map<string, A2ATask[]>();
|
||||
for (const task of tasks) {
|
||||
const contextId = task.context_id || "no-context";
|
||||
const existing = grouped.get(contextId) || [];
|
||||
existing.push(task);
|
||||
grouped.set(contextId, existing);
|
||||
}
|
||||
return grouped;
|
||||
}
|
||||
+340
-10
@@ -9,7 +9,7 @@ import { Textarea } from "@/components/ui/textarea";
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@/components/ui/select";
|
||||
import { Table, TableBody, TableCell, TableHead, TableHeader, TableRow } from "@/components/ui/table";
|
||||
import { api } from "@/lib/api";
|
||||
import { a2aApi, type A2ARemoteAgent, type A2ATask } from "@/api/a2a";
|
||||
import { a2aApi, type A2ARemoteAgent, type A2ATask, type A2AArtifact, renderPart, renderParts, getArtifactPreview, groupTasksByContextId } from "@/api/a2a";
|
||||
import { useProjectStore } from "@/store/projectStore";
|
||||
import { useMcpHealthStore } from "@/store/mcpHealthStore";
|
||||
import { useRef } from 'react';
|
||||
@@ -118,6 +118,11 @@ export function Skills() {
|
||||
auth_token: '',
|
||||
});
|
||||
const [isA2aRefreshingHealth, setIsA2aRefreshingHealth] = useState(false);
|
||||
const [selectedA2aAgent, setSelectedA2aAgent] = useState<A2ARemoteAgent | null>(null);
|
||||
const [selectedTask, setSelectedTask] = useState<A2ATask | null>(null);
|
||||
const [taskArtifactPreview, setTaskArtifactPreview] = useState<{ type: string; content: string } | null>(null);
|
||||
const [contextIdFilter, setContextIdFilter] = useState<string>('all');
|
||||
const [groupedByContextId, setGroupedByContextId] = useState<Map<string, A2ATask[]>>(new Map());
|
||||
|
||||
const { currentProject } = useProjectStore();
|
||||
const { hasMcpError, refresh: refreshMcpHealth } = useMcpHealthStore();
|
||||
@@ -245,6 +250,7 @@ export function Skills() {
|
||||
]);
|
||||
setA2aAgents(agents || []);
|
||||
setA2aTasks(tasks || []);
|
||||
setGroupedByContextId(groupTasksByContextId(tasks || []));
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch A2A data", error);
|
||||
} finally {
|
||||
@@ -252,6 +258,25 @@ export function Skills() {
|
||||
}
|
||||
};
|
||||
|
||||
const handlePreviewArtifact = (artifact: A2AArtifact) => {
|
||||
const preview = getArtifactPreview(artifact);
|
||||
setTaskArtifactPreview(preview);
|
||||
};
|
||||
|
||||
const handleTaskClick = (task: A2ATask) => {
|
||||
setSelectedTask(task);
|
||||
if (task.artifacts && task.artifacts.length > 0) {
|
||||
const preview = getArtifactPreview(task.artifacts[0]);
|
||||
setTaskArtifactPreview(preview);
|
||||
} else {
|
||||
setTaskArtifactPreview(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAgentCardClick = (agent: A2ARemoteAgent) => {
|
||||
setSelectedA2aAgent(agent);
|
||||
};
|
||||
|
||||
const handleRefreshA2aHealth = async () => {
|
||||
if (!currentProject || a2aAgents.length === 0) return;
|
||||
setIsA2aRefreshingHealth(true);
|
||||
@@ -602,6 +627,20 @@ export function Skills() {
|
||||
<RefreshCw className="h-4 w-4" />
|
||||
{t('refresh')}
|
||||
</Button>
|
||||
<Select value={contextIdFilter} onValueChange={(val) => { if (val) setContextIdFilter(val); }}>
|
||||
<SelectTrigger className="w-[180px] h-9">
|
||||
<SelectValue placeholder={t('filterByContext')} />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectItem value="all">{t('allContexts')}</SelectItem>
|
||||
{Array.from(groupedByContextId.keys()).filter(k => k !== 'no-context').map(contextId => (
|
||||
<SelectItem key={contextId} value={contextId}>{contextId.slice(0, 16)}...</SelectItem>
|
||||
))}
|
||||
{groupedByContextId.has('no-context') && (
|
||||
<SelectItem value="no-context">{t('noContext')}</SelectItem>
|
||||
)}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<Button
|
||||
className="h-9 bg-[#ff4d29] hover:bg-[#ff4d29]/90 text-white gap-2 rounded-md px-3"
|
||||
onClick={handleOpenCreateA2a}
|
||||
@@ -856,7 +895,7 @@ export function Skills() {
|
||||
</TableRow>
|
||||
) : (
|
||||
a2aAgents.map((agent) => (
|
||||
<TableRow key={agent.id} className="group hover:bg-muted/50/50 transition-colors border-border">
|
||||
<TableRow key={agent.id} className="group hover:bg-muted/50/50 transition-colors border-border cursor-pointer" onClick={() => handleAgentCardClick(agent)}>
|
||||
<TableCell className="py-4 px-4 text-sm font-medium">{agent.name}</TableCell>
|
||||
<TableCell className="py-4 px-4 text-sm text-muted-foreground truncate" title={agent.base_url}>{agent.base_url}</TableCell>
|
||||
<TableCell className="py-4 px-4 text-sm text-muted-foreground">{agent.protocol_version || '-'}</TableCell>
|
||||
@@ -868,7 +907,7 @@ export function Skills() {
|
||||
<span className="opacity-70">#{agent.failure_count}</span>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="py-4 px-4 text-right">
|
||||
<TableCell className="py-4 px-4 text-right" onClick={(e) => e.stopPropagation()}>
|
||||
<div className="flex items-center justify-end gap-1">
|
||||
<Button variant="ghost" size="icon" className="h-8 w-8 text-muted-foreground hover:text-indigo-600 hover:bg-indigo-50 rounded-md transition-all shrink-0" onClick={() => void handleRefreshA2aCard(agent.id)}>
|
||||
<RefreshCw className="h-4 w-4" />
|
||||
@@ -895,10 +934,11 @@ export function Skills() {
|
||||
<Table className="table-fixed w-full">
|
||||
<TableHeader className="bg-muted/50/50">
|
||||
<TableRow className="hover:bg-transparent">
|
||||
<TableHead className="w-[18%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('taskId')}</TableHead>
|
||||
<TableHead className="w-[12%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('taskSource')}</TableHead>
|
||||
<TableHead className="w-[12%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('status')}</TableHead>
|
||||
<TableHead className="w-[38%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('content')}</TableHead>
|
||||
<TableHead className="w-[16%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('taskId')}</TableHead>
|
||||
<TableHead className="w-[12%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('contextId')}</TableHead>
|
||||
<TableHead className="w-[10%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('taskSource')}</TableHead>
|
||||
<TableHead className="w-[10%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('status')}</TableHead>
|
||||
<TableHead className="w-[32%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('content')}</TableHead>
|
||||
<TableHead className="w-[20%] font-semibold text-foreground/80 py-3 px-4 text-sm">{t('time')}</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
@@ -915,14 +955,20 @@ export function Skills() {
|
||||
</TableRow>
|
||||
) : (
|
||||
a2aTasks.map((task) => (
|
||||
<TableRow key={task.id} className="group hover:bg-muted/50/50 transition-colors border-border">
|
||||
<TableRow key={task.id} className="group hover:bg-muted/50/50 transition-colors border-border cursor-pointer" onClick={() => handleTaskClick(task)}>
|
||||
<TableCell className="py-4 px-4 text-xs font-mono truncate" title={task.id}>{task.id}</TableCell>
|
||||
<TableCell className="py-4 px-4 text-xs font-mono truncate text-muted-foreground" title={task.context_id || ''}>{task.context_id ? task.context_id.slice(0, 12) + '...' : '-'}</TableCell>
|
||||
<TableCell className="py-4 px-4 text-sm text-muted-foreground">{task.source}</TableCell>
|
||||
<TableCell className="py-4 px-4 text-sm">{task.state}</TableCell>
|
||||
<TableCell className="py-4 px-4 text-xs text-muted-foreground">
|
||||
<div className="line-clamp-2" title={task.error_message || task.output_text || task.input_text}>
|
||||
{task.error_message || task.output_text || task.input_text}
|
||||
<div className="line-clamp-2" title={task.error_message || task.output_text || task.input_text || (task.input_parts ? renderParts(task.input_parts) : '')}>
|
||||
{task.error_message || task.output_text || task.input_text || (task.input_parts ? renderParts(task.input_parts) : '')}
|
||||
</div>
|
||||
{task.artifacts && task.artifacts.length > 0 && (
|
||||
<div className="mt-1 text-[10px] text-indigo-600">
|
||||
{task.artifacts.length} artifact(s)
|
||||
</div>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="py-4 px-4 text-xs text-muted-foreground">{task.updated_at}</TableCell>
|
||||
</TableRow>
|
||||
@@ -1203,6 +1249,290 @@ export function Skills() {
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
<Dialog open={!!selectedA2aAgent} onOpenChange={(open) => { if (!open) setSelectedA2aAgent(null); }}>
|
||||
<DialogContent className="sm:max-w-[700px] max-h-[90vh] flex flex-col rounded-2xl p-0 overflow-hidden">
|
||||
<DialogHeader className="p-6 pb-2">
|
||||
<DialogTitle className="text-xl font-bold text-foreground">{selectedA2aAgent?.name} - Agent Card</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="flex-1 overflow-y-auto px-6 py-2">
|
||||
{selectedA2aAgent?.agent_card ? (
|
||||
<div className="grid gap-4">
|
||||
{selectedA2aAgent.agent_card.description && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('description')}</Label>
|
||||
<p className="text-sm text-foreground">{selectedA2aAgent.agent_card.description}</p>
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.url && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">URL</Label>
|
||||
<a href={selectedA2aAgent.agent_card.url} target="_blank" rel="noopener noreferrer" className="text-sm text-indigo-600 hover:underline">{selectedA2aAgent.agent_card.url}</a>
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.provider && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('provider')}</Label>
|
||||
<div className="text-sm text-foreground">
|
||||
{selectedA2aAgent.agent_card.provider.organization && <span>{selectedA2aAgent.agent_card.provider.organization}</span>}
|
||||
{selectedA2aAgent.agent_card.provider.url && <span> - <a href={selectedA2aAgent.agent_card.provider.url} target="_blank" rel="noopener noreferrer" className="text-indigo-600 hover:underline">{selectedA2aAgent.agent_card.provider.url}</a></span>}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.skills && selectedA2aAgent.agent_card.skills.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('skills')}</Label>
|
||||
<div className="space-y-2">
|
||||
{selectedA2aAgent.agent_card.skills.map((skill, idx) => (
|
||||
<div key={idx} className="p-3 bg-muted/50 rounded-lg">
|
||||
<div className="font-medium text-sm">{skill.name}</div>
|
||||
{skill.description && <p className="text-xs text-muted-foreground mt-1">{skill.description}</p>}
|
||||
{skill.tags && skill.tags.length > 0 && (
|
||||
<div className="flex flex-wrap gap-1 mt-2">
|
||||
{skill.tags.map((tag, i) => (
|
||||
<span key={i} className="px-2 py-0.5 bg-indigo-100 text-indigo-700 text-[10px] rounded-full">{tag}</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{skill.inputModes && skill.inputModes.length > 0 && (
|
||||
<div className="mt-1 text-xs text-muted-foreground">Input: {skill.inputModes.join(', ')}</div>
|
||||
)}
|
||||
{skill.outputModes && skill.outputModes.length > 0 && (
|
||||
<div className="mt-1 text-xs text-muted-foreground">Output: {skill.outputModes.join(', ')}</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.supportedInterfaces && selectedA2aAgent.agent_card.supportedInterfaces.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('supportedInterfaces')}</Label>
|
||||
<div className="space-y-1">
|
||||
{selectedA2aAgent.agent_card.supportedInterfaces.map((iface, idx) => (
|
||||
<div key={idx} className="p-2 bg-muted/50 rounded text-xs">
|
||||
<span className="font-medium">{iface.type}</span>
|
||||
{iface.url && <span className="text-muted-foreground ml-2">{iface.url}</span>}
|
||||
{iface.protocolVersion && <span className="text-muted-foreground ml-2">v{iface.protocolVersion}</span>}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.defaultInputModes && selectedA2aAgent.agent_card.defaultInputModes.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('defaultInputModes')}</Label>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{selectedA2aAgent.agent_card.defaultInputModes.map((mode, idx) => (
|
||||
<span key={idx} className="px-2 py-0.5 bg-muted rounded text-xs">{mode}</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.defaultOutputModes && selectedA2aAgent.agent_card.defaultOutputModes.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('defaultOutputModes')}</Label>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{selectedA2aAgent.agent_card.defaultOutputModes.map((mode, idx) => (
|
||||
<span key={idx} className="px-2 py-0.5 bg-muted rounded text-xs">{mode}</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.iconUrl && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('iconUrl')}</Label>
|
||||
<img src={selectedA2aAgent.agent_card.iconUrl} alt="Agent Icon" className="h-16 w-16 rounded-lg object-contain" />
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.documentationUrl && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('documentationUrl')}</Label>
|
||||
<a href={selectedA2aAgent.agent_card.documentationUrl} target="_blank" rel="noopener noreferrer" className="text-sm text-indigo-600 hover:underline">{selectedA2aAgent.agent_card.documentationUrl}</a>
|
||||
</div>
|
||||
)}
|
||||
{selectedA2aAgent.agent_card.security && selectedA2aAgent.agent_card.security.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('security')}</Label>
|
||||
<div className="text-xs text-muted-foreground font-mono bg-muted/50 p-2 rounded">
|
||||
{JSON.stringify(selectedA2aAgent.agent_card.security, null, 2)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-center py-8 text-muted-foreground">
|
||||
<p>{t('noAgentCardAvailable')}</p>
|
||||
<p className="text-xs mt-2">{t('tryRefreshingCard')}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<DialogFooter className="p-6 pt-2">
|
||||
<Button variant="outline" onClick={() => setSelectedA2aAgent(null)}>{t('close')}</Button>
|
||||
{selectedA2aAgent && (
|
||||
<Button onClick={() => void handleRefreshA2aCard(selectedA2aAgent.id)} className="gap-2">
|
||||
<RefreshCw className="h-4 w-4" />
|
||||
{t('refreshCard')}
|
||||
</Button>
|
||||
)}
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
|
||||
<Dialog open={!!selectedTask} onOpenChange={(open) => { if (!open) { setSelectedTask(null); setTaskArtifactPreview(null); } }}>
|
||||
<DialogContent className="sm:max-w-[900px] max-h-[90vh] flex flex-col rounded-2xl p-0 overflow-hidden">
|
||||
<DialogHeader className="p-6 pb-2">
|
||||
<DialogTitle className="text-xl font-bold text-foreground">{t('taskDetails')}</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="flex-1 overflow-y-auto px-6 py-2">
|
||||
{selectedTask && (
|
||||
<div className="grid gap-4">
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">Task ID</Label>
|
||||
<p className="text-sm font-mono break-all">{selectedTask.id}</p>
|
||||
</div>
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">Context ID</Label>
|
||||
<p className="text-sm font-mono break-all">{selectedTask.context_id || '-'}</p>
|
||||
</div>
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">State</Label>
|
||||
<p className="text-sm">{selectedTask.state}</p>
|
||||
</div>
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">Source</Label>
|
||||
<p className="text-sm">{selectedTask.source}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{selectedTask.input_parts && selectedTask.input_parts.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('inputParts')}</Label>
|
||||
<div className="bg-muted/50 rounded-lg p-3 space-y-2">
|
||||
{selectedTask.input_parts.map((part, idx) => (
|
||||
<div key={idx} className="text-sm">
|
||||
{part.kind === 'text' && <p className="whitespace-pre-wrap">{part.text}</p>}
|
||||
{part.kind === 'url' && <a href={part.url} target="_blank" rel="noopener noreferrer" className="text-indigo-600 hover:underline">{part.url}</a>}
|
||||
{part.kind === 'file' && (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-muted-foreground">[{part.filename || 'file'}]</span>
|
||||
{part.mediaType?.startsWith('image/') && <span className="text-xs text-muted-foreground">({part.mediaType})</span>}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedTask.output_parts && selectedTask.output_parts.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('outputParts')}</Label>
|
||||
<div className="bg-muted/50 rounded-lg p-3 space-y-2">
|
||||
{selectedTask.output_parts.map((part, idx) => (
|
||||
<div key={idx} className="text-sm">
|
||||
{part.kind === 'text' && <p className="whitespace-pre-wrap">{part.text}</p>}
|
||||
{part.kind === 'url' && <a href={part.url} target="_blank" rel="noopener noreferrer" className="text-indigo-600 hover:underline">{part.url}</a>}
|
||||
{part.kind === 'file' && (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-muted-foreground">[{part.filename || 'file'}]</span>
|
||||
{part.mediaType?.startsWith('image/') && <span className="text-xs text-muted-foreground">({part.mediaType})</span>}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedTask.artifacts && selectedTask.artifacts.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('artifacts')} ({selectedTask.artifacts.length})</Label>
|
||||
<div className="space-y-2">
|
||||
{selectedTask.artifacts.map((artifact, idx) => (
|
||||
<div key={idx} className="border border-border rounded-lg p-3">
|
||||
<div className="flex items-center justify-between mb-2">
|
||||
<div className="font-medium text-sm">
|
||||
{artifact.name || `Artifact ${idx + 1}`}
|
||||
</div>
|
||||
<Button variant="ghost" size="sm" className="h-7 text-xs gap-1" onClick={() => handlePreviewArtifact(artifact)}>
|
||||
<Eye className="h-3 w-3" />
|
||||
{t('preview')}
|
||||
</Button>
|
||||
</div>
|
||||
{artifact.description && (
|
||||
<p className="text-xs text-muted-foreground mb-2">{artifact.description}</p>
|
||||
)}
|
||||
<div className="text-xs text-muted-foreground">
|
||||
{artifact.parts.length} part(s)
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{taskArtifactPreview && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('artifactPreview')}</Label>
|
||||
<div className="border border-border rounded-lg p-3 bg-muted/50">
|
||||
{taskArtifactPreview.type === 'text' && (
|
||||
<pre className="text-xs whitespace-pre-wrap break-all max-h-[300px] overflow-auto">{taskArtifactPreview.content}</pre>
|
||||
)}
|
||||
{taskArtifactPreview.type === 'image' && (
|
||||
<img src={taskArtifactPreview.content} alt="Artifact Preview" className="max-w-full max-h-[300px] rounded-lg object-contain" />
|
||||
)}
|
||||
{taskArtifactPreview.type === 'html' && (
|
||||
<div className="text-xs text-muted-foreground italic">[HTML Preview - rendered separately]</div>
|
||||
)}
|
||||
{taskArtifactPreview.type === 'json' && (
|
||||
<pre className="text-xs whitespace-pre-wrap break-all max-h-[300px] overflow-auto">{taskArtifactPreview.content}</pre>
|
||||
)}
|
||||
{taskArtifactPreview.type === 'unknown' && (
|
||||
<p className="text-xs text-muted-foreground">{taskArtifactPreview.content}</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedTask.error_message && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('errorMessage')}</Label>
|
||||
<div className="bg-rose-50 border border-rose-100 rounded-lg p-3 text-sm text-rose-700">
|
||||
{selectedTask.error_message}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{selectedTask.history && selectedTask.history.length > 0 && (
|
||||
<div className="grid gap-1.5">
|
||||
<Label className="text-muted-foreground font-medium text-sm">{t('messageHistory')} ({selectedTask.history.length})</Label>
|
||||
<div className="space-y-2 max-h-[300px] overflow-auto">
|
||||
{selectedTask.history.map((msg, idx) => (
|
||||
<div key={idx} className={`p-3 rounded-lg ${msg.role === 'user' ? 'bg-blue-50 border border-blue-100' : 'bg-green-50 border border-green-100'}`}>
|
||||
<div className="flex items-center justify-between mb-1">
|
||||
<span className="text-xs font-medium">{msg.role}</span>
|
||||
{msg.messageId && <span className="text-[10px] text-muted-foreground font-mono">{msg.messageId.slice(0, 8)}...</span>}
|
||||
</div>
|
||||
<div className="text-xs">
|
||||
{msg.parts.map((part, pIdx) => (
|
||||
<div key={pIdx}>{renderPart(part)}</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<DialogFooter className="p-6 pt-2">
|
||||
<Button variant="outline" onClick={() => { setSelectedTask(null); setTaskArtifactPreview(null); }}>{t('close')}</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user