Materialize negative-stride MLX Subtensor results#2259
Draft
cetagostini wants to merge 3 commits into
Draft
Conversation
Port the JAX scan dispatch to MLX. As MLX has no native general-scan primitive, the inner fgraph is driven by a Python carry loop that `mx.compile` unrolls. Because scalar values are not readable while MLX traces, the (full-sized) recurring buffers are used to infer the number of steps, falling back to a constant `n_steps` or the sequence length. Covers seqs, MIT-SOT, SIT-SOT, NIT-SOT, untraced SIT-SOT, MIT-MOT (scan gradients) and non-sequences, mirroring the JAX semantics of recreating the trace and prepending/truncating to the buffer size. Gradients over sequences reverse the trace and currently trip a separate MLX bug (an elementwise op fed by a negative-stride array is miscompiled under `mx.compile`); this is captured as a strict xfail under the full `mode="MLX"` and is addressed by a follow-up. Co-authored-by: Cursor <cursoragent@cursor.com>
Clarify in the dispatch that the Python carry loop is unrolled into the graph at trace time (not by `mx.compile`) and that this is a workaround until MLX exposes a native scan/while primitive (ml-explore/mlx#1441): it needs a static step count and the graph grows as O(n_steps * inner_ops), where a real primitive would keep it O(inner_ops). Co-authored-by: Cursor <cursoragent@cursor.com>
`mx.compile` miscompiles an elementwise op fed by a negative-stride view (`mx.compile(lambda x: 2.0 * x[::-1])` zeroes the trailing entries; eager is correct). The MLX Subtensor dispatch now copies reversed slices into a contiguous array. This unblocks Scan gradients over sequences (which reverse the trace) under the full `mode="MLX"`, and resolves the existing strict xfail `test_mlx_IncSubtensor_negative_step_slice_grad` — whose failure was the same negative-stride read feeding the elementwise gradient term, not the IncSubtensor write it was attributed to (ml-explore/mlx#3716). Co-authored-by: Cursor <cursoragent@cursor.com>
273b91c to
85e3f6d
Compare
ricardoV94
reviewed
Jun 26, 2026
ricardoV94
left a comment
Member
There was a problem hiding this comment.
There are some edge cases of Scan packaged as reusable tests that you can pass mode to, you should try them
Contributor
Author
|
I'll redo this if mlx patch get merge: ml-explore/mlx#3772 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
mx.compilemiscompiles an elementwise op fed by a negative-stride array — trailing entries are zeroed (eager /use_compile=Falseare correct):The MLX
Subtensordispatch now materializes reversed slices into a contiguous array (mx.contiguous) so downstream compiled kernels see correct data. Only negative steps are affected (positive steps / basic slices are contiguous and untouched).Why it matters
mode="MLX": the backward pass reverses the trace and applies elementwise terms, which previously tripped this bug. The strict xfail added in Add MLX backend dispatch for Scan #2258 (test_scan_grad_over_sequence_default_mode) becomes a normal passing test here.test_mlx_IncSubtensor_negative_step_slice_grad). Its failure — attributed to mx.compile: assigning an elementwise expression to a negative-strided slice writes only one element ml-explore/mlx#3716 (assigning to a negative-strided slice) — was actually this same negative-stride read feeding the elementwise gradient term; the IncSubtensor write to a reversed slice is independently correct (verified). Un-skipped per the "sweep for mis-attributed xfails" convention.Dependency
Builds on #2258 (MLX Scan dispatch) and should be reviewed/merged after it. The diff will reduce to just the Subtensor change once #2258 lands on
main.Test plan
tests/link/mlx/test_subtensor.py::test_mlx_negative_step_slice_elemwise(reversed slice + elementwise undermode="MLX", both axes).test_mlx_IncSubtensor_negative_step_slice_gradandtest_scan_grad_over_sequence_default_modenow pass.pytest tests/link/mlx/), ruff/pre-commit clean.🤖 Generated with assistance from Cursor.
Made with Cursor