docs(benchmarks): Gemma3n decode profile, no fusion justified (#329)#345
Merged
Conversation
Profile-first deliverable for #329. Measured Gemma3n decode on M1 Ultra (e2b-4bit 83 tok/s, e4b-4bit 63 tok/s) with mlxcel-bench-decode and the mlxcel-gpu-profiling hooks, and compared against mlx-vlm, the only Python runtime that loads these multimodal checkpoints (mlx-lm cannot). Conclusion: no compiled fusion is justified now. Decode is ~92% GPU-bound and the Rust graph-build step where an FFI-crossing fusion would land is ~7% and fully hidden behind async_eval, so a fusion that only collapses FFI crossings would regress as MLXCEL_FUSED_QK_NORM did on Qwen3. The worthwhile Gemma3n fusions already shipped in #60 (fused MLP bridge, compiled gelu_topk/GeGLU, stacked AltUp) and are active on M1 Ultra. mlxcel already leads mlx-vlm by 1.20-1.23x on these checkpoints. The decode-time overhead above pure weight streaming is dominated by command-buffer dispatch gaps, not small-kernel compute: raising MLX_MAX_OPS_PER_BUFFER toward 1000 recovers +11-13% with no code. That scheduling knob, hardware-gated since M5 regresses with larger buffers, is the recommended follow-up and belongs in its own issue, not a fusion. Adds docs/benchmark_results/gemma3n-decode-profile.md with the full numbers, per-layer op analysis, text-vs-VLM comparison, and reproduce commands; links it from docs/benchmarks.md.
The no-fusion decode-profile finding is unchanged; these are accuracy fixes to its supporting claims, verified against the code on this branch. The fused MLP bridge (gemma3n_mlp_forward) is gated on regular_weight(), so it is bf16-only and not active on the 4-bit e2b/e4b checkpoints profiled here; the op-histogram QuantizedMatmul nodes are the unfused gate/up/down. Extending the bridge to quantized weights would only collapse the FFI crossings the profile already shows are hidden, so the conclusion stands. The compiled gelu_topk and GeGLU activation kernels predate #60 (added 2026-04-02 and 2026-04-24) and are reused by it; #60 itself shipped the fused MLP bridge and stacked AltUp. Corrected the three spots that attributed the activation kernels to #60 or implied the bridge was active on the 4-bit path.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Profile-first investigation for #329: profile Gemma3n decode, then add a compiled fusion only if the profile justifies one. It does not, so this PR ships the profiling finding as documentation and adds no kernel. A documented "no fusion justified" outcome is the result the issue explicitly calls for.
Findings (Apple M1 Ultra, MLX 0.31.2, mlx-vlm 0.4.4)
forward), where an FFI-crossing fusion would land, is ~7% per token and fully hidden behindasync_eval. A fusion that only collapses FFI crossings would regress here, the same outcome asMLXCEL_FUSED_QK_NORMon Qwen3 (1-3.4% slower on M1 Ultra).use_fused_decode_path()is true on non-NA hardware).MLX_MAX_OPS_PER_BUFFERtoward 1000 recovers +11-13% with no code (e2b 82.7 -> 93.2, e4b 63.0 -> 70.0). That is a scheduling knob, not a kernel, and must be hardware-gated because M5 regresses with larger buffers, so it belongs in a separate follow-up issue.What changed
docs/benchmark_results/gemma3n-decode-profile.md: full numbers, pipeline split, one-decode-token op histogram, pure-GEMV streaming floor, command-buffer sweep, text-vs-VLM table, and reproduce commands.docs/benchmarks.md.Test plan
mlxcel-bench-decode.MLXCEL_PROFILE_PIPELINE_DETAIL; op histogram viaMLXCEL_EXPORT_DECODE_DOT.mlx_vlm.stream_generate(wall-clock).Closes #329