Describe the issue
Sorry for Half-Vibe issue. I am working on optimization of existing TTS model and do not have enough expertise in ORT DiT functions.
Description
ONNX Runtime v1.24.4 with Flash Attention compiled for sm_120 (Blackwell/RTX 5090) cannot dispatch Flash Attention kernels for Diffusion Transformer (DiT) models because the attention pattern is not recognized by the attention fusion optimizer.
This affects F5-TTS and likely other diffusion transformer models (Stable Diffusion 3, FLUX, etc.) that use the same attention structure.
Impact
FP16 DiT models require 44 explicit Cast operations (FP16↔FP32) per inference around Softmax for numerical stability. These Casts operate on the full attention matrix [batch, heads, seq, seq] and add ~200ms of pure memory bandwidth overhead per inference call (measured on RTX 5090).
Eliminating these Casts via Flash Attention (which handles FP32 softmax internally) would provide a 42% speedup — verified by testing with pure FP16 attention (0 Casts, 915ms vs 1574ms FP32), though audio quality requires the FP32 softmax that Flash Attention provides.
Current Behavior
The model's attention pattern in the ONNX graph:
MatMul(Q, K^T) # FP16, Q and K have RoPE applied
→ Cast(FP16 → FP32)
→ Mul(* 100.0) # attention logit rescaling
→ Softmax # FP32
→ Cast(FP32 → FP16)
→ MatMul(attn, V) # FP16
With GraphOptimizationLevel::All, ORT produces FusedMatMul and Gemm ops but does not fuse the attention pattern into MultiHeadAttention. The Softmax and Cast ops remain unfused. Flash Attention kernels are compiled and available but never dispatched.
Expected Behavior
ORT's attention fusion should recognize this pattern and replace it with a fused attention op (MultiHeadAttention or similar) that dispatches to Flash Attention, handling the FP32 softmax internally.
Why Fusion Fails
The DiT attention pattern differs from BERT/GPT-style attention that ORT's fusion recognizes:
- Q and K have Rotary Position Embeddings (RoPE) applied between projection and attention — not a standard linear projection → attention pattern
- K is pre-transposed to
[batch, heads, head_dim, seq] (optimization for manual MatMul(Q, K_transposed))
- Custom
* 100.0 scaling before softmax (compensates for pre-scaled Q/K weights, not standard 1/sqrt(d_k))
- Batch dimension of 2 for classifier-free guidance (conditional + unconditional)
Environment
- ORT: v1.24.4 (built from source)
- GPU: RTX 5090 (sm_120, Blackwell)
- CUDA: 13.2
- cuDNN: 9.20
- Build flags:
CMAKE_CUDA_ARCHITECTURES=120, onnxruntime_USE_FLASH_ATTENTION=ON
- Model: F5-TTS DiT transformer, 22 layers, 16 heads, head_dim=64, FP16
Suggested Solution
Either:
- Extend the attention fusion pattern matcher to recognize DiT-style attention (post-RoPE Q/K, custom scaling, pre-transposed K)
- Add a "raw attention" fused op that accepts pre-computed, head-split Q/K/V tensors (after RoPE) and performs fused matmul+softmax+matmul — dispatching to Flash Attention internally
- Support a
scale parameter on the fused attention op to handle the * 100.0 (or arbitrary) logit rescaling
Option 2 would be the most broadly useful, as it would work for any model that computes Q/K/V externally (with RoPE, custom projections, etc.) and just needs efficient attention dispatch.
To reproduce
Reproduction
import onnxruntime as ort
opts = ort.SessionOptions()
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
opts.optimized_model_filepath = "optimized.onnx"
sess = ort.InferenceSession("F5_Transformer.onnx", sess_options=opts,
providers=["CUDAExecutionProvider"])
# Check optimized model — no MultiHeadAttention ops
import onnx
m = onnx.load("optimized.onnx")
ops = {n.op_type for n in m.graph.node}
print("MultiHeadAttention" in ops) # False
print(sum(1 for n in m.graph.node if n.op_type == "Cast")) # 44 (unchanged)
Urgency
No response
Platform
Windows
OS Version
10
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
v1.24.4
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
13.2
Describe the issue
Sorry for Half-Vibe issue. I am working on optimization of existing TTS model and do not have enough expertise in ORT DiT functions.
Description
ONNX Runtime v1.24.4 with Flash Attention compiled for sm_120 (Blackwell/RTX 5090) cannot dispatch Flash Attention kernels for Diffusion Transformer (DiT) models because the attention pattern is not recognized by the attention fusion optimizer.
This affects F5-TTS and likely other diffusion transformer models (Stable Diffusion 3, FLUX, etc.) that use the same attention structure.
Impact
FP16 DiT models require 44 explicit
Castoperations (FP16↔FP32) per inference aroundSoftmaxfor numerical stability. These Casts operate on the full attention matrix[batch, heads, seq, seq]and add ~200ms of pure memory bandwidth overhead per inference call (measured on RTX 5090).Eliminating these Casts via Flash Attention (which handles FP32 softmax internally) would provide a 42% speedup — verified by testing with pure FP16 attention (0 Casts, 915ms vs 1574ms FP32), though audio quality requires the FP32 softmax that Flash Attention provides.
Current Behavior
The model's attention pattern in the ONNX graph:
With
GraphOptimizationLevel::All, ORT producesFusedMatMulandGemmops but does not fuse the attention pattern intoMultiHeadAttention. TheSoftmaxandCastops remain unfused. Flash Attention kernels are compiled and available but never dispatched.Expected Behavior
ORT's attention fusion should recognize this pattern and replace it with a fused attention op (
MultiHeadAttentionor similar) that dispatches to Flash Attention, handling the FP32 softmax internally.Why Fusion Fails
The DiT attention pattern differs from BERT/GPT-style attention that ORT's fusion recognizes:
[batch, heads, head_dim, seq](optimization for manualMatMul(Q, K_transposed))* 100.0scaling before softmax (compensates for pre-scaled Q/K weights, not standard1/sqrt(d_k))Environment
CMAKE_CUDA_ARCHITECTURES=120,onnxruntime_USE_FLASH_ATTENTION=ONSuggested Solution
Either:
scaleparameter on the fused attention op to handle the* 100.0(or arbitrary) logit rescalingOption 2 would be the most broadly useful, as it would work for any model that computes Q/K/V externally (with RoPE, custom projections, etc.) and just needs efficient attention dispatch.
To reproduce
Reproduction
Urgency
No response
Platform
Windows
OS Version
10
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
v1.24.4
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
13.2