From c26542491756b8bc6d053cfc41ffe406931126ee Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 26 May 2026 16:52:49 -0400 Subject: [PATCH 1/5] improved websocket connection management --- .../src/WebSocketManager.ts | 8 + @plotly/dash-websocket-worker/src/types.ts | 1 + @plotly/dash-websocket-worker/src/worker.ts | 7 + dash/backends/_fastapi.py | 32 +- dash/backends/_quart.py | 32 +- dash/backends/_ws_registry.py | 166 +++++++ dash/backends/ws.py | 157 ++++-- .../src/observers/websocketObserver.ts | 60 +++ dash/dash-renderer/src/utils/workerClient.ts | 14 + tests/websocket/test_ws_reconnect.py | 449 ++++++++++++++++++ 10 files changed, 879 insertions(+), 47 deletions(-) create mode 100644 dash/backends/_ws_registry.py create mode 100644 tests/websocket/test_ws_reconnect.py diff --git a/@plotly/dash-websocket-worker/src/WebSocketManager.ts b/@plotly/dash-websocket-worker/src/WebSocketManager.ts index d96a0d8e68..942917f075 100644 --- a/@plotly/dash-websocket-worker/src/WebSocketManager.ts +++ b/@plotly/dash-websocket-worker/src/WebSocketManager.ts @@ -144,6 +144,14 @@ export class WebSocketManager { return this.ws !== null && this.ws.readyState === WebSocket.OPEN; } + /** + * Reset the activity timer. + * Call this when a tab becomes visible to prevent inactivity timeout. + */ + public resetActivity(): void { + this.lastActivityTime = Date.now(); + } + private createConnection(): void { if (!this.serverUrl) { return; diff --git a/@plotly/dash-websocket-worker/src/types.ts b/@plotly/dash-websocket-worker/src/types.ts index 5d1ff80bf0..3d88af068a 100644 --- a/@plotly/dash-websocket-worker/src/types.ts +++ b/@plotly/dash-websocket-worker/src/types.ts @@ -7,6 +7,7 @@ export enum WorkerMessageType { DISCONNECT = 'disconnect', CALLBACK_REQUEST = 'callback_request', GET_PROPS_RESPONSE = 'get_props_response', + TAB_VISIBLE = 'tab_visible', // Worker -> Renderer CONNECTED = 'connected', diff --git a/@plotly/dash-websocket-worker/src/worker.ts b/@plotly/dash-websocket-worker/src/worker.ts index 0e68f0b09a..e28d0c1583 100644 --- a/@plotly/dash-websocket-worker/src/worker.ts +++ b/@plotly/dash-websocket-worker/src/worker.ts @@ -122,6 +122,13 @@ self.onconnect = (event: MessageEvent) => { break; } + case WorkerMessageType.TAB_VISIBLE: { + // Reset activity timer when tab becomes visible + // This prevents inactivity timeout while user is viewing the tab + wsManager.resetActivity(); + break; + } + default: // Forward other messages through the router router.handleRendererMessage(message.rendererId, message); diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index c46fb4ffc5..b619c040c9 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -50,6 +50,7 @@ SHUTDOWN_SIGNAL, DISCONNECTED, ) +from ._ws_registry import ActiveCallbackRegistry from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only @@ -677,6 +678,12 @@ def serve_websocket_callback(self, dash_app: "Dash"): dash_app, "_websocket_allowed_origins", [] ) # pylint: disable=protected-access + # Initialize registry on dash_app if not present + # pylint: disable=protected-access + if not hasattr(dash_app, "_ws_callback_registry"): + dash_app._ws_callback_registry = ActiveCallbackRegistry() + registry: ActiveCallbackRegistry = dash_app._ws_callback_registry + def validate_origin(origin: str | None, host: str | None) -> str | None: """Validate WebSocket origin. Returns error message or None if valid.""" if not origin: @@ -723,6 +730,8 @@ async def websocket_handler(websocket: WebSocket): executor = self.get_callback_executor() # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} + # Track current renderer ID for this connection + current_renderer_id: str | None = None # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -753,6 +762,22 @@ async def websocket_handler(websocket: WebSocket): renderer_id = message.get("rendererId", "") payload = message.get("payload", {}) + # Update current renderer ID for cleanup + current_renderer_id = renderer_id + + # Adopt connection for this renderer (allows reconnection) + # Called for every callback to ensure registry entry exists + # (entry may have been cleaned up after previous callback) + registry.adopt_connection( + renderer_id, + outbound_queue, + pending_get_props, + shutdown_event, + ) + + # Register this callback with the registry + registry.register_callback(renderer_id) + # Validate that the callback is allowed to use WebSocket transport # pylint: disable=protected-access _validate.validate_websocket_callback_request( @@ -761,12 +786,13 @@ async def websocket_handler(websocket: WebSocket): dash_app._websocket_callbacks, ) - # Create WebSocket callback instance with outbound queue + # Create WebSocket callback instance with registry ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, shutdown_event, + registry=registry, ) # Submit callback to executor @@ -786,6 +812,7 @@ async def websocket_handler(websocket: WebSocket): request_id, renderer_id, shutdown_event, + registry=registry, ) ) pending_callbacks[request_id] = future @@ -821,6 +848,9 @@ async def websocket_handler(websocket: WebSocket): # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() + # Cleanup registry entry if no active callbacks + if current_renderer_id is not None: + registry.cleanup_renderer(current_renderer_id) self.server.add_api_websocket_route(ws_path, websocket_handler) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 881fd6466f..31d9668d36 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -55,6 +55,7 @@ SHUTDOWN_SIGNAL, DISCONNECTED, ) +from ._ws_registry import ActiveCallbackRegistry from ._utils import format_traceback_html if TYPE_CHECKING: @@ -521,6 +522,12 @@ def serve_websocket_callback(self, dash_app: "Dash"): # pylint: disable=protected-access allowed_origins = getattr(dash_app, "_websocket_allowed_origins", []) + # Initialize registry on dash_app if not present + # pylint: disable=protected-access + if not hasattr(dash_app, "_ws_callback_registry"): + dash_app._ws_callback_registry = ActiveCallbackRegistry() + registry: ActiveCallbackRegistry = dash_app._ws_callback_registry + @self.server.websocket(ws_path) async def websocket_handler(): # pylint: disable=too-many-branches ws = websocket @@ -564,6 +571,8 @@ async def websocket_handler(): # pylint: disable=too-many-branches executor = self.get_callback_executor() # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} + # Track current renderer ID for this connection + current_renderer_id: str | None = None # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -601,6 +610,22 @@ async def websocket_handler(): # pylint: disable=too-many-branches renderer_id = message.get("rendererId", "") payload = message.get("payload", {}) + # Update current renderer ID for cleanup + current_renderer_id = renderer_id + + # Adopt connection for this renderer (allows reconnection) + # Called for every callback to ensure registry entry exists + # (entry may have been cleaned up after previous callback) + registry.adopt_connection( + renderer_id, + outbound_queue, + pending_get_props, + connection_shutdown_event, + ) + + # Register this callback with the registry + registry.register_callback(renderer_id) + # Validate that the callback is allowed to use WebSocket transport # pylint: disable=protected-access _validate.validate_websocket_callback_request( @@ -609,12 +634,13 @@ async def websocket_handler(): # pylint: disable=too-many-branches dash_app._websocket_callbacks, ) - # Create WebSocket callback instance with outbound queue + # Create WebSocket callback instance with registry ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, connection_shutdown_event, + registry=registry, ) # Submit callback to executor @@ -634,6 +660,7 @@ async def websocket_handler(): # pylint: disable=too-many-branches request_id, renderer_id, connection_shutdown_event, + registry=registry, ) ) pending_callbacks[request_id] = future @@ -672,6 +699,9 @@ async def websocket_handler(): # pylint: disable=too-many-branches # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() + # Cleanup registry entry if no active callbacks + if current_renderer_id is not None: + registry.cleanup_renderer(current_renderer_id) class QuartRequestAdapter(RequestAdapter): diff --git a/dash/backends/_ws_registry.py b/dash/backends/_ws_registry.py new file mode 100644 index 0000000000..4ab1132d6a --- /dev/null +++ b/dash/backends/_ws_registry.py @@ -0,0 +1,166 @@ +"""WebSocket callback registry for handling reconnections. + +This module provides a registry that tracks active callbacks per renderer, +allowing callbacks to persist across WebSocket reconnections. +""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Optional + +if TYPE_CHECKING: + import queue + import janus + + +@dataclass +class RendererState: + """State for a single renderer's WebSocket connection.""" + + outbound_queue: "janus.Queue[str]" + pending_get_props: Dict[str, "queue.Queue[Any]"] + shutdown_event: threading.Event + active_callback_count: int = 0 + lock: threading.Lock = field(default_factory=threading.Lock) + + +class ActiveCallbackRegistry: + """Registry for active WebSocket callbacks that persists across reconnections. + + When a WebSocket disconnects and reconnects, callbacks that are still running + can "adopt" the new connection's queues to continue sending updates. + + Thread-safe for access from both the main event loop and worker threads. + """ + + def __init__(self) -> None: + self._renderers: Dict[str, RendererState] = {} + self._lock = threading.Lock() + + def adopt_connection( + self, + renderer_id: str, + outbound_queue: "janus.Queue[str]", + pending_get_props: Dict[str, "queue.Queue[Any]"], + shutdown_event: threading.Event, + ) -> None: + """Associate new connection with existing callbacks for this renderer. + + When a WebSocket reconnects, this method updates the queues and shutdown + event so that running callbacks can use the new connection. + + Args: + renderer_id: The renderer ID for this connection + outbound_queue: janus.Queue for sending messages + pending_get_props: Dict to track pending get_props requests + shutdown_event: Event signaling connection closure + """ + with self._lock: + if renderer_id in self._renderers: + state = self._renderers[renderer_id] + with state.lock: + state.outbound_queue = outbound_queue + state.pending_get_props = pending_get_props + state.shutdown_event = shutdown_event + else: + self._renderers[renderer_id] = RendererState( + outbound_queue=outbound_queue, + pending_get_props=pending_get_props, + shutdown_event=shutdown_event, + active_callback_count=0, + ) + + def register_callback(self, renderer_id: str) -> None: + """Register a new active callback for this renderer. + + Args: + renderer_id: The renderer ID + """ + with self._lock: + if renderer_id in self._renderers: + state = self._renderers[renderer_id] + with state.lock: + state.active_callback_count += 1 + + def unregister_callback(self, renderer_id: str) -> None: + """Unregister a completed callback for this renderer. + + If no active callbacks remain, the renderer state is cleaned up. + + Args: + renderer_id: The renderer ID + """ + with self._lock: + if renderer_id in self._renderers: + state = self._renderers[renderer_id] + with state.lock: + state.active_callback_count -= 1 + if state.active_callback_count <= 0: + del self._renderers[renderer_id] + + def get_queue(self, renderer_id: str) -> Optional["janus.Queue[str]"]: + """Get current outbound queue for renderer (thread-safe). + + Args: + renderer_id: The renderer ID + + Returns: + The current outbound queue, or None if renderer not found + """ + with self._lock: + state = self._renderers.get(renderer_id) + if state is None: + return None + with state.lock: + return state.outbound_queue + + def get_pending_get_props( + self, renderer_id: str + ) -> Optional[Dict[str, "queue.Queue[Any]"]]: + """Get current pending_get_props dict for renderer (thread-safe). + + Args: + renderer_id: The renderer ID + + Returns: + The current pending_get_props dict, or None if renderer not found + """ + with self._lock: + state = self._renderers.get(renderer_id) + if state is None: + return None + with state.lock: + return state.pending_get_props + + def is_shutdown(self, renderer_id: str) -> bool: + """Check if current connection is shutdown. + + Args: + renderer_id: The renderer ID + + Returns: + True if shutdown event is set or renderer not found, False otherwise + """ + with self._lock: + state = self._renderers.get(renderer_id) + if state is None: + return True + with state.lock: + return state.shutdown_event.is_set() + + def cleanup_renderer(self, renderer_id: str) -> None: + """Clean up renderer state when connection closes. + + Only removes if no active callbacks remain. + + Args: + renderer_id: The renderer ID to clean up + """ + with self._lock: + state = self._renderers.get(renderer_id) + if state is not None: + with state.lock: + if state.active_callback_count <= 0: + del self._renderers[renderer_id] diff --git a/dash/backends/ws.py b/dash/backends/ws.py index f44913d873..adebdf2c4a 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -25,6 +25,7 @@ if TYPE_CHECKING: import dash from .base_server import ResponseAdapter + from ._ws_registry import ActiveCallbackRegistry SHUTDOWN_SIGNAL = "__shutdown__" @@ -41,6 +42,10 @@ class DashWebsocketCallback: Uses janus.Queue for outbound messages (serialized with to_json) and queue.Queue for get_props responses, enabling thread-safe communication between worker threads and the main event loop. + + Supports two modes: + 1. Registry mode: Uses ActiveCallbackRegistry to allow queue adoption on reconnect + 2. Direct mode: Uses direct queue references (legacy, for backwards compatibility) """ def __init__( @@ -49,6 +54,7 @@ def __init__( renderer_id: str, outbound_queue: janus.Queue[str], shutdown_event: "threading.Event", + registry: "ActiveCallbackRegistry | None" = None, ): """Initialize the WebSocket callback interface. @@ -58,26 +64,46 @@ def __init__( renderer_id: The renderer ID for routing messages back to the correct client outbound_queue: janus.Queue for thread-safe outbound messaging. shutdown_event: Event signaling the websocket connection has closed. + registry: Optional registry for handling reconnections. If provided, + the callback will use the registry to get current queues, allowing + it to survive reconnections. """ self._pending_get_props = pending_get_props self._renderer_id = renderer_id self._outbound_queue = outbound_queue self._shutdown_event = shutdown_event + self._registry = registry @property def is_shutdown(self) -> bool: """Check if the websocket connection has been shut down.""" + if self._registry is not None: + return self._registry.is_shutdown(self._renderer_id) return self._shutdown_event.is_set() + def _get_outbound_queue(self) -> janus.Queue[str] | None: + """Get the current outbound queue (may be updated on reconnect).""" + if self._registry is not None: + return self._registry.get_queue(self._renderer_id) + return self._outbound_queue + + def _get_pending_get_props(self) -> Dict[str, queue.Queue[Any]] | None: + """Get the current pending_get_props dict (may be updated on reconnect).""" + if self._registry is not None: + return self._registry.get_pending_get_props(self._renderer_id) + return self._pending_get_props + def _queue_message(self, msg: dict) -> None: """Serialize and queue message for sending (thread-safe, non-blocking). Uses to_json for proper serialization of Dash components. Does nothing if the connection has been shut down. """ - if self._shutdown_event.is_set(): + if self.is_shutdown: return - self._outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) + outbound_queue = self._get_outbound_queue() + if outbound_queue is not None: + outbound_queue.sync_q.put_nowait(cast(str, to_json(msg))) async def set_prop(self, component_id: str, prop_name: str, value: Any) -> None: """Send immediate prop update to the client via WebSocket. @@ -115,7 +141,11 @@ async def get_prop( WebsocketDisconnected: If the websocket connection has been closed. TimeoutError: If the response doesn't arrive within the timeout. """ - if self._shutdown_event.is_set(): + if self.is_shutdown: + raise WebsocketDisconnected() + + pending_get_props = self._get_pending_get_props() + if pending_get_props is None: raise WebsocketDisconnected() request_id = str(uuid.uuid4()) @@ -128,7 +158,7 @@ async def get_prop( # Use standard queue.Queue for response response_queue: queue.Queue = queue.Queue() - self._pending_get_props[request_id] = response_queue + pending_get_props[request_id] = response_queue # Queue the outbound request via janus sync interface self._queue_message(msg) @@ -146,7 +176,10 @@ async def get_prop( f"Timeout waiting for {component_id}.{prop_name}" ) from exc finally: - self._pending_get_props.pop(request_id, None) + # Get fresh reference in case of reconnection + current_pending = self._get_pending_get_props() + if current_pending is not None: + current_pending.pop(request_id, None) def create_ws_context( @@ -219,21 +252,26 @@ async def run_ws_sender( return if msg == FLUSH_SIGNAL: if messages: - await _send_batched(send_text, messages) + if not await _send_batched(send_text, messages): + return # Connection closed messages = [] continue if not batch_delay: - await send_text(msg) + try: + await send_text(msg) + except Exception: # WebSocketDisconnect, RuntimeError, etc. + return # Connection closed else: messages.append(msg) except asyncio.TimeoutError: - await _send_batched(send_text, messages) + if not await _send_batched(send_text, messages): + return # Connection closed messages = [] except asyncio.CancelledError: pass -async def _send_batched(send_text: Callable[[str], Any], messages: list) -> None: +async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool: """Send messages as a batch. Single messages are sent as-is. Multiple messages are wrapped @@ -242,12 +280,19 @@ async def _send_batched(send_text: Callable[[str], Any], messages: list) -> None Args: send_text: Async function to send text data over WebSocket messages: List of pre-serialized JSON message strings + + Returns: + True if send succeeded, False if connection was closed """ - if len(messages) == 1: - await send_text(messages[0]) - else: - # Wrap in array: "[msg1,msg2,msg3]" - await send_text("[" + ",".join(messages) + "]") + try: + if len(messages) == 1: + await send_text(messages[0]) + else: + # Wrap in array: "[msg1,msg2,msg3]" + await send_text("[" + ",".join(messages) + "]") + return True + except Exception: # WebSocketDisconnect, RuntimeError, etc. + return False # Connection closed, cleanup handled by main loop def make_callback_done_handler( @@ -256,6 +301,7 @@ def make_callback_done_handler( request_id: str, renderer_id: str, shutdown_event: threading.Event, + registry: "ActiveCallbackRegistry | None" = None, ) -> Callable[[concurrent.futures.Future], None]: """Create a done callback handler for executor futures. @@ -268,52 +314,73 @@ def make_callback_done_handler( request_id: The request ID for the callback response renderer_id: The renderer ID for routing the response shutdown_event: Event signaling the websocket connection has closed. + registry: Optional registry for managing callback lifecycle. Returns: A callback function suitable for Future.add_done_callback() """ + def _is_shutdown() -> bool: + """Check if connection is shutdown (registry-aware).""" + if registry is not None: + return registry.is_shutdown(renderer_id) + return shutdown_event.is_set() + + def _get_queue() -> janus.Queue[str] | None: + """Get current outbound queue (may change on reconnect).""" + if registry is not None: + return registry.get_queue(renderer_id) + return outbound_queue + def on_done(f: concurrent.futures.Future) -> None: try: - if shutdown_event.is_set(): + if _is_shutdown(): return result = f.result() - outbound_queue.sync_q.put_nowait( - cast( - str, - to_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": result, - } - ), + current_queue = _get_queue() + if current_queue is not None: + current_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": result, + } + ), + ) ) - ) except Exception as e: # pylint: disable=broad-exception-caught - if shutdown_event.is_set(): + if _is_shutdown(): return - outbound_queue.sync_q.put_nowait( - cast( - str, - to_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": { - "status": "error", - "message": str(e), - }, - } - ), + current_queue = _get_queue() + if current_queue is not None: + current_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": { + "status": "error", + "message": str(e), + }, + } + ), + ) ) - ) finally: pending_callbacks.pop(request_id, None) - if not shutdown_event.is_set(): - outbound_queue.sync_q.put_nowait(FLUSH_SIGNAL) + if registry is not None: + registry.unregister_callback(renderer_id) + if not _is_shutdown(): + current_queue = _get_queue() + if current_queue is not None: + current_queue.sync_q.put_nowait(FLUSH_SIGNAL) return on_done diff --git a/dash/dash-renderer/src/observers/websocketObserver.ts b/dash/dash-renderer/src/observers/websocketObserver.ts index daa3238773..24dc5a39d4 100644 --- a/dash/dash-renderer/src/observers/websocketObserver.ts +++ b/dash/dash-renderer/src/observers/websocketObserver.ts @@ -18,6 +18,8 @@ import { GetPropsRequestPayload } from '../utils/workerClient'; import {DashConfig} from '../config'; +import {addRequestedCallbacks} from '../actions/callbacks'; +import {makeResolvedCallback, resolveDeps} from '../actions/dependencies_ts'; /** * Parse a component ID that may be a stringified JSON object. @@ -175,13 +177,53 @@ export async function initializeWebSocket( workerClient.sendGetPropsResponse(requestId, result); }; + // Track connection state for reconnection handling + let wasDisconnected = false; + // Handle connection events workerClient.onConnected = () => { console.log('[Dash] WebSocket connected'); + + // On reconnect (not initial connect), re-trigger persistent callbacks + if (wasDisconnected) { + console.log( + '[Dash] Reconnected - re-triggering persistent callbacks' + ); + const state = store.getState(); + const {graphs} = state; + + if (graphs?.callbacks) { + const persistentCallbacks = graphs.callbacks.reduce( + (acc: any[], cb: any) => { + // Only re-trigger no-output callbacks with no inputs + // These are the "persistent" callbacks that should restart + if (cb.noOutput && cb.inputs.length === 0) { + const resolved = makeResolvedCallback( + cb, + resolveDeps(), + '' + ); + resolved.initialCall = true; + acc.push(resolved); + } + return acc; + }, + [] + ); + + if (persistentCallbacks.length > 0) { + console.log( + `[Dash] Re-triggering ${persistentCallbacks.length} persistent callback(s)` + ); + store.dispatch(addRequestedCallbacks(persistentCallbacks)); + } + } + } }; workerClient.onDisconnected = (reason?: string) => { console.log(`[Dash] WebSocket disconnected: ${reason}`); + wasDisconnected = true; }; workerClient.onError = (message: string, code?: string) => { @@ -201,6 +243,24 @@ export async function initializeWebSocket( } catch (error) { console.error('[Dash] Failed to connect to WebSocket worker:', error); } + + // Handle tab visibility changes + document.addEventListener('visibilitychange', () => { + if (document.visibilityState === 'visible') { + if (workerClient.connected) { + // Tab visible and connected - reset inactivity timer + workerClient.notifyTabVisible(); + } else { + // Tab visible but disconnected - reconnect + console.log('[Dash] Tab visible, reconnecting WebSocket...'); + workerClient + .ensureConnected(config) + .catch(err => + console.error('[Dash] Failed to reconnect:', err) + ); + } + } + }); } /** diff --git a/dash/dash-renderer/src/utils/workerClient.ts b/dash/dash-renderer/src/utils/workerClient.ts index 01584bf20c..ebfb1223eb 100644 --- a/dash/dash-renderer/src/utils/workerClient.ts +++ b/dash/dash-renderer/src/utils/workerClient.ts @@ -10,6 +10,7 @@ export enum WorkerMessageType { DISCONNECT = 'disconnect', CALLBACK_REQUEST = 'callback_request', GET_PROPS_RESPONSE = 'get_props_response', + TAB_VISIBLE = 'tab_visible', CONNECTED = 'connected', DISCONNECTED = 'disconnected', CALLBACK_RESPONSE = 'callback_response', @@ -251,6 +252,19 @@ class WorkerClient { return this.isConnected; } + /** + * Notify the worker that the tab is now visible. + * This resets the inactivity timer to prevent timeout while user is viewing. + */ + public notifyTabVisible(): void { + if (this.worker && this.isConnected) { + this.worker.port.postMessage({ + type: WorkerMessageType.TAB_VISIBLE, + rendererId: this.rendererId + }); + } + } + private handleMessage(event: MessageEvent): void { const message = event.data; diff --git a/tests/websocket/test_ws_reconnect.py b/tests/websocket/test_ws_reconnect.py new file mode 100644 index 0000000000..64826ab1e6 --- /dev/null +++ b/tests/websocket/test_ws_reconnect.py @@ -0,0 +1,449 @@ +""" +WebSocket reconnection and disconnect handling tests. + +Tests: +- Callback continuity after WebSocket reconnection +- Registry tracks active callbacks correctly +- Disconnect handling doesn't cause error spam +- Long-running callbacks survive reconnection +""" + +import asyncio +import time +import threading + +from dash import Dash, html, Input, Output, set_props +from dash.backends._ws_registry import ActiveCallbackRegistry + + +class TestActiveCallbackRegistry: + """Unit tests for the ActiveCallbackRegistry class.""" + + def test_registry_adopt_creates_entry(self): + """Test that adopt_connection creates a new registry entry.""" + registry = ActiveCallbackRegistry() + + # Mock queue-like object + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + pending_get_props = {} + shutdown_event = threading.Event() + + registry.adopt_connection( + "renderer1", outbound_queue, pending_get_props, shutdown_event + ) + + assert registry.get_queue("renderer1") == outbound_queue + assert registry.get_pending_get_props("renderer1") == pending_get_props + assert not registry.is_shutdown("renderer1") + + def test_registry_callback_lifecycle(self): + """Test register/unregister callback with cleanup.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + shutdown_event = threading.Event() + + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + + # Register callback + registry.register_callback("renderer1") + assert not registry.is_shutdown("renderer1") + + # Unregister - should clean up entry since count becomes 0 + registry.unregister_callback("renderer1") + assert registry.is_shutdown("renderer1") # Returns True when not found + + def test_registry_multiple_callbacks(self): + """Test that multiple callbacks keep entry alive.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + shutdown_event = threading.Event() + + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + + # Register two callbacks + registry.register_callback("renderer1") + registry.register_callback("renderer1") + + # Unregister one - entry should still exist + registry.unregister_callback("renderer1") + assert not registry.is_shutdown("renderer1") + + # Unregister second - now should be cleaned up + registry.unregister_callback("renderer1") + assert registry.is_shutdown("renderer1") + + def test_registry_adopt_after_cleanup(self): + """Test that adopt_connection works after cleanup.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + shutdown_event = threading.Event() + + # First connection + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + registry.register_callback("renderer1") + registry.unregister_callback("renderer1") # Cleans up + + # Re-adopt after cleanup + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + assert not registry.is_shutdown("renderer1") + + def test_registry_adopt_updates_existing(self): + """Test that adopt_connection updates queues for existing entry.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self, name): + self.name = name + self.sync_q = None + + old_queue = MockQueue("old") + new_queue = MockQueue("new") + old_shutdown = threading.Event() + new_shutdown = threading.Event() + + registry.adopt_connection("renderer1", old_queue, {}, old_shutdown) + registry.register_callback("renderer1") # Keep entry alive + + assert registry.get_queue("renderer1").name == "old" + + # Simulate reconnection + registry.adopt_connection("renderer1", new_queue, {}, new_shutdown) + + assert registry.get_queue("renderer1").name == "new" + + def test_registry_shutdown_event_respected(self): + """Test that shutdown event is checked correctly.""" + registry = ActiveCallbackRegistry() + + class MockQueue: + def __init__(self): + self.sync_q = None + + outbound_queue = MockQueue() + shutdown_event = threading.Event() + + registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) + + assert not registry.is_shutdown("renderer1") + + shutdown_event.set() + + assert registry.is_shutdown("renderer1") + + def test_registry_unknown_renderer_is_shutdown(self): + """Test that unknown renderer IDs report as shutdown.""" + registry = ActiveCallbackRegistry() + + assert registry.is_shutdown("unknown_renderer") + assert registry.get_queue("unknown_renderer") is None + assert registry.get_pending_get_props("unknown_renderer") is None + + +def test_ws030_multiple_callbacks_same_connection(dash_duo): + """Test multiple sequential callbacks on the same WebSocket connection.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Click", id="btn", n_clicks=0), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # Multiple clicks - each should work via the same connection + for i in range(1, 6): + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", str(i)) + + assert dash_duo.get_logs() == [] + + +def test_ws031_rapid_callbacks_registry_handling(dash_duo): + """Test that rapid callbacks are handled correctly by registry.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Rapid Click", id="btn", n_clicks=0), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # Rapid clicks without waiting + for _ in range(10): + dash_duo.find_element("#btn").click() + time.sleep(0.05) # 50ms between clicks + + # Should eventually reach 10 + dash_duo.wait_for_text_to_equal("#output", "10", timeout=10) + + assert dash_duo.get_logs() == [] + + +def test_ws032_long_callback_with_set_props(dash_duo): + """Test long-running callback with intermediate set_props updates.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Start", id="btn", n_clicks=0), + html.Div("ready", id="status"), + html.Div("0", id="progress"), + ] + ) + + @app.callback( + Output("status", "children"), + Input("btn", "n_clicks"), + prevent_initial_call=True, + ) + async def long_task(n_clicks): + set_props("status", {"children": "running"}) + + # Simulate progress updates + for i in range(1, 6): + set_props("progress", {"children": str(i * 20)}) + await asyncio.sleep(0.1) + + return "done" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#status", "ready") + + dash_duo.find_element("#btn").click() + + # Should see intermediate updates + dash_duo.wait_for_text_to_equal("#status", "done", timeout=10) + dash_duo.wait_for_text_to_equal("#progress", "100") + + assert dash_duo.get_logs() == [] + + +def test_ws033_callback_after_reconnect(dash_duo): + """Test that callbacks work after WebSocket reconnection.""" + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Click", id="btn", n_clicks=0), + html.Div("0", id="output"), + ] + ) + + @app.callback(Output("output", "children"), Input("btn", "n_clicks")) + def on_click(n_clicks): + return str(n_clicks or 0) + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "0") + + # First click + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "1") + + # Wait for connection to timeout + time.sleep(3) + + # Click after reconnection - should still work + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "2") + + # Multiple clicks after reconnection + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "3") + + assert dash_duo.get_logs() == [] + + +def test_ws034_concurrent_callbacks(dash_duo): + """Test multiple concurrent callbacks from different inputs.""" + app = Dash(__name__, backend="fastapi", websocket_callbacks=True) + + app.layout = html.Div( + [ + html.Button("Button A", id="btn-a", n_clicks=0), + html.Button("Button B", id="btn-b", n_clicks=0), + html.Div("a:0", id="output-a"), + html.Div("b:0", id="output-b"), + ] + ) + + @app.callback(Output("output-a", "children"), Input("btn-a", "n_clicks")) + async def on_click_a(n_clicks): + await asyncio.sleep(0.1) # Small delay to ensure overlap + return f"a:{n_clicks or 0}" + + @app.callback(Output("output-b", "children"), Input("btn-b", "n_clicks")) + async def on_click_b(n_clicks): + await asyncio.sleep(0.1) + return f"b:{n_clicks or 0}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output-a", "a:0") + dash_duo.wait_for_text_to_equal("#output-b", "b:0") + + # Click both buttons rapidly + dash_duo.find_element("#btn-a").click() + dash_duo.find_element("#btn-b").click() + + dash_duo.wait_for_text_to_equal("#output-a", "a:1") + dash_duo.wait_for_text_to_equal("#output-b", "b:1") + + # More concurrent clicks + dash_duo.find_element("#btn-a").click() + dash_duo.find_element("#btn-b").click() + dash_duo.find_element("#btn-a").click() + + dash_duo.wait_for_text_to_equal("#output-a", "a:3") + dash_duo.wait_for_text_to_equal("#output-b", "b:2") + + assert dash_duo.get_logs() == [] + + +def test_ws035_callback_survives_inactivity_timeout(dash_duo): + """Test that long callback completes even when inactivity timeout triggers mid-execution. + + This is the key test for Issue #3788: when a callback runs longer than the + inactivity timeout without sending updates, the WebSocket disconnects and + reconnects. The callback should still complete and send its result via the + new connection. + """ + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Start", id="btn", n_clicks=0), + html.Div("ready", id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn", "n_clicks"), + prevent_initial_call=True, + ) + async def silent_long_task(n_clicks): + # Wait longer than inactivity timeout WITHOUT sending any updates + # This will trigger WebSocket disconnect/reconnect mid-callback + await asyncio.sleep(5) + return f"completed:{n_clicks}" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#output", "ready") + + # Start the long task + dash_duo.find_element("#btn").click() + + # Should complete despite inactivity timeout triggering during execution + dash_duo.wait_for_text_to_equal("#output", "completed:1", timeout=15) + + # Verify subsequent callbacks also work + dash_duo.find_element("#btn").click() + dash_duo.wait_for_text_to_equal("#output", "completed:2", timeout=15) + + assert dash_duo.get_logs() == [] + + +def test_ws036_set_props_after_reconnect(dash_duo): + """Test that set_props works after WebSocket reconnects mid-callback. + + This tests the registry's ability to adopt new queues so that + set_props calls use the new connection after reconnection. + """ + app = Dash( + __name__, + backend="fastapi", + websocket_callbacks=True, + websocket_inactivity_timeout=2000, # 2 seconds + ) + + app.layout = html.Div( + [ + html.Button("Start", id="btn", n_clicks=0), + html.Div("ready", id="status"), + html.Div("0", id="progress"), + ] + ) + + @app.callback( + Output("status", "children"), + Input("btn", "n_clicks"), + prevent_initial_call=True, + ) + async def task_with_late_set_props(n_clicks): + set_props("status", {"children": "started"}) + set_props("progress", {"children": "10"}) + + # Wait long enough for inactivity timeout to trigger + await asyncio.sleep(5) + + # These set_props calls happen AFTER reconnection + # They should still work via the adopted queue + set_props("progress", {"children": "100"}) + + return "done" + + dash_duo.start_server(app) + + dash_duo.wait_for_text_to_equal("#status", "ready") + + dash_duo.find_element("#btn").click() + + # Should see initial updates + dash_duo.wait_for_text_to_equal("#status", "started", timeout=5) + dash_duo.wait_for_text_to_equal("#progress", "10", timeout=5) + + # Should see final update after reconnection + dash_duo.wait_for_text_to_equal("#progress", "100", timeout=15) + dash_duo.wait_for_text_to_equal("#status", "done", timeout=5) + + assert dash_duo.get_logs() == [] From 21278e0ffd2a05d6c9586427d33fc40bf832a261 Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 26 May 2026 17:05:55 -0400 Subject: [PATCH 2/5] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 62d49513c5..a262be18b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ This project adheres to [Semantic Versioning](https://semver.org/). - [#3669](https://github.com/plotly/dash/pull/3669) Selection for DataTable cleared with custom action settings - [#3680](https://github.com/plotly/dash/pull/3680) Added `search_order` prop to `Dropdown` to allow users to preserve original option order during search - Added `csrf_token_name` and `csrf_header_name` config options to allow configuring the CSRF cookie and header names. Fixes [#729](https://github.com/plotly/dash/issues/729) +- [#3797](https://github.com/plotly/dash/pull/3797) Improved websocket callback management. ## Added - [#3523](https://github.com/plotly/dash/pull/3523) Fall back to background callback function names if source cannot be found From 6f6442f0b185f32d876900b4d55f8ce89e06c62a Mon Sep 17 00:00:00 2001 From: philippe Date: Tue, 26 May 2026 18:52:22 -0400 Subject: [PATCH 3/5] lint fix --- dash/backends/ws.py | 81 ++++++++++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 27 deletions(-) diff --git a/dash/backends/ws.py b/dash/backends/ws.py index adebdf2c4a..807a167ab7 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -242,35 +242,62 @@ async def run_ws_sender( messages: list[str] = [] try: while True: - # Wait indefinitely for first message, then use timeout for batching - timeout = batch_delay if messages else None - try: - msg = await asyncio.wait_for(q.get(), timeout=timeout) - if msg == SHUTDOWN_SIGNAL: - if messages: - await _send_batched(send_text, messages) - return - if msg == FLUSH_SIGNAL: - if messages: - if not await _send_batched(send_text, messages): - return # Connection closed - messages = [] - continue - if not batch_delay: - try: - await send_text(msg) - except Exception: # WebSocketDisconnect, RuntimeError, etc. - return # Connection closed - else: - messages.append(msg) - except asyncio.TimeoutError: - if not await _send_batched(send_text, messages): - return # Connection closed - messages = [] + result = await _process_ws_message(q, send_text, messages, batch_delay) + if result is False: + return except asyncio.CancelledError: pass +async def _process_ws_message( + q: "janus._AsyncQueueProxy[str]", + send_text: Callable[[str], Any], + messages: list[str], + batch_delay: float, +) -> bool | None: + """Process a single WebSocket message from the queue. + + Args: + q: The async queue to read from + send_text: Async function to send text data over WebSocket + messages: List to accumulate messages for batching (mutated in place) + batch_delay: Batch delay in seconds + + Returns: + True to continue processing, False to stop the sender loop, + None to continue (same as True but used for continue semantics). + """ + timeout = batch_delay if messages else None + try: + msg = await asyncio.wait_for(q.get(), timeout=timeout) + except asyncio.TimeoutError: + if not await _send_batched(send_text, messages): + return False + messages.clear() + return True + + if msg == SHUTDOWN_SIGNAL: + if messages: + await _send_batched(send_text, messages) + return False + + if msg == FLUSH_SIGNAL: + if messages and not await _send_batched(send_text, messages): + return False + messages.clear() + return None + + if not batch_delay: + try: + await send_text(msg) + except Exception: # pylint: disable=broad-exception-caught + return False # WebSocketDisconnect, RuntimeError, etc. + else: + messages.append(msg) + + return True + + async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool: """Send messages as a batch. @@ -291,8 +318,8 @@ async def _send_batched(send_text: Callable[[str], Any], messages: list) -> bool # Wrap in array: "[msg1,msg2,msg3]" await send_text("[" + ",".join(messages) + "]") return True - except Exception: # WebSocketDisconnect, RuntimeError, etc. - return False # Connection closed, cleanup handled by main loop + except Exception: # pylint: disable=broad-exception-caught + return False # WebSocketDisconnect, RuntimeError, etc. def make_callback_done_handler( From 2f42201532cb2d5892197f18f0703d6f42c4c2a6 Mon Sep 17 00:00:00 2001 From: philippe Date: Wed, 27 May 2026 16:39:25 -0400 Subject: [PATCH 4/5] fix lint --- dash/backends/ws.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/dash/backends/ws.py b/dash/backends/ws.py index 807a167ab7..8ee2cf11c1 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -254,7 +254,7 @@ async def _process_ws_message( send_text: Callable[[str], Any], messages: list[str], batch_delay: float, -) -> bool | None: +) -> bool: """Process a single WebSocket message from the queue. Args: @@ -264,17 +264,15 @@ async def _process_ws_message( batch_delay: Batch delay in seconds Returns: - True to continue processing, False to stop the sender loop, - None to continue (same as True but used for continue semantics). + True to continue processing, False to stop the sender loop. """ timeout = batch_delay if messages else None try: msg = await asyncio.wait_for(q.get(), timeout=timeout) except asyncio.TimeoutError: - if not await _send_batched(send_text, messages): - return False + success = await _send_batched(send_text, messages) messages.clear() - return True + return success if msg == SHUTDOWN_SIGNAL: if messages: @@ -282,10 +280,9 @@ async def _process_ws_message( return False if msg == FLUSH_SIGNAL: - if messages and not await _send_batched(send_text, messages): - return False + success = not messages or await _send_batched(send_text, messages) messages.clear() - return None + return success if not batch_delay: try: From 3892cd4484515c837869a3b27e9b14d9635cefb5 Mon Sep 17 00:00:00 2001 From: philippe Date: Thu, 28 May 2026 19:28:34 -0400 Subject: [PATCH 5/5] remove ActiveCallbackRegistry & callback adoption --- .ai/ARCHITECTURE.md | 14 ++ dash/backends/_fastapi.py | 32 +--- dash/backends/_quart.py | 32 +--- dash/backends/_ws_registry.py | 166 ------------------ dash/backends/ws.py | 121 ++++++-------- tests/websocket/test_ws_reconnect.py | 242 +++------------------------ 6 files changed, 90 insertions(+), 517 deletions(-) delete mode 100644 dash/backends/_ws_registry.py diff --git a/.ai/ARCHITECTURE.md b/.ai/ARCHITECTURE.md index 84f553978c..da9913a13f 100644 --- a/.ai/ARCHITECTURE.md +++ b/.ai/ARCHITECTURE.md @@ -968,6 +968,7 @@ WebSocket callbacks can stream updates to the client during execution using `set ```python import asyncio from dash import callback, Output, Input, set_props, ctx +from dash.exceptions import PreventUpdate @callback( Output('result', 'children'), @@ -981,6 +982,9 @@ async def long_running_task(n_clicks): # Stream progress updates to the client for i in range(100): + # IMPORTANT: Check is_shutdown in loops to detect disconnections + if ws.is_shutdown: + raise PreventUpdate # Exit gracefully on disconnect await asyncio.sleep(0.1) set_props('progress-bar', {'value': i + 1}) set_props('status', {'children': f'Processing step {i + 1}/100...'}) @@ -991,9 +995,19 @@ async def long_running_task(n_clicks): return f"Completed! Input was: {current_value}" ``` +**IMPORTANT - Checking `is_shutdown` in Loops:** + +Long-running callbacks that use loops **must** check `ws.is_shutdown` to detect when the WebSocket connection has closed. Without this check: +- Callbacks continue running after the client disconnects, wasting server resources +- `set_props` calls go to a closed connection and are lost +- The callback result is never delivered to the client + +Only "persistent callbacks" (callbacks with no Output and no Input that use only `set_props`) are automatically restarted when the WebSocket reconnects. Regular callbacks with outputs are not restarted. + **API:** - `set_props(component_id, props_dict)` - Stream prop updates immediately to client - `ctx.websocket` - Get WebSocket interface (returns `None` if not in WS context) +- `ws.is_shutdown` - Check if the WebSocket connection has been closed - `await ws.get_prop(component_id, prop_name)` - Read current prop value from client - `await ws.set_prop(component_id, prop_name, value)` - Set single prop (async version) - `await ws.close(code, reason)` - Close the WebSocket connection diff --git a/dash/backends/_fastapi.py b/dash/backends/_fastapi.py index b619c040c9..027b94bed4 100644 --- a/dash/backends/_fastapi.py +++ b/dash/backends/_fastapi.py @@ -50,7 +50,6 @@ SHUTDOWN_SIGNAL, DISCONNECTED, ) -from ._ws_registry import ActiveCallbackRegistry from ._utils import format_traceback_html if TYPE_CHECKING: # pragma: no cover - typing only @@ -678,12 +677,6 @@ def serve_websocket_callback(self, dash_app: "Dash"): dash_app, "_websocket_allowed_origins", [] ) # pylint: disable=protected-access - # Initialize registry on dash_app if not present - # pylint: disable=protected-access - if not hasattr(dash_app, "_ws_callback_registry"): - dash_app._ws_callback_registry = ActiveCallbackRegistry() - registry: ActiveCallbackRegistry = dash_app._ws_callback_registry - def validate_origin(origin: str | None, host: str | None) -> str | None: """Validate WebSocket origin. Returns error message or None if valid.""" if not origin: @@ -730,8 +723,6 @@ async def websocket_handler(websocket: WebSocket): executor = self.get_callback_executor() # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} - # Track current renderer ID for this connection - current_renderer_id: str | None = None # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -762,22 +753,6 @@ async def websocket_handler(websocket: WebSocket): renderer_id = message.get("rendererId", "") payload = message.get("payload", {}) - # Update current renderer ID for cleanup - current_renderer_id = renderer_id - - # Adopt connection for this renderer (allows reconnection) - # Called for every callback to ensure registry entry exists - # (entry may have been cleaned up after previous callback) - registry.adopt_connection( - renderer_id, - outbound_queue, - pending_get_props, - shutdown_event, - ) - - # Register this callback with the registry - registry.register_callback(renderer_id) - # Validate that the callback is allowed to use WebSocket transport # pylint: disable=protected-access _validate.validate_websocket_callback_request( @@ -786,13 +761,12 @@ async def websocket_handler(websocket: WebSocket): dash_app._websocket_callbacks, ) - # Create WebSocket callback instance with registry + # Create WebSocket callback instance ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, shutdown_event, - registry=registry, ) # Submit callback to executor @@ -812,7 +786,6 @@ async def websocket_handler(websocket: WebSocket): request_id, renderer_id, shutdown_event, - registry=registry, ) ) pending_callbacks[request_id] = future @@ -848,9 +821,6 @@ async def websocket_handler(websocket: WebSocket): # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() - # Cleanup registry entry if no active callbacks - if current_renderer_id is not None: - registry.cleanup_renderer(current_renderer_id) self.server.add_api_websocket_route(ws_path, websocket_handler) diff --git a/dash/backends/_quart.py b/dash/backends/_quart.py index 31d9668d36..daac2dd3f6 100644 --- a/dash/backends/_quart.py +++ b/dash/backends/_quart.py @@ -55,7 +55,6 @@ SHUTDOWN_SIGNAL, DISCONNECTED, ) -from ._ws_registry import ActiveCallbackRegistry from ._utils import format_traceback_html if TYPE_CHECKING: @@ -522,12 +521,6 @@ def serve_websocket_callback(self, dash_app: "Dash"): # pylint: disable=protected-access allowed_origins = getattr(dash_app, "_websocket_allowed_origins", []) - # Initialize registry on dash_app if not present - # pylint: disable=protected-access - if not hasattr(dash_app, "_ws_callback_registry"): - dash_app._ws_callback_registry = ActiveCallbackRegistry() - registry: ActiveCallbackRegistry = dash_app._ws_callback_registry - @self.server.websocket(ws_path) async def websocket_handler(): # pylint: disable=too-many-branches ws = websocket @@ -571,8 +564,6 @@ async def websocket_handler(): # pylint: disable=too-many-branches executor = self.get_callback_executor() # Track pending callback futures pending_callbacks: Dict[str, concurrent.futures.Future] = {} - # Track current renderer ID for this connection - current_renderer_id: str | None = None # Start sender task to drain outbound queue (sends pre-serialized text) # pylint: disable=protected-access @@ -610,22 +601,6 @@ async def websocket_handler(): # pylint: disable=too-many-branches renderer_id = message.get("rendererId", "") payload = message.get("payload", {}) - # Update current renderer ID for cleanup - current_renderer_id = renderer_id - - # Adopt connection for this renderer (allows reconnection) - # Called for every callback to ensure registry entry exists - # (entry may have been cleaned up after previous callback) - registry.adopt_connection( - renderer_id, - outbound_queue, - pending_get_props, - connection_shutdown_event, - ) - - # Register this callback with the registry - registry.register_callback(renderer_id) - # Validate that the callback is allowed to use WebSocket transport # pylint: disable=protected-access _validate.validate_websocket_callback_request( @@ -634,13 +609,12 @@ async def websocket_handler(): # pylint: disable=too-many-branches dash_app._websocket_callbacks, ) - # Create WebSocket callback instance with registry + # Create WebSocket callback instance ws_cb = DashWebsocketCallback( pending_get_props, renderer_id, outbound_queue, connection_shutdown_event, - registry=registry, ) # Submit callback to executor @@ -660,7 +634,6 @@ async def websocket_handler(): # pylint: disable=too-many-branches request_id, renderer_id, connection_shutdown_event, - registry=registry, ) ) pending_callbacks[request_id] = future @@ -699,9 +672,6 @@ async def websocket_handler(): # pylint: disable=too-many-branches # Cancel any pending futures for f in pending_callbacks.values(): f.cancel() - # Cleanup registry entry if no active callbacks - if current_renderer_id is not None: - registry.cleanup_renderer(current_renderer_id) class QuartRequestAdapter(RequestAdapter): diff --git a/dash/backends/_ws_registry.py b/dash/backends/_ws_registry.py deleted file mode 100644 index 4ab1132d6a..0000000000 --- a/dash/backends/_ws_registry.py +++ /dev/null @@ -1,166 +0,0 @@ -"""WebSocket callback registry for handling reconnections. - -This module provides a registry that tracks active callbacks per renderer, -allowing callbacks to persist across WebSocket reconnections. -""" - -from __future__ import annotations - -import threading -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, Optional - -if TYPE_CHECKING: - import queue - import janus - - -@dataclass -class RendererState: - """State for a single renderer's WebSocket connection.""" - - outbound_queue: "janus.Queue[str]" - pending_get_props: Dict[str, "queue.Queue[Any]"] - shutdown_event: threading.Event - active_callback_count: int = 0 - lock: threading.Lock = field(default_factory=threading.Lock) - - -class ActiveCallbackRegistry: - """Registry for active WebSocket callbacks that persists across reconnections. - - When a WebSocket disconnects and reconnects, callbacks that are still running - can "adopt" the new connection's queues to continue sending updates. - - Thread-safe for access from both the main event loop and worker threads. - """ - - def __init__(self) -> None: - self._renderers: Dict[str, RendererState] = {} - self._lock = threading.Lock() - - def adopt_connection( - self, - renderer_id: str, - outbound_queue: "janus.Queue[str]", - pending_get_props: Dict[str, "queue.Queue[Any]"], - shutdown_event: threading.Event, - ) -> None: - """Associate new connection with existing callbacks for this renderer. - - When a WebSocket reconnects, this method updates the queues and shutdown - event so that running callbacks can use the new connection. - - Args: - renderer_id: The renderer ID for this connection - outbound_queue: janus.Queue for sending messages - pending_get_props: Dict to track pending get_props requests - shutdown_event: Event signaling connection closure - """ - with self._lock: - if renderer_id in self._renderers: - state = self._renderers[renderer_id] - with state.lock: - state.outbound_queue = outbound_queue - state.pending_get_props = pending_get_props - state.shutdown_event = shutdown_event - else: - self._renderers[renderer_id] = RendererState( - outbound_queue=outbound_queue, - pending_get_props=pending_get_props, - shutdown_event=shutdown_event, - active_callback_count=0, - ) - - def register_callback(self, renderer_id: str) -> None: - """Register a new active callback for this renderer. - - Args: - renderer_id: The renderer ID - """ - with self._lock: - if renderer_id in self._renderers: - state = self._renderers[renderer_id] - with state.lock: - state.active_callback_count += 1 - - def unregister_callback(self, renderer_id: str) -> None: - """Unregister a completed callback for this renderer. - - If no active callbacks remain, the renderer state is cleaned up. - - Args: - renderer_id: The renderer ID - """ - with self._lock: - if renderer_id in self._renderers: - state = self._renderers[renderer_id] - with state.lock: - state.active_callback_count -= 1 - if state.active_callback_count <= 0: - del self._renderers[renderer_id] - - def get_queue(self, renderer_id: str) -> Optional["janus.Queue[str]"]: - """Get current outbound queue for renderer (thread-safe). - - Args: - renderer_id: The renderer ID - - Returns: - The current outbound queue, or None if renderer not found - """ - with self._lock: - state = self._renderers.get(renderer_id) - if state is None: - return None - with state.lock: - return state.outbound_queue - - def get_pending_get_props( - self, renderer_id: str - ) -> Optional[Dict[str, "queue.Queue[Any]"]]: - """Get current pending_get_props dict for renderer (thread-safe). - - Args: - renderer_id: The renderer ID - - Returns: - The current pending_get_props dict, or None if renderer not found - """ - with self._lock: - state = self._renderers.get(renderer_id) - if state is None: - return None - with state.lock: - return state.pending_get_props - - def is_shutdown(self, renderer_id: str) -> bool: - """Check if current connection is shutdown. - - Args: - renderer_id: The renderer ID - - Returns: - True if shutdown event is set or renderer not found, False otherwise - """ - with self._lock: - state = self._renderers.get(renderer_id) - if state is None: - return True - with state.lock: - return state.shutdown_event.is_set() - - def cleanup_renderer(self, renderer_id: str) -> None: - """Clean up renderer state when connection closes. - - Only removes if no active callbacks remain. - - Args: - renderer_id: The renderer ID to clean up - """ - with self._lock: - state = self._renderers.get(renderer_id) - if state is not None: - with state.lock: - if state.active_callback_count <= 0: - del self._renderers[renderer_id] diff --git a/dash/backends/ws.py b/dash/backends/ws.py index 8ee2cf11c1..a4b302f215 100644 --- a/dash/backends/ws.py +++ b/dash/backends/ws.py @@ -25,7 +25,6 @@ if TYPE_CHECKING: import dash from .base_server import ResponseAdapter - from ._ws_registry import ActiveCallbackRegistry SHUTDOWN_SIGNAL = "__shutdown__" @@ -43,9 +42,24 @@ class DashWebsocketCallback: queue.Queue for get_props responses, enabling thread-safe communication between worker threads and the main event loop. - Supports two modes: - 1. Registry mode: Uses ActiveCallbackRegistry to allow queue adoption on reconnect - 2. Direct mode: Uses direct queue references (legacy, for backwards compatibility) + IMPORTANT: For long-running callbacks that use loops (e.g., streaming updates), + you MUST check `ws.is_shutdown` in your loop to detect disconnections: + + @callback(Input('btn', 'n_clicks')) # No Output - uses set_props only + async def long_running(n_clicks): + ws = ctx.websocket + while True: + if ws and ws.is_shutdown: + raise PreventUpdate # Exit gracefully on disconnect + set_props('progress', {'value': get_data()}) + await asyncio.sleep(0.1) + + Without this check, callbacks will continue running after the client disconnects, + wasting server resources. + + Note: Only "persistent callbacks" (callbacks with no Output and no Input that use + only set_props) are automatically restarted when the WebSocket reconnects. Regular + callbacks with outputs are not restarted. """ def __init__( @@ -54,7 +68,6 @@ def __init__( renderer_id: str, outbound_queue: janus.Queue[str], shutdown_event: "threading.Event", - registry: "ActiveCallbackRegistry | None" = None, ): """Initialize the WebSocket callback interface. @@ -64,33 +77,23 @@ def __init__( renderer_id: The renderer ID for routing messages back to the correct client outbound_queue: janus.Queue for thread-safe outbound messaging. shutdown_event: Event signaling the websocket connection has closed. - registry: Optional registry for handling reconnections. If provided, - the callback will use the registry to get current queues, allowing - it to survive reconnections. """ self._pending_get_props = pending_get_props self._renderer_id = renderer_id self._outbound_queue = outbound_queue self._shutdown_event = shutdown_event - self._registry = registry @property def is_shutdown(self) -> bool: """Check if the websocket connection has been shut down.""" - if self._registry is not None: - return self._registry.is_shutdown(self._renderer_id) return self._shutdown_event.is_set() def _get_outbound_queue(self) -> janus.Queue[str] | None: - """Get the current outbound queue (may be updated on reconnect).""" - if self._registry is not None: - return self._registry.get_queue(self._renderer_id) + """Get the outbound queue.""" return self._outbound_queue def _get_pending_get_props(self) -> Dict[str, queue.Queue[Any]] | None: - """Get the current pending_get_props dict (may be updated on reconnect).""" - if self._registry is not None: - return self._registry.get_pending_get_props(self._renderer_id) + """Get the pending_get_props dict.""" return self._pending_get_props def _queue_message(self, msg: dict) -> None: @@ -325,7 +328,6 @@ def make_callback_done_handler( request_id: str, renderer_id: str, shutdown_event: threading.Event, - registry: "ActiveCallbackRegistry | None" = None, ) -> Callable[[concurrent.futures.Future], None]: """Create a done callback handler for executor futures. @@ -338,73 +340,52 @@ def make_callback_done_handler( request_id: The request ID for the callback response renderer_id: The renderer ID for routing the response shutdown_event: Event signaling the websocket connection has closed. - registry: Optional registry for managing callback lifecycle. Returns: A callback function suitable for Future.add_done_callback() """ - def _is_shutdown() -> bool: - """Check if connection is shutdown (registry-aware).""" - if registry is not None: - return registry.is_shutdown(renderer_id) - return shutdown_event.is_set() - - def _get_queue() -> janus.Queue[str] | None: - """Get current outbound queue (may change on reconnect).""" - if registry is not None: - return registry.get_queue(renderer_id) - return outbound_queue - def on_done(f: concurrent.futures.Future) -> None: try: - if _is_shutdown(): + if shutdown_event.is_set(): return result = f.result() - current_queue = _get_queue() - if current_queue is not None: - current_queue.sync_q.put_nowait( - cast( - str, - to_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": result, - } - ), - ) + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": result, + } + ), ) + ) except Exception as e: # pylint: disable=broad-exception-caught - if _is_shutdown(): + if shutdown_event.is_set(): return - current_queue = _get_queue() - if current_queue is not None: - current_queue.sync_q.put_nowait( - cast( - str, - to_json( - { - "type": "callback_response", - "rendererId": renderer_id, - "requestId": request_id, - "payload": { - "status": "error", - "message": str(e), - }, - } - ), - ) + outbound_queue.sync_q.put_nowait( + cast( + str, + to_json( + { + "type": "callback_response", + "rendererId": renderer_id, + "requestId": request_id, + "payload": { + "status": "error", + "message": str(e), + }, + } + ), ) + ) finally: pending_callbacks.pop(request_id, None) - if registry is not None: - registry.unregister_callback(renderer_id) - if not _is_shutdown(): - current_queue = _get_queue() - if current_queue is not None: - current_queue.sync_q.put_nowait(FLUSH_SIGNAL) + if not shutdown_event.is_set(): + outbound_queue.sync_q.put_nowait(FLUSH_SIGNAL) return on_done diff --git a/tests/websocket/test_ws_reconnect.py b/tests/websocket/test_ws_reconnect.py index 64826ab1e6..51eeb38ae3 100644 --- a/tests/websocket/test_ws_reconnect.py +++ b/tests/websocket/test_ws_reconnect.py @@ -3,159 +3,15 @@ Tests: - Callback continuity after WebSocket reconnection -- Registry tracks active callbacks correctly - Disconnect handling doesn't cause error spam -- Long-running callbacks survive reconnection +- Long-running callbacks with is_shutdown check """ import asyncio import time -import threading -from dash import Dash, html, Input, Output, set_props -from dash.backends._ws_registry import ActiveCallbackRegistry - - -class TestActiveCallbackRegistry: - """Unit tests for the ActiveCallbackRegistry class.""" - - def test_registry_adopt_creates_entry(self): - """Test that adopt_connection creates a new registry entry.""" - registry = ActiveCallbackRegistry() - - # Mock queue-like object - class MockQueue: - def __init__(self): - self.sync_q = None - - outbound_queue = MockQueue() - pending_get_props = {} - shutdown_event = threading.Event() - - registry.adopt_connection( - "renderer1", outbound_queue, pending_get_props, shutdown_event - ) - - assert registry.get_queue("renderer1") == outbound_queue - assert registry.get_pending_get_props("renderer1") == pending_get_props - assert not registry.is_shutdown("renderer1") - - def test_registry_callback_lifecycle(self): - """Test register/unregister callback with cleanup.""" - registry = ActiveCallbackRegistry() - - class MockQueue: - def __init__(self): - self.sync_q = None - - outbound_queue = MockQueue() - shutdown_event = threading.Event() - - registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) - - # Register callback - registry.register_callback("renderer1") - assert not registry.is_shutdown("renderer1") - - # Unregister - should clean up entry since count becomes 0 - registry.unregister_callback("renderer1") - assert registry.is_shutdown("renderer1") # Returns True when not found - - def test_registry_multiple_callbacks(self): - """Test that multiple callbacks keep entry alive.""" - registry = ActiveCallbackRegistry() - - class MockQueue: - def __init__(self): - self.sync_q = None - - outbound_queue = MockQueue() - shutdown_event = threading.Event() - - registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) - - # Register two callbacks - registry.register_callback("renderer1") - registry.register_callback("renderer1") - - # Unregister one - entry should still exist - registry.unregister_callback("renderer1") - assert not registry.is_shutdown("renderer1") - - # Unregister second - now should be cleaned up - registry.unregister_callback("renderer1") - assert registry.is_shutdown("renderer1") - - def test_registry_adopt_after_cleanup(self): - """Test that adopt_connection works after cleanup.""" - registry = ActiveCallbackRegistry() - - class MockQueue: - def __init__(self): - self.sync_q = None - - outbound_queue = MockQueue() - shutdown_event = threading.Event() - - # First connection - registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) - registry.register_callback("renderer1") - registry.unregister_callback("renderer1") # Cleans up - - # Re-adopt after cleanup - registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) - assert not registry.is_shutdown("renderer1") - - def test_registry_adopt_updates_existing(self): - """Test that adopt_connection updates queues for existing entry.""" - registry = ActiveCallbackRegistry() - - class MockQueue: - def __init__(self, name): - self.name = name - self.sync_q = None - - old_queue = MockQueue("old") - new_queue = MockQueue("new") - old_shutdown = threading.Event() - new_shutdown = threading.Event() - - registry.adopt_connection("renderer1", old_queue, {}, old_shutdown) - registry.register_callback("renderer1") # Keep entry alive - - assert registry.get_queue("renderer1").name == "old" - - # Simulate reconnection - registry.adopt_connection("renderer1", new_queue, {}, new_shutdown) - - assert registry.get_queue("renderer1").name == "new" - - def test_registry_shutdown_event_respected(self): - """Test that shutdown event is checked correctly.""" - registry = ActiveCallbackRegistry() - - class MockQueue: - def __init__(self): - self.sync_q = None - - outbound_queue = MockQueue() - shutdown_event = threading.Event() - - registry.adopt_connection("renderer1", outbound_queue, {}, shutdown_event) - - assert not registry.is_shutdown("renderer1") - - shutdown_event.set() - - assert registry.is_shutdown("renderer1") - - def test_registry_unknown_renderer_is_shutdown(self): - """Test that unknown renderer IDs report as shutdown.""" - registry = ActiveCallbackRegistry() - - assert registry.is_shutdown("unknown_renderer") - assert registry.get_queue("unknown_renderer") is None - assert registry.get_pending_get_props("unknown_renderer") is None +from dash import Dash, html, Input, Output, set_props, ctx +from dash.exceptions import PreventUpdate def test_ws030_multiple_callbacks_same_connection(dash_duo): @@ -185,8 +41,8 @@ def on_click(n_clicks): assert dash_duo.get_logs() == [] -def test_ws031_rapid_callbacks_registry_handling(dash_duo): - """Test that rapid callbacks are handled correctly by registry.""" +def test_ws031_rapid_callbacks(dash_duo): + """Test that rapid callbacks are handled correctly.""" app = Dash(__name__, backend="fastapi", websocket_callbacks=True) app.layout = html.Div( @@ -343,61 +199,12 @@ async def on_click_b(n_clicks): assert dash_duo.get_logs() == [] -def test_ws035_callback_survives_inactivity_timeout(dash_duo): - """Test that long callback completes even when inactivity timeout triggers mid-execution. - - This is the key test for Issue #3788: when a callback runs longer than the - inactivity timeout without sending updates, the WebSocket disconnects and - reconnects. The callback should still complete and send its result via the - new connection. - """ - app = Dash( - __name__, - backend="fastapi", - websocket_callbacks=True, - websocket_inactivity_timeout=2000, # 2 seconds - ) - - app.layout = html.Div( - [ - html.Button("Start", id="btn", n_clicks=0), - html.Div("ready", id="output"), - ] - ) - - @app.callback( - Output("output", "children"), - Input("btn", "n_clicks"), - prevent_initial_call=True, - ) - async def silent_long_task(n_clicks): - # Wait longer than inactivity timeout WITHOUT sending any updates - # This will trigger WebSocket disconnect/reconnect mid-callback - await asyncio.sleep(5) - return f"completed:{n_clicks}" - - dash_duo.start_server(app) - - dash_duo.wait_for_text_to_equal("#output", "ready") - - # Start the long task - dash_duo.find_element("#btn").click() - - # Should complete despite inactivity timeout triggering during execution - dash_duo.wait_for_text_to_equal("#output", "completed:1", timeout=15) - - # Verify subsequent callbacks also work - dash_duo.find_element("#btn").click() - dash_duo.wait_for_text_to_equal("#output", "completed:2", timeout=15) - - assert dash_duo.get_logs() == [] - - -def test_ws036_set_props_after_reconnect(dash_duo): - """Test that set_props works after WebSocket reconnects mid-callback. +def test_ws035_long_callback_with_shutdown_check(dash_duo): + """Test long-running callback that properly checks is_shutdown. - This tests the registry's ability to adopt new queues so that - set_props calls use the new connection after reconnection. + Long-running callbacks should check ws.is_shutdown in their loops to + detect disconnections and exit gracefully. This prevents wasted server + resources when the client disconnects. """ app = Dash( __name__, @@ -419,16 +226,17 @@ def test_ws036_set_props_after_reconnect(dash_duo): Input("btn", "n_clicks"), prevent_initial_call=True, ) - async def task_with_late_set_props(n_clicks): - set_props("status", {"children": "started"}) - set_props("progress", {"children": "10"}) - - # Wait long enough for inactivity timeout to trigger - await asyncio.sleep(5) + async def long_task_with_shutdown_check(n_clicks): + ws = ctx.websocket + set_props("status", {"children": "running"}) - # These set_props calls happen AFTER reconnection - # They should still work via the adopted queue - set_props("progress", {"children": "100"}) + # Properly check is_shutdown in the loop + for i in range(1, 11): + if ws and ws.is_shutdown: + # Exit gracefully on disconnect + raise PreventUpdate + set_props("progress", {"children": str(i * 10)}) + await asyncio.sleep(0.2) return "done" @@ -438,12 +246,8 @@ async def task_with_late_set_props(n_clicks): dash_duo.find_element("#btn").click() - # Should see initial updates - dash_duo.wait_for_text_to_equal("#status", "started", timeout=5) - dash_duo.wait_for_text_to_equal("#progress", "10", timeout=5) - - # Should see final update after reconnection - dash_duo.wait_for_text_to_equal("#progress", "100", timeout=15) - dash_duo.wait_for_text_to_equal("#status", "done", timeout=5) + # Should see progress updates and complete + dash_duo.wait_for_text_to_equal("#status", "done", timeout=10) + dash_duo.wait_for_text_to_equal("#progress", "100") assert dash_duo.get_logs() == []