Skip to content

Add MLX backend dispatch for Scan#2258

Draft
cetagostini wants to merge 2 commits into
pymc-devs:mainfrom
cetagostini:mlx-scan-dispatch
Draft

Add MLX backend dispatch for Scan#2258
cetagostini wants to merge 2 commits into
pymc-devs:mainfrom
cetagostini:mlx-scan-dispatch

Conversation

@cetagostini

@cetagostini cetagostini commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds pytensor/link/mlx/dispatch/scan.py, registering mlx_funcify(Scan) so compiled loops (RNNs/recurrences, autoregressive generation, cumulative scans) can lower to the MLX backend. Previously any Scan raised NotImplementedError: No MLX conversion for the given Op: Scan{...} (JAX and Numba ship a scan dispatch; MLX and PyTorch did not).

It ports the JAX dispatch (link/jax/dispatch/scan.py). MLX has no native general-scan/while primitive to target, so JAX's lax.scan is replaced by a plain Python carry loop that is unrolled into the graph at trace time (since n_steps is statically known); mx.compile then compiles the resulting flat graph. The same trace-recreation semantics are kept: per-step values are stacked, then initial states are prepended / traces truncated to the buffer size.

This is a deliberate trade-off, not a primitive we're missing in PyTensor — it differs from both other backends:

  • JAX emits a single lax.scan (one XLA While); the inner graph appears once.
  • Numba code-generates a real while i < n_steps loop and JIT-compiles it; again the inner graph appears once and n_steps is a runtime value.
  • MLX has no loop op to emit and builds its graph by tracing, so the host-side Python loop unrolls — the inner graph is emitted n_steps times. Consequences: n_steps must be statically known, and graph size / compile time grow linearly with it. (The no-compile path runs the same unrolled ops eagerly, confirming the unrolling happens at trace time rather than inside mx.compile.)

MLX-specific design notes

  • Static step count. Under mx.compile, array shapes are concrete but scalar values are not readable. n_steps is therefore taken from a constant n_steps, else inferred from the (full-sized) recurring buffer shapes — which is sound because the MLX linker already excludes scan_reduce_trace_prealloc — else from the sequence length.
  • No .at[].set() in MLX, and in-place item assignment aliases buffers under mx.compile, so MIT-MOT writes use a functional scatter-add of the delta (buffer.at[idx].add(vals - buffer[idx])), mirroring the AdvancedIncSubtensor dispatch in the same backend.

Reproducer (now works)

import numpy as np, pytensor, pytensor.tensor as pt
pytensor.config.floatX = "float32"
w = pt.matrix("w"); init = pt.matrix("init")
out, _ = pytensor.scan(lambda p, w: pt.tanh(p @ w),
                       outputs_info=[init], non_sequences=[w], n_steps=3, return_updates=True)
pytensor.function([init, w], out, mode="MLX")  # previously raised NotImplementedError

Known limitation (separate bug, follow-up)

Gradients over sequences reverse the trace, which trips a pre-existing bare-MLX bug: mx.compile miscompiles an elementwise op fed by a negative-stride array, e.g.

import mlx.core as mx, numpy as np
mx.compile(lambda x: 2.0 * x[::-1])(mx.array(np.random.randn(5, 3).astype("float32")))  # tail rows zeroed

This is unrelated to the Scan dispatch (the dispatch logic is correct under the base MLX optimizer query, verified by test_scan_grad_over_sequence). It is captured here as a strict xfail under the full mode="MLX" (test_scan_grad_over_sequence_default_mode) and fixed by a follow-up that materializes negative-stride Subtensor results in the MLX backend.

Test plan

  • tests/link/mlx/test_scan.py — SIT-SOT/MIT-SOT (incl. views), sequences + non-seq RNN, multiple/combined recurring states, NIT-SOT-only, dtype preservation, n_steps=0, while-scan raises, gradients; plus the strict xfail above.
  • Full MLX suite green (pytest tests/link/mlx/), ruff/pre-commit clean.

🤖 Generated with assistance from Cursor.

Port the JAX scan dispatch to MLX. As MLX has no native general-scan
primitive, the inner fgraph is driven by a Python carry loop that
`mx.compile` unrolls. Because scalar values are not readable while MLX
traces, the (full-sized) recurring buffers are used to infer the number
of steps, falling back to a constant `n_steps` or the sequence length.

Covers seqs, MIT-SOT, SIT-SOT, NIT-SOT, untraced SIT-SOT, MIT-MOT (scan
gradients) and non-sequences, mirroring the JAX semantics of recreating
the trace and prepending/truncating to the buffer size.

Gradients over sequences reverse the trace and currently trip a separate
MLX bug (an elementwise op fed by a negative-stride array is miscompiled
under `mx.compile`); this is captured as a strict xfail under the full
`mode="MLX"` and is addressed by a follow-up.

Co-authored-by: Cursor <cursoragent@cursor.com>
Clarify in the dispatch that the Python carry loop is unrolled into the
graph at trace time (not by `mx.compile`) and that this is a workaround
until MLX exposes a native scan/while primitive (ml-explore/mlx#1441):
it needs a static step count and the graph grows as O(n_steps * inner_ops),
where a real primitive would keep it O(inner_ops).

Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini cetagostini self-assigned this Jun 26, 2026
@cetagostini cetagostini marked this pull request as draft June 26, 2026 10:35
@cetagostini

Copy link
Copy Markdown
Contributor Author

@ricardoV94 what do you think about this?

@cetagostini

Copy link
Copy Markdown
Contributor Author

If this get merge, I'll redo this: ml-explore/mlx#3772

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant