Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 187 additions & 49 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from numpy.lib.stride_tricks import as_strided

from pytensor import config
from pytensor.graph.fg import FrozenFunctionGraph
from pytensor.graph.op import Op
from pytensor.link.numba.cache import (
compile_numba_function_src,
Expand Down Expand Up @@ -51,6 +52,12 @@
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros
from pytensor.tensor.rewriting.indexed_elemwise import IndexedElemwise
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
IncSubtensor,
Subtensor,
)


@singledispatch
Expand Down Expand Up @@ -715,103 +722,234 @@ def impl(*inputs):
return elemwise, elemwise_key


@register_funcify_and_cache_key(IndexedElemwise)
def numba_funcify_IndexedElemwise(op, node, **kwargs):
"""Generate fused Elemwise Numba code with indexed reads and updates.
def _basic_write_slice_fgraph(
op, basic_writes, basic_order, adv_write_targets, adv_write_order, nin, n_indices
):
"""FunctionGraph that slices each basic write target into a view and emits the
inputs in the order ``_vectorized`` expects.

Outputs are ``[elemwise inputs, basic views, index arrays, advanced targets]``;
only the views are computed (via ``Subtensor``), the rest pass through.
Compiling this graph through the numba dispatch yields the input-marshalling
function, reusing the ``Subtensor`` slice codegen.
"""
inputs = op.fgraph.inputs
outputs = [inputs[i] for i in range(nin)]
for oi in basic_order:
wn = basic_writes[oi]
outputs.append(Subtensor(idx_list=wn.op.idx_list)(wn.inputs[0], *wn.inputs[2:]))
outputs += [inputs[nin + k] for k in range(n_indices)]
outputs += [adv_write_targets[oi] for oi in adv_write_order]
# clone (the default) copies the subgraph and truncates at ``inputs`` so the
# op's own fgraph is untouched.
return FrozenFunctionGraph(inputs, outputs)

Reads indexed_inputs/indexed_outputs specs stored on the Op by the
rewriting pass, and generates a single vectorized loop with indirect
indexing.

fgraph inputs are ordered as::
@register_funcify_and_cache_key(IndexedElemwise)
def numba_funcify_IndexedElemwise(op, node, **kwargs):
"""Funcify IndexedElemwise into one vectorized Elemwise loop.

[elemwise_inputs..., idx_0, idx_1, ..., update_target_0, ...]
Inputs are ordered ``[elemwise inputs, basic views, index arrays, targets]``.
Advanced reads and scatter-writes use the ``indexed_inputs``/``indexed_outputs``
specs for indirect indexing. Basic slice writes are inner ``IncSubtensor``
nodes: the target is sliced into a view (via the Subtensor funcify) and the
loop writes into it in place.
"""
[elemwise_node] = [n for n in op.fgraph.apply_nodes if isinstance(n.op, Elemwise)]
basic_write_nodes = [
n for n in op.fgraph.apply_nodes if isinstance(n.op, IncSubtensor)
]

scalar_node = elemwise_node.op.make_scalar_node(*elemwise_node.inputs)
scalar_op_fn, scalar_cache_key = numba_funcify_and_cache_key(
elemwise_node.op.scalar_op, node=scalar_node, **kwargs
)

nin_elemwise = len(elemwise_node.inputs)
nout = len(elemwise_node.outputs)
fgraph_inputs = op.fgraph.inputs
fgraph_outputs = op.fgraph.outputs

# ``build_outputs`` (below) indexes the outer-input tuple ``args``.
def _arg(var):
return f"args[{fgraph_inputs.index(var)}]"

indexed_inputs = op.indexed_inputs
indexed_outputs = op.indexed_outputs
n_indices = len(indexed_inputs)
nin_elemwise = len(elemwise_node.inputs)
nout = len(elemwise_node.outputs)

# Basic writes: output index -> inner IncSubtensor node.
basic_writes = {fgraph_outputs.index(wn.outputs[0]): wn for wn in basic_write_nodes}
basic_order = sorted(basic_writes)
n_basic = len(basic_order)

# Advanced scatter writes (mixed op): output index -> target buffer, in the
# ascending-output order the vectorized write targets are expected.
adv_write_targets = {
fgraph_outputs.index(n.outputs[0]): n.inputs[0]
for n in op.fgraph.apply_nodes
if isinstance(n.op, AdvancedIncSubtensor1 | AdvancedIncSubtensor)
}
adv_write_order = sorted(adv_write_targets)
all_write_outs = set(basic_order) | set(adv_write_order)

inc_outputs = frozenset(
out_idx
for entry in indexed_outputs
if entry is not None
for out_idx in entry[0]
if entry[2] == "inc"
)
) | {oi for oi in basic_order if not basic_writes[oi].op.set_instead_of_inc}

# Each basic write writes in place into a view of its buffer. An in-place
# output's buffer must be a core input, so ``_vectorized`` feeds the view in
# as an extra input the scalar op never reads -- ``store_core_outputs`` accepts
# and drops these trailing buffer inputs.
core_op_fn = store_core_outputs(
scalar_op_fn, nin=nin_elemwise, nout=nout, inc_outputs=inc_outputs
scalar_op_fn,
nin=nin_elemwise,
nout=nout,
inc_outputs=inc_outputs,
n_inplace_buffer_inputs=n_basic,
)

input_bc_patterns = tuple(inp.type.broadcastable for inp in elemwise_node.inputs)
output_bc_patterns = tuple(out.type.broadcastable for out in elemwise_node.outputs)
input_bc_patterns = tuple(
[inp.type.broadcastable for inp in elemwise_node.inputs]
+ [output_bc_patterns[oi] for oi in basic_order]
)
output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(elemwise_node.op.inplace_pattern.items())
core_output_shapes = tuple(() for _ in range(nout))

idx_broadcastable = tuple(
node.inputs[nin_elemwise + k].type.broadcastable for k in range(n_indices)
)

# Keep the Elemwise's own (non-write) in-place pattern; add the basic views.
inplace = {
oi: ii
for oi, ii in elemwise_node.op.inplace_pattern.items()
if oi not in all_write_outs
}
for j, oi in enumerate(basic_order):
inplace[oi] = nin_elemwise + j
inplace_pattern = tuple(inplace.items())

core_output_shapes = tuple(() for _ in range(nout))

input_bc_patterns_enc = encode_literals(input_bc_patterns)
output_bc_patterns_enc = encode_literals(output_bc_patterns)
output_dtypes_enc = encode_literals(output_dtypes)
inplace_pattern_enc = encode_literals(inplace_pattern)
indexed_inputs_enc = encode_literals((indexed_inputs, idx_broadcastable))
indexed_outputs_enc = encode_literals(indexed_outputs)

def indexed_elemwise_fn(*outer_inputs):
# Basic writes are marshalled around the loop by two composed helpers:
# ``prepare_inputs`` (a FunctionGraph of Subtensor views, dispatched like any
# other graph) slices each basic target into a view and emits the loop inputs
# as ``[elemwise inputs, basic views, index arrays, advanced targets]``;
# ``build_outputs`` returns the (updated) buffer for basic-written outputs and
# the loop result for the rest. With no basic writes the loop runs directly.
sub_keys = []
build_src = None
prepare_inputs = build_outputs = None
if basic_order:
prepare_inputs, prepare_key = numba_funcify_and_cache_key(
_basic_write_slice_fgraph(
op,
basic_writes,
basic_order,
adv_write_targets,
adv_write_order,
nin_elemwise,
n_indices,
),
**kwargs,
)
sub_keys.append(prepare_key)

def _out(i):
if i in basic_writes:
return _arg(basic_writes[i].inputs[0])
return f"result[{i}]" if nout > 1 else "result"

out_expr = (
_out(0)
if nout == 1
else "(" + ", ".join(_out(i) for i in range(nout)) + ",)"
)
build_src = f"def build_outputs(result, args):\n return {out_expr}\n"
build_outputs = numba_basic.numba_njit(
compile_numba_function_src(build_src, "build_outputs")
)

def indexed_elemwise_fn(*args):
raise NotImplementedError(
"IndexedElemwise cannot be evaluated in Python (non-JIT) mode."
)

@overload(indexed_elemwise_fn, jit_options=_jit_options)
def ov_indexed_elemwise_fn(*outer_inputs):
def impl(*outer_inputs):
return _vectorized(
core_op_fn,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs
outer_inputs,
core_output_shapes,
NO_SIZE,
indexed_inputs_enc,
indexed_outputs_enc,
)
def ov_indexed_elemwise_fn(*args):
if basic_order:

def impl(*args):
return build_outputs(
_vectorized(
core_op_fn,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs
prepare_inputs(*args),
core_output_shapes,
NO_SIZE,
indexed_inputs_enc,
indexed_outputs_enc,
),
args,
)
else:

def impl(*args):
return _vectorized(
core_op_fn,
input_bc_patterns_enc,
output_bc_patterns_enc,
output_dtypes_enc,
inplace_pattern_enc,
True, # allow_core_scalar
(), # constant_inputs
args,
core_output_shapes,
NO_SIZE,
indexed_inputs_enc,
indexed_outputs_enc,
)

return impl

cache_version = 2
if scalar_cache_key is None:
if scalar_cache_key is None or any(k is None for k in sub_keys):
key = None
else:
key = str(
(
type(op),
"IndexedElemwise",
cache_version,
inplace_pattern,
input_bc_patterns,
indexed_inputs,
idx_broadcastable,
indexed_outputs,
scalar_cache_key,
)
)
key = sha256(key.encode()).hexdigest()
key = sha256(
str(
(
type(op),
"IndexedElemwise",
3, # cache version
input_bc_patterns,
output_dtypes,
inplace_pattern,
indexed_inputs,
idx_broadcastable,
indexed_outputs,
tuple(inc_outputs),
build_src,
scalar_cache_key,
tuple(sub_keys),
)
).encode()
).hexdigest()

return indexed_elemwise_fn, key

Expand Down
16 changes: 13 additions & 3 deletions pytensor/link/numba/dispatch/vectorize_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ def encode_literals(literals: Sequence) -> str:


def store_core_outputs(
core_op_fn: Callable, nin: int, nout: int, inc_outputs: frozenset = frozenset()
core_op_fn: Callable,
nin: int,
nout: int,
inc_outputs: frozenset = frozenset(),
n_inplace_buffer_inputs: int = 0,
) -> Callable:
"""Create a Numba function that wraps a core function and stores its vectorized outputs.

Expand All @@ -39,16 +43,22 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
...

``inc_outputs`` lists output indices that use ``+=`` instead of ``=``.

``n_inplace_buffer_inputs`` trailing inputs are accepted but not passed to
``core_op_fn``. An ``inplace_pattern`` target must be a core input, so
``_vectorized`` feeds such buffers in even when the scalar op doesn't read
them (e.g. a basic slice write writes into a buffer view it never reads).
"""
if getattr(core_op_fn, "handles_out", False):
return core_op_fn

inputs = [f"i{i}" for i in range(nin)]
buffer_inputs = [f"b{i}" for i in range(n_inplace_buffer_inputs)]
outputs = [f"o{i}" for i in range(nout)]
inner_outputs = [f"t{output}" for output in outputs]

inp_signature = ", ".join(inputs)
out_signature = ", ".join(outputs)
full_signature = ", ".join([*inputs, *buffer_inputs, *outputs])
inner_out_signature = ", ".join(inner_outputs)
store_outputs = "\n".join(
f"{output} += {inner_output}"
Expand All @@ -59,7 +69,7 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
)
)
func_src = f"""
def store_core_outputs({inp_signature}, {out_signature}):
def store_core_outputs({full_signature}):
{inner_out_signature} = core_op_fn({inp_signature})
{indent(store_outputs, " " * 4)}
"""
Expand Down
Loading
Loading