Skip to content

Rewrite solve(matrix_inverse(X), b) → X @ b#2101

Open
alessandrogentili001 wants to merge 3 commits intopymc-devs:mainfrom
alessandrogentili001:rewrite-solve-matrix-inverse-as-mutmul
Open

Rewrite solve(matrix_inverse(X), b) → X @ b#2101
alessandrogentili001 wants to merge 3 commits intopymc-devs:mainfrom
alessandrogentili001:rewrite-solve-matrix-inverse-as-mutmul

Conversation

@alessandrogentili001
Copy link
Copy Markdown
Contributor

Description

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Comment on lines +462 to +468
# Get all nodes in the rewritten graph
all_nodes = io_toposort([], [rewritten_out])

assert not any(
isinstance(getattr(node.op, "core_op", node.op), Solve | MatrixInverse)
for node in all_nodes
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can replace this with assert_equal_computation

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #2103 to try and formalize a bit better

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way should works

@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_solve_of_inv_to_matmul(b_ndim):
    X = pt.dmatrix("X")
    b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")
    out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)

    # Just include the rewrite we are testing
    rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul)
    rewritten_out = rewrite_graph(out, custom_rewrite=rewriter)

    # Verify the rewrite
    expected = X @ b
    assert_equal_computations([rewritten_out], [expected])

    # Numerical check
    rng = np.random.default_rng(42)
    X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype)
    b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype)

    f_opt = function([X, b], rewritten_out)
    res_opt = f_opt(X_val, b_val)
    res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val)

    np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)

# Graph rewrite test
# We include 'stabilize' because solve_of_inv_to_matmul is registered there.
# This avoids dependency on the global config.mode (e.g. FAST_COMPILE).
rewritten_out = rewrite_graph(out, include=["stabilize"])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can directly include just the rewrite you're testing , its a bit more clear that way

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Comment on lines +462 to +471
# Numerical check
rng = np.random.default_rng(42)
X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype)
b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype)

f_opt = function([X, b], rewritten_out)
res_opt = f_opt(X_val, b_val)
res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val)

np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need the numerical check if the structural check passes (you're just testing BLAS at that point -- i promise you BLAS works)

Comment on lines +455 to +456
rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul)
rewritten_out = rewrite_graph(out, custom_rewrite=rewriter)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the import + custom_rewrite. The pattern I had in mind was to just use rewrite_graph(out, include=('your_rewrite_name', )). If that doesn't work I'd rather it be reverted to what you had before for simplicity. But also fine with it staying this way if you don't want to keep going back and forth.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I think this PR is a great discussion for #2103 which is still open-ended and in ask of feedback. So we standardize how we want to test this sort of rewrites and don't need to waste future time discussing it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite solve(matrix_inverse(X), b) → X @ b

3 participants