Skip to content

Stop subtensor merge bonanza#2098

Draft
ricardoV94 wants to merge 6 commits intopymc-devs:mainfrom
ricardoV94:stop_subtensor_merge
Draft

Stop subtensor merge bonanza#2098
ricardoV94 wants to merge 6 commits intopymc-devs:mainfrom
ricardoV94:stop_subtensor_merge

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 29, 2026

Summary

Closes #112
Closes #1283

local_subtensor_merge no longer expands into combinatorial switch/min/max trees on Scan outputs with symbolic shapes, and scan_save_mem still trims buffers all the way down to taps + 1 whenever a chain of constant-bound subtensors ends in a constant scalar — including when intermediate slices sit between the strip and the terminal index.

What changed

  1. Gate local_subtensor_merge + add a shape-free fast path. Only call the symbolic merge_two_slices when slice bounds (and shapes, for slice+scalar) are constant. A new _safe_slice_slice_merge handles the common slice-on-slice cases without consulting shape: forward × forward (sign-aware bound combination, e.g. x[1:-1][1:-1]x[2:-2]), x[a:b][::-1], x[::-1][a:b], x[::-1][a:b:-1], and x[a:b:-1][::-1] (last restricted to non-negative bounds). Aggressive scalar-into-slice merging moves to local_subtensor_merge_unsafe (shape_unsafe).
  2. Replace while_scan_merge_subtensor_last_element with scan_merge_subtensor_chain — walks arbitrary chains of constant-bound Subtensor clients (steps ∈ {None, 1, -1}) on a Scan output, folds them into a direct raw_out[final_idx] with an Assert(n_steps >= k) for safety. Concrete simulation for constant n_steps; affine bookkeeping over raw_length for symbolic.
  3. Add scan_reduce_nsteps — when every client of a Scan output is a constant scalar index, reduce n_steps to k + 1 - init_l and convert clients to negative form so save_mem can finish the job.
  4. Refactor scan_save_mem — drops get_canonical_form_slice; reads buffer requirements straight off the (now-folded) negative indices. Caps prealloc extra_size at n_steps to avoid uninitialized slots.
  5. Fix infer_shape and local_subtensor_shape_constant for 0-d outputs that arise after save_mem + sit_sot_to_untraced collapse.

Invariant honored

result[...][-1] (static [-1] through any chain of constant-bound, ±1-step slices) always reduces to a unit buffer, with both constant and symbolic n_steps.

Verified

Pattern Static n=10 Symbolic n
result[-1] 0-d scan output ✓ buffer ≤ 2 + Assert(n≥1) ✓
result[-3:][-1] 0-d ✓ buffer ≤ 2 + Assert(n≥3) ✓
result[3:][-1] 0-d ✓ buffer ≤ 2 + Assert(n≥4) ✓
result[::-1][0] 0-d ✓ buffer ≤ 2 ✓
result[5::-1][-1] 0-d ✓ buffer ≤ 2 ✓
result[:-1][::-1][-1] 0-d ✓ reduced ✓

tests/scan/test_rewriting.py: 59 passed / 2 skipped. tests/tensor/rewriting/test_subtensor.py::TestLocalSubtensorMerge: 13 passed / 2 skipped.

Benchmarks

Compile-time and post-rewrite node counts at the pre-gate baseline (d35fb51b1) vs this branch, mode excluding("fusion") so the rewriter output is visible. Composite fusion otherwise lumps the symbolic switch tree into a single node, hiding the count but not the rewrite-time cost. constant_folding is not hiding anything — the switches reference symbolic Shape_i(x) and don't fold.

