Rewrite solve(matrix_inverse(X), b) → X @ b#2101
Rewrite solve(matrix_inverse(X), b) → X @ b#2101alessandrogentili001 wants to merge 3 commits intopymc-devs:mainfrom
Conversation
| # Get all nodes in the rewritten graph | ||
| all_nodes = io_toposort([], [rewritten_out]) | ||
|
|
||
| assert not any( | ||
| isinstance(getattr(node.op, "core_op", node.op), Solve | MatrixInverse) | ||
| for node in all_nodes | ||
| ) |
There was a problem hiding this comment.
Can replace this with assert_equal_computation
There was a problem hiding this comment.
this way should works
@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_solve_of_inv_to_matmul(b_ndim):
X = pt.dmatrix("X")
b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")
out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)
# Just include the rewrite we are testing
rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul)
rewritten_out = rewrite_graph(out, custom_rewrite=rewriter)
# Verify the rewrite
expected = X @ b
assert_equal_computations([rewritten_out], [expected])
# Numerical check
rng = np.random.default_rng(42)
X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype)
b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype)
f_opt = function([X, b], rewritten_out)
res_opt = f_opt(X_val, b_val)
res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val)
np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)| # Graph rewrite test | ||
| # We include 'stabilize' because solve_of_inv_to_matmul is registered there. | ||
| # This avoids dependency on the global config.mode (e.g. FAST_COMPILE). | ||
| rewritten_out = rewrite_graph(out, include=["stabilize"]) |
There was a problem hiding this comment.
You can directly include just the rewrite you're testing , its a bit more clear that way
| # Numerical check | ||
| rng = np.random.default_rng(42) | ||
| X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype) | ||
| b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype) | ||
|
|
||
| f_opt = function([X, b], rewritten_out) | ||
| res_opt = f_opt(X_val, b_val) | ||
| res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val) | ||
|
|
||
| np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7) |
There was a problem hiding this comment.
We don't need the numerical check if the structural check passes (you're just testing BLAS at that point -- i promise you BLAS works)
| rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul) | ||
| rewritten_out = rewrite_graph(out, custom_rewrite=rewriter) |
There was a problem hiding this comment.
I don't love the import + custom_rewrite. The pattern I had in mind was to just use rewrite_graph(out, include=('your_rewrite_name', )). If that doesn't work I'd rather it be reverted to what you had before for simplicity. But also fine with it staying this way if you don't want to keep going back and forth.
There was a problem hiding this comment.
Again I think this PR is a great discussion for #2103 which is still open-ended and in ask of feedback. So we standardize how we want to test this sort of rewrites and don't need to waste future time discussing it
Description
Related Issue
Checklist
Type of change