diff --git a/src/apify/events/_apify_event_manager.py b/src/apify/events/_apify_event_manager.py index 21707875..8ded8334 100644 --- a/src/apify/events/_apify_event_manager.py +++ b/src/apify/events/_apify_event_manager.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Annotated, Self import websockets.asyncio.client +import websockets.exceptions from pydantic import Discriminator, TypeAdapter from typing_extensions import Unpack, override @@ -91,49 +92,80 @@ async def __aexit__( exc_value: BaseException | None, exc_traceback: TracebackType | None, ) -> None: - if self._platform_events_websocket: - await self._platform_events_websocket.close() - + # Cancel the task before closing the websocket so that the closed connection is not treated as a drop + # and followed by a reconnect attempt. if self._process_platform_messages_task and not self._process_platform_messages_task.done(): self._process_platform_messages_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._process_platform_messages_task + if self._platform_events_websocket: + await self._platform_events_websocket.close() + await super().__aexit__(exc_type, exc_value, exc_traceback) + def _process_connection_exception(self, exc: Exception) -> Exception | None: + """Decide whether a failed connection attempt to the platform websocket should be retried. + + Before the first successful connection, every error is fatal so that `__aenter__` fails fast. After that, + the default `websockets` behavior decides which errors are transient and retried with exponential backoff. + """ + if self._connected_to_platform_websocket and self._connected_to_platform_websocket.done(): + return websockets.asyncio.client.process_exception(exc) + return exc + async def _process_platform_messages(self, ws_url: str) -> None: try: - async with websockets.asyncio.client.connect(ws_url) as websocket: + # Used as an async iterator, `connect` reconnects with exponential backoff whenever a connection + # attempt fails with a transient error. + async for websocket in websockets.asyncio.client.connect( + ws_url, process_exception=self._process_connection_exception + ): self._platform_events_websocket = websocket - if self._connected_to_platform_websocket is not None: + if self._connected_to_platform_websocket and not self._connected_to_platform_websocket.done(): self._connected_to_platform_websocket.set_result(True) - - async for message in websocket: - try: - parsed_message = event_data_adapter.validate_json(message) - - if isinstance(parsed_message, DeprecatedEvent): - continue - - if isinstance(parsed_message, UnknownEvent): - logger.info( - f'Unknown message received: event_name={parsed_message.name}, ' - f'event_data={parsed_message.data}' + else: + logger.info('Reconnected to the platform events websocket.') + + try: + async for message in websocket: + try: + parsed_message = event_data_adapter.validate_json(message) + + if isinstance(parsed_message, DeprecatedEvent): + continue + + if isinstance(parsed_message, UnknownEvent): + logger.info( + f'Unknown message received: event_name={parsed_message.name}, ' + f'event_data={parsed_message.data}' + ) + continue + + self.emit( + event=parsed_message.name, + event_data=parsed_message.data + if not isinstance(parsed_message.data, SystemInfoEventData) + else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1), ) - continue - - self.emit( - event=parsed_message.name, - event_data=parsed_message.data - if not isinstance(parsed_message.data, SystemInfoEventData) - else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1), - ) - - if parsed_message.name == Event.MIGRATING: - await self._emit_persist_state_event_rec_task.stop() - self.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True)) - except Exception: - logger.exception('Cannot parse Actor event', extra={'raw_message': message}) + + if parsed_message.name == Event.MIGRATING: + await self._emit_persist_state_event_rec_task.stop() + self.emit( + event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True) + ) + except Exception: + logger.exception('Cannot parse Actor event', extra={'raw_message': message}) + except websockets.exceptions.ConnectionClosed: + logger.warning( + f'Connection to platform events websocket was lost ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) + else: + logger.info( + f'Connection to platform events websocket was closed ' + f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...' + ) except Exception: logger.exception('Error in websocket connection') if self._connected_to_platform_websocket is not None and not self._connected_to_platform_websocket.done(): diff --git a/tests/unit/events/test_apify_event_manager.py b/tests/unit/events/test_apify_event_manager.py index 13568f43..3e80eecb 100644 --- a/tests/unit/events/test_apify_event_manager.py +++ b/tests/unit/events/test_apify_event_manager.py @@ -12,7 +12,6 @@ import pytest import websockets import websockets.asyncio.server -import websockets.exceptions from crawlee.events._types import Event @@ -26,6 +25,18 @@ from collections.abc import AsyncGenerator, Callable +DUMMY_SYSTEM_INFO = { + 'memAvgBytes': 19328860.328293584, + 'memCurrentBytes': 65171456, + 'memMaxBytes': 65171456, + 'cpuAvgUsage': 2.0761105633130397, + 'cpuMaxUsage': 53.941134593993326, + 'cpuCurrentUsage': 8.45549815498155, + 'isCpuOverloaded': False, + 'createdAt': '2024-08-09T16:04:16.161Z', +} + + @contextlib.asynccontextmanager async def _platform_ws_server( monkeypatch: pytest.MonkeyPatch, @@ -189,17 +200,7 @@ async def send_platform_event(event_name: Event, data: Any = None) -> None: websockets.broadcast(connected_ws_clients, json.dumps(message)) - dummy_system_info = { - 'memAvgBytes': 19328860.328293584, - 'memCurrentBytes': 65171456, - 'memMaxBytes': 65171456, - 'cpuAvgUsage': 2.0761105633130397, - 'cpuMaxUsage': 53.941134593993326, - 'cpuCurrentUsage': 8.45549815498155, - 'isCpuOverloaded': False, - 'createdAt': '2024-08-09T16:04:16.161Z', - } - SystemInfoEventData.model_validate(dummy_system_info) + SystemInfoEventData.model_validate(DUMMY_SYSTEM_INFO) async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager: await client_connected.wait() @@ -211,7 +212,7 @@ def listener(data: Any) -> None: event_manager.on(event=Event.SYSTEM_INFO, listener=listener) # Test sending event with data - await send_platform_event(Event.SYSTEM_INFO, dummy_system_info) + await send_platform_event(Event.SYSTEM_INFO, DUMMY_SYSTEM_INFO) await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05) assert len(event_calls) == 1 assert event_calls[0] is not None @@ -320,38 +321,54 @@ def migrating_listener(data: Any) -> None: assert len(migration_persist_events) >= 1 -async def test_websocket_mid_stream_disconnect_does_not_raise_invalid_state_error( - monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +@pytest.mark.parametrize( + ('close_code', 'expected_log'), + [ + pytest.param(1000, 'Connection to platform events websocket was closed (code=1000', id='graceful_close'), + pytest.param(1011, 'Connection to platform events websocket was lost (code=1011', id='abnormal_close'), + ], +) +async def test_websocket_reconnects_after_connection_drop( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, close_code: int, expected_log: str ) -> None: - """Regression: a mid-stream websocket disconnect after a successful connect must not raise InvalidStateError. + """Test that the event manager logs a websocket drop, reconnects, and keeps receiving platform events. - The `_connected_to_platform_websocket` future is resolved to `True` on successful connect. If the websocket - later drops, the outer `except` in `_process_platform_messages` must not call `set_result(False)` on the - already-resolved future. + Also a regression test for the resolved `_connected_to_platform_websocket` future: a mid-stream disconnect + must not kill the message-processing task with `InvalidStateError`. """ + caplog.set_level(logging.INFO, logger='apify') async with ( _platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected), ApifyEventManager(Configuration.get_global_configuration()) as event_manager, ): await client_connected.wait() + assert len(connected_ws_clients) == 1 + + event_calls: list[Any] = [] + event_manager.on(event=Event.SYSTEM_INFO, listener=event_calls.append) - # Force an abnormal close from the server so the client's `async for` raises ConnectionClosedError. + # Drop the connection from the server side and wait for the client to reconnect. + client_connected.clear() for ws in list(connected_ws_clients): - await ws.close(code=1011, reason='Simulated server error') + await ws.close(code=close_code, reason='Simulated connection drop') + await asyncio.wait_for(client_connected.wait(), timeout=10) + # Poll because the old server-side handler may not have deregistered its connection yet. + await poll_until_condition(lambda: len(connected_ws_clients) == 1, poll_interval=0.05) + assert len(connected_ws_clients) == 1 + # The message-processing task must have survived the drop. task = event_manager._process_platform_messages_task assert task is not None - await asyncio.wait_for(asyncio.shield(task), timeout=2.0) + assert not task.done() - exc = task.exception() - assert not isinstance(exc, asyncio.InvalidStateError), f'Task raised InvalidStateError: {exc}' + # Events sent over the new connection must still be emitted. + websockets.broadcast(connected_ws_clients, json.dumps({'name': 'systemInfo', 'data': DUMMY_SYSTEM_INFO})) + await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05) + assert len(event_calls) == 1 - # Confirm the test actually exercised the disconnect path — the outer `except` in - # `_process_platform_messages` should have logged a `ConnectionClosedError`. - logged_exc_types = [ - record.exc_info[0] for record in caplog.records if record.exc_info and record.exc_info[0] is not None - ] - assert any(issubclass(exc_type, websockets.exceptions.ConnectionClosedError) for exc_type in logged_exc_types) + # Both the drop and the successful reconnect must be logged. + assert expected_log in caplog.text + assert 'Reconnected to the platform events websocket.' in caplog.text async def test_malformed_message_logs_exception(