Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pytensor/tensor/rewriting/linalg/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor
from pytensor.tensor.linalg.decomposition.qr import QR
from pytensor.tensor.linalg.decomposition.svd import SVD
from pytensor.tensor.linalg.inverse import MatrixInverse
from pytensor.tensor.linalg.summary import SLogDet, det
from pytensor.tensor.math import Prod, log, prod
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -192,6 +193,16 @@ def det_of_triangular(fgraph, node):
return [det_val]


@register_canonicalize
@register_stabilize
@node_rewriter([det])
def det_of_inv(fgraph, node):
"""Replace det(matrix_inverse(X)) with reciprocal(det(X))."""
match node.inputs[0].owner_op_and_inputs:
case (Blockwise(MatrixInverse()), X):
return [1 / det(X)]


@register_specialize
@node_rewriter([det])
def slogdet_specialization(fgraph, node):
Expand Down
91 changes: 90 additions & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
from pytensor.tensor.rewriting.blockwise import blockwise_of
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.shape import Shape, Shape_i, specify_shape
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.subtensor import Subtensor, _is_provably_positive
from pytensor.tensor.type import (
complex_dtypes,
uint_dtypes,
Expand Down Expand Up @@ -689,6 +689,95 @@ def local_exp_log_nan_switch(fgraph, node):
return [new_out]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([log])
def local_log_reciprocal(fgraph, node):
"""Rewrite log(reciprocal(x)) -> -log(x)."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do the more general as well (reciprocal is fine): log(a/b), where a or b is a non-negative constant -> log(a) - log(b) (the constant constant-folded already).

(inp,) = node.inputs
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, ps.Reciprocal)
):
return [neg(log(inp.owner.inputs[0]))]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([log])
def local_log_div(fgraph, node):
"""Rewrite log(a / b) -> log(a) - log(b) when a or b is provably positive.

The provably-positive side is typically a constant or a shape, which the
surrounding pipeline constant-folds.
"""
(inp,) = node.inputs
if not (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, ps.TrueDiv)
):
return None

num, den = inp.owner.inputs
if _is_provably_positive(num, strict=True) or _is_provably_positive(
den, strict=True
):
return [log(num) - log(den)]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think complexity wise we only want to do this if one is a constant so we replace 1 log and 1 division by 1 log and one subtraction, not 2 logs and one subtraction



@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([sign])
def local_sign_reciprocal(fgraph, node):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would combine the reciprocal and div cases in these rewrites, conceptually the same

"""Rewrite sign(reciprocal(x)) -> sign(x)."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here sign a/b, where one is a positive constant -> sign of the other term. If the constant is negative, 1-sign of the other. If it's mixed, can't do anything

(inp,) = node.inputs
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, ps.Reciprocal)
):
return [sign(inp.owner.inputs[0])]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([sign])
def local_sign_div(fgraph, node):
"""Rewrite sign(a / b) using a known-sign numerator or denominator.

Provably positive side -> ``sign(other)``; negative constant side ->
``-sign(other)``. Bails out otherwise.
"""
(inp,) = node.inputs
if not (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, ps.TrueDiv)
):
return None

num, den = inp.owner.inputs

if _is_provably_positive(num, strict=True):
return [sign(den)]
if _is_provably_positive(den, strict=True):
return [sign(num)]

for side, other in ((num, den), (den, num)):
try:
val = get_underlying_scalar_constant_value(side)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be scalar, I would just check for constant and then all

except NotScalarConstantError:
continue
if np.all(val < 0):
return [neg(sign(other))]


@register_canonicalize
@register_specialize
@node_rewriter([Sum])
Expand Down
35 changes: 35 additions & 0 deletions tests/tensor/rewriting/linalg/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,38 @@ def test_det_of_factorized_matrix_special_cases(original_fn, expected_fn):
expected = expected_fn(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])


def test_det_of_inv():
x = pt.tensor("x", shape=(3, 3))
out = det(pt.linalg.inv(x))
expected = pt.as_tensor(1.0, dtype="float64") / det(x)
rewritten = rewrite_graph(out, include=["canonicalize", "stabilize"])
assert_equal_computations([rewritten], [expected])


def test_slogdet_of_inv():
x = pt.dmatrix("x")
# slogdet(inv(x)) -> (sign, logabsdet)
sign_inv, logabsdet_inv = pt.linalg.slogdet(pt.linalg.inv(x))

# expected: (sign(det(x)), -logabsdet(det(x)))
# det(inv(x)) = 1/det(x), so sign is same.
# logabsdet(inv(x)) = log(abs(1/det(x))) = -log(abs(det(x)))
sign_x, logabsdet_x = pt.linalg.slogdet(x)
expected_sign = sign_x
expected_logabsdet = -logabsdet_x

# We need stabilize for det_of_inv and log_reciprocal
# and specialize for slogdet_specialization
rewritten_sign, rewritten_logabsdet = rewrite_graph(
[sign_inv, logabsdet_inv], include=["canonicalize", "stabilize", "specialize"]
)

expected_sign_opt, expected_logabsdet_opt = rewrite_graph(
[expected_sign, expected_logabsdet],
include=["canonicalize", "stabilize", "specialize"],
)

assert_equal_computations([rewritten_sign], [expected_sign_opt])
assert_equal_computations([rewritten_logabsdet], [expected_logabsdet_opt])
80 changes: 80 additions & 0 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -5116,3 +5116,83 @@ def test_rewrite_does_not_apply(self):
original, include=("canonicalize", "stabilize", "specialize")
)
assert_equal_computations([rewritten], [original])


def test_log_reciprocal():
x = pt.dscalar("x")
out = pt.log(pt.reciprocal(x))
expected = -pt.log(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])


def test_sign_reciprocal():
x = pt.dscalar("x")
out = pt.sign(pt.reciprocal(x))
expected = pt.sign(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])


@pytest.mark.parametrize(
"build, expected_fn",
[
(lambda x: pt.log(3.0 / x), lambda x: pt.log(3.0) - pt.log(x)),
(lambda x: pt.log(x / 3.0), lambda x: pt.log(x) - pt.log(3.0)),
(lambda x: pt.log(1.0 / x), lambda x: -pt.log(x)),
],
ids=["pos_const_num", "pos_const_den", "one_over_x"],
)
def test_log_div_positive_constant(build, expected_fn):
x = pt.dscalar("x")
rewritten = rewrite_graph(
build(x), include=["canonicalize", "stabilize", "specialize"]
)
expected = rewrite_graph(
expected_fn(x), include=["canonicalize", "stabilize", "specialize"]
)
assert_equal_computations([rewritten], [expected])


def test_log_div_non_constant_not_rewritten():
x = pt.dscalar("x")
y = pt.dscalar("y")
out = pt.log(x / y)
rewritten = rewrite_graph(out, include=["canonicalize", "stabilize", "specialize"])
# No constant to peel off — graph should still contain a true_div.
nodes = [v.owner for v in ancestors([rewritten]) if v.owner]
assert any(
isinstance(getattr(node.op, "scalar_op", None), ps.TrueDiv) for node in nodes
)


@pytest.mark.parametrize(
"build, expected_fn",
[
(lambda x: pt.sign(3.0 / x), lambda x: pt.sign(x)),
(lambda x: pt.sign(-3.0 / x), lambda x: -pt.sign(x)),
(lambda x: pt.sign(x / 3.0), lambda x: pt.sign(x)),
(lambda x: pt.sign(x / -3.0), lambda x: -pt.sign(x)),
],
ids=["pos_num", "neg_num", "pos_den", "neg_den"],
)
def test_sign_div_constant(build, expected_fn):
x = pt.dscalar("x")
rewritten = rewrite_graph(
build(x), include=["canonicalize", "stabilize", "specialize"]
)
expected = rewrite_graph(
expected_fn(x), include=["canonicalize", "stabilize", "specialize"]
)
assert_equal_computations([rewritten], [expected])


def test_sign_div_non_constant_not_rewritten():
x = pt.dscalar("x")
y = pt.dscalar("y")
out = pt.sign(x / y)
rewritten = rewrite_graph(out, include=["canonicalize", "stabilize", "specialize"])
nodes = [v.owner for v in ancestors([rewritten]) if v.owner]
assert any(
isinstance(getattr(node.op, "scalar_op", None), ps.TrueDiv) for node in nodes
)
Loading