Skip to content

Materialize negative-stride MLX Subtensor results#2259

Draft
cetagostini wants to merge 3 commits into
pymc-devs:mainfrom
cetagostini:mlx-subtensor-negative-stride
Draft

Materialize negative-stride MLX Subtensor results#2259
cetagostini wants to merge 3 commits into
pymc-devs:mainfrom
cetagostini:mlx-subtensor-negative-stride

Conversation

@cetagostini

Copy link
Copy Markdown
Contributor

Summary

mx.compile miscompiles an elementwise op fed by a negative-stride array — trailing entries are zeroed (eager / use_compile=False are correct):

import mlx.core as mx, numpy as np
mx.compile(lambda x: 2.0 * x[::-1])(mx.array(np.arange(15, dtype="float32").reshape(5, 3)))
# tail rows are 0.0 instead of the reversed values

The MLX Subtensor dispatch now materializes reversed slices into a contiguous array (mx.contiguous) so downstream compiled kernels see correct data. Only negative steps are affected (positive steps / basic slices are contiguous and untouched).

Why it matters

  • Unblocks Scan gradients over sequences under the full mode="MLX": the backward pass reverses the trace and applies elementwise terms, which previously tripped this bug. The strict xfail added in Add MLX backend dispatch for Scan #2258 (test_scan_grad_over_sequence_default_mode) becomes a normal passing test here.
  • Resolves an existing strict xfail (test_mlx_IncSubtensor_negative_step_slice_grad). Its failure — attributed to mx.compile: assigning an elementwise expression to a negative-strided slice writes only one element ml-explore/mlx#3716 (assigning to a negative-strided slice) — was actually this same negative-stride read feeding the elementwise gradient term; the IncSubtensor write to a reversed slice is independently correct (verified). Un-skipped per the "sweep for mis-attributed xfails" convention.

Dependency

Builds on #2258 (MLX Scan dispatch) and should be reviewed/merged after it. The diff will reduce to just the Subtensor change once #2258 lands on main.

Test plan

  • New tests/link/mlx/test_subtensor.py::test_mlx_negative_step_slice_elemwise (reversed slice + elementwise under mode="MLX", both axes).
  • Un-xfailed test_mlx_IncSubtensor_negative_step_slice_grad and test_scan_grad_over_sequence_default_mode now pass.
  • Full MLX suite green (pytest tests/link/mlx/), ruff/pre-commit clean.

🤖 Generated with assistance from Cursor.

Made with Cursor

Port the JAX scan dispatch to MLX. As MLX has no native general-scan
primitive, the inner fgraph is driven by a Python carry loop that
`mx.compile` unrolls. Because scalar values are not readable while MLX
traces, the (full-sized) recurring buffers are used to infer the number
of steps, falling back to a constant `n_steps` or the sequence length.

Covers seqs, MIT-SOT, SIT-SOT, NIT-SOT, untraced SIT-SOT, MIT-MOT (scan
gradients) and non-sequences, mirroring the JAX semantics of recreating
the trace and prepending/truncating to the buffer size.

Gradients over sequences reverse the trace and currently trip a separate
MLX bug (an elementwise op fed by a negative-stride array is miscompiled
under `mx.compile`); this is captured as a strict xfail under the full
`mode="MLX"` and is addressed by a follow-up.

Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini cetagostini marked this pull request as draft June 26, 2026 09:55
@cetagostini cetagostini self-assigned this Jun 26, 2026
cetagostini and others added 2 commits June 26, 2026 13:33
Clarify in the dispatch that the Python carry loop is unrolled into the
graph at trace time (not by `mx.compile`) and that this is a workaround
until MLX exposes a native scan/while primitive (ml-explore/mlx#1441):
it needs a static step count and the graph grows as O(n_steps * inner_ops),
where a real primitive would keep it O(inner_ops).

Co-authored-by: Cursor <cursoragent@cursor.com>
`mx.compile` miscompiles an elementwise op fed by a negative-stride view
(`mx.compile(lambda x: 2.0 * x[::-1])` zeroes the trailing entries; eager
is correct). The MLX Subtensor dispatch now copies reversed slices into a
contiguous array.

This unblocks Scan gradients over sequences (which reverse the trace) under
the full `mode="MLX"`, and resolves the existing strict xfail
`test_mlx_IncSubtensor_negative_step_slice_grad` — whose failure was the same
negative-stride read feeding the elementwise gradient term, not the
IncSubtensor write it was attributed to (ml-explore/mlx#3716).

Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini cetagostini force-pushed the mlx-subtensor-negative-stride branch from 273b91c to 85e3f6d Compare June 26, 2026 10:33

@ricardoV94 ricardoV94 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some edge cases of Scan packaged as reusable tests that you can pass mode to, you should try them

@cetagostini

Copy link
Copy Markdown
Contributor Author

I'll redo this if mlx patch get merge: ml-explore/mlx#3772

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants