From 33de554133936ebf782ad413d0fa8d0e8045a5e3 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Mon, 22 Jun 2026 00:09:29 +0200 Subject: [PATCH 1/3] Subtensor rewrites: reason about advanced indices jointly when gating --- pytensor/tensor/rewriting/subtensor_lift.py | 106 +++++++++++++++--- tests/tensor/rewriting/test_subtensor.py | 29 +++-- tests/tensor/rewriting/test_subtensor_lift.py | 36 ++++++ 3 files changed, 148 insertions(+), 23 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor_lift.py b/pytensor/tensor/rewriting/subtensor_lift.py index d3a4bf5bf6..0f242577a9 100644 --- a/pytensor/tensor/rewriting/subtensor_lift.py +++ b/pytensor/tensor/rewriting/subtensor_lift.py @@ -27,6 +27,7 @@ Eye, Join, MakeVector, + Nonzero, alloc, arange, as_tensor, @@ -244,6 +245,77 @@ def _index_provably_not_larger(idx, val_static_dim, fgraph=None) -> bool: return bool(np.prod(idx_static_shape) < val_static_dim) +def _constants_jointly_unique(consts) -> bool: + """Whether stacked constant indices have no duplicate coordinate tuples. + + The stacked ``np.unique`` can be expensive on large indices, so the result + is cached on the first constant's tag. Uniqueness is a property of the whole + group, and a constant may belong to several groups (constants are shared + across the graph), so the cache is keyed by the group's identities rather + than a single flag. + """ + key = tuple(id(c) for c in consts) + cache = getattr(consts[0].tag, "jointly_unique_indices", None) + if cache is None: + cache = consts[0].tag.jointly_unique_indices = {} + if key not in cache: + datas = [np.asarray(c.data) for c in consts] + # A coordinate axis that mixes positive and negative values may alias + # (``0`` and ``-dim`` are the same position), so distinctness of the raw + # values no longer proves distinctness of the coordinates. + if any((data >= 0).any() and (data < 0).any() for data in datas): + cache[key] = False + else: + coords = np.broadcast_arrays(*datas) + stacked = np.stack([coord.ravel() for coord in coords]) + cache[key] = bool(np.unique(stacked, axis=1).shape[1] == stacked.shape[1]) + return bool(cache[key]) + + +def _indices_provably_not_larger(idxs_and_dims, fgraph) -> bool: + """Whether advanced-indexing some consecutive axes selects no more elements + than those axes already hold, so lifting a Subtensor through the indexing + can't increase computation. + + ``idxs_and_dims`` pairs each advanced index (``ndim > 0``) with the static + size of the axis it indexes. + """ + if not idxs_and_dims: + return True + + idxs = [idx for idx, _ in idxs_and_dims] + dims = [dim for _, dim in idxs_and_dims] + idx_shapes = [idx.type.shape for idx in idxs] + + # With static shapes the result size is known exactly, so just compare it + # against the number of elements the indexed axes hold. + if all(d is not None for d in dims) and all( + None not in shape for shape in idx_shapes + ): + return bool(np.prod(np.broadcast_shapes(*idx_shapes)) <= np.prod(dims)) + + # Otherwise fall back to proving the indices are duplicate-free, which on its + # own bounds the result by the axes' size, even when the sizes are unknown: + # - each index repeats no position on its own axis, or + if all(_index_provably_not_larger(idx, dim, fgraph) for idx, dim in idxs_and_dims): + return True + if len(idxs) > 1: + # - the indices are all the coordinates of one Nonzero, distinct by + # construction (e.g. symbolic tril_indices), or + owners = {idx.owner for idx in idxs} + if ( + len(owners) == 1 + and (owner := next(iter(owners))) is not None + and isinstance(owner.op, Nonzero) + and set(idxs) == set(owner.outputs) + ): + return True + # - the constant coordinate tuples have no duplicates. + if all(isinstance(idx, Constant) for idx in idxs): + return _constants_jointly_unique(idxs) + return False + + @register_canonicalize @register_stabilize @register_specialize @@ -345,17 +417,19 @@ def local_subtensor_of_batch_dims(fgraph, node): if _non_consecutive_adv_indexing(idx_tuple): return None - # Skip when lifting would expand a gather past a non-broadcast input's size. + # Skip when indexing each input would select more elements than it holds, + # making the lifted Elemwise do more work. The advanced indices are weighed + # together, over the consecutive axes they jointly index. for inp in elem.owner.inputs: - for axis, idx in enumerate(idx_tuple): - if axis >= inp.type.ndim: - break - if not isinstance(idx, TensorVariable) or idx.type.ndim == 0: - continue - if inp.type.broadcastable[axis]: - continue - if not _index_provably_not_larger(idx, inp.type.shape[axis], fgraph): - return None + adv_indices = [ + (idx, inp.type.shape[axis]) + for axis, idx in enumerate(idx_tuple[: inp.type.ndim]) + if isinstance(idx, TensorVariable) + and idx.type.ndim > 0 + and not inp.type.broadcastable[axis] + ] + if not _indices_provably_not_larger(adv_indices, fgraph): + return None batch_ndim = ( elem.owner.op.batch_ndim(elem.owner) @@ -742,11 +816,15 @@ def lift_subtensor_through_alloc(fgraph, node): # Indices on Alloc-added dims don't reach val; the rest line up with val's dims. val_indexer = indices[n_added_dims:] - dangerous_index_reaches_val = any( - not val.type.broadcastable[axis] - # Per-axis check; doesn't account for net effect across all axes. - and not _index_provably_not_larger(idx, val.type.shape[axis], fgraph) + val_adv_indices = [ + (idx, val.type.shape[axis]) for axis, idx in enumerate(val_indexer) + if isinstance(idx, TensorVariable) + and idx.type.ndim > 0 + and not val.type.broadcastable[axis] + ] + dangerous_index_reaches_val = not _indices_provably_not_larger( + val_adv_indices, fgraph ) # On broadcast val dims the index is neutralized (advanced indices dropped, diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 1c1cbb7bb0..1d2032e914 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -2753,8 +2753,11 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): packed = pt.vector("packed") if exp_before_materialize: - # We test the same optimized result regardless of whether - # the diagonals are updated before or after materialization + # Two equivalent ways to build the same ``L``: exponentiate the diagonal + # in the packed vector before scattering it into the matrix, or scatter + # first and exponentiate the matrix diagonal afterwards (the ``else`` + # branch below). They are the same computation but optimize to slightly + # different graphs under ``BlasOpt`` (see the index-count assertion). packed_diag_indices = pt.arange(n + 1).cumsum()[1:] - 1 log_diag = packed[packed_diag_indices] packed_update = packed[packed_diag_indices].set(pt.exp(log_diag)) @@ -2778,7 +2781,6 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): mode = get_default_mode().excluding("fuse_indexed_into_elemwise") f = function([packed], [loss, grad], mode=mode) - f.dprint(print_shape=True) idx_types = ( Subtensor, @@ -2790,13 +2792,22 @@ def test_cholesky_unconstrain_grad(exp_before_materialize): ExtractDiag, ) n_idx = sum(1 for n in f.maker.fgraph.toposort() if isinstance(n.op, idx_types)) - # The ``BlasOpt`` rewrites lower ``L @ L.T`` to ``Gemm``; the gradient then - # fuses the diagonal-gradient term into a ``Gemm`` operand, materializing one - # extra set-subtensor. A linker that cannot use them lists ``BlasOpt`` in - # ``incompatible_rewrites`` (e.g. the numba linker), keeping the plain ``Dot`` - # lowering with that term as a vector. Both lowerings are correct. + # The gradient w.r.t. ``L`` adds the ``sum(L@L.T)`` term ``2*(ones@L)`` to the + # log-det term ``diag(1/diag(L))`` (a diagonal matrix, ``1/diag(L) == + # exp(-diag)``). In the post-materialization formulation this addition happens + # at the matrix level, so when ``BlasOpt`` runs its ``GemmOptimizer`` fuses + # ``add(dot, C)`` into a single ``Gemm`` with the log-det term as the additive + # ``C`` operand; materializing that diagonal ``C`` matrix is one extra + # set-subtensor, giving 7 indexing ops. The pre-materialization formulation + # keeps the log-det term in the packed vector's index space (added after the + # tril gather), so there is no matrix-level ``add(dot, C)`` to fuse and it + # stays 6. Without ``Gemm`` (a linker that lists ``BlasOpt`` in + # ``incompatible_rewrites``, e.g. numba) the additive term is never raised to + # a matrix -- it is folded into an Elemwise on the diagonal entries -- so both + # formulations collapse to 6. All lowerings are correct. blas_rewrites_run = "BlasOpt" not in f.maker.mode.linker.incompatible_rewrites - assert n_idx == (7 if blas_rewrites_run else 6) + expected_n_idx = 7 if (blas_rewrites_run and not exp_before_materialize) else 6 + assert n_idx == expected_n_idx x = np.array([1.0, 0.5, 2.0, 0.3, 0.1, 1.5]) # Expected values were computed once by running ``f(x)``. diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 3314ba6952..664e877efc 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -235,6 +235,32 @@ def test_elemwise_adv_index_assumed_unique_lifts(self): ) result.assert_graph(x[idx] + y[idx]) + def test_elemwise_jointly_unique_adv_indices_lift(self): + """A group of adv indices that each repeat but pair up to distinct + coordinates (tril_indices) can't select more elements than the indexed + axes hold, so it lifts.""" + # Symbolic indices: the outputs of a single Nonzero. + n = pt.scalar("n", dtype="int64") + x = pt.matrix("x") + rows, cols = pt.tril_indices(n) + out = pt.exp(x)[rows, cols] + rewritten = rewrite_graph(out) + assert_equal_computations([rewritten], [pt.exp(x[rows, cols])]) + + # Constant indices, static array shape: proved through the exact size. + x = pt.matrix("x", shape=(5, 5)) + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + out = pt.exp(x)[rows, cols] + rewritten = rewrite_graph(out) + assert_equal_computations([rewritten], [pt.exp(x[rows, cols])]) + + # Constant indices, unknown array shape: proved through joint uniqueness. + x = pt.matrix("x") + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + out = pt.exp(x)[rows, cols] + rewritten = rewrite_graph(out) + assert_equal_computations([rewritten], [pt.exp(x[rows, cols])]) + def test_blockwise(self): class CoreTestOp(Op): itypes = [dvector, dvector] @@ -756,6 +782,16 @@ def test_const_idx_with_duplicates_bails(self): rewritten = rewrite_graph(out, **self.rewrite_kw) assert_equal_computations([rewritten], [out], strict_dtype=False) + def test_jointly_unique_adv_indices_lift(self): + """Indices that each repeat but pair up to distinct coordinates + (tril_indices) don't enlarge val, so the read lifts through Alloc.""" + val = pt.matrix("val", shape=(5, 5)) + rows, cols = (pt.constant(i) for i in np.tril_indices(5)) + + out = pt.alloc(val, 5, 5)[rows, cols] + rewritten = rewrite_graph(out, **self.rewrite_kw) + assert_equal_computations([rewritten], [val[rows, cols]], strict_dtype=False) + def test_negative_step_idx_to_slice(self): """Negative-step constant arange ``[7, 5, 3, 1]`` rewrites to ``x[7::-2]``.""" x = pt.vector("x", shape=(10,)) From 252d29c2bddb9f2f7f816bfc07ef04025585e533 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 25 Jun 2026 14:05:42 +0200 Subject: [PATCH 2/3] Fuse safe read-modify-write into IndexedElemwise A write whose buffer is also read no longer always bails: when every aliasing read is through the same variable as the write target, uses the write's index, and that index is duplicate-free, fuse it and set destroyhandler_tolerate_aliased instead. --- pytensor/tensor/rewriting/indexed_elemwise.py | 68 +++++++++++++++---- tests/link/numba/test_indexed_elemwise.py | 63 ++++++++++++----- 2 files changed, 101 insertions(+), 30 deletions(-) diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index b3ff3828b3..46e9880f3e 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -17,6 +17,7 @@ from pytensor.scalar.basic import Composite from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.rewriting.elemwise import InplaceElemwiseOptimizer +from pytensor.tensor.rewriting.subtensor import _has_unique_indices from pytensor.tensor.shape import Reshape, shape_padright from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -474,9 +475,12 @@ def apply(self, fgraph): idx_groups = {} # (idx_var, axis) -> (reads: list[int], writes: list[int]) - # Roots of the buffers we read through, used below to skip fusing a + # Roots of the buffers we read through, used below to gate fusing a # write whose target buffer aliases one of them (see aliasing check). read_source_roots = set() + # Per fused read: (input position, source variable, source root, + # frozenset of index pairs), used by the aliasing check below. + read_info = [] # Find indexed reads to fuse: single client AdvancedSubtensor(1) for i, inp in enumerate(node.inputs): @@ -488,7 +492,10 @@ def apply(self, fgraph): idx_axis_pairs = self._extract_idx_axis_pairs(inp_node) if idx_axis_pairs is None: continue - read_source_roots.add(_view_root(view_i, inp_node.inputs[0])) + source = inp_node.inputs[0] + root = _view_root(view_i, source) + read_source_roots.add(root) + read_info.append((i, source, root, frozenset(idx_axis_pairs))) for idx_axis_pair in idx_axis_pairs: if idx_axis_pair not in idx_groups: idx_groups[idx_axis_pair] = ([], []) @@ -500,6 +507,9 @@ def apply(self, fgraph): # Our current vectorize codegen can't produce write only loops that don't force # the recomputation of the core function in every step. write_targets = {} # out_idx -> update_node + # out_idx -> read input positions whose buffer the write aliases and + # which the destroy handler must be told to tolerate. + aliased_read_positions = {} must_transpose_write_axes = False for out_idx, out in enumerate(node.outputs): clients = fgraph.clients[out] @@ -533,15 +543,37 @@ def apply(self, fgraph): target, _, *idx_vars = client_node.inputs - # Fusing an in-place write whose target is a buffer we also read - # through (same root) makes it two distinct inputs of one node, one - # read and one destroyed: the destroy handler rejects that aliasing. - # Leave such writes unfused (the reads still fuse, no added copy). - # Non-in-place writes are copied below, so their copy breaks the alias. - if client_node.op.inplace and _view_root(view_i, target) in ( - read_source_roots + # An in-place write whose target buffer is also read aliases a + # destroyed input with a live read, which the destroy handler + # rejects. It is safe only when every aliasing read is through the + # *same variable* as the write target (same root is not enough: a + # view like ``x[::-1]`` remaps positions), uses the write's index, + # and that index is duplicate-free -- then each position is read + # once before being overwritten. We then keep the write fused and + # tell the destroy handler to tolerate the alias (set below); + # otherwise leave it unfused (the reads still fuse, no added copy). + if ( + client_node.op.inplace + and (target_root := _view_root(view_i, target)) in read_source_roots ): - continue + write_pairs = frozenset(idx_axis_pairs) + root_aliasing = [ + (pos, source, pairs) + for pos, source, root, pairs in read_info + if root == target_root + ] + rmw_is_safe = all( + source is target and pairs == write_pairs + for _pos, source, pairs in root_aliasing + ) and all( + _has_unique_indices(fgraph, idx) + for idx, _axis in idx_axis_pairs + ) + if not rmw_is_safe: + continue + aliased_read_positions[out_idx] = [ + pos for pos, _source, _pairs in root_aliasing + ] write_bcast = AdvancedSubtensor(idx_list=client_node.op.idx_list)( target, *idx_vars @@ -680,6 +712,8 @@ def _has_non_write_clients(out_idx): # fgraph uses a fresh variable there (see below), so the actual buffer # (or its copy) is bound only at the outer call. outer_write_targets = {} + # (destroyed write-target position, aliased read position) pairs. + tolerate_aliased = [] # Inner fgraph outputs: Elemwise outputs, with write targets # replaced by their AdvancedIncSubtensor result @@ -715,6 +749,11 @@ def _has_non_write_clients(out_idx): fgraph_outputs[out_idx] = write_out fgraph_destroy_map[out_idx] = [target_pos] + tolerate_aliased.extend( + (target_pos, read_pos) + for read_pos in aliased_read_positions.get(out_idx, ()) + ) + # indexed_inputs_spec: ((read_positions, axis) | None, ...) # indexed_outputs_spec: ((write_positions, axis, "inc"|"set") | None, ...) indexed_inputs_spec = tuple( @@ -737,13 +776,18 @@ def _has_non_write_clients(out_idx): val = outer_write_targets.get(i, inp) outer_inputs.append(val.copy() if i in copy_positions else val) - new_outs = IndexedElemwise( + indexed_elemwise_op = IndexedElemwise( fgraph_inputs, fgraph_outputs, destroy_map=fgraph_destroy_map, indexed_inputs=indexed_inputs_spec, indexed_outputs=indexed_outputs_spec, - )(*outer_inputs, return_list=True) + ) + if tolerate_aliased: + indexed_elemwise_op.destroyhandler_tolerate_aliased = tuple( + tolerate_aliased + ) + new_outs = indexed_elemwise_op(*outer_inputs, return_list=True) replacements = [] for out_idx in range(len(node.outputs)): diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 22ffbe5927..2c26ec0b54 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -433,31 +433,32 @@ def test_multiple_write_targets_different_lengths(self): np.testing.assert_allclose(fused_o, unfused_o, rtol=1e-10) @pytest.mark.parametrize( - "read_idx, write_idx", + "read_idx, write_idx, write_fuses", # Non-contiguous so the indices stay Advanced(Inc)Subtensor1 rather than # being canonicalised into basic slices. [ - ([0, 2, 5], [1, 3, 7]), - ([0, 2, 5], [0, 2, 5]), - ([0, 2, 5], [5, 0, 2]), + ([0, 2, 5], [1, 3, 7], False), + ([0, 2, 5], [0, 2, 5], True), + ([0, 2, 5], [5, 0, 2], False), ], ids=["write_out_of_read_range", "write_equals_read", "write_permutes_read"], ) - def test_write_target_aliases_read_source(self, read_idx, write_idx): + def test_write_target_aliases_read_source(self, read_idx, write_idx, write_fuses): """Indexed write into a buffer that is also read through the same Elemwise. ``set_subtensor(b[write_idx], b[read_idx] * 2)`` reads and writes the same - buffer ``b``. Fusing the in-place write would alias the destroyed write - target with the live read input, so the write must stay external while the - read still fuses -- without raising an aliasing error or aborting the pass. - - The aliasing is only genuinely unsafe when read and write indices overlap - in a *different order* (``write_permutes_read``): then an in-loop write - could clobber a position another iteration still has to read. When the - indices don't overlap (``write_out_of_read_range``) or overlap in the same - order (``write_equals_read``) the alias is harmless and could be fused in - the future via a ``tolerated_aliased`` flag. For now we conservatively skip - the write in all cases; this test pins the correctness of that behaviour. + buffer ``b``. Fusing the in-place write aliases the destroyed write target + with the live read input, which the destroy handler rejects by default. + + It is safe only when read and write hit the *same* positions in the *same* + order (``write_equals_read``): each position is then read once, before being + overwritten, and never revisited, so the write fuses behind a + ``destroyhandler_tolerate_aliased`` promise. When the positions overlap in a + *different order* (``write_permutes_read``) an in-loop write could clobber a + position another iteration still has to read, so the write must stay + external. The disjoint case (``write_out_of_read_range``) is also safe in + principle but not yet fused (it would need per-index non-overlap reasoning). + In every case the read still fuses and the result stays correct. """ rng = np.random.default_rng(42) x = pt.vector("x", shape=(9,)) @@ -466,12 +467,38 @@ def test_write_target_aliases_read_source(self, read_idx, write_idx): write_idx = np.array(write_idx, dtype=np.int64) out = b[write_idx].set(b[read_idx] * 2.0) fn, fn_u = fused_and_unfused([x], out) - # The read fuses into an IndexedElemwise; the aliasing write stays external. + # The read always fuses into an IndexedElemwise; the write fuses only when + # the alias is provably safe, otherwise it stays an external scatter. + assert_fused(fn) + nodes = fn.maker.fgraph.toposort() + has_external_write = any(isinstance(n.op, AdvancedIncSubtensor1) for n in nodes) + assert has_external_write == (not write_fuses) + if write_fuses: + [ie] = [n for n in nodes if isinstance(n.op, IndexedElemwise)] + assert ie.op.destroyhandler_tolerate_aliased + xv = rng.normal(size=9) + np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-10) + + def test_write_aliases_read_through_view_not_fused(self): + """Read and write share a root but through different variables. + + ``b[idx].set(exp(b[::-1][idx]))`` reads the reversed view ``b[::-1]`` and + writes ``b``: same root and index, but the view remaps positions, so fusing + the in-place write would clobber positions still to be read. The write must + stay external (read source is not the same variable as the write target). + ``b`` is an intermediate so no protective input-copy masks the alias. + """ + rng = np.random.default_rng(42) + x = pt.vector("x", shape=(6,)) + b = x + 1.0 + idx = np.array([0, 2, 5], dtype=np.int64) + out = b[idx].set(pt.exp(b[::-1][idx])) + fn, fn_u = fused_and_unfused([x], out) assert_fused(fn) assert any( isinstance(n.op, AdvancedIncSubtensor1) for n in fn.maker.fgraph.toposort() ) - xv = rng.normal(size=9) + xv = rng.normal(size=6) np.testing.assert_allclose(fn(xv), fn_u(xv), rtol=1e-10) def test_non_inplace_aliasing_write_preserves_input(self): From f4c7a0e622388cfe1904d515e430c90c202df376 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 25 Jun 2026 16:43:12 +0200 Subject: [PATCH 3/3] Fuse basic slice writes into IndexedElemwise Generalize the IndexedElemwise fusion to absorb basic IncSubtensor/SetSubtensor slice writes (gh #2192), not just advanced indexing. An Elemwise whose result is written into buffer[slices] now writes straight into a view of the destination, eliminating the intermediate temp + copy. Basic and advanced writes can also be fused into a single loop. Rewrite (FuseIndexedElemwise): detect single-client basic IncSubtensor writes of an Elemwise output, with a coverage check; encode them via the inner IncSubtensor node + destroy_map (no new spec). Slice-bound vars become inner inputs. Drop the gate against mixing basic and advanced writes in one op. Numba dispatch: one unified funcify for advanced-only, basic-only and mixed ops. prepare_inputs is a FrozenFunctionGraph of Subtensor views compiled through the normal dispatch (reusing the Subtensor slice codegen); the views are passed as in-place core inputs to a single _vectorized call alongside any advanced index specs. store_core_outputs gains n_inplace_buffer_inputs to accept (and drop) the buffer inputs an inplace_pattern target requires. Pure-advanced ops are unchanged. The op stays a portable OpFromGraph (inner Elemwise + IncSubtensor), so non-Numba backends evaluate it correctly; only the fast slice-write path is Numba-specific. --- pytensor/link/numba/dispatch/elemwise.py | 236 ++++++++++++++---- .../link/numba/dispatch/vectorize_codegen.py | 16 +- pytensor/tensor/rewriting/indexed_elemwise.py | 74 +++++- tests/link/numba/test_indexed_elemwise.py | 157 ++++++++++++ 4 files changed, 424 insertions(+), 59 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 1383ee1df1..ab5a4cf055 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -11,6 +11,7 @@ from numpy.lib.stride_tricks import as_strided from pytensor import config +from pytensor.graph.fg import FrozenFunctionGraph from pytensor.graph.op import Op from pytensor.link.numba.cache import ( compile_numba_function_src, @@ -51,6 +52,12 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + IncSubtensor, + Subtensor, +) @singledispatch @@ -715,30 +722,76 @@ def impl(*inputs): return elemwise, elemwise_key -@register_funcify_and_cache_key(IndexedElemwise) -def numba_funcify_IndexedElemwise(op, node, **kwargs): - """Generate fused Elemwise Numba code with indexed reads and updates. +def _basic_write_slice_fgraph( + op, basic_writes, basic_order, adv_write_targets, adv_write_order, nin, n_indices +): + """FunctionGraph that slices each basic write target into a view and emits the + inputs in the order ``_vectorized`` expects. + + Outputs are ``[elemwise inputs, basic views, index arrays, advanced targets]``; + only the views are computed (via ``Subtensor``), the rest pass through. + Compiling this graph through the numba dispatch yields the input-marshalling + function, reusing the ``Subtensor`` slice codegen. + """ + inputs = op.fgraph.inputs + outputs = [inputs[i] for i in range(nin)] + for oi in basic_order: + wn = basic_writes[oi] + outputs.append(Subtensor(idx_list=wn.op.idx_list)(wn.inputs[0], *wn.inputs[2:])) + outputs += [inputs[nin + k] for k in range(n_indices)] + outputs += [adv_write_targets[oi] for oi in adv_write_order] + # clone (the default) copies the subgraph and truncates at ``inputs`` so the + # op's own fgraph is untouched. + return FrozenFunctionGraph(inputs, outputs) - Reads indexed_inputs/indexed_outputs specs stored on the Op by the - rewriting pass, and generates a single vectorized loop with indirect - indexing. - fgraph inputs are ordered as:: +@register_funcify_and_cache_key(IndexedElemwise) +def numba_funcify_IndexedElemwise(op, node, **kwargs): + """Funcify IndexedElemwise into one vectorized Elemwise loop. - [elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...] + Inputs are ordered ``[elemwise inputs, basic views, index arrays, targets]``. + Advanced reads and scatter-writes use the ``indexed_inputs``/``indexed_outputs`` + specs for indirect indexing. Basic slice writes are inner ``IncSubtensor`` + nodes: the target is sliced into a view (via the Subtensor funcify) and the + loop writes into it in place. """ [elemwise_node] = [n for n in op.fgraph.apply_nodes if isinstance(n.op, Elemwise)] + basic_write_nodes = [ + n for n in op.fgraph.apply_nodes if isinstance(n.op, IncSubtensor) + ] scalar_node = elemwise_node.op.make_scalar_node(*elemwise_node.inputs) scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key( elemwise_node.op.scalar_op, node=scalar_node, **kwargs ) + nin_elemwise = len(elemwise_node.inputs) + nout = len(elemwise_node.outputs) + fgraph_inputs = op.fgraph.inputs + fgraph_outputs = op.fgraph.outputs + + # ``build_outputs`` (below) indexes the outer-input tuple ``args``. + def _arg(var): + return f"args[{fgraph_inputs.index(var)}]" + indexed_inputs = op.indexed_inputs indexed_outputs = op.indexed_outputs n_indices = len(indexed_inputs) - nin_elemwise = len(elemwise_node.inputs) - nout = len(elemwise_node.outputs) + + # Basic writes: output index -> inner IncSubtensor node. + basic_writes = {fgraph_outputs.index(wn.outputs[0]): wn for wn in basic_write_nodes} + basic_order = sorted(basic_writes) + n_basic = len(basic_order) + + # Advanced scatter writes (mixed op): output index -> target buffer, in the + # ascending-output order the vectorized write targets are expected. + adv_write_targets = { + fgraph_outputs.index(n.outputs[0]): n.inputs[0] + for n in op.fgraph.apply_nodes + if isinstance(n.op, AdvancedIncSubtensor1 | AdvancedIncSubtensor) + } + adv_write_order = sorted(adv_write_targets) + all_write_outs = set(basic_order) | set(adv_write_order) inc_outputs = frozenset( out_idx @@ -746,22 +799,42 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): if entry is not None for out_idx in entry[0] if entry[2] == "inc" - ) + ) | {oi for oi in basic_order if not basic_writes[oi].op.set_instead_of_inc} + # Each basic write writes in place into a view of its buffer. An in-place + # output's buffer must be a core input, so ``_vectorized`` feeds the view in + # as an extra input the scalar op never reads -- ``store_core_outputs`` accepts + # and drops these trailing buffer inputs. core_op_fn = store_core_outputs( - scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs + scalar_op_fn, + nin=nin_elemwise, + nout=nout, + inc_outputs=inc_outputs, + n_inplace_buffer_inputs=n_basic, ) - input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs) output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs) + input_bc_patterns = tuple( + [inp.type.broadcastable for inp in elemwise_node.inputs] + + [output_bc_patterns[oi] for oi in basic_order] + ) output_dtypes = tuple(out.type.dtype for out in node.outputs) - inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items()) - core_output_shapes = tuple(() for _ in range(nout)) - idx_broadcastable = tuple( node.inputs[nin_elemwise + k].type.broadcastable for k in range(n_indices) ) + # Keep the Elemwise's own (non-write) in-place pattern; add the basic views. + inplace = { + oi: ii + for oi, ii in elemwise_node.op.inplace_pattern.items() + if oi not in all_write_outs + } + for j, oi in enumerate(basic_order): + inplace[oi] = nin_elemwise + j + inplace_pattern = tuple(inplace.items()) + + core_output_shapes = tuple(() for _ in range(nout)) + input_bc_patterns_enc = encode_literals(input_bc_patterns) output_bc_patterns_enc = encode_literals(output_bc_patterns) output_dtypes_enc = encode_literals(output_dtypes) @@ -769,49 +842,114 @@ def numba_funcify_IndexedElemwise(op, node, **kwargs): indexed_inputs_enc = encode_literals((indexed_inputs, idx_broadcastable)) indexed_outputs_enc = encode_literals(indexed_outputs) - def indexed_elemwise_fn(*outer_inputs): + # Basic writes are marshalled around the loop by two composed helpers: + # ``prepare_inputs`` (a FunctionGraph of Subtensor views, dispatched like any + # other graph) slices each basic target into a view and emits the loop inputs + # as ``[elemwise inputs, basic views, index arrays, advanced targets]``; + # ``build_outputs`` returns the (updated) buffer for basic-written outputs and + # the loop result for the rest. With no basic writes the loop runs directly. + sub_keys = [] + build_src = None + prepare_inputs = build_outputs = None + if basic_order: + prepare_inputs, prepare_key = numba_funcify_and_cache_key( + _basic_write_slice_fgraph( + op, + basic_writes, + basic_order, + adv_write_targets, + adv_write_order, + nin_elemwise, + n_indices, + ), + **kwargs, + ) + sub_keys.append(prepare_key) + + def _out(i): + if i in basic_writes: + return _arg(basic_writes[i].inputs[0]) + return f"result[{i}]" if nout > 1 else "result" + + out_expr = ( + _out(0) + if nout == 1 + else "(" + ", ".join(_out(i) for i in range(nout)) + ",)" + ) + build_src = f"def build_outputs(result, args):\n return {out_expr}\n" + build_outputs = numba_basic.numba_njit( + compile_numba_function_src(build_src, "build_outputs") + ) + + def indexed_elemwise_fn(*args): raise NotImplementedError( "IndexedElemwise cannot be evaluated in Python (non-JIT) mode." ) @overload(indexed_elemwise_fn, jit_options=_jit_options) - def ov_indexed_elemwise_fn(*outer_inputs): - def impl(*outer_inputs): - return _vectorized( - core_op_fn, - input_bc_patterns_enc, - output_bc_patterns_enc, - output_dtypes_enc, - inplace_pattern_enc, - True, # allow_core_scalar - (), # constant_inputs - outer_inputs, - core_output_shapes, - NO_SIZE, - indexed_inputs_enc, - indexed_outputs_enc, - ) + def ov_indexed_elemwise_fn(*args): + if basic_order: + + def impl(*args): + return build_outputs( + _vectorized( + core_op_fn, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, + True, # allow_core_scalar + (), # constant_inputs + prepare_inputs(*args), + core_output_shapes, + NO_SIZE, + indexed_inputs_enc, + indexed_outputs_enc, + ), + args, + ) + else: + + def impl(*args): + return _vectorized( + core_op_fn, + input_bc_patterns_enc, + output_bc_patterns_enc, + output_dtypes_enc, + inplace_pattern_enc, + True, # allow_core_scalar + (), # constant_inputs + args, + core_output_shapes, + NO_SIZE, + indexed_inputs_enc, + indexed_outputs_enc, + ) return impl - cache_version = 2 - if scalar_cache_key is None: + if scalar_cache_key is None or any(k is None for k in sub_keys): key = None else: - key = str( - ( - type(op), - "IndexedElemwise", - cache_version, - inplace_pattern, - input_bc_patterns, - indexed_inputs, - idx_broadcastable, - indexed_outputs, - scalar_cache_key, - ) - ) - key = sha256(key.encode()).hexdigest() + key = sha256( + str( + ( + type(op), + "IndexedElemwise", + 3, # cache version + input_bc_patterns, + output_dtypes, + inplace_pattern, + indexed_inputs, + idx_broadcastable, + indexed_outputs, + tuple(inc_outputs), + build_src, + scalar_cache_key, + tuple(sub_keys), + ) + ).encode() + ).hexdigest() return indexed_elemwise_fn, key diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index d1efdf037d..9733c12b54 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -24,7 +24,11 @@ def encode_literals(literals: Sequence) -> str: def store_core_outputs( - core_op_fn: Callable, nin: int, nout: int, inc_outputs: frozenset = frozenset() + core_op_fn: Callable, + nin: int, + nout: int, + inc_outputs: frozenset = frozenset(), + n_inplace_buffer_inputs: int = 0, ) -> Callable: """Create a Numba function that wraps a core function and stores its vectorized outputs. @@ -39,16 +43,22 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): ... ``inc_outputs`` lists output indices that use ``+=`` instead of ``=``. + + ``n_inplace_buffer_inputs`` trailing inputs are accepted but not passed to + ``core_op_fn``. An ``inplace_pattern`` target must be a core input, so + ``_vectorized`` feeds such buffers in even when the scalar op doesn't read + them (e.g. a basic slice write writes into a buffer view it never reads). """ if getattr(core_op_fn, "handles_out", False): return core_op_fn inputs = [f"i{i}" for i in range(nin)] + buffer_inputs = [f"b{i}" for i in range(n_inplace_buffer_inputs)] outputs = [f"o{i}" for i in range(nout)] inner_outputs = [f"t{output}" for output in outputs] inp_signature = ", ".join(inputs) - out_signature = ", ".join(outputs) + full_signature = ", ".join([*inputs, *buffer_inputs, *outputs]) inner_out_signature = ", ".join(inner_outputs) store_outputs = "\n".join( f"{output} += {inner_output}" @@ -59,7 +69,7 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on): ) ) func_src = f""" -def store_core_outputs({inp_signature}, {out_signature}): +def store_core_outputs({full_signature}): {inner_out_signature} = core_op_fn({inp_signature}) {indent(store_outputs, " " * 4)} """ diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index 46e9880f3e..f870efc3a7 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -9,6 +9,7 @@ from pytensor.compile import optdb from pytensor.compile.builders import OpFromGraph from pytensor.graph import node_rewriter +from pytensor.graph.basic import Constant from pytensor.graph.rewriting.basic import GraphRewriter, dfs_rewriter from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.unify import OpPattern @@ -24,6 +25,8 @@ AdvancedIncSubtensor1, AdvancedSubtensor, AdvancedSubtensor1, + IncSubtensor, + Subtensor, ) from pytensor.tensor.variable import TensorVariable @@ -418,6 +421,9 @@ def transpose_non_indexed_write_axes(node, write_targets): elemwise_batch_ndim = len(node.outputs[0].type.broadcastable) for update_node in write_targets.values(): op = update_node.op + if isinstance(op, IncSubtensor): + # Basic slice writes preserve dim order; never need transposing. + continue target, val, *idx_vars = update_node.inputs idx_axes = [i for i, e in enumerate(op.idx_list) if e != slice(None)] @@ -506,7 +512,7 @@ def apply(self, fgraph): # All indexed write axes have to overlap and not broadcast the core Elemwise loop # Our current vectorize codegen can't produce write only loops that don't force # the recomputation of the core function in every step. - write_targets = {} # out_idx -> update_node + write_targets = {} # out_idx -> update_node (advanced or basic IncSubtensor) # out_idx -> read input positions whose buffer the write aliases and # which the destroy handler must be told to tolerate. aliased_read_positions = {} @@ -531,12 +537,44 @@ def apply(self, fgraph): (c, ci) for c, ci in clients if ci == 1 - and isinstance(c.op, AdvancedIncSubtensor1 | AdvancedIncSubtensor) + and isinstance( + c.op, + AdvancedIncSubtensor1 | AdvancedIncSubtensor | IncSubtensor, + ) ] if len(inc_clients) != 1: # TODO: support multiple writes from the same Elemwise output via Composite duplication continue [(client_node, _)] = inc_clients + + # Basic-slice writes (``buffer[slices].set/inc(out)``) target a + # contiguous *view* of the buffer, so they need no indirect + # indexing: the Numba funcify slices the buffer once and runs the + # Elemwise loop into that view. The write is encoded purely by the + # inner IncSubtensor node + destroy_map (no indexed_outputs entry), + # so it stays out of ``idx_groups``. + if isinstance(client_node.op, IncSubtensor): + if right_pad: + # expand_dims between elemwise and a basic write: not yet handled + continue + target, _, *slice_idx_vars = client_node.inputs + # The Elemwise output must exactly fill the written slice region + # (no implicit broadcast, which would force recomputation). + write_region_bcast = Subtensor(idx_list=client_node.op.idx_list)( + target, *slice_idx_vars + ).type.broadcastable + if out.type.ndim != len(write_region_bcast): + continue + if any( + ob and not wb + for ob, wb in zip( + out.type.broadcastable, write_region_bcast, strict=True + ) + ): + continue + write_targets[out_idx] = client_node + continue + idx_axis_pairs = self._extract_idx_axis_pairs(client_node, write=True) if idx_axis_pairs is None: continue @@ -606,9 +644,27 @@ def apply(self, fgraph): idx_groups[idx_axis_pair][1].append(out_idx) write_targets[out_idx] = client_node - if not idx_groups: + if not idx_groups and not write_targets: continue + # Basic slice writes reuse the generic write-target machinery below + # (inner IncSubtensor + destroy_map, no indexed_outputs entry). Their + # slice-index variables must become inner-fgraph inputs, like the + # advanced index arrays. They can coexist with advanced indexing in a + # single op; the Numba dispatch marshals both into one vectorized loop. + basic_slice_idx_inputs = [] + _seen_slice_idx = set() + for write_node in write_targets.values(): + if not isinstance(write_node.op, IncSubtensor): + continue + for idx_var in write_node.inputs[2:]: + if ( + not isinstance(idx_var, Constant) + and idx_var not in _seen_slice_idx + ): + _seen_slice_idx.add(idx_var) + basic_slice_idx_inputs.append(idx_var) + if must_transpose_write_axes: replacements = self.transpose_non_indexed_write_axes( node, write_targets @@ -700,10 +756,14 @@ def _has_non_write_clients(out_idx): # Fgraph inputs: substitute indexed sources back to their # pre-subtensor arrays, append index arrays and update targets. - fgraph_inputs = [ - inp.owner.inputs[0] if i in indexed_reads else inp - for i, inp in enumerate(node.inputs) - ] + idx_vars + fgraph_inputs = ( + [ + inp.owner.inputs[0] if i in indexed_reads else inp + for i, inp in enumerate(node.inputs) + ] + + idx_vars + + basic_slice_idx_inputs + ) # Non-inplace write targets need a copy so the original isn't destroyed # Elemwise will always destroy the write buffers inplace afterwards. diff --git a/tests/link/numba/test_indexed_elemwise.py b/tests/link/numba/test_indexed_elemwise.py index 2c26ec0b54..ea56ac44ae 100644 --- a/tests/link/numba/test_indexed_elemwise.py +++ b/tests/link/numba/test_indexed_elemwise.py @@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import ( AdvancedIncSubtensor1, AdvancedSubtensor, + IncSubtensor, ) @@ -552,6 +553,162 @@ def test_non_inplace_aliasing_write_preserves_input(self): np.testing.assert_array_equal(tv_in, tv) +class TestBasicSliceWriteFusion: + """Test basic slice writes (``IncSubtensor``) fused into the Elemwise loop. + + These target a contiguous view of the buffer, so the fused funcify slices + the buffer once and writes the Elemwise result straight into the view + (no temp, no per-iteration index arithmetic). See gh #2192. + """ + + def test_set_subtensor_slice(self): + """o[1:].set(exp(x)) fuses; result lands directly in the buffer.""" + x = pt.vector("x") + o0 = pt.vector("o0") + o = o0 + 1.0 + out = o[1:].set(pt.exp(x)) + fn, fn_u = fused_and_unfused([x, o0], out) + assert_fused(fn) + rng = np.random.default_rng(0) + xv = rng.normal(size=(4,)) + ov = rng.normal(size=(5,)) + np.testing.assert_allclose(fn(xv, ov), fn_u(xv, ov), rtol=1e-10) + + def test_inc_subtensor_slice(self): + """o[1:].inc(exp(x)) accumulates onto the existing buffer values.""" + x = pt.vector("x") + o0 = pt.vector("o0") + o = o0 + 1.0 + out = o[1:].inc(pt.exp(x)) + fn, fn_u = fused_and_unfused([x, o0], out) + assert_fused(fn) + rng = np.random.default_rng(1) + xv = rng.normal(size=(4,)) + ov = rng.normal(size=(5,)) + np.testing.assert_allclose(fn(xv, ov), fn_u(xv, ov), rtol=1e-10) + + def test_dynamic_slice_bound(self): + """A symbolic slice bound becomes an inner input of the fused op.""" + x = pt.vector("x") + o0 = pt.vector("o0") + st = pt.lscalar("st") + o = o0 + 1.0 + out = o[st:].set(pt.exp(x)) + fn, fn_u = fused_and_unfused([x, o0, st], out) + assert_fused(fn) + rng = np.random.default_rng(2) + xv = rng.normal(size=(3,)) + ov = rng.normal(size=(5,)) + np.testing.assert_allclose(fn(xv, ov, 2), fn_u(xv, ov, 2), rtol=1e-10) + + def test_step_slice(self): + """A strided slice writes into a non-contiguous view.""" + x = pt.vector("x") + o0 = pt.vector("o0") + o = o0 + 1.0 + out = o[::2].set(pt.exp(x)) + fn, fn_u = fused_and_unfused([x, o0], out) + assert_fused(fn) + rng = np.random.default_rng(3) + xv = rng.normal(size=(3,)) + ov = rng.normal(size=(5,)) + np.testing.assert_allclose(fn(xv, ov), fn_u(xv, ov), rtol=1e-10) + + def test_integer_index_drops_dim(self): + """m[2].set(exp(row)) writes a row (the indexed axis drops).""" + row = pt.vector("row") + m0 = pt.matrix("m0") + m = m0 + 1.0 + out = m[2].set(pt.exp(row)) + fn, fn_u = fused_and_unfused([row, m0], out) + assert_fused(fn) + rng = np.random.default_rng(4) + rv = rng.normal(size=(3,)) + mv = rng.normal(size=(4, 3)) + np.testing.assert_allclose(fn(rv, mv), fn_u(rv, mv), rtol=1e-10) + + def test_multi_axis_slice(self): + """A 2-D slice m[1:, 1:] fuses.""" + sub = pt.matrix("sub") + m0 = pt.matrix("m0") + m = m0 + 1.0 + out = m[1:, 1:].set(pt.exp(sub)) + fn, fn_u = fused_and_unfused([sub, m0], out) + assert_fused(fn) + rng = np.random.default_rng(5) + sv = rng.normal(size=(3, 2)) + mv = rng.normal(size=(4, 3)) + np.testing.assert_allclose(fn(sv, mv), fn_u(sv, mv), rtol=1e-10) + + def test_composite_scalar_op(self): + """A fused Composite inner Elemwise writes into the slice.""" + x = pt.vector("x") + y = pt.vector("y") + o0 = pt.vector("o0") + o = o0 + 1.0 + out = o[1:].set(pt.exp(x) + pt.log(y)) + fn, fn_u = fused_and_unfused([x, y, o0], out) + assert_fused(fn) + rng = np.random.default_rng(6) + xv = rng.normal(size=(4,)) + yv = np.abs(rng.normal(size=(4,))) + 0.1 + ov = rng.normal(size=(5,)) + np.testing.assert_allclose(fn(xv, yv, ov), fn_u(xv, yv, ov), rtol=1e-10) + + def test_non_inplace_target_preserves_input(self): + """Writing into a slice of a non-destroyable input copies first.""" + x = pt.vector("x") + oin = pt.vector("oin") + out = oin[1:].set(pt.exp(x)) + fn, fn_u = fused_and_unfused([x, oin], out) + assert_fused(fn) + rng = np.random.default_rng(7) + xv = rng.normal(size=(4,)) + ov = rng.normal(size=(5,)) + ov_keep = ov.copy() + np.testing.assert_allclose(fn(xv, ov), fn_u(xv, ov), rtol=1e-10) + # The input buffer must not be mutated by the fused op. + np.testing.assert_array_equal(ov, ov_keep) + + def test_mixed_advanced_read_and_basic_write(self): + """An advanced read and a basic slice write fuse into one op. + + ``o[1:].set(exp(x[idx]))`` has an advanced read (``x[idx]``, fused via + the index specs) and a basic slice write (``o[1:]``, written into a view) + in the same Elemwise loop -- a single IndexedElemwise, no outer write. + """ + x = pt.vector("x") + o0 = pt.vector("o0") + idx = pt.lvector("idx") + o = o0 + 1.0 + out = o[1:].set(pt.exp(x[idx])) + fn, fn_u = fused_and_unfused([x, o0, idx], out) + assert_fused(fn) + nodes = fn.maker.fgraph.toposort() + # Both the read and the write are absorbed: no leftover scatter/slice op. + assert sum(isinstance(n.op, IndexedElemwise) for n in nodes) == 1 + assert not any( + isinstance(n.op, AdvancedIncSubtensor1 | IncSubtensor) for n in nodes + ) + rng = np.random.default_rng(9) + xv = rng.normal(size=(6,)) + ov = rng.normal(size=(5,)) + iv = np.array([0, 2, 4, 1], dtype=np.int64) + np.testing.assert_allclose(fn(xv, ov, iv), fn_u(xv, ov, iv), rtol=1e-10) + + def test_read_modify_write_same_slice_not_fused(self): + """o[1:].set(o[1:] * 2) aliases the write target; leave it unfused.""" + o0 = pt.vector("o0") + o = o0 + 1.0 + out = o[1:].set(o[1:] * 2.0) + fn, fn_u = fused_and_unfused([o0], out) + # The write target is read through the same view, so fusing would make + # the Elemwise destroy a buffer it also reads -> rejected, stays correct. + rng = np.random.default_rng(8) + ov = rng.normal(size=(5,)) + np.testing.assert_allclose(fn(ov), fn_u(ov), rtol=1e-10) + + class TestRepeatedAccumulationIndices: """Test inc_subtensor with repeated indices (same position accumulated multiple times)."""