Benchmark Before After Speedup
x[1:-1] × 3 82 apply, 228 ms 2 apply, 27 ms
x[1:-1] × 5 166 apply, 490 ms 2 apply, 45 ms 11×
x[1:-1] × 8 292 apply, 835 ms 2 apply, 74 ms 11×
grad(xs[-1], x0) symbolic-n Scan (#112's worst case) 348 apply, 1260 ms 22 apply, 114 ms 11×

Covered by tests/benchmarks/test_subtensor.py::test_local_subtensor_merge_compile_benchmark and tests/benchmarks/test_scan.py::test_scan_grad_subtensor_compile_benchmark

Limitations (deliberate)

  • While-scans only fold to final_idx == -1: the only number of steps we can even be sure a while scan takes is 1 (if n_steps>0, which we assert)
  • Slice steps of magnitude > 1 are not handled.
  • concatenate([rev, zeros])[k] from the while-scan gradient path stays opaque to chain folding (would need Consider lifting Subtensor through Joins #919)

@ricardoV94
Copy link
Copy Markdown
Member Author

Possibly fixes: #1288 need to check

@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch 4 times, most recently from a2dd243 to 3126e9c Compare April 30, 2026 13:59
Adds ``test_scan_grad_subtensor_compile_benchmark`` covering the most
extreme pathology in issue pymc-devs#112: ``grad(xs[-1], x0)`` over a symbolic-
``n_steps`` Scan used to blow up into hundreds of switch / min / max
nodes via the unguarded ``local_subtensor_merge``.
…lice chain)

``test_local_subtensor_merge_compile_benchmark[depth={3,5,8}]`` --
nested ``[1:-1]`` slices on a vector. The original
``local_subtensor_merge`` fired ``merge_two_slices`` per pair regardless
of whether the result simplified, generating switch/min/max trees that
constant-folding would later flatten -- final node count was small but
the rewriter still paid for the intermediate explosion. Compile time
scales with depth.

Warm caches once before timing so the recorded number is rewrite/compile
cost, not import or JIT init.
The unconditional merge produced large switch/min/max trees whenever any
component of the chain (slice bounds or shapes) was symbolic — most
visibly on Scan outputs whose stripping slice is rolled together with a
client index. Add ``_can_merge_simply``: only merge slice+slice when both
steps are constant, and slice+scalar when all components and shapes are
constant. Non-mergeable dimensions stay as a separate outer Subtensor.

Add ``local_subtensor_merge_unsafe`` (tag ``shape_unsafe``) that handles
slice+scalar with step ±1 more aggressively without bounds checks, for
the cases where the safe merge bails out.
…nsor_chain

The old rewrite only handled ``while_scan_out[init_l:][-1]``. Generalize:
walk arbitrary chains of constant-bound ``Subtensor`` clients rooted at a
scan output, fold the chain to a direct ``raw_out[final_idx]``, and
insert ``Assert(n_steps >= k)`` when the final index requires it.

Slice steps are restricted to ``None``, ``1``, ``-1`` (paired ``-1`` steps
cancel back to forward). Chains with symbolic bounds bail.

Two evaluation paths share a common helper:

  * Concrete simulation when ``n_steps`` is constant -- run the chain on
    a Python tracker list of indices.
  * Affine bookkeeping ``(c, d)`` over ``raw_length`` when ``n_steps`` is
    symbolic -- track an iterator view ``raw[base + i*direction]`` and
    accumulate the minimum ``raw_length`` constraint as bounds are
    normalized.

Registered in ``scan_eqopt1`` so the folded form reaches ``scan_save_mem``
in time. While-scans only fold to ``-1`` (the only post-early-exit-safe
target).

The old rewrite's ``"n_steps > 0"`` assertion message becomes
``"n_steps >= 1"`` (uniform with the new ``"n_steps >= k"`` form for
k > 1).
@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch from 3126e9c to 30c36f6 Compare April 30, 2026 14:54
… indices

Two coupled changes that are easier to review together because the new
``scan_save_mem`` relies on ``scan_reduce_nsteps`` to have run first.

scan_reduce_nsteps (new, registered at position 1.60, before save_mem):
when every client of a Scan output is a ``Subtensor`` whose first index
is a constant scalar, reduce ``n_steps`` to the minimum that covers all
those scalars and rewrite each client to use a negative index against
the new (shorter) trace. This trims the iteration count itself, not just
the buffer.

scan_save_mem refactor:
  * No longer calls ``get_canonical_form_slice`` on client slices --
    after ``scan_merge_subtensor_chain`` and ``scan_reduce_nsteps`` have
    run, every reachable client either has a constant negative index/
    slice-start (whose magnitude is the buffer length we need) or is
    something we can't trim. Walk those directly.
  * The orphan-output branch is gone (handled by ``scan_remove_unused``).
  * For preallocated buffers, ``extra_size`` is capped at ``n_steps`` so
    a small ``n_steps`` doesn't hand the inner loop uninitialized slots.
  * The ``shape_of`` cache is no longer needed since negative-index buffer
    sizing doesn't probe shapes.

Net diff is mostly deletions in scan_save_mem; the new function itself
adds ~150 lines.

Test updates that come with the refactor:
  * ``test_save_mem_store_steps``: read buffer length via runtime
    ``shape[0]`` instead of searching for ``AllocEmpty`` in ancestors;
    the trimmed buffer can now be a direct slice of the user's init
    input when no extra slot is needed (the rewrite skips the
    ``expand_empty(AllocEmpty(...))`` round-trip).
  * ``test_while_scan_taps_and_map``: ys is now an untraced sit_sot
    (scalar input) once save_mem reduces buffer to 1; assert that
    structurally instead of probing ``ys_trace.shape[0]``.
  * ``test_inplace_taps[symbolic n_steps]``: under symbolic ``n_steps``
    buffers can stay symbolic and the inplace pass may or may not fire;
    don't assert on ``destroyed_inputs`` in that branch.
After the save_mem refactor, ``scan_save_mem`` + ``scan_sit_sot_to_untraced``
can shrink a sit_sot output all the way to a 0-d scalar (no leading
length dimension at all). Two downstream callsites assumed an n-d
output and crashed:

  * ``Scan.infer_shape`` for while-scans called ``Shape_i(0)(o)``
    unconditionally to model the unknown leading dimension. Skip it
    for 0-d outputs and just propagate the empty shape tuple.

  * ``local_subtensor_shape_constant`` indexed
    ``shape_arg.type.broadcastable[idx_val]`` without bounds checking,
    so user code like ``scalar.shape[0]`` (legitimate for some
    constructs, but ill-formed for a 0-d) raised ``IndexError`` from
    inside the rewriter and aborted the optimization pass. Catch and
    bail; the operation is still semantically wrong at runtime, but
    the rewriter shouldn't be the one to surface it.
@ricardoV94 ricardoV94 force-pushed the stop_subtensor_merge branch from 30c36f6 to 8456244 Compare April 30, 2026 21:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Inplace on sit-sot / mit-sot when nsteps is symbolic local_subtensor_merge can complicate graphs

1 participant