Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 63 additions & 31 deletions src/apify/events/_apify_event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, Annotated, Self

import websockets.asyncio.client
import websockets.exceptions
from pydantic import Discriminator, TypeAdapter
from typing_extensions import Unpack, override

Expand Down Expand Up @@ -91,49 +92,80 @@ async def __aexit__(
exc_value: BaseException | None,
exc_traceback: TracebackType | None,
) -> None:
if self._platform_events_websocket:
await self._platform_events_websocket.close()

# Cancel the task before closing the websocket so that the closed connection is not treated as a drop
# and followed by a reconnect attempt.
if self._process_platform_messages_task and not self._process_platform_messages_task.done():
self._process_platform_messages_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._process_platform_messages_task

if self._platform_events_websocket:
await self._platform_events_websocket.close()

await super().__aexit__(exc_type, exc_value, exc_traceback)

def _process_connection_exception(self, exc: Exception) -> Exception | None:
"""Decide whether a failed connection attempt to the platform websocket should be retried.

Before the first successful connection, every error is fatal so that `__aenter__` fails fast. After that,
the default `websockets` behavior decides which errors are transient and retried with exponential backoff.
"""
if self._connected_to_platform_websocket and self._connected_to_platform_websocket.done():
return websockets.asyncio.client.process_exception(exc)
return exc

async def _process_platform_messages(self, ws_url: str) -> None:
try:
async with websockets.asyncio.client.connect(ws_url) as websocket:
# Used as an async iterator, `connect` reconnects with exponential backoff whenever a connection
# attempt fails with a transient error.
async for websocket in websockets.asyncio.client.connect(
ws_url, process_exception=self._process_connection_exception
):
self._platform_events_websocket = websocket
if self._connected_to_platform_websocket is not None:
if self._connected_to_platform_websocket and not self._connected_to_platform_websocket.done():
self._connected_to_platform_websocket.set_result(True)

async for message in websocket:
try:
parsed_message = event_data_adapter.validate_json(message)

if isinstance(parsed_message, DeprecatedEvent):
continue

if isinstance(parsed_message, UnknownEvent):
logger.info(
f'Unknown message received: event_name={parsed_message.name}, '
f'event_data={parsed_message.data}'
else:
logger.info('Reconnected to the platform events websocket.')

try:
async for message in websocket:
try:
parsed_message = event_data_adapter.validate_json(message)

if isinstance(parsed_message, DeprecatedEvent):
continue

if isinstance(parsed_message, UnknownEvent):
logger.info(
f'Unknown message received: event_name={parsed_message.name}, '
f'event_data={parsed_message.data}'
)
continue

self.emit(
event=parsed_message.name,
event_data=parsed_message.data
if not isinstance(parsed_message.data, SystemInfoEventData)
else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1),
)
continue

self.emit(
event=parsed_message.name,
event_data=parsed_message.data
if not isinstance(parsed_message.data, SystemInfoEventData)
else parsed_message.data.to_crawlee_format(self._configuration.dedicated_cpus or 1),
)

if parsed_message.name == Event.MIGRATING:
await self._emit_persist_state_event_rec_task.stop()
self.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True))
except Exception:
logger.exception('Cannot parse Actor event', extra={'raw_message': message})

if parsed_message.name == Event.MIGRATING:
await self._emit_persist_state_event_rec_task.stop()
self.emit(
event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=True)
)
except Exception:
logger.exception('Cannot parse Actor event', extra={'raw_message': message})
except websockets.exceptions.ConnectionClosed:
logger.warning(
f'Connection to platform events websocket was lost '
f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...'
)
else:
logger.info(
f'Connection to platform events websocket was closed '
f'(code={websocket.close_code}, reason={websocket.close_reason!r}), reconnecting...'
)
except Exception:
logger.exception('Error in websocket connection')
if self._connected_to_platform_websocket is not None and not self._connected_to_platform_websocket.done():
Expand Down
77 changes: 47 additions & 30 deletions tests/unit/events/test_apify_event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pytest
import websockets
import websockets.asyncio.server
import websockets.exceptions

from crawlee.events._types import Event

Expand All @@ -26,6 +25,18 @@
from collections.abc import AsyncGenerator, Callable


DUMMY_SYSTEM_INFO = {
'memAvgBytes': 19328860.328293584,
'memCurrentBytes': 65171456,
'memMaxBytes': 65171456,
'cpuAvgUsage': 2.0761105633130397,
'cpuMaxUsage': 53.941134593993326,
'cpuCurrentUsage': 8.45549815498155,
'isCpuOverloaded': False,
'createdAt': '2024-08-09T16:04:16.161Z',
}


