From 7f2a8da8b3549e1badfc2ea39cdced0560b79170 Mon Sep 17 00:00:00 2001 From: GitGlimpse895 Date: Tue, 14 Apr 2026 09:57:58 +0530 Subject: [PATCH] hooks/pyramid_attention_broadcast: fix redundant recompute at iteration 0 and free stale cache when outside timestep range --- src/diffusers/hooks/pyramid_attention_broadcast.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index ed5bd24dea01..0c3e2079a108 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -159,17 +159,21 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: ) should_compute_attention = ( self.state.cache is None - or self.state.iteration == 0 or not is_within_timestep_range or self.state.iteration % self.block_skip_range == 0 ) if should_compute_attention: output = self.fn_ref.original_forward(*args, **kwargs) + # When outside the active timestep window, release the cached tensor + # immediately so GPU memory is not held until the next reset_state(). + if not is_within_timestep_range: + self.state.cache = None + else: + self.state.cache = output else: output = self.state.cache - self.state.cache = output self.state.iteration += 1 return output @@ -177,7 +181,6 @@ def reset_state(self, module: torch.nn.Module) -> None: self.state.reset() return module - def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig): r""" Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.