perf(sampling): incremental per-sequence penalty state caches#344
Conversation
Maintain per-sequence sampler state so repetition and frequency/presence penalties stop rebuilding from scratch on every decode step in long generations, and add a selected-token-only logprob fast path for top_k == 0. SamplerState (mlxcel-core/src/sampling.rs) keeps an incrementally maintained sorted, deduped seen-token set for the repetition penalty plus per-token counts and reusable sparse index/value scratch buffers for the frequency/presence penalty, so that penalty no longer allocates a full-vocabulary vector per token. It synchronizes to the token history on entry: append-only growth absorbs just the new tail, while a shorter or diverged history (speculative rollback, KV cache trim/restore) triggers a rebuild, so no explicit reset is needed. The repetition path reuses the shared apply_repetition_penalty_sorted core, and the frequency/presence path promotes to f32 first to match the rebuild path's broadcast-subtract dtype, so both are byte-identical to the rebuild-every-token path and penalty-adjusted greedy sampling selects identical token ids. State is created lazily only when a repetition/frequency/presence penalty is active and lives on the owning sequence (Option<SamplerState> on SequenceInfo, a local in the CLI decode loops), so the default no-penalty path never allocates it and keeps calling the original sample_token_optimized unchanged. The server scheduler decode steps and the four CxxGenerator decode loops gate on the existing needs_token_history() flag. compute_logprobs gains a top_k == 0 fast path computing logit minus logsumexp instead of materializing the full-vocabulary log-softmax. It matches the full path within floating-point tolerance (the subtraction order differs by at most 1 ULP, which is OpenAI-compatible) and does not affect token selection; the top_k > 0 path is unchanged. DRY is intentionally left on the rebuild path (its sliding window would need fragile position rebasing) with behavior unchanged. Tests: byte-identical repetition and frequency/presence parity over a multi-token sequence (f32 and f16), 40-step greedy parity with all penalties active, no-state-allocated assertions for the no-penalty and DRY-only paths, sync append/shrink/divergence, and the top_k == 0 fast path versus full log-softmax.
Implementation Review SummaryIntent
Findings Addressed
Remaining Items (LOW, informational, non-blocking)
Verification
|
The existing dtype coverage tests (compute_logprobs_selected_token_f16_matches_f32 and bf16 variant) compare the top_k==0 fast path on f16/bf16 input against the f32 reference value. The new test mirrors compute_logprobs_top_k_zero_fast_path_matches_full_softmax but with f16 logits, verifying that the fast path (gather + logsumexp) agrees with the full log_softmax on the same f16 input within f16 precision (~0.01). Closes the gap flagged in the PR review.
PR FinalizationTestsAdded DocsNo documentation changes. The optimization has no user-facing surface, and there is no changelog or architecture doc that this change naturally belongs in. Lint / Format
Commit |
Real-model validation (orchestrator)Server validation on M1 Ultra (
Combined with the 44 sampling unit tests (byte-identical incremental-vs-rebuild parity for repetition and frequency/presence in both f32 and f16, 40-step greedy parity with all penalties, no-state-when-off, sync append/shrink/divergence, and the top_k==0 logprob fast path vs full log-softmax in f32 and f16), the acceptance criteria are met. DRY incrementalization is deferred (documented in the PR); the speculative path still uses the rebuild sampler (clean follow-up). Reviewer + security-checker: 0 unresolved CRITICAL/HIGH/MEDIUM.
|
Summary
Maintain incremental per-sequence sampler state so the repetition and frequency/presence penalties stop rebuilding sorted sets, count maps, and full-vocabulary penalty vectors on every decode step in long generations, plus a selected-token-only logprob fast path for
top_k == 0. The incremental state is purely an optimization: penalty-adjusted greedy sampling produces byte-identical token ids to the rebuild-every-token path, and the default no-penalty path is untouched and allocates nothing new.What changed
mlxcel-core/src/sampling.rs: newSamplerStateholds a sorted/deduped seen-token set (repetition), per-token counts (frequency/presence), and reusable sparse index/value scratch buffers, so frequency/presence never allocates a full-vocab vector per token.sync()absorbs only the appended tail on append-only growth and rebuilds on a shorter or diverged history (speculative rollback, KV trim/restore), so trim/restore needs no explicit reset.sample_token_optimized_with_stateis the state-aware entry point;sample_token_optimizednow delegates to a shared core withstate == None, so its behavior is bit-for-bit unchanged.apply_repetition_penalty_sortedcore; frequency/presence promotes to f32 first to match the rebuild path's broadcast-subtract dtype. Both are byte-identical to the rebuild path (verified for f32 and f16), so token selection is unchanged.compute_logprobs: atop_k == 0fast path computeslogit - logsumexpinstead of materializing the full-vocab log-softmax. It matches the full path within floating-point tolerance (the subtraction order differs by at most 1 ULP, which is OpenAI-compatible) and does not affect token selection. Thetop_k > 0path (including the fix(sampling): top-k logprobs extraction crashes the server on f16/bf16 models (reads 2-byte logprobs as f32) #340 dtype-aware top-k extraction) is unchanged, and every logprob caller funnels throughcompute_logprobs, so all paths stay mutually consistent.server/batch/sequence.rs: anOption<SamplerState>field onSequenceInfo, created lazily and dropped with the sequence (no leak, no global).server/batch/scheduler.rs: the two per-row decode sample sites gate onneeds_token_history()and pass&mut seq.sampler_state; no-penalty rows take the original path unchanged (and most never leave the perf(server): batched decode sampler fast path ([B, vocab] -> [B] fused sampling) #325 fused-batch fast path). The perf(server): batched decode sampler fast path ([B, vocab] -> [B] fused sampling) #325 batched fused sampler is untouched.mlxcel-core/src/generate.rs: the four CLI decode loops keep a lazily-created localSamplerState, gated on the existingneeds_historyflag.Deferred
DRY stays on the rebuild path (its sliding window would need fragile per-step position rebasing); behavior is unchanged. The speculative decoders and
commands/generate.rskeep using the rebuildsample_token_optimized(correct, just not yet routed through the incremental state).No baseline overhead when off
The no-penalty path calls the original
sample_token_optimized(selected by theneeds_token_history()/needs_historygate), allocates noSamplerState(the field/local staysNone), and is byte-identical. Unit tests assert the state staysNonefor greedy and DRY-only configs.Test plan
cargo test --release -p mlxcel-core sampling::(43 pass, including 9 new: f32 and f16 byte-identical repetition and frequency/presence parity, 40-step greedy parity with all penalties active, no-state-when-off for greedy and DRY-only, sync append/shrink/divergence, and thetop_k == 0fast path versus full log-softmax)cargo clippy -p mlxcel-core --lib --tests -- -D warningscargo check -p mlxcel --libcargo fmt --checkgeneratewith--repetition-penalty/--frequency-penalty) versus the pre-change binary, and the server no-overhead-when-off checkCloses #328