Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions src/project_x_py/realtime/connection_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,13 @@ async def setup_connections(self: "ProjectXRealtimeClientProtocol") -> None:
self.user_connection.on(
"GatewayUserTrade", self._forward_trade_execution
)
self.user_connection.on("GatewayLogout", self._on_gateway_logout)

# Market Hub Events
self.market_connection.on("GatewayQuote", self._forward_quote_update)
self.market_connection.on("GatewayTrade", self._forward_market_trade)
self.market_connection.on("GatewayDepth", self._forward_market_depth)
self.market_connection.on("GatewayLogout", self._on_gateway_logout)

logger.debug(
LogMessages.WS_CONNECTED, extra={"phase": "setup_complete"}
Expand Down Expand Up @@ -272,16 +274,16 @@ async def connect(self: "ProjectXRealtimeClientProtocol") -> bool:
operation="connect",
account_id=self.account_id,
):
if not self.setup_complete:
await self.setup_connections()

# Store the event loop for cross-thread task scheduling
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
logger.error("No running event loop found.")
return False

if not self.setup_complete:
await self.setup_connections()

logger.debug(LogMessages.WS_CONNECT)

async with self._connection_lock:
Expand All @@ -304,15 +306,27 @@ async def connect(self: "ProjectXRealtimeClientProtocol") -> bool:
)
return False

# Wait for connections to establish
# Wait for connections to establish. Keep explicit task handles so
# timeout and shutdown paths can drain cancellations cleanly.
wait_tasks = [
asyncio.create_task(self.user_hub_ready.wait()),
asyncio.create_task(self.market_hub_ready.wait()),
]
try:
await asyncio.wait_for(
asyncio.gather(
self.user_hub_ready.wait(), self.market_hub_ready.wait()
),
_, pending = await asyncio.wait(
wait_tasks,
timeout=10.0,
)
except TimeoutError:
except asyncio.CancelledError:
for task in wait_tasks:
task.cancel()
await asyncio.gather(*wait_tasks, return_exceptions=True)
raise

if pending:
for task in pending:
task.cancel()
await asyncio.gather(*wait_tasks, return_exceptions=True)
logger.error(
LogMessages.WS_ERROR,
extra={
Expand Down Expand Up @@ -468,6 +482,10 @@ def _on_market_hub_close(self: "ProjectXRealtimeClientProtocol") -> None:
self.market_hub_ready.clear()
self.logger.warning("❌ Market hub disconnected")

def _on_gateway_logout(self: "ProjectXRealtimeClientProtocol", *args: Any) -> None:
"""Handle GatewayLogout events so SignalRCore does not log them as unhandled."""
self.logger.debug("Gateway logout event received", extra={"payload": args})

def _on_connection_error(
self: "ProjectXRealtimeClientProtocol", hub: str, error: Any
) -> None:
Expand Down
82 changes: 57 additions & 25 deletions src/project_x_py/realtime/event_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async def on_quote_update(data):

import asyncio
from collections.abc import Callable, Coroutine
from concurrent.futures import CancelledError as FutureCancelledError
from datetime import datetime
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -349,10 +350,9 @@ def _forward_quote_update(self, *args: Any) -> None:
"""
if self._use_batching and self._batched_handler and args:
# Use batched processing for high-frequency quotes
self._create_task(
self._schedule_coroutine_threadsafe(
self._batched_handler.handle_quote(args[0]),
name="handle_quote",
persistent=False,
)
else:
self._schedule_async_task("quote_update", args)
Expand All @@ -372,10 +372,9 @@ def _forward_market_trade(self, *args: Any) -> None:
"""
if self._use_batching and self._batched_handler and args:
# Use batched processing for trades
self._create_task(
self._schedule_coroutine_threadsafe(
self._batched_handler.handle_trade(args[0]),
name="handle_trade",
persistent=False,
)
else:
self._schedule_async_task("market_trade", args)
Expand All @@ -395,14 +394,62 @@ def _forward_market_depth(self, *args: Any) -> None:
"""
if self._use_batching and self._batched_handler and args:
# Use batched processing for depth updates
self._create_task(
self._schedule_coroutine_threadsafe(
self._batched_handler.handle_depth(args[0]),
name="handle_depth",
persistent=False,
)
else:
self._schedule_async_task("market_depth", args)

def _active_event_loop(self) -> asyncio.AbstractEventLoop | None:
"""Return the captured asyncio loop, or capture the current running loop."""
if self._loop and not self._loop.is_closed():
return self._loop

try:
loop = asyncio.get_running_loop()
except RuntimeError:
return None

if loop.is_closed():
return None

self._loop = loop
return loop

def _schedule_coroutine_threadsafe(
self,
coro: Coroutine[Any, Any, Any],
*,
name: str,
) -> bool:
"""Schedule a coroutine on the captured loop from the SignalR thread."""
loop = self._active_event_loop()
if loop is None:
coro.close()
self.logger.debug(f"Dropping {name}; no active asyncio event loop")
return False

try:
future = asyncio.run_coroutine_threadsafe(coro, loop)
except Exception as e:
coro.close()
self.logger.error(f"Error scheduling async task {name}: {e}")
return False

def _log_task_error(task_future: Any) -> None:
if task_future.cancelled():
return
try:
task_future.result()
except (asyncio.CancelledError, FutureCancelledError):
return
except Exception as exc:
self.logger.error(f"Async task {name} failed: {exc}", exc_info=True)

future.add_done_callback(_log_task_error)
return True

def _schedule_async_task(self, event_type: str, data: Any) -> None:
"""
Schedule async task in the main event loop from any thread.
Expand All @@ -428,25 +475,10 @@ def _schedule_async_task(self, event_type: str, data: Any) -> None:
Note:
Critical for thread safety - ensures callbacks run in proper context.
"""
if self._loop and not self._loop.is_closed():
try:
asyncio.run_coroutine_threadsafe(
self._forward_event_async(event_type, data), self._loop
)
except Exception as e:
# Fallback for logging - avoid recursion
self.logger.error(f"Error scheduling async task: {e}")
else:
# Fallback - try to create task in current loop context
try:
self._create_task(
self._forward_event_async(event_type, data),
name=f"forward_{event_type}",
persistent=False,
)
except RuntimeError:
# No event loop available, log and continue
self.logger.error(f"No event loop available for {event_type} event")
self._schedule_coroutine_threadsafe(
self._forward_event_async(event_type, data),
name=f"forward_{event_type}",
)

async def _forward_event_async(self, event_type: str, args: Any) -> None:
"""
Expand Down
7 changes: 6 additions & 1 deletion src/project_x_py/trading_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from project_x_py.client import ProjectX
from project_x_py.client.base import ProjectXBase
from project_x_py.event_bus import EventBus, EventType
from project_x_py.exceptions import ProjectXConnectionError
from project_x_py.models import Instrument
from project_x_py.order_manager import OrderManager
from project_x_py.order_tracker import OrderChainBuilder, OrderTracker
Expand Down Expand Up @@ -838,7 +839,11 @@ async def _initialize(self) -> None:
try:
# Connect to realtime feeds
logger.info("Connecting to real-time feeds...")
await self.realtime.connect()
connected = await self.realtime.connect()
if not connected:
raise ProjectXConnectionError(
"Failed to establish ProjectX realtime connections"
)
await self.realtime.subscribe_user_updates()

if self._instruments:
Expand Down
18 changes: 14 additions & 4 deletions tests/realtime/test_connection_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,15 @@ async def test_setup_connections_registers_event_handlers(self, mock_builder_cla
"GatewayUserAccount",
"GatewayUserPosition",
"GatewayUserOrder",
"GatewayUserTrade"
"GatewayUserTrade",
"GatewayLogout",
]

expected_market_events = [
"GatewayQuote",
"GatewayTrade",
"GatewayDepth"
"GatewayDepth",
"GatewayLogout",
]

# Check that all event handlers were registered
Expand Down Expand Up @@ -385,8 +387,16 @@ async def test_connect_returns_false_on_timeout(self, mock_client):
mock_client.user_connected = False
mock_client.market_connected = False

# Use a very short timeout for testing
with patch('asyncio.wait_for', side_effect=TimeoutError()):
async def timeout_wait(tasks, timeout):
del timeout
return set(), set(tasks)

# Simulate a timeout without waiting for the real 10 second timeout.
with patch(
'asyncio.wait',
new_callable=AsyncMock,
side_effect=timeout_wait,
):
result = await mock_client.connect()

assert result is False
Expand Down
96 changes: 95 additions & 1 deletion tests/realtime/test_event_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,25 @@ async def test_batched_handler_cleanup(self, event_handler):
class TestCrossThreadEventScheduling:
"""Test cross-thread event scheduling for asyncio compatibility."""

@pytest.mark.asyncio
async def test_schedule_async_task_captures_running_loop(self, event_handler):
"""Test scheduling captures the active event loop when none was stored."""
event_data = {"test": "data"}
received = asyncio.Event()

async def callback(data):
assert data == event_data
received.set()

await event_handler.add_callback('test_event', callback)

assert event_handler._loop is None

event_handler._schedule_async_task('test_event', (event_data,))

await asyncio.wait_for(received.wait(), timeout=1.0)
assert event_handler._loop == asyncio.get_running_loop()

@pytest.mark.asyncio
async def test_schedule_event_from_different_thread(self, event_handler):
"""Test scheduling event from a different thread."""
Expand Down Expand Up @@ -358,6 +377,81 @@ def thread_func():

callback.assert_called_once_with(event_data)

@pytest.mark.asyncio
async def test_schedule_async_task_from_signalr_thread(self, event_handler):
"""Test SignalR thread events use the captured asyncio loop."""
import threading

event_data = {"test": "data"}
received = asyncio.Event()

async def callback(data):
assert data == event_data
received.set()

await event_handler.add_callback('test_event', callback)

event_handler._loop = asyncio.get_running_loop()

thread = threading.Thread(
target=lambda: event_handler._schedule_async_task(
'test_event', (event_data,)
)
)
thread.start()
thread.join(timeout=1.0)

await asyncio.wait_for(received.wait(), timeout=1.0)

def test_schedule_async_task_without_loop_drops_event(self, event_handler):
"""Test late SignalR events are dropped quietly when no loop exists."""
import threading

thread = threading.Thread(
target=lambda: event_handler._schedule_async_task(
'test_event', ({"test": "data"},)
)
)
thread.start()
thread.join(timeout=1.0)

event_handler.logger.error.assert_not_called()
event_handler.logger.debug.assert_called_with(
"Dropping forward_test_event; no active asyncio event loop"
)

@pytest.mark.asyncio
async def test_cancelled_threadsafe_task_does_not_log_error(self, event_handler):
"""Test cancelled SignalR callback tasks are treated as normal shutdown."""

class CancelledFuture:
def cancelled(self):
return True

def add_done_callback(self, callback):
callback(self)

async def noop():
pass

def run_threadsafe(coro, loop):
del loop
coro.close()
return CancelledFuture()

event_handler._loop = asyncio.get_running_loop()

with patch(
"asyncio.run_coroutine_threadsafe",
side_effect=run_threadsafe,
):
assert event_handler._schedule_coroutine_threadsafe(
noop(),
name="cancelled_task",
)

event_handler.logger.error.assert_not_called()

@pytest.mark.asyncio
async def test_event_loop_detection(self, event_handler):
"""Test that event handler detects and uses correct event loop."""
Expand Down Expand Up @@ -410,7 +504,7 @@ async def test_get_batching_stats(self, event_handler):
# Stats contain handler stats, not just an "enabled" flag
assert len(stats) > 0
# Each handler should have stats with expected keys
for handler_name, handler_stats in stats.items():
for _handler_name, handler_stats in stats.items():
assert isinstance(handler_stats, dict)
assert "batches_processed" in handler_stats

Expand Down
Loading