Skip to content

fix(sampling): correct f16/bf16 logprobs crash and corruption (#340)#342

Merged
inureyes merged 3 commits into
mainfrom
fix/issue-340-topk-logprobs-bf16-dtype
Jun 17, 2026
Merged

fix(sampling): correct f16/bf16 logprobs crash and corruption (#340)#342
inureyes merged 3 commits into
mainfrom
fix/issue-340-topk-logprobs-bf16-dtype

Conversation

@inureyes

@inureyes inureyes commented Jun 17, 2026

Copy link
Copy Markdown
Member

Summary

Fixes a server crash (Abort trap: 6) and a silent garbage-value bug when a /v1/completions or /v1/chat/completions request asks for logprobs against a model whose logits are f16 or bf16. After #289 keeps quantized models fully bf16, that is the common case: any top_k >= 1 crashed the worker, and even logprobs without top-k returned a corrupt selected-token logprob.

Root cause

compute_logprobs() in src/lib/mlxcel-core/src/sampling.rs reads log-probabilities out of MLX arrays on the host through two byte-level accessors that do not convert dtype, while the arrays inherit the model logit dtype (f16/bf16, 2 bytes per element, for quantized checkpoints post-#289):

  1. Top-k path (crash): the alternatives were extracted with ffi::array_to_raw_bytes(&top_lp) plus a hardcoded 4-byte stride. array_to_raw_bytes dumps the buffer verbatim, so for f16/bf16 the buffer is 2*k bytes while the loop reads it with a 4-byte stride bounded by the int32 index array (k elements). For logprobs: 5 on bf16, lp_bytes is 10 bytes and the read lp_bytes[8..12] overruns, producing the reported panic range end index 12 out of range for slice of length 10. Even without an overrun, reinterpreting two f16 elements as one f32 yields garbage.

  2. Selected-token path (silent corruption): the value was read with ffi::item_f32, which calls MLX array::item<float>(). Contrary to the original issue's assumption, item<float>() is not dtype-aware: it returns *data<float>(), reinterpreting the element's raw bytes as f32 with no conversion. On bf16 it read a 2-byte element as a 4-byte f32 and returned garbage (for example 6.9e-41 instead of -4.5028). This path is always on whenever logprobs is enabled, so it was wrong on every bf16 model, not just the top-k case. The f16/bf16 regression tests added here caught it; reasoning alone (as in the issue) missed it.

Fix

Promote to f32 at the host-extraction boundary in both places, casting only the small arrays so the full-vocab log_probs stays in its native dtype and the decode hot path is unaffected:

  • Top-k: ffi::astype(&top_lp, dtype::FLOAT32) (then eval) before array_to_raw_bytes, so the 4-byte stride is valid (k elements of f32).
  • Selected token: ffi::astype(&selected_lp_arr, dtype::FLOAT32) (then eval) before item_f32, so the single value converts correctly.

This mirrors the established in-repo pattern (generate.rs and rope_proportional.rs both astype to f32 before raw-byte extraction). The int32 token-id read is left unchanged because indices really are 4 bytes, and no model weights or dtype policy elsewhere are changed. ffi::item_f32 itself is left as-is to keep the change at the logprobs boundary; its byte-reinterpretation behavior on non-f32 arrays is a latent footgun worth a separate follow-up.

Tests

The existing tests only built f32 logits via ffi::from_slice_f32, where the 4-byte stride was correct by accident, so they never exercised either bug. The new tests build f16 and bf16 logit arrays from the same underlying f32 values and assert the results match an f32 reference run within a dtype-appropriate tolerance (bf16 ~0.1, f16 ~0.03, f32 exact), covering both paths:

  • compute_logprobs_top_k_bf16_no_panic_matches_f32 — unit-level reproduction of the server crash: top_k = 5 on bf16 drives the identical top-k path the server hits. Asserts no panic, correct alternative count, descending sort, matching top-token id, and per-value agreement with the f32 reference.
  • compute_logprobs_top_k_f16_matches_f32 — same for f16 with a tighter tolerance.
  • compute_logprobs_top_k_f32_values_correct — f32 reference path, exact values, guards the shared helper and sort.
  • compute_logprobs_selected_token_bf16_matches_f32 and compute_logprobs_selected_token_f16_matches_f32top_k = 0 on bf16/f16; these reproduced the selected-token corruption (pre-fix they returned ~6.9e-41 / ~7.0e-41) and now assert the logprob matches the f32 reference.

Test plan

  • cargo test -p mlxcel-core --release --lib sampling — 37 passed, 0 failed (run by the orchestrator with warm release artifacts).

Closes #340

…340)

`compute_logprobs()` extracted top-k log-probabilities via `array_to_raw_bytes(&top_lp)` with a hardcoded 4-byte stride. `top_lp = take_along_axis(log_probs, top_idx)` inherits the model logit dtype, which is f16/bf16 (2 bytes per element) for quantized models post-#289. `array_to_raw_bytes` dumps the buffer verbatim with no dtype conversion, so for a 2-byte-per-element buffer the read overruns: for `logprobs: 5` on bf16 the buffer is 10 bytes while the loop reads `lp_bytes[8..12]`, which panics with "range end index 12 out of range for slice of length 10" and aborts the server. Any `top_k >= 1` on a bf16 model crashed; even without overrun, reinterpreting two f16 elements as one f32 yields garbage.

The fix casts `top_lp` to f32 with `astype` before extracting raw bytes, so the 4-byte stride is valid and the decoded values are correct. This mirrors the dtype-aware selected-token path, which already uses `item_f32` (`array::item<float>()`). The int32 token-id read is left unchanged because indices really are 4 bytes, and the selected-token path is untouched.

Adds regression tests that build f16 and bf16 logit arrays from the same underlying f32 values and exercise both the top-k and selected-token paths. The existing tests only built f32 logits, where the 4-byte stride was correct by accident, so they never caught the bug. `compute_logprobs_top_k_bf16_no_panic_matches_f32` is the unit-level reproduction of the server crash.
@inureyes inureyes added type:bug Bug fixes, error corrections, or issue resolutions priority:high High priority area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:inference Generation, sampling, decoding (incl. speculative, DRY) status:review Under review labels Jun 17, 2026
The selected-token logprob path read its value with `item_f32`, which calls MLX `item<float>()` and reinterprets the element's raw bytes without dtype conversion. On f16/bf16 logit models (most quantized checkpoints post-#289) that read a 2-byte element as a 4-byte f32 and returned garbage (for example `6.9e-41` instead of `-4.5028`), so the always-on selected-token logprob was wrong on the same models the top-k path crashed on. Issue #340 assumed this path was already correct via `item<float>()`; the mandated f16/bf16 regression tests proved otherwise.

Cast the single selected-logprob array to f32 before `item_f32`, mirroring the top-k boundary fix. Only the 1-element array is promoted, so the full-vocab `log_probs` stays in its native dtype and the decode hot path is unaffected. The two previously-failing selected-token regression tests now pass.
@inureyes inureyes changed the title fix(sampling): cast top-k logprobs to f32 before raw-byte extraction (#340) fix(sampling): correct f16/bf16 logprobs crash and corruption (#340) Jun 17, 2026
The dtype regression tests added for #340 left chained `compute_logprobs(...).expect(...)` and `let x = logprobs_for_dtype(...)` calls split across lines, which `cargo fmt --all -- --check` (the ci.yml gate) collapses onto single lines. Apply rustfmt so the formatting gate passes. Test code only, no logic change.
@inureyes

Copy link
Copy Markdown
Member Author

Implementation Review Summary

Intent

Fix the f16/bf16 logprobs crash (Abort trap: 6) and silent garbage-value bug in compute_logprobs() for issue #340: read 2-byte f16/bf16 logprobs correctly instead of with a hardcoded 4-byte stride.

Findings Addressed

  • New f16/bf16 regression tests failed the cargo fmt --all -- --check CI gate (ci.yml:86): chained compute_logprobs(...).expect(...) and let x = logprobs_for_dtype(...) calls were split across lines that rustfmt collapses (MEDIUM). Fixed by running cargo fmt on the test block (10 insertions, 20 deletions, test code only, no logic change).

Remaining Items

  • ffi::item_f32 still byte-reinterprets array::item<float>() on non-f32 arrays (LOW) — the PR scopes this out as a documented follow-up footgun. Acceptable: every current caller now casts to f32 first, so it is latent, not live.
  • No standalone test for size-1 vocab or top_k-clamped-on-bf16 combination (LOW) — pre-existing coverage gap, dtype-independent code path, not introduced by this PR.

Verification

  • All stated requirements implemented (both top-k crash and selected-token corruption fixed; the second bug, which issue fix(sampling): top-k logprobs extraction crashes the server on f16/bf16 models (reads 2-byte logprobs as f32) #340 assumed was safe, was found by the regression tests and fixed)
  • No placeholder/mock code remaining
  • Integrated into project code flow (callers in scheduler.rs, speculative_burst.rs, gemma4_mtp_target.rs, dflash round_loop.rs consume corrected values unchanged; return shape stable)
  • Project conventions followed (mirrors the established astype to f32 before raw-byte extraction pattern in generate.rs and rope_proportional.rs)
  • Existing modules reused where applicable (ffi::astype, dtype::FLOAT32)
  • No unintended structural changes (fix confined to the logprobs host-extraction boundary; int32 token-id read, model-weight dtype policy, and the full-vocab log_probs dtype all unchanged)
  • Tests pass (orchestrator: 37 passed, 0 failed in release; the 2 selected-token tests failed pre-second-fix and pass after)

Notes

  • Dtype cast ordering verified correct in both paths: cast -> eval -> host read.
  • Tolerances verified sound against an independent f32 reference: garbage gap (~4.5) is 45x the bf16 tolerance (0.1) and 150x the f16 tolerance (0.03), so they cannot pass on the known pre-fix garbage; worst-case legitimate bf16 rounding (~0.018) and f16 (~0.002) sit well inside, so they will not flake.
  • Both bug claims confirmed against the C++ bridge: item_f32 is array::item<float>() (line 285, no dtype conversion) and array_to_raw_bytes dumps nbytes() verbatim (line 301).

No CRITICAL or HIGH findings. The implementation is correct, complete, and properly integrated.

@inureyes

Copy link
Copy Markdown
Member Author

Security and performance review

Reviewed the f16/bf16 logprobs dtype fix against the decode hot path, the FFI raw-bytes boundary, and edge-value handling. No CRITICAL or HIGH findings. No blockers.

Security

The changed surface is internal sampling code: no untrusted input, no auth, no injection or serialization surface. The fix removes a denial-of-service vector, a normal logprobs request against a bf16 model aborted the worker (range end index 12 out of range for slice of length 10). Confirmed against the C++ bridge that the two accessors behave exactly as the PR claims: item_f32 is array::item<float>() (raw reinterpret, no dtype conversion) and array_to_raw_bytes dumps nbytes() verbatim. Casting the extracted arrays to f32 before the host read is correct.

Performance (decode hot path)

compute_logprobs runs per generated token when logprobs is enabled. Both casts operate only on the small extracted arrays, not the full vocab:

  • Selected token: astype runs on selected_lp_arr, shape [1, 1] from take_along_axis(log_probs, idx[1,1]). 1 element.
  • Top-k: astype runs on top_lp, shape [1, k] (top_idx is sliced to k at line 712 before the gather). k elements.

The full-vocab log_softmax + eval at lines 677-678 is pre-existing and unchanged. Neither cast widens the full vocab, so no full-vocabulary allocation or eval is added per token.

FFI / memory safety at the raw-bytes boundary

After the top-k cast, lp_bytes is 4*k bytes (k f32 elements). The loop is bounded by idx_bytes.len() / 4 == k and reads lp_bytes[i*4..(i+1)*4] for i in 0..k, so the last read is [(k-1)*4 .. k*4], exactly in bounds. No remaining overrun. eval precedes the host read in both paths (line 690 before item_f32, line 730 before array_to_raw_bytes); the array_to_raw_bytes C++ also re-evals defensively.

Edge values

  • top_k == 0: selected-token-only path, top-k block skipped, top_alternatives empty. Covered by the two selected-token tests.
  • top_k clamped to vocab: k = min(top_k, vocab_size) (line 697) feeds both argpartition(k-1) and the slice stops, so the cast array and the loop bound stay consistent at the clamp. The pre-existing capped_at_vocab test still guards this.
  • Ties: alternatives are matched by token id, not position, and sorted descending with a total-order-safe comparator.
  • from_ne_bytes / item_f32 now receive genuine 4-byte f32 data after the cast.

No new panics on the request path

The filter_map + try_into on fixed [u8; 4] slices cannot panic (out-of-range yields None). The fix is the established in-repo pattern (generate.rs, rope_proportional.rs both astype to f32 before raw-byte extraction). cargo check -p mlxcel-core --release --lib is clean.

Note (non-blocking, LOW)

ffi::item_f32 reinterpreting raw bytes on non-f32 arrays remains a latent footgun for any future caller that forgets the cast. The PR already calls this out as a separate follow-up; leaving the FFI as-is and fixing at the logprobs boundary is the right scope for this change.

@inureyes

Copy link
Copy Markdown
Member Author

PR Finalization Complete

Summary

  • Tests: No new tests added. Coverage is already comprehensive: 5 regression tests covering f16/bf16/f32 on both the top-k path and the selected-token path, with sound tolerances (0.1 for bf16, 0.03 for f16). The two known LOW gaps (size-1 vocab, top_k clamped exactly on bf16) are pre-existing edge cases not worth adding here. All 37 tests pass.
  • Documentation: No changes needed. This fix restores correct behavior with no API or behavior change. The project has no per-PR changelog convention (changelogs are written at release time). The one logprobs mention in docs is a feature-support table row that remains accurate.
  • Lint/Format: cargo fmt --all -- --check is clean. One pre-existing clippy warning exists in build.rs (needless borrow on line 90), but build.rs was not touched by this PR and the warning is present on main as well. No warnings in the touched code (sampling.rs).

The branch is up to date with origin and the working tree is clean. Ready for merge.

@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 17, 2026
@inureyes inureyes merged commit e056634 into main Jun 17, 2026
5 checks passed
@inureyes inureyes deleted the fix/issue-340-topk-logprobs-bf16-dtype branch June 17, 2026 18:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:inference Generation, sampling, decoding (incl. speculative, DRY) priority:high High priority status:done Completed type:bug Bug fixes, error corrections, or issue resolutions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fix(sampling): top-k logprobs extraction crashes the server on f16/bf16 models (reads 2-byte logprobs as f32)

1 participant