@contextlib.asynccontextmanager
async def _platform_ws_server(
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -189,17 +200,7 @@ async def send_platform_event(event_name: Event, data: Any = None) -> None:

websockets.broadcast(connected_ws_clients, json.dumps(message))

dummy_system_info = {
'memAvgBytes': 19328860.328293584,
'memCurrentBytes': 65171456,
'memMaxBytes': 65171456,
'cpuAvgUsage': 2.0761105633130397,
'cpuMaxUsage': 53.941134593993326,
'cpuCurrentUsage': 8.45549815498155,
'isCpuOverloaded': False,
'createdAt': '2024-08-09T16:04:16.161Z',
}
SystemInfoEventData.model_validate(dummy_system_info)
SystemInfoEventData.model_validate(DUMMY_SYSTEM_INFO)

async with ApifyEventManager(Configuration.get_global_configuration()) as event_manager:
await client_connected.wait()
Expand All @@ -211,7 +212,7 @@ def listener(data: Any) -> None:
event_manager.on(event=Event.SYSTEM_INFO, listener=listener)

# Test sending event with data
await send_platform_event(Event.SYSTEM_INFO, dummy_system_info)
await send_platform_event(Event.SYSTEM_INFO, DUMMY_SYSTEM_INFO)
await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05)
assert len(event_calls) == 1
assert event_calls[0] is not None
Expand Down Expand Up @@ -320,38 +321,54 @@ def migrating_listener(data: Any) -> None:
assert len(migration_persist_events) >= 1


async def test_websocket_mid_stream_disconnect_does_not_raise_invalid_state_error(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
@pytest.mark.parametrize(
('close_code', 'expected_log'),
[
pytest.param(1000, 'Connection to platform events websocket was closed (code=1000', id='graceful_close'),
pytest.param(1011, 'Connection to platform events websocket was lost (code=1011', id='abnormal_close'),
],
)
async def test_websocket_reconnects_after_connection_drop(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture, close_code: int, expected_log: str
) -> None:
"""Regression: a mid-stream websocket disconnect after a successful connect must not raise InvalidStateError.
"""Test that the event manager logs a websocket drop, reconnects, and keeps receiving platform events.

The `_connected_to_platform_websocket` future is resolved to `True` on successful connect. If the websocket
later drops, the outer `except` in `_process_platform_messages` must not call `set_result(False)` on the
already-resolved future.
Also a regression test for the resolved `_connected_to_platform_websocket` future: a mid-stream disconnect
must not kill the message-processing task with `InvalidStateError`.
"""
caplog.set_level(logging.INFO, logger='apify')
async with (
_platform_ws_server(monkeypatch) as (connected_ws_clients, client_connected),
ApifyEventManager(Configuration.get_global_configuration()) as event_manager,
):
await client_connected.wait()
assert len(connected_ws_clients) == 1

event_calls: list[Any] = []
event_manager.on(event=Event.SYSTEM_INFO, listener=event_calls.append)

# Force an abnormal close from the server so the client's `async for` raises ConnectionClosedError.
# Drop the connection from the server side and wait for the client to reconnect.
client_connected.clear()
for ws in list(connected_ws_clients):
await ws.close(code=1011, reason='Simulated server error')
await ws.close(code=close_code, reason='Simulated connection drop')
await asyncio.wait_for(client_connected.wait(), timeout=10)
# Poll because the old server-side handler may not have deregistered its connection yet.
await poll_until_condition(lambda: len(connected_ws_clients) == 1, poll_interval=0.05)
assert len(connected_ws_clients) == 1

# The message-processing task must have survived the drop.
task = event_manager._process_platform_messages_task
assert task is not None
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
assert not task.done()

exc = task.exception()
assert not isinstance(exc, asyncio.InvalidStateError), f'Task raised InvalidStateError: {exc}'
# Events sent over the new connection must still be emitted.
websockets.broadcast(connected_ws_clients, json.dumps({'name': 'systemInfo', 'data': DUMMY_SYSTEM_INFO}))
await poll_until_condition(lambda: len(event_calls) == 1, poll_interval=0.05)
assert len(event_calls) == 1

# Confirm the test actually exercised the disconnect path — the outer `except` in
# `_process_platform_messages` should have logged a `ConnectionClosedError`.
logged_exc_types = [
record.exc_info[0] for record in caplog.records if record.exc_info and record.exc_info[0] is not None
]
assert any(issubclass(exc_type, websockets.exceptions.ConnectionClosedError) for exc_type in logged_exc_types)
# Both the drop and the successful reconnect must be logged.
assert expected_log in caplog.text
assert 'Reconnected to the platform events websocket.' in caplog.text


async def test_malformed_message_logs_exception(
Expand Down
Loading