Rewrite ShapeFeature to not hold live variables#2056
Draft
ricardoV94 wants to merge 4 commits into
Draft
Conversation
f80be94 to
aba080f
Compare
6f24fca to
c37cb6b
Compare
`Alloc.do_constant_folding` listed `Elemwise | DimShuffle | Alloc | Join` and batched-`Blockwise` as protected client ops, but not `Subtensor`. `local_subtensor_of_alloc` rewrites `alloc(val, *shape)[idx]` into `alloc(val[...], *new_shape)` — preserving the Alloc structure that downstream rewrites like `local_blockwise_alloc_inputs` depend on. Folding the Alloc here short-circuited that lift and produced broadcast-equivalent `Constant` matrices whose batch dim was no longer type-broadcastable, so `local_blockwise_reshape` couldn't unwrap the surrounding `Blockwise(Reshape)`. Surfaced by the lazy-kernel `ShapeFeature` (which resolves `Subtensor(Shape(out), const)` to a scalar `Constant` earlier and makes more upstream Allocs constant-foldable), but the fix belongs here — the protection was too narrow.
Breaking API change: the `fgraph` argument was unused by every in-tree `infer_shape` implementation. Removing it makes `infer_shape` a pure function of `(node, input_shapes)`, simpler to call from outside an fgraph context (e.g. ShapeFeature's lazy kernel build) and tighter as a contract. External Ops with custom `infer_shape(self, fgraph, node, input_shapes)` must drop the `fgraph` parameter.
Add `break_aliasing_cycles` to `pytensor.graph.replace`. When an inplace Op overwrites input `x` and a single Apply ends up reading both `x` and a transitive dependent of the destroyer's output, no valid schedule exists. The helper re-routes such inputs through `deep_copy_op` to lift the conflict. Expose it via a `ShapeFeature.get_shape_no_cycle` convenience method, and use it from `introduce_explicit_core_shape_rv` and `introduce_explicit_core_shape_blockwise`, where lazy shape materialization can otherwise produce that pattern.
Member
Author
|
Rewrite time on the asv experiment down: https://ricardov94.github.io/pymc-model-catalogue/experiments.html#base=shape_feature_pr2056_base&compare=shape_feature_pr2056
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Closes pymc-devs/pymc-extras#673
ShapeFeature reintroducing variables we have lowered/rewritten away is no-bueno
Details
Replace the eager per-variable dict (shape_of, shape_of_reverse_index, scheduled) with a lazy FrozenFunctionGraph-based shape kernel cache. For each Apply, a kernel built from dummy clones of node.inputs is stored in self._cache[node] and materialized against today's live inputs on demand via a custom frozen-graph walker (graph_replace would mutate globally-interned FrozenApply inputs).The kernel holds only NominalVariables and Constants, so no live
variable can leak between tests or across rewrites, eliminating by
construction the stale-XRV class of bugs.
Back-compat surface (_LazyShapeTuple, _ShapeOfProxy, update_shape,
shape_ir, init_r) is retained and marked as temporary. A regression
test for the stale-XRV scenario replaces the prior xfail.
shape_of_variables switches to builders.infer_shape so it returns to
scalar-dim inputs instead of allocating per-input arrays.
local_track_shape_i no longer depends on the deleted scheduled dict;
it rewrites Shape_i(v, i) to get_shape(v, i) whenever the kernel
produces something other than the trivial fallback.
on_change_input carries r's inferred shape onto new_r as an override
when new_r's Op has no infer_shape, preserving the legacy behavior
where a well-inferred shape survives through a replacement with an
opaque op.
Benchmarks (cxx enabled):