diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index d283ba3e93..6bfc2ee936 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -183,13 +183,23 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node): try: client_idx = client.inputs.index(blockdiag_result) except ValueError: - # If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a - # DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the - # blockdiag result. + # If the blockdiag result is not an input to the dot, there is at least one Op between them. + # We allow left expand_dims (DimShuffle), which is introduced automatically by Blockwise to equalize number of batch dims, + # But does not change the semantics of the graph for ancestor in client.inputs: - if ancestor.owner and blockdiag_result in ancestor.owner.inputs: + if ( + ancestor.owner is not None + and ( + isinstance(ancestor.owner.op, DimShuffle) + and ancestor.owner.op.is_left_expand_dims + ) + and blockdiag_result in ancestor.owner.inputs + ): client_idx = client.inputs.index(ancestor) break + else: # no-break + # Not a simple left expand_dims between dot and block_diag + return None other_input = client.inputs[1 - client_idx] diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 881fb59214..e056b8e30f 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -148,6 +148,7 @@ ) from pytensor.tensor.variable import TensorConstant from tests import unittest_tools as utt +from tests.unittest_tools import assert_equal_computations rewrite_mode = config.mode @@ -4857,119 +4858,146 @@ def test_local_dot_to_mul_unspecified_length_1(): ) -@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) -@pytest.mark.parametrize( - "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"] -) -@pytest.mark.parametrize( - "batch_other", [True, False], ids=["batched_other", "unbatched_other"] -) -def test_local_block_diag_dot_to_dot_block_diag( - left_multiply, batch_blockdiag, batch_other -): - """ - Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) - """ +class TestBlockDiagDotToDotBlockDiag: + @pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"]) + @pytest.mark.parametrize( + "batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"] + ) + @pytest.mark.parametrize( + "batch_other", [True, False], ids=["batched_other", "unbatched_other"] + ) + def test_rewrite_applies(self, left_multiply, batch_blockdiag, batch_other): + """ + Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:])) + """ - def has_blockdiag(graph): - return any( - ( - var.owner - and ( - isinstance(var.owner.op, BlockDiagonal) - or ( - isinstance(var.owner.op, Blockwise) - and isinstance(var.owner.op.core_op, BlockDiagonal) + def has_blockdiag(graph): + return any( + ( + var.owner + and ( + isinstance(var.owner.op, BlockDiagonal) + or ( + isinstance(var.owner.op, Blockwise) + and isinstance(var.owner.op.core_op, BlockDiagonal) + ) ) ) + for var in ancestors([graph]) ) - for var in ancestors([graph]) - ) - a = tensor("a", shape=(4, 2)) - b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4)) - c = tensor("c", shape=(4, 4)) - x = pt.linalg.block_diag(a, b, c) + a = tensor("a", shape=(4, 2)) + b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4)) + c = tensor("c", shape=(4, 4)) + x = pt.linalg.block_diag(a, b, c) - d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10)) + d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10)) - # Test multiple clients are all rewritten - if left_multiply: - out = x @ d - else: - out = d @ x + # Test multiple clients are all rewritten + if left_multiply: + out = x @ d + else: + out = d @ x - assert has_blockdiag(out) - fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) - assert not has_blockdiag(fn.maker.fgraph.outputs[0]) + assert has_blockdiag(out) + fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode) + assert not has_blockdiag(fn.maker.fgraph.outputs[0]) - n_dots_rewrite = sum( - isinstance(node.op, Dot | Dot22) - or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) - for node in fn.maker.fgraph.apply_nodes - ) - assert n_dots_rewrite == 3 + n_dots_rewrite = sum( + isinstance(node.op, Dot | Dot22) + or ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, Dot | Dot22) + ) + for node in fn.maker.fgraph.apply_nodes + ) + assert n_dots_rewrite == 3 - fn_expected = pytensor.function( - [a, b, c, d], - out, - mode=Mode(linker="py", optimizer=None), - ) - assert has_blockdiag(fn_expected.maker.fgraph.outputs[0]) + fn_expected = pytensor.function( + [a, b, c, d], + out, + mode=Mode(linker="py", optimizer=None), + ) + assert has_blockdiag(fn_expected.maker.fgraph.outputs[0]) - n_dots_no_rewrite = sum( - isinstance(node.op, Dot | Dot22) - or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22)) - for node in fn_expected.maker.fgraph.apply_nodes - ) - assert n_dots_no_rewrite == 1 + n_dots_no_rewrite = sum( + isinstance(node.op, Dot | Dot22) + or ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, Dot | Dot22) + ) + for node in fn_expected.maker.fgraph.apply_nodes + ) + assert n_dots_no_rewrite == 1 - rng = np.random.default_rng() - a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) - b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) - c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) - d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + rng = np.random.default_rng() + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) - rewrite_out = fn(a_val, b_val, c_val, d_val) - expected_out = fn_expected(a_val, b_val, c_val, d_val) - np.testing.assert_allclose( - rewrite_out, - expected_out, - atol=1e-6 if config.floatX == "float32" else 1e-12, - rtol=1e-6 if config.floatX == "float32" else 1e-12, - ) + rewrite_out = fn(a_val, b_val, c_val, d_val) + expected_out = fn_expected(a_val, b_val, c_val, d_val) + np.testing.assert_allclose( + rewrite_out, + expected_out, + atol=1e-6 if config.floatX == "float32" else 1e-12, + rtol=1e-6 if config.floatX == "float32" else 1e-12, + ) + def test_rewrite_does_not_apply(self): + # Regression test for https://github.com/pymc-devs/pytensor/issues/1836 -@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) -@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) -def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite): - rng = np.random.default_rng() - a_size = int(rng.uniform(1, int(0.8 * size))) - b_size = int(rng.uniform(1, int(0.8 * (size - a_size)))) - c_size = size - a_size - b_size + # Shapes match if either R is tranposed or y is, but not by default + y = pt.tensor("y", shape=(7, 9)) + R1 = pt.tensor("R1", shape=(2, 3)) + R2 = pt.tensor("R2", shape=(5, 6)) + R = pt.linalg.block_diag(R1, R2) - a = tensor("a", shape=(a_size, a_size)) - b = tensor("b", shape=(b_size, b_size)) - c = tensor("c", shape=(c_size, c_size)) - d = tensor("d", shape=(size,)) + # This could be rewritten in the future, if that's the case remove this condition + original = dot(R.mT, y) + rewritten = rewrite_graph( + original, include=("canonicalize", "stabilize", "specialize") + ) + assert_equal_computations([rewritten], [original]) - x = pt.linalg.block_diag(a, b, c) - out = x @ d + # This is unlikely to ever be rewritten + original = dot(R.exp(), y.mT) + rewritten = rewrite_graph( + original, include=("canonicalize", "stabilize", "specialize") + ) + assert_equal_computations([rewritten], [original]) + + @pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) + @pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"]) + def test_benchmark(self, benchmark, size, rewrite): + rng = np.random.default_rng() + a_size = int(rng.uniform(1, int(0.8 * size))) + b_size = int(rng.uniform(1, int(0.8 * (size - a_size)))) + c_size = size - a_size - b_size + + a = tensor("a", shape=(a_size, a_size)) + b = tensor("b", shape=(b_size, b_size)) + c = tensor("c", shape=(c_size, c_size)) + d = tensor("d", shape=(size,)) + + x = pt.linalg.block_diag(a, b, c) + out = x @ d - mode = get_default_mode() - if not rewrite: - mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") - fn = pytensor.function([a, b, c, d], out, mode=mode) - - a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) - b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) - c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) - d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) - - benchmark( - fn, - a_val, - b_val, - c_val, - d_val, - ) + mode = get_default_mode() + if not rewrite: + mode = mode.excluding("local_block_diag_dot_to_dot_block_diag") + fn = pytensor.function([a, b, c, d], out, mode=mode) + + a_val = rng.normal(size=a.type.shape).astype(a.type.dtype) + b_val = rng.normal(size=b.type.shape).astype(b.type.dtype) + c_val = rng.normal(size=c.type.shape).astype(c.type.dtype) + d_val = rng.normal(size=d.type.shape).astype(d.type.dtype) + + benchmark( + fn, + a_val, + b_val, + c_val, + d_val, + )