Optimizations for MXFP8/NVFP4 dequantize kernels#2865
Optimizations for MXFP8/NVFP4 dequantize kernels#2865YigongQin wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
Signed-off-by: YigongQin <qqqyyy1233@outlook.com>
f5e7375 to
39c0fb1
Compare
for more information, see https://pre-commit.ci
|
The following relevant unit tests passed on SM100 (with the drop |
Signed-off-by: Ziang Li <ziangli@umich.edu>
ddab15d to
3a4afdd
Compare
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
After this PR, fwd is around 3%-4% faster for DeepSeek shape MoE: |
Greptile SummaryThis PR adds GEMM-swizzled scale support to the NVFP4 and MXFP8 dequantize kernels (templated on
Confidence Score: 5/5Safe to merge; all production kernel changes are correct and the only notable issue is unreachable dead code in test infrastructure. All kernel logic (swizzle index computation, scale stride handling, empty-tensor guard) is correct. The Python optimize_for_gemm removal is consistent with the new kernel capability. The only findings are P2: dead code in a test helper and a redundant function call in a test. tests/cpp/test_common.cu — duplicate NVTE_MXFP8_1D_SCALING block in get_scales() should be cleaned up. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_dequantize] --> B{input.numel == 0?}
B -- yes --> C[Early return CUDA graph safe]
B -- no --> D{scaling_mode}
D -- DELAYED_TENSOR_SCALING --> E[fp8::dequantize]
D -- MXFP8_1D_SCALING --> F{with_gemm_swizzled_scales?}
D -- NVFP4_1D_SCALING --> G{with_gemm_swizzled_scales?}
F -- true --> H[dequantize_mxfp8_kernel WITH_GEMM_SWIZZLED_SCALES=true]
F -- false --> I[dequantize_mxfp8_kernel WITH_GEMM_SWIZZLED_SCALES=false]
G -- true --> J[dequantize_fp4_kernel WITH_GEMM_SWIZZLED_SCALES=true]
G -- false --> K[dequantize_fp4_kernel WITH_GEMM_SWIZZLED_SCALES=false]
|
| quantized_compact.to_cpu(); | ||
| quantized_swizzled.set_amax(quantized_compact.amax()); | ||
| quantized_swizzled.set_scale(quantized_compact.scale()); | ||
| quantized_swizzled.from_cpu(); |
There was a problem hiding this comment.
Redundant
from_cpu() after set_scale()
set_scale() for NVTE_NVFP4_1D_SCALING now calls from_cpu() internally (test_common.cu line 518). The explicit quantized_swizzled.from_cpu() call on line 159 is therefore redundant and uploads zero-initialised CPU buffers a second time. Consider removing the explicit call, or adding a comment explaining it's intentional.
| quantized_compact.to_cpu(); | |
| quantized_swizzled.set_amax(quantized_compact.amax()); | |
| quantized_swizzled.set_scale(quantized_compact.scale()); | |
| quantized_swizzled.from_cpu(); | |
| quantized_compact.to_cpu(); | |
| quantized_swizzled.set_amax(quantized_compact.amax()); | |
| quantized_swizzled.set_scale(quantized_compact.scale()); | |
| // set_scale() internally calls from_cpu(); FP4 data is copied after this point. |
| } | ||
| } | ||
|
|
||
| std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = { |
There was a problem hiding this comment.
There is one edge case:
For MXFP8, When the input shape is like 64x64, it will produce scaling factor shape 64x2, but then zero padded to 128x4. We should be able to inject some very large random values in the padded region during malloc (because we don't use torch.zeros to malloc but torch.empty), and detect whether dequantize results is affected. If things work as expected, this line will be triggered
and the dequantize numerics won't be affected.For NVFP4, I think we optimize for GEMM (or swizzle fusion) is actually not enabled, same for the zero-out edge case handling logic?
So there shouldn't be any unswizzle logic needed here?There was a problem hiding this comment.
For NVFP4, I believe currently only device-init grouped quantize with RHT has the swizzle fusion feature, so the scaling factor zero-out is the job of the dedicated swizzle kernel. So if we dequantize + unswizzle for NVFP4, the unswizzle logic might not be correct.
Description
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: