Skip to content

Flash Attention not dispatched for DiT-style attention pattern (diffusion transformers) #27983

@Nickuru

Description

@Nickuru

Describe the issue

Sorry for Half-Vibe issue. I am working on optimization of existing TTS model and do not have enough expertise in ORT DiT functions.


Description

ONNX Runtime v1.24.4 with Flash Attention compiled for sm_120 (Blackwell/RTX 5090) cannot dispatch Flash Attention kernels for Diffusion Transformer (DiT) models because the attention pattern is not recognized by the attention fusion optimizer.

This affects F5-TTS and likely other diffusion transformer models (Stable Diffusion 3, FLUX, etc.) that use the same attention structure.

Impact

FP16 DiT models require 44 explicit Cast operations (FP16↔FP32) per inference around Softmax for numerical stability. These Casts operate on the full attention matrix [batch, heads, seq, seq] and add ~200ms of pure memory bandwidth overhead per inference call (measured on RTX 5090).

Eliminating these Casts via Flash Attention (which handles FP32 softmax internally) would provide a 42% speedup — verified by testing with pure FP16 attention (0 Casts, 915ms vs 1574ms FP32), though audio quality requires the FP32 softmax that Flash Attention provides.

Current Behavior

The model's attention pattern in the ONNX graph:

MatMul(Q, K^T)           # FP16, Q and K have RoPE applied
  → Cast(FP16 → FP32)
  → Mul(* 100.0)          # attention logit rescaling
  → Softmax               # FP32
  → Cast(FP32 → FP16)
  → MatMul(attn, V)       # FP16

With GraphOptimizationLevel::All, ORT produces FusedMatMul and Gemm ops but does not fuse the attention pattern into MultiHeadAttention. The Softmax and Cast ops remain unfused. Flash Attention kernels are compiled and available but never dispatched.

Expected Behavior

ORT's attention fusion should recognize this pattern and replace it with a fused attention op (MultiHeadAttention or similar) that dispatches to Flash Attention, handling the FP32 softmax internally.

Why Fusion Fails

The DiT attention pattern differs from BERT/GPT-style attention that ORT's fusion recognizes:

  1. Q and K have Rotary Position Embeddings (RoPE) applied between projection and attention — not a standard linear projection → attention pattern
  2. K is pre-transposed to [batch, heads, head_dim, seq] (optimization for manual MatMul(Q, K_transposed))
  3. Custom * 100.0 scaling before softmax (compensates for pre-scaled Q/K weights, not standard 1/sqrt(d_k))
  4. Batch dimension of 2 for classifier-free guidance (conditional + unconditional)

Environment

  • ORT: v1.24.4 (built from source)
  • GPU: RTX 5090 (sm_120, Blackwell)
  • CUDA: 13.2
  • cuDNN: 9.20
  • Build flags: CMAKE_CUDA_ARCHITECTURES=120, onnxruntime_USE_FLASH_ATTENTION=ON
  • Model: F5-TTS DiT transformer, 22 layers, 16 heads, head_dim=64, FP16

Suggested Solution

Either:

  1. Extend the attention fusion pattern matcher to recognize DiT-style attention (post-RoPE Q/K, custom scaling, pre-transposed K)
  2. Add a "raw attention" fused op that accepts pre-computed, head-split Q/K/V tensors (after RoPE) and performs fused matmul+softmax+matmul — dispatching to Flash Attention internally
  3. Support a scale parameter on the fused attention op to handle the * 100.0 (or arbitrary) logit rescaling

Option 2 would be the most broadly useful, as it would work for any model that computes Q/K/V externally (with RoPE, custom projections, etc.) and just needs efficient attention dispatch.

To reproduce

Reproduction

import onnxruntime as ort

opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.optimized_model_filepath = "optimized.onnx"

sess = ort.InferenceSession("F5_Transformer.onnx", sess_options=opts,
                            providers=["CUDAExecutionProvider"])

# Check optimized model — no MultiHeadAttention ops
import onnx
m = onnx.load("optimized.onnx")
ops = {n.op_type for n in m.graph.node}
print("MultiHeadAttention" in ops)  # False
print(sum(1 for n in m.graph.node if n.op_type == "Cast"))  # 44 (unchanged)

Urgency

No response

Platform

Windows

OS Version

10

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

v1.24.4

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

13.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    model:transformerissues related to a transformer model: BERT, GPT2, Hugging Face, Longformer, T5, etc.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions