diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 3bd6a0dc2..52e7c01bf 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -41,6 +41,7 @@ TaskPushNotificationConfig, TaskQueryParams, TaskState, + TaskStatus, UnsupportedOperationError, ) from a2a.utils.errors import ServerError @@ -286,6 +287,27 @@ async def _send_push_notification_if_needed( if isinstance(latest_task, Task): await self._push_sender.send_notification(latest_task) + async def _consume_and_notify_in_background( + self, + task_id: str, + result_aggregator: ResultAggregator, + consumer: EventConsumer, + producer_task: asyncio.Task, + ) -> None: + """Consume executor events and send push notifications in background.""" + try: + async for _event in result_aggregator.consume_and_emit(consumer): + await self._send_push_notification_if_needed( + task_id, result_aggregator + ) + except Exception: + logger.exception( + 'Background event consumption failed for task %s', + task_id, + ) + finally: + await self._cleanup_producer(producer_task, task_id) + async def on_message_send( self, params: MessageSendParams, @@ -295,6 +317,11 @@ async def on_message_send( Starts the agent execution for the message and waits for the final result (Task or Message). + + When ``blocking`` is ``False``, the handler returns the task + immediately without waiting for executor events and processes + everything in the background. Results are delivered via push + notifications. """ ( _task_manager, @@ -311,6 +338,34 @@ async def on_message_send( if params.configuration and params.configuration.blocking is False: blocking = False + # Non-blocking fast path: return the task immediately and process + # events entirely in the background via push notifications. + if not blocking: + task = await _task_manager.get_task() + if not task: + task = Task( + id=task_id, + context_id=params.message.context_id, + status=TaskStatus(state=TaskState.submitted), + history=[params.message], + ) + + bg_task = asyncio.create_task( + self._consume_and_notify_in_background( + task_id, result_aggregator, consumer, producer_task + ) + ) + bg_task.set_name(f'non_blocking_consume:{task_id}') + self._track_background_task(bg_task) + + if params.configuration: + task = apply_history_length( + task, params.configuration.history_length + ) + + return task + + # Blocking path: wait for completion or interruption. interrupted_or_non_blocking = False try: # Create async callback for push notifications @@ -325,7 +380,7 @@ async def push_notification_callback() -> None: bg_consume_task, ) = await result_aggregator.consume_and_break_on_interrupt( consumer, - blocking=blocking, + blocking=True, event_callback=push_notification_callback, ) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index ec2956fa2..6a33ddfe5 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -460,135 +460,57 @@ async def get_current_result(): @pytest.mark.asyncio async def test_on_message_send_with_push_notification_in_non_blocking_request(): - """Test that push notification callback is called during background event processing for non-blocking requests.""" - mock_task_store = AsyncMock(spec=TaskStore) - mock_push_notification_store = AsyncMock(spec=PushNotificationConfigStore) - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + """Test that non-blocking requests return immediately and process push notifications in background.""" + task_store = InMemoryTaskStore() + push_store = InMemoryPushNotificationConfigStore() mock_push_sender = AsyncMock() - task_id = 'non_blocking_task_1' - context_id = 'non_blocking_ctx_1' - - # Create a task that will be returned after the first event - initial_task = create_sample_task( - task_id=task_id, context_id=context_id, status_state=TaskState.working - ) - - # Create a final task that will be available during background processing - final_task = create_sample_task( - task_id=task_id, context_id=context_id, status_state=TaskState.completed - ) - - mock_task_store.get.return_value = None - - # Mock request context - mock_request_context = MagicMock(spec=RequestContext) - mock_request_context.task_id = task_id - mock_request_context.context_id = context_id - mock_request_context_builder.build.return_value = mock_request_context - request_handler = DefaultRequestHandler( - agent_executor=mock_agent_executor, - task_store=mock_task_store, - push_config_store=mock_push_notification_store, - request_context_builder=mock_request_context_builder, + agent_executor=HelloAgentExecutor(), + task_store=task_store, + push_config_store=push_store, push_sender=mock_push_sender, ) - # Configure push notification push_config = PushNotificationConfig(url='http://callback.com/push') - message_config = MessageSendConfiguration( - push_notification_config=push_config, - accepted_output_modes=['text/plain'], - blocking=False, # Non-blocking request - ) params = MessageSendParams( message=Message( role=Role.user, message_id='msg_non_blocking', - parts=[], - task_id=task_id, - context_id=context_id, + parts=[Part(root=TextPart(text='Hi'))], + ), + configuration=MessageSendConfiguration( + push_notification_config=push_config, + accepted_output_modes=['text/plain'], + blocking=False, ), - configuration=message_config, - ) - - # Mock ResultAggregator with custom behavior - mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) - - # First call returns the initial task and indicates interruption (non-blocking) - mock_result_aggregator_instance.consume_and_break_on_interrupt.return_value = ( - initial_task, - True, # interrupted = True for non-blocking - MagicMock(spec=asyncio.Task), # background task - ) - - # Mock the current_result property to return the final task - async def get_current_result(): - return final_task - - type(mock_result_aggregator_instance).current_result = PropertyMock( - return_value=get_current_result() ) - # Track if the event_callback was passed to consume_and_break_on_interrupt - event_callback_passed = False - event_callback_received = None - - async def mock_consume_and_break_on_interrupt( - consumer, blocking=True, event_callback=None - ): - nonlocal event_callback_passed, event_callback_received - event_callback_passed = event_callback is not None - event_callback_received = event_callback - return ( - initial_task, - True, - MagicMock(spec=asyncio.Task), - ) # interrupted = True for non-blocking - - mock_result_aggregator_instance.consume_and_break_on_interrupt = ( - mock_consume_and_break_on_interrupt + result = await request_handler.on_message_send( + params, create_server_call_context() ) - with ( - patch( - 'a2a.server.request_handlers.default_request_handler.ResultAggregator', - return_value=mock_result_aggregator_instance, - ), - patch( - 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', - return_value=initial_task, - ), - patch( - 'a2a.server.request_handlers.default_request_handler.TaskManager.update_with_message', - return_value=initial_task, - ), - ): - # Execute the non-blocking request - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + # Non-blocking: should return immediately with submitted state + assert result is not None + assert isinstance(result, Task) + assert result.status.state == TaskState.submitted - # Verify the result is the initial task (non-blocking behavior) - assert result == initial_task + # Wait for background processing to complete + for _ in range(10): + await asyncio.sleep(0.1) + task = await task_store.get(result.id) + if task and task.status.state == TaskState.completed: + break - # Verify that the event_callback was passed to consume_and_break_on_interrupt - assert event_callback_passed, ( - 'event_callback should have been passed to consume_and_break_on_interrupt' - ) - assert event_callback_received is not None, ( - 'event_callback should not be None' - ) + assert task is not None + assert task.status.state == TaskState.completed - # Verify that the push notification was sent with the final task - mock_push_sender.send_notification.assert_called_with(final_task) + # Verify push notification config was stored by checking the store directly + stored_configs = await push_store.get_info(result.id) + assert stored_configs and len(stored_configs) >= 1 - # Verify that the push notification config was stored - mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config - ) + # Verify push notifications were sent during background processing + assert mock_push_sender.send_notification.call_count >= 1 @pytest.mark.asyncio @@ -843,11 +765,11 @@ async def test_on_message_send_non_blocking(): assert task is not None assert task.status.state == TaskState.completed - assert ( - result.history - and task.history - and len(result.history) == len(task.history) - ) + # The immediately returned result has the initial history (user message), + # while the completed task may have additional history entries from the + # executor. The initial result should have at least the user message. + assert result.history and len(result.history) >= 1 + assert task.history and len(task.history) >= 1 @pytest.mark.asyncio