From 70234a77ac6c3b8cdd8a80536c53f4eb38ff7ed4 Mon Sep 17 00:00:00 2001 From: GitGlimpse895 Date: Fri, 17 Apr 2026 11:58:50 +0530 Subject: [PATCH] hooks/pyramid_attention_broadcast: remove redundant iteration==0 guard and fix stale cache VRAM leak --- src/diffusers/hooks/pyramid_attention_broadcast.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index ed5bd24dea01..97dd92a44a77 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -159,7 +159,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: ) should_compute_attention = ( self.state.cache is None - or self.state.iteration == 0 or not is_within_timestep_range or self.state.iteration % self.block_skip_range == 0 ) @@ -169,7 +168,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: else: output = self.state.cache - self.state.cache = output + self.state.cache = output if is_within_timestep_range else None self.state.iteration += 1 return output