Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@
)
from pytensor.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ExtractDiag,
Eye,
Join,
MakeVector,
ScalarFromTensor,
Expand All @@ -70,9 +73,11 @@
as_tensor_variable,
atleast_Nd,
cast,
diagonal,
fill,
get_scalar_constant_value,
join,
ones,
ones_like,
register_infer_shape,
switch,
Expand Down Expand Up @@ -1409,3 +1414,195 @@ def local_join_of_alloc(fgraph, node):
new_out = alloc(new_join, *post_join_shape)
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_alloc_diag(fgraph, node):
"""ExtractDiag(AllocDiag(x, offset=k), offset=k) -> x

When the offsets match and the axes are the standard trailing pair,
extracting the diagonal undoes the diagonal allocation.
"""
op = node.op
[inp] = node.inputs
match inp.owner_op_and_inputs:
case (AllocDiag(offset=offset, axis1=ad_axis1, axis2=ad_axis2), x):
if offset != op.offset:
return None
core_ndim = x.type.ndim
if ad_axis1 == core_ndim - 1 and ad_axis2 == core_ndim:
if op.axis1 == ad_axis1 and op.axis2 == ad_axis2:
return [x]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_eye(fgraph, node):
"""ExtractDiag(Eye(n, m, k), offset=k) -> ones(diag_length)

The diagonal of an eye matrix is a vector of ones.
"""
op = node.op
if op.axis1 != 0 or op.axis2 != 1:
return None

[inp] = node.inputs
match inp.owner_op_and_inputs:
case (Eye(), _, _, k):
if not isinstance(k, Constant):
return None
out = node.outputs[0]
diag_len = out.type.shape[-1]
if diag_len is None:
return None
if k.data.item() == op.offset:
return [ones(diag_len, dtype=out.dtype)]
else:
return [zeros(diag_len, dtype=out.dtype)]
case _:
return None


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_alloc(fgraph, node):
"""ExtractDiag(Alloc(scalar, ...)) -> Alloc(scalar, ..., diag_len)

When a constant scalar is broadcast into a matrix via Alloc, extracting the diagonal just broadcasts it into
the diagonal length.
"""
[inp] = node.inputs
match inp.owner_op_and_inputs:
case (Alloc(), value, *_shape_args):
pass
case _:
return None

if value.type.ndim != 0:
return None

out = node.outputs[0]
out_shape = out.type.shape
new_shape = []
for i, s in enumerate(out_shape):
new_shape.append(s if s is not None else out.shape[i])
return [alloc(value, *new_shape)]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_transpose(fgraph, node):
"""ExtractDiag(X.T, offset=k) -> ExtractDiag(X, offset=-k)

Eliminates a matrix transpose before diagonal extraction. Transpose is a view Op with no runtime cost, but this
rewrite eliminates a Transpose that can get in the way of other ExtractDiag rewrites firing.
"""
op = node.op
[inp] = node.inputs
ndim = inp.type.ndim
if op.axis1 != ndim - 2 or op.axis2 != ndim - 1:
return None

match inp.owner_op_and_inputs:
case (DimShuffle(is_matrix_transpose=True), inner):
return [diagonal(inner, offset=-op.offset, axis1=op.axis1, axis2=op.axis2)]
case _:
return None


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_eye_mul(fgraph, node):
"""ExtractDiag(eye * x) -> extract the non-zero diagonal values.

When a matrix is the elementwise product of an identity matrix and another tensor, extracting the diagonal
simplifies to extracting the relevant values from the non-eye factor.
"""
from pytensor.tensor.rewriting.linalg.utils import is_eye_mul

op = node.op
if op.offset != 0:
return None

[inp] = node.inputs
match is_eye_mul(inp):
case None:
return None
case (_, non_eye_term):
pass

axis1, axis2 = op.axis1, op.axis2
non_eye_bcast = non_eye_term.type.broadcastable
# Scalar-like in both matrix dims: broadcast to diagonal length
if len(non_eye_bcast) >= 2 and non_eye_bcast[axis1] and non_eye_bcast[axis2]:
out = node.outputs[0]
if any(s is None for s in out.type.shape):
return None
squeezed = non_eye_term.squeeze(axis=(axis1, axis2))
return [alloc(squeezed, *out.type.shape)]

# Full matrix in both dims: extract its diagonal
if (
len(non_eye_bcast) >= 2
and not non_eye_bcast[axis1]
and not non_eye_bcast[axis2]
):
return [diagonal(non_eye_term, offset=0, axis1=axis1, axis2=axis2)]

# Vector-like (broadcastable in exactly one of the two matrix dims):
# squeeze out the broadcastable axis to get the vector
if len(non_eye_bcast) >= 2:
if non_eye_bcast[axis1] and not non_eye_bcast[axis2]:
squeezed = non_eye_term.squeeze(axis=axis1)
return [squeezed]
if not non_eye_bcast[axis1] and non_eye_bcast[axis2]:
squeezed = non_eye_term.squeeze(axis=axis2)
return [squeezed]

return None


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([ExtractDiag])
def extract_diag_of_elemwise(fgraph, node):
"""ExtractDiag(Elemwise(f)(A, B, ...)) → Elemwise(f)(ExtractDiag(A), ...)

Pushes diagonal extraction inside elementwise ops so the full matrix intermediate is never materialized. Only
fires when the Elemwise output has no other clients.
"""
op = node.op
[inp] = node.inputs
if not (inp.owner and isinstance(inp.owner.op, Elemwise)):
return None

if len(fgraph.clients[inp]) > 1:
return None

axis1, axis2 = op.axis1, op.axis2
new_inputs = []
for elem_inp in inp.owner.inputs:
b1 = elem_inp.type.broadcastable[axis1]
b2 = elem_inp.type.broadcastable[axis2]
if not b1 and not b2:
new_inputs.append(
diagonal(elem_inp, offset=op.offset, axis1=axis1, axis2=axis2)
)
elif b1 and b2:
new_inputs.append(elem_inp.squeeze(axis=(axis1, axis2)))
else:
return None

return [inp.owner.op(*new_inputs)]
161 changes: 161 additions & 0 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
vector,
)
from tests import unittest_tools as utt
from tests.unittest_tools import assert_equal_computations


rewrite_mode = config.mode
Expand Down Expand Up @@ -2072,3 +2073,163 @@ def test_topological_fill_sink_broadcastable_change():
topological_fill_sink.rewrite(fg)
[new_out] = fg.outputs
assert equal_computations([new_out], [a + b])


class TestExtractDiagRewrites:
rewrite_kw = dict(include=("canonicalize", "stabilize", "specialize"))

def test_extract_diag_of_alloc_diag(self):
v = pt.vector("v", shape=(5,))
out = pt.diagonal(pt.diag(v))
rewritten = rewrite_graph(out, **self.rewrite_kw)
assert_equal_computations([rewritten], [v])

def test_extract_diag_of_alloc_diag_offset_match(self):
v = pt.vector("v", shape=(4,))
out = pt.diagonal(pt.diag(v, k=1), offset=1)
rewritten = rewrite_graph(out, **self.rewrite_kw)
assert_equal_computations([rewritten], [v])

def test_extract_diag_of_alloc_diag_offset_mismatch(self):
"""When offsets differ, the rewrite should not fire."""
v = pt.vector("v", shape=(4,))
out = pt.diagonal(pt.diag(v, k=1), offset=0)
rewritten = rewrite_graph(out, **self.rewrite_kw)
assert_equal_computations([rewritten], [out])

