diff --git a/examples/a2a/basic_client_example.py b/examples/a2a/basic_client_example.py new file mode 100644 index 0000000000..971aa0c8c5 --- /dev/null +++ b/examples/a2a/basic_client_example.py @@ -0,0 +1,53 @@ +""" +Example: Calling an A2A agent as a tool from an OpenAI agent. + +This example demonstrates using ``A2AClientTool`` to call a remote A2A agent +as a function tool. It assumes you have an A2A agent running at the given URL. + +To try it out: + +1. Start a sample A2A server (e.g. the hello-world agent from a2a-sdk): + ``cd a2a-python/samples && python cli.py server`` + +2. Run this script: + ``python examples/a2a/basic_client_example.py`` +""" + +from __future__ import annotations + +import asyncio + +from agents import Agent, Runner +from agents.extensions.a2a import A2AClientTool + + +async def main() -> None: + # Create a tool that wraps a remote A2A agent. + # from_url() fetches the AgentCard at .well-known/agent-card.json + research_tool = await A2AClientTool.from_url( + url="http://localhost:10000", + tool_name="research_agent", + tool_description=( + "Ask the research agent to find and summarize information. " + "Use when you need external knowledge." + ), + ) + + orchestrator = Agent( + name="Orchestrator", + instructions="You are an orchestrator. Use the research_agent tool for external queries.", + tools=[research_tool.as_function_tool()], + ) + + result = await Runner.run( + orchestrator, + "What are the latest developments in quantum computing?", + ) + print(f"Final output: {result.final_output}") + + # Clean up + await research_tool.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index ad2b314ead..83b99c2a1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ modal = ["modal==1.3.5"] runloop = ["runloop_api_client>=1.16.0,<2.0.0"] vercel = ["vercel>=0.5.6,<0.6"] s3 = ["boto3>=1.34"] +a2a = ["a2a-sdk>=1.0.0,<2"] temporal = [ "temporalio==1.26.0", "textual>=8.2.3,<8.3", @@ -164,6 +165,10 @@ ignore_missing_imports = true module = ["vercel", "vercel.*"] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["a2a", "a2a.*", "a2a_pb2", "google.protobuf"] +ignore_missing_imports = true + [tool.coverage.run] source = ["src/agents"] omit = [ diff --git a/src/agents/extensions/a2a/__init__.py b/src/agents/extensions/a2a/__init__.py new file mode 100644 index 0000000000..a2cc5df9bd --- /dev/null +++ b/src/agents/extensions/a2a/__init__.py @@ -0,0 +1,96 @@ +""" +A2A (Agent-to-Agent) protocol integration for the OpenAI Agents SDK. + +This extension enables bidirectional interoperability between OpenAI Agents +and any A2A-compatible agent (built with any framework, in any language): + +- **A2A Client Tool**: Call external A2A agents as tools from your OpenAI agent. +- **A2A Server Agent**: Expose your OpenAI agent as an A2A service so other + agents can call it. + +The A2A protocol is defined by Google at https://github.com/google/A2A. + +Dependencies +------------ +This module requires the ``a2a-sdk`` package. Install it with:: + + pip install openai-agents[a2a] + +or directly:: + + pip install a2a-sdk +""" + +from __future__ import annotations + +from importlib import import_module +from typing import TYPE_CHECKING, Any + +from agents.extensions.memory._optional_imports import raise_optional_dependency_error + +if TYPE_CHECKING: + from ._agent_card import generate_agent_card + from ._client_tool import A2AClientTool + from ._server_executor import A2AServerAgent + from ._converter import ( + a2a_context_to_openai_input, + a2a_history_to_openai_input_items, + a2a_message_to_openai_input_items, + openai_error_to_failed_task, + openai_final_output_to_artifacts, + openai_items_to_a2a_messages, + openai_run_result_to_task, + openai_stream_event_to_task_status, + ) + +__all__ = [ + "A2AClientTool", + "A2AServerAgent", + "generate_agent_card", + "a2a_context_to_openai_input", + "a2a_history_to_openai_input_items", + "a2a_message_to_openai_input_items", + "openai_error_to_failed_task", + "openai_final_output_to_artifacts", + "openai_items_to_a2a_messages", + "openai_run_result_to_task", + "openai_stream_event_to_task_status", +] + +_LAZY_EXPORTS: dict[str, tuple[str, tuple[str, str] | None]] = { + "A2AClientTool": ("._client_tool", ("a2a-sdk", "a2a")), + "A2AServerAgent": ("._server_executor", ("a2a-sdk", "a2a")), + "generate_agent_card": ("._agent_card", ("a2a-sdk", "a2a")), + "a2a_context_to_openai_input": ("._converter", ("a2a-sdk", "a2a")), + "a2a_history_to_openai_input_items": ("._converter", ("a2a-sdk", "a2a")), + "a2a_message_to_openai_input_items": ("._converter", ("a2a-sdk", "a2a")), + "openai_error_to_failed_task": ("._converter", ("a2a-sdk", "a2a")), + "openai_final_output_to_artifacts": ("._converter", ("a2a-sdk", "a2a")), + "openai_items_to_a2a_messages": ("._converter", ("a2a-sdk", "a2a")), + "openai_run_result_to_task": ("._converter", ("a2a-sdk", "a2a")), + "openai_stream_event_to_task_status": ("._converter", ("a2a-sdk", "a2a")), +} + + +def __getattr__(name: str) -> Any: + if name not in _LAZY_EXPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module_name, optional_dependency = _LAZY_EXPORTS[name] + try: + module = import_module(module_name, __name__) + except ModuleNotFoundError as e: + if optional_dependency is None: + raise ImportError(f"Failed to import {name}: {e}") from e + dependency_name, extra_name = optional_dependency + raise_optional_dependency_error( + name, + dependency_name=dependency_name, + extra_name=extra_name, + cause=e, + ) + + value = getattr(module, name) + # Cache for subsequent access. + globals()[name] = value + return value diff --git a/src/agents/extensions/a2a/_agent_card.py b/src/agents/extensions/a2a/_agent_card.py new file mode 100644 index 0000000000..e81dc5631b --- /dev/null +++ b/src/agents/extensions/a2a/_agent_card.py @@ -0,0 +1,159 @@ +""" +AgentCard generator — build an A2A ``AgentCard`` from OpenAI Agent metadata. + +This module inspects an OpenAI ``Agent`` instance and produces a compliant +A2A ``AgentCard`` that describes its name, description, skills (derived from +tools and handoffs), capabilities, and supported interfaces. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from a2a.types.a2a_pb2 import ( # type: ignore[import-untyped] + AgentCapabilities, + AgentCard, + AgentInterface, + AgentProvider, + AgentSkill, + ) + from agents.agent import Agent + + +def generate_agent_card( + agent: Agent[Any], + *, + url: str, + provider: AgentProvider | None = None, + capabilities: AgentCapabilities | None = None, + supported_interfaces: list[AgentInterface] | None = None, + version: str = "1.0.0", +) -> AgentCard: + """Generate an A2A ``AgentCard`` from an OpenAI ``Agent`` instance. + + The card includes: + + - The agent's ``name`` and ``instructions`` (or ``handoff_description``) + as the card's ``description``. + - One ``AgentSkill`` per ``FunctionTool`` (and per ``Agent.as_tool()`` + handoff) with its name, description, and tags. + - Default ``input_modes`` / ``output_modes`` set to ``["text"]``. + + Args: + agent: The OpenAI ``Agent`` to describe. + url: The base URL where the agent will be served. + provider: Optional ``AgentProvider`` (organisation metadata). + capabilities: Optional ``AgentCapabilities``; defaults to streaming + enabled with no push notifications. + supported_interfaces: Optional list of ``AgentInterface`` entries; + defaults to a single JSON-RPC interface at ``url``. + version: Version string for the agent card (default ``"1.0.0"``). + + Returns: + A populated A2A ``AgentCard`` protobuf message. + """ + from a2a.types.a2a_pb2 import ( # type: ignore[import-untyped] + AgentCapabilities, + AgentCard, + AgentInterface, + AgentSkill, + ) + + from agents.handoffs import Handoff + from agents.tool import FunctionTool + + # -- description -------------------------------------------------------- + description = agent.handoff_description or "" + if not description and agent.instructions: + if isinstance(agent.instructions, str): + # Truncate very long instructions for the card. + description = agent.instructions[:2000] + else: + description = "Dynamic instructions (callable)." + + # -- skills ------------------------------------------------------------- + skills: list[AgentSkill] = [] + + for tool in agent.tools: + if not isinstance(tool, FunctionTool): + continue + skill = AgentSkill( + id=tool.name, + name=tool.name, + description=tool.description[:2000], + tags=_tool_tags(tool), + input_modes=["text"], + output_modes=["text"], + ) + skills.append(skill) + + for handoff in agent.handoffs: + if isinstance(handoff, Handoff): + skill = AgentSkill( + id=handoff.tool_name, + name=handoff.agent_name, + description=handoff.handoff_description or handoff.tool_description or "", + tags=["handoff"], + input_modes=["text"], + output_modes=["text"], + ) + skills.append(skill) + + # -- capabilities ------------------------------------------------------- + if capabilities is None: + capabilities = AgentCapabilities( + streaming=True, + push_notifications=False, + ) + + # -- interfaces --------------------------------------------------------- + if supported_interfaces is None: + supported_interfaces = [ + AgentInterface( + url=url, + protocol_binding="a2a-json-rpc", + protocol_version="1.0", + ), + ] + + return AgentCard( + name=agent.name, + description=description, + version=version, + supported_interfaces=supported_interfaces, + default_input_modes=["text"], + default_output_modes=["text"], + capabilities=capabilities, + skills=skills, + provider=provider, + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _tool_tags(tool: Any) -> list[str]: + """Infer A2A skill tags from tool metadata.""" + tags = ["tool"] + + # Tag known hosted tools + tool_type = type(tool).__name__ + type_to_tag: dict[str, str] = { + "FileSearchTool": "file-search", + "WebSearchTool": "web-search", + "CodeInterpreterTool": "code-interpreter", + "HostedMCPTool": "mcp", + "ShellTool": "shell", + } + if tag := type_to_tag.get(tool_type): + tags.append(tag) + + # Include namespace if present + namespace = getattr(tool, "_tool_namespace", None) + if namespace: + tags.append(f"namespace:{namespace}") + + return tags diff --git a/src/agents/extensions/a2a/_client_tool.py b/src/agents/extensions/a2a/_client_tool.py new file mode 100644 index 0000000000..b180542a6b --- /dev/null +++ b/src/agents/extensions/a2a/_client_tool.py @@ -0,0 +1,521 @@ +""" +A2A Client Tool — call external A2A agents from your OpenAI agent. + +Wraps an A2A ``Client`` (or its configuration) as an OpenAI Agents SDK +``FunctionTool`` so that any A2A-compatible agent can be invoked like any +other tool in an agent's tool set. + +Usage:: + + from agents import Agent, Runner + from agents.extensions.a2a import A2AClientTool + + research_agent = A2AClientTool.from_url( + url="http://research-agent:8080", + tool_name="research_agent", + tool_description="Ask the research agent to find and summarize information.", + ) + + orchestrator = Agent( + name="Orchestrator", + tools=[research_agent], + ) + result = await Runner.run(orchestrator, "Research quantum computing") +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import uuid +from typing import TYPE_CHECKING, Any + +from agents.exceptions import ModelBehaviorError +from agents.logger import logger +from agents.run_context import RunContextWrapper +from agents.tool import ( + FunctionTool, + _build_handled_function_tool_error_handler, + _build_wrapped_function_tool, + _parse_function_tool_json_input, +) + +if TYPE_CHECKING: + from a2a.client.client import ( + Client as A2AClient, + ClientConfig, + ClientCallContext, + ) + from a2a.types.a2a_pb2 import ( # type: ignore[import-untyped] + AgentCard, + Message, + Part, + SendMessageRequest, + Task, + ) + +# --------------------------------------------------------------------------- +# Default schema for the tool parameters +# --------------------------------------------------------------------------- + +_A2A_TOOL_PARAMS_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": ( + "The message to send to the external agent. " + "Be clear and specific about what you need." + ), + }, + "context_id": { + "type": "string", + "description": ( + "Optional. An existing conversation context ID to continue " + "a previous conversation with this agent." + ), + }, + }, + "required": ["message"], + "additionalProperties": False, +} + + +# --------------------------------------------------------------------------- +# A2AClientTool +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class A2AClientTool: + """Wraps an A2A remote agent as an OpenAI Agents SDK ``FunctionTool``. + + The tool can be constructed from an existing ``AgentCard`` (when you + already have the card), or from a URL (the card is fetched at + construction time via ``.well-known/agent-card.json``). + + Once constructed, pass it directly to ``Agent(tools=[...])``. + + Parameters + ---------- + tool_name: + The name exposed to the LLM for this tool. + tool_description: + A human-readable description helping the LLM decide when to call + the external agent. + agent_card: + An A2A ``AgentCard`` describing the remote agent. Either + ``agent_card`` or ``agent_card_url`` must be provided. + agent_card_url: + The base URL of the remote A2A agent. The ``AgentCard`` will be + fetched from ``/.well-known/agent-card.json``. Either + ``agent_card`` or ``agent_card_url`` must be provided. + client_config: + Optional A2A ``ClientConfig``; a sensible default is used when + omitted. + timeout_seconds: + Maximum time (in seconds) to wait for the remote agent to + complete. Defaults to 300 (5 minutes). Set to ``None`` to + disable. + failure_error_function: + Optional formatter for tool-level errors. + """ + + tool_name: str + tool_description: str + agent_card: AgentCard | None = None + agent_card_url: str | None = None + client_config: ClientConfig | None = None + timeout_seconds: float | None = 300.0 + failure_error_function: Any = None # ToolErrorFunction | None + _client: A2AClient | None = dataclasses.field(default=None, repr=False) + _httpx_client: Any = dataclasses.field(default=None, repr=False) + + def __post_init__(self) -> None: + if self.agent_card is None and self.agent_card_url is None: + raise ValueError( + "Either 'agent_card' or 'agent_card_url' must be provided." + ) + + # ------------------------------------------------------------------ + # Factory constructors + # ------------------------------------------------------------------ + + @classmethod + async def from_url( + cls, + *, + url: str, + tool_name: str, + tool_description: str, + tool_params_schema: dict[str, Any] | None = None, + client_config: ClientConfig | None = None, + timeout_seconds: float | None = 300.0, + ) -> A2AClientTool: + """Asynchronously create an ``A2AClientTool`` by fetching the + ``AgentCard`` from the remote agent's well-known URL. + + Args: + url: Base URL of the remote A2A agent. + tool_name: Tool name exposed to the LLM. + tool_description: Tool description for the LLM. + client_config: Optional A2A ``ClientConfig``. + timeout_seconds: Per-invocation timeout in seconds. + + Returns: + A ready-to-use ``A2AClientTool`` instance. + """ + from a2a.client.card_resolver import A2ACardResolver + + config = client_config + if config is None: + from a2a.client.client import ClientConfig + + config = ClientConfig() + + httpx_client = getattr(config, "httpx_client", None) + own_httpx_client = httpx_client is None + if own_httpx_client: + import httpx + + httpx_client = httpx.AsyncClient() + + resolver = A2ACardResolver(httpx_client, url) + card = await resolver.get_agent_card() + + instance = cls( + tool_name=tool_name, + tool_description=tool_description, + agent_card=card, + agent_card_url=url, + client_config=config, + timeout_seconds=timeout_seconds, + ) + instance._client = await instance._get_or_create_client() + if own_httpx_client: + instance._httpx_client = httpx_client + return instance + + @classmethod + def from_card( + cls, + *, + card: AgentCard, + tool_name: str, + tool_description: str, + client_config: ClientConfig | None = None, + timeout_seconds: float | None = 300.0, + ) -> A2AClientTool: + """Synchronously create an ``A2AClientTool`` from an existing + ``AgentCard``. + + The underlying A2A ``Client`` is lazily created on first use. + + Args: + card: An A2A ``AgentCard`` describing the remote agent. + tool_name: Tool name exposed to the LLM. + tool_description: Tool description for the LLM. + client_config: Optional A2A ``ClientConfig``. + timeout_seconds: Per-invocation timeout in seconds. + + Returns: + A ready-to-use ``A2AClientTool`` instance. + """ + return cls( + tool_name=tool_name, + tool_description=tool_description, + agent_card=card, + client_config=client_config, + timeout_seconds=timeout_seconds, + ) + + # ------------------------------------------------------------------ + # FunctionTool conversion — the core integration point + # ------------------------------------------------------------------ + + def as_function_tool(self) -> FunctionTool: + """Return a ``FunctionTool`` that can be added to an ``Agent.tools``. + + The generated tool has a single ``message`` parameter (the text to + send to the remote agent). + """ + async def _invoke(ctx: RunContextWrapper[Any], input_json: str) -> Any: + return await self._invoke_impl(ctx, input_json) + + tool = _build_wrapped_function_tool( + name=self.tool_name, + description=self.tool_description, + params_json_schema=_A2A_TOOL_PARAMS_SCHEMA, + invoke_tool_impl=_invoke, + on_handled_error=_build_handled_function_tool_error_handler( + span_message=f"Error calling A2A agent '{self.tool_name}'", + span_message_for_json_decode_error="Error parsing arguments " + f"for A2A agent '{self.tool_name}'", + log_label="A2A", + ), + failure_error_function=self.failure_error_function, + strict_json_schema=True, + ) + return tool + + async def close(self) -> None: + """Release resources held by this tool (client connections, etc.).""" + if self._client is not None: + try: + await self._client.close() + except Exception: + logger.debug("Error closing A2A client for '%s'", self.tool_name) + self._client = None + if self._httpx_client is not None: + try: + await self._httpx_client.aclose() + except Exception: + logger.debug("Error closing httpx client for '%s'", self.tool_name) + self._httpx_client = None + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + async def _get_or_create_client(self) -> A2AClient: + """Create (or return a cached) A2A ``Client``.""" + if self._client is not None: + return self._client + + from a2a.client.client_factory import ClientFactory + + card = self.agent_card + if card is None: + raise RuntimeError("AgentCard not available — cannot create client.") + + factory = ClientFactory(self.client_config) + self._client = factory.create(card) + return self._client + + async def _invoke_impl( + self, ctx: RunContextWrapper[Any], input_json: str + ) -> str: + """Execute the A2A call when the tool is invoked by the LLM.""" + from a2a.client.client import ClientCallContext + from a2a.types.a2a_pb2 import Message, Part, SendMessageConfiguration, SendMessageRequest + + json_data = _parse_function_tool_json_input( + tool_name=self.tool_name, input_json=input_json + ) + + user_text = json_data.get("message", input_json) + context_id = json_data.get("context_id") + + # Build the A2A SendMessage request + text_part = Part(text=str(user_text)) + text_part.media_type = "text/plain" + + message = Message( + message_id=f"oai-msg-{uuid.uuid4().hex[:12]}", + role=1, # USER + parts=[text_part], + ) + if context_id: + message.context_id = str(context_id) + + config = SendMessageConfiguration( + accepted_output_modes=["text"], + ) + + request = SendMessageRequest( + message=message, + configuration=config, + ) + + # Send to the remote A2A agent + client = await self._get_or_create_client() + call_context = ClientCallContext() + + try: + task = await self._send_and_wait( + client, request, call_context, self.timeout_seconds + ) + except ModelBehaviorError: + raise + except asyncio.TimeoutError: + raise ModelBehaviorError( + f"A2A agent '{self.tool_name}' timed out after " + f"{self.timeout_seconds} seconds." + ) + except Exception as exc: + logger.warning( + "A2A agent '%s' call failed: %s", self.tool_name, exc + ) + raise ModelBehaviorError( + f"A2A agent '{self.tool_name}' returned an error: {exc}" + ) from exc + + # Extract text from the completed task + return self._extract_task_result(task) + + async def _send_and_wait( + self, + client: A2AClient, + request: SendMessageRequest, + call_context: ClientCallContext, + timeout_seconds: float | None, + ) -> Task: + """Send a message to the A2A agent and wait for the task to complete. + + Uses streaming if the client supports it; otherwise falls back to + polling ``get_task``. + """ + from a2a.types.a2a_pb2 import TaskState + + task: Task | None = None + task_id: str | None = None + + timeout = timeout_seconds + + async def _consume() -> None: + nonlocal task, task_id + # send_message returns an AsyncIterator[StreamResponse] + async for response in client.send_message( + request, context=call_context + ): + kind = response.WhichOneof("response") + if kind == "task": + task = response.task + task_id = task.id + if task.status.state in _TERMINAL_STATES: + return + elif kind == "task_status_update_event": + ev = response.task_status_update_event + task_id = ev.task_id + if ev.status.state in _TERMINAL_STATES: + # Fetch the complete task to get artifacts + break + + # If we have a task_id but the task wasn't fully populated + # by streaming, fetch it. + if task_id and ( + task is None + or task.status.state not in _TERMINAL_STATES + ): + from a2a.types.a2a_pb2 import GetTaskRequest + + task = await client.get_task( + GetTaskRequest(id=task_id), context=call_context + ) + + try: + await asyncio.wait_for(_consume(), timeout=timeout) + except asyncio.TimeoutError: + # Try to cancel the remote task + if task_id: + from a2a.types.a2a_pb2 import CancelTaskRequest + + try: + await asyncio.wait_for( + client.cancel_task( + CancelTaskRequest(id=task_id), + context=call_context, + ), + timeout=10.0, + ) + except Exception: + logger.debug( + "Failed to cancel A2A task %s after timeout", task_id + ) + raise + + if task is None: + raise ModelBehaviorError( + f"A2A agent '{self.tool_name}' returned no task." + ) + + if task.status.state == TaskState.TASK_STATE_FAILED: + error_text = self._extract_message_text(task.status.message) + raise ModelBehaviorError( + f"A2A agent '{self.tool_name}' task failed: {error_text}" + ) + + if task.status.state == TaskState.TASK_STATE_CANCELED: + raise ModelBehaviorError( + f"A2A agent '{self.tool_name}' task was canceled." + ) + + # TASK_STATE_REJECTED is terminal — treat as a tool error + if task.status.state == TaskState.TASK_STATE_REJECTED: + error_text = self._extract_message_text(task.status.message) + raise ModelBehaviorError( + f"A2A agent '{self.tool_name}' task was rejected: {error_text}" + ) + + return task + + def _extract_task_result(self, task: Task) -> str: + """Extract a text result from a completed A2A Task.""" + parts: list[str] = [] + + # 1. Extract from artifacts + for artifact in task.artifacts: + for part in artifact.parts: + text = self._part_to_text(part) + if text: + artifact_label = ( + f"[{artifact.name}] " if artifact.name else "" + ) + parts.append(f"{artifact_label}{text}") + + # 2. Extract from the last agent message in history + if not parts: + for msg in reversed(list(task.history)): + if msg.role == 2: # AGENT + for part in msg.parts: + text = self._part_to_text(part) + if text: + parts.append(text) + if parts: + break + + if parts: + return "\n\n".join(parts) + + # 3. Fall back to status message + return self._extract_message_text(task.status.message) + + @staticmethod + def _part_to_text(part: Part) -> str: + """Extract text content from an A2A Part.""" + kind = part.WhichOneof("content") + if kind == "text": + return part.text + if kind == "url": + return f"[URL: {part.url}]" + if kind == "data": + try: + from google.protobuf import json_format + + return json_format.MessageToJson(part.data) + except Exception: + return str(part.data) + return "" + + @staticmethod + def _extract_message_text(message: Message | None) -> str: + """Extract all text from a status/response Message.""" + if message is None: + return "" + texts: list[str] = [] + for part in message.parts: + if part.WhichOneof("content") == "text": + texts.append(part.text) + return "\n".join(texts) + + +# --------------------------------------------------------------------------- +# Terminal states +# --------------------------------------------------------------------------- + +_TERMINAL_STATES: frozenset[int] = frozenset({ + 3, # TASK_STATE_COMPLETED + 4, # TASK_STATE_FAILED + 5, # TASK_STATE_CANCELED + 7, # TASK_STATE_REJECTED +}) diff --git a/src/agents/extensions/a2a/_converter.py b/src/agents/extensions/a2a/_converter.py new file mode 100644 index 0000000000..e30afb7301 --- /dev/null +++ b/src/agents/extensions/a2a/_converter.py @@ -0,0 +1,656 @@ +""" +A2A ↔ OpenAI Agents SDK message format converter. + +This module handles bidirectional conversion between the A2A protocol types +(protobuf-based Message, Part, Task, Artifact) and the OpenAI Agents SDK types +(TResponseInputItem, RunResult, StreamEvent). + +The converter is designed as pure functions with no side effects, making it +easy to test and compose. All protobuf interactions are isolated to this module. +""" + +from __future__ import annotations + +import base64 +import dataclasses as _dataclasses +import json as json_module +import uuid +from typing import TYPE_CHECKING, Any + +from google.protobuf.timestamp_pb2 import Timestamp + +from agents.items import TResponseInputItem + +if TYPE_CHECKING: + from a2a.types.a2a_pb2 import ( + Artifact, + Message, + Part, + Task, + TaskStatus, + ) + from a2a.server.agent_execution.context import RequestContext # type: ignore[import-untyped] + from agents.result import RunResult + from agents.stream_events import StreamEvent + + +# --------------------------------------------------------------------------- +# A2A → OpenAI conversion +# --------------------------------------------------------------------------- + +# Role mapping enums are defined in the A2A proto (0=UNSPECIFIED, 1=USER, 2=AGENT). +# We handle the int values to avoid coupling to the generated enum. +_A2A_ROLE_USER = 1 +_A2A_ROLE_AGENT = 2 + + +def a2a_message_to_openai_input_items( + message: Message, + *, + include_role: bool = True, +) -> list[TResponseInputItem]: + """Convert a single A2A ``Message`` to OpenAI ``TResponseInputItem`` dicts. + + Each ``Part`` of the message becomes one input item. Text parts become + ``message``-role items; file/data parts become ``file`` or ``image`` items + when the MIME type is known. + + Args: + message: The A2A protobuf ``Message`` to convert. + include_role: When True (default), the message's ``role`` field is used + as the OpenAI item role. When False, the role is omitted so the + caller can assign it externally. + + Returns: + A list of ``TResponseInputItem`` dicts ready to pass to ``Runner.run()``. + """ + items: list[TResponseInputItem] = [] + + for part in message.parts: + item = _convert_single_part(part) + if item is None: + continue + if include_role: + item["role"] = _a2a_role_to_openai_role(message.role) + items.append(item) + + return items + + +def a2a_history_to_openai_input_items( + history: list[Message], +) -> list[TResponseInputItem]: + """Convert a list of A2A ``Message`` objects (e.g. ``Task.history``) + into a flat list of OpenAI input items, preserving role information + from each message. + + Args: + history: The message history from an A2A ``Task``. + + Returns: + A flat list of ``TResponseInputItem`` dicts. + """ + items: list[TResponseInputItem] = [] + for message in history: + items.extend(a2a_message_to_openai_input_items(message, include_role=True)) + return items + + +def a2a_context_to_openai_input( + context: RequestContext, +) -> list[TResponseInputItem]: + """Build the full OpenAI input from an A2A ``RequestContext``. + + This merges: + 1. The current incoming message (if any). + 2. The task history from the current task (if any). + 3. History from related tasks (if any), prefixed with a system note. + + Args: + context: The A2A server ``RequestContext``. + + Returns: + A flat list of ``TResponseInputItem`` dicts representing the full + conversation context for the agent to process. + """ + items: list[TResponseInputItem] = [] + + # 1. Current message + if context.message is not None: + items.extend( + a2a_message_to_openai_input_items(context.message, include_role=True) + ) + + # 2. Task history (existing conversation) + if context.current_task is not None and context.current_task.history: + history_items = a2a_history_to_openai_input_items( + list(context.current_task.history) + ) + # Avoid duplicating the current message if it's already in history + if items and history_items: + existing_content = {_item_content_hash(i) for i in items} + for hi in history_items: + if _item_content_hash(hi) not in existing_content: + items.append(hi) + else: + items.extend(history_items) + + # 3. Related tasks + for related_task in context.related_tasks: + if related_task.history: + note: TResponseInputItem = { + "role": "user", + "content": _format_related_task_note(related_task), + } + items.append(note) + items.extend(a2a_history_to_openai_input_items(list(related_task.history))) + + return items + + +# --------------------------------------------------------------------------- +# OpenAI → A2A conversion +# --------------------------------------------------------------------------- + + +def openai_final_output_to_artifacts( + final_output: Any, + *, + artifact_id: str | None = None, + artifact_name: str = "output", +) -> list[Artifact]: + """Convert an OpenAI agent's ``final_output`` into one or more A2A ``Artifact`` + objects. + + - ``str`` output → single ``Artifact`` with a text part. + - ``dict`` / Pydantic model → single ``Artifact`` with a JSON ``data`` part. + - ``list`` → one ``Artifact`` per element. + - ``None`` → empty list. + + Args: + final_output: The ``RunResult.final_output`` value. + artifact_id: Optional artifact ID; auto-generated if omitted. + artifact_name: Human-readable label for the artifact. + + Returns: + A list of A2A ``Artifact`` protobuf messages. + """ + from a2a.types.a2a_pb2 import Artifact, Part + + if final_output is None: + return [] + + if isinstance(final_output, list): + return [ + a + for item in final_output + for a in openai_final_output_to_artifacts( + item, + artifact_id=None, + artifact_name=artifact_name, + ) + ] + + artifact_id = artifact_id or f"artifact-{uuid.uuid4().hex[:12]}" + + if isinstance(final_output, str): + text_part = Part(text=final_output) + # Use media_type to help clients interpret the content. + text_part.media_type = "text/plain" + return [ + Artifact( + artifact_id=artifact_id, + name=artifact_name, + parts=[text_part], + ) + ] + + # dict, Pydantic model, dataclass, etc. + try: + json_str = _serialize_to_json(final_output) + data_part = Part() + data_part.media_type = "application/json" + # data is a google.protobuf.Value; we assign the raw JSON string as text + # because Part.data expects a Value protobuf, not a plain string. + # Text is the most compatible single-part representation. + text_part = Part(text=json_str) + text_part.media_type = "application/json" + return [ + Artifact( + artifact_id=artifact_id, + name=artifact_name, + parts=[text_part], + ) + ] + except Exception: + text_part = Part(text=str(final_output)) + text_part.media_type = "text/plain" + return [ + Artifact( + artifact_id=artifact_id, + name=artifact_name, + parts=[text_part], + ) + ] + + +def openai_items_to_a2a_messages( + items: list[TResponseInputItem], + *, + context_id: str | None = None, + task_id: str | None = None, +) -> list[Message]: + """Convert a list of OpenAI input/output items into A2A ``Message`` objects. + + Each item becomes a single ``Message``. The item's ``role`` determines the + A2A ``Role``. Tool call / tool output items are represented as data parts. + + Args: + items: The items to convert (e.g. ``RunResult.new_items``). + context_id: Optional context ID to set on every message. + task_id: Optional task ID to set on every message. + + Returns: + A list of A2A ``Message`` protobuf messages. + """ + from a2a.types.a2a_pb2 import Message, Part + + messages: list[Message] = [] + for item in items: + part = _openai_item_to_part(item) + if part is None: + continue + + role = _openai_role_to_a2a_role(item.get("role", "user")) + + message = Message( + message_id=f"msg-{uuid.uuid4().hex[:12]}", + role=role, + parts=[part], + ) + if context_id: + message.context_id = context_id + if task_id: + message.task_id = task_id + + messages.append(message) + + return messages + + +def openai_run_result_to_task( + result: RunResult, + *, + task_id: str, + context_id: str | None = None, +) -> Task: + """Build a complete A2A ``Task`` from an OpenAI ``RunResult``. + + The task will be in ``TASK_STATE_COMPLETED`` and contain: + - Full conversation history in ``history``. + - The final output as ``artifacts``. + + Args: + result: The completed ``RunResult`` from ``Runner.run()``. + task_id: The A2A task ID to assign. + context_id: Optional context ID. + + Returns: + An A2A ``Task`` protobuf message in completed state. + """ + from a2a.types.a2a_pb2 import ( + Artifact, + Task, + TaskState, + TaskStatus, + ) + + # Build history from both input and new items. + # RunResult.new_items contains RunItem objects; convert them to + # TResponseInputItem dicts so downstream converters can handle them. + from agents.items import RunItemBase + + input_items: list[TResponseInputItem] = [] + for item in result.new_items: + if isinstance(item, RunItemBase): + input_items.append(item.to_input_item()) + else: + input_items.append(item) # type: ignore[arg-type] + + history_messages = openai_items_to_a2a_messages( + input_items, + context_id=context_id, + task_id=task_id, + ) + + # Build artifacts from final output + artifacts = openai_final_output_to_artifacts(result.final_output) + + status_message = _make_status_message( + text="Task completed successfully.", + context_id=context_id, + task_id=task_id, + ) + + timestamp = Timestamp() + timestamp.GetCurrentTime() + + status = TaskStatus( + state=TaskState.TASK_STATE_COMPLETED, + message=status_message, + timestamp=timestamp, + ) + + task = Task( + id=task_id, + status=status, + artifacts=artifacts, + history=history_messages, + ) + if context_id: + task.context_id = context_id + + return task + + +def openai_error_to_failed_task( + error: Exception, + *, + task_id: str, + context_id: str | None = None, + history: list[Message] | None = None, +) -> Task: + """Build an A2A ``Task`` in ``TASK_STATE_FAILED`` from an exception. + + Args: + error: The exception that caused the failure. + task_id: The A2A task ID. + context_id: Optional context ID. + history: Optional message history accumulated before the failure. + + Returns: + An A2A ``Task`` in failed state. + """ + from a2a.types.a2a_pb2 import Task, TaskState, TaskStatus + + error_text = f"Task failed: {type(error).__name__}: {error}" + status_message = _make_status_message( + text=error_text, + context_id=context_id, + task_id=task_id, + ) + + timestamp = Timestamp() + timestamp.GetCurrentTime() + + status = TaskStatus( + state=TaskState.TASK_STATE_FAILED, + message=status_message, + timestamp=timestamp, + ) + + task = Task( + id=task_id, + status=status, + history=list(history) if history else [], + ) + if context_id: + task.context_id = context_id + + return task + + +def openai_stream_event_to_task_status( + event: StreamEvent, + *, + task_id: str, + context_id: str | None = None, +) -> TaskStatus | None: + """Convert a single OpenAI ``StreamEvent`` into an A2A ``TaskStatus``. + + Returns ``None`` for events that do not represent a task state change + (e.g. raw model deltas that should be aggregated). + + Args: + event: The streaming event from ``RunResultStreaming.stream_events()``. + task_id: The A2A task ID. + context_id: Optional context ID. + + Returns: + An ``TaskStatus`` or ``None``. + """ + from a2a.types.a2a_pb2 import TaskState, TaskStatus + + from agents.stream_events import ( + AgentUpdatedStreamEvent, + RawResponsesStreamEvent, + RunItemStreamEvent, + ) + + timestamp = Timestamp() + timestamp.GetCurrentTime() + + if isinstance(event, RunItemStreamEvent): + status_message = _make_status_message( + text=f"Agent produced: {event.name}", + context_id=context_id, + task_id=task_id, + ) + return TaskStatus( + state=TaskState.TASK_STATE_WORKING, + message=status_message, + timestamp=timestamp, + ) + + if isinstance(event, AgentUpdatedStreamEvent): + status_message = _make_status_message( + text=f"Agent switched to: {event.new_agent.name}", + context_id=context_id, + task_id=task_id, + ) + return TaskStatus( + state=TaskState.TASK_STATE_WORKING, + message=status_message, + timestamp=timestamp, + ) + + if isinstance(event, RawResponsesStreamEvent): + # Raw model events don't map cleanly to task states; skip. + return None + + return None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _convert_single_part(part: Part) -> TResponseInputItem | None: + """Convert a single A2A ``Part`` to an OpenAI input item dict.""" + content_type = part.WhichOneof("content") + if content_type is None: + return None + + if content_type == "text": + return {"content": part.text} + + if content_type == "url": + return {"content": _format_url_content(part.url)} + + if content_type == "raw": + return _convert_raw_part(part) + + if content_type == "data": + return _convert_data_part(part) + + return None + + +def _convert_raw_part(part: Part) -> TResponseInputItem | None: + """Convert a raw bytes Part to an image or file item based on media_type.""" + media = (part.media_type or "").lower() + + if media.startswith("image/"): + b64 = base64.b64encode(part.raw).decode("ascii") + return { + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:{media};base64,{b64}"}, + } + ] + } + + # Generic file: attach as text description with metadata + return { + "content": ( + f"[Attached file: {part.filename or 'unnamed'} " + f"({media or 'application/octet-stream'}, " + f"{len(part.raw)} bytes)]" + ) + } + + +def _convert_data_part(part: Part) -> TResponseInputItem | None: + """Convert a structured data Part to an OpenAI input item.""" + try: + data_dict = json_format.MessageToDict(part.data) + json_str = json_module.dumps(data_dict, ensure_ascii=False) + return {"content": json_str} + except Exception: + return {"content": str(part.data)} + + +def _openai_item_to_part(item: TResponseInputItem) -> Part | None: + """Convert a single OpenAI item to an A2A ``Part``.""" + from a2a.types.a2a_pb2 import Part + + content = item.get("content") + + if content is None: + return None + + if isinstance(content, str): + text_part = Part(text=content) + text_part.media_type = "text/plain" + return text_part + + if isinstance(content, list): + # Multi-modal content: extract text portions; preserve the first + # image URL as a note since the A2A text part is the primary medium. + texts: list[str] = [] + image_urls: list[str] = [] + for block in content: + if isinstance(block, dict): + if block.get("type") == "text": + texts.append(str(block.get("text", ""))) + elif block.get("type") == "image_url": + image_urls.append( + block.get("image_url", {}).get("url", "") + ) + if image_urls: + texts.append(f"[Attached image(s): {', '.join(image_urls[:3])}]") + + if texts: + text_part = Part(text="\n".join(texts)) + text_part.media_type = "text/plain" + return text_part + + # Fallback: stringify + text_part = Part(text=str(content)) + text_part.media_type = "text/plain" + return text_part + + +def _a2a_role_to_openai_role(a2a_role: int) -> str: + """Map A2A Role enum value to OpenAI role string.""" + if a2a_role == _A2A_ROLE_USER: + return "user" + if a2a_role == _A2A_ROLE_AGENT: + return "assistant" + return "user" + + +def _openai_role_to_a2a_role(openai_role: object) -> int: + """Map OpenAI role string to A2A Role enum value.""" + role_str = str(openai_role).lower() + if role_str in ("assistant", "agent", "model"): + return _A2A_ROLE_AGENT + return _A2A_ROLE_USER + + +def _format_url_content(url: str) -> str: + """Format a URL part as text content for the model.""" + return f"[URL: {url}]" + + +def _serialize_to_json(obj: Any) -> str: + """Serialize an arbitrary object to a JSON string, handling Pydantic models, + dataclasses, and plain dicts.""" + try: + from pydantic import BaseModel + + if isinstance(obj, BaseModel): + return obj.model_dump_json() + except ImportError: + pass + + if _dataclasses.is_dataclass(obj) and not isinstance(obj, type): + return json_module.dumps(_dataclasses.asdict(obj), default=str) + + if isinstance(obj, dict): + return json_module.dumps(obj, default=str) + + return json_module.dumps({"value": obj}, default=str) + + +def _item_content_hash(item: TResponseInputItem) -> int: + """Fast content-based hash for deduplication.""" + content = item.get("content", "") + role = item.get("role", "") + return hash(f"{role}:{content}") + + +def _format_related_task_note(task: Task) -> str: + """Format a system note describing a related task's context.""" + task_id = getattr(task, "id", "unknown") + task_state = _task_state_name(task) + return ( + f"[Related task {task_id} (status: {task_state}) " + f"provides additional context below:]" + ) + + +def _task_state_name(task: Task) -> str: + """Human-readable task state string.""" + try: + from a2a.types.a2a_pb2 import TaskState + + state_value = task.status.state + return TaskState.Name(state_value) + except (AttributeError, ValueError): + return "unknown" + + +def _make_status_message( + text: str, + context_id: str | None = None, + task_id: str | None = None, +) -> Message: + """Create a small status ``Message`` for use inside ``TaskStatus``.""" + from a2a.types.a2a_pb2 import Message, Part + + text_part = Part(text=text) + text_part.media_type = "text/plain" + + message = Message( + message_id=f"status-{uuid.uuid4().hex[:12]}", + role=_A2A_ROLE_AGENT, + parts=[text_part], + ) + if context_id: + message.context_id = context_id + if task_id: + message.task_id = task_id + return message diff --git a/src/agents/extensions/a2a/_server_executor.py b/src/agents/extensions/a2a/_server_executor.py new file mode 100644 index 0000000000..f323d6ffa7 --- /dev/null +++ b/src/agents/extensions/a2a/_server_executor.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +""" +A2A Server Agent — expose an OpenAI agent via the A2A protocol. + +Implements the ``AgentExecutor`` interface from ``a2a-sdk`` so that any +OpenAI Agents SDK ``Agent`` can be served as a standard A2A endpoint. + +Usage:: + + from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server import A2AServer + from agents import Agent + from agents.extensions.a2a import A2AServerAgent + + my_agent = Agent(name="Assistant", instructions="You are helpful.") + + executor = A2AServerAgent(agent=my_agent) + handler = DefaultRequestHandler(executor=executor) + server = A2AServer(handler=handler) + + await server.start(host="0.0.0.0", port=8080) +""" + +import asyncio +import time +import uuid +from typing import TYPE_CHECKING, Any + +from agents.logger import logger + +from ._converter import ( + a2a_context_to_openai_input, + openai_error_to_failed_task, + openai_run_result_to_task, + openai_stream_event_to_task_status, +) + +if TYPE_CHECKING: + from a2a.server.agent_execution.context import RequestContext # type: ignore[import-untyped] + from a2a.server.events.event_queue_v2 import EventQueue # type: ignore[import-untyped] + from a2a.types.a2a_pb2 import Task, TaskStatus # type: ignore[import-untyped] + + from agents.agent import Agent + from agents.run import RunConfig + from agents.run_context import TContext + + +class A2AServerAgent: + """Expose an OpenAI ``Agent`` as an A2A-compatible agent. + + Implements the ``AgentExecutor`` interface required by the ``a2a-sdk`` + server framework. The executor translates incoming A2A ``SendMessage`` + requests into ``Runner.run()`` calls on the wrapped OpenAI agent, and + translates the results back into A2A task events. + + Parameters + ---------- + agent: + The OpenAI ``Agent`` to expose. + run_config: + Optional ``RunConfig`` applied to every ``Runner.run()`` invocation. + max_turns: + Maximum conversation turns per A2A task (defaults to 30). + session_ttl_seconds: + In-memory session entries are evicted after this many seconds of + inactivity. Set to ``None`` to disable expiry. Default: 3600 (1 h). + """ + + def __init__( + self, + agent: Agent[TContext], + *, + run_config: RunConfig | None = None, + max_turns: int | None = 30, + session_ttl_seconds: float | None = 3600.0, + ) -> None: + # Attempt to inherit from AgentExecutor for protocol compliance, + # but degrade gracefully when the ABC is not available at runtime. + try: + from a2a.server.agent_execution.agent_executor import AgentExecutor + + self.__class__ = type( + self.__class__.__name__, + (self.__class__, AgentExecutor), + {}, + ) + except ImportError: + pass + + self.agent = agent + self.run_config = run_config + self.max_turns = max_turns + self._session_ttl = session_ttl_seconds + + # In-memory session store: context_id → (items, last_access_timestamp) + self._sessions: dict[str, tuple[list[Any], float]] = {} + # Running tasks for cancellation support: task_id → asyncio.Task + self._running_tasks: dict[str, asyncio.Task[Any]] = {} + + # ------------------------------------------------------------------ + # AgentExecutor interface + # ------------------------------------------------------------------ + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Execute the agent for the given A2A request context. + + This method is called by the ``a2a-sdk`` server framework for each + incoming ``SendMessage`` / ``SendStreamingMessage`` request. + + Args: + context: The A2A request context containing the message and task. + event_queue: Queue to publish ``Task``, ``TaskStatusUpdateEvent``, + and ``TaskArtifactUpdateEvent`` messages. + """ + task_id = context.task_id or f"oai-task-{uuid.uuid4().hex[:12]}" + context_id = context.context_id + + # Register for cancellation support + current_task = asyncio.current_task() + if current_task is not None: + self._running_tasks[task_id] = current_task + + try: + # Retrieve or initialise the conversation session (with TTL eviction) + input_items = self._get_session(context_id) + + # Append the current message + new_items = a2a_context_to_openai_input(context) + input_items.extend(new_items) + + from agents.run import Runner + + # Publish working status + await self._publish_working(event_queue, task_id, context_id) + + if self.run_config is not None and getattr( + self.run_config, "streaming_enabled", False + ): + # Streaming path + streamed = Runner.run_streamed( + self.agent, + input=input_items, + max_turns=self.max_turns, + run_config=self.run_config, + ) + async for event in streamed.stream_events(): + status = openai_stream_event_to_task_status( + event, task_id=task_id, context_id=context_id + ) + if status is not None: + await self._publish_status_update( + event_queue, task_id, status + ) + result = streamed + else: + # Non-streaming path + result = await Runner.run( + self.agent, + input=input_items, + max_turns=self.max_turns, + run_config=self.run_config, + ) + + # Build the completed task + task = openai_run_result_to_task( + result, + task_id=task_id, + context_id=context_id, + ) + + # Persist conversation history for subsequent turns + self._update_session(context_id, result) + + # Publish the completed task + await self._publish_task(event_queue, task) + + except asyncio.CancelledError: + # Task was cancelled by the framework + failed_task = openai_error_to_failed_task( + asyncio.CancelledError("Task was cancelled."), + task_id=task_id, + context_id=context_id, + ) + await self._publish_task(event_queue, failed_task) + raise + + except Exception as exc: + logger.exception( + "A2A executor failed for task %s: %s", task_id, exc + ) + failed_task = openai_error_to_failed_task( + exc, task_id=task_id, context_id=context_id + ) + await self._publish_task(event_queue, failed_task) + + finally: + self._running_tasks.pop(task_id, None) + + async def cancel( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Cancel an in-progress task. + + Args: + context: The request context for the task to cancel. + event_queue: Queue to publish the cancellation status. + """ + task_id = context.task_id + if task_id and task_id in self._running_tasks: + self._running_tasks[task_id].cancel() + del self._running_tasks[task_id] + + # Publish cancellation status + from a2a.types.a2a_pb2 import Message, Part, TaskState, TaskStatus + + from google.protobuf.timestamp_pb2 import Timestamp + + timestamp = Timestamp() + timestamp.GetCurrentTime() + + text_part = Part(text="Task cancelled by request.") + text_part.media_type = "text/plain" + + status = TaskStatus( + state=TaskState.TASK_STATE_CANCELED, + message=Message( + message_id=f"cancel-{uuid.uuid4().hex[:12]}", + role=2, # AGENT + parts=[text_part], + ), + timestamp=timestamp, + ) + + await self._publish_status_update(event_queue, task_id or "", status) + + # ------------------------------------------------------------------ + # Session helpers + # ------------------------------------------------------------------ + + def _get_session(self, context_id: str | None) -> list[Any]: + """Retrieve persisted conversation items for a context with TTL eviction.""" + if not context_id: + return [] + entry = self._sessions.get(context_id) + if entry is None: + return [] + + items, last_access = entry + if self._session_ttl is not None: + if time.monotonic() - last_access > self._session_ttl: + del self._sessions[context_id] + return [] + # Update last-access timestamp + self._sessions[context_id] = (items, time.monotonic()) + return list(items) + + def _update_session( + self, context_id: str | None, result: Any + ) -> None: + """Persist new conversation items for future turns.""" + if not context_id or not self._session_ttl: + return + + new_items = getattr(result, "new_items", []) + if new_items: + existing, _ = self._sessions.get(context_id, ([], time.monotonic())) + existing.extend(new_items) + self._sessions[context_id] = (existing, time.monotonic()) + + # ------------------------------------------------------------------ + # Event publishing helpers + # ------------------------------------------------------------------ + + async def _publish_working( + self, + event_queue: EventQueue, + task_id: str, + context_id: str | None, + ) -> None: + """Publish a TASK_STATE_WORKING status update.""" + from a2a.types.a2a_pb2 import Message, Part, TaskState, TaskStatus + + from google.protobuf.timestamp_pb2 import Timestamp + + timestamp = Timestamp() + timestamp.GetCurrentTime() + + text_part = Part(text="Agent is working on the task.") + text_part.media_type = "text/plain" + + message = Message( + message_id=f"working-{uuid.uuid4().hex[:12]}", + role=2, + parts=[text_part], + ) + if context_id: + message.context_id = context_id + + status = TaskStatus( + state=TaskState.TASK_STATE_WORKING, + message=message, + timestamp=timestamp, + ) + + await self._publish_status_update(event_queue, task_id, status) + + async def _publish_status_update( + self, + event_queue: EventQueue, + task_id: str, + status: TaskStatus, + ) -> None: + """Enqueue a TaskStatusUpdateEvent.""" + from a2a.types.a2a_pb2 import TaskStatusUpdateEvent + + event = TaskStatusUpdateEvent( + task_id=task_id, + status=status, + ) + await event_queue.enqueue_event(event) + + async def _publish_task( + self, event_queue: EventQueue, task: Task + ) -> None: + """Enqueue a full Task object.""" + await event_queue.enqueue_event(task) diff --git a/tests/test_a2a_client_tool.py b/tests/test_a2a_client_tool.py new file mode 100644 index 0000000000..9d8ba049fc --- /dev/null +++ b/tests/test_a2a_client_tool.py @@ -0,0 +1,337 @@ +""" +Tests for the A2AClientTool class. + +Uses a mock/fake A2A client so tests run without a live server. The fake +client implements the exact ``Client`` interface, enabling deterministic +verification of request construction and result extraction. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import uuid +from collections.abc import AsyncIterator +from typing import Any + +import pytest + +from agents.run_context import RunContextWrapper + + +# --------------------------------------------------------------------------- +# Fake A2A Client — implements the same async interface +# --------------------------------------------------------------------------- + + +class _FakeStreamResponse: + """Minimal stream response wrapper for faking send_message.""" + + def __init__(self, task: Any = None, status_update: Any = None): + self._task = task + self._status_update = status_update + + def WhichOneof(self, name: str) -> str | None: # noqa: N802 + if self._task is not None: + return "task" + if self._status_update is not None: + return "task_status_update_event" + return None + + @property + def task(self) -> Any: + return self._task + + @property + def task_status_update_event(self) -> Any: + return self._status_update + + +class FakeA2AClient: + """A fake A2A Client that replays pre-configured responses.""" + + def __init__(self, responses: list[Any] | None = None): + self._responses: list[Any] = responses or [] + self.send_message_calls: list[Any] = [] + self.get_task_calls: list[Any] = [] + self.cancel_task_calls: list[Any] = [] + + async def send_message( + self, request: Any, *, context: Any = None + ) -> AsyncIterator[Any]: + self.send_message_calls.append(request) + for response in self._responses: + yield _FakeStreamResponse(task=response) + + async def get_task(self, request: Any, *, context: Any = None) -> Any: + self.get_task_calls.append(request) + if self._responses: + return self._responses[-1] + raise RuntimeError("No responses configured") + + async def cancel_task(self, request: Any, *, context: Any = None) -> Any: + self.cancel_task_calls.append(request) + return None + + async def close(self) -> None: + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *args: Any) -> None: + pass + + +# --------------------------------------------------------------------------- +# Helpers for building fake A2A protobuf messages +# --------------------------------------------------------------------------- + + +def _fake_completed_task(task_id: str, text: str) -> Any: + """Build a fake completed A2A Task with a text artifact.""" + from a2a.types.a2a_pb2 import Artifact, Message, Part, Task, TaskState, TaskStatus + + from google.protobuf.timestamp_pb2 import Timestamp + + artifact = Artifact( + artifact_id=f"art-{uuid.uuid4().hex[:8]}", + name="output", + parts=[_text_part(text)], + ) + + timestamp = Timestamp() + timestamp.GetCurrentTime() + + status = TaskStatus( + state=TaskState.TASK_STATE_COMPLETED, + message=Message( + message_id=f"status-{uuid.uuid4().hex[:8]}", + role=2, + parts=[_text_part("Task completed.")], + ), + timestamp=timestamp, + ) + + return Task( + id=task_id, + status=status, + artifacts=[artifact], + history=[ + Message( + message_id=f"msg-{uuid.uuid4().hex[:8]}", + role=1, # USER + parts=[_text_part("user query")], + ), + Message( + message_id=f"msg-{uuid.uuid4().hex[:8]}", + role=2, # AGENT + parts=[_text_part(text)], + ), + ], + ) + + +def _fake_failed_task(task_id: str, error_msg: str) -> Any: + """Build a fake failed A2A Task.""" + from a2a.types.a2a_pb2 import Message, Task, TaskState, TaskStatus + + from google.protobuf.timestamp_pb2 import Timestamp + + timestamp = Timestamp() + timestamp.GetCurrentTime() + + status = TaskStatus( + state=TaskState.TASK_STATE_FAILED, + message=Message( + message_id=f"status-{uuid.uuid4().hex[:8]}", + role=2, + parts=[_text_part(error_msg)], + ), + timestamp=timestamp, + ) + + return Task(id=task_id, status=status) + + +def _text_part(text: str) -> Any: + """Create an A2A text Part.""" + from a2a.types.a2a_pb2 import Part + + p = Part(text=text) + p.media_type = "text/plain" + return p + + +def _minimal_card() -> Any: + """Create a minimal AgentCard for testing.""" + from a2a.types.a2a_pb2 import AgentCapabilities, AgentCard, AgentInterface + + return AgentCard( + name="test_agent", + description="A test agent", + version="1.0.0", + default_input_modes=["text"], + default_output_modes=["text"], + capabilities=AgentCapabilities(streaming=True), + skills=[], + supported_interfaces=[ + AgentInterface( + protocol_binding="jsonrpc", + url="http://localhost:9999", + ) + ], + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestA2AClientToolConstruction: + """Tests for constructing A2AClientTool instances.""" + + def test_from_card_sync(self): + from agents.extensions.a2a._client_tool import A2AClientTool + + tool = A2AClientTool.from_card( + card=_minimal_card(), + tool_name="test_tool", + tool_description="Test tool description", + ) + assert tool.tool_name == "test_tool" + assert tool.tool_description == "Test tool description" + assert tool.agent_card is not None + + def test_missing_card_and_url_raises(self): + from agents.extensions.a2a._client_tool import A2AClientTool + + with pytest.raises(ValueError, match="agent_card.*agent_card_url"): + A2AClientTool( + tool_name="bad", + tool_description="bad", + ) + + def test_as_function_tool_returns_valid_tool(self): + from agents.extensions.a2a._client_tool import A2AClientTool + + tool = A2AClientTool.from_card( + card=_minimal_card(), + tool_name="test_tool", + tool_description="Test tool", + ) + ft = tool.as_function_tool() + assert ft.name == "test_tool" + assert ft.description == "Test tool" + assert "message" in str(ft.params_json_schema) + + +class TestA2AClientToolInvocation: + """Tests for the tool invocation path with a fake client.""" + + async def test_successful_call_extracts_artifact_text(self): + from agents.extensions.a2a._client_tool import A2AClientTool + + task = _fake_completed_task("task-1", "The answer is 42.") + + tool = A2AClientTool.from_card( + card=_minimal_card(), + tool_name="math_agent", + tool_description="Does math", + ) + fake_client = FakeA2AClient([task]) + tool._client = fake_client + + result = await tool._invoke_impl( + RunContextWrapper(context=None), + '{"message": "What is the answer?"}', + ) + + assert "The answer is 42." in result + assert len(fake_client.send_message_calls) == 1 + + async def test_failed_task_raises_model_behavior_error(self): + from agents.exceptions import ModelBehaviorError + from agents.extensions.a2a._client_tool import A2AClientTool + + task = _fake_failed_task("task-fail", "Something went wrong") + + tool = A2AClientTool.from_card( + card=_minimal_card(), + tool_name="unreliable", + tool_description="Fails sometimes", + ) + fake_client = FakeA2AClient([task]) + tool._client = fake_client + + with pytest.raises(ModelBehaviorError, match="Something went wrong"): + await tool._invoke_impl( + RunContextWrapper(context=None), + '{"message": "Do something"}', + ) + + async def test_timeout_cancels_remote_task(self): + """ + When the request times out, the tool should attempt to cancel the + remote task before raising. + """ + from agents.extensions.a2a._client_tool import A2AClientTool + + # Build a fake client whose send_message never yields a completed task + class NeverFinishesClient: + async def send_message(self, request, *, context=None): + while True: + await asyncio.sleep(0.1) + yield _FakeStreamResponse() + return # pragma: no cover + + async def get_task(self, request, *, context=None): + raise RuntimeError("not reached") + + async def cancel_task(self, request, *, context=None): + pass + + async def close(self): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + tool = A2AClientTool.from_card( + card=_minimal_card(), + tool_name="slow_agent", + tool_description="Too slow", + timeout_seconds=0.5, + ) + tool._client = NeverFinishesClient() + + with pytest.raises(Exception): # ModelBehaviorError or TimeoutError + await tool._invoke_impl( + RunContextWrapper(context=None), + '{"message": "Do something slow"}', + ) + + async def test_context_id_is_propagated_to_request(self): + from agents.extensions.a2a._client_tool import A2AClientTool + + task = _fake_completed_task("task-ctx", "Done with context") + + tool = A2AClientTool.from_card( + card=_minimal_card(), + tool_name="context_agent", + tool_description="Uses context", + ) + fake_client = FakeA2AClient([task]) + tool._client = fake_client + + await tool._invoke_impl( + RunContextWrapper(context=None), + '{"message": "Continue conversation", "context_id": "ctx-42"}', + ) + + request = fake_client.send_message_calls[0] + assert request.message.context_id == "ctx-42" diff --git a/tests/test_a2a_converter.py b/tests/test_a2a_converter.py new file mode 100644 index 0000000000..1cd8a4fc5b --- /dev/null +++ b/tests/test_a2a_converter.py @@ -0,0 +1,440 @@ +""" +Tests for the A2A ↔ OpenAI Agents SDK converter module. + +These tests use inline-constructed protobuf objects to avoid requiring a live +A2A server. They verify the bidirectional conversion fidelity for messages, +task history, artifacts, and streaming events. +""" + +from __future__ import annotations + +import json +import uuid + +import pytest + + +def _make_part_text(text: str) -> "Part": + """Create an A2A text Part.""" + from a2a.types.a2a_pb2 import Part + + p = Part(text=text) + p.media_type = "text/plain" + return p + + +def _make_part_url(url: str) -> "Part": + """Create an A2A URL Part.""" + from a2a.types.a2a_pb2 import Part + + p = Part() + p.url = url + return p + + +def _make_part_data(data: dict) -> "Part": + """Create an A2A data Part.""" + from a2a.types.a2a_pb2 import Part + + from google.protobuf.struct_pb2 import Value + + p = Part() + json_str = json.dumps(data) + p.media_type = "application/json" + p.text = json_str + return p + + +def _make_message( + *, + role: int = 1, + text: str | None = None, + parts: list["Part"] | None = None, + message_id: str | None = None, +) -> "Message": + """Create an A2A Message with sensible defaults.""" + from a2a.types.a2a_pb2 import Message + + if parts is None and text is not None: + parts = [_make_part_text(text)] + if parts is None: + parts = [] + + return Message( + message_id=message_id or f"msg-{uuid.uuid4().hex[:12]}", + role=role, + parts=parts, + ) + + +def _make_run_result( + final_output: object = "hello world", + *, + new_items: list | None = None, + last_agent_name: str = "test_agent", +) -> "RunResult": + """Create a minimal RunResult for testing.""" + from unittest.mock import MagicMock + + from agents.result import RunResult + + result = MagicMock(spec=RunResult) + result.final_output = final_output + result.new_items = new_items or [] + result._last_agent = MagicMock() + result._last_agent.name = last_agent_name + return result + + +# --------------------------------------------------------------------------- +# A2A → OpenAI tests +# --------------------------------------------------------------------------- + + +class TestA2AMessageToOpenAIInput: + """Tests for a2a_message_to_openai_input_items.""" + + def test_text_message_user_role(self): + from agents.extensions.a2a._converter import a2a_message_to_openai_input_items + + msg = _make_message(role=1, text="Hello, agent!") + items = a2a_message_to_openai_input_items(msg) + + assert len(items) == 1 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello, agent!" + + def test_text_message_agent_role(self): + from agents.extensions.a2a._converter import a2a_message_to_openai_input_items + + msg = _make_message(role=2, text="Here is the answer.") + items = a2a_message_to_openai_input_items(msg) + + assert len(items) == 1 + assert items[0]["role"] == "assistant" + assert items[0]["content"] == "Here is the answer." + + def test_message_without_role_inclusion(self): + from agents.extensions.a2a._converter import a2a_message_to_openai_input_items + + msg = _make_message(role=1, text="No role please") + items = a2a_message_to_openai_input_items(msg, include_role=False) + + assert len(items) == 1 + assert "role" not in items[0] + assert items[0]["content"] == "No role please" + + def test_message_with_url_part(self): + from agents.extensions.a2a._converter import a2a_message_to_openai_input_items + + msg = _make_message(parts=[_make_part_url("https://example.com/report.pdf")]) + items = a2a_message_to_openai_input_items(msg) + + assert len(items) == 1 + assert "URL" in items[0]["content"] + + def test_message_with_multiple_parts(self): + from agents.extensions.a2a._converter import a2a_message_to_openai_input_items + + msg = _make_message( + parts=[ + _make_part_text("First part"), + _make_part_text("Second part"), + ] + ) + items = a2a_message_to_openai_input_items(msg) + + assert len(items) == 2 + assert items[0]["content"] == "First part" + assert items[1]["content"] == "Second part" + + def test_empty_message(self): + from agents.extensions.a2a._converter import a2a_message_to_openai_input_items + + msg = _make_message(parts=[]) + items = a2a_message_to_openai_input_items(msg) + + assert items == [] + + def test_unspecified_role_defaults_to_user(self): + from agents.extensions.a2a._converter import a2a_message_to_openai_input_items + + msg = _make_message(role=0, text="Unspecified role") + items = a2a_message_to_openai_input_items(msg) + + assert items[0]["role"] == "user" + + +class TestA2AHistoryToOpenAIInput: + """Tests for a2a_history_to_openai_input_items.""" + + def test_converts_history_preserving_roles(self): + from agents.extensions.a2a._converter import a2a_history_to_openai_input_items + + history = [ + _make_message(role=1, text="User query"), + _make_message(role=2, text="Agent response"), + _make_message(role=1, text="Follow-up"), + ] + items = a2a_history_to_openai_input_items(history) + + assert len(items) == 3 + assert items[0]["role"] == "user" + assert items[0]["content"] == "User query" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "Agent response" + assert items[2]["role"] == "user" + assert items[2]["content"] == "Follow-up" + + def test_empty_history(self): + from agents.extensions.a2a._converter import a2a_history_to_openai_input_items + + assert a2a_history_to_openai_input_items([]) == [] + + +# --------------------------------------------------------------------------- +# OpenAI → A2A tests +# --------------------------------------------------------------------------- + + +class TestOpenAIFinalOutputToArtifacts: + """Tests for openai_final_output_to_artifacts.""" + + def test_string_output(self): + from agents.extensions.a2a._converter import openai_final_output_to_artifacts + + artifacts = openai_final_output_to_artifacts("result text") + + assert len(artifacts) == 1 + assert artifacts[0].name == "output" + assert len(artifacts[0].parts) == 1 + assert artifacts[0].parts[0].text == "result text" + assert artifacts[0].parts[0].media_type == "text/plain" + + def test_none_output(self): + from agents.extensions.a2a._converter import openai_final_output_to_artifacts + + assert openai_final_output_to_artifacts(None) == [] + + def test_dict_output(self): + from agents.extensions.a2a._converter import openai_final_output_to_artifacts + + artifacts = openai_final_output_to_artifacts({"key": "value"}) + + assert len(artifacts) == 1 + text = artifacts[0].parts[0].text + parsed = json.loads(text) + assert parsed == {"key": "value"} + + def test_custom_artifact_id(self): + from agents.extensions.a2a._converter import openai_final_output_to_artifacts + + artifacts = openai_final_output_to_artifacts( + "data", artifact_id="custom-id", artifact_name="report" + ) + + assert artifacts[0].artifact_id == "custom-id" + assert artifacts[0].name == "report" + + def test_int_output_stringified(self): + from agents.extensions.a2a._converter import openai_final_output_to_artifacts + + artifacts = openai_final_output_to_artifacts(42) + + assert len(artifacts) == 1 + assert "42" in artifacts[0].parts[0].text + + +class TestOpenAIItemsToA2AMessages: + """Tests for openai_items_to_a2a_messages.""" + + def test_user_item(self): + from agents.extensions.a2a._converter import openai_items_to_a2a_messages + + items = [{"role": "user", "content": "User message"}] + messages = openai_items_to_a2a_messages(items) + + assert len(messages) == 1 + assert messages[0].role == 1 # USER + assert messages[0].parts[0].text == "User message" + + def test_assistant_item(self): + from agents.extensions.a2a._converter import openai_items_to_a2a_messages + + items = [{"role": "assistant", "content": "Agent reply"}] + messages = openai_items_to_a2a_messages(items) + + assert len(messages) == 1 + assert messages[0].role == 2 # AGENT + + def test_with_context_and_task_ids(self): + from agents.extensions.a2a._converter import openai_items_to_a2a_messages + + items = [{"role": "user", "content": "Hi"}] + messages = openai_items_to_a2a_messages( + items, context_id="ctx-1", task_id="task-1" + ) + + assert messages[0].context_id == "ctx-1" + assert messages[0].task_id == "task-1" + + def test_empty_items(self): + from agents.extensions.a2a._converter import openai_items_to_a2a_messages + + assert openai_items_to_a2a_messages([]) == [] + + def test_item_with_none_content_skipped(self): + from agents.extensions.a2a._converter import openai_items_to_a2a_messages + + items = [{"role": "user", "content": None}] # type: ignore[dict-item] + messages = openai_items_to_a2a_messages(items) + + assert messages == [] + + +class TestOpenAIRunResultToTask: + """Tests for openai_run_result_to_task.""" + + def test_builds_completed_task(self): + from agents.extensions.a2a._converter import openai_run_result_to_task + + result = _make_run_result( + final_output="done", + new_items=[ + {"role": "user", "content": "query"}, + {"role": "assistant", "content": "response"}, + ], + ) + task = openai_run_result_to_task(result, task_id="task-abc") + + assert task.id == "task-abc" + assert task.status.state == 3 # TASK_STATE_COMPLETED + assert len(task.artifacts) == 1 + assert task.artifacts[0].parts[0].text == "done" + assert len(task.history) == 2 + + def test_task_with_context_id(self): + from agents.extensions.a2a._converter import openai_run_result_to_task + + result = _make_run_result(final_output="ok") + task = openai_run_result_to_task( + result, task_id="t1", context_id="ctx-42" + ) + + assert task.context_id == "ctx-42" + + +class TestOpenAIErrorToFailedTask: + """Tests for openai_error_to_failed_task.""" + + def test_builds_failed_task(self): + from agents.extensions.a2a._converter import openai_error_to_failed_task + + task = openai_error_to_failed_task( + ValueError("something went wrong"), task_id="task-fail" + ) + + assert task.id == "task-fail" + assert task.status.state == 4 # TASK_STATE_FAILED + assert "something went wrong" in task.status.message.parts[0].text + + +# --------------------------------------------------------------------------- +# Round-trip fidelity tests +# --------------------------------------------------------------------------- + + +class TestRoundTrip: + """End-to-end conversion fidelity tests.""" + + def test_text_message_round_trip(self): + """A2A text Message → OpenAI items → A2A Messages should preserve text.""" + from agents.extensions.a2a._converter import ( + a2a_message_to_openai_input_items, + openai_items_to_a2a_messages, + ) + + original = _make_message(role=1, text="Hello, world!") + items = a2a_message_to_openai_input_items(original) + restored = openai_items_to_a2a_messages(items) + + assert len(restored) == 1 + assert restored[0].parts[0].text == "Hello, world!" + + def test_multiturn_conversation_round_trip(self): + """A full conversation should survive the round-trip.""" + from agents.extensions.a2a._converter import ( + a2a_history_to_openai_input_items, + openai_items_to_a2a_messages, + ) + + history = [ + _make_message(role=1, text="What is 2+2?"), + _make_message(role=2, text="4"), + _make_message(role=1, text="Thanks!"), + ] + items = a2a_history_to_openai_input_items(history) + restored = openai_items_to_a2a_messages(items) + + assert len(restored) == 3 + assert restored[0].parts[0].text == "What is 2+2?" + assert restored[0].role == 1 # USER + assert restored[1].parts[0].text == "4" + assert restored[1].role == 2 # AGENT + assert restored[2].parts[0].text == "Thanks!" + assert restored[2].role == 1 # USER + + +# --------------------------------------------------------------------------- +# Streaming event tests +# --------------------------------------------------------------------------- + + +class TestStreamEventConversion: + """Tests for openai_stream_event_to_task_status.""" + + def test_run_item_event_returns_working_status(self): + import dataclasses + + from agents.stream_events import RunItemStreamEvent + + from agents.extensions.a2a._converter import openai_stream_event_to_task_status + + # Create a minimal RunItem mock + class FakeRunItem: + type = "message_output_item" + + event = RunItemStreamEvent( + name="message_output_created", + item=FakeRunItem(), # type: ignore[arg-type] + ) + status = openai_stream_event_to_task_status(event, task_id="task-s1") + + assert status is not None + assert status.state == 2 # TASK_STATE_WORKING + + def test_agent_updated_event_returns_working_status(self): + from unittest.mock import MagicMock + + from agents.stream_events import AgentUpdatedStreamEvent + + from agents.extensions.a2a._converter import openai_stream_event_to_task_status + + agent = MagicMock() + agent.name = "new_agent" + event = AgentUpdatedStreamEvent(new_agent=agent) + status = openai_stream_event_to_task_status(event, task_id="task-s2") + + assert status is not None + assert status.state == 2 # TASK_STATE_WORKING + assert "new_agent" in status.message.parts[0].text + + def test_raw_event_returns_none(self): + from agents.stream_events import RawResponsesStreamEvent + + from agents.extensions.a2a._converter import openai_stream_event_to_task_status + + event = RawResponsesStreamEvent(data={"type": "response.output_text.delta"}) # type: ignore[arg-type] + status = openai_stream_event_to_task_status(event, task_id="task-s3") + + assert status is None diff --git a/tests/test_a2a_server_executor.py b/tests/test_a2a_server_executor.py new file mode 100644 index 0000000000..aeb191831a --- /dev/null +++ b/tests/test_a2a_server_executor.py @@ -0,0 +1,245 @@ +""" +Smoke tests for the A2A server executor. + +Verifies that A2AServerAgent correctly: + - Processes a SendMessage request and produces a completed Task + - Persists and retrieves conversation sessions across turns + - Handles errors gracefully (produces a failed Task) +""" + +from __future__ import annotations + +import asyncio +import uuid +from typing import Any + +import pytest + +from agents.agent import Agent + + +# --------------------------------------------------------------------------- +# Fake RequestContext — minimal mock for testing +# --------------------------------------------------------------------------- + + +class _FakeRequestContext: + """Minimal RequestContext stub for testing A2AServerAgent.""" + + def __init__( + self, + message: Any = None, + task_id: str | None = None, + context_id: str | None = None, + current_task: Any = None, + related_tasks: list[Any] | None = None, + ) -> None: + self._message = message + self.task_id = task_id or f"test-task-{uuid.uuid4().hex[:8]}" + self.context_id = context_id or f"test-session-{uuid.uuid4().hex[:8]}" + self.current_task = current_task + self.related_tasks = related_tasks or [] + + @property + def message(self) -> Any: + return self._message + + +# --------------------------------------------------------------------------- +# Fake EventQueue — captures published events for assertions +# --------------------------------------------------------------------------- + + +class _FakeEventQueue: + """Minimal EventQueue stub that captures published events.""" + + def __init__(self) -> None: + self.tasks: list[Any] = [] + self.status_updates: list[Any] = [] + + async def enqueue_event(self, event: Any) -> None: + """Match the real EventQueue.enqueue_event signature.""" + from a2a.types.a2a_pb2 import Task, TaskStatusUpdateEvent + + if isinstance(event, Task): + self.tasks.append(event) + elif isinstance(event, TaskStatusUpdateEvent): + self.status_updates.append({ + "task_id": event.task_id, + "status": event.status, + }) + + async def enqueue_task(self, task: Any) -> None: + self.tasks.append(task) + + async def enqueue_task_status_update(self, task_id: str, status: Any, final: bool) -> None: + self.status_updates.append({"task_id": task_id, "status": status, "final": final}) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_fake_output(text: str) -> list[Any]: + """Create a fake model output containing a text message.""" + from openai.types.responses import ResponseOutputMessage, ResponseOutputText + + return [ + ResponseOutputMessage( + id=f"msg-{uuid.uuid4().hex[:8]}", + type="message", + role="assistant", + content=[ + ResponseOutputText( + text=text, + type="output_text", + annotations=[], + ) + ], + status="completed", + ) + ] + + +def _create_agent_with_fake_model( + agent: Agent, + output_text: str = "I am a helpful AI assistant.", + num_turns: int = 1, +) -> Agent: + """Set a FakeModel on the agent so tests don't need a real API key. + + Args: + agent: The agent to configure. + output_text: The text the fake model should respond with. + num_turns: How many turns of output to pre-load (default 1). + """ + from tests.fake_model import FakeModel + + fake_model = FakeModel(initial_output=_make_fake_output(output_text)) + # Pre-load additional turn outputs for multi-turn tests + for _ in range(1, num_turns): + fake_model.set_next_output(_make_fake_output(output_text)) + agent.model = fake_model + return agent + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +@pytest.mark.allow_call_model_methods +async def test_executor_produces_completed_task() -> None: + """Verify that a basic non-streaming execution produces a completed Task.""" + from agents.extensions.a2a._converter import ( + a2a_message_to_openai_input_items, + ) + + # Build an A2A message using the real converter + from a2a.types.a2a_pb2 import Message, Part + + part = Part() + part.text = "Hello" + part.media_type = "text/plain" + message = Message() + message.role = 1 # USER + message.parts.append(part) + + context = _FakeRequestContext(message=message) + event_queue = _FakeEventQueue() + + agent = _create_agent_with_fake_model( + Agent(name="TestAgent", instructions="Be helpful.") + ) + + from agents.extensions.a2a._server_executor import A2AServerAgent + + executor = A2AServerAgent(agent=agent, max_turns=5) + await executor.execute(context, event_queue) + + # Should have published at least one task (working + completed) + assert len(event_queue.tasks) >= 1, "expected at least a completed task" + completed = event_queue.tasks[-1] + assert completed.id == context.task_id + + +@pytest.mark.asyncio +@pytest.mark.allow_call_model_methods +async def test_executor_session_persistence() -> None: + """Verify that conversation history is persisted across turns.""" + from a2a.types.a2a_pb2 import Message, Part + + def make_message(text: str) -> Message: + part = Part() + part.text = text + part.media_type = "text/plain" + msg = Message() + msg.role = 1 # USER + msg.parts.append(part) + return msg + + agent = _create_agent_with_fake_model( + Agent(name="SessionAgent", instructions="You are helpful."), + num_turns=2, + ) + + from agents.extensions.a2a._server_executor import A2AServerAgent + + executor = A2AServerAgent(agent=agent, max_turns=5) + + # First turn + ctx1 = _FakeRequestContext( + message=make_message("First message"), + context_id="session-1", + ) + queue1 = _FakeEventQueue() + await executor.execute(ctx1, queue1) + assert len(queue1.tasks) >= 1 + + # Verify session was stored + session_items = executor._get_session("session-1") + assert len(session_items) > 0, "session should contain items after first turn" + + # Second turn: session should have accumulated history + ctx2 = _FakeRequestContext( + message=make_message("Second message"), + context_id="session-1", + ) + queue2 = _FakeEventQueue() + await executor.execute(ctx2, queue2) + assert len(queue2.tasks) >= 1 + + # Session should now contain more items + session_items_2 = executor._get_session("session-1") + assert len(session_items_2) > len(session_items), ( + "session should accumulate history across turns" + ) + + +@pytest.mark.asyncio +@pytest.mark.allow_call_model_methods +async def test_executor_cancel_cleans_up_running_task() -> None: + """Verify that a running task is removed from tracking after completion.""" + from a2a.types.a2a_pb2 import Message, Part + + part = Part() + part.text = "test" + part.media_type = "text/plain" + message = Message() + message.role = 1 + message.parts.append(part) + + context = _FakeRequestContext(message=message, task_id="cancel-test-task") + event_queue = _FakeEventQueue() + + agent = _create_agent_with_fake_model(Agent(name="CancelAgent")) + + from agents.extensions.a2a._server_executor import A2AServerAgent + + executor = A2AServerAgent(agent=agent, max_turns=3) + await executor.execute(context, event_queue) + + # After completion, the task should not be in the running tasks dict + assert "cancel-test-task" not in executor._running_tasks diff --git a/uv.lock b/uv.lock index 03d2dd0903..9120d61720 100644 --- a/uv.lock +++ b/uv.lock @@ -3,7 +3,8 @@ revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", - "python_full_version >= '3.12' and python_full_version < '3.14'", + "python_full_version == '3.13.*'", + "python_full_version == '3.12.*'", "python_full_version == '3.11.*'", "python_full_version < '3.11'", ] @@ -12,6 +13,26 @@ resolution-markers = [ exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. exclude-newer-span = "P7D" +[[package]] +name = "a2a-sdk" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "culsans", marker = "python_full_version < '3.13'" }, + { name = "google-api-core" }, + { name = "googleapis-common-protos" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "json-rpc" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/35/8b7ac94f405f57c591925fa0afc105a0f797151876fffa666b57722eefa9/a2a_sdk-1.0.3.tar.gz", hash = "sha256:c57ddd910aece4a426ae26b8f0d0e8e2f3271a6adde974078075e4f600aaf628", size = 367155, upload-time = "2026-05-13T06:52:33.929Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/53/6f/ae79f8210f1ecd70e1c37c310a523b26f1d6da458d4c1365914bf1ea58e0/a2a_sdk-1.0.3-py3-none-any.whl", hash = "sha256:068e5b2ceb4e962ac61d9e1fd43ca0c1016b64f0c80d901f6e23420bc8a31a93", size = 235705, upload-time = "2026-05-13T06:52:31.88Z" }, +] + [[package]] name = "aiofiles" version = "24.1.0" @@ -128,6 +149,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1a/99/84ba7273339d0f3dfa57901b846489d2e5c2cd731470167757f1935fffbd/aiohttp_retry-2.9.1-py3-none-any.whl", hash = "sha256:66d2759d1921838256a05a3f80ad7e724936f083e35be5abb5e16eed6be6dc54", size = 9981, upload-time = "2024-11-06T10:44:52.917Z" }, ] +[[package]] +name = "aiologic" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "sniffio", marker = "python_full_version < '3.14'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "wrapt", marker = "python_full_version < '3.14'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/13/50b91a3ea6b030d280d2654be97c48b6ed81753a50286ee43c646ba36d3c/aiologic-0.16.0.tar.gz", hash = "sha256:c267ccbd3ff417ec93e78d28d4d577ccca115d5797cdbd16785a551d9658858f", size = 225952, upload-time = "2025-11-27T23:48:41.195Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/27/206615942005471499f6fbc36621582e24d0686f33c74b2d018fcfd4fe67/aiologic-0.16.0-py3-none-any.whl", hash = "sha256:e00ce5f68c5607c864d26aec99c0a33a83bdf8237aa7312ffbb96805af67d8b6", size = 135193, upload-time = "2025-11-27T23:48:40.099Z" }, +] + [[package]] name = "aiosignal" version = "1.4.0" @@ -725,6 +760,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/23/87/7ce86f3fa14bc11a5a48c30d8103c26e09b6465f8d8e9d74cf7a0714f043/cryptography-45.0.7-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:1f3d56f73595376f4244646dd5c5870c14c196949807be39e79e7bd9bac3da63", size = 3332908, upload-time = "2025-09-01T11:14:58.78Z" }, ] +[[package]] +name = "culsans" +version = "0.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiologic", marker = "python_full_version < '3.14'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/e3/49afa1bc180e0d28008ec6bcdf82a4072d1c7a41032b5b759b60814ca4b0/culsans-0.11.0.tar.gz", hash = "sha256:0b43d0d05dce6106293d114c86e3fb4bfc63088cfe8ff08ed3fe36891447fe33", size = 107546, upload-time = "2025-12-31T23:15:38.196Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/5d/9fb19fb38f6d6120422064279ea5532e22b84aa2be8831d49607194feda3/culsans-0.11.0-py3-none-any.whl", hash = "sha256:278d118f63fc75b9db11b664b436a1b83cc30d9577127848ba41420e66eb5a47", size = 21811, upload-time = "2025-12-31T23:15:37.189Z" }, +] + [[package]] name = "dapr" version = "1.16.0" @@ -1196,6 +1244,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034, upload-time = "2022-05-02T15:47:14.552Z" }, ] +[[package]] +name = "google-api-core" +version = "2.30.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/16/ce/502a57fb0ec752026d24df1280b162294b22a0afb98a326084f9a979138b/google_api_core-2.30.3.tar.gz", hash = "sha256:e601a37f148585319b26db36e219df68c5d07b6382cff2d580e83404e44d641b", size = 177001, upload-time = "2026-04-10T00:41:28.035Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/15/e56f351cf6ef1cfea58e6ac226a7318ed1deb2218c4b3cc9bd9e4b786c5a/google_api_core-2.30.3-py3-none-any.whl", hash = "sha256:a85761ba72c444dad5d611c2220633480b2b6be2521eca69cca2dbb3ffd6bfe8", size = 173274, upload-time = "2026-04-09T22:57:16.198Z" }, +] + +[[package]] +name = "google-auth" +version = "2.53.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "pyasn1-modules" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/ad/ff781329bbbdc0974a098d996e89c9e1f7024262f9e3eec442fbb9ad1ac6/google_auth-2.53.0.tar.gz", hash = "sha256:e7e6aa16f6bee7b2b264830fd04f08087a1d5a836df516251a5d15327b246c9c", size = 335844, upload-time = "2026-05-15T20:53:07.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/c9/db44165ba7c581268c6d46017ef63339110378305062830104fc7fa144cb/google_auth-2.53.0-py3-none-any.whl", hash = "sha256:6e7449917c599b35126a99ec268ec6880301f2fea41dce198fe8fd83ff642b68", size = 246071, upload-time = "2026-05-15T20:53:05.609Z" }, +] + [[package]] name = "googleapis-common-protos" version = "1.70.0" @@ -1625,6 +1702,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, ] +[[package]] +name = "json-rpc" +version = "1.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/9e/59f4a5b7855ced7346ebf40a2e9a8942863f644378d956f68bcef2c88b90/json-rpc-1.15.0.tar.gz", hash = "sha256:e6441d56c1dcd54241c937d0a2dcd193bdf0bdc539b5316524713f554b7f85b9", size = 28854, upload-time = "2023-06-11T09:45:49.078Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/9e/820c4b086ad01ba7d77369fb8b11470a01fac9b4977f02e18659cf378b6b/json_rpc-1.15.0-py2.py3-none-any.whl", hash = "sha256:4a4668bbbe7116feb4abbd0f54e64a4adcf4b8f648f19ffa0848ad0f6606a9bf", size = 39450, upload-time = "2023-06-11T09:45:47.136Z" }, +] + [[package]] name = "jsonschema" version = "4.25.0" @@ -2244,7 +2330,8 @@ version = "2.3.2" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.14'", - "python_full_version >= '3.12' and python_full_version < '3.14'", + "python_full_version == '3.13.*'", + "python_full_version == '3.12.*'", "python_full_version == '3.11.*'", ] sdist = { url = "https://files.pythonhosted.org/packages/37/7d/3fec4199c5ffb892bed55cff901e4f39a58c81df9c44c280499e92cad264/numpy-2.3.2.tar.gz", hash = "sha256:e0486a11ec30cdecb53f184d496d1c6a20786c81e55e41640270130056f8ee48", size = 20489306, upload-time = "2025-07-24T21:32:07.553Z" } @@ -2446,6 +2533,9 @@ dependencies = [ ] [package.optional-dependencies] +a2a = [ + { name = "a2a-sdk" }, +] any-llm = [ { name = "any-llm-sdk", marker = "python_full_version >= '3.11'" }, ] @@ -2550,6 +2640,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "a2a-sdk", marker = "extra == 'a2a'", specifier = ">=1.0.0,<2" }, { name = "aiohttp", marker = "extra == 'blaxel'", specifier = ">=3.12,<4" }, { name = "aiohttp", marker = "extra == 'cloudflare'", specifier = ">=3.12,<4" }, { name = "any-llm-sdk", marker = "python_full_version >= '3.11' and extra == 'any-llm'", specifier = ">=1.11.0,<2" }, @@ -2585,7 +2676,7 @@ requires-dist = [ { name = "websockets", marker = "extra == 'realtime'", specifier = ">=15.0,<17" }, { name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<17" }, ] -provides-extras = ["voice", "viz", "litellm", "any-llm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr", "mongodb", "docker", "blaxel", "daytona", "cloudflare", "e2b", "modal", "runloop", "vercel", "s3", "temporal"] +provides-extras = ["voice", "viz", "litellm", "any-llm", "realtime", "sqlalchemy", "encrypt", "redis", "dapr", "mongodb", "docker", "blaxel", "daytona", "cloudflare", "e2b", "modal", "runloop", "vercel", "s3", "a2a", "temporal"] [package.metadata.requires-dev] dev = [ @@ -2908,6 +2999,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/35/cc0aaecf278bb4575b8555f2b137de5ab821595ddae9da9d3cd1da4072c7/propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f", size = 12663, upload-time = "2025-06-09T22:56:04.484Z" }, ] +[[package]] +name = "proto-plus" +version = "1.28.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/56/e647b0c675392d2da368da7b6f158f7368b18542fd6f7d7400a2f39de000/proto_plus-1.28.0.tar.gz", hash = "sha256:38e5696342835b08fc116f30a25665b29531cda9d5d5643e9b81fc312385abd9", size = 57221, upload-time = "2026-05-07T08:04:50.811Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/20/b122d4626976acb81132036d2ad1bb35a1a8775fceb837ec30964622516a/proto_plus-1.28.0-py3-none-any.whl", hash = "sha256:a630604310899e73c59ec302e5765c058d412b2f090b9c79c8822589f14955b8", size = 50410, upload-time = "2026-05-07T08:03:31.962Z" }, +] + [[package]] name = "protobuf" version = "5.29.5" @@ -2922,6 +3025,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/cc/7e77861000a0691aeea8f4566e5d3aa716f2b1dece4a24439437e41d3d25/protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5", size = 172823, upload-time = "2025-05-28T23:51:58.157Z" }, ] +[[package]] +name = "pyasn1" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5f/6583902b6f79b399c9c40674ac384fd9cd77805f9e6205075f828ef11fb2/pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf", size = 148685, upload-time = "2026-03-17T01:06:53.382Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/a0/7d793dce3fa811fe047d6ae2431c672364b462850c6235ae306c0efd025f/pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde", size = 83997, upload-time = "2026-03-17T01:06:52.036Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + [[package]] name = "pycparser" version = "2.22"