More rewrites for ExtractDiag#2045
Conversation
|
I've been playing with AdvancedSubtensor rewrites from analysis of the Wishart PR and we should talk. I feel AdvancedSubtensor rewrites are more general for these cases and not necessarily harder. We may want to rewrite ExtractDiagonal (and AllocDiagona) as the AdvancedSubtensor/AdvancedSetSubtensor version, then immediately call on these rewrites when we know they'll help. |
| @register_specialize | ||
| @node_rewriter([ExtractDiag]) | ||
| def extract_diag_of_alloc_diag(fgraph, node): | ||
| """ExtractDiag(AllocDiag(x, offset=k), offset=k) -> x |
There was a problem hiding this comment.
Can't AllocDiag be a scalar/vector that is broadcasted?
| The diagonal of an eye matrix is a vector of ones. | ||
| """ | ||
| op = node.op | ||
| if op.axis1 != 0 or op.axis2 != 1: |
There was a problem hiding this comment.
what about poor batched Eye?
| @register_specialize | ||
| @node_rewriter([ExtractDiag]) | ||
| def extract_diag_of_eye_mul(fgraph, node): | ||
| """ExtractDiag(eye * x) -> extract the non-zero diagonal values. |
There was a problem hiding this comment.
Can Eye * x be written as AllocDiag?
|
I'll pause on this pending that discussion then. I will say though that some of these have big savings, specifically the elemwise lift. |
No reason not to do this. This is actually something that generalizes to other Ops somewhat. Sum of Transpose -> Sum with transposed axes, and the like. Always good to remove cruft that doesn't affect computation Otherwise all rewrites you mention would work out of the box if we materialize ExtractDiag (and AllocDia) as the equivalent advanced subtensor / advanced set_subtensor and let the rewrites from #2061 act. Except for
Which I think we should tackle now. I had been worried about duplicate indices, but if they are constant (or provably unique like created from symbolic arange) we don't need to worry about. We should always reduce before we compute. We need to think about this. Maybe ExtractDiag AllocDiag could be an OFG so it's trivial for rewrites to materialize the low lever IR and reuse the general read-write rewrite logic we have? |
|
@ricardoV94 how does this PR need to change now that #2061 is merged |
|
@jessegrabowski what I would suggest trying (may prove not the best solution) is to lower all ExtractDiag of Diag | Eye | Alloc | Eye * x -> AdvancedSubtensor(arange(n), arange(n)) of AdvancedSetSubtensor(zeros(n, n), 1 | x, arange(n), arange(n)), and see if our existing rewrites simplify all. This lowering would be done in the rewrite, when we think it's going to work as is done in: pytensor/pytensor/tensor/rewriting/subtensor.py Lines 2151 to 2220 in d35fb51 (the k changes the arange bit) Make sure to reuse the same arange(n) between AdvancedSubtensor and AdvancedSetSubtensor in case it's not constant for our rewrites that work with symbolic x. If n is constant use constant indices for the more general rewrites. I think this will cover all ExtractDiag of Diag writes. Separately keep the transpose -> k=-k rewrite and the ExtractDiag(Elemwise). Just be careful about broadcasted inputs, if you weren't already: ExtractDiag(mat + v) -> ExtractDiag(mat) + v (possibly with some squeezing/dimshuffle. We don't have that rewrite for Elemwise, and it would be nice to keep having an ExtractDiag in other end anyway, because ExtractDiag is always simpler to reason about. |
These came up when I was working on #2032. They're not related to that perse so I split them off. We're missing a lot of simple rewrites for ExtractDiag. I added: