Skip to content

Check numerics in MXFP8 C++ tests by dequantizing to FP32#2881

Open
timmoon10 wants to merge 6 commits intoNVIDIA:mainfrom
timmoon10:tmoon/mxfp8-cpp-tests-with-dequant
Open

Check numerics in MXFP8 C++ tests by dequantizing to FP32#2881
timmoon10 wants to merge 6 commits intoNVIDIA:mainfrom
timmoon10:tmoon/mxfp8-cpp-tests-with-dequant

Conversation

@timmoon10
Copy link
Copy Markdown
Collaborator

Description

We have experienced some failures in the MXFP8 C++ tests because of expected numerical errors. In particular, FP32 numerical errors will sometimes cause the MXFP8 scales to differ by one. However, our current tests assume that MXFP8 scales are bit-wise exact between the TE implementation and reference implementation. This PR changes the MXFP8 tests so that we dequantize MXFP8 values and compare against an FP32 reference implementation.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add MXFP8 dequantization function to test utilities
  • Change FP8 tolerances to match machine epsilon
  • Change MXFP8 tests to use FP32 reference implementation
  • Change MXFP8 tests to check that dequantized MXFP8 values and FP32 reference values are within FP8 tolerances

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 2 commits April 15, 2026 00:20
Replace the brittle raw-FP8 comparison (which failed when GPU and CPU
round a near-power-of-2 amax to adjacent E8M0 exponents) with a
dequantize-then-compare approach: dequantize the GPU output using the
GPU's own scales and compare against the pre-quantization float
reference. Tolerances are derived from the half-ULP of each FP8 format
(1/16 for E4M3, 1/8 for E5M2). Also removes the ad-hoc mismatch budget
workaround and dead `compare_scaled_elts` / `ref_amax` code.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
No test passes kFloat8E8M0 to getTolerances — it is not in all_fp_types
and no INSTANTIATE_TEST_SUITE_P block includes it. The default branch
now errors out on unexpected types.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the testing Improvements to tests or testing infrastructure label Apr 15, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Apr 15, 2026

Greptile Summary

This PR replaces bit-exact scale comparisons in the MXFP8 C++ tests with a dequantize-then-compare-as-float approach, fixing flaky failures caused by FP32 rounding differences between the TE and reference scale selections. The new test flow adds dequantize_mxfp8_rowwise/colwise helpers in test_common.h, simplifies compute_ref to only produce pre-quantization float values, and widens the FP8 tolerances to match the true machine epsilon of each format (E4M3: rtol 0.125, E5M2: rtol 0.25).

Confidence Score: 5/5

Safe to merge — logic is correct throughout and the only finding is a latent P2 trap in getTolerances that has no live impact today.

All five files reviewed in full. The dequantization math (exp2f(biased_exp − 127)) is consistent with the existing exp2f_rcp quantization step. compute_ref correctly handles every ProcessingMethod case including the previously-buggy CAST_DBIAS. The per-tensor grouped loop passes each tensor its own scale stride and offset. The only open concern is the removal of kFloat8E8M0 from getTolerances, which is P2 with no current caller.

tests/cpp/test_common.cu — kFloat8E8M0 removal from getTolerances is a latent risk worth addressing.

Important Files Changed

Filename Overview
tests/cpp/test_common.h Adds dequantize_mxfp8_rowwise/colwise helpers and a new compareResults(float*, float*) overload; scale formula exp2f(biased_exp - 127) is consistent with exp2f_rcp's quantization step.
tests/cpp/test_common.cu FP8 tolerances updated to machine epsilon; kFloat8E8M0 removed from getTolerances switch — safe today but creates a latent trap if any test ever calls getTolerances(kFloat8E8M0) in the future.
tests/cpp/operator/test_cast_mxfp8.cu compute_ref simplified to pre-quantization floats; CAST_DBIAS correctly uses grad; performTest_x1 gains NVTE_CHECK(rowwise != colwise) to guard single-direction invariant.
tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu compute_ref correctly separates act/gate gradients; after_dgate = silu(x)*grad (no gate_elt) matches the SwiGLU backward formula; namespace closing moved past the test class definition which is valid in C++.
tests/cpp/operator/test_cast_mxfp8_grouped.cu Per-tensor dequantization loop correctly uses each tensor's own scale offset and stride; trailing padding positions are zero-initialized on both sides so the full-elts_num comparison is safe.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Fill input / grad tensors] --> B[Run GPU kernel\nnvte_quantize / nvte_swiglu / etc.]
    B --> C[Copy FP8 data + E8M0 scales\nto host]
    C --> D[dequantize_mxfp8_rowwise\nor dequantize_mxfp8_colwise\nusing GPU's own scales]
    D --> E[dequant_output float32 array]
    A --> F[compute_ref\npre-quantization float values\nno scale computation]
    F --> G[ref_output float32 array]
    E --> H[compareResults\natol/rtol based on FP8 machine epsilon]
    G --> H
    H --> I{Pass / Fail}
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into tmoon/mxfp8-cpp..." | Re-trigger Greptile

@@ -194,39 +76,26 @@ void performTest_x1(const size_t rows,
InputsFillCase fill_case,
const bool IS_DGATED) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unused fill_case parameter

fill_case is accepted by both performTest_x1 and performTest_x2 but never referenced — both functions always call fillUniform, ignoring the argument. If future fill cases are meant to be exercisable here, the parameter should be plumbed through; otherwise it can be removed to avoid confusion.

Suggested change
const bool IS_DGATED) {
void performTest_x1(const size_t rows,
const size_t cols,
const size_t block_size_rows,
const size_t block_size_cols,
const bool IS_DGATED) {

@timmoon10 timmoon10 added 2.14.0 and removed 2.14.0 labels Apr 15, 2026
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci core

Comment thread tests/cpp/operator/test_cast_mxfp8.cu
@timmoon10
Copy link
Copy Markdown
Collaborator Author

/te-ci core

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.15.0 testing Improvements to tests or testing infrastructure

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants