Skip to content

Fix fp_qmv_impl small-output-dim branch using raw fp8 scale byte#3763

Open
jax-0n-git wants to merge 1 commit into
ml-explore:mainfrom
jax-0n-git:fix/fp-qmv-small-n-raw-scale
Open

Fix fp_qmv_impl small-output-dim branch using raw fp8 scale byte#3763
jax-0n-git wants to merge 1 commit into
ml-explore:mainfrom
jax-0n-git:fix/fp-qmv-small-n-raw-scale

Conversation

@jax-0n-git

@jax-0n-git jax-0n-git commented Jun 24, 2026

Copy link
Copy Markdown

Fixes #3762.

Problem

In fp_qmv_impl (mlx/backend/metal/kernels/fp_quantized.h), the out_vec_size < 8 branch's full-block loop loads the fp8 scale as a raw byte and passes it straight to qdot, skipping dequantize_scale:

uint8_t s = sl[0];                                                   // raw e8m0/e4m3 byte
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s);

Every other fp scale-load site in the file decodes the byte first — including the remainder loop of this same branch (:464) and the >= 8 branch loops (:306/:368/:495/:514):

U s = dequantize_scale<U, group_size>(sl[0]);

qdot applies the scale directly (scale * accum); it does not decode internally. So fp-quantized (mxfp4 / mxfp8 / nvfp4) mat-vec with output dim < 8 multiplies by the literal 0–255 exponent/mantissa byte instead of the real fp8 scale, producing grossly wrong output.

The line looks copy-pasted from the integer quantized.h qmv_impl, where U s = sl[0]; is correct because there scales is a real const device T* float array. In fp_quantized.h, scales is const device uint8_t* (packed e8m0/e4m3) and must be decoded.

When it bites

Both conditions are required:

  1. output dim N < 8 → enters the out_vec_size < (num_simdgroups * results_per_simdgroup) = 8 branch, and
  2. K > block_size (values_per_thread * SIMD_SIZE; 256 for 4-bit) → the full-block loop at the buggy line actually runs. (For K ≤ block_size only the correct remainder loop runs.)

This is why it slipped past CI: test_fp_qmv uses output dim ≥ 8 (wrong branch), and test_qmv_small_non_multiples uses K = 32 (below block_size, so only the correct remainder loop executes) and doesn't cover mxfp4.

Repro on 0.31.2 (M5 Max) — relative error vs dequantize+matmul, K=512:

mode N=5 N=7 N=12 N=8/16 (fast) affine N<8 (int control)
mxfp4 2.2e2 1.8e2 7.8e-8 clean clean
mxfp8 2.2e4 1.8e4 clean clean
nvfp4 7.9e1 7.8e1 clean clean

The integer affine path is clean at every N, isolating the defect to the fp scale-decode.

Fix

Decode the scale, matching the four sibling loops:

-        uint8_t s = sl[0];
+        U s = dequantize_scale<U, group_size>(sl[0]);

Test

Adds test_fp_qmv_small_non_multiples in python/tests/test_quantized.py: fp modes {mxfp4, mxfp8, nvfp4} × M∈{1,2} × N∈{1,2,3,5,7} at K=512 (forcing at least one full block), inputs normalized by 1/sqrt(K) like the other large-dim tests, asserting (y_q - y_hat).abs().max() < 1e-3.

Verified locally on an M5 Max (Metal) by building MLX from source at the current main:

  • On the current kernel the new test fails every subtest by a wide margin (relative error ~1e1–1e4).
  • With the fix it passes, and the small-output-dim path's error matches the always-correct >= 8 path (~5e-4 relative).
  • The full test_quantized.py suite (32 tests) passes with the fix — no regressions.
  • pre-commit (clang-format, black, isort) is clean on both files.

The out_vec_size < 8 branch's full-block loop loaded the fp8 scale as a
raw byte and passed it to qdot without dequantize_scale, so fp-quantized
(mxfp4/mxfp8/nvfp4) matvec with output dim < 8 multiplied by the raw
e8m0/e4m3 byte instead of the decoded scale (gross error, ~1e2-1e4
relative). Every other fp scale-load site decodes it, including the
remainder loop of this same branch.

Add test_fp_qmv_small_non_multiples covering the fp modes at output
dim < 8 with K large enough to run the full-block loop (the existing
fp tests use output dim >= 8 or K below block_size).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@jax-0n-git jax-0n-git force-pushed the fix/fp-qmv-small-n-raw-scale branch from b8079fd to ce93101 Compare June 24, 2026 21:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Metal fp_qmv_impl: out_vec_size < 8 branch uses raw scale byte instead of dequantize_scale → wrong mxfp4 matvec for output dim < 8

1 participant