Skip to content

Clean up ShapeFeature caches and remove tracks_shape#2104

Draft
ricardoV94 wants to merge 5 commits intopymc-devs:mainfrom
ricardoV94:shape_feature_cleanup
Draft

Clean up ShapeFeature caches and remove tracks_shape#2104
ricardoV94 wants to merge 5 commits intopymc-devs:mainfrom
ricardoV94:shape_feature_cleanup

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

  • Drop fgraph parameter from Op.infer_shape signature
  • Don't constant-fold Alloc consumed by Subtensor
  • Rewrite ShapeFeature as lazy kernel-based feature
  • WIP
  • Clean up ShapeFeature caches and remove tracks_shape

Breaking API change: the `fgraph` argument was unused by every
in-tree `infer_shape` implementation. Removing it makes
`infer_shape` a pure function of `(node, input_shapes)`, simpler
to call from outside an fgraph context (e.g. ShapeFeature's lazy
kernel build) and tighter as a contract.

External Ops with custom `infer_shape(self, fgraph, node, input_shapes)`
must drop the `fgraph` parameter.
`Alloc.do_constant_folding` listed `Elemwise | DimShuffle | Alloc | Join`
and batched-`Blockwise` as protected client ops, but not `Subtensor`.
`local_subtensor_of_alloc` rewrites `alloc(val, *shape)[idx]` into
`alloc(val[...], *new_shape)` — preserving the Alloc structure that
downstream rewrites like `local_blockwise_alloc_inputs` depend on.
Folding the Alloc here short-circuited that lift and produced
broadcast-equivalent `Constant` matrices whose batch dim was no longer
type-broadcastable, so `local_blockwise_reshape` couldn't unwrap the
surrounding `Blockwise(Reshape)`.

Surfaced by the lazy-kernel `ShapeFeature` (which resolves
`Subtensor(Shape(out), const)` to a scalar `Constant` earlier and
makes more upstream Allocs constant-foldable), but the fix belongs
here — the protection was too narrow.
Build one FrozenFunctionGraph "kernel" per Apply, rooted in
NominalVariable clones of node.inputs. Each kernel is cached in
self._cache and materialized on demand against today's live inputs
via a custom walker (graph_replace would mutate the globally-interned
FrozenApply's inputs). The kernel never holds live variables, so
stale references can't leak into shape expressions across rewrites.

local_track_shape_i rewrites Shape_i(v, i) with the kernel-inferred
expression directly. on_change_input installs r's inferred shape as
an override on new_r when new_r's Op has no infer_shape.

Also includes:
- break_aliasing_cycles (graph/replace.py) for sub-graphs where a
  single Apply reads an inplace-destroyed input and has another input
  that depends on the destroyer's output
- Hash-cons materialized get_shape(v, i) results
- Canonicalize Subtensor(Shape(x), const) / Shape(x) patterns into
  Shape_i / MakeVector post-materialization
- Drop set_shape; route overrides through borrowed kernels
- Drop fallback_out role (redundant with layout-None branch)
- Updated builders.infer_shape: leaf-rebinding approach
- Rename _cache -> _shape_kernel_cache, _materialized -> _materialized_dim_cache
- Key _materialized_dim_cache by node instead of (id(v), i) for cheap invalidation
- Remove tracks_shape (broken with lazy design), replace with
  _inferred_shape_or_fallback helper in scan rewriting
- Simplify on_prune/on_change_input cache invalidation
- Clean up getattr(out.type, "ndim", 0) or 0 pattern
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant