Skip to content

perf(server): fuse batched decode sampling into one [B] dispatch#339

Merged
inureyes merged 1 commit into
mainfrom
perf/325-batched-decode-sampler
Jun 17, 2026
Merged

perf(server): fuse batched decode sampling into one [B] dispatch#339
inureyes merged 1 commit into
mainfrom
perf/325-batched-decode-sampler

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Add a [B, vocab] -> [B] fused batched decode sampler so the batched decode loop samples all active rows in a single fused_sample dispatch and a single host copy, replacing the per-row slice + sample + eval + item_i32 round trips that incurred one eval/sync per row. The payoff lands on fast small models and B>1 serving, where per-row sync and host scalar extraction are a visible fraction of decode.

What changed

  • src/lib/mlxcel-core/src/sampling.rs: new batched_fused_sample(&MlxArray, &FusedSampleParams) -> Vec<i32> that calls fused_sample once over [B, vocab] (or [B, 1, vocab]) and copies the [B] token array to host in one eval. fused_sample already reduces over the last axis (argmax/categorical), so [B, vocab] yields [B] with no C++ change. Added FusedSampleParams (a Copy scalar-param carrier with bitwise matches) and the gate predicates config_supports_fused_batch and row_supports_fused_batch.
  • src/server/batch/scheduler.rs: execute_batched_decode gains a fast-path gate (batched_decode_fused_params) plus dispatch. The gate returns Some(params) only when every active row shares the same scalar sampling params and none needs a structured-output mask, a thinking-budget override, a per-token logprobs payload, a history-based penalty, or a token bias. On a hit, apply_fused_decode_tokens runs the shared bookkeeping tail (EOS, history, streaming decode, length limit, periodic cache clear, cache offset). The existing per-row loop is preserved exactly as the fallback.
  • src/server/batch/active.rs: immutable ActiveBatch::get accessor for the read-only gate.

Correctness

  • Greedy (temperature == 0 or top_k == 1) output is byte-identical to the per-row path: argmax over the last axis is independent per row. A unit test compares batched_fused_sample against the per-row batched_sample on synthetic [B, 1, vocab] logits and asserts equality.
  • Stochastic sampling (shared temperature/top_k/top_p/min_p across rows) keeps the correct per-row distribution; it differs from the per-row path only in random-number sequencing, the documented batched-vs-B=1 jitter class. Any per-row penalty, token bias, structured-output mask, thinking budget, or logprobs request forces the unchanged per-row fallback.
  • No astype node is added to the sampling path: the [B] token ids are read from raw bytes (fused_sample returns uint32, exact as i32 for any token id), mirroring the existing compute_logprobs extraction of argpartition indices.

Test plan

  • cargo check -p mlxcel-core --lib --tests (clean)
  • cargo check -p mlxcel --lib (clean)
  • cargo test --release -p mlxcel-core sampling:: (29 passed, incl. greedy parity, B=1 no-regression, 2-D input, and all gate predicates)
  • cargo clippy -p mlxcel-core --lib --tests -- -D warnings and cargo clippy -p mlxcel --lib -- -D warnings (clean)
  • cargo fmt --check (clean)
  • Release build of both binaries with metal,accelerate (pending: cold MLX C++ link exceeds the implementation watchdog)
  • Real-model server benchmark on qwen3-0.6b-4bit and qwen3-4b-4bit at B=2/4/8: greedy output parity, decode tok/s, per-token latency, and no regression at B=1, plus one logprobs/structured fallback run (pending)

Closes #325

Add a [B, vocab] -> [B] fused batched sampler so a batched decode step samples all active rows in one fused_sample dispatch and one host copy, instead of B per-row slice + sample + eval + item_i32 round trips (one eval/sync per row).

mlxcel-core/sampling.rs gains batched_fused_sample plus the FusedSampleParams scalar-param carrier and two gate predicates (config_supports_fused_batch, row_supports_fused_batch). fused_sample already reduces over the last axis (argmax/categorical), so [B, vocab] yields [B] with no C++ change. Greedy (temperature 0 or top_k 1) is byte-identical to the per-row path because argmax over the last axis is independent per row; stochastic sampling differs only in RNG sequencing (the documented batched-vs-B=1 jitter class), not in the sampled distribution. The [B] ids are read from raw bytes, mirroring the existing compute_logprobs extraction, so no astype node is added to the sampling path.

BatchScheduler::execute_batched_decode gains a fast-path gate (batched_decode_fused_params) and dispatch that trigger only when every active row shares the same scalar sampling params and none needs a structured-output mask, a thinking-budget override, a per-token logprobs payload, a history-based penalty, or a token bias. When the gate passes, apply_fused_decode_tokens runs the shared bookkeeping tail (EOS, history, streaming decode, length limit, periodic cache clear, cache offset). The existing per-row loop is preserved exactly as the fallback for structured output, row-specific logprobs, token-bias observability, thinking budgets, and mixed sampling configs.

ActiveBatch gains an immutable get accessor for the read-only gate.

Tests (mlxcel-core sampling::): greedy parity vs the per-row batched_sample on synthetic [B, 1, vocab] logits, B=1 no-regression, 2-D input acceptance, and the gate predicates (config compatibility, scalar-param matching, and fallback when a row needs structured output, a thinking override, or per-token logprobs).
@inureyes inureyes added status:review Under review type:performance Performance improvements priority:high High priority area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:inference Generation, sampling, decoding (incl. speculative, DRY) labels Jun 17, 2026
@inureyes

Copy link
Copy Markdown
Member Author

Security and performance review: fused batched decode sampler

Scope: resource-safety and performance (internal numeric compute, no network/parsing/auth surface). Reviewed all three changed files (sampling.rs, scheduler.rs, active.rs) plus the fused_sample / slice_last_logits / array_to_raw_bytes FFI they depend on.

Result: no CRITICAL, HIGH, or MEDIUM findings. No code changes required.

Findings against the five risk areas:

  1. One eval/sync per step (B>=2): the fast path builds the [B, vocab] -> [B] graph lazily (slice_last_logits then fused_sample) and forces exactly one evaluation in array_to_raw_bytes. The gate batched_decode_fused_params is host-only HashMap reads with no eval. This collapses B per-row eval + item_i32 syncs into one. No extra evals or host round trips are added.
  2. Allocation: the hot path allocates one [B] token array, a 4*B-byte Vec<u8>, and a Vec<i32> of length B. No full-vocab host vector. The gate carries Copy FusedSampleParams rather than cloning SamplingConfig and its penalty Vecs. This is strictly less per-step allocation than the per-row path it replaces, which did a vocab-sized copy per row.
  3. Safety: chunks_exact(4) is panic-safe; fused_sample returns uint32, so the byte buffer is exactly 4*B with no remainder, and i32::from_ne_bytes([c0..c3]) is always in bounds. No unwrap/expect on the path; a vanished sequence uses ? in the gate and match -> continue in the apply phase. tokens[i] is provably in bounds because tokens.len() == B == seq_ids.len() (logits batch dim equals the input batch dim). The debug_assert_eq! documents that invariant and is compiled out in release.
  4. Concurrency: ActiveBatch::get (&self) returns Copy params and the borrow ends before the separate &mut self apply_fused_decode_tokens call. The scheduler decode step is single-threaded between check and use, and the apply phase re-guards with get_mut -> None -> continue. No aliasing and no TOCTOU.
  5. Output parity: greedy (temperature == 0 or top_k == 1) is byte-identical (per-row-independent argmax over the last axis; covered by a unit test). Stochastic sampling differs only in RNG sequencing, the documented batched-vs-B=1 jitter class. seed_rng_if_needed runs once at prefill, not per decode step, so the per-row path does not reseed per step either. The gate excludes every per-row obligation (structured-output mask, thinking-budget override, logprobs payload, token bias, history-based penalties), and apply_fused_decode_tokens matches the per-row loop tail line for line.

Routing confirmed: execute_decode_step sends B <= 1 and non-batching models to decode_single_step, so B=1 never enters the new path.

Optional (LOW, non-blocking): apply_fused_decode_tokens could index with tokens.get(i) for defense in depth, though the tokens.len() == seq_ids.len() invariant already rules out a panic.

@inureyes

Copy link
Copy Markdown
Member Author

PR Finalization

Tests

Reviewed all 29 tests in sampling::tests. The 10 tests added in this PR cover the new public surface directly:

  • batched_fused_sample: B=4 with [B, 1, V] (greedy parity against per-row path), B=1 no-regression, and 2D [B, V] input.
  • config_supports_fused_batch: gate-on for greedy/default/temperature configs; gate-off for repetition, frequency, presence, DRY penalties, and token bias separately.
  • FusedSampleParams::matches: self-identity plus per-field mismatch for all four scalar fields (temperature, top_k, top_p, min_p).
  • row_supports_fused_batch: gate-on for a plain greedy row; gate-off for each per-row obligation (structured-output mask, thinking-budget override, logprobs payload); gate-off for an incompatible penalty config.

No genuine coverage gaps found. The token_ids_to_host and FusedSampleParams::from_config helpers are exercised through the batched_fused_sample tests. The batched_decode_fused_params and apply_fused_decode_tokens scheduler methods are not unit-testable without a full BatchScheduler and are covered by the pure-function tests above.

Documentation

No existing doc has a natural home for an internal sampler fast-path note. CONTINUOUS_BATCHING.md covers scheduler behavior at the user-configuration level; architecture.md is higher-level. Skipping doc changes.

Lint / Format

  • cargo fmt --check: clean
  • cargo clippy -p mlxcel-core --lib --tests -- -D warnings: clean
  • cargo clippy -p mlxcel --lib -- -D warnings: clean
  • cargo test --release -p mlxcel-core sampling::: 29 passed

No commits needed. PR is production-ready pending the release build and real-model benchmark.

@inureyes

Copy link
Copy Markdown
Member Author

Real-model validation (orchestrator)

Validated on M1 Ultra (Metal) against qwen3-0.6b-4bit with --parallel 8, n_predict=96, temperature=0 (greedy):

  • Greedy parity (PASS): B=4 and B=8 fused fast-path output is byte-identical to each other and to the B=1 single-row reference. B=1 routes to decode_single_step and never enters the fused path, so this is a true fused-vs-reference comparison end to end.
  • Throughput: B=1 206.0 tok/s (recheck 192.8, no regression); B=4 aggregate 322.2 tok/s; B=8 aggregate 614.6 tok/s. Scales cleanly with batch size.
  • Reviewer, security-checker, and finalizer: 0 CRITICAL/HIGH/MEDIUM, no fixes required.

Note: a separate within-binary A/B run (logprobs enabled, forcing the per-row fallback) surfaced a pre-existing server crash in compute_logprobs top-k extraction (sampling.rs:722, commit 892b24288a, 2026-04-02), unrelated to this PR. It reads f16/bf16 logprobs as 4-byte f32 and overruns the byte slice (range end index 12 out of range for slice of length 10) on bf16 models. This PR's gate correctly routes logprobs requests to that fallback; #325 does not touch compute_logprobs. Filed as a separate issue.

@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 17, 2026
@inureyes inureyes merged commit 45a71d1 into main Jun 17, 2026
5 checks passed
@inureyes inureyes deleted the perf/325-batched-decode-sampler branch June 17, 2026 17:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:inference Generation, sampling, decoding (incl. speculative, DRY) priority:high High priority status:done Completed type:performance Performance improvements

Projects

None yet

Development

Successfully merging this pull request may close these issues.

perf(server): batched decode sampler fast path ([B, vocab] -> [B] fused sampling)

1 participant