Stop subtensor merge bonanza#2098
Draft
ricardoV94 wants to merge 6 commits intopymc-devs:mainfrom
Draft
Conversation
Member
Author
|
Possibly fixes: #1288 need to check |
a2dd243 to
3126e9c
Compare
Adds ``test_scan_grad_subtensor_compile_benchmark`` covering the most extreme pathology in issue pymc-devs#112: ``grad(xs[-1], x0)`` over a symbolic- ``n_steps`` Scan used to blow up into hundreds of switch / min / max nodes via the unguarded ``local_subtensor_merge``.
…lice chain)
``test_local_subtensor_merge_compile_benchmark[depth={3,5,8}]`` --
nested ``[1:-1]`` slices on a vector. The original
``local_subtensor_merge`` fired ``merge_two_slices`` per pair regardless
of whether the result simplified, generating switch/min/max trees that
constant-folding would later flatten -- final node count was small but
the rewriter still paid for the intermediate explosion. Compile time
scales with depth.
Warm caches once before timing so the recorded number is rewrite/compile
cost, not import or JIT init.
The unconditional merge produced large switch/min/max trees whenever any component of the chain (slice bounds or shapes) was symbolic — most visibly on Scan outputs whose stripping slice is rolled together with a client index. Add ``_can_merge_simply``: only merge slice+slice when both steps are constant, and slice+scalar when all components and shapes are constant. Non-mergeable dimensions stay as a separate outer Subtensor. Add ``local_subtensor_merge_unsafe`` (tag ``shape_unsafe``) that handles slice+scalar with step ±1 more aggressively without bounds checks, for the cases where the safe merge bails out.
…nsor_chain
The old rewrite only handled ``while_scan_out[init_l:][-1]``. Generalize:
walk arbitrary chains of constant-bound ``Subtensor`` clients rooted at a
scan output, fold the chain to a direct ``raw_out[final_idx]``, and
insert ``Assert(n_steps >= k)`` when the final index requires it.
Slice steps are restricted to ``None``, ``1``, ``-1`` (paired ``-1`` steps
cancel back to forward). Chains with symbolic bounds bail.
Two evaluation paths share a common helper:
* Concrete simulation when ``n_steps`` is constant -- run the chain on
a Python tracker list of indices.
* Affine bookkeeping ``(c, d)`` over ``raw_length`` when ``n_steps`` is
symbolic -- track an iterator view ``raw[base + i*direction]`` and
accumulate the minimum ``raw_length`` constraint as bounds are
normalized.
Registered in ``scan_eqopt1`` so the folded form reaches ``scan_save_mem``
in time. While-scans only fold to ``-1`` (the only post-early-exit-safe
target).
The old rewrite's ``"n_steps > 0"`` assertion message becomes
``"n_steps >= 1"`` (uniform with the new ``"n_steps >= k"`` form for
k > 1).
3126e9c to
30c36f6
Compare
… indices
Two coupled changes that are easier to review together because the new
``scan_save_mem`` relies on ``scan_reduce_nsteps`` to have run first.
scan_reduce_nsteps (new, registered at position 1.60, before save_mem):
when every client of a Scan output is a ``Subtensor`` whose first index
is a constant scalar, reduce ``n_steps`` to the minimum that covers all
those scalars and rewrite each client to use a negative index against
the new (shorter) trace. This trims the iteration count itself, not just
the buffer.
scan_save_mem refactor:
* No longer calls ``get_canonical_form_slice`` on client slices --
after ``scan_merge_subtensor_chain`` and ``scan_reduce_nsteps`` have
run, every reachable client either has a constant negative index/
slice-start (whose magnitude is the buffer length we need) or is
something we can't trim. Walk those directly.
* The orphan-output branch is gone (handled by ``scan_remove_unused``).
* For preallocated buffers, ``extra_size`` is capped at ``n_steps`` so
a small ``n_steps`` doesn't hand the inner loop uninitialized slots.
* The ``shape_of`` cache is no longer needed since negative-index buffer
sizing doesn't probe shapes.
Net diff is mostly deletions in scan_save_mem; the new function itself
adds ~150 lines.
Test updates that come with the refactor:
* ``test_save_mem_store_steps``: read buffer length via runtime
``shape[0]`` instead of searching for ``AllocEmpty`` in ancestors;
the trimmed buffer can now be a direct slice of the user's init
input when no extra slot is needed (the rewrite skips the
``expand_empty(AllocEmpty(...))`` round-trip).
* ``test_while_scan_taps_and_map``: ys is now an untraced sit_sot
(scalar input) once save_mem reduces buffer to 1; assert that
structurally instead of probing ``ys_trace.shape[0]``.
* ``test_inplace_taps[symbolic n_steps]``: under symbolic ``n_steps``
buffers can stay symbolic and the inplace pass may or may not fire;
don't assert on ``destroyed_inputs`` in that branch.
After the save_mem refactor, ``scan_save_mem`` + ``scan_sit_sot_to_untraced``
can shrink a sit_sot output all the way to a 0-d scalar (no leading
length dimension at all). Two downstream callsites assumed an n-d
output and crashed:
* ``Scan.infer_shape`` for while-scans called ``Shape_i(0)(o)``
unconditionally to model the unknown leading dimension. Skip it
for 0-d outputs and just propagate the empty shape tuple.
* ``local_subtensor_shape_constant`` indexed
``shape_arg.type.broadcastable[idx_val]`` without bounds checking,
so user code like ``scalar.shape[0]`` (legitimate for some
constructs, but ill-formed for a 0-d) raised ``IndexError`` from
inside the rewriter and aborted the optimization pass. Catch and
bail; the operation is still semantically wrong at runtime, but
the rewriter shouldn't be the one to surface it.
30c36f6 to
8456244
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes #112
Closes #1283
local_subtensor_mergeno longer expands into combinatorialswitch/min/maxtrees on Scan outputs with symbolic shapes, andscan_save_memstill trims buffers all the way down totaps + 1whenever a chain of constant-bound subtensors ends in a constant scalar — including when intermediate slices sit between the strip and the terminal index.What changed
local_subtensor_merge+ add a shape-free fast path. Only call the symbolicmerge_two_sliceswhen slice bounds (and shapes, for slice+scalar) are constant. A new_safe_slice_slice_mergehandles the common slice-on-slice cases without consulting shape: forward × forward (sign-aware bound combination, e.g.x[1:-1][1:-1]→x[2:-2]),x[a:b][::-1],x[::-1][a:b],x[::-1][a:b:-1], andx[a:b:-1][::-1](last restricted to non-negative bounds). Aggressive scalar-into-slice merging moves tolocal_subtensor_merge_unsafe(shape_unsafe).while_scan_merge_subtensor_last_elementwithscan_merge_subtensor_chain— walks arbitrary chains of constant-boundSubtensorclients (steps ∈ {None, 1, -1}) on a Scan output, folds them into a directraw_out[final_idx]with anAssert(n_steps >= k)for safety. Concrete simulation for constantn_steps; affine bookkeeping overraw_lengthfor symbolic.scan_reduce_nsteps— when every client of a Scan output is a constant scalar index, reducen_stepstok + 1 - init_land convert clients to negative form so save_mem can finish the job.scan_save_mem— dropsget_canonical_form_slice; reads buffer requirements straight off the (now-folded) negative indices. Caps preallocextra_sizeatn_stepsto avoid uninitialized slots.infer_shapeandlocal_subtensor_shape_constantfor 0-d outputs that arise after save_mem + sit_sot_to_untraced collapse.Invariant honored
result[...][-1](static[-1]through any chain of constant-bound, ±1-step slices) always reduces to a unit buffer, with both constant and symbolicn_steps.Verified
n=10nresult[-1]result[-3:][-1]result[3:][-1]result[::-1][0]result[5::-1][-1]result[:-1][::-1][-1]tests/scan/test_rewriting.py: 59 passed / 2 skipped.tests/tensor/rewriting/test_subtensor.py::TestLocalSubtensorMerge: 13 passed / 2 skipped.Benchmarks
Compile-time and post-rewrite node counts at the pre-gate baseline (
d35fb51b1) vs this branch, modeexcluding("fusion")so the rewriter output is visible.Compositefusion otherwise lumps the symbolic switch tree into a single node, hiding the count but not the rewrite-time cost.constant_foldingis not hiding anything — the switches reference symbolicShape_i(x)and don't fold.x[1:-1]× 3x[1:-1]× 5x[1:-1]× 8grad(xs[-1], x0)symbolic-nScan (#112's worst case)Covered by
tests/benchmarks/test_subtensor.py::test_local_subtensor_merge_compile_benchmarkandtests/benchmarks/test_scan.py::test_scan_grad_subtensor_compile_benchmarkLimitations (deliberate)
final_idx == -1: the only number of steps we can even be sure a while scan takes is 1 (if n_steps>0, which we assert)concatenate([rev, zeros])[k]from the while-scan gradient path stays opaque to chain folding (would need Consider lifting Subtensor through Joins #919)