Add MLX backend dispatch for Scan#2258
Draft
cetagostini wants to merge 2 commits into
Draft
Conversation
Port the JAX scan dispatch to MLX. As MLX has no native general-scan primitive, the inner fgraph is driven by a Python carry loop that `mx.compile` unrolls. Because scalar values are not readable while MLX traces, the (full-sized) recurring buffers are used to infer the number of steps, falling back to a constant `n_steps` or the sequence length. Covers seqs, MIT-SOT, SIT-SOT, NIT-SOT, untraced SIT-SOT, MIT-MOT (scan gradients) and non-sequences, mirroring the JAX semantics of recreating the trace and prepending/truncating to the buffer size. Gradients over sequences reverse the trace and currently trip a separate MLX bug (an elementwise op fed by a negative-stride array is miscompiled under `mx.compile`); this is captured as a strict xfail under the full `mode="MLX"` and is addressed by a follow-up. Co-authored-by: Cursor <cursoragent@cursor.com>
3 tasks
Clarify in the dispatch that the Python carry loop is unrolled into the graph at trace time (not by `mx.compile`) and that this is a workaround until MLX exposes a native scan/while primitive (ml-explore/mlx#1441): it needs a static step count and the graph grows as O(n_steps * inner_ops), where a real primitive would keep it O(inner_ops). Co-authored-by: Cursor <cursoragent@cursor.com>
Contributor
Author
|
@ricardoV94 what do you think about this? |
Contributor
Author
|
If this get merge, I'll redo this: ml-explore/mlx#3772 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
pytensor/link/mlx/dispatch/scan.py, registeringmlx_funcify(Scan)so compiled loops (RNNs/recurrences, autoregressive generation, cumulative scans) can lower to the MLX backend. Previously anyScanraisedNotImplementedError: No MLX conversion for the given Op: Scan{...}(JAX and Numba ship a scan dispatch; MLX and PyTorch did not).It ports the JAX dispatch (
link/jax/dispatch/scan.py). MLX has no native general-scan/while primitive to target, so JAX'slax.scanis replaced by a plain Python carry loop that is unrolled into the graph at trace time (sincen_stepsis statically known);mx.compilethen compiles the resulting flat graph. The same trace-recreation semantics are kept: per-step values are stacked, then initial states are prepended / traces truncated to the buffer size.This is a deliberate trade-off, not a primitive we're missing in PyTensor — it differs from both other backends:
lax.scan(one XLAWhile); the inner graph appears once.while i < n_stepsloop and JIT-compiles it; again the inner graph appears once andn_stepsis a runtime value.n_stepstimes. Consequences:n_stepsmust be statically known, and graph size / compile time grow linearly with it. (The no-compile path runs the same unrolled ops eagerly, confirming the unrolling happens at trace time rather than insidemx.compile.)MLX-specific design notes
mx.compile, array shapes are concrete but scalar values are not readable.n_stepsis therefore taken from a constantn_steps, else inferred from the (full-sized) recurring buffer shapes — which is sound because the MLX linker already excludesscan_reduce_trace_prealloc— else from the sequence length..at[].set()in MLX, and in-place item assignment aliases buffers undermx.compile, so MIT-MOT writes use a functional scatter-add of the delta (buffer.at[idx].add(vals - buffer[idx])), mirroring theAdvancedIncSubtensordispatch in the same backend.Reproducer (now works)
Known limitation (separate bug, follow-up)
Gradients over sequences reverse the trace, which trips a pre-existing bare-MLX bug:
mx.compilemiscompiles an elementwise op fed by a negative-stride array, e.g.This is unrelated to the Scan dispatch (the dispatch logic is correct under the base MLX optimizer query, verified by
test_scan_grad_over_sequence). It is captured here as astrictxfail under the fullmode="MLX"(test_scan_grad_over_sequence_default_mode) and fixed by a follow-up that materializes negative-stride Subtensor results in the MLX backend.Test plan
tests/link/mlx/test_scan.py— SIT-SOT/MIT-SOT (incl. views), sequences + non-seq RNN, multiple/combined recurring states, NIT-SOT-only, dtype preservation,n_steps=0, while-scan raises, gradients; plus the strict xfail above.pytest tests/link/mlx/), ruff/pre-commit clean.🤖 Generated with assistance from Cursor.