From 354052021de053ab54ffd77fabbfb7f3da4874cd Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 31 May 2026 23:26:28 -0400 Subject: [PATCH 01/14] feat(app): parallel multi-GPU session execution Run one generation session per configured GPU concurrently, with a tiled progress preview. Multi-user isolation is unchanged. Backed by five seams: - Per-thread device context (TorchDevice.set/get/clear_session_device); choose_torch_device() consults it first, so all device-selecting call sites resolve to the calling worker's GPU with no per-node changes. - Per-device model caches: build_model_manager builds one ModelCache per generation device; ModelLoadService.ram_cache resolves by current thread device; ram_caches fans out clear/drop/shutdown. - Atomic concurrent dequeue: a dequeue lock makes select+claim atomic so concurrent workers never claim the same item (works on FIFO; round-robin from #9086 slots in later). - Worker pool: one _SessionWorker per device, each pinning torch.cuda.set_device and its session device, with its own runner and cancel event; cancellation routes via an {item_id -> worker} lookup. Single-device installs keep the exact legacy single-worker behavior. Profiling disabled when >1 worker. - New config `generation_devices`; unset = legacy single-worker mode. Frontend: the canvas staging area already tiles per queue item; the main ImageViewer now tracks progress per session and renders a tile grid (ProgressImageTiles) when more than one session is active. Also adds a lock to ObjectSerializerForwardCache for concurrent access. Co-Authored-By: Claude Opus 4.8 (1M context) --- docs/src/generated/settings.json | 11 + invokeai/app/api/routers/model_manager.py | 11 +- .../app/services/config/config_default.py | 14 ++ .../services/model_load/model_load_base.py | 16 +- .../services/model_load/model_load_default.py | 42 +++- .../model_manager/model_manager_default.py | 48 ++-- .../object_serializer_forward_cache.py | 23 +- .../session_processor_default.py | 208 +++++++++++++----- .../session_queue/session_queue_sqlite.py | 54 +++-- .../load/model_cache/model_cache.py | 5 + invokeai/backend/util/devices.py | 27 +++ .../ImageViewer/CurrentImagePreview.tsx | 19 +- .../ImageViewer/ProgressImageTiles.tsx | 39 ++++ .../components/ImageViewer/context.tsx | 42 +++- .../test_model_load_device_routing.py | 81 +++++++ .../test_session_queue_dequeue_concurrency.py | 70 ++++++ tests/backend/util/test_devices.py | 45 ++++ 17 files changed, 634 insertions(+), 121 deletions(-) create mode 100644 invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx create mode 100644 tests/app/services/model_load/test_model_load_device_routing.py create mode 100644 tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index 88a42f8fbcf..eb26d39960f 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -490,6 +490,17 @@ "type": "", "validation": {} }, + { + "category": "DEVICE", + "default": null, + "description": "List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", + "env_var": "INVOKEAI_GENERATION_DEVICES", + "literal_values": [], + "name": "generation_devices", + "required": false, + "type": "typing.Optional[list[str]]", + "validation": {} + }, { "category": "DEVICE", "default": "auto", diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index bdd2e406444..53c4c68981f 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -443,7 +443,11 @@ async def update_model_record( # nn.Module at load time, so toggling them on a cached model is otherwise silently a no-op until # the entry is evicted. Drop any unlocked cached entries for this model so the next load rebuilds. if _load_settings_changed(previous_config, config): - dropped = ApiDependencies.invoker.services.model_manager.load.ram_cache.drop_model(key) + # Drop the model from every per-device cache so the next load on any GPU rebuilds it. + dropped = sum( + cache.drop_model(key) + for cache in ApiDependencies.invoker.services.model_manager.load.ram_caches.values() + ) if dropped: logger.info( f"Dropped {dropped} cached entr{'y' if dropped == 1 else 'ies'} for model {key} after settings change." @@ -1304,9 +1308,10 @@ async def get_stats() -> Optional[CacheStats]: ) async def empty_model_cache(current_admin: AdminUserOrDefault) -> None: """Drop all models from the model cache to free RAM/VRAM. 'Locked' models that are in active use will not be dropped.""" - # Request 1000GB of room in order to force the cache to drop all models. + # Request 1000GB of room in order to force each per-device cache to drop all models. ApiDependencies.invoker.services.logger.info("Emptying model cache.") - ApiDependencies.invoker.services.model_manager.load.ram_cache.make_room(1000 * 2**30) + for cache in ApiDependencies.invoker.services.model_manager.load.ram_caches.values(): + cache.make_room(1000 * 2**30) class HFTokenStatus(str, Enum): diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 57004efca39..a70f5f7e97c 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -205,6 +205,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$") + generation_devices: Optional[list[str]] = Field(default=None, description="List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION @@ -257,6 +258,19 @@ class InvokeAIAppConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="INVOKEAI_", env_ignore_empty=True) + @field_validator("generation_devices") + @classmethod + def validate_generation_devices(cls, v: Optional[list[str]]) -> Optional[list[str]]: + if v is None: + return v + pattern = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") + for device in v: + if not pattern.match(device): + raise ValueError( + f"Invalid generation device '{device}'. Valid values are 'cpu', 'mps', 'cuda', or 'cuda:N'." + ) + return v + def update_config(self, config: dict[str, Any] | InvokeAIAppConfig, clobber: bool = True) -> None: """Updates the config, overwriting existing values. diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 87a405b4ea4..8fc9823328d 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -26,7 +26,21 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo @property @abstractmethod def ram_cache(self) -> ModelCache: - """Return the RAM cache used by this loader.""" + """Return the RAM cache for the current thread's execution device. + + In multi-GPU mode, each session-processor worker is pinned to a device and gets its own + cache; this resolves to the calling thread's cache. Outside a worker (e.g. API threads), + it resolves to the default device's cache. + """ + + @property + @abstractmethod + def ram_caches(self) -> dict[str, ModelCache]: + """Return all per-device RAM caches, keyed by normalized device string. + + Use this for maintenance operations that must apply to every device (clear cache, drop a + model from all devices, shutdown). + """ @abstractmethod def load_model_from_path( diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 2e2d2ae219d..45d0c354278 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -33,13 +33,25 @@ def __init__( app_config: InvokeAIAppConfig, ram_cache: ModelCache, registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, + ram_caches: Optional[dict[str, ModelCache]] = None, ): - """Initialize the model load service.""" + """Initialize the model load service. + + Args: + ram_cache: The default RAM cache, used when no per-device cache matches the calling + thread (e.g. single-device installs, or API threads). + ram_caches: Optional map of normalized device string -> ModelCache for multi-GPU mode. + One cache per generation device. The default `ram_cache` is always included. + """ logger = InvokeAILogger.get_logger(self.__class__.__name__) logger.setLevel(app_config.log_level.upper()) self._logger = logger self._app_config = app_config - self._ram_cache = ram_cache + self._default_ram_cache = ram_cache + # Map normalized device string -> cache. Always includes the default cache so that callers + # without a pinned device (API threads) resolve to a valid cache. + self._ram_caches: dict[str, ModelCache] = dict(ram_caches) if ram_caches else {} + self._ram_caches.setdefault(str(TorchDevice.normalize(ram_cache.execution_device)), ram_cache) self._registry = registry def start(self, invoker: Invoker) -> None: @@ -47,8 +59,18 @@ def start(self, invoker: Invoker) -> None: @property def ram_cache(self) -> ModelCache: - """Return the RAM cache used by this loader.""" - return self._ram_cache + """Return the RAM cache for the calling thread's execution device. + + `choose_torch_device()` is thread-local-aware: a session-processor worker pinned to a GPU + gets that GPU's cache; everything else falls back to the default cache. + """ + key = str(TorchDevice.choose_torch_device()) + return self._ram_caches.get(key, self._default_ram_cache) + + @property + def ram_caches(self) -> dict[str, ModelCache]: + """Return all per-device RAM caches, keyed by normalized device string.""" + return dict(self._ram_caches) def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel: """ @@ -67,7 +89,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo loaded_model: LoadedModel = implementation( app_config=self._app_config, logger=self._logger, - ram_cache=self._ram_cache, + ram_cache=self.ram_cache, ).load_model(model_config, submodel_type) if hasattr(self, "_invoker"): @@ -78,9 +100,11 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None ) -> LoadedModelWithoutConfig: + # Resolve the calling thread's cache once so the whole load uses a single device's cache. + ram_cache = self.ram_cache cache_key = str(model_path) try: - return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) except IndexError: pass @@ -110,7 +134,7 @@ def diffusers_load_directory(directory: Path) -> AnyModel: load_class = GenericDiffusersLoader( app_config=self._app_config, logger=self._logger, - ram_cache=self._ram_cache, + ram_cache=ram_cache, convert_cache=self.convert_cache, ).get_hf_load_class(directory) return load_class.from_pretrained(model_path, torch_dtype=TorchDevice.choose_torch_dtype()) @@ -124,5 +148,5 @@ def diffusers_load_directory(directory: Path) -> AnyModel: ) assert loader is not None raw_model = loader(model_path) - self._ram_cache.put(key=cache_key, model=raw_model) - return LoadedModelWithoutConfig(cache_record=self._ram_cache.get(key=cache_key), cache=self._ram_cache) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index 6141a635f4d..eaeb5d4e612 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -60,9 +60,10 @@ def start(self, invoker: Invoker) -> None: service.start(invoker) def stop(self, invoker: Invoker) -> None: - # Shutdown the model cache to cancel any pending timers - if hasattr(self._load, "ram_cache"): - self._load.ram_cache.shutdown() + # Shutdown every per-device model cache to cancel any pending keep-alive timers. + if hasattr(self._load, "ram_caches"): + for cache in self._load.ram_caches.values(): + cache.shutdown() for service in [self._store, self._install, self._load]: if hasattr(service, "stop"): @@ -85,22 +86,39 @@ def build_model_manager( logger = InvokeAILogger.get_logger(cls.__name__) logger.setLevel(app_config.log_level.upper()) - ram_cache = ModelCache( - execution_device_working_mem_gb=app_config.device_working_mem_gb, - enable_partial_loading=app_config.enable_partial_loading, - keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, - max_ram_cache_size_gb=app_config.max_cache_ram_gb, - max_vram_cache_size_gb=app_config.max_cache_vram_gb, - execution_device=execution_device or TorchDevice.choose_torch_device(), - storage_device="cpu", - log_memory_usage=app_config.log_memory_usage, - logger=logger, - keep_alive_minutes=app_config.model_cache_keep_alive_min, - ) + def build_cache(device: torch.device) -> ModelCache: + return ModelCache( + execution_device_working_mem_gb=app_config.device_working_mem_gb, + enable_partial_loading=app_config.enable_partial_loading, + keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, + max_ram_cache_size_gb=app_config.max_cache_ram_gb, + max_vram_cache_size_gb=app_config.max_cache_vram_gb, + execution_device=device, + storage_device="cpu", + log_memory_usage=app_config.log_memory_usage, + logger=logger, + keep_alive_minutes=app_config.model_cache_keep_alive_min, + ) + + # The default cache for callers without a pinned device (API threads, single-device installs). + default_device = execution_device or TorchDevice.choose_torch_device() + ram_cache = build_cache(default_device) + + # In multi-GPU mode, build one independent cache per generation device. Each session-processor + # worker is pinned to a device (see TorchDevice.set_session_device) and resolves to its own + # cache. The default cache is always included by ModelLoadService. + ram_caches: dict[str, ModelCache] = {str(TorchDevice.normalize(default_device)): ram_cache} + if app_config.generation_devices: + for device_str in app_config.generation_devices: + key = str(TorchDevice.normalize(device_str)) + if key not in ram_caches: + ram_caches[key] = build_cache(torch.device(key)) + loader = ModelLoadService( app_config=app_config, ram_cache=ram_cache, registry=ModelLoaderRegistry, + ram_caches=ram_caches, ) installer = ModelInstallService( app_config=app_config, diff --git a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py index b361259a4b1..ae00173e422 100644 --- a/invokeai/app/services/object_serializer/object_serializer_forward_cache.py +++ b/invokeai/app/services/object_serializer/object_serializer_forward_cache.py @@ -1,4 +1,5 @@ from queue import Queue +from threading import Lock from typing import TYPE_CHECKING, Optional, TypeVar from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase @@ -21,6 +22,9 @@ def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: self._cache: dict[str, T] = {} self._cache_ids = Queue[str]() self._max_cache_size = max_cache_size + # Guards the in-memory cache so concurrent session-processor workers (multi-GPU) can't race + # the check-then-evict in `_set_cache` (which could otherwise raise KeyError on eviction). + self._cache_lock = Lock() def start(self, invoker: "Invoker") -> None: self._invoker = invoker @@ -50,16 +54,19 @@ def save(self, obj: T) -> str: def delete(self, name: str) -> None: self._underlying_storage.delete(name) - if name in self._cache: - del self._cache[name] + with self._cache_lock: + if name in self._cache: + del self._cache[name] self._on_deleted(name) def _get_cache(self, name: str) -> Optional[T]: - return None if name not in self._cache else self._cache[name] + with self._cache_lock: + return None if name not in self._cache else self._cache[name] def _set_cache(self, name: str, data: T): - if name not in self._cache: - self._cache[name] = data - self._cache_ids.put(name) - if self._cache_ids.qsize() > self._max_cache_size: - self._cache.pop(self._cache_ids.get()) + with self._cache_lock: + if name not in self._cache: + self._cache[name] = data + self._cache_ids.put(name) + if self._cache_ids.qsize() > self._max_cache_size: + self._cache.pop(self._cache_ids.get()) diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7159c19e746..c6d566255b2 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -5,6 +5,8 @@ from threading import Event as ThreadEvent from typing import Optional +import torch + from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, @@ -31,6 +33,7 @@ from invokeai.app.services.shared.graph import NodeInputError from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context from invokeai.app.util.profiler import Profiler +from invokeai.backend.util.devices import TorchDevice class DefaultSessionRunner(SessionRunnerBase): @@ -305,6 +308,26 @@ def _on_node_error( ) +class _SessionWorker: + """A single generation worker: one thread, optionally pinned to one device. + + In single-device (legacy) mode there is exactly one worker with `device=None`. In multi-GPU + mode there is one worker per configured device, each with its own session runner and cancel + event so concurrent sessions can be canceled independently. + """ + + def __init__(self, device: Optional[torch.device], runner: SessionRunnerBase) -> None: + self.device = device + self.runner = runner + self.cancel_event = ThreadEvent() + self.queue_item: Optional[SessionQueueItem] = None + self.thread: Optional[Thread] = None + + @property + def label(self) -> str: + return str(self.device) if self.device is not None else "default device" + + class DefaultSessionProcessor(SessionProcessorBase): def __init__( self, @@ -319,57 +342,118 @@ def __init__( self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or [] self._thread_limit = thread_limit self._polling_interval = polling_interval + self._workers: list[_SessionWorker] = [] + + def _resolve_devices(self) -> list[Optional[torch.device]]: + """Determine the per-worker devices from config. + + Returns a single `None` (legacy single-worker, device chosen by the global config) unless + `generation_devices` is configured, in which case it returns one normalized device per + listed device (deduplicated, order preserved). + """ + generation_devices = self._invoker.services.configuration.generation_devices + if not generation_devices: + return [None] + devices: list[Optional[torch.device]] = [] + seen: set[str] = set() + for device_str in generation_devices: + device = TorchDevice.normalize(device_str) + if str(device) not in seen: + seen.add(str(device)) + devices.append(device) + return devices + + def _clone_session_runner(self, template: SessionRunnerBase) -> SessionRunnerBase: + """Create an independent runner for an additional worker. + + Each worker needs its own runner because the runner stores its session's cancel event. + We carry over the template's callbacks so all workers behave identically. + """ + if isinstance(template, DefaultSessionRunner): + return DefaultSessionRunner( + on_before_run_session_callbacks=list(template._on_before_run_session_callbacks), + on_before_run_node_callbacks=list(template._on_before_run_node_callbacks), + on_after_run_node_callbacks=list(template._on_after_run_node_callbacks), + on_node_error_callbacks=list(template._on_node_error_callbacks), + on_after_run_session_callbacks=list(template._on_after_run_session_callbacks), + ) + # Unknown runner implementation — only safe to reuse in single-worker mode. + return template def start(self, invoker: Invoker) -> None: self._invoker: Invoker = invoker - self._queue_item: Optional[SessionQueueItem] = None - self._invocation: Optional[BaseInvocation] = None self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() self._poll_now_event = ThreadEvent() - self._cancel_event = ThreadEvent() register_events(QueueClearedEvent, self._on_queue_cleared) register_events(BatchEnqueuedEvent, self._on_batch_enqueued) register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed) - self._thread_semaphore = BoundedSemaphore(self._thread_limit) + devices = self._resolve_devices() # If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally, - # the profiler will create a new profile for each session. + # the profiler will create a new profile for each session. Profiling uses a process-global cProfile, which + # cannot cleanly attribute work when multiple sessions run concurrently, so it is disabled in multi-GPU mode. + profiler_enabled = self._invoker.services.configuration.profile_graphs + if profiler_enabled and len(devices) > 1: + self._invoker.services.logger.warning( + "Graph profiling is disabled because multiple generation devices are configured." + ) + profiler_enabled = False self._profiler = ( Profiler( logger=self._invoker.services.logger, output_dir=self._invoker.services.configuration.profiles_path, prefix=self._invoker.services.configuration.profile_prefix, ) - if self._invoker.services.configuration.profile_graphs + if profiler_enabled else None ) - self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler) - self._thread = Thread( - name="session_processor", - target=self._process, - daemon=True, - kwargs={ - "stop_event": self._stop_event, - "poll_now_event": self._poll_now_event, - "resume_event": self._resume_event, - "cancel_event": self._cancel_event, - }, - ) - self._thread.start() + self._thread_semaphore = BoundedSemaphore(len(devices)) + + # Start in the running (resumed) state. + self._stop_event.clear() + self._resume_event.set() + + self._workers = [] + for index, device in enumerate(devices): + runner = self.session_runner if index == 0 else self._clone_session_runner(self.session_runner) + worker = _SessionWorker(device=device, runner=runner) + runner.start(services=invoker.services, cancel_event=worker.cancel_event, profiler=self._profiler) + self._workers.append(worker) + + if len(self._workers) > 1: + self._invoker.services.logger.info( + f"Starting session processor with {len(self._workers)} parallel workers on devices: " + f"{', '.join(w.label for w in self._workers)}" + ) + + for index, worker in enumerate(self._workers): + worker.thread = Thread( + name=f"session_processor_{index}", + target=self._process, + daemon=True, + kwargs={ + "worker": worker, + "stop_event": self._stop_event, + "poll_now_event": self._poll_now_event, + "resume_event": self._resume_event, + }, + ) + worker.thread.start() def stop(self, *args, **kwargs) -> None: self._stop_event.set() # Cancel any in-progress generation so that long-running nodes (e.g. denoising) stop at - # the next step boundary instead of running to completion. Without this, the generation + # the next step boundary instead of running to completion. Without this, a generation # thread may still be executing CUDA operations when Python teardown begins, which can # cause a C++ std::terminate() crash ("terminate called without an active exception"). - self._cancel_event.set() - # Wake the thread if it is sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused). + for worker in self._workers: + worker.cancel_event.set() + # Wake any worker sleeping in poll_now_event.wait() or blocked in resume_event.wait() (paused). self._poll_now_event.set() self._resume_event.set() @@ -377,28 +461,31 @@ def _poll_now(self) -> None: self._poll_now_event.set() async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: - if self._queue_item and self._queue_item.queue_id == event[1].queue_id: - self._cancel_event.set() + # Cancel every worker currently running an item from the cleared queue. + canceled = False + for worker in self._workers: + if worker.queue_item and worker.queue_item.queue_id == event[1].queue_id: + worker.cancel_event.set() + canceled = True + if canceled: self._poll_now() async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None: self._poll_now() async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: - # Make sure the cancel event is for the currently processing queue item - if self._queue_item and self._queue_item.item_id != event[1].item_id: - return - if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: - # When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is - # emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel - # event, which the session runner checks between invocations. If set, the session runner loop is broken. - # - # Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such - # node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item - # is canceled, and if it is, raises a `CanceledException` to stop execution immediately. - if event[1].status == "canceled": - self._cancel_event.set() - self._poll_now() + # Find the worker (if any) currently running the item whose status changed. + for worker in self._workers: + if worker.queue_item and worker.queue_item.item_id == event[1].item_id: + if event[1].status in ["completed", "failed", "canceled"]: + # When the queue item is canceled via HTTP, the status is set to "canceled" and this event is + # emitted. We respond by setting that worker's cancel event, which its session runner checks + # between invocations (and which denoise_latents' step callback checks mid-node, raising + # CanceledException to stop immediately). + if event[1].status == "canceled": + worker.cancel_event.set() + self._poll_now() + return def resume(self) -> SessionProcessorStatus: if not self._resume_event.is_set(): @@ -413,22 +500,28 @@ def pause(self) -> SessionProcessorStatus: def get_status(self) -> SessionProcessorStatus: return SessionProcessorStatus( is_started=self._resume_event.is_set(), - is_processing=self._queue_item is not None, + is_processing=any(worker.queue_item is not None for worker in self._workers), ) def _process( self, + worker: _SessionWorker, stop_event: ThreadEvent, poll_now_event: ThreadEvent, resume_event: ThreadEvent, - cancel_event: ThreadEvent, ): try: - # Any unhandled exception in this block is a fatal processor error and will stop the processor. + # Any unhandled exception in this block is a fatal processor error and will stop this worker. self._thread_semaphore.acquire() - stop_event.clear() - resume_event.set() - cancel_event.clear() + + # Pin this worker thread to its device so all device-selecting code (TorchDevice.choose_torch_device, + # which nodes and the model loader consult) resolves to this GPU. CUDA's current device is per-thread. + if worker.device is not None: + TorchDevice.set_session_device(worker.device) + if worker.device.type == "cuda": + torch.cuda.set_device(worker.device) + + worker.cancel_event.clear() while not stop_event.is_set(): poll_now_event.clear() @@ -437,10 +530,14 @@ def _process( # If we are paused, wait for resume event resume_event.wait() - # Get the next session to process - self._queue_item = self._invoker.services.session_queue.dequeue() + if stop_event.is_set(): + break + + # Get the next session to process. dequeue() atomically claims the item, so concurrent + # workers never receive the same item. + worker.queue_item = self._invoker.services.session_queue.dequeue() - if self._queue_item is None: + if worker.queue_item is None: # The queue was empty, wait for next polling interval or event to try again self._invoker.services.logger.debug("Waiting for next polling interval or event") poll_now_event.wait(self._polling_interval) @@ -453,19 +550,20 @@ def _process( gc.collect() self._invoker.services.logger.info( - f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}" + f"Executing queue item {worker.queue_item.item_id}, session {worker.queue_item.session_id} " + f"on {worker.label}" ) - cancel_event.clear() + worker.cancel_event.clear() # Run the graph - self.session_runner.run(queue_item=self._queue_item) + worker.runner.run(queue_item=worker.queue_item) except Exception as e: error_type = e.__class__.__name__ error_message = str(e) error_traceback = traceback.format_exc() self._on_non_fatal_processor_error( - queue_item=self._queue_item, + queue_item=worker.queue_item, error_type=error_type, error_message=error_message, error_traceback=error_traceback, @@ -474,7 +572,7 @@ def _process( poll_now_event.wait(self._polling_interval) continue except Exception as e: - # Fatal error in processor, log and pass - we're done here + # Fatal error in this worker, log and pass - we're done here error_type = e.__class__.__name__ error_message = str(e) error_traceback = traceback.format_exc() @@ -482,9 +580,9 @@ def _process( self._invoker.services.logger.error(error_traceback) pass finally: - stop_event.clear() - poll_now_event.clear() - self._queue_item = None + worker.queue_item = None + if worker.device is not None: + TorchDevice.clear_session_device() self._thread_semaphore.release() def _on_non_fatal_processor_error( diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index a05ed468857..f1bcd8c7c5c 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -1,6 +1,7 @@ import asyncio import json import sqlite3 +import threading from typing import Optional, Union, cast from pydantic_core import to_jsonable_python @@ -42,6 +43,12 @@ class SqliteSessionQueue(SessionQueueBase): __invoker: Invoker + # Serializes the select-candidate-then-claim sequence in `dequeue()`. The DB connection's + # RLock serializes individual statements, but the gap between selecting the next pending item + # and marking it 'in_progress' is a race: with multiple session-processor workers (multi-GPU), + # two workers could select the same item. Holding this lock across the whole claim prevents it. + _dequeue_lock = threading.Lock() + def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() @@ -210,27 +217,32 @@ async def enqueue_batch( return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: - with self._db.transaction() as cursor: - cursor.execute( - """--sql - SELECT - sq.*, - u.display_name as user_display_name, - u.email as user_email - FROM session_queue sq - LEFT JOIN users u ON sq.user_id = u.user_id - WHERE sq.status = 'pending' - ORDER BY - sq.priority DESC, - sq.item_id ASC - LIMIT 1 - """ - ) - result = cast(Union[sqlite3.Row, None], cursor.fetchone()) - if result is None: - return None - queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) - queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress") + # Hold the dequeue lock across the select-then-claim so concurrent workers (multi-GPU) + # cannot select and claim the same pending item. `_set_queue_item_status` already no-ops + # if the item was concurrently moved to a terminal state (e.g. canceled), so we only need + # to guard against two dequeues racing for the same pending row. + with self._dequeue_lock: + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT + sq.*, + u.display_name as user_display_name, + u.email as user_email + FROM session_queue sq + LEFT JOIN users u ON sq.user_id = u.user_id + WHERE sq.status = 'pending' + ORDER BY + sq.priority DESC, + sq.item_id ASC + LIMIT 1 + """ + ) + result = cast(Union[sqlite3.Row, None], cursor.fetchone()) + if result is None: + return None + queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) + queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress") return queue_item def get_next(self, queue_id: str) -> Optional[SessionQueueItem]: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index e3a0928e52b..1196a0f3885 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -229,6 +229,11 @@ def unsubscribe() -> None: return unsubscribe + @property + def execution_device(self) -> torch.device: + """Return the default execution device this cache loads models onto.""" + return self._execution_device + @property @synchronized def stats(self) -> Optional[CacheStats]: diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index 359ce45dc4f..d912f86a8a3 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -1,3 +1,4 @@ +import threading from typing import Dict, Literal, Optional, Union import torch @@ -46,9 +47,35 @@ class TorchDevice: CUDA_DEVICE = torch.device("cuda") MPS_DEVICE = torch.device("mps") + # Per-thread execution device. When set (by a session-processor worker thread bound to a + # specific GPU), `choose_torch_device()` returns it instead of consulting the global config. + # This is the lynchpin that makes the ~79 `choose_torch_device()` call sites (nodes, model + # patcher, etc.) resolve to the calling worker's GPU without per-call-site changes. + _session_device = threading.local() + + @classmethod + def set_session_device(cls, device: Union[str, torch.device]) -> None: + """Pin the calling thread's execution device. Used by multi-GPU session workers.""" + cls._session_device.device = cls.normalize(device) + + @classmethod + def get_session_device(cls) -> Optional[torch.device]: + """Return the calling thread's pinned execution device, or None if unset.""" + return getattr(cls._session_device, "device", None) + + @classmethod + def clear_session_device(cls) -> None: + """Remove the calling thread's pinned execution device, reverting to global config.""" + if hasattr(cls._session_device, "device"): + del cls._session_device.device + @classmethod def choose_torch_device(cls) -> torch.device: """Return the torch.device to use for accelerated inference.""" + # A worker thread pinned to a specific GPU takes precedence over the global config. + session_device = cls.get_session_device() + if session_device is not None: + return session_device app_config = get_config() if app_config.device != "auto": device = torch.device(app_config.device) diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index a39cf9be514..b22bd1b3aee 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -22,6 +22,7 @@ import type { ImageDTO } from 'services/api/types'; import { useImageViewerContext } from './context'; import { NoContentForViewer } from './NoContentForViewer'; import { ProgressImage } from './ProgressImage2'; +import { ProgressImageTiles } from './ProgressImageTiles'; import { ProgressIndicator } from './ProgressIndicator2'; export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => { @@ -30,9 +31,10 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu const shouldShowItemDetails = useAppSelector(selectShouldShowItemDetails); const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer); const { goToPreviousImage, goToNextImage, isFetching } = useNextPrevItemNavigation(); - const { onLoadImage, $progressEvent, $progressImage } = useImageViewerContext(); + const { onLoadImage, $progressEvent, $progressImage, $activeProgressData } = useImageViewerContext(); const progressEvent = useStore($progressEvent); const progressImage = useStore($progressImage); + const activeProgressData = useStore($activeProgressData); const [imageToRender, setImageToRender] = useState(null); useEffect(() => { @@ -134,6 +136,9 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu }); const withProgress = shouldShowProgressInViewer && progressImage !== null; + // When more than one session is generating concurrently (multi-GPU), tile their previews instead of + // showing only the most recent one. + const withTiledProgress = withProgress && activeProgressData.length > 1; return ( } {withProgress && ( - - {progressEvent && ( - + {withTiledProgress ? ( + + ) : ( + <> + + {progressEvent && ( + + )} + )} )} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx new file mode 100644 index 00000000000..6f66c02e929 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/ProgressImageTiles.tsx @@ -0,0 +1,39 @@ +import { Flex, Grid, GridItem } from '@invoke-ai/ui-library'; +import { memo, useMemo } from 'react'; + +import type { ViewerProgressDatum } from './context'; +import { ProgressImage } from './ProgressImage2'; +import { ProgressIndicator } from './ProgressIndicator2'; + +/** + * Renders one tile per concurrently-running session (multi-GPU). Each tile shows that session's live + * preview image plus a small progress indicator. Used by the viewer when more than one session is + * active; a single active session uses the full-size preview instead. + */ +export const ProgressImageTiles = memo(({ data }: { data: ViewerProgressDatum[] }) => { + // Lay the tiles out in a roughly-square grid that grows with the number of active sessions. + const columns = useMemo(() => Math.ceil(Math.sqrt(data.length)), [data.length]); + + return ( + + {data.map((datum) => ( + + + + + + + ))} + + ); +}); +ProgressImageTiles.displayName = 'ProgressImageTiles'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx index 1cb22d61463..145ab63ba6e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/context.tsx @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { selectAutoSwitch } from 'features/gallery/store/gallerySelectors'; import type { ProgressImage as ProgressImageType } from 'features/nodes/types/common'; import { LRUCache } from 'lru-cache'; -import { type Atom, atom, computed } from 'nanostores'; +import { type Atom, atom, computed, map, type MapStore } from 'nanostores'; import type { PropsWithChildren } from 'react'; import { createContext, memo, useCallback, useContext, useEffect, useMemo, useState } from 'react'; import type { S } from 'services/api/types'; @@ -12,10 +12,24 @@ import { $socket } from 'services/events/stores'; import { assert } from 'tsafe'; import type { JsonObject } from 'type-fest'; +/** Live progress for a single in-flight session (queue item). Used to tile the viewer when several + * sessions run concurrently (multi-GPU). Only items that have produced a preview image are tracked. */ +export type ViewerProgressDatum = { + itemId: number; + progressEvent: S['InvocationProgressEvent']; + progressImage: ProgressImageType; +}; + +type ViewerProgressDataMap = Record; + type ImageViewerContextValue = { $progressEvent: Atom; $progressImage: Atom; $hasProgressImage: Atom; + /** Per-session progress, keyed by queue item id. Drives the tiled multi-session preview. */ + $progressData: MapStore; + /** Active sessions (those with a preview image), sorted by item id for a stable tile order. */ + $activeProgressData: Atom; onLoadImage: () => void; }; @@ -29,6 +43,15 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { const $progressEvent = useState(() => atom(null))[0]; const $progressImage = useState(() => atom(null))[0]; const $hasProgressImage = useState(() => computed($progressImage, (progressImage) => progressImage !== null))[0]; + // Per-session progress, keyed by queue item id, for the tiled multi-session preview (multi-GPU). + const $progressData = useState(() => map({}))[0]; + const $activeProgressData = useState(() => + computed($progressData, (progressData) => + Object.values(progressData) + .filter((datum): datum is ViewerProgressDatum => datum !== undefined) + .sort((a, b) => a.itemId - b.itemId) + ) + )[0]; // We can have race conditions where we receive a progress event for a queue item that has already finished. Easiest // way to handle this is to keep track of finished queue items in a cache and ignore progress events for those. const [finishedQueueItemIds] = useState(() => new LRUCache({ max: 200 })); @@ -49,6 +72,12 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { $progressEvent.set(data); if (data.image) { $progressImage.set(data.image); + // Track per-session so the viewer can tile concurrent sessions (multi-GPU). + $progressData.setKey(data.item_id, { + itemId: data.item_id, + progressEvent: data, + progressImage: data.image, + }); } }; @@ -57,7 +86,7 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { return () => { socket.off('invocation_progress', onInvocationProgress); }; - }, [$progressEvent, $progressImage, finishedQueueItemIds, socket]); + }, [$progressData, $progressEvent, $progressImage, finishedQueueItemIds, socket]); useEffect(() => { if (!socket) { @@ -74,6 +103,9 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { } if (data.status === 'completed' || data.status === 'canceled' || data.status === 'failed') { finishedQueueItemIds.set(data.item_id, true); + // Remove this session's tile from the multi-session preview as soon as it reaches a terminal + // state. The single-image "resolve" illusion below is handled separately via onLoadImage. + $progressData.setKey(data.item_id, undefined); // Completed queue items have the progress event cleared by the onLoadImage callback. This allows the viewer to // create the illusion of the progress image "resolving" into the final image. If we cleared the progress image // now, there would be a flicker where the progress image disappears before the final image appears, and the @@ -103,7 +135,7 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { return () => { socket.off('queue_item_status_changed', onQueueItemStatusChanged); }; - }, [$progressEvent, $progressImage, autoSwitch, finishedQueueItemIds, socket]); + }, [$progressData, $progressEvent, $progressImage, autoSwitch, finishedQueueItemIds, socket]); const onLoadImage = useCallback(() => { $progressEvent.set(null); @@ -111,8 +143,8 @@ export const ImageViewerContextProvider = memo((props: PropsWithChildren) => { }, [$progressEvent, $progressImage]); const value = useMemo( - () => ({ $progressEvent, $progressImage, $hasProgressImage, onLoadImage }), - [$hasProgressImage, $progressEvent, $progressImage, onLoadImage] + () => ({ $progressEvent, $progressImage, $hasProgressImage, $progressData, $activeProgressData, onLoadImage }), + [$hasProgressImage, $progressEvent, $progressImage, $progressData, $activeProgressData, onLoadImage] ); return {props.children}; diff --git a/tests/app/services/model_load/test_model_load_device_routing.py b/tests/app/services/model_load/test_model_load_device_routing.py new file mode 100644 index 00000000000..c9bb107d809 --- /dev/null +++ b/tests/app/services/model_load/test_model_load_device_routing.py @@ -0,0 +1,81 @@ +"""Tests that ModelLoadService routes to the per-device cache for the calling thread (multi-GPU).""" + +import threading + +import torch + +from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config +from invokeai.app.services.model_load.model_load_default import ModelLoadService +from invokeai.backend.util.devices import TorchDevice + + +class _FakeCache: + """Stand-in for ModelCache; ModelLoadService only needs `.execution_device` for keying.""" + + def __init__(self, device: str): + self.execution_device = torch.device(device) + + +def _build_service() -> tuple[ModelLoadService, _FakeCache, _FakeCache]: + cache0 = _FakeCache("cuda:0") + cache1 = _FakeCache("cuda:1") + service = ModelLoadService( + app_config=InvokeAIAppConfig(), + ram_cache=cache0, # type: ignore[arg-type] + ram_caches={"cuda:0": cache0, "cuda:1": cache1}, # type: ignore[arg-type] + ) + return service, cache0, cache1 + + +def test_ram_cache_routes_to_pinned_device(): + """A thread pinned to cuda:1 resolves to that device's cache; the default thread to cuda:0.""" + service, cache0, cache1 = _build_service() + + # The default thread has no session device; point config.device at cuda:0 so it resolves there. + get_config().device = "cuda:0" + assert service.ram_cache is cache0 + + results: dict[str, object] = {} + + def worker(): + TorchDevice.set_session_device("cuda:1") + try: + results["cache"] = service.ram_cache + finally: + TorchDevice.clear_session_device() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert results["cache"] is cache1 + # Main thread is unaffected by the worker's pinning. + assert service.ram_cache is cache0 + + +def test_ram_caches_exposes_all_devices(): + service, cache0, cache1 = _build_service() + caches = service.ram_caches + assert set(caches.keys()) == {"cuda:0", "cuda:1"} + assert caches["cuda:0"] is cache0 + assert caches["cuda:1"] is cache1 + + +def test_unknown_device_falls_back_to_default(): + """A thread pinned to a device with no cache falls back to the default cache.""" + service, cache0, _ = _build_service() + + results: dict[str, object] = {} + + def worker(): + TorchDevice.set_session_device("cuda:7") + try: + results["cache"] = service.ram_cache + finally: + TorchDevice.clear_session_device() + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert results["cache"] is cache0 diff --git a/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py b/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py new file mode 100644 index 00000000000..8d55db941a5 --- /dev/null +++ b/tests/app/services/session_queue/test_session_queue_dequeue_concurrency.py @@ -0,0 +1,70 @@ +"""Tests that concurrent dequeue() calls (multi-GPU session workers) never claim the same item twice.""" + +import threading +import uuid + +import pytest + +from invokeai.app.services.invoker import Invoker +from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue +from invokeai.app.services.shared.graph import Graph, GraphExecutionState +from tests.test_nodes import PromptTestInvocation + + +@pytest.fixture +def session_queue(mock_invoker: Invoker) -> SqliteSessionQueue: + db = mock_invoker.services.board_records._db + queue = SqliteSessionQueue(db=db) + queue.start(mock_invoker) + return queue + + +def _insert_queue_item(session_queue: SqliteSessionQueue, user_id: str = "system") -> int: + graph = Graph() + graph.add_node(PromptTestInvocation(id="prompt", prompt="test")) + session = GraphExecutionState(graph=graph) + session_json = session.model_dump_json(warnings=False, exclude_none=True) + batch_id = str(uuid.uuid4()) + with session_queue._db.transaction() as cursor: + cursor.execute( + """--sql + INSERT INTO session_queue ( + queue_id, session, session_id, batch_id, field_values, priority, + workflow, origin, destination, retried_from_item_id, user_id + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ("default", session_json, session.id, batch_id, None, 0, None, None, None, None, user_id), + ) + return cursor.lastrowid + + +def test_concurrent_dequeue_never_claims_same_item_twice(session_queue: SqliteSessionQueue) -> None: + item_count = 50 + worker_count = 8 + for _ in range(item_count): + _insert_queue_item(session_queue) + + claimed_ids: list[int] = [] + claimed_lock = threading.Lock() + start_barrier = threading.Barrier(worker_count) + + def worker() -> None: + # Release all workers at once to maximize contention on the dequeue path. + start_barrier.wait() + while True: + item = session_queue.dequeue() + if item is None: + break + with claimed_lock: + claimed_ids.append(item.item_id) + + threads = [threading.Thread(target=worker) for _ in range(worker_count)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every item is claimed exactly once: no duplicates, none lost. + assert len(claimed_ids) == item_count + assert len(set(claimed_ids)) == item_count diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 3f134e3c3da..39dee5cb618 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -2,6 +2,7 @@ Test abstract device class. """ +import threading from unittest.mock import patch import pytest @@ -24,6 +25,50 @@ def test_device_choice(device_name): assert torch_device == torch.device(device_name) +# ===== per-thread session device (multi-GPU worker pinning) ================ + + +def test_session_device_overrides_config(): + """A per-thread session device takes precedence over the global config.device.""" + config = get_config() + config.device = "cpu" + try: + TorchDevice.set_session_device("cuda:1") + assert TorchDevice.choose_torch_device() == torch.device("cuda:1") + finally: + TorchDevice.clear_session_device() + # Once cleared, we fall back to the global config. + assert TorchDevice.choose_torch_device() == torch.device("cpu") + + +def test_session_device_is_thread_local(): + """Each thread sees only its own pinned device; the main thread is unaffected.""" + config = get_config() + config.device = "cpu" + results: dict[str, torch.device] = {} + barrier = threading.Barrier(2) + + def worker(name: str, device: str): + TorchDevice.set_session_device(device) + # Wait so both threads have set their device before either reads it, proving isolation. + barrier.wait() + results[name] = TorchDevice.choose_torch_device() + TorchDevice.clear_session_device() + + t0 = threading.Thread(target=worker, args=("a", "cuda:0")) + t1 = threading.Thread(target=worker, args=("b", "cuda:1")) + t0.start() + t1.start() + t0.join() + t1.join() + + assert results["a"] == torch.device("cuda:0") + assert results["b"] == torch.device("cuda:1") + # The main thread never set a session device, so it still uses the global config. + assert TorchDevice.get_session_device() is None + assert TorchDevice.choose_torch_device() == torch.device("cpu") + + @pytest.mark.parametrize("device_dtype_pair", device_types_cpu) def test_device_dtype_cpu(device_dtype_pair): with ( From 6bb89d6ba6ff7f06c1b5057efce84e757acfae8e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 1 Jun 2026 14:49:43 -0400 Subject: [PATCH 02/14] fix(tests): restore global device after multi-GPU cache routing test test_model_load_device_routing mutated the process-wide get_config() singleton (device = "cuda:0") to exercise the per-thread cache routing, but never restored it. The leaked CUDA device was then picked up by a later test (test_model_load::test_loading) via choose_torch_device(), which crashed with "Torch not compiled with CUDA enabled" on the CUDA-less CI runner. Add an autouse fixture to save/restore device and clear any pinned session device. Co-Authored-By: Claude Opus 4.8 --- .../model_load/test_model_load_device_routing.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/app/services/model_load/test_model_load_device_routing.py b/tests/app/services/model_load/test_model_load_device_routing.py index c9bb107d809..85b3868b92f 100644 --- a/tests/app/services/model_load/test_model_load_device_routing.py +++ b/tests/app/services/model_load/test_model_load_device_routing.py @@ -1,7 +1,9 @@ """Tests that ModelLoadService routes to the per-device cache for the calling thread (multi-GPU).""" import threading +from collections.abc import Iterator +import pytest import torch from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config @@ -9,6 +11,19 @@ from invokeai.backend.util.devices import TorchDevice +@pytest.fixture(autouse=True) +def restore_global_device() -> Iterator[None]: + """`get_config()` is a process-wide singleton; restore `device` so we don't leak a CUDA device + into later CPU-only tests (e.g. the model-loading suite on the CUDA-less CI runner).""" + config = get_config() + original_device = config.device + try: + yield + finally: + config.device = original_device + TorchDevice.clear_session_device() + + class _FakeCache: """Stand-in for ModelCache; ModelLoadService only needs `.execution_device` for keying.""" From a3be44423a6dcb772c8ba359bf53830ca6695b30 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 1 Jun 2026 14:51:14 -0400 Subject: [PATCH 03/14] chore(ui): regenerate openapi schema and frontend types for generation_devices Regenerate openapi.json (make frontend-openapi) and the frontend schema.ts types (make frontend-typegen) so they include the new generation_devices config field, fixing the openapi-checks and typegen-checks CI jobs. Co-Authored-By: Claude Opus 4.8 --- invokeai/frontend/web/openapi.json | 18 +++++++++++++++++- .../frontend/web/src/services/api/schema.ts | 5 +++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 1287ee58865..b828412fb0c 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -14375,7 +14375,8 @@ } }, "type": "object", - "title": "CacheStats" + "title": "CacheStats", + "description": "Collect statistics on cache performance." }, "CalculateImageTilesEvenSplitInvocation": { "category": "tiles", @@ -41151,6 +41152,21 @@ "description": "Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", "default": "auto" }, + "generation_devices": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Generation Devices", + "description": "List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)" + }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index a80183476bd..ace2e30178e 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -16506,6 +16506,11 @@ export type components = { * @default auto */ device?: string; + /** + * Generation Devices + * @description List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number) + */ + generation_devices?: string[] | null; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. From be54889811d95b6cf6aefd350f11ba76689bc64a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 1 Jun 2026 15:07:44 -0400 Subject: [PATCH 04/14] fix(ui): regenerate openapi.json with uv to match CI generator `make frontend-openapi` used a bare `python` from a different environment that emitted the CacheStats @dataclass docstring as a schema description. CI generates the schema via `uv run`, which does not, so openapi-checks failed on the diff. Regenerate with the uv-locked environment to drop the stray description while keeping the generation_devices field. Co-Authored-By: Claude Opus 4.8 --- invokeai/frontend/web/openapi.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index b828412fb0c..852dc866bce 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -14375,8 +14375,7 @@ } }, "type": "object", - "title": "CacheStats", - "description": "Collect statistics on cache performance." + "title": "CacheStats" }, "CalculateImageTilesEvenSplitInvocation": { "category": "tiles", From a119b50bc1f27eafcf3087c22b35f28b9ee2e398 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 1 Jun 2026 23:00:29 -0400 Subject: [PATCH 05/14] fix(model-manager): serialize model construction against VRAM moves to prevent meta-device corruption Parallel multi-GPU session workers could intermittently crash with "unrecognized device meta" (denoise) or "Cannot copy out of meta tensor; no data!" (l2i), because model loading relies on process-global, non-thread-safe monkey-patches. accelerate.init_empty_weights() (used directly by the loaders and implicitly by diffusers' default low_cpu_mem_usage=True in from_pretrained) swaps torch.nn.Module.register_parameter globally for the duration of a load, routing every newly-registered parameter to the meta device. The model cache's VRAM load/unload runs nn.Module.load_state_dict(assign=True), whose assign path does setattr -> __setattr__ -> register_parameter. When one worker's VRAM move overlapped another worker's from_pretrained, the move's real weights got hijacked onto meta and blew up on the next .to(device). Introduce MODEL_LOAD_LOCK, a write-preferring readers-writer lock: - write lock = model construction (_load_and_cache, load_model_from_path), exclusive. - read lock = VRAM load/unload (ModelCache.lock(), repair_required_tensors_on_device). VRAM transfers across GPUs still overlap each other; they only block while a construction holds the write lock. The lock is always acquired before any per-cache lock to keep a consistent order and avoid an AB-BA deadlock with the writer's make_room/put. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../services/model_load/model_load_default.py | 17 +++-- .../backend/model_manager/load/load_base.py | 20 ++++-- .../model_manager/load/load_default.py | 64 ++++++++++++----- .../load/model_cache/model_cache.py | 70 ++++++++++++++++++- 4 files changed, 143 insertions(+), 28 deletions(-) diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 45d0c354278..33c7ef6108c 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -18,7 +18,7 @@ ModelLoaderRegistry, ModelLoaderRegistryBase, ) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import MODEL_LOAD_LOCK, ModelCache from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType from invokeai.backend.util.devices import TorchDevice @@ -147,6 +147,15 @@ def diffusers_load_directory(directory: Path) -> AnyModel: else lambda path: safetensors_load_file(path, device="cpu") ) assert loader is not None - raw_model = loader(model_path) - ram_cache.put(key=cache_key, model=raw_model) - return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) + # Serialize construction (see MODEL_LOAD_LOCK): the diffusers loader path uses the same + # process-global, non-thread-safe monkey-patches as the main loader, so it takes the write + # lock to exclude concurrent VRAM moves. Re-check the cache after acquiring the lock in case + # a worker sharing this cache built it while we waited. + with MODEL_LOAD_LOCK.write_lock(): + try: + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) + except IndexError: + pass + raw_model = loader(model_path) + ram_cache.put(key=cache_key, model=raw_model) + return LoadedModelWithoutConfig(cache_record=ram_cache.get(key=cache_key), cache=ram_cache) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 4609a2e92ab..984362f185d 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -17,7 +17,7 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache +from invokeai.backend.model_manager.load.model_cache.model_cache import MODEL_LOAD_LOCK, ModelCache from invokeai.backend.model_manager.taxonomy import AnyModel, SubModelType @@ -57,7 +57,12 @@ def __init__(self, cache_record: CacheRecord, cache: ModelCache): self._cache = cache def __enter__(self) -> AnyModel: - self._cache.lock(self._cache_record, None) + # Hold the MODEL_LOAD_LOCK read lock across the VRAM load (lock() runs + # load_state_dict(assign=True), which calls register_parameter) so it can't overlap a + # concurrent model construction that has the global register_parameter -> meta patch active. + # Acquired before the cache's own lock to keep a consistent lock order (see MODEL_LOAD_LOCK). + with MODEL_LOAD_LOCK.read_lock(): + self._cache.lock(self._cache_record, None) try: self.repair_required_tensors_on_device() return self.model @@ -77,7 +82,9 @@ def model_on_device( :param working_mem_bytes: The amount of working memory to keep available on the compute device when loading the model. """ - self._cache.lock(self._cache_record, working_mem_bytes) + # See __enter__ for why the VRAM load is wrapped in the read lock. + with MODEL_LOAD_LOCK.read_lock(): + self._cache.lock(self._cache_record, working_mem_bytes) try: self.repair_required_tensors_on_device() yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model) @@ -94,7 +101,12 @@ def repair_required_tensors_on_device(self) -> int: cached_model = self._cache_record.cached_model if not isinstance(cached_model, CachedModelWithPartialLoad): return 0 - return cached_model.repair_required_tensors_on_compute_device() + # Repair runs load_state_dict(assign=True) -> register_parameter, so it must hold the read + # lock to avoid being hijacked onto the `meta` device by a concurrent construction. This is + # also called directly (outside __enter__/model_on_device) by some text-encoder invocations, + # so the guard lives here rather than only at the call sites. + with MODEL_LOAD_LOCK.read_lock(): + return cached_model.repair_required_tensors_on_compute_device() class LoadedModel(LoadedModelWithoutConfig): diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 040b55cb6ec..02929ff6132 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -13,7 +13,11 @@ from invokeai.backend.model_manager.configs.factory import AnyModelConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord -from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache, get_model_cache_key +from invokeai.backend.model_manager.load.model_cache.model_cache import ( + MODEL_LOAD_LOCK, + ModelCache, + get_model_cache_key, +) from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.model_manager.taxonomy import ( @@ -52,7 +56,9 @@ ) -# TO DO: The loader is not thread safe! +# The construction path is not thread-safe on its own; it monkey-patches process-global torch state +# (see MODEL_LOAD_LOCK). Concurrent callers must hold the MODEL_LOAD_LOCK write lock (see +# _load_and_cache). class ModelLoader(ModelLoaderBase): """Default implementation of ModelLoaderBase.""" @@ -85,8 +91,7 @@ def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubMo if not model_path.exists(): raise FileNotFoundError(f"Files for model '{model_config.name}' not found at {model_path}") - with skip_torch_weight_init(): - cache_record = self._load_and_cache(model_config, submodel_type) + cache_record = self._load_and_cache(model_config, submodel_type) return LoadedModel(config=model_config, cache_record=cache_record, cache=self._ram_cache) @property @@ -124,25 +129,46 @@ def _get_execution_device( def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> CacheRecord: stats_name = ":".join([config.base, config.type, config.name, (submodel_type or "")]) + cache_key = get_model_cache_key(config.key, submodel_type) try: - return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) + return self._ram_cache.get(key=cache_key, stats_name=stats_name) except IndexError: pass - config.path = str(self._get_model_path(config)) - self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) - loaded_model = self._load_model(config, submodel_type) - - # Determine execution device from model config, considering submodel type - execution_device = self._get_execution_device(config, submodel_type) - - self._ram_cache.put( - get_model_cache_key(config.key, submodel_type), - model=loaded_model, - execution_device=execution_device, - ) - - return self._ram_cache.get(key=get_model_cache_key(config.key, submodel_type), stats_name=stats_name) + # Cache miss: construct the model from disk. This path holds the MODEL_LOAD_LOCK *write* + # lock because it relies on process-global, non-thread-safe monkey-patches + # (skip_torch_weight_init and, inside the loaders, accelerate.init_empty_weights / diffusers + # low_cpu_mem_usage). The write lock excludes both other constructions AND concurrent VRAM + # load/unload on other workers (which take the read lock); without that, a concurrent move's + # load_state_dict(assign=True) -> register_parameter gets hijacked onto the `meta` device. + # See MODEL_LOAD_LOCK for the full explanation. + # + # Lock-ordering: the write lock is acquired before any ModelCache._lock taken below + # (get/make_room/put), matching the readers' order, so there is no AB-BA deadlock. + with MODEL_LOAD_LOCK.write_lock(): + # Double-checked locking: another worker sharing this cache may have loaded the same + # entry while we waited for the mutex. (Workers on other devices use a different cache, + # so they will still miss here and construct their own copy — which is intended.) + try: + return self._ram_cache.get(key=cache_key, stats_name=stats_name) + except IndexError: + pass + + config.path = str(self._get_model_path(config)) + self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type)) + with skip_torch_weight_init(): + loaded_model = self._load_model(config, submodel_type) + + # Determine execution device from model config, considering submodel type + execution_device = self._get_execution_device(config, submodel_type) + + self._ram_cache.put( + cache_key, + model=loaded_model, + execution_device=execution_device, + ) + + return self._ram_cache.get(key=cache_key, stats_name=stats_name) def get_size_fs( self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index 1196a0f3885..2ca8dd44ba2 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -2,10 +2,11 @@ import logging import threading import time +from contextlib import contextmanager from dataclasses import dataclass from functools import wraps from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Protocol +from typing import Any, Callable, Dict, Generator, List, Optional, Protocol import psutil import torch @@ -35,6 +36,73 @@ MB = 2**20 +class _ModelLoadReadWriteLock: + """A write-preferring readers-writer lock that serializes model construction against VRAM moves. + + The model load machinery depends on PROCESS-GLOBAL monkey-patches that are not thread-safe: + model CONSTRUCTION (diffusers `from_pretrained` / `accelerate.init_empty_weights`) temporarily + replaces `torch.nn.Module.register_parameter` so that every newly-registered parameter is routed + to the `meta` device. While that patch is installed, ANY `register_parameter` call in ANY thread + is hijacked onto `meta`. VRAM load/unload uses `nn.Module.load_state_dict(assign=True)`, which + assigns `Parameter`s via `__setattr__` -> `register_parameter` — so if it runs concurrently with + a construction on another worker thread, its real weights get stranded on `meta`. That surfaces + later as "Cannot copy out of meta tensor; no data!" or "unrecognized device meta". + + - Construction takes the WRITE lock (exclusive — no reader and no other writer may run). + - VRAM load/unload takes the READ lock (shared, so concurrent moves on different GPUs still + overlap each other; they only block while a construction holds the write lock). + + Write-preferring: once a construction is waiting, new readers queue behind it, so a steady stream + of VRAM moves from busy workers can't starve a pending load. + + Lock-ordering contract: callers MUST acquire this lock *before* any `ModelCache._lock`, never + after. Readers do so by taking the read lock around the outer `ModelCache.lock()` call (see + `LoadedModelWithoutConfig`), and writers around the whole construction (see + `ModelLoader._load_and_cache`). Acquiring it in the other order — cache lock first, then this + lock — would risk an AB-BA deadlock with a writer that takes a cache lock during `put()`. + """ + + def __init__(self) -> None: + self._cond = threading.Condition(threading.Lock()) + self._readers = 0 + self._writers_waiting = 0 + self._writer_active = False + + @contextmanager + def read_lock(self) -> Generator[None, None, None]: + with self._cond: + # Defer to any active or waiting writer (write-preferring). + while self._writer_active or self._writers_waiting > 0: + self._cond.wait() + self._readers += 1 + try: + yield + finally: + with self._cond: + self._readers -= 1 + if self._readers == 0: + self._cond.notify_all() + + @contextmanager + def write_lock(self) -> Generator[None, None, None]: + with self._cond: + self._writers_waiting += 1 + while self._writer_active or self._readers > 0: + self._cond.wait() + self._writers_waiting -= 1 + self._writer_active = True + try: + yield + finally: + with self._cond: + self._writer_active = False + self._cond.notify_all() + + +# Process-global lock guarding the non-thread-safe model load machinery. See _ModelLoadReadWriteLock. +MODEL_LOAD_LOCK = _ModelLoadReadWriteLock() + + # TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels. def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str: """Get the cache key for a model based on the optional submodel type.""" From 70114464698543acbe78c6b6191d5c50faabe119 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 16:34:44 -0400 Subject: [PATCH 06/14] fix(backend): fix outpainting crash caused by model download collisions --- .../model_install/model_install_default.py | 67 +++++++++++++------ 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 49d3cfdf7f9..5f70fc53838 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -114,6 +114,11 @@ def __init__( self._install_completed_event = threading.Event() self._download_queue = download_queue self._download_cache: Dict[int, ModelInstallJob] = {} + # Per-source locks serializing download_and_cache_model() so parallel (multi-GPU) sessions + # that need the same remote model (e.g. the LaMa infill model) don't race to download into + # the same cache directory. _download_cache_locks_guard protects the dict itself. + self._download_cache_locks: Dict[str, threading.Lock] = {} + self._download_cache_locks_guard = threading.Lock() self._running = False self._session = session self._install_thread: Optional[threading.Thread] = None @@ -711,27 +716,47 @@ def download_and_cache_model( if len(contents) > 0: return contents[0] - model_path.mkdir(parents=True, exist_ok=True) - model_source = self._guess_source(str(source)) - remote_files, _ = self._remote_files_from_source(model_source) - # Handle multiple subfolders for HFModelSource - subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else [] - job = self._multifile_download( - dest=model_path, - remote_files=remote_files, - subfolder=model_source.subfolder - if isinstance(model_source, HFModelSource) and len(subfolders) <= 1 - else None, - subfolders=subfolders if len(subfolders) > 1 else None, - ) - files_string = "file" if len(remote_files) == 1 else "files" - self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") - self._download_queue.wait_for_job(job) - if job.complete: - assert job.download_path is not None - return job.download_path - else: - raise Exception(job.error) + # Serialize concurrent downloads of the same source. Parallel multi-GPU sessions can each + # request the same remote model (e.g. the LaMa infill model) at once; without this lock they + # both download into the same cache directory and collide on the final rename, which fails on + # Windows with "WinError 32: the file is being used by another process". The other waiters + # find the completed download on the post-lock re-check below and skip downloading. + with self._download_cache_lock(str(source)): + if model_path.exists(): + contents = list(model_path.iterdir()) + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + model_source = self._guess_source(str(source)) + remote_files, _ = self._remote_files_from_source(model_source) + # Handle multiple subfolders for HFModelSource + subfolders = model_source.subfolders if isinstance(model_source, HFModelSource) else [] + job = self._multifile_download( + dest=model_path, + remote_files=remote_files, + subfolder=model_source.subfolder + if isinstance(model_source, HFModelSource) and len(subfolders) <= 1 + else None, + subfolders=subfolders if len(subfolders) > 1 else None, + ) + files_string = "file" if len(remote_files) == 1 else "files" + self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") + self._download_queue.wait_for_job(job) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + + def _download_cache_lock(self, source: str) -> threading.Lock: + """Return the lock that serializes downloads for a given source, creating it on first use.""" + with self._download_cache_locks_guard: + lock = self._download_cache_locks.get(source) + if lock is None: + lock = threading.Lock() + self._download_cache_locks[source] = lock + return lock def _remote_files_from_source( self, source: ModelSource From a1fe3757f051d21893f53b3b3cddd0aca703f819 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 16:42:18 -0400 Subject: [PATCH 07/14] fix(backend): make DiskImageFileStorage thread-safe for parallel sessions Image.open() is lazy: it reads the header but defers pixel decoding (and holds the file handle open) until the first .load()/.copy()/.convert(). The opened object was cached and the same object handed to every caller, so in multi-GPU parallel mode two session-processor worker threads could call .copy() on it concurrently and race on the shared file handle and decoder state. This surfaced as "broken data stream when reading image file" and "AssertionError: self.png is not None" during inpainting with batch >1. Force the decode (image.load()) before the object enters the cache so the cached object is safe for concurrent reads, and guard the cache structures (__cache / __cache_ids) with a lock since they are now mutated from multiple threads. Co-Authored-By: Claude Opus 4.8 --- .../services/image_files/image_files_disk.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/invokeai/app/services/image_files/image_files_disk.py b/invokeai/app/services/image_files/image_files_disk.py index 12b737a7cf1..ec84439547a 100644 --- a/invokeai/app/services/image_files/image_files_disk.py +++ b/invokeai/app/services/image_files/image_files_disk.py @@ -1,4 +1,5 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team +import threading from pathlib import Path from queue import Queue from typing import Optional, Union @@ -23,6 +24,9 @@ def __init__(self, output_folder: Union[str, Path]): self.__cache: dict[Path, PILImageType] = {} self.__cache_ids = Queue[Path]() self.__max_cache_size = 10 # TODO: get this from config + # Guards the cache structures (__cache / __cache_ids), which are read and mutated from + # multiple session-processor worker threads in multi-GPU parallel mode. + self.__cache_lock = threading.Lock() self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder) self.__thumbnails_folder = self.__output_folder / "thumbnails" @@ -41,6 +45,13 @@ def get(self, image_name: str, image_subfolder: str = "") -> PILImageType: return cache_item image = Image.open(image_path) + # Image.open() is lazy: it reads the header but defers pixel decoding (and holds the + # file handle open) until the first .load()/.copy()/.convert(). The opened object is + # cached and the SAME object is handed to every caller, so in multi-GPU parallel mode + # two worker threads can call .copy() on it concurrently and race on the shared file + # handle and decoder state, producing "broken data stream" / "self.png is not None" + # errors. Forcing the decode here makes the cached object safe for concurrent reads. + image.load() self.__set_cache(image_path, image) return image except FileNotFoundError as e: @@ -105,16 +116,18 @@ def delete(self, image_name: str, image_subfolder: str = "") -> None: if image_path.exists(): image_path.unlink() - if image_path in self.__cache: - del self.__cache[image_path] thumbnail_name = get_thumbnail_name(image_name) thumbnail_path = self.get_path(thumbnail_name, True, image_subfolder=image_subfolder) if thumbnail_path.exists(): thumbnail_path.unlink() - if thumbnail_path in self.__cache: - del self.__cache[thumbnail_path] + + with self.__cache_lock: + if image_path in self.__cache: + del self.__cache[image_path] + if thumbnail_path in self.__cache: + del self.__cache[thumbnail_path] except Exception as e: raise ImageFileDeleteException from e @@ -185,13 +198,15 @@ def __validate_storage_folders(self) -> None: folder.mkdir(parents=True, exist_ok=True) def __get_cache(self, image_name: Path) -> Optional[PILImageType]: - return None if image_name not in self.__cache else self.__cache[image_name] + with self.__cache_lock: + return None if image_name not in self.__cache else self.__cache[image_name] def __set_cache(self, image_name: Path, image: PILImageType): - if image_name not in self.__cache: - self.__cache[image_name] = image - self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache - if len(self.__cache) > self.__max_cache_size: - cache_id = self.__cache_ids.get() - if cache_id in self.__cache: - del self.__cache[cache_id] + with self.__cache_lock: + if image_name not in self.__cache: + self.__cache[image_name] = image + self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache + if len(self.__cache) > self.__max_cache_size: + cache_id = self.__cache_ids.get() + if cache_id in self.__cache: + del self.__cache[cache_id] From 3db88d51ebcf2bbd07bc01ec434c574c0effc15c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 17:09:23 -0400 Subject: [PATCH 08/14] feat(ui): stack per-session progress bars during parallel generation The generation progress bars (under the Invoke button and the Viewer tab) both read a single global $lastProgressEvent atom, which every session overwrites. With parallel multi-GPU sessions this made the bar jump back and forth between sessions. Track progress per queue item id and render one bar per in-flight session, stacked vertically, each removed as its session reaches a terminal state. - stores.ts: add $progressEvents (map keyed by item_id), $activeProgressEvents (sorted), and set/clear helpers. - setEventListeners.tsx: populate per-item progress on invocation_progress; clear per item on terminal status; clear all on connect/disconnect/queue cleared. - ProgressBar.tsx: render a vertical stack of bars (one per active session) with a single-bar fallback for the idle / model-loading window; add containerProps so dockview tabs can position the stack. - Dockview tab call sites: move positioning into containerProps. Co-Authored-By: Claude Opus 4.8 --- .../system/components/ProgressBar.tsx | 90 ++++++++++--------- .../ui/layouts/DockviewTabCanvasViewer.tsx | 6 +- .../ui/layouts/DockviewTabCanvasWorkspace.tsx | 6 +- .../ui/layouts/DockviewTabProgress.tsx | 6 +- .../src/services/events/setEventListeners.tsx | 15 +++- .../web/src/services/events/stores.ts | 29 +++++- 6 files changed, 103 insertions(+), 49 deletions(-) diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 5a4abdd4d28..6a305416bbf 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -1,62 +1,64 @@ -import type { ProgressProps } from '@invoke-ai/ui-library'; -import { Progress } from '@invoke-ai/ui-library'; +import type { FlexProps, ProgressProps } from '@invoke-ai/ui-library'; +import { Flex, Progress } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; -import { $isConnected, $lastProgressEvent, $loadingModelsCount } from 'services/events/stores'; +import { $activeProgressEvents, $isConnected, $loadingModelsCount } from 'services/events/stores'; -const ProgressBar = (props: ProgressProps) => { +type ProgressBarProps = ProgressProps & { + /** Applied to the Flex that stacks the per-session bars. Use for positioning (e.g. absolute). */ + containerProps?: FlexProps; +}; + +type BarDescriptor = { + key: number | string; + value: number; + isIndeterminate: boolean; +}; + +const ProgressBar = ({ containerProps, ...props }: ProgressBarProps) => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); const isConnected = useStore($isConnected); - const lastProgressEvent = useStore($lastProgressEvent); + const activeProgressEvents = useStore($activeProgressEvents); const loadingModelsCount = useStore($loadingModelsCount); - const value = useMemo(() => { - if (!lastProgressEvent) { - return 0; - } - return (lastProgressEvent.percentage ?? 0) * 100; - }, [lastProgressEvent]); - - const isIndeterminate = useMemo(() => { - if (!isConnected) { - return false; - } - if (loadingModelsCount > 0) { - return true; + const bars = useMemo(() => { + // One bar per in-flight session (multi-GPU). Each session's progress is tracked independently, so + // the bars no longer jump back and forth when several sessions render simultaneously. + if (activeProgressEvents.length > 0) { + return activeProgressEvents.map((event) => ({ + key: event.item_id, + value: (event.percentage ?? 0) * 100, + isIndeterminate: isConnected && (loadingModelsCount > 0 || event.percentage === null || event.percentage === 0), + })); } - if (!queueStatus?.queue.in_progress) { - return false; + // Fallback single bar: idle, or generation has started but no progress event has arrived yet (e.g. + // while models are loading). Mirrors the previous single-bar indeterminate behavior. + let isIndeterminate = false; + if (isConnected && (loadingModelsCount > 0 || Boolean(queueStatus?.queue.in_progress))) { + isIndeterminate = true; } - - if (!lastProgressEvent) { - return true; - } - - if (lastProgressEvent.percentage === null) { - return true; - } - - if (lastProgressEvent.percentage === 0) { - return true; - } - - return false; - }, [isConnected, lastProgressEvent, queueStatus?.queue.in_progress, loadingModelsCount]); + return [{ key: 'idle', value: 0, isIndeterminate }]; + }, [activeProgressEvents, isConnected, loadingModelsCount, queueStatus?.queue.in_progress]); return ( - + + {bars.map((bar) => ( + + ))} + ); }; diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx index 80f851ab7af..a53e0c3c4cb 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx @@ -34,7 +34,11 @@ export const DockviewTabCanvasViewer = memo((props: IDockviewPanelHeaderProps {currentQueueItemDestination === 'canvas' && isGenerationInProgress && ( - + )}
); diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx index 440847d7451..f96381511fc 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx @@ -37,7 +37,11 @@ export const DockviewTabCanvasWorkspace = memo((props: IDockviewPanelHeaderProps {t(props.params.i18nKey)} {currentQueueItemDestination === canvasSessionId && isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx index 1d997caaf78..180babf8191 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx @@ -32,7 +32,11 @@ export const DockviewTabProgress = memo((props: IDockviewPanelHeaderProps {isGenerationInProgress && ( - + )} ); diff --git a/invokeai/frontend/web/src/services/events/setEventListeners.tsx b/invokeai/frontend/web/src/services/events/setEventListeners.tsx index 1e73abb2027..fa8e2895ba3 100644 --- a/invokeai/frontend/web/src/services/events/setEventListeners.tsx +++ b/invokeai/frontend/web/src/services/events/setEventListeners.tsx @@ -53,7 +53,13 @@ import type { ClientToServerEvents, ServerToClientEvents } from 'services/events import type { Socket } from 'socket.io-client'; import type { JsonObject } from 'type-fest'; -import { $lastProgressEvent, $loadingModelsCount } from './stores'; +import { + $lastProgressEvent, + $loadingModelsCount, + clearAllProgressEvents, + clearProgressEvent, + setProgressEvent, +} from './stores'; const log = logger('events'); @@ -84,6 +90,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.emit('subscribe_queue', { queue_id: 'default' }); socket.emit('subscribe_bulk_download', { bulk_download_id: 'default' }); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); }); @@ -91,6 +98,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.debug('Connect error'); setIsConnected(false); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); if (error && error.message) { const data: string | undefined = (error as unknown as { data: string | undefined }).data; @@ -108,6 +116,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis socket.on('disconnect', () => { log.debug('Disconnected'); $lastProgressEvent.set(null); + clearAllProgressEvents(); $loadingModelsCount.set(0); setIsConnected(false); }); @@ -148,6 +157,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis log.trace({ data } as JsonObject, _message); $lastProgressEvent.set(data); + setProgressEvent(data); if (origin === 'workflows') { const nes = $nodeExecutionStates.get()[invocation_source_id]; @@ -491,11 +501,14 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis } // If the queue item is completed, failed, or cancelled, we want to clear the last progress event $lastProgressEvent.set(null); + // Also remove this session's per-item progress so its stacked progress bar disappears. + clearProgressEvent(item_id); } }); socket.on('queue_cleared', (data) => { log.debug({ data }, 'Queue cleared'); + clearAllProgressEvents(); dispatch( queueApi.util.invalidateTags([ 'SessionQueueStatus', diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index 180f4a3a636..95c88bc28cc 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -1,5 +1,5 @@ import { round } from 'es-toolkit/compat'; -import { atom, computed } from 'nanostores'; +import { atom, computed, map } from 'nanostores'; import type { S } from 'services/api/types'; import type { AppSocket } from 'services/events/types'; @@ -8,6 +8,33 @@ export const $isConnected = atom(false); export const $lastProgressEvent = atom(null); export const $loadingModelsCount = atom(0); +/** + * Live progress events keyed by queue item id. Unlike `$lastProgressEvent` (a single global value that + * is overwritten by whichever session reported last), this tracks each in-flight session separately so + * the UI can render one progress bar per concurrent session (multi-GPU). Entries are added as progress + * events arrive and removed when the session reaches a terminal state. + */ +export const $progressEvents = map>({}); + +/** In-flight sessions sorted by queue item id, for a stable top-to-bottom bar order. */ +export const $activeProgressEvents = computed($progressEvents, (events) => + Object.values(events) + .filter((event): event is S['InvocationProgressEvent'] => event !== undefined) + .sort((a, b) => a.item_id - b.item_id) +); + +export const setProgressEvent = (event: S['InvocationProgressEvent']) => { + $progressEvents.setKey(event.item_id, event); +}; + +export const clearProgressEvent = (itemId: number) => { + $progressEvents.setKey(itemId, undefined); +}; + +export const clearAllProgressEvents = () => { + $progressEvents.set({}); +}; + export const $lastProgressMessage = computed($lastProgressEvent, (val) => { if (!val) { return null; From 351758c60643365e85d302b4b48d778341a83224 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 17:13:35 -0400 Subject: [PATCH 09/14] fix(ui): make $progressEvents module-local to satisfy knip $progressEvents is only referenced within stores.ts (via the $activeProgressEvents computed and the set/clear helpers), so exporting it tripped knip's unused-exports check. Co-Authored-By: Claude Opus 4.8 --- invokeai/frontend/web/src/services/events/stores.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/services/events/stores.ts b/invokeai/frontend/web/src/services/events/stores.ts index 95c88bc28cc..7c7630e2019 100644 --- a/invokeai/frontend/web/src/services/events/stores.ts +++ b/invokeai/frontend/web/src/services/events/stores.ts @@ -14,7 +14,7 @@ export const $loadingModelsCount = atom(0); * the UI can render one progress bar per concurrent session (multi-GPU). Entries are added as progress * events arrive and removed when the session reaches a terminal state. */ -export const $progressEvents = map>({}); +const $progressEvents = map>({}); /** In-flight sessions sorted by queue item id, for a stable top-to-bottom bar order. */ export const $activeProgressEvents = computed($progressEvents, (events) => From 4f6613f0002f38401b4cba4f5539d8864b6af44d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 17:23:44 -0400 Subject: [PATCH 10/14] fix(ui): cap stacked tab progress bars to fit below the tab label With 4 GPUs the stacked per-session progress bars grew past the bottom strip of the dockview tab and overlapped the "Viewer" label. Add a fitHeightPx prop: in fit mode the stack is capped to the available strip (10px below the ~40px tab's centered label) and the bars flex to share it, shrinking below their natural height only once they no longer fit. With 1-2 sessions the bars keep their familiar thin height; with 3+ they scale down to stay within the strip. The sidebar bar is unaffected and continues to stack at natural height (it has the vertical room). Co-Authored-By: Claude Opus 4.8 --- .../system/components/ProgressBar.tsx | 29 +++++++++++++++++-- .../ui/layouts/DockviewTabCanvasViewer.tsx | 2 +- .../ui/layouts/DockviewTabCanvasWorkspace.tsx | 2 +- .../ui/layouts/DockviewTabProgress.tsx | 2 +- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx index 6a305416bbf..a38b2603625 100644 --- a/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx +++ b/invokeai/frontend/web/src/features/system/components/ProgressBar.tsx @@ -6,9 +6,20 @@ import { useTranslation } from 'react-i18next'; import { useGetQueueStatusQuery } from 'services/api/endpoints/queue'; import { $activeProgressEvents, $isConnected, $loadingModelsCount } from 'services/events/stores'; +// In "fit" mode (e.g. the strip below a dockview tab label) the stack is constrained to a fixed height. +// Bars stay at FIT_BAR_HEIGHT_PX while they fit, then shrink to share the available space so they never +// overlap the label, no matter how many sessions are running. +const FIT_BAR_HEIGHT_PX = 4; +const FIT_BAR_GAP_PX = 1; + type ProgressBarProps = ProgressProps & { /** Applied to the Flex that stacks the per-session bars. Use for positioning (e.g. absolute). */ containerProps?: FlexProps; + /** + * When set, the stacked bars are constrained to this total height (in px) and shrink to share it, so + * they never grow past the available space (e.g. the strip below a dockview tab label). + */ + fitHeightPx?: number; }; type BarDescriptor = { @@ -17,7 +28,7 @@ type BarDescriptor = { isIndeterminate: boolean; }; -const ProgressBar = ({ containerProps, ...props }: ProgressBarProps) => { +const ProgressBar = ({ containerProps, fitHeightPx, ...props }: ProgressBarProps) => { const { t } = useTranslation(); const { data: queueStatus } = useGetQueueStatusQuery(); const isConnected = useStore($isConnected); @@ -44,8 +55,21 @@ const ProgressBar = ({ containerProps, ...props }: ProgressBarProps) => { return [{ key: 'idle', value: 0, isIndeterminate }]; }, [activeProgressEvents, isConnected, loadingModelsCount, queueStatus?.queue.in_progress]); + // In fit mode, cap the whole stack to the available strip and let the bars flex to share it. When the + // bars fit at their natural height the stack is shorter than the cap; once they don't, they shrink. + const isFit = fitHeightPx !== undefined; + const fitContainerProps = useMemo(() => { + if (!isFit) { + return undefined; + } + const naturalHeight = bars.length * FIT_BAR_HEIGHT_PX + Math.max(0, bars.length - 1) * FIT_BAR_GAP_PX; + return { h: `${Math.min(naturalHeight, fitHeightPx)}px`, gap: `${FIT_BAR_GAP_PX}px` }; + }, [bars.length, fitHeightPx, isFit]); + + const fitBarProps: ProgressProps | undefined = isFit ? { flex: '1 1 0', minH: 0, h: 'auto' } : undefined; + return ( - + {bars.map((bar) => ( { w="full" colorScheme="invokeBlue" {...props} + {...fitBarProps} /> ))} diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx index a53e0c3c4cb..62246faa0f8 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasViewer.tsx @@ -35,8 +35,8 @@ export const DockviewTabCanvasViewer = memo((props: IDockviewPanelHeaderProps {currentQueueItemDestination === 'canvas' && isGenerationInProgress && ( )} diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx index f96381511fc..285afa3a1b6 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabCanvasWorkspace.tsx @@ -38,8 +38,8 @@ export const DockviewTabCanvasWorkspace = memo((props: IDockviewPanelHeaderProps {currentQueueItemDestination === canvasSessionId && isGenerationInProgress && ( )} diff --git a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx index 180babf8191..c89f682e66a 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/DockviewTabProgress.tsx @@ -33,8 +33,8 @@ export const DockviewTabProgress = memo((props: IDockviewPanelHeaderProps {isGenerationInProgress && ( )} From 2a65c4aa6e381d104bfee69bcd89ecf185d41f51 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 17:34:26 -0400 Subject: [PATCH 11/14] feat(config): support "auto" generation_devices to use all GPUs by default MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit generation_devices now accepts "auto" (the new default), which expands to every visible CUDA device — so multi-GPU parallel generation works out of the box without manually listing devices. On GPU-less systems "auto" resolves to the single cpu/mps device, preserving serial behavior. - config_default.py: type is now Union[Literal["auto"], list[str]], default "auto"; validator accepts "auto" or a list of device strings. - devices.py: add TorchDevice.get_generation_devices(), the single resolver that expands "auto", normalizes, and deduplicates. - session_processor / model_manager: both consumers use the resolver instead of iterating the raw config value (which would have iterated the characters of the "auto" string). - Regenerated docs/src/generated/settings.json. - Tests for the resolver (auto-with/without-CUDA, dedup, empty). An explicit single-device list (e.g. [cuda:0]) or an empty list opts out of parallelism. Co-Authored-By: Claude Opus 4.8 --- docs/src/generated/settings.json | 6 +-- .../app/services/config/config_default.py | 10 ++--- .../model_manager/model_manager_default.py | 9 ++-- .../session_processor_default.py | 19 ++++----- invokeai/backend/util/devices.py | 28 +++++++++++++ tests/backend/util/test_devices.py | 41 +++++++++++++++++++ 6 files changed, 88 insertions(+), 25 deletions(-) diff --git a/docs/src/generated/settings.json b/docs/src/generated/settings.json index eb26d39960f..35cea553b96 100644 --- a/docs/src/generated/settings.json +++ b/docs/src/generated/settings.json @@ -492,13 +492,13 @@ }, { "category": "DEVICE", - "default": null, - "description": "List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", + "default": "auto", + "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", "env_var": "INVOKEAI_GENERATION_DEVICES", "literal_values": [], "name": "generation_devices", "required": false, - "type": "typing.Optional[list[str]]", + "type": "typing.Union[typing.Literal['auto'], list[str]]", "validation": {} }, { diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index a70f5f7e97c..4d9755654a3 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -11,7 +11,7 @@ import shutil from functools import lru_cache from pathlib import Path -from typing import Any, Literal, Optional +from typing import Any, Literal, Optional, Union import yaml from pydantic import BaseModel, Field, PrivateAttr, field_validator @@ -205,7 +205,7 @@ class InvokeAIAppConfig(BaseSettings): # DEVICE device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$") - generation_devices: Optional[list[str]] = Field(default=None, description="List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)") + generation_devices: Union[Literal["auto"], list[str]] = Field(default="auto", description="Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)") precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.") # GENERATION @@ -260,14 +260,14 @@ class InvokeAIAppConfig(BaseSettings): @field_validator("generation_devices") @classmethod - def validate_generation_devices(cls, v: Optional[list[str]]) -> Optional[list[str]]: - if v is None: + def validate_generation_devices(cls, v: Union[str, list[str]]) -> Union[str, list[str]]: + if v == "auto": return v pattern = re.compile(r"^(cpu|mps|cuda(:\d+)?)$") for device in v: if not pattern.match(device): raise ValueError( - f"Invalid generation device '{device}'. Valid values are 'cpu', 'mps', 'cuda', or 'cuda:N'." + f"Invalid generation device '{device}'. Valid values are 'auto', 'cpu', 'mps', 'cuda', or 'cuda:N'." ) return v diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index eaeb5d4e612..b7680524a34 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -108,11 +108,10 @@ def build_cache(device: torch.device) -> ModelCache: # worker is pinned to a device (see TorchDevice.set_session_device) and resolves to its own # cache. The default cache is always included by ModelLoadService. ram_caches: dict[str, ModelCache] = {str(TorchDevice.normalize(default_device)): ram_cache} - if app_config.generation_devices: - for device_str in app_config.generation_devices: - key = str(TorchDevice.normalize(device_str)) - if key not in ram_caches: - ram_caches[key] = build_cache(torch.device(key)) + for device in TorchDevice.get_generation_devices(app_config.generation_devices): + key = str(device) + if key not in ram_caches: + ram_caches[key] = build_cache(device) loader = ModelLoadService( app_config=app_config, diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index c6d566255b2..c6edb5069f8 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -347,21 +347,16 @@ def __init__( def _resolve_devices(self) -> list[Optional[torch.device]]: """Determine the per-worker devices from config. - Returns a single `None` (legacy single-worker, device chosen by the global config) unless - `generation_devices` is configured, in which case it returns one normalized device per - listed device (deduplicated, order preserved). + Resolves `generation_devices` (which defaults to `"auto"` — every available GPU) into one + normalized device per worker. Returns a single `None` (legacy single-worker, device chosen by + the global config) only if the resolution is empty (e.g. `generation_devices` set to an empty + list). """ generation_devices = self._invoker.services.configuration.generation_devices - if not generation_devices: + devices = TorchDevice.get_generation_devices(generation_devices) + if not devices: return [None] - devices: list[Optional[torch.device]] = [] - seen: set[str] = set() - for device_str in generation_devices: - device = TorchDevice.normalize(device_str) - if str(device) not in seen: - seen.add(str(device)) - devices.append(device) - return devices + return list(devices) def _clone_session_runner(self, template: SessionRunnerBase) -> SessionRunnerBase: """Create an independent runner for an additional worker. diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index d912f86a8a3..0511601b557 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -120,6 +120,34 @@ def get_torch_device_name(cls) -> str: device = cls.choose_torch_device() return torch.cuda.get_device_name(device) if device.type == "cuda" else device.type.upper() + @classmethod + def get_generation_devices(cls, generation_devices: Union[str, list[str], None]) -> list[torch.device]: + """Resolve the configured `generation_devices` into a concrete, deduplicated device list. + + - ``"auto"`` (the default) expands to every visible CUDA device, or the single best available + device (mps/cpu) when CUDA is unavailable. + - An explicit list is normalized and deduplicated, with order preserved. + - ``None`` or an empty list yields an empty list; the caller decides the single-device fallback. + """ + if generation_devices == "auto": + if torch.cuda.is_available(): + device_strs: list[str] = [f"cuda:{index}" for index in range(torch.cuda.device_count())] + else: + device_strs = [str(cls.choose_torch_device())] + elif not generation_devices: + return [] + else: + device_strs = list(generation_devices) + + devices: list[torch.device] = [] + seen: set[str] = set() + for device_str in device_strs: + device = cls.normalize(device_str) + if str(device) not in seen: + seen.add(str(device)) + devices.append(device) + return devices + @classmethod def normalize(cls, device: Union[str, torch.device]) -> torch.device: """Add the device index to CUDA devices.""" diff --git a/tests/backend/util/test_devices.py b/tests/backend/util/test_devices.py index 39dee5cb618..aa8433c632e 100644 --- a/tests/backend/util/test_devices.py +++ b/tests/backend/util/test_devices.py @@ -69,6 +69,47 @@ def worker(name: str, device: str): assert TorchDevice.choose_torch_device() == torch.device("cpu") +# ===== generation_devices resolution (config -> concrete device list) ======= + + +def test_get_generation_devices_auto_expands_to_all_cuda(): + """`auto` enumerates every visible CUDA device.""" + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=True), + patch("invokeai.backend.util.devices.torch.cuda.device_count", return_value=3), + ): + assert TorchDevice.get_generation_devices("auto") == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + ] + + +def test_get_generation_devices_auto_without_cuda(): + """`auto` falls back to the single best device when CUDA is unavailable.""" + config = get_config() + config.device = "cpu" + with ( + patch("invokeai.backend.util.devices.torch.cuda.is_available", return_value=False), + patch("invokeai.backend.util.devices.torch.backends.mps.is_available", return_value=False), + ): + assert TorchDevice.get_generation_devices("auto") == [torch.device("cpu")] + + +def test_get_generation_devices_explicit_list_is_deduplicated(): + """An explicit list is normalized and deduplicated, preserving order.""" + assert TorchDevice.get_generation_devices(["cuda:0", "cuda:0", "cuda:1"]) == [ + torch.device("cuda:0"), + torch.device("cuda:1"), + ] + + +@pytest.mark.parametrize("value", [None, []]) +def test_get_generation_devices_empty(value): + """`None` or an empty list resolves to an empty list (caller handles the single-device fallback).""" + assert TorchDevice.get_generation_devices(value) == [] + + @pytest.mark.parametrize("device_dtype_pair", device_types_cpu) def test_device_dtype_cpu(device_dtype_pair): with ( From 420978083c9889067aa67902c6f2e9ba86609ca0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 21:26:18 -0400 Subject: [PATCH 12/14] chore(frontend): typegen+openapi --- invokeai/frontend/web/openapi.json | 15 --------------- invokeai/frontend/web/src/services/api/schema.ts | 5 ----- 2 files changed, 20 deletions(-) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 852dc866bce..1287ee58865 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -41151,21 +41151,6 @@ "description": "Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", "default": "auto" }, - "generation_devices": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Generation Devices", - "description": "List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)" - }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index ace2e30178e..a80183476bd 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -16506,11 +16506,6 @@ export type components = { * @default auto */ device?: string; - /** - * Generation Devices - * @description List of devices to use for parallel generation, e.g. `[cuda:0, cuda:1]`. When set, the session processor runs one generation session per listed device concurrently, distributing jobs fairly across users. When unset (the default), generation runs serially on the single `device`.
Valid values for each entry: `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number) - */ - generation_devices?: string[] | null; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system. From 914c577679f217c8f70abe3a4c241bc5d61eca58 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 21:43:11 -0400 Subject: [PATCH 13/14] docs(multi-gpu): add configuration information --- .../docs/configuration/invokeai-yaml.mdx | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/src/content/docs/configuration/invokeai-yaml.mdx b/docs/src/content/docs/configuration/invokeai-yaml.mdx index 987c8eb98a2..6ac56053928 100644 --- a/docs/src/content/docs/configuration/invokeai-yaml.mdx +++ b/docs/src/content/docs/configuration/invokeai-yaml.mdx @@ -114,6 +114,39 @@ Most common algorithms are supported, like `md5`, `sha256`, and `sha512`. These These options set the paths of various directories and files used by InvokeAI. Any user-defined paths should be absolute paths. +#### Multi-GPU Generation + +On a machine with more than one GPU, InvokeAI can run several generation sessions at the same time — one per GPU — instead of processing the queue one job at a time. Jobs are distributed fairly across users, so a single user's large batch cannot monopolize every GPU while others wait. + +This is controlled by the `generation_devices` setting: + +```yaml +generation_devices: auto # default value +``` + +| Value | Behavior | +| -------------------------- | ----------------------------------------------------------------------------------------------------------------------- | +| `auto` | Use every available CUDA GPU, running one generation session per GPU concurrently. This is the default. | +| `[cuda:0,cuda:1]` | Use the specific devices listed, one session per device. Useful for reserving a GPU for other work. | +| `[cuda:0]` | Use a single specific device. Generation runs serially, as it did before multi-GPU support. | +| `[]` | Use the first detected device. Generation runs serially, as it did before multi-GPU support. | + +Each entry in the list must be one of `cpu`, `cuda`, `mps`, or `cuda:N`, where `N` is a zero-based device number (`cuda:0` is the first GPU, `cuda:1` the second, and so on). + +```yaml +# Use the first and third GPUs, leaving the second free for other tasks +generation_devices: [cuda:0, cuda:2] +``` + +Notes: + +- On a system without a CUDA GPU, `auto` resolves to the single best available device (`mps` on Apple Silicon, otherwise `cpu`), so generation runs serially. +- Each active GPU gets its own model cache, and model weights are duplicated in system RAM for every device. Running many GPUs in parallel therefore increases RAM usage — ensure you have ample system memory before enabling a large device list. +- Duplicate entries are ignored; `[cuda:0, cuda:0]` is treated as `[cuda:0]`. +- You can restrict which physical GPUs InvokeAI sees with the `CUDA_VISIBLE_DEVICES` environment variable. When set, `auto` only enumerates the visible subset, and `cuda:N` indices refer to positions within that subset. + +During parallel generation, the progress display shows one progress bar per active session, stacked vertically, each disappearing as its session completes. + #### Image Subfolder Strategy By default, generated images are stored in a single flat directory under `outputs/images/`. The `image_subfolder_strategy` setting lets you organize newly-created images into subfolders automatically. You can edit this setting in `invokeai.yaml` or, as an admin user, in the Settings panel. From a928a756f90902e65e00904062256eb761104d3a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 2 Jun 2026 21:55:25 -0400 Subject: [PATCH 14/14] chore(frontend): typegen + openapi again --- invokeai/frontend/web/openapi.json | 17 +++++++++++++++++ .../frontend/web/src/services/api/schema.ts | 6 ++++++ 2 files changed, 23 insertions(+) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index 1287ee58865..dbbe40d5ad4 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -41151,6 +41151,23 @@ "description": "Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", "default": "auto" }, + "generation_devices": { + "anyOf": [ + { + "type": "string", + "const": "auto" + }, + { + "items": { + "type": "string" + }, + "type": "array" + } + ], + "title": "Generation Devices", + "description": "Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number)", + "default": "auto" + }, "precision": { "type": "string", "enum": ["auto", "float16", "bfloat16", "float32"], diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index a80183476bd..91393f53a9c 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -16506,6 +16506,12 @@ export type components = { * @default auto */ device?: string; + /** + * Generation Devices + * @description Devices to use for parallel generation. `auto` (the default) uses every available GPU, running one generation session per GPU concurrently and distributing jobs fairly across users. Provide an explicit list (e.g. `[cuda:0, cuda:1]`) to use specific devices, or a single-device list (e.g. `[cuda:0]`) to run serially. On systems without a GPU, `auto` resolves to the single `cpu`/`mps` device.
Valid values: `auto`, or a list whose entries are each `cpu`, `cuda`, `mps`, or `cuda:N` (where N is a device number) + * @default auto + */ + generation_devices?: "auto" | string[]; /** * Precision * @description Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.