diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py index ed5bd24dea01..97dd92a44a77 100644 --- a/src/diffusers/hooks/pyramid_attention_broadcast.py +++ b/src/diffusers/hooks/pyramid_attention_broadcast.py @@ -159,7 +159,6 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: ) should_compute_attention = ( self.state.cache is None - or self.state.iteration == 0 or not is_within_timestep_range or self.state.iteration % self.block_skip_range == 0 ) @@ -169,7 +168,7 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: else: output = self.state.cache - self.state.cache = output + self.state.cache = output if is_within_timestep_range else None self.state.iteration += 1 return output