Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/extending/graph_rewriting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,14 @@ high-level overview of miniKanren's use as a tool for symbolic computation see
`"miniKanren as a Tool for Symbolic Computation in Python" <https://arxiv.org/abs/2005.11644>`_.


.. _testing_rewrites:

Testing Rewrites
================

.. autoclass:: tests.unittest_tools.RewriteTester


.. _optdb:

The Optimization Database (:obj:`optdb`)
Expand Down
45 changes: 14 additions & 31 deletions tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,53 +179,36 @@ def test_local_useless_inc_subtensor_increment_zeros():
s = pt.zeros((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))

mode = get_default_mode().including("local_useless_inc_subtensor")
f_shape = function([y], o_shape, mode=mode)

topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)
result = utt.RewriteTester([y], [o_shape])
result.assert_graph(specify_shape(y, (2, 2)))


def test_local_useless_inc_subtensor_no_opt():
r"""Make sure we don't remove `IncSubtensor`\s that involve slices with steps that skip elements and non-zero increments."""
x = matrix("x")
y = matrix("y")

# Stepped slice — can't be removed.
s = x[:, ::2]
o_shape = set_subtensor(s, specify_shape(y, s.shape))

mode = get_default_mode().including("local_useless_inc_subtensor")
f_shape = function([x, y], o_shape, mode=mode)

topo = f_shape.maker.fgraph.toposort()
assert any(isinstance(n.op, IncSubtensor) for n in topo)

out = f_shape([[2, 3, 6, 7]], [[8, 9]])
assert np.array_equal(out, np.asarray([[8, 3, 9, 7]]))
result = utt.RewriteTester([x, y], [o_shape])
result.assert_graph(o_shape)
result.assert_eval([[2, 3, 6, 7]], [[8, 9]])

# Increment with a non-constant target array, full slices collapse to x + y.
s = x[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))

f_shape = function([x, y], o_shape, mode=mode)

topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)

out = f_shape([[1, 2], [3, 4]], [[10, 20], [30, 40]])
assert np.array_equal(out, np.asarray([[11, 22], [33, 44]]))
result = utt.RewriteTester([x, y], [o_shape])
result.assert_graph(x + specify_shape(y, x.shape))
result.assert_eval([[1, 2], [3, 4]], [[10, 20], [30, 40]])

# Increment with a non-zero constant target array, same collapse to x + y.
s = pt.ones((2, 2))[:, :]
ones = pt.ones((2, 2))
s = ones[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))

f_shape = function([y], o_shape, mode=mode)

topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)

out = f_shape([[10, 20], [30, 40]])
assert np.array_equal(out, np.asarray([[11, 21], [31, 41]]))
result = utt.RewriteTester([y], [o_shape])
result.assert_graph(ones + specify_shape(y, ones.shape))
result.assert_eval([[10, 20], [30, 40]])


def test_local_add_of_sparse_write():
Expand Down
182 changes: 181 additions & 1 deletion tests/unittest_tools.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
import logging
import sys
from collections.abc import Sequence
from copy import copy, deepcopy
from functools import wraps
from numbers import Number
from typing import TYPE_CHECKING

import numpy as np
import pytest

import pytensor
from pytensor.compile import optdb
from pytensor.compile.debug.debugmode import str_diagnostic
from pytensor.compile.mode import Mode
from pytensor.gradient import verify_grad as orig_verify_grad
from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import NodeRewriter, dfs_rewriter
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.math import _allclose
from pytensor.tensor.math import add as pt_add


if TYPE_CHECKING:
from pytensor.graph.rewriting.basic import GraphRewriter


_logger = logging.getLogger("tests.unittest_tools")


Expand Down Expand Up @@ -402,3 +413,172 @@ def test_with_assert(*args, **kwargs):
return test_with_assert
else:
return f


class RewriteTester:
"""Apply rewrites to a graph and provide assertions for correctness.

Rewrite tests should verify **two independent properties**:

1. **Graph structure** — the rewritten graph matches the expected symbolic
expression. This is the primary check. Prefer this over Op-count
assertions, which can silently pass when one Op is replaced by a
specialised variant of the same kind (e.g. ``AdvancedSubtensor`` →
``AdvancedSubtensor1``, or ``Blockwise(Dot)`` → ``Dot``). Writing
out the expected graph is harder, but it makes the test explicit about
*exactly* what the rewrite should produce.

2. **Numerical equivalence** — the original and rewritten graphs produce
the same values when evaluated. This is an independent safety net: a
graph can look structurally correct yet compute something subtly
different from the original (e.g. operand order swapped in a
non-commutative Op). Both graphs are compiled with ``optimizer=None``
so no further rewrites can mask a problem. If you are confident the
structural check is sufficient, the numerical check can be skipped.

Example — canonicalize removes redundant transposes (``x.T.T → x``)::

x = pt.dmatrix("x")
out = x.T.T

result = RewriteTester([x], [out])
result.assert_graph(x)
result.assert_eval(np.eye(3))

**include / exclude** control which passes from the rewrite database are
applied. They are orthogonal to ``custom_rewrite`` — both can be used
together or independently. The default ``include=("canonicalize",)``
applies only the canonicalization pass, which is usually enough for
targeted rewrite tests. To test against the full compilation pipeline
(without inplace rewrites), use ``include=("o3",)`` — this is equivalent
to ``fast_run`` minus ``inplace``. Pass ``include=()`` to disable
database rewrites entirely.

**custom_rewrite** applies a specific rewriter before the database
passes. ``NodeRewriter`` instances are automatically wrapped in a DFS
``WalkingGraphRewriter``. For example, testing ``log(1 + x) → log1p(x)``
with the ``local_log1p`` node rewriter::

from pytensor.tensor.math import log, log1p
from pytensor.tensor.rewriting.math import local_log1p

x = pt.dscalar("x")
out = log(1 + x)

result = RewriteTester(
[x],
[out],
include=(),
custom_rewrite=local_log1p,
)
result.assert_graph(log1p(x))
result.assert_eval(0.5)

**linker** controls the backend used when evaluating both graphs
(default ``"py"``). Use ``"c"`` or ``"numba"`` to test rewrites that
target a specific backend.

Parameters
----------
inputs
Input variables of the graph.
outputs
Output variables of the graph.
include
Rewrite database tags to include (default: ``("canonicalize",)``).
Pass ``()`` to disable database rewrites entirely.
exclude
Rewrite database tags to exclude.
custom_rewrite
A ``GraphRewriter`` or ``NodeRewriter`` to apply before the database
rewrites. ``NodeRewriter`` instances are automatically wrapped in a
DFS ``WalkingGraphRewriter``.
linker
Linker used when evaluating both graphs (default: ``"py"``).
**kwargs
Extra arguments forwarded to ``RewriteDatabaseQuery``.
"""

def __init__(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
*,
include: Sequence[str] = ("canonicalize",),
exclude: Sequence[str] = (),
custom_rewrite: "GraphRewriter | NodeRewriter | None" = None,
linker: str = "py",
**kwargs,
):
inputs = list(inputs)
outputs = list(outputs)
self.orig_fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=True)
self.rewr_fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=True)
self._orig_inputs = inputs
self._no_opt = Mode(linker=linker, optimizer=None)
self._orig_fn = None
self._rewr_fn = None

if custom_rewrite is not None:
if isinstance(custom_rewrite, NodeRewriter):
custom_rewrite = dfs_rewriter(custom_rewrite)
custom_rewrite.rewrite(self.rewr_fg)

if include:
query_kwargs = {**kwargs}
if exclude:
query_kwargs["exclude"] = list(exclude)
optdb.query(
RewriteDatabaseQuery(include=list(include), **query_kwargs)
).rewrite(self.rewr_fg)

@property
def orig_fn(self):
if self._orig_fn is None:
self._orig_fn = pytensor.function(
self.orig_fg.inputs,
self.orig_fg.outputs,
mode=self._no_opt,
on_unused_input="ignore",
)
return self._orig_fn

@property
def rewr_fn(self):
if self._rewr_fn is None:
self._rewr_fn = pytensor.function(
self.rewr_fg.inputs,
self.rewr_fg.outputs,
mode=self._no_opt,
on_unused_input="ignore",
)
return self._rewr_fn

def assert_eval(self, *test_values, rtol=None, atol=None):
__tracebackhide__ = True
orig_out = self.orig_fn(*test_values)
rewr_out = self.rewr_fn(*test_values)
if not isinstance(orig_out, list | tuple):
orig_out = [orig_out]
if not isinstance(rewr_out, list | tuple):
rewr_out = [rewr_out]
for i, (a, b) in enumerate(zip(orig_out, rewr_out, strict=True)):
np.testing.assert_allclose(
a,
b,
rtol=rtol or 1e-7,
atol=atol or 0,
err_msg=f"Output {i} mismatch between original and rewritten graph",
)

def assert_graph(self, *expected_outputs, strict_dtype=False, **kwargs):
__tracebackhide__ = True
assert_equal_computations(
self.rewr_fg.outputs,
list(expected_outputs),
in_xs=list(self.rewr_fg.inputs),
in_ys=self._orig_inputs,
original=self.orig_fg.outputs,
strict_dtype=strict_dtype,
**kwargs,
)
Loading