Skip to content

Create helper to standardize rewrite tests#2103

Open
ricardoV94 wants to merge 2 commits into
pymc-devs:mainfrom
ricardoV94:rewrite_test_helper
Open

Create helper to standardize rewrite tests#2103
ricardoV94 wants to merge 2 commits into
pymc-devs:mainfrom
ricardoV94:rewrite_test_helper

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 30, 2026

This PR adds a helper to streamline our rewrite tests, according to what I these days feel is the best strategy:

  1. Write a graph
  2. Rewrite
  3. Check the rewrite matches what we expect
  4. Eval both original and rewritten without any extra compilation

For instance, the test in #2101 would look like:

@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_solve_of_inv_to_matmul(b_ndim):
    X = pt.dmatrix("X")
    b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")

    out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)

    result = utt.rewrite_test([X, b], [out])
    result.assert_equivalent_computations([pt.dot(X, b)])

    rng = np.random.default_rng(42)
    X_val = rng.random((4, 4)) + np.eye(4) * 4
    b_val = rng.random((4,) if b_ndim == 1 else (4, 3))
    result.assert_numerical_close([X_val, b_val], rtol=1e-6)

I really don't like op counts, because it can miss stuff like AdvancedSubtensor become AdvancedSubtensor1, or Blockwise(Dot) become Dot, or Gemm, and it seems like we are optimizing stuff when we are not. Still I added the helpers to count ops...

The main goal though is to reduce friction / have a baseline for contributions

@jessegrabowski
Copy link
Copy Markdown
Member

Why do you think a numerical test should be standard after a structural comparison? It seems like the structural should be sufficient in most cases.

@ricardoV94
Copy link
Copy Markdown
Member Author

Why do you think a numerical test should be standard after a structural comparison? It seems like the structural should be sufficient in most cases.

I also like structural check but we've been bit by that in the past. We got the structure we expected but there was a flaw in it. The numbers are less likely to be fooled

@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented May 1, 2026

Example the square(sqrt) rewrite we got the wrong order recently. Needed nan in the test but you get the idea

@jessegrabowski
Copy link
Copy Markdown
Member

can we have this? :D

@ricardoV94
Copy link
Copy Markdown
Member Author

can we have this? :D

Yes, may want to trim away stuff we're not sure to need, and also think a bit about the default rewrites we want, maybe nothing by default?

And we can help and if a node_rewriter is passed, wrap it in dfs ourselves?

@ricardoV94 ricardoV94 changed the title rewrite test helper Create test helper for rewrites May 27, 2026
@ricardoV94 ricardoV94 changed the title Create test helper for rewrites Create helper to standardize rewrite tests May 27, 2026
@ricardoV94 ricardoV94 force-pushed the rewrite_test_helper branch 2 times, most recently from ace9287 to e3e3475 Compare May 27, 2026 13:50
@ricardoV94 ricardoV94 requested a review from jessegrabowski May 27, 2026 13:50
@ricardoV94
Copy link
Copy Markdown
Member Author

@jessegrabowski cleaned up and ready for review

Provides RewriteTester class and rewrite_test factory that clone a graph,
apply rewrites to the clone, and offer structural/numerical assertions
without compiling full-mode functions.
@ricardoV94 ricardoV94 force-pushed the rewrite_test_helper branch from e3e3475 to 600a8d6 Compare May 27, 2026 13:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants