JAX: fix Alloc failing under jax.jit when shape inputs are constants#2050
Open
jaj42 wants to merge 3 commits intopymc-devs:mainfrom
Open
JAX: fix Alloc failing under jax.jit when shape inputs are constants#2050jaj42 wants to merge 3 commits intopymc-devs:mainfrom
jaj42 wants to merge 3 commits intopymc-devs:mainfrom
Conversation
jax_funcify_Alloc returned a closure that received shape dimensions as JAX arrays at runtime. When the compiled pytensor function was re-traced by an outer jax.jit or jax.value_and_grad call — as done by downstream libraries that extract f.vm.jit_fn and differentiate through it — those JAX arrays were promoted to JitTracers. jnp.broadcast_to requires a concrete tuple of Python ints for its shape argument and raised: TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer(int64[]),). Fix: at compile time in jax_funcify_Alloc, attempt to resolve each shape input to a concrete Python int using get_scalar_constant_value. When all shape dimensions are constant (the common case), bake the resulting tuple directly into the closure so that no JAX array is ever passed as a shape argument at runtime. Dynamic shape dimensions fall back to the previous behaviour. This mirrors the existing fix already applied to jax_funcify_ARange in the same file for the same class of problem.
ricardoV94
reviewed
Apr 13, 2026
| try: | ||
| static_shapes.append(int(get_scalar_constant_value(shape_input))) | ||
| except NotScalarConstantError: | ||
| static_shapes.append(None) |
Member
There was a problem hiding this comment.
As soon as this fails you can stop iterating and go to the fallback branch? Then you don't need the if all(s is not None) check again either
Author
There was a problem hiding this comment.
Yes indeed, this can be simplified.
What do you think about this:
static_shapes = []
for shape_input in node.inputs[1:]:
try:
static_shapes.append(int(get_scalar_constant_value(shape_input)))
except NotScalarConstantError:
concrete_shape = None
break
else:
concrete_shape = tuple(static_shapes)
def alloc(x, *shape):
res = jnp.broadcast_to(x, concrete_shape if concrete_shape is not None else shape)
Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
return res
return alloc
Author
There was a problem hiding this comment.
It works for me. I updated the diff
…t, only define the closure once.
Member
|
pre-commit failing |
Author
|
Sorry, I fixed the linting. There is still a mypy failure but that is caused by code in another module, so I'm not sure I should fix that |
Member
|
mypy, not your fault so don't worry |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
jax_funcify_Alloc returns a closure that receives shape dimensions as
JAX arrays at runtime. When the compiled pytensor function is
re-traced by an outer jax.jit or jax.value_and_grad call, as done by
downstream libraries that extract f.vm.jit_fn and differentiate through
it, those JAX arrays are promoted to JitTracers. jnp.broadcast_to
requires a concrete tuple of Python ints for its shape argument and
raised:
TypeError: Shapes must be 1D sequences of concrete values of integer
type, got (JitTracer(int64[]),).
Fix: at compile time in jax_funcify_Alloc, attempt to resolve each
shape input to a concrete Python int using get_scalar_constant_value.
When all shape dimensions are constant (the common case), bake the
resulting tuple directly into the closure so that no JAX array is ever
passed as a shape argument at runtime. Dynamic shape dimensions fall
back to the previous behaviour.
I came across this issue while using DADVI with @wrap_jax in PyMC.
Please note this was analyzed and patched using AI.
Minimal example showing the problem:
Related Issue
This mirrors the existing fix already applied to jax_funcify_ARange in
the same file for the same class of problem.
Checklist
Type of change