Support dynamic shape-derived bounds in MLX ARange#2260
Open
cetagostini wants to merge 3 commits into
Open
Conversation
The MLX `arange` dispatcher rejected any non-constant bound at funcify time, so `pt.arange(x.shape[0])` failed for tensors with a dynamic (None) static shape, even though shape-derived bounds are concrete under `mx.compile`. This broke advanced-indexing/gather patterns such as `logp[pt.arange(targets.shape[0]), targets]`. Bake constant-foldable bounds and resolve the rest at runtime by converting the MLX scalar to a Python scalar, mirroring the JAX dispatch. Genuinely data-dependent bounds raise a clear NotImplementedError. Co-authored-by: Cursor <cursoragent@cursor.com>
ricardoV94
reviewed
Jun 26, 2026
Address review: replace the fragile runtime match on MLX's eval error string with a static, funcify-time check that walks the bound's graph treating Shape ops as barriers. A bound is resolvable iff it derives only from input shapes and constants; genuinely data-dependent bounds now raise NotImplementedError up front, consistently across MLX modes. This is more general than the JAX dispatch (which only recognizes a bare Shape_i). Helpers moved above the dispatcher that uses them. Co-authored-by: Cursor <cursoragent@cursor.com>
Contributor
Author
|
@ricardoV94 @jessegrabowski should we ready! |
…shape Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
pt.arangeon the MLX backend works for compile-time-constant bounds (including static-shape lengths that constant-fold), but fails for a genuinely dynamic (runtime) length:This surfaces indirectly: a vectorized gather like
logp[pt.arange(targets.shape[0]), targets]lowers to advanced indexing that internally buildsarange(targets.shape[0]), which is symbolic when the token tensors have dynamic shapes.Root cause
mlx_funcify_ARangecalledget_scalar_constant_valueon each bound at funcify time and raisedNotImplementedErrorif any wasn't constant-foldable. For a dynamic-shape input the bound lowers toSqueeze(Shape(x))(not a graph constant), so it was rejected — even though undermx.compilethe shape value is a concrete Pythonintat trace time (MLX retraces per shape).mx.arangeitself only accepts Pythonint/float, not arrays — that part of the limitation is real.arange(5))arange(xs.shape[0]),xsshape(5,3))arange(x.shape[0]),xshape(None,))NotImplementedErrorarange(x.sum()))NotImplementedErrorFix
Mirror the JAX dispatch (and the sibling
Eye/Allocidiom in the same file): bake constant-foldable bounds, and resolve the rest at runtime by converting the MLX scalar to a Python scalar via.item(). Shape-derived bounds are concrete undermx.compile, so the conversion returns the value without a real eval. Genuinely data-dependent bounds trigger MLX's[eval] Attempting to eval an array during function transformationserror, which is caught and re-raised as a clearNotImplementedError— reusing the same error-string idiom already used forAllocin this file.Tests
RED → GREEN verified (new tests fail on unpatched code, pass with the fix):
test_arange_dynamic_shape— every position (start/stop/step) shape-derived, an offset expression, and an empty (start > stop) result.test_arange_dynamic_advanced_index— the motivating gather patternlogp[arange(targets.shape[0]), targets].test_arange_data_dependent_raises— data-dependent length fails loudly undermx.compile.Full MLX suite:
163 passed, 4 xfailed(pre-existing, unrelated).ruffclean.🤖 Implemented with AI assistance (Cursor).
Made with Cursor