Skip to content

More rewrites for ExtractDiag#2045

Open
jessegrabowski wants to merge 1 commit intopymc-devs:mainfrom
jessegrabowski:extract-diag-rewrites
Open

More rewrites for ExtractDiag#2045
jessegrabowski wants to merge 1 commit intopymc-devs:mainfrom
jessegrabowski:extract-diag-rewrites

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented Apr 12, 2026

These came up when I was working on #2032. They're not related to that perse so I split them off. We're missing a lot of simple rewrites for ExtractDiag. I added:

  • ExtractDiag(Eye) -> Ones or Zeros, depending on k
  • ExtractDiag(Eye * x) -> Alloc(x) (new shape)
  • ExtractDiag(Zeros / Ones) -> Zeros / Ones of new shape
  • ExtractDiag(Alloc) -> Alloc (with new shape)
  • ExtractDiag(Elemwise(a, b)) -> Elemwise(ExtractDiag(a), ExtractDiag(b)) (plus some broadcasting logic)
  • ExtractDiag(Transpose(x), offset=k) -> ExtractDiag(X, offset=-k) (I understand transpose is just a free view but it removes a blocker to further rewrites that no longer need to see through the transpose)

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 14, 2026

I've been playing with AdvancedSubtensor rewrites from analysis of the Wishart PR and we should talk.

I feel AdvancedSubtensor rewrites are more general for these cases and not necessarily harder. We may want to rewrite ExtractDiagonal (and AllocDiagona) as the AdvancedSubtensor/AdvancedSetSubtensor version, then immediately call on these rewrites when we know they'll help.

@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_alloc_diag(fgraph, node):
"""ExtractDiag(AllocDiag(x, offset=k), offset=k) -> x
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't AllocDiag be a scalar/vector that is broadcasted?

The diagonal of an eye matrix is a vector of ones.
"""
op = node.op
if op.axis1 != 0 or op.axis2 != 1:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about poor batched Eye?

@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_eye_mul(fgraph, node):
"""ExtractDiag(eye * x) -> extract the non-zero diagonal values.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can Eye * x be written as AllocDiag?

@jessegrabowski
Copy link
Copy Markdown
Member Author

I'll pause on this pending that discussion then. I will say though that some of these have big savings, specifically the elemwise lift.

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 23, 2026

ExtractDiag(Transpose(x), offset=k) -> ExtractDiag(X, offset=-k) (I understand transpose is just a free view but it removes a blocker to further rewrites that no longer need to see through the transpose)

No reason not to do this. This is actually something that generalizes to other Ops somewhat. Sum of Transpose -> Sum with transposed axes, and the like. Always good to remove cruft that doesn't affect computation

Otherwise all rewrites you mention would work out of the box if we materialize ExtractDiag (and AllocDia) as the equivalent advanced subtensor / advanced set_subtensor and let the rewrites from #2061 act.

Except for

ExtractDiag(Elemwise(a, b)) -> Elemwise(ExtractDiag(a), ExtractDiag(b)) (plus some broadcasting logic)

Which I think we should tackle now. I had been worried about duplicate indices, but if they are constant (or provably unique like created from symbolic arange) we don't need to worry about. We should always reduce before we compute.


We need to think about this. Maybe ExtractDiag AllocDiag could be an OFG so it's trivial for rewrites to materialize the low lever IR and reuse the general read-write rewrite logic we have?

@jessegrabowski
Copy link
Copy Markdown
Member Author

@ricardoV94 how does this PR need to change now that #2061 is merged

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented Apr 29, 2026

@jessegrabowski what I would suggest trying (may prove not the best solution) is to lower all ExtractDiag of Diag | Eye | Alloc | Eye * x -> AdvancedSubtensor(arange(n), arange(n)) of AdvancedSetSubtensor(zeros(n, n), 1 | x, arange(n), arange(n)), and see if our existing rewrites simplify all. This lowering would be done in the rewrite, when we think it's going to work as is done in:

@register_stabilize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([ExtractDiag])
def local_extract_diag_of_write(fgraph, node):
"""Delegate ``extract_diag(advanced_inc_subtensor(...))`` to the constant-indices rewrite.
Rewrites ``extract_diag(x, offset=k)`` as the equivalent
``x[..., arange(d) + max(0, -k), arange(d) + max(0, k), ...]`` and
calls ``local_advanced_read_of_write_constant_indices`` to do the
work. Since ``extract_diag`` is a zero-copy view, we only commit the
replacement when the downstream rewrite eliminates the gather.
Requires statically-known sizes on the two diagonal axes.
"""
op = node.op
inner = node.inputs[0]
# AdvancedIncSubtensor1 is intentionally not accepted: it writes whole
# rows/slices on a single axis, not specific (i, j) positions, so it
# can't express "write the diagonal" the way two paired index arrays can.
if not (inner.owner and isinstance(inner.owner.op, AdvancedIncSubtensor)):
return None
# Need static sizes on the two diagonal axes to build constant indices.
dim_a = inner.type.shape[op.axis1]
dim_b = inner.type.shape[op.axis2]
if dim_a is None or dim_b is None:
return None
k = op.offset
row_offset = max(0, -k)
col_offset = max(0, k)
d = min(dim_a - row_offset, dim_b - col_offset)
if d <= 0:
return None
# Build equivalent AdvancedSubtensor: inner[..., arange(d) + row_offset, ..., arange(d) + col_offset, ...]
base_arange = np.arange(d, dtype=np.int64)
rows = pytensor.tensor.as_tensor_variable(base_arange + row_offset)
cols = pytensor.tensor.as_tensor_variable(base_arange + col_offset)
idxs = [slice(None)] * inner.type.ndim
idxs[op.axis1] = rows
idxs[op.axis2] = cols
equiv = inner[tuple(idxs)]
if not (equiv.owner and isinstance(equiv.owner.op, AdvancedSubtensor)):
return None
# Delegate to the general read-after-write rewrite.
result = local_advanced_read_of_write_constant_indices.fn(fgraph, equiv.owner)
if not result:
return None
# Stay zero-copy where possible: when the simplification reduced to a
# gather of the inner write's base at our diagonal-arange pattern (i.e.
# the no-coverage case where the write is irrelevant for this read),
# re-emit as ExtractDiag so we keep the view semantics of the original.
base = inner.owner.inputs[0]
[result_var] = result
if (
result_var.owner
and isinstance(result_var.owner.op, AdvancedSubtensor)
and result_var.owner.inputs[0] is base
):
out = base.diagonal(offset=k, axis1=op.axis1, axis2=op.axis2)
copy_stack_trace(node.outputs[0], out)
return [out]
copy_stack_trace(node.outputs[0], result)
return result

(the k changes the arange bit)

Make sure to reuse the same arange(n) between AdvancedSubtensor and AdvancedSetSubtensor in case it's not constant for our rewrites that work with symbolic x. If n is constant use constant indices for the more general rewrites.

I think this will cover all ExtractDiag of Diag writes.

Separately keep the transpose -> k=-k rewrite and the ExtractDiag(Elemwise). Just be careful about broadcasted inputs, if you weren't already: ExtractDiag(mat + v) -> ExtractDiag(mat) + v (possibly with some squeezing/dimshuffle.

We don't have that rewrite for Elemwise, and it would be nice to keep having an ExtractDiag in other end anyway, because ExtractDiag is always simpler to reason about.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants