Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,15 +1523,18 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas

# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
if attn_mask is not None:
if (
attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
elif attn_mask.ndim == 4 and attn_mask.shape[1:3] == (1, 1):
attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1)

attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()

return attn_mask

Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/models/transformers/transformer_ernie_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,14 @@ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tenso
query, key = query.to(dtype), key.to(dtype)

# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
if attention_mask is not None:
if attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]

if attention_mask.ndim == 4:
# NPU does not support automatic broadcasting for this type; the mask must be expanded.
if attention_mask.device.type == 'npu' and attention_mask.shape[1:3] == (1, 1):
attention_mask = attention_mask.expand(-1, attn.heads, query.shape[1], -1)

# Compute joint attention
hidden_states = dispatch_attention_fn(
Expand Down
Loading