From 0bf9ea882d8c4ee196f30d3004042527564a5b7f Mon Sep 17 00:00:00 2001 From: Alex Kwiatkowski Date: Sun, 28 Jun 2026 20:30:09 -0700 Subject: [PATCH 1/2] dispatch realtime events on captured loop --- .../realtime/connection_management.py | 6 +- src/project_x_py/realtime/event_handling.py | 77 +++++++++++++------ tests/realtime/test_event_handling.py | 64 ++++++++++++++- 3 files changed, 118 insertions(+), 29 deletions(-) diff --git a/src/project_x_py/realtime/connection_management.py b/src/project_x_py/realtime/connection_management.py index f4c11a4..b6e465a 100644 --- a/src/project_x_py/realtime/connection_management.py +++ b/src/project_x_py/realtime/connection_management.py @@ -272,9 +272,6 @@ async def connect(self: "ProjectXRealtimeClientProtocol") -> bool: operation="connect", account_id=self.account_id, ): - if not self.setup_complete: - await self.setup_connections() - # Store the event loop for cross-thread task scheduling try: self._loop = asyncio.get_running_loop() @@ -282,6 +279,9 @@ async def connect(self: "ProjectXRealtimeClientProtocol") -> bool: logger.error("No running event loop found.") return False + if not self.setup_complete: + await self.setup_connections() + logger.debug(LogMessages.WS_CONNECT) async with self._connection_lock: diff --git a/src/project_x_py/realtime/event_handling.py b/src/project_x_py/realtime/event_handling.py index bf7d6b9..9aedd02 100644 --- a/src/project_x_py/realtime/event_handling.py +++ b/src/project_x_py/realtime/event_handling.py @@ -349,10 +349,9 @@ def _forward_quote_update(self, *args: Any) -> None: """ if self._use_batching and self._batched_handler and args: # Use batched processing for high-frequency quotes - self._create_task( + self._schedule_coroutine_threadsafe( self._batched_handler.handle_quote(args[0]), name="handle_quote", - persistent=False, ) else: self._schedule_async_task("quote_update", args) @@ -372,10 +371,9 @@ def _forward_market_trade(self, *args: Any) -> None: """ if self._use_batching and self._batched_handler and args: # Use batched processing for trades - self._create_task( + self._schedule_coroutine_threadsafe( self._batched_handler.handle_trade(args[0]), name="handle_trade", - persistent=False, ) else: self._schedule_async_task("market_trade", args) @@ -395,14 +393,58 @@ def _forward_market_depth(self, *args: Any) -> None: """ if self._use_batching and self._batched_handler and args: # Use batched processing for depth updates - self._create_task( + self._schedule_coroutine_threadsafe( self._batched_handler.handle_depth(args[0]), name="handle_depth", - persistent=False, ) else: self._schedule_async_task("market_depth", args) + def _active_event_loop(self) -> asyncio.AbstractEventLoop | None: + """Return the captured asyncio loop, or capture the current running loop.""" + if self._loop and not self._loop.is_closed(): + return self._loop + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return None + + if loop.is_closed(): + return None + + self._loop = loop + return loop + + def _schedule_coroutine_threadsafe( + self, + coro: Coroutine[Any, Any, Any], + *, + name: str, + ) -> bool: + """Schedule a coroutine on the captured loop from the SignalR thread.""" + loop = self._active_event_loop() + if loop is None: + coro.close() + self.logger.debug(f"Dropping {name}; no active asyncio event loop") + return False + + try: + future = asyncio.run_coroutine_threadsafe(coro, loop) + except Exception as e: + coro.close() + self.logger.error(f"Error scheduling async task {name}: {e}") + return False + + def _log_task_error(task_future: Any) -> None: + try: + task_future.result() + except Exception as exc: + self.logger.error(f"Async task {name} failed: {exc}", exc_info=True) + + future.add_done_callback(_log_task_error) + return True + def _schedule_async_task(self, event_type: str, data: Any) -> None: """ Schedule async task in the main event loop from any thread. @@ -428,25 +470,10 @@ def _schedule_async_task(self, event_type: str, data: Any) -> None: Note: Critical for thread safety - ensures callbacks run in proper context. """ - if self._loop and not self._loop.is_closed(): - try: - asyncio.run_coroutine_threadsafe( - self._forward_event_async(event_type, data), self._loop - ) - except Exception as e: - # Fallback for logging - avoid recursion - self.logger.error(f"Error scheduling async task: {e}") - else: - # Fallback - try to create task in current loop context - try: - self._create_task( - self._forward_event_async(event_type, data), - name=f"forward_{event_type}", - persistent=False, - ) - except RuntimeError: - # No event loop available, log and continue - self.logger.error(f"No event loop available for {event_type} event") + self._schedule_coroutine_threadsafe( + self._forward_event_async(event_type, data), + name=f"forward_{event_type}", + ) async def _forward_event_async(self, event_type: str, args: Any) -> None: """ diff --git a/tests/realtime/test_event_handling.py b/tests/realtime/test_event_handling.py index 5249642..8219274 100644 --- a/tests/realtime/test_event_handling.py +++ b/tests/realtime/test_event_handling.py @@ -321,6 +321,25 @@ async def test_batched_handler_cleanup(self, event_handler): class TestCrossThreadEventScheduling: """Test cross-thread event scheduling for asyncio compatibility.""" + @pytest.mark.asyncio + async def test_schedule_async_task_captures_running_loop(self, event_handler): + """Test scheduling captures the active event loop when none was stored.""" + event_data = {"test": "data"} + received = asyncio.Event() + + async def callback(data): + assert data == event_data + received.set() + + await event_handler.add_callback('test_event', callback) + + assert event_handler._loop is None + + event_handler._schedule_async_task('test_event', (event_data,)) + + await asyncio.wait_for(received.wait(), timeout=1.0) + assert event_handler._loop == asyncio.get_running_loop() + @pytest.mark.asyncio async def test_schedule_event_from_different_thread(self, event_handler): """Test scheduling event from a different thread.""" @@ -358,6 +377,49 @@ def thread_func(): callback.assert_called_once_with(event_data) + @pytest.mark.asyncio + async def test_schedule_async_task_from_signalr_thread(self, event_handler): + """Test SignalR thread events use the captured asyncio loop.""" + import threading + + event_data = {"test": "data"} + received = asyncio.Event() + + async def callback(data): + assert data == event_data + received.set() + + await event_handler.add_callback('test_event', callback) + + event_handler._loop = asyncio.get_running_loop() + + thread = threading.Thread( + target=lambda: event_handler._schedule_async_task( + 'test_event', (event_data,) + ) + ) + thread.start() + thread.join(timeout=1.0) + + await asyncio.wait_for(received.wait(), timeout=1.0) + + def test_schedule_async_task_without_loop_drops_event(self, event_handler): + """Test late SignalR events are dropped quietly when no loop exists.""" + import threading + + thread = threading.Thread( + target=lambda: event_handler._schedule_async_task( + 'test_event', ({"test": "data"},) + ) + ) + thread.start() + thread.join(timeout=1.0) + + event_handler.logger.error.assert_not_called() + event_handler.logger.debug.assert_called_with( + "Dropping forward_test_event; no active asyncio event loop" + ) + @pytest.mark.asyncio async def test_event_loop_detection(self, event_handler): """Test that event handler detects and uses correct event loop.""" @@ -410,7 +472,7 @@ async def test_get_batching_stats(self, event_handler): # Stats contain handler stats, not just an "enabled" flag assert len(stats) > 0 # Each handler should have stats with expected keys - for handler_name, handler_stats in stats.items(): + for _handler_name, handler_stats in stats.items(): assert isinstance(handler_stats, dict) assert "batches_processed" in handler_stats From 23e2fd0cec66f4b2b3bb438daa030cd1cb10ab42 Mon Sep 17 00:00:00 2001 From: Alex Kwiatkowski Date: Sun, 28 Jun 2026 21:00:27 -0700 Subject: [PATCH 2/2] harden realtime connection shutdown --- .../realtime/connection_management.py | 30 ++++++++-- src/project_x_py/realtime/event_handling.py | 5 ++ src/project_x_py/trading_suite.py | 7 ++- tests/realtime/test_connection_management.py | 18 ++++-- tests/realtime/test_event_handling.py | 32 +++++++++++ tests/trading_suite/test_core.py | 57 +++++++++++++++++++ 6 files changed, 138 insertions(+), 11 deletions(-) diff --git a/src/project_x_py/realtime/connection_management.py b/src/project_x_py/realtime/connection_management.py index b6e465a..cabf964 100644 --- a/src/project_x_py/realtime/connection_management.py +++ b/src/project_x_py/realtime/connection_management.py @@ -218,11 +218,13 @@ async def setup_connections(self: "ProjectXRealtimeClientProtocol") -> None: self.user_connection.on( "GatewayUserTrade", self._forward_trade_execution ) + self.user_connection.on("GatewayLogout", self._on_gateway_logout) # Market Hub Events self.market_connection.on("GatewayQuote", self._forward_quote_update) self.market_connection.on("GatewayTrade", self._forward_market_trade) self.market_connection.on("GatewayDepth", self._forward_market_depth) + self.market_connection.on("GatewayLogout", self._on_gateway_logout) logger.debug( LogMessages.WS_CONNECTED, extra={"phase": "setup_complete"} @@ -304,15 +306,27 @@ async def connect(self: "ProjectXRealtimeClientProtocol") -> bool: ) return False - # Wait for connections to establish + # Wait for connections to establish. Keep explicit task handles so + # timeout and shutdown paths can drain cancellations cleanly. + wait_tasks = [ + asyncio.create_task(self.user_hub_ready.wait()), + asyncio.create_task(self.market_hub_ready.wait()), + ] try: - await asyncio.wait_for( - asyncio.gather( - self.user_hub_ready.wait(), self.market_hub_ready.wait() - ), + _, pending = await asyncio.wait( + wait_tasks, timeout=10.0, ) - except TimeoutError: + except asyncio.CancelledError: + for task in wait_tasks: + task.cancel() + await asyncio.gather(*wait_tasks, return_exceptions=True) + raise + + if pending: + for task in pending: + task.cancel() + await asyncio.gather(*wait_tasks, return_exceptions=True) logger.error( LogMessages.WS_ERROR, extra={ @@ -468,6 +482,10 @@ def _on_market_hub_close(self: "ProjectXRealtimeClientProtocol") -> None: self.market_hub_ready.clear() self.logger.warning("❌ Market hub disconnected") + def _on_gateway_logout(self: "ProjectXRealtimeClientProtocol", *args: Any) -> None: + """Handle GatewayLogout events so SignalRCore does not log them as unhandled.""" + self.logger.debug("Gateway logout event received", extra={"payload": args}) + def _on_connection_error( self: "ProjectXRealtimeClientProtocol", hub: str, error: Any ) -> None: diff --git a/src/project_x_py/realtime/event_handling.py b/src/project_x_py/realtime/event_handling.py index 9aedd02..43b0a7c 100644 --- a/src/project_x_py/realtime/event_handling.py +++ b/src/project_x_py/realtime/event_handling.py @@ -69,6 +69,7 @@ async def on_quote_update(data): import asyncio from collections.abc import Callable, Coroutine +from concurrent.futures import CancelledError as FutureCancelledError from datetime import datetime from typing import TYPE_CHECKING, Any @@ -437,8 +438,12 @@ def _schedule_coroutine_threadsafe( return False def _log_task_error(task_future: Any) -> None: + if task_future.cancelled(): + return try: task_future.result() + except (asyncio.CancelledError, FutureCancelledError): + return except Exception as exc: self.logger.error(f"Async task {name} failed: {exc}", exc_info=True) diff --git a/src/project_x_py/trading_suite.py b/src/project_x_py/trading_suite.py index f1bf151..40306dd 100644 --- a/src/project_x_py/trading_suite.py +++ b/src/project_x_py/trading_suite.py @@ -51,6 +51,7 @@ from project_x_py.client import ProjectX from project_x_py.client.base import ProjectXBase from project_x_py.event_bus import EventBus, EventType +from project_x_py.exceptions import ProjectXConnectionError from project_x_py.models import Instrument from project_x_py.order_manager import OrderManager from project_x_py.order_tracker import OrderChainBuilder, OrderTracker @@ -838,7 +839,11 @@ async def _initialize(self) -> None: try: # Connect to realtime feeds logger.info("Connecting to real-time feeds...") - await self.realtime.connect() + connected = await self.realtime.connect() + if not connected: + raise ProjectXConnectionError( + "Failed to establish ProjectX realtime connections" + ) await self.realtime.subscribe_user_updates() if self._instruments: diff --git a/tests/realtime/test_connection_management.py b/tests/realtime/test_connection_management.py index c7b8432..9ac69d8 100644 --- a/tests/realtime/test_connection_management.py +++ b/tests/realtime/test_connection_management.py @@ -226,13 +226,15 @@ async def test_setup_connections_registers_event_handlers(self, mock_builder_cla "GatewayUserAccount", "GatewayUserPosition", "GatewayUserOrder", - "GatewayUserTrade" + "GatewayUserTrade", + "GatewayLogout", ] expected_market_events = [ "GatewayQuote", "GatewayTrade", - "GatewayDepth" + "GatewayDepth", + "GatewayLogout", ] # Check that all event handlers were registered @@ -385,8 +387,16 @@ async def test_connect_returns_false_on_timeout(self, mock_client): mock_client.user_connected = False mock_client.market_connected = False - # Use a very short timeout for testing - with patch('asyncio.wait_for', side_effect=TimeoutError()): + async def timeout_wait(tasks, timeout): + del timeout + return set(), set(tasks) + + # Simulate a timeout without waiting for the real 10 second timeout. + with patch( + 'asyncio.wait', + new_callable=AsyncMock, + side_effect=timeout_wait, + ): result = await mock_client.connect() assert result is False diff --git a/tests/realtime/test_event_handling.py b/tests/realtime/test_event_handling.py index 8219274..f151476 100644 --- a/tests/realtime/test_event_handling.py +++ b/tests/realtime/test_event_handling.py @@ -420,6 +420,38 @@ def test_schedule_async_task_without_loop_drops_event(self, event_handler): "Dropping forward_test_event; no active asyncio event loop" ) + @pytest.mark.asyncio + async def test_cancelled_threadsafe_task_does_not_log_error(self, event_handler): + """Test cancelled SignalR callback tasks are treated as normal shutdown.""" + + class CancelledFuture: + def cancelled(self): + return True + + def add_done_callback(self, callback): + callback(self) + + async def noop(): + pass + + def run_threadsafe(coro, loop): + del loop + coro.close() + return CancelledFuture() + + event_handler._loop = asyncio.get_running_loop() + + with patch( + "asyncio.run_coroutine_threadsafe", + side_effect=run_threadsafe, + ): + assert event_handler._schedule_coroutine_threadsafe( + noop(), + name="cancelled_task", + ) + + event_handler.logger.error.assert_not_called() + @pytest.mark.asyncio async def test_event_loop_detection(self, event_handler): """Test that event handler detects and uses correct event loop.""" diff --git a/tests/trading_suite/test_core.py b/tests/trading_suite/test_core.py index 07fa799..fcfd520 100644 --- a/tests/trading_suite/test_core.py +++ b/tests/trading_suite/test_core.py @@ -9,6 +9,7 @@ import pytest from project_x_py import Features, TradingSuite, TradingSuiteConfig +from project_x_py.exceptions import ProjectXConnectionError from project_x_py.models import Account @@ -125,6 +126,62 @@ async def test_trading_suite_create(): assert suite._initialized is False +@pytest.mark.asyncio +async def test_trading_suite_create_fails_when_realtime_connect_returns_false(): + """Test creation fails when realtime connections do not establish.""" + + mock_client = MagicMock() + mock_client.account_info = Account( + id=12345, + name="TEST_ACCOUNT", + balance=100000.0, + canTrade=True, + isVisible=True, + simulated=True, + ) + mock_client.session_token = "mock_jwt_token" + mock_client.config = MagicMock() + mock_client.authenticate = AsyncMock() + mock_client.get_instrument = AsyncMock(return_value=MagicMock(id="MNQ_CONTRACT_ID")) + mock_client.search_all_orders = AsyncMock(return_value=[]) + mock_client.search_open_positions = AsyncMock(return_value=[]) + + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_client + mock_context.__aexit__.return_value = None + + mock_realtime = MagicMock() + mock_realtime.connect = AsyncMock(return_value=False) + mock_realtime.disconnect = AsyncMock(return_value=None) + + mock_data_manager = MagicMock() + mock_data_manager.stop_realtime_feed = AsyncMock(return_value=None) + mock_data_manager.cleanup = AsyncMock(return_value=None) + + mock_position_manager = MagicMock() + + with patch( + "project_x_py.trading_suite.ProjectX.from_env", return_value=mock_context + ): + with patch( + "project_x_py.trading_suite.ProjectXRealtimeClient", + return_value=mock_realtime, + ): + with patch( + "project_x_py.trading_suite.RealtimeDataManager", + return_value=mock_data_manager, + ): + with patch( + "project_x_py.trading_suite.PositionManager", + return_value=mock_position_manager, + ): + with pytest.raises(ProjectXConnectionError): + await TradingSuite.create("MNQ") + + mock_realtime.disconnect.assert_awaited_once() + assert mock_context.__aexit__.await_count >= 1 + + @pytest.mark.asyncio async def test_trading_suite_with_features(): """Test TradingSuite creation with optional features."""