Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
TaskPushNotificationConfig,
TaskQueryParams,
TaskState,
TaskStatus,
UnsupportedOperationError,
)
from a2a.utils.errors import ServerError
Expand Down Expand Up @@ -286,6 +287,27 @@ async def _send_push_notification_if_needed(
if isinstance(latest_task, Task):
await self._push_sender.send_notification(latest_task)

async def _consume_and_notify_in_background(
self,
task_id: str,
result_aggregator: ResultAggregator,
consumer: EventConsumer,
producer_task: asyncio.Task,
) -> None:
"""Consume executor events and send push notifications in background."""
try:
async for _event in result_aggregator.consume_and_emit(consumer):
await self._send_push_notification_if_needed(
task_id, result_aggregator
)
except Exception:
logger.exception(
'Background event consumption failed for task %s',
task_id,
)
finally:
await self._cleanup_producer(producer_task, task_id)

async def on_message_send(
self,
params: MessageSendParams,
Expand All @@ -295,6 +317,11 @@ async def on_message_send(

Starts the agent execution for the message and waits for the final
result (Task or Message).

When ``blocking`` is ``False``, the handler returns the task
immediately without waiting for executor events and processes
everything in the background. Results are delivered via push
notifications.
"""
(
_task_manager,
Expand All @@ -311,6 +338,34 @@ async def on_message_send(
if params.configuration and params.configuration.blocking is False:
blocking = False

# Non-blocking fast path: return the task immediately and process
# events entirely in the background via push notifications.
if not blocking:
task = await _task_manager.get_task()
if not task:
task = Task(
id=task_id,
context_id=params.message.context_id,
status=TaskStatus(state=TaskState.submitted),
history=[params.message],
)

bg_task = asyncio.create_task(
self._consume_and_notify_in_background(
task_id, result_aggregator, consumer, producer_task
)
)
bg_task.set_name(f'non_blocking_consume:{task_id}')
self._track_background_task(bg_task)

if params.configuration:
task = apply_history_length(
task, params.configuration.history_length
)

return task

# Blocking path: wait for completion or interruption.
interrupted_or_non_blocking = False
try:
# Create async callback for push notifications
Expand All @@ -325,7 +380,7 @@ async def push_notification_callback() -> None:
bg_consume_task,
) = await result_aggregator.consume_and_break_on_interrupt(
consumer,
blocking=blocking,
blocking=True,
event_callback=push_notification_callback,
)

Expand Down
150 changes: 36 additions & 114 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,135 +460,57 @@ async def get_current_result():

@pytest.mark.asyncio
async def test_on_message_send_with_push_notification_in_non_blocking_request():
"""Test that push notification callback is called during background event processing for non-blocking requests."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)
"""Test that non-blocking requests return immediately and process push notifications in background."""
task_store = InMemoryTaskStore()
push_store = InMemoryPushNotificationConfigStore()
mock_push_sender = AsyncMock()

task_id = 'non_blocking_task_1'
context_id = 'non_blocking_ctx_1'

# Create a task that will be returned after the first event
initial_task = create_sample_task(
task_id=task_id, context_id=context_id, status_state=TaskState.working
)

# Create a final task that will be available during background processing
final_task = create_sample_task(
task_id=task_id, context_id=context_id, status_state=TaskState.completed
)

mock_task_store.get.return_value = None

# Mock request context
mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
push_config_store=mock_push_notification_store,
request_context_builder=mock_request_context_builder,
agent_executor=HelloAgentExecutor(),
task_store=task_store,
push_config_store=push_store,
push_sender=mock_push_sender,
)

# Configure push notification
push_config = PushNotificationConfig(url='http://callback.com/push')
message_config = MessageSendConfiguration(
push_notification_config=push_config,
accepted_output_modes=['text/plain'],
blocking=False, # Non-blocking request
)
params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_non_blocking',
parts=[],
task_id=task_id,
context_id=context_id,
parts=[Part(root=TextPart(text='Hi'))],
),
configuration=MessageSendConfiguration(
push_notification_config=push_config,
accepted_output_modes=['text/plain'],
blocking=False,
),
configuration=message_config,
)

# Mock ResultAggregator with custom behavior
mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator)

# First call returns the initial task and indicates interruption (non-blocking)
mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = (
initial_task,
True, # interrupted = True for non-blocking
MagicMock(spec=asyncio.Task), # background task
)

# Mock the current_result property to return the final task
async def get_current_result():
return final_task

type(mock_result_aggregator_instance).current_result = PropertyMock(
return_value=get_current_result()
)

# Track if the event_callback was passed to consume_and_break_on_interrupt
event_callback_passed = False
event_callback_received = None

async def mock_consume_and_break_on_interrupt(
consumer, blocking=True, event_callback=None
):
nonlocal event_callback_passed, event_callback_received
event_callback_passed = event_callback is not None
event_callback_received = event_callback
return (
initial_task,
True,
MagicMock(spec=asyncio.Task),
) # interrupted = True for non-blocking

mock_result_aggregator_instance.consume_and_break_on_interrupt = (
mock_consume_and_break_on_interrupt
result = await request_handler.on_message_send(
params, create_server_call_context()
)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=initial_task,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message',
return_value=initial_task,
),
):
# Execute the non-blocking request
result = await request_handler.on_message_send(
params, create_server_call_context()
)
# Non-blocking: should return immediately with submitted state
assert result is not None
assert isinstance(result, Task)
assert result.status.state == TaskState.submitted

# Verify the result is the initial task (non-blocking behavior)
assert result == initial_task
# Wait for background processing to complete
for _ in range(10):
await asyncio.sleep(0.1)
task = await task_store.get(result.id)
if task and task.status.state == TaskState.completed:
break

# Verify that the event_callback was passed to consume_and_break_on_interrupt
assert event_callback_passed, (
'event_callback should have been passed to consume_and_break_on_interrupt'
)
assert event_callback_received is not None, (
'event_callback should not be None'
)
assert task is not None
assert task.status.state == TaskState.completed

# Verify that the push notification was sent with the final task
mock_push_sender.send_notification.assert_called_with(final_task)
# Verify push notification config was stored by checking the store directly
stored_configs = await push_store.get_info(result.id)
assert stored_configs and len(stored_configs) >= 1

# Verify that the push notification config was stored
mock_push_notification_store.set_info.assert_awaited_once_with(
task_id, push_config
)
# Verify push notifications were sent during background processing
assert mock_push_sender.send_notification.call_count >= 1


@pytest.mark.asyncio
Expand Down Expand Up @@ -843,11 +765,11 @@ async def test_on_message_send_non_blocking():

assert task is not None
assert task.status.state == TaskState.completed
assert (
result.history
and task.history
and len(result.history) == len(task.history)
)
# The immediately returned result has the initial history (user message),
# while the completed task may have additional history entries from the
# executor. The initial result should have at least the user message.
assert result.history and len(result.history) >= 1
assert task.history and len(task.history) >= 1


@pytest.mark.asyncio
Expand Down
Loading