Skip to content

Support dynamic shape-derived bounds in MLX ARange#2260

Open
cetagostini wants to merge 3 commits into
pymc-devs:mainfrom
cetagostini:mlx-arange-dynamic-shape
Open

Support dynamic shape-derived bounds in MLX ARange#2260
cetagostini wants to merge 3 commits into
pymc-devs:mainfrom
cetagostini:mlx-arange-dynamic-shape

Conversation

@cetagostini

Copy link
Copy Markdown
Contributor

Problem

pt.arange on the MLX backend works for compile-time-constant bounds (including static-shape lengths that constant-fold), but fails for a genuinely dynamic (runtime) length:

import numpy as np, pytensor, pytensor.tensor as pt
pytensor.config.floatX = "float32"

x = pt.vector("x")  # shape (None,)
pytensor.function([x], pt.arange(x.shape[0]), mode=None)(np.zeros(5, "float32"))  # OK
pytensor.function([x], pt.arange(x.shape[0]), mode="MLX")                         # NotImplementedError

This surfaces indirectly: a vectorized gather like logp[pt.arange(targets.shape[0]), targets] lowers to advanced indexing that internally builds arange(targets.shape[0]), which is symbolic when the token tensors have dynamic shapes.

Root cause

mlx_funcify_ARange called get_scalar_constant_value on each bound at funcify time and raised NotImplementedError if any wasn't constant-foldable. For a dynamic-shape input the bound lowers to Squeeze(Shape(x)) (not a graph constant), so it was rejected — even though under mx.compile the shape value is a concrete Python int at trace time (MLX retraces per shape).

mx.arange itself only accepts Python int/float, not arrays — that part of the limitation is real.

bound reference (NumPy/JAX) MLX before MLX after
constant (arange(5))
static shape (arange(xs.shape[0]), xs shape (5,3))
dynamic shape (arange(x.shape[0]), x shape (None,)) NotImplementedError
data-dependent (arange(x.sum())) ✅ (eager) ❌ clear NotImplementedError

Fix

Mirror the JAX dispatch (and the sibling Eye/Alloc idiom in the same file): bake constant-foldable bounds, and resolve the rest at runtime by converting the MLX scalar to a Python scalar via .item(). Shape-derived bounds are concrete under mx.compile, so the conversion returns the value without a real eval. Genuinely data-dependent bounds trigger MLX's [eval] Attempting to eval an array during function transformations error, which is caught and re-raised as a clear NotImplementedError — reusing the same error-string idiom already used for Alloc in this file.

Tests

RED → GREEN verified (new tests fail on unpatched code, pass with the fix):

  • test_arange_dynamic_shape — every position (start/stop/step) shape-derived, an offset expression, and an empty (start > stop) result.
  • test_arange_dynamic_advanced_index — the motivating gather pattern logp[arange(targets.shape[0]), targets].
  • test_arange_data_dependent_raises — data-dependent length fails loudly under mx.compile.

Full MLX suite: 163 passed, 4 xfailed (pre-existing, unrelated). ruff clean.

🤖 Implemented with AI assistance (Cursor).

Made with Cursor

The MLX `arange` dispatcher rejected any non-constant bound at funcify
time, so `pt.arange(x.shape[0])` failed for tensors with a dynamic
(None) static shape, even though shape-derived bounds are concrete
under `mx.compile`. This broke advanced-indexing/gather patterns such
as `logp[pt.arange(targets.shape[0]), targets]`.

Bake constant-foldable bounds and resolve the rest at runtime by
converting the MLX scalar to a Python scalar, mirroring the JAX
dispatch. Genuinely data-dependent bounds raise a clear
NotImplementedError.

Co-authored-by: Cursor <cursoragent@cursor.com>
Comment thread pytensor/link/mlx/dispatch/tensor_basic.py Outdated
Address review: replace the fragile runtime match on MLX's eval error
string with a static, funcify-time check that walks the bound's graph
treating Shape ops as barriers. A bound is resolvable iff it derives
only from input shapes and constants; genuinely data-dependent bounds
now raise NotImplementedError up front, consistently across MLX modes.
This is more general than the JAX dispatch (which only recognizes a bare
Shape_i). Helpers moved above the dispatcher that uses them.

Co-authored-by: Cursor <cursoragent@cursor.com>
@cetagostini

Copy link
Copy Markdown
Contributor Author

@ricardoV94 @jessegrabowski should we ready!

…shape

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants