Skip to content

JAX: fix Alloc failing under jax.jit when shape inputs are constants#2050

Open
jaj42 wants to merge 3 commits intopymc-devs:mainfrom
jaj42:jaj42
Open

JAX: fix Alloc failing under jax.jit when shape inputs are constants#2050
jaj42 wants to merge 3 commits intopymc-devs:mainfrom
jaj42:jaj42

Conversation

@jaj42
Copy link
Copy Markdown

@jaj42 jaj42 commented Apr 13, 2026

jax_funcify_Alloc returns a closure that receives shape dimensions as
JAX arrays at runtime. When the compiled pytensor function is
re-traced by an outer jax.jit or jax.value_and_grad call, as done by
downstream libraries that extract f.vm.jit_fn and differentiate through
it, those JAX arrays are promoted to JitTracers. jnp.broadcast_to
requires a concrete tuple of Python ints for its shape argument and
raised:

TypeError: Shapes must be 1D sequences of concrete values of integer
type, got (JitTracer(int64[]),).

Fix: at compile time in jax_funcify_Alloc, attempt to resolve each
shape input to a concrete Python int using get_scalar_constant_value.
When all shape dimensions are constant (the common case), bake the
resulting tuple directly into the closure so that no JAX array is ever
passed as a shape argument at runtime. Dynamic shape dimensions fall
back to the previous behaviour.

I came across this issue while using DADVI with @wrap_jax in PyMC.

Please note this was analyzed and patched using AI.

Minimal example showing the problem:

import os

os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"

import jax
import jax.numpy as jnp
import numpy as np
import pytensor.tensor as pt

jax.config.update("jax_enable_x64", True)

# Direct JAX reproduction of the core issue.
# (Reproduces the mechanism without going through pytensor.)
def unpatched_alloc(x, *shape):
    """Mimics the runtime closure returned by unpatched jax_funcify_Alloc."""
    return jnp.broadcast_to(x, shape)

@jax.jit
def objective_bug1(x):
    n = jnp.array(5, dtype=jnp.int64)  # JAX constant → JitTracer under jit
    return unpatched_alloc(x, n).sum()

print("Direct JAX demonstration:")
try:
    val, grad = jax.value_and_grad(objective_bug1)(np.float64(2.0))
    print(f"Result: val={val}, grad={grad}   (patch applied, no error)")
except TypeError as e:
    print(f"BUG — TypeError: {e}")

# pytensor round-trip: compile a function with Alloc and differentiate it.
# On unpatched pytensor this raises the same TypeError via f.vm.jit_fn.
# On patched pytensor, jax_funcify_Alloc bakes in a concrete Python int
# at compile time (using get_scalar_constant_value), so the closure
# never contains JAX arrays.
x_pt = pt.scalar("x")
loss = pt.sum(pt.alloc(x_pt, 5))  # constant shape (5,)
f_jax = pytensor.function([x_pt], loss, mode="JAX")
jit_fn = f_jax.vm.jit_fn

print("\nPytensor round-trip (wraps f.vm.jit_fn in jax.value_and_grad):")
try:
    val, grad = jax.jit(lambda v: jax.value_and_grad(lambda v: jit_fn(v)[0])(v))(
        np.float64(2.0)
    )
    print(f"Result: val={val:.1f}, grad={grad:.1f}   (patch applied, no error)")
except TypeError as e:
    print(f"BUG — TypeError: {e}")

Related Issue

This mirrors the existing fix already applied to jax_funcify_ARange in
the same file for the same class of problem.

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

jax_funcify_Alloc returned a closure that received shape dimensions as
JAX arrays at runtime.  When the compiled pytensor function was
re-traced by an outer jax.jit or jax.value_and_grad call — as done by
downstream libraries that extract f.vm.jit_fn and differentiate through
it — those JAX arrays were promoted to JitTracers.  jnp.broadcast_to
requires a concrete tuple of Python ints for its shape argument and
raised:

TypeError: Shapes must be 1D sequences of concrete values of integer
	 type, got (JitTracer(int64[]),).

Fix: at compile time in jax_funcify_Alloc, attempt to resolve each
shape input to a concrete Python int using get_scalar_constant_value.
When all shape dimensions are constant (the common case), bake the
resulting tuple directly into the closure so that no JAX array is ever
passed as a shape argument at runtime.  Dynamic shape dimensions fall
back to the previous behaviour.

This mirrors the existing fix already applied to jax_funcify_ARange in
the same file for the same class of problem.
@jaj42 jaj42 marked this pull request as ready for review April 13, 2026 12:50
try:
static_shapes.append(int(get_scalar_constant_value(shape_input)))
except NotScalarConstantError:
static_shapes.append(None)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As soon as this fails you can stop iterating and go to the fallback branch? Then you don't need the if all(s is not None) check again either

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, this can be simplified.

What do you think about this:

static_shapes = []
for shape_input in node.inputs[1:]:
    try:
        static_shapes.append(int(get_scalar_constant_value(shape_input)))
    except NotScalarConstantError:
        concrete_shape = None
        break
else:
    concrete_shape = tuple(static_shapes)

def alloc(x, *shape):
    res = jnp.broadcast_to(x, concrete_shape if concrete_shape is not None else shape)
    Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
    return res

return alloc

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Give it a try

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works for me. I updated the diff

@ricardoV94
Copy link
Copy Markdown
Member

pre-commit failing

@jaj42
Copy link
Copy Markdown
Author

jaj42 commented Apr 14, 2026

Sorry, I fixed the linting. There is still a mypy failure but that is caused by code in another module, so I'm not sure I should fix that

@ricardoV94
Copy link
Copy Markdown
Member

mypy, not your fault so don't worry

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants