Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,23 @@ def local_block_diag_dot_to_dot_block_diag(fgraph, node):
try:
client_idx = client.inputs.index(blockdiag_result)
except ValueError:
# If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a
# DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the
# blockdiag result.
# If the blockdiag result is not an input to the dot, there is at least one Op between them.
# We allow left expand_dims (DimShuffle), which is introduced automatically by Blockwise to equalize number of batch dims,
# But does not change the semantics of the graph
for ancestor in client.inputs:
if ancestor.owner and blockdiag_result in ancestor.owner.inputs:
if (
ancestor.owner is not None
and (
isinstance(ancestor.owner.op, DimShuffle)
and ancestor.owner.op.is_left_expand_dims
)
and blockdiag_result in ancestor.owner.inputs
):
client_idx = client.inputs.index(ancestor)
break
else: # no-break
# Not a simple left expand_dims between dot and block_diag
return None

other_input = client.inputs[1 - client_idx]

Expand Down
224 changes: 126 additions & 98 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
)
from pytensor.tensor.variable import TensorConstant
from tests import unittest_tools as utt
from tests.unittest_tools import assert_equal_computations


rewrite_mode = config.mode
Expand Down Expand Up @@ -4857,119 +4858,146 @@ def test_local_dot_to_mul_unspecified_length_1():
)


@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
)
@pytest.mark.parametrize(
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
)
def test_local_block_diag_dot_to_dot_block_diag(
left_multiply, batch_blockdiag, batch_other
):
"""
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
"""
class TestBlockDiagDotToDotBlockDiag:
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
)
@pytest.mark.parametrize(
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
)
def test_rewrite_applies(self, left_multiply, batch_blockdiag, batch_other):
"""
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
"""

def has_blockdiag(graph):
return any(
(
var.owner
and (
isinstance(var.owner.op, BlockDiagonal)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, BlockDiagonal)
def has_blockdiag(graph):
return any(
(
var.owner
and (
isinstance(var.owner.op, BlockDiagonal)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, BlockDiagonal)
)
)
)
for var in ancestors([graph])
)
for var in ancestors([graph])
)

a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)
a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)

d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))

# Test multiple clients are all rewritten
if left_multiply:
out = x @ d
else:
out = d @ x
# Test multiple clients are all rewritten
if left_multiply:
out = x @ d
else:
out = d @ x

assert has_blockdiag(out)
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
assert has_blockdiag(out)
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
assert not has_blockdiag(fn.maker.fgraph.outputs[0])

n_dots_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3
n_dots_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Dot | Dot22)
)
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3

fn_expected = pytensor.function(
[a, b, c, d],
out,
mode=Mode(linker="py", optimizer=None),
)
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
fn_expected = pytensor.function(
[a, b, c, d],
out,
mode=Mode(linker="py", optimizer=None),
)
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])

n_dots_no_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn_expected.maker.fgraph.apply_nodes
)
assert n_dots_no_rewrite == 1
n_dots_no_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (
isinstance(node.op, Blockwise)
and isinstance(node.op.core_op, Dot | Dot22)
)
for node in fn_expected.maker.fgraph.apply_nodes
)
assert n_dots_no_rewrite == 1

rng = np.random.default_rng()
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
rng = np.random.default_rng()
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)

rewrite_out = fn(a_val, b_val, c_val, d_val)
expected_out = fn_expected(a_val, b_val, c_val, d_val)
np.testing.assert_allclose(
rewrite_out,
expected_out,
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)
rewrite_out = fn(a_val, b_val, c_val, d_val)
expected_out = fn_expected(a_val, b_val, c_val, d_val)
np.testing.assert_allclose(
rewrite_out,
expected_out,
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)

def test_rewrite_does_not_apply(self):
# Regression test for https://github.com/pymc-devs/pytensor/issues/1836

@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
c_size = size - a_size - b_size
# Shapes match if either R is tranposed or y is, but not by default
y = pt.tensor("y", shape=(7, 9))
R1 = pt.tensor("R1", shape=(2, 3))
R2 = pt.tensor("R2", shape=(5, 6))
R = pt.linalg.block_diag(R1, R2)

a = tensor("a", shape=(a_size, a_size))
b = tensor("b", shape=(b_size, b_size))
c = tensor("c", shape=(c_size, c_size))
d = tensor("d", shape=(size,))
# This could be rewritten in the future, if that's the case remove this condition
original = dot(R.mT, y)
rewritten = rewrite_graph(
original, include=("canonicalize", "stabilize", "specialize")
)
assert_equal_computations([rewritten], [original])

x = pt.linalg.block_diag(a, b, c)
out = x @ d
# This is unlikely to ever be rewritten
original = dot(R.exp(), y.mT)
rewritten = rewrite_graph(
original, include=("canonicalize", "stabilize", "specialize")
)
assert_equal_computations([rewritten], [original])

@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_benchmark(self, benchmark, size, rewrite):
rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
c_size = size - a_size - b_size

a = tensor("a", shape=(a_size, a_size))
b = tensor("b", shape=(b_size, b_size))
c = tensor("c", shape=(c_size, c_size))
d = tensor("d", shape=(size,))

x = pt.linalg.block_diag(a, b, c)
out = x @ d

mode = get_default_mode()
if not rewrite:
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
fn = pytensor.function([a, b, c, d], out, mode=mode)

a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)

benchmark(
fn,
a_val,
b_val,
c_val,
d_val,
)
mode = get_default_mode()
if not rewrite:
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
fn = pytensor.function([a, b, c, d], out, mode=mode)

a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)

benchmark(
fn,
a_val,
b_val,
c_val,
d_val,
)
Loading