From 718713a0c65dd1a5c48d3ad3937c3b82cc5644b2 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 30 Apr 2026 23:32:31 +0200 Subject: [PATCH 1/2] Add rewrite_test helper for testing graph rewrites Provides RewriteTester class and rewrite_test factory that clone a graph, apply rewrites to the clone, and offer structural/numerical assertions without compiling full-mode functions. --- doc/extending/graph_rewriting.rst | 8 ++ tests/unittest_tools.py | 189 +++++++++++++++++++++++++++++- 2 files changed, 196 insertions(+), 1 deletion(-) diff --git a/doc/extending/graph_rewriting.rst b/doc/extending/graph_rewriting.rst index 0bb4c9fa7f..b3b6df4c05 100644 --- a/doc/extending/graph_rewriting.rst +++ b/doc/extending/graph_rewriting.rst @@ -548,6 +548,14 @@ high-level overview of miniKanren's use as a tool for symbolic computation see `"miniKanren as a Tool for Symbolic Computation in Python" `_. +.. _testing_rewrites: + +Testing Rewrites +================ + +.. autoclass:: tests.unittest_tools.RewriteTester + + .. _optdb: The Optimization Database (:obj:`optdb`) diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index c63da8eff3..5353945050 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -1,20 +1,31 @@ import logging import sys +from collections.abc import Sequence from copy import copy, deepcopy from functools import wraps from numbers import Number +from typing import TYPE_CHECKING import numpy as np import pytest import pytensor +from pytensor.compile import optdb from pytensor.compile.debug.debugmode import str_diagnostic +from pytensor.compile.mode import Mode, predefined_optimizers from pytensor.gradient import verify_grad as orig_verify_grad -from pytensor.graph.basic import equal_computations +from pytensor.graph.basic import Variable, equal_computations +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import NodeRewriter, dfs_rewriter +from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.tensor.math import _allclose from pytensor.tensor.math import add as pt_add +if TYPE_CHECKING: + from pytensor.graph.rewriting.basic import GraphRewriter + + _logger = logging.getLogger("tests.unittest_tools") @@ -402,3 +413,179 @@ def test_with_assert(*args, **kwargs): return test_with_assert else: return f + + +class RewriteTester: + """Apply rewrites to a graph and provide assertions for correctness. + + Rewrite tests should verify **two independent properties**: + + 1. **Graph structure** — the rewritten graph matches the expected symbolic + expression. This is the primary check. Prefer this over Op-count + assertions, which can silently pass when one Op is replaced by a + specialised variant of the same kind (e.g. ``AdvancedSubtensor`` → + ``AdvancedSubtensor1``, or ``Blockwise(Dot)`` → ``Dot``). Writing + out the expected graph is harder, but it makes the test explicit about + *exactly* what the rewrite should produce. + + 2. **Numerical equivalence** — the original and rewritten graphs produce + the same values when evaluated. This is an independent safety net: a + graph can look structurally correct yet compute something subtly + different from the original (e.g. operand order swapped in a + non-commutative Op). Both graphs are compiled with ``optimizer=None`` + so no further rewrites can mask a problem. If you are confident the + structural check is sufficient, the numerical check can be skipped. + + Example — canonicalize removes redundant transposes (``x.T.T → x``):: + + x = pt.dmatrix("x") + out = x.T.T + + result = RewriteTester([x], [out]) + result.assert_graph(x) + result.assert_eval(np.eye(3)) + + **include / exclude** control which passes from the rewrite database are + applied. They are orthogonal to ``custom_rewrite`` — both can be used + together or independently. The default ``include="canonicalize"`` + applies only the canonicalization pass, which is usually enough for + targeted rewrite tests. ``include`` also accepts a predefined optimizer + name as a single string — e.g. ``include="o3"`` applies ``fast_run`` + minus ``inplace``. Pass ``include=None`` to disable database rewrites + entirely. + + **custom_rewrite** applies a specific rewriter before the database + passes. ``NodeRewriter`` instances are automatically wrapped in a DFS + ``WalkingGraphRewriter``. For example, testing ``log(1 + x) → log1p(x)`` + with the ``local_log1p`` node rewriter:: + + from pytensor.tensor.math import log, log1p + from pytensor.tensor.rewriting.math import local_log1p + + x = pt.dscalar("x") + out = log(1 + x) + + result = RewriteTester( + [x], + [out], + include=None, + custom_rewrite=local_log1p, + ) + result.assert_graph(log1p(x)) + result.assert_eval(0.5) + + **linker** controls the backend used when evaluating both graphs + (default ``"py"``). Use ``"c"`` or ``"numba"`` to test rewrites that + target a specific backend. + + Parameters + ---------- + inputs + Input variables of the graph. + outputs + Output variables of the graph. + include + Rewrite database tags to include (default: ``"canonicalize"``), + or a predefined optimizer name as a string (e.g. ``"o3"``, + ``"fast_run"``). Pass ``None`` to disable database rewrites. + exclude + Rewrite database tags to exclude. + custom_rewrite + A ``GraphRewriter`` or ``NodeRewriter`` to apply before the database + rewrites. ``NodeRewriter`` instances are automatically wrapped in a + DFS ``WalkingGraphRewriter``. + linker + Linker used when evaluating both graphs (default: ``"py"``). + **kwargs + Extra arguments forwarded to ``RewriteDatabaseQuery``. + """ + + def __init__( + self, + inputs: Sequence[Variable], + outputs: Sequence[Variable], + *, + include: str | Sequence[str] | None = "canonicalize", + exclude: Sequence[str] | None = (), + custom_rewrite: "GraphRewriter | NodeRewriter | None" = None, + linker: str = "py", + **kwargs, + ): + inputs = list(inputs) + outputs = list(outputs) + self.orig_fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=True) + self.rewr_fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=True) + self._orig_inputs = inputs + self._no_opt = Mode(linker=linker, optimizer=None) + self._orig_fn = None + self._rewr_fn = None + + if custom_rewrite is not None: + if isinstance(custom_rewrite, NodeRewriter): + custom_rewrite = dfs_rewriter(custom_rewrite) + custom_rewrite.rewrite(self.rewr_fg) + + if include: + if ( + isinstance(include, str) + and (query := predefined_optimizers.get(include)) is not None + ): + pass + else: + if isinstance(include, str): + include = [include] + query = RewriteDatabaseQuery(include=list(include), **kwargs) + if exclude: + query = query.excluding(*exclude) + optdb.query(query).rewrite(self.rewr_fg) + + @property + def orig_fn(self): + if self._orig_fn is None: + self._orig_fn = pytensor.function( + self.orig_fg.inputs, + self.orig_fg.outputs, + mode=self._no_opt, + on_unused_input="ignore", + ) + return self._orig_fn + + @property + def rewr_fn(self): + if self._rewr_fn is None: + self._rewr_fn = pytensor.function( + self.rewr_fg.inputs, + self.rewr_fg.outputs, + mode=self._no_opt, + on_unused_input="ignore", + ) + return self._rewr_fn + + def assert_eval(self, *test_values, rtol=None, atol=None): + __tracebackhide__ = True + orig_out = self.orig_fn(*test_values) + rewr_out = self.rewr_fn(*test_values) + if not isinstance(orig_out, list | tuple): + orig_out = [orig_out] + if not isinstance(rewr_out, list | tuple): + rewr_out = [rewr_out] + for i, (a, b) in enumerate(zip(orig_out, rewr_out, strict=True)): + np.testing.assert_allclose( + a, + b, + rtol=rtol or 1e-7, + atol=atol or 0, + err_msg=f"Output {i} mismatch between original and rewritten graph", + ) + + def assert_graph(self, *expected_outputs, strict_dtype=False, **kwargs): + __tracebackhide__ = True + assert_equal_computations( + self.rewr_fg.outputs, + list(expected_outputs), + in_xs=list(self.rewr_fg.inputs), + in_ys=self._orig_inputs, + original=self.orig_fg.outputs, + strict_dtype=strict_dtype, + **kwargs, + ) From 5e92cf9904e1efc3bd207bc8713caf4682d1b5cc Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 30 Apr 2026 23:40:50 +0200 Subject: [PATCH 2/2] Apply rewrite_test as POC --- tests/tensor/rewriting/test_subtensor.py | 45 ++++++++---------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 430b01e6ca..119f8206d2 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -179,11 +179,8 @@ def test_local_useless_inc_subtensor_increment_zeros(): s = pt.zeros((2, 2))[:, :] o_shape = inc_subtensor(s, specify_shape(y, s.shape)) - mode = get_default_mode().including("local_useless_inc_subtensor") - f_shape = function([y], o_shape, mode=mode) - - topo = f_shape.maker.fgraph.toposort() - assert not any(isinstance(n.op, IncSubtensor) for n in topo) + result = utt.RewriteTester([y], [o_shape]) + result.assert_graph(specify_shape(y, (2, 2))) def test_local_useless_inc_subtensor_no_opt(): @@ -191,41 +188,27 @@ def test_local_useless_inc_subtensor_no_opt(): x = matrix("x") y = matrix("y") + # Stepped slice — can't be removed. s = x[:, ::2] o_shape = set_subtensor(s, specify_shape(y, s.shape)) - - mode = get_default_mode().including("local_useless_inc_subtensor") - f_shape = function([x, y], o_shape, mode=mode) - - topo = f_shape.maker.fgraph.toposort() - assert any(isinstance(n.op, IncSubtensor) for n in topo) - - out = f_shape([[2, 3, 6, 7]], [[8, 9]]) - assert np.array_equal(out, np.asarray([[8, 3, 9, 7]])) + result = utt.RewriteTester([x, y], [o_shape]) + result.assert_graph(o_shape) + result.assert_eval([[2, 3, 6, 7]], [[8, 9]]) # Increment with a non-constant target array, full slices collapse to x + y. s = x[:, :] o_shape = inc_subtensor(s, specify_shape(y, s.shape)) - - f_shape = function([x, y], o_shape, mode=mode) - - topo = f_shape.maker.fgraph.toposort() - assert not any(isinstance(n.op, IncSubtensor) for n in topo) - - out = f_shape([[1, 2], [3, 4]], [[10, 20], [30, 40]]) - assert np.array_equal(out, np.asarray([[11, 22], [33, 44]])) + result = utt.RewriteTester([x, y], [o_shape]) + result.assert_graph(x + specify_shape(y, x.shape)) + result.assert_eval([[1, 2], [3, 4]], [[10, 20], [30, 40]]) # Increment with a non-zero constant target array, same collapse to x + y. - s = pt.ones((2, 2))[:, :] + ones = pt.ones((2, 2)) + s = ones[:, :] o_shape = inc_subtensor(s, specify_shape(y, s.shape)) - - f_shape = function([y], o_shape, mode=mode) - - topo = f_shape.maker.fgraph.toposort() - assert not any(isinstance(n.op, IncSubtensor) for n in topo) - - out = f_shape([[10, 20], [30, 40]]) - assert np.array_equal(out, np.asarray([[11, 21], [31, 41]])) + result = utt.RewriteTester([y], [o_shape]) + result.assert_graph(ones + specify_shape(y, ones.shape)) + result.assert_eval([[10, 20], [30, 40]]) def test_local_add_of_sparse_write():