Skip to content

perf(sampling): incremental per-sequence penalty state caches#344

Merged
inureyes merged 2 commits into
mainfrom
perf/328-sampler-state-caches
Jun 17, 2026
Merged

perf(sampling): incremental per-sequence penalty state caches#344
inureyes merged 2 commits into
mainfrom
perf/328-sampler-state-caches

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Maintain incremental per-sequence sampler state so the repetition and frequency/presence penalties stop rebuilding sorted sets, count maps, and full-vocabulary penalty vectors on every decode step in long generations, plus a selected-token-only logprob fast path for top_k == 0. The incremental state is purely an optimization: penalty-adjusted greedy sampling produces byte-identical token ids to the rebuild-every-token path, and the default no-penalty path is untouched and allocates nothing new.

What changed

  • mlxcel-core/src/sampling.rs: new SamplerState holds a sorted/deduped seen-token set (repetition), per-token counts (frequency/presence), and reusable sparse index/value scratch buffers, so frequency/presence never allocates a full-vocab vector per token. sync() absorbs only the appended tail on append-only growth and rebuilds on a shorter or diverged history (speculative rollback, KV trim/restore), so trim/restore needs no explicit reset. sample_token_optimized_with_state is the state-aware entry point; sample_token_optimized now delegates to a shared core with state == None, so its behavior is bit-for-bit unchanged.
  • Repetition reuses the extracted apply_repetition_penalty_sorted core; frequency/presence promotes to f32 first to match the rebuild path's broadcast-subtract dtype. Both are byte-identical to the rebuild path (verified for f32 and f16), so token selection is unchanged.
  • compute_logprobs: a top_k == 0 fast path computes logit - logsumexp instead of materializing the full-vocab log-softmax. It matches the full path within floating-point tolerance (the subtraction order differs by at most 1 ULP, which is OpenAI-compatible) and does not affect token selection. The top_k > 0 path (including the fix(sampling): top-k logprobs extraction crashes the server on f16/bf16 models (reads 2-byte logprobs as f32) #340 dtype-aware top-k extraction) is unchanged, and every logprob caller funnels through compute_logprobs, so all paths stay mutually consistent.
  • server/batch/sequence.rs: an Option<SamplerState> field on SequenceInfo, created lazily and dropped with the sequence (no leak, no global). server/batch/scheduler.rs: the two per-row decode sample sites gate on needs_token_history() and pass &mut seq.sampler_state; no-penalty rows take the original path unchanged (and most never leave the perf(server): batched decode sampler fast path ([B, vocab] -> [B] fused sampling) #325 fused-batch fast path). The perf(server): batched decode sampler fast path ([B, vocab] -> [B] fused sampling) #325 batched fused sampler is untouched.
  • mlxcel-core/src/generate.rs: the four CLI decode loops keep a lazily-created local SamplerState, gated on the existing needs_history flag.

Deferred

DRY stays on the rebuild path (its sliding window would need fragile per-step position rebasing); behavior is unchanged. The speculative decoders and commands/generate.rs keep using the rebuild sample_token_optimized (correct, just not yet routed through the incremental state).

No baseline overhead when off

The no-penalty path calls the original sample_token_optimized (selected by the needs_token_history() / needs_history gate), allocates no SamplerState (the field/local stays None), and is byte-identical. Unit tests assert the state stays None for greedy and DRY-only configs.

Test plan

  • cargo test --release -p mlxcel-core sampling:: (43 pass, including 9 new: f32 and f16 byte-identical repetition and frequency/presence parity, 40-step greedy parity with all penalties active, no-state-when-off for greedy and DRY-only, sync append/shrink/divergence, and the top_k == 0 fast path versus full log-softmax)
  • cargo clippy -p mlxcel-core --lib --tests -- -D warnings
  • cargo check -p mlxcel --lib
  • cargo fmt --check
  • Full release build plus real-model validation: long-generation penalty parity (server and CLI generate with --repetition-penalty / --frequency-penalty) versus the pre-change binary, and the server no-overhead-when-off check

Closes #328

Maintain per-sequence sampler state so repetition and frequency/presence penalties stop rebuilding from scratch on every decode step in long generations, and add a selected-token-only logprob fast path for top_k == 0.

SamplerState (mlxcel-core/src/sampling.rs) keeps an incrementally maintained sorted, deduped seen-token set for the repetition penalty plus per-token counts and reusable sparse index/value scratch buffers for the frequency/presence penalty, so that penalty no longer allocates a full-vocabulary vector per token. It synchronizes to the token history on entry: append-only growth absorbs just the new tail, while a shorter or diverged history (speculative rollback, KV cache trim/restore) triggers a rebuild, so no explicit reset is needed. The repetition path reuses the shared apply_repetition_penalty_sorted core, and the frequency/presence path promotes to f32 first to match the rebuild path's broadcast-subtract dtype, so both are byte-identical to the rebuild-every-token path and penalty-adjusted greedy sampling selects identical token ids.

State is created lazily only when a repetition/frequency/presence penalty is active and lives on the owning sequence (Option<SamplerState> on SequenceInfo, a local in the CLI decode loops), so the default no-penalty path never allocates it and keeps calling the original sample_token_optimized unchanged. The server scheduler decode steps and the four CxxGenerator decode loops gate on the existing needs_token_history() flag.

compute_logprobs gains a top_k == 0 fast path computing logit minus logsumexp instead of materializing the full-vocabulary log-softmax. It matches the full path within floating-point tolerance (the subtraction order differs by at most 1 ULP, which is OpenAI-compatible) and does not affect token selection; the top_k > 0 path is unchanged.

DRY is intentionally left on the rebuild path (its sliding window would need fragile position rebasing) with behavior unchanged.

Tests: byte-identical repetition and frequency/presence parity over a multi-token sequence (f32 and f16), 40-step greedy parity with all penalties active, no-state-allocated assertions for the no-penalty and DRY-only paths, sync append/shrink/divergence, and the top_k == 0 fast path versus full log-softmax.
@inureyes inureyes added status:review Under review type:performance Performance improvements priority:low Low priority area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:inference Generation, sampling, decoding (incl. speculative, DRY) labels Jun 17, 2026
@inureyes

Copy link
Copy Markdown
Member Author

Implementation Review Summary

Intent

Maintain incremental per-sequence SamplerState (sorted/deduped seen-set for repetition, incremental counts + reusable sparse buffers for frequency/presence) and a top_k == 0 selected-token logprob fast path, as a pure optimization that is byte-identical to the rebuild-every-token path and adds zero baseline overhead when penalties/logprobs are off (closes #328).

Findings Addressed

  • No CRITICAL/HIGH/MEDIUM findings. No auto-fixes were required.

Remaining Items (LOW, informational, non-blocking)

  • (LOW) SamplerState::apply_frequency_presence returns copy(logits) (original dtype) in the idx.is_empty() branch, whereas the rebuild path would return a promoted f32 array there. This branch is only reachable when counts is non-empty yet every token id is out of [0, vocab) range, which cannot happen for real token histories, so it is value- and dtype-identical for every realistic input. Optional: promote in that branch to make the comment ("matches the rebuild path's empty early return") strictly accurate.
  • (LOW) The top_k == 0 fast-path parity unit test exercises f32 logits; f16/bf16 are covered for the full/selected-token path (the fix(sampling): top-k logprobs extraction crashes the server on f16/bf16 models (reads 2-byte logprobs as f32) #340 tests) but not specifically for the new fast path. Acceptable since logprobs are approximate and token selection is unaffected; a f16/bf16 fast-path parity test would close the gap.

Verification

  • No baseline overhead when off: gate is exactly repetition_penalty != 1.0 || frequency_penalty != 0.0 || presence_penalty != 0.0; no-penalty call sites take the original sample_token_optimized via needs_token_history(); Option<SamplerState> stays None (DRY-only included). sample_token_optimized delegates to the shared core with state == None, bit-for-bit unchanged.
  • Deterministic parity (by inspection): repetition feeds the same sorted/deduped set into the shared apply_repetition_penalty_sorted; frequency/presence computes the identical freq*count + presence and promotes to f32 the same way, so x - 0.0 == x keeps untouched tokens byte-identical. HashMap order is irrelevant (unique keys, non-overlapping put_along_axis).
  • State lifecycle / sync(): sampler_state is mutated only by sample_token_optimized_with_state, which sync()s to token_history first. All token_history transitions are append-one, set-at-prefill, or clear-on-preempt; preemption leaves a stale state but the re-prefilled history is always <= absorbed_len, so sync() rebuilds (shrink) or the tip check catches divergence. Speculative burst uses the rebuild path and never touches sampler_state. Resume builds a fresh SequenceInfo (sampler_state: None).
  • top_k == 0 fast path: logit - logsumexp == log_softmax[s], same dtype regime, returns empty alternatives (matches prior top_k == 0 behavior); the top_k > 0 path (incl. fix(sampling): top-k logprobs extraction crashes the server on f16/bf16 models (reads 2-byte logprobs as f32) #340 extraction) is untouched; selection is upstream and unaffected.
  • Scope: perf(server): batched decode sampler fast path ([B, vocab] -> [B] fused sampling) #325 batched_fused_sample / row_supports_fused_batch / FusedSampleParams untouched; only an added import and the 2 per-row fallback sample sites changed. All 10 SequenceInfo literals set sampler_state: None.
  • Conventions: no bare cross-repo refs (#328/fix(sampling): top-k logprobs extraction crashes the server on f16/bf16 models (reads 2-byte logprobs as f32) #340 are valid lablup/mlxcel refs; checker exits 0); no em dashes; no AI attribution; // Used by: present on new shared fns and no existing shared signature changed.
  • Full release build + real-model long-generation penalty parity + no-overhead-when-off check: owned by the orchestrator (pending).

The existing dtype coverage tests (compute_logprobs_selected_token_f16_matches_f32
and bf16 variant) compare the top_k==0 fast path on f16/bf16 input against the
f32 reference value. The new test mirrors compute_logprobs_top_k_zero_fast_path_matches_full_softmax
but with f16 logits, verifying that the fast path (gather + logsumexp) agrees with
the full log_softmax on the same f16 input within f16 precision (~0.01). Closes
the gap flagged in the PR review.
@inureyes

Copy link
Copy Markdown
Member Author

PR Finalization

Tests

Added compute_logprobs_top_k_zero_fast_path_f16_matches_full_softmax to close the gap noted in review. The existing dtype tests compare the top_k==0 fast path on f16/bf16 against the f32 reference value; the new test mirrors compute_logprobs_top_k_zero_fast_path_matches_full_softmax but with f16 logits, verifying the fast path (gather + logsumexp) agrees with the full log_softmax on the same f16 input within f16 precision (tolerance 0.01). 44 sampling tests pass (43 + 1 new).

Docs

No documentation changes. The optimization has no user-facing surface, and there is no changelog or architecture doc that this change naturally belongs in.

Lint / Format

cargo clippy -p mlxcel-core --lib --tests -- -D warnings and cargo fmt --check both pass clean. The Clone derive on SamplerState flagged as LOW by the security review is not exercised by any caller (no containing type derives Clone; no .clone() call site found), but clippy does not warn about it either way so it was left in place to avoid unnecessary churn.

Commit 24f2ffa56 on perf/328-sampler-state-caches.

@inureyes

Copy link
Copy Markdown
Member Author

Real-model validation (orchestrator)

Server validation on M1 Ultra (qwen3-0.6b-4bit, temp=0), exercising both incremental paths:

  • no-overhead-when-off: no-penalty decode 203.8 tok/s, matching the baseline (no regression). The no-penalty path takes the original sample_token_optimized unchanged (SamplerState stays None).
  • repetition penalty (1.3): 201.3 tok/s, deterministic across two runs, output differs from the no-penalty baseline (penalty active). Incremental seen-set path.
  • frequency/presence penalty (0.6): 190.9 tok/s, deterministic across two runs, active. Incremental counts + reusable sparse-buffer path.

Combined with the 44 sampling unit tests (byte-identical incremental-vs-rebuild parity for repetition and frequency/presence in both f32 and f16, 40-step greedy parity with all penalties, no-state-when-off, sync append/shrink/divergence, and the top_k==0 logprob fast path vs full log-softmax in f32 and f16), the acceptance criteria are met. DRY incrementalization is deferred (documented in the PR); the speculative path still uses the rebuild sampler (clean follow-up). Reviewer + security-checker: 0 unresolved CRITICAL/HIGH/MEDIUM.

cargo check -p mlxcel --tests passes, confirming all SequenceInfo literal updates across the 6 test files are complete.

@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 17, 2026
@inureyes inureyes merged commit f360063 into main Jun 17, 2026
5 checks passed
@inureyes inureyes deleted the perf/328-sampler-state-caches branch June 17, 2026 21:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:inference Generation, sampling, decoding (incl. speculative, DRY) priority:low Low priority status:done Completed type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf(inference): incremental per-sequence state caches for repetition/frequency/DRY sampling

1 participant