Skip to content

Optimizations for MXFP8/NVFP4 dequantize kernels#2865

Open
YigongQin wants to merge 6 commits intoNVIDIA:mainfrom
YigongQin:yigongq/bwd-dequantize-optim
Open

Optimizations for MXFP8/NVFP4 dequantize kernels#2865
YigongQin wants to merge 6 commits intoNVIDIA:mainfrom
YigongQin:yigongq/bwd-dequantize-optim

Conversation

@YigongQin
Copy link
Copy Markdown

@YigongQin YigongQin commented Apr 10, 2026

Description

  • Handle empty tensors in dequantize for CUDA graph compatibility
  • Add swizzled scale support to the NVFP4 dequantize kernel, reusing the existing MXFP8 swizzle index computation
  • Add C++ unit tests for both NVFP4 and MXFP8 dequantization (including swizzled scale variants)
  • Fix to_cpu() and set_scale() in test infrastructure to correctly sync amax/scale for NVTE_NVFP4_1D_SCALING mode

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Handle empty tensors in dequantize for CUDA graph compatibility — Early return when input has zero elements, avoiding kernel launches on empty tensors.
  • Add GEMM-swizzled scale support to NVFP4 dequantize kernel — Template the kernel with WITH_GEMM_SWIZZLED_SCALES to support reading scales from swizzled layout, reusing the MXFP8 swizzle index computation.
  • Add GEMM-swizzled scale support to MXFP8 dequantize kernel — Extend the MXFP8 dequantize kernel to handle swizzled scale inputs.
  • Add C++ unit tests for NVFP4 dequantization — 21 tests for compact scales + 21 tests for swizzled scales, covering multiple sizes and output dtypes (fp32, bf16, fp16).
  • Add C++ unit tests for MXFP8 dequantization with swizzled scales — New swizzled test suite for MXFP8.
  • Fix to_cpu() to sync amax/scale for NVFP4 tensors — Previously only synced for NVTE_DELAYED_TENSOR_SCALING, causing the CPU reference to use stale amax=0.
  • Fix set_scale() to work for NVFP4 tensors — Same condition fix, enabling the scale to be properly uploaded to GPU before quantization.
  • Fix swizzled test ordering — Move from_cpu() before the FP4 data copy to prevent from_cpu() from overwriting the copied data with zeros.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
@YigongQin YigongQin force-pushed the yigongq/bwd-dequantize-optim branch from f5e7375 to 39c0fb1 Compare April 10, 2026 22:04
@zianglih
Copy link
Copy Markdown
Contributor

zianglih commented Apr 14, 2026

The following relevant unit tests passed on SM100 (with the drop optimize_for_gemm = False changes):

python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih zianglih force-pushed the yigongq/bwd-dequantize-optim branch from ddab15d to 3a4afdd Compare April 14, 2026 18:46
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih
Copy link
Copy Markdown
Contributor

After this PR, fwd is around 3%-4% faster for DeepSeek shape MoE:

# With the optimization
NVTE_BACKWARD_OVERRIDE=dequantized python benchmarks/linear/benchmark_grouped_linear.py --recipe mxfp8 --fwd-only
       m     k     n recipe  num_gemms  grouped_fwd_time_ms
0  16384  7168  2048  mxfp8          4             0.272829
1  32768  7168  2048  mxfp8          4             0.509788
2  65536  7168  2048  mxfp8          4             0.948633
3  98304  7168  2048  mxfp8          4             1.391146
0  16384  7168  2048  mxfp8          8             0.303238
1  32768  7168  2048  mxfp8          8             0.533896
2  65536  7168  2048  mxfp8          8             1.003446
3  98304  7168  2048  mxfp8          8             1.470030

# Without the optimization
git restore --source 77b8681de5cf -- transformer_engine/pytorch/module
NVTE_BACKWARD_OVERRIDE=dequantized python benchmarks/linear/benchmark_grouped_linear.py --recipe mxfp8 --fwd-only
       m     k     n recipe  num_gemms  grouped_fwd_time_ms
0  16384  7168  2048  mxfp8          4             0.282720
1  32768  7168  2048  mxfp8          4             0.526736
2  65536  7168  2048  mxfp8          4             0.982166
3  98304  7168  2048  mxfp8          4             1.451485
0  16384  7168  2048  mxfp8          8             0.313753
1  32768  7168  2048  mxfp8          8             0.551043
2  65536  7168  2048  mxfp8          8             1.040773
3  98304  7168  2048  mxfp8          8             1.527951

@YigongQin YigongQin marked this pull request as ready for review April 15, 2026 16:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 15, 2026

Greptile Summary

This PR adds GEMM-swizzled scale support to the NVFP4 and MXFP8 dequantize kernels (templated on WITH_GEMM_SWIZZLED_SCALES), adds early-return handling for empty tensors in the dequantize dispatcher for CUDA graph compatibility, fixes to_cpu()/set_scale() in the C++ test infrastructure to sync amax/scale for NVTE_NVFP4_1D_SCALING, and removes the now-unnecessary optimize_for_gemm = False override from the Python linear modules.

  • The get_scales() function in tests/cpp/test_common.cu contains an unreachable duplicate NVTE_MXFP8_1D_SCALING block (~line 196) that appears to be a copy-paste artifact from inserting the NVTE_NVFP4_1D_SCALING case; it should be removed.

Confidence Score: 5/5

Safe to merge; all production kernel changes are correct and the only notable issue is unreachable dead code in test infrastructure.

All kernel logic (swizzle index computation, scale stride handling, empty-tensor guard) is correct. The Python optimize_for_gemm removal is consistent with the new kernel capability. The only findings are P2: dead code in a test helper and a redundant function call in a test.

tests/cpp/test_common.cu — duplicate NVTE_MXFP8_1D_SCALING block in get_scales() should be cleaned up.

Important Files Changed

Filename Overview
transformer_engine/common/cast/dispatch/dequantize.cuh Adds early-return guard for empty tensors (numel == 0) enabling CUDA graph compatibility; dispatches correctly to nvfp4::dequantize for NVTE_NVFP4_1D_SCALING.
transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh Adds WITH_GEMM_SWIZZLED_SCALES template parameter to kernel; swizzle index computation reuses mxfp8::swizzle::gemm_swizzled_scale_idx with correct row/col argument ordering for both rowwise and colwise cases.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Adds WITH_GEMM_SWIZZLED_SCALES template parameter and num_scale_tiles_X argument; swizzle index via gemm_swizzled_scale_idx(y, x, num_scale_tiles_X) is correct for the row-major scale layout (1 scale per 16 FP4 elements).
tests/cpp/test_common.cu to_cpu() and set_scale() correctly extended to sync amax/scale for NVTE_NVFP4_1D_SCALING, but get_scales() contains an unreachable duplicate NVTE_MXFP8_1D_SCALING block (~line 196) that is dead code.
tests/cpp/operator/test_dequantize_nvfp4.cu Comprehensive compact and swizzled-scale dequantize tests with correct empty-tensor guarding, but contains a redundant from_cpu() call after set_scale() (which now calls from_cpu() internally).
tests/cpp/operator/test_dequantize_mxfp8.cu New swizzled-scale test suite mirrors the existing compact test suite; correctly copies data via device-to-device memcpy and calls nvte_swizzle_scaling_factors before dequantizing.
transformer_engine/pytorch/ops/basic/basic_linear.py Removes the backward_override conditional block that was disabling optimize_for_gemm for MXFP8/NVFP4, now that dequantize supports swizzled scales.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_dequantize] --> B{input.numel == 0?}
    B -- yes --> C[Early return CUDA graph safe]
    B -- no --> D{scaling_mode}
    D -- DELAYED_TENSOR_SCALING --> E[fp8::dequantize]
    D -- MXFP8_1D_SCALING --> F{with_gemm_swizzled_scales?}
    D -- NVFP4_1D_SCALING --> G{with_gemm_swizzled_scales?}
    F -- true --> H[dequantize_mxfp8_kernel WITH_GEMM_SWIZZLED_SCALES=true]
    F -- false --> I[dequantize_mxfp8_kernel WITH_GEMM_SWIZZLED_SCALES=false]
    G -- true --> J[dequantize_fp4_kernel WITH_GEMM_SWIZZLED_SCALES=true]
    G -- false --> K[dequantize_fp4_kernel WITH_GEMM_SWIZZLED_SCALES=false]
Loading

Comments Outside Diff (1)

  1. tests/cpp/test_common.cu, line 196-222 (link)

    P2 Unreachable duplicate NVTE_MXFP8_1D_SCALING block

    The NVTE_MXFP8_1D_SCALING check at this line is dead code. The identical block at lines 141–167 already handles this mode and always exits via return {ret_rowwise, ret_colwise}. This second block will never be reached. It appears to be a copy-paste artifact from inserting the NVTE_NVFP4_1D_SCALING case (lines 168–195) between two copies of the MXFP8 block. The second copy should be removed.

Reviews (1): Last reviewed commit: "Drop `optimize_for_gemm` in basic linear" | Re-trigger Greptile

Comment on lines +156 to +159
quantized_compact.to_cpu();
quantized_swizzled.set_amax(quantized_compact.amax());
quantized_swizzled.set_scale(quantized_compact.scale());
quantized_swizzled.from_cpu();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant from_cpu() after set_scale()

set_scale() for NVTE_NVFP4_1D_SCALING now calls from_cpu() internally (test_common.cu line 518). The explicit quantized_swizzled.from_cpu() call on line 159 is therefore redundant and uploads zero-initialised CPU buffers a second time. Consider removing the explicit call, or adding a comment explaining it's intentional.

Suggested change
quantized_compact.to_cpu();
quantized_swizzled.set_amax(quantized_compact.amax());
quantized_swizzled.set_scale(quantized_compact.scale());
quantized_swizzled.from_cpu();
quantized_compact.to_cpu();
quantized_swizzled.set_amax(quantized_compact.amax());
quantized_swizzled.set_scale(quantized_compact.scale());
// set_scale() internally calls from_cpu(); FP4 data is copied after this point.

}
}

std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = {
Copy link
Copy Markdown
Collaborator

@zhongbozhu zhongbozhu Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one edge case:

For MXFP8, When the input shape is like 64x64, it will produce scaling factor shape 64x2, but then zero padded to 128x4. We should be able to inject some very large random values in the padded region during malloc (because we don't use torch.zeros to malloc but torch.empty), and detect whether dequantize results is affected. If things work as expected, this line will be triggered

// Zero out swizzled scales if padding is needed
and the dequantize numerics won't be affected.

For NVFP4, I think we optimize for GEMM (or swizzle fusion) is actually not enabled, same for the zero-out edge case handling logic?

NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format.");
So there shouldn't be any unswizzle logic needed here?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For NVFP4, I believe currently only device-init grouped quantize with RHT has the swizzle fusion feature, so the scaling factor zero-out is the job of the dedicated swizzle kernel. So if we dequantize + unswizzle for NVFP4, the unswizzle logic might not be correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants