fix(sampling): correct f16/bf16 logprobs crash and corruption (#340)#342
Conversation
…340) `compute_logprobs()` extracted top-k log-probabilities via `array_to_raw_bytes(&top_lp)` with a hardcoded 4-byte stride. `top_lp = take_along_axis(log_probs, top_idx)` inherits the model logit dtype, which is f16/bf16 (2 bytes per element) for quantized models post-#289. `array_to_raw_bytes` dumps the buffer verbatim with no dtype conversion, so for a 2-byte-per-element buffer the read overruns: for `logprobs: 5` on bf16 the buffer is 10 bytes while the loop reads `lp_bytes[8..12]`, which panics with "range end index 12 out of range for slice of length 10" and aborts the server. Any `top_k >= 1` on a bf16 model crashed; even without overrun, reinterpreting two f16 elements as one f32 yields garbage. The fix casts `top_lp` to f32 with `astype` before extracting raw bytes, so the 4-byte stride is valid and the decoded values are correct. This mirrors the dtype-aware selected-token path, which already uses `item_f32` (`array::item<float>()`). The int32 token-id read is left unchanged because indices really are 4 bytes, and the selected-token path is untouched. Adds regression tests that build f16 and bf16 logit arrays from the same underlying f32 values and exercise both the top-k and selected-token paths. The existing tests only built f32 logits, where the 4-byte stride was correct by accident, so they never caught the bug. `compute_logprobs_top_k_bf16_no_panic_matches_f32` is the unit-level reproduction of the server crash.
The selected-token logprob path read its value with `item_f32`, which calls MLX `item<float>()` and reinterprets the element's raw bytes without dtype conversion. On f16/bf16 logit models (most quantized checkpoints post-#289) that read a 2-byte element as a 4-byte f32 and returned garbage (for example `6.9e-41` instead of `-4.5028`), so the always-on selected-token logprob was wrong on the same models the top-k path crashed on. Issue #340 assumed this path was already correct via `item<float>()`; the mandated f16/bf16 regression tests proved otherwise. Cast the single selected-logprob array to f32 before `item_f32`, mirroring the top-k boundary fix. Only the 1-element array is promoted, so the full-vocab `log_probs` stays in its native dtype and the decode hot path is unaffected. The two previously-failing selected-token regression tests now pass.
The dtype regression tests added for #340 left chained `compute_logprobs(...).expect(...)` and `let x = logprobs_for_dtype(...)` calls split across lines, which `cargo fmt --all -- --check` (the ci.yml gate) collapses onto single lines. Apply rustfmt so the formatting gate passes. Test code only, no logic change.
Implementation Review SummaryIntent
Findings Addressed
Remaining Items
Verification
Notes
No CRITICAL or HIGH findings. The implementation is correct, complete, and properly integrated. |
Security and performance reviewReviewed the f16/bf16 logprobs dtype fix against the decode hot path, the FFI raw-bytes boundary, and edge-value handling. No CRITICAL or HIGH findings. No blockers. SecurityThe changed surface is internal sampling code: no untrusted input, no auth, no injection or serialization surface. The fix removes a denial-of-service vector, a normal Performance (decode hot path)
The full-vocab FFI / memory safety at the raw-bytes boundaryAfter the top-k cast, Edge values
No new panics on the request pathThe Note (non-blocking, LOW)
|
PR Finalization CompleteSummary
The branch is up to date with origin and the working tree is clean. Ready for merge. |
Summary
Fixes a server crash (
Abort trap: 6) and a silent garbage-value bug when a/v1/completionsor/v1/chat/completionsrequest asks forlogprobsagainst a model whose logits are f16 or bf16. After #289 keeps quantized models fully bf16, that is the common case: anytop_k >= 1crashed the worker, and evenlogprobswithout top-k returned a corrupt selected-token logprob.Root cause
compute_logprobs()insrc/lib/mlxcel-core/src/sampling.rsreads log-probabilities out of MLX arrays on the host through two byte-level accessors that do not convert dtype, while the arrays inherit the model logit dtype (f16/bf16, 2 bytes per element, for quantized checkpoints post-#289):Top-k path (crash): the alternatives were extracted with
ffi::array_to_raw_bytes(&top_lp)plus a hardcoded 4-byte stride.array_to_raw_bytesdumps the buffer verbatim, so for f16/bf16 the buffer is2*kbytes while the loop reads it with a 4-byte stride bounded by the int32 index array (kelements). Forlogprobs: 5on bf16,lp_bytesis 10 bytes and the readlp_bytes[8..12]overruns, producing the reported panicrange end index 12 out of range for slice of length 10. Even without an overrun, reinterpreting two f16 elements as one f32 yields garbage.Selected-token path (silent corruption): the value was read with
ffi::item_f32, which calls MLXarray::item<float>(). Contrary to the original issue's assumption,item<float>()is not dtype-aware: it returns*data<float>(), reinterpreting the element's raw bytes as f32 with no conversion. On bf16 it read a 2-byte element as a 4-byte f32 and returned garbage (for example6.9e-41instead of-4.5028). This path is always on wheneverlogprobsis enabled, so it was wrong on every bf16 model, not just the top-k case. The f16/bf16 regression tests added here caught it; reasoning alone (as in the issue) missed it.Fix
Promote to f32 at the host-extraction boundary in both places, casting only the small arrays so the full-vocab
log_probsstays in its native dtype and the decode hot path is unaffected:ffi::astype(&top_lp, dtype::FLOAT32)(theneval) beforearray_to_raw_bytes, so the 4-byte stride is valid (kelements of f32).ffi::astype(&selected_lp_arr, dtype::FLOAT32)(theneval) beforeitem_f32, so the single value converts correctly.This mirrors the established in-repo pattern (
generate.rsandrope_proportional.rsbothastypeto f32 before raw-byte extraction). The int32 token-id read is left unchanged because indices really are 4 bytes, and no model weights or dtype policy elsewhere are changed.ffi::item_f32itself is left as-is to keep the change at the logprobs boundary; its byte-reinterpretation behavior on non-f32 arrays is a latent footgun worth a separate follow-up.Tests
The existing tests only built f32 logits via
ffi::from_slice_f32, where the 4-byte stride was correct by accident, so they never exercised either bug. The new tests build f16 and bf16 logit arrays from the same underlying f32 values and assert the results match an f32 reference run within a dtype-appropriate tolerance (bf16 ~0.1, f16 ~0.03, f32 exact), covering both paths:compute_logprobs_top_k_bf16_no_panic_matches_f32— unit-level reproduction of the server crash:top_k = 5on bf16 drives the identical top-k path the server hits. Asserts no panic, correct alternative count, descending sort, matching top-token id, and per-value agreement with the f32 reference.compute_logprobs_top_k_f16_matches_f32— same for f16 with a tighter tolerance.compute_logprobs_top_k_f32_values_correct— f32 reference path, exact values, guards the shared helper and sort.compute_logprobs_selected_token_bf16_matches_f32andcompute_logprobs_selected_token_f16_matches_f32—top_k = 0on bf16/f16; these reproduced the selected-token corruption (pre-fix they returned~6.9e-41/~7.0e-41) and now assert the logprob matches the f32 reference.Test plan
cargo test -p mlxcel-core --release --lib sampling— 37 passed, 0 failed (run by the orchestrator with warm release artifacts).Closes #340