Skip to content

rewrite test helper#2103

Open
ricardoV94 wants to merge 2 commits intopymc-devs:mainfrom
ricardoV94:rewrite_test_helper
Open

rewrite test helper#2103
ricardoV94 wants to merge 2 commits intopymc-devs:mainfrom
ricardoV94:rewrite_test_helper

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 30, 2026

This PR adds a helper to streamline our rewrite tests, according to what I these days feel is the best strategy:

  1. Write a graph
  2. Rewrite
  3. Check the rewrite matches what we expect
  4. Eval both original and rewritten without any extra compilation

For instance, the test in #2101 would look like:

@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_solve_of_inv_to_matmul(b_ndim):
    X = pt.dmatrix("X")
    b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")

    out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)

    result = utt.rewrite_test([X, b], [out])
    result.assert_equivalent_computations([pt.dot(X, b)])

    rng = np.random.default_rng(42)
    X_val = rng.random((4, 4)) + np.eye(4) * 4
    b_val = rng.random((4,) if b_ndim == 1 else (4, 3))
    result.assert_numerical_close([X_val, b_val], rtol=1e-6)

I really don't like op counts, because it can miss stuff like AdvancedSubtensor become AdvancedSubtensor1, or Blockwise(Dot) become Dot, or Gemm, and it seems like we are optimizing stuff when we are not. Still I added the helpers to count ops...

The main goal though is to reduce friction / have a baseline for contributions

Provides RewriteTester class and rewrite_test factory that clone a graph,
apply rewrites to the clone, and offer structural/numerical assertions
without compiling full-mode functions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant