From 2dd19eddc4fa82ce63da8a0d687a59869eaa0887 Mon Sep 17 00:00:00 2001 From: chang-zhijie <609212560@qq.com> Date: Mon, 13 Apr 2026 10:50:06 +0800 Subject: [PATCH 1/2] Fix attention_mask broadcasting for NPU compatibility --- .../models/transformers/transformer_ernie_image.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 09682a218d91..f9951e328ab6 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -127,8 +127,14 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso query, key = query.to(dtype), key.to(dtype) # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len] - if attention_mask is not None and attention_mask.ndim == 2: - attention_mask = attention_mask[:, None, None, :] + if attention_mask is not None: + if attention_mask.ndim == 2: + attention_mask = attention_mask[:, None, None, :] + + if attention_mask.ndim == 4: + # NPU does not support automatic broadcasting for this type; the mask must be expanded. + if attention_mask.device.type == 'npu' and attention_mask.shape[1:3] == (1, 1): + attention_mask = attention_mask.expand(-1, attn.heads, query.shape[1], -1) # Compute joint attention hidden_states = dispatch_attention_fn( From 2623481f13f35aeab2784aff8ad69f1718635671 Mon Sep 17 00:00:00 2001 From: chang-zhijie <609212560@qq.com> Date: Wed, 15 Apr 2026 17:57:55 +0800 Subject: [PATCH 2/2] Fix _native_npu_attention: add inversion for 4D attn_mask and expand when dim2/dim3 == 1 --- src/diffusers/models/attention_dispatch.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 837d573d8c4d..9443dc4440c3 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1523,15 +1523,18 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k] # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md - if ( - attn_mask is not None - and attn_mask.ndim == 2 - and attn_mask.shape[0] == query.shape[0] - and attn_mask.shape[1] == key.shape[1] - ): - B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + if attn_mask is not None: + if ( + attn_mask.ndim == 2 + and attn_mask.shape[0] == query.shape[0] + and attn_mask.shape[1] == key.shape[1] + ): + B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1] + attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() + elif attn_mask.ndim == 4 and attn_mask.shape[1:3] == (1, 1): + attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1) + attn_mask = ~attn_mask.to(torch.bool) - attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous() return attn_mask