Skip to content

Rewrite ShapeFeature to not hold live variables#2056

Draft
ricardoV94 wants to merge 4 commits into
pymc-devs:mainfrom
ricardoV94:shape_feature
Draft

Rewrite ShapeFeature to not hold live variables#2056
ricardoV94 wants to merge 4 commits into
pymc-devs:mainfrom
ricardoV94:shape_feature

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 17, 2026

Closes pymc-devs/pymc-extras#673

ShapeFeature reintroducing variables we have lowered/rewritten away is no-bueno

Details Replace the eager per-variable dict (shape_of, shape_of_reverse_index, scheduled) with a lazy FrozenFunctionGraph-based shape kernel cache. For each Apply, a kernel built from dummy clones of node.inputs is stored in self._cache[node] and materialized against today's live inputs on demand via a custom frozen-graph walker (graph_replace would mutate globally-interned FrozenApply inputs).

The kernel holds only NominalVariables and Constants, so no live
variable can leak between tests or across rewrites, eliminating by
construction the stale-XRV class of bugs.

Back-compat surface (_LazyShapeTuple, _ShapeOfProxy, update_shape,
shape_ir, init_r) is retained and marked as temporary. A regression
test for the stale-XRV scenario replaces the prior xfail.

shape_of_variables switches to builders.infer_shape so it returns to
scalar-dim inputs instead of allocating per-input arrays.

local_track_shape_i no longer depends on the deleted scheduled dict;
it rewrites Shape_i(v, i) to get_shape(v, i) whenever the kernel
produces something other than the trivial fallback.

on_change_input carries r's inferred shape onto new_r as an override
when new_r's Op has no infer_shape, preserving the legacy behavior
where a well-inferred shape survives through a replacement with an
opaque op.

Benchmarks (cxx enabled):

  • radon_repeat 0.78s -> 0.55s (-30%)
  • radon_variants (8) 7.9s -> 7.2s ( -9%)
  • fusion_large 0.22s -> 0.22s (noise)
  • fusion_deep 13ms -> 13ms (noise)

ricardoV94 added 4 commits May 1, 2026 18:00
`Alloc.do_constant_folding` listed `Elemwise | DimShuffle | Alloc | Join`
and batched-`Blockwise` as protected client ops, but not `Subtensor`.
`local_subtensor_of_alloc` rewrites `alloc(val, *shape)[idx]` into
`alloc(val[...], *new_shape)` — preserving the Alloc structure that
downstream rewrites like `local_blockwise_alloc_inputs` depend on.
Folding the Alloc here short-circuited that lift and produced
broadcast-equivalent `Constant` matrices whose batch dim was no longer
type-broadcastable, so `local_blockwise_reshape` couldn't unwrap the
surrounding `Blockwise(Reshape)`.

Surfaced by the lazy-kernel `ShapeFeature` (which resolves
`Subtensor(Shape(out), const)` to a scalar `Constant` earlier and
makes more upstream Allocs constant-foldable), but the fix belongs
here — the protection was too narrow.
Breaking API change: the `fgraph` argument was unused by every
in-tree `infer_shape` implementation. Removing it makes
`infer_shape` a pure function of `(node, input_shapes)`, simpler
to call from outside an fgraph context (e.g. ShapeFeature's lazy
kernel build) and tighter as a contract.

External Ops with custom `infer_shape(self, fgraph, node, input_shapes)`
must drop the `fgraph` parameter.
Add `break_aliasing_cycles` to `pytensor.graph.replace`. When an inplace
Op overwrites input `x` and a single Apply ends up reading both `x` and
a transitive dependent of the destroyer's output, no valid schedule
exists. The helper re-routes such inputs through `deep_copy_op` to lift
the conflict.

Expose it via a `ShapeFeature.get_shape_no_cycle` convenience method,
and use it from `introduce_explicit_core_shape_rv` and
`introduce_explicit_core_shape_blockwise`, where lazy shape
materialization can otherwise produce that pattern.
@ricardoV94
Copy link
Copy Markdown
Member Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Prior.create_variable(xdist=True) fails compile_logp for centered priors with nested Prior parameters that have dims

1 participant