Skip to content

Rewrite det(inv(X)) → 1/det(X)#2102

Open
alessandrogentili001 wants to merge 2 commits intopymc-devs:mainfrom
alessandrogentili001:rewrite-determinant-of-inverse-matrix
Open

Rewrite det(inv(X)) → 1/det(X)#2102
alessandrogentili001 wants to merge 2 commits intopymc-devs:mainfrom
alessandrogentili001:rewrite-determinant-of-inverse-matrix

Conversation

@alessandrogentili001
Copy link
Copy Markdown
Contributor

Description

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@register_stabilize
@register_specialize
@node_rewriter([Elemwise])
def local_reciprocal_linalg_special_cases(fgraph, node):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These aren't related to linalg, the name is wrong

Each of these should be a separate rewrite, and we should be using the new scalar Op properties. We have monotonic_increasing and monotonic_decreasing, which imply sign(f(x)) -> sign(x) and sign(f(x)) -> sign(-x) if the op is also zero_preserving.

We just have to be careful about strict montonicity vs non-strict. I can't remember if ceil(x) is marked as monotonic for example. We might need a separate flag for the strict variety and check it in this rewrite.

def det_of_inv(fgraph, node):
"""Replace det(matrix_inverse(X)) with reciprocal(det(X)).

Since det(inv(X)) = 1/det(X), we avoid computing the inverse.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Since det(inv(X)) = 1/det(X), we avoid computing the inverse.

@alessandrogentili001
Copy link
Copy Markdown
Contributor Author

Hi jesse, thanks for the feedback! I've refactored the implementation to address your points regarding modularity and property-based rewriting.

  1. Decoupling Linalg from Math: I've moved the non-linalg rewrites (like log(reciprocal(x))) out of pytensor/tensor/rewriting/linalg/summary.py and into pytensor/tensor/rewriting/math.py.
  2. Property-Based sign(f(x)) Rewrite: I implemented a generic local_sign_of_monotonic rewriter in math.py. Instead of hardcoding special cases, it now queries ScalarOp properties.
  3. Strict Monotonicity Flags: To handle the edge cases you mentioned (like ceil), I've added strictly_monotonic_increasing and strictly_monotonic_decreasing flags to the UnaryScalarOp base class and applied them to all relevant strictly monotonic Ops across basic.py and math.py (including Log, Exp, Tanh, Sigmoid, Erf, etc.). The sign rewrite now specifically checks for strict monotonicity to ensure correctness.
  4. Atomic Rewrites: I've separated the logic into atomic rewriters (local_log_reciprocal, local_sign_reciprocal, and local_sign_of_monotonic) and registered them across canonicalize, stabilize, and specialize to ensure they trigger reliably.
  5. Docstring Cleanup: Simplified the det_of_inv docstring as suggested.
  6. Expanded Testing: Added comprehensive tests in tests/tensor/rewriting/test_math.py to verify these generic rewrites, including a negative test case for ceil to confirm it is not incorrectly optimized.

Let me know if you have any further suggestions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite det(inv(X)) → 1/det(X)

2 participants