Skip to content

fix(ernie-image): pass attn_mask=None when text is unpadded#13802

Open
Ace3Z wants to merge 1 commit into
huggingface:mainfrom
Ace3Z:fix/ernie-flash-attention
Open

fix(ernie-image): pass attn_mask=None when text is unpadded#13802
Ace3Z wants to merge 1 commit into
huggingface:mainfrom
Ace3Z:fix/ernie-flash-attention

Conversation

@Ace3Z
Copy link
Copy Markdown

@Ace3Z Ace3Z commented May 24, 2026

What does this PR do?

Fixes #13801.

If you call pipeline.transformer.set_attention_backend("flash") on an ErnieImagePipeline, it crashes:

ValueError: attn_mask is not supported for flash-attn 2.

The same code works fine on ZImagePipeline and Flux2KleinPipeline.

Root cause

ErnieImageTransformer2DModel.forward always builds a boolean attention mask from text_lens, even when every sample already has the full text length. In the common single-prompt case (where Tmax = lens.max() in _pad_text), every position is valid and the mask is just all True. flash-attn 2 rejects any non-None attn_mask, so the call fails before any attention runs.

Z-Image already handles this. In _prepare_for_attention (transformer_z_image.py:792-797) it does:

if all(seq == max_seqlen for seq in item_seqlens):
    attn_mask = None
else:
    attn_mask = ...  # build it

Fix

Do the same thing in Ernie's forward. If at least one sample is padded, build the mask. Otherwise pass None.

The previous all-True bool mask was a no-op on sdpa, cudnn and native paths (no positions get masked, no -inf rows), so behavior on those backends doesn't change. I verified that numerically: with the same seed and inputs, baseline and fix produce bit-identical output in both the uniform (text_lens=[16,16]) and padded (text_lens=[16,12]) cases. Norm, mean, std, and per-element samples match to 7 decimal places.

flash-attn and its variants are now usable when the batch is uniform.

Test

Wires the existing AttentionBackendTesterMixin into the Ernie test class (mirrors the Flux pattern). Native_cudnn backend tests pass; flash variants skip cleanly when kernels isn't installed, and will catch the regression when it is.

Before submitting

  • This PR fixes a typo or improves the docs.
  • Read the contributor guideline.
  • Read the philosophy doc.
  • Discussed via GitHub issue (Incompatibility between FlashAttention and ERNIE Image #13801).
  • Documentation: mask construction is internal to the forward and not referenced in docs/source/en/.
  • New tests: wired in AttentionBackendTesterMixin for Ernie.

Who can review?

@yiyixuxu @sayakpaul

@Ace3Z Ace3Z force-pushed the fix/ernie-flash-attention branch from 4bf7bcc to 5f37368 Compare May 25, 2026 08:57
@github-actions github-actions Bot added size/M PR with diff < 200 LOC size/S PR with diff < 50 LOC and removed size/M PR with diff < 200 LOC tests labels May 25, 2026
@Ace3Z Ace3Z force-pushed the fix/ernie-flash-attention branch from 5f37368 to fb49e96 Compare May 25, 2026 09:19
@github-actions github-actions Bot added tests size/M PR with diff < 200 LOC and removed size/M PR with diff < 200 LOC labels May 25, 2026
ErnieImageTransformer2DModel.forward built a bool attention mask from
text_lens on every call, including the common case where every sample
already has full-length text. flash-attn 2 rejects any non-None
attn_mask, so set_attention_backend('flash') crashed even though the
all-True mask was effectively a no-op. Z-Image's _prepare_for_attention
takes the same shortcut.

Closes huggingface#13801
@Ace3Z Ace3Z force-pushed the fix/ernie-flash-attention branch from fb49e96 to 84ae4b4 Compare May 25, 2026 10:27
@github-actions github-actions Bot removed the size/M PR with diff < 200 LOC label May 25, 2026
@Ace3Z
Copy link
Copy Markdown
Author

Ace3Z commented May 29, 2026

Friendly ping on this one. @yiyixuxu @sayakpaul, when you have a moment, would you mind taking a look? It's a small 10 line change to transformer_ernie_image.py plus a 5 line test wiring in the existing AttentionBackendTesterMixin, mirroring the Z Image pattern. Closes #13801.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incompatibility between FlashAttention and ERNIE Image

1 participant