def test_extract_diag_of_eye_square(self):
out = pt.diagonal(pt.eye(5))
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(np.float64(1.0), 5)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_eye_rectangular(self):
out = pt.diagonal(pt.eye(3, 5))
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(np.float64(1.0), 3)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_eye_symbolic(self):
n = pt.iscalar("n")
out = pt.diagonal(pt.eye(n, n, 0))
rewritten = rewrite_graph(out, **self.rewrite_kw)
assert_equal_computations([rewritten], [out])

def test_extract_diag_of_eye_offset_mismatch(self):
out = pt.diagonal(pt.eye(5), offset=1)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.zeros(4, dtype=out.dtype)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_zeros(self):
out = pt.diagonal(pt.zeros((5, 5)))
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(np.float64(0.0), 5)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_ones(self):
out = pt.diagonal(pt.ones((4, 6)))
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(np.float64(1.0), 4)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_full(self):
out = pt.diagonal(pt.full((3, 3), 7.0))
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(np.asarray(7.0, dtype=out.dtype), 3)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_alloc_non_scalar_no_rewrite(self):
"""Alloc of a non-scalar value should not be rewritten."""
v = pt.vector("v", shape=(5,))
out = pt.diagonal(pt.alloc(v, 5, 5))
rewritten = rewrite_graph(out, **self.rewrite_kw)
assert_equal_computations([rewritten], [out])

def test_extract_diag_of_eye_mul_matrix(self):
x = pt.matrix("x", shape=(5, 5))
out = pt.diagonal(pt.eye(5) * x)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.diagonal(x)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_eye_mul_scalar(self):
s = pt.scalar("s")
out = pt.diagonal(pt.eye(5) * s)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(s, 5)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_eye_mul_row(self):
v = pt.row("v", shape=(1, 5))
out = pt.diagonal(pt.eye(5) * v)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = v.squeeze(axis=0)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_eye_mul_col(self):
v = pt.col("v", shape=(5, 1))
out = pt.diagonal(pt.eye(5) * v)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = v.squeeze(axis=1)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_eye_mul_nonzero_offset(self):
"""With offset != 0, the eye diagonal is zeros, so the whole product is zeros."""
x = pt.matrix("x", shape=(5, 5))
out = pt.diagonal(pt.eye(5) * x, offset=1)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.alloc(np.zeros(1, dtype=out.dtype), np.int64(4))
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_transpose(self):
x = pt.matrix("x", shape=(4, 6))
out = pt.diagonal(x.T)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.diagonal(x)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_transpose_offset(self):
x = pt.matrix("x", shape=(4, 6))
out = pt.diagonal(x.T, offset=2)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.diagonal(x, offset=-2)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_batched_transpose(self):
x = pt.tensor("x", shape=(3, 4, 5))
out = pt.diagonal(x.mT, axis1=-2, axis2=-1)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.diagonal(x, axis1=-2, axis2=-1)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_elemwise_unary(self):
x = pt.matrix("x", shape=(5, 4))
out = pt.diagonal(pt.exp(x))
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.exp(pt.diagonal(x))
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_elemwise_binary(self):
x = pt.matrix("x", shape=(5, 5))
y = pt.matrix("y", shape=(5, 5))
out = pt.diagonal(x + y)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.diagonal(x) + pt.diagonal(y)
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_elemwise_scalar_broadcast(self):
x = pt.matrix("x", shape=(5, 5))
s = pt.scalar("s")
out = pt.diagonal(x * s)
rewritten = rewrite_graph(out, **self.rewrite_kw)
expected = pt.diagonal(x) * s
assert_equal_computations([rewritten], [expected])

def test_extract_diag_of_elemwise_mixed_broadcast_no_rewrite(self):
"""Bail when an input is broadcastable in exactly one diagonal axis."""
x = pt.matrix("x", shape=(5, 5))
r = pt.row("r", shape=(1, 5))
out = pt.diagonal(x + r)
rewritten = rewrite_graph(out, **self.rewrite_kw)
assert_equal_computations([rewritten], [out])
Loading