perf(server): fuse batched decode sampling into one [B] dispatch#339
Conversation
Add a [B, vocab] -> [B] fused batched sampler so a batched decode step samples all active rows in one fused_sample dispatch and one host copy, instead of B per-row slice + sample + eval + item_i32 round trips (one eval/sync per row). mlxcel-core/sampling.rs gains batched_fused_sample plus the FusedSampleParams scalar-param carrier and two gate predicates (config_supports_fused_batch, row_supports_fused_batch). fused_sample already reduces over the last axis (argmax/categorical), so [B, vocab] yields [B] with no C++ change. Greedy (temperature 0 or top_k 1) is byte-identical to the per-row path because argmax over the last axis is independent per row; stochastic sampling differs only in RNG sequencing (the documented batched-vs-B=1 jitter class), not in the sampled distribution. The [B] ids are read from raw bytes, mirroring the existing compute_logprobs extraction, so no astype node is added to the sampling path. BatchScheduler::execute_batched_decode gains a fast-path gate (batched_decode_fused_params) and dispatch that trigger only when every active row shares the same scalar sampling params and none needs a structured-output mask, a thinking-budget override, a per-token logprobs payload, a history-based penalty, or a token bias. When the gate passes, apply_fused_decode_tokens runs the shared bookkeeping tail (EOS, history, streaming decode, length limit, periodic cache clear, cache offset). The existing per-row loop is preserved exactly as the fallback for structured output, row-specific logprobs, token-bias observability, thinking budgets, and mixed sampling configs. ActiveBatch gains an immutable get accessor for the read-only gate. Tests (mlxcel-core sampling::): greedy parity vs the per-row batched_sample on synthetic [B, 1, vocab] logits, B=1 no-regression, 2-D input acceptance, and the gate predicates (config compatibility, scalar-param matching, and fallback when a row needs structured output, a thinking override, or per-token logprobs).
Security and performance review: fused batched decode samplerScope: resource-safety and performance (internal numeric compute, no network/parsing/auth surface). Reviewed all three changed files ( Result: no CRITICAL, HIGH, or MEDIUM findings. No code changes required. Findings against the five risk areas:
Routing confirmed: Optional (LOW, non-blocking): |
PR FinalizationTestsReviewed all 29 tests in
No genuine coverage gaps found. The DocumentationNo existing doc has a natural home for an internal sampler fast-path note. Lint / Format
No commits needed. PR is production-ready pending the release build and real-model benchmark. |
Real-model validation (orchestrator)Validated on M1 Ultra (Metal) against
Note: a separate within-binary A/B run ( |
Summary
Add a
[B, vocab] -> [B]fused batched decode sampler so the batched decode loop samples all active rows in a singlefused_sampledispatch and a single host copy, replacing the per-rowslice + sample + eval + item_i32round trips that incurred one eval/sync per row. The payoff lands on fast small models and B>1 serving, where per-row sync and host scalar extraction are a visible fraction of decode.What changed
src/lib/mlxcel-core/src/sampling.rs: newbatched_fused_sample(&MlxArray, &FusedSampleParams) -> Vec<i32>that callsfused_sampleonce over[B, vocab](or[B, 1, vocab]) and copies the[B]token array to host in one eval.fused_samplealready reduces over the last axis (argmax/categorical), so[B, vocab]yields[B]with no C++ change. AddedFusedSampleParams(aCopyscalar-param carrier with bitwisematches) and the gate predicatesconfig_supports_fused_batchandrow_supports_fused_batch.src/server/batch/scheduler.rs:execute_batched_decodegains a fast-path gate (batched_decode_fused_params) plus dispatch. The gate returnsSome(params)only when every active row shares the same scalar sampling params and none needs a structured-output mask, a thinking-budget override, a per-token logprobs payload, a history-based penalty, or a token bias. On a hit,apply_fused_decode_tokensruns the shared bookkeeping tail (EOS, history, streaming decode, length limit, periodic cache clear, cache offset). The existing per-row loop is preserved exactly as the fallback.src/server/batch/active.rs: immutableActiveBatch::getaccessor for the read-only gate.Correctness
temperature == 0ortop_k == 1) output is byte-identical to the per-row path:argmaxover the last axis is independent per row. A unit test comparesbatched_fused_sampleagainst the per-rowbatched_sampleon synthetic[B, 1, vocab]logits and asserts equality.temperature/top_k/top_p/min_pacross rows) keeps the correct per-row distribution; it differs from the per-row path only in random-number sequencing, the documented batched-vs-B=1 jitter class. Any per-row penalty, token bias, structured-output mask, thinking budget, or logprobs request forces the unchanged per-row fallback.astypenode is added to the sampling path: the[B]token ids are read from raw bytes (fused_samplereturnsuint32, exact asi32for any token id), mirroring the existingcompute_logprobsextraction of argpartition indices.Test plan
cargo check -p mlxcel-core --lib --tests(clean)cargo check -p mlxcel --lib(clean)cargo test --release -p mlxcel-core sampling::(29 passed, incl. greedy parity, B=1 no-regression, 2-D input, and all gate predicates)cargo clippy -p mlxcel-core --lib --tests -- -D warningsandcargo clippy -p mlxcel --lib -- -D warnings(clean)cargo fmt --check(clean)metal,accelerate(pending: cold MLX C++ link exceeds the implementation watchdog)qwen3-0.6b-4bitandqwen3-4b-4bitat B=2/4/8: greedy output parity, decode tok/s, per-token latency, and no regression at B=1, plus one logprobs/structured fallback run (pending)Closes #325