Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ def inner_outputs(self):

def clone(self):
res = copy(self)
res.fgraph = res.fgraph.clone()
res.fgraph = res.fgraph.clone(clone_inner_graphs=True)
return res

def perform(self, node, inputs, outputs):
Expand Down
8 changes: 6 additions & 2 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,9 +838,13 @@ def check_integrity(self) -> None:
def __repr__(self):
return f"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})"

def clone(self, check_integrity=True) -> "FunctionGraph":
def clone(
self, check_integrity=True, clone_inner_graphs: bool = False
) -> "FunctionGraph":
"""Clone the graph."""
return self.clone_get_equiv(check_integrity)[0]
return self.clone_get_equiv(
check_integrity, clone_inner_graphs=clone_inner_graphs
)[0]

def clone_get_equiv(
self, check_integrity: bool = True, attach_feature: bool = True, **kwargs
Expand Down
92 changes: 2 additions & 90 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4300,14 +4300,8 @@ def fgraph(self):
self._fgraph = fgraph
return self._fgraph

def clone_float32(self):
# This will not modify the fgraph or the nodes
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)

def clone(self):
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
return self.__class__(self.fgraph.inputs, self.fgraph.outputs)

def output_types(self, input_types):
if tuple(input_types) != self.inputs_type:
Expand Down Expand Up @@ -4423,86 +4417,4 @@ def c_code(self, node, nodename, inames, onames, sub):
return self.c_code_template % d

def c_code_cache_version_outer(self) -> tuple[int, ...]:
return (6,)


class Compositef32:
# This is a dict of scalar op classes that need special handling
special: dict = {}

def apply(self, fgraph):
mapping = {}
topo = fgraph.toposort()
for i in fgraph.inputs:
if i.dtype == "float16":
mapping[i] = get_scalar_type("float32")()
if hasattr(i.tag, "test_value"):
mapping[i].tag.test_value = i.tag.test_value
else:
mapping[i] = i
for node in topo:
# Patch up for constants
for i in node.inputs:
if i not in mapping:
assert type(i) is ScalarConstant
if i.type == float16:
ni = ScalarConstant(float32, i.data)
else:
ni = i
mapping[i] = ni
if isinstance(node.op, tuple(self.special)):
self.special[type(node.op)](node, mapping)
continue
new_node = node.clone_with_new_inputs(
[mapping[inp] for inp in node.inputs], strict=False
)
# make sure we don't produce any float16.
assert not any(o.dtype == "float16" for o in new_node.outputs)
mapping.update(zip(node.outputs, new_node.outputs, strict=True))

new_ins = [mapping[inp] for inp in fgraph.inputs]
new_outs = [mapping[out] for out in fgraph.outputs]
return new_ins, new_outs


composite_f32 = Compositef32()


def handle_cast(node, mapping):
inp = mapping[node.inputs[0]]
out = node.outputs[0]
node_ok = False
if node.op.o_type == float16:
if node.inputs[0].type == float32:
# cast f32 -> f16, remove
mapping[out] = inp
return
else:
# cast to f16, convert to f32
new_out = cast(inp, "float32")
# change the node for the following if
node = new_out.owner
mapping[out] = new_out
node_ok = True
if node.inputs[0].type == float16:
if node.op.o_type == inp.type:
# cast f16 to new input type, remove
mapping[out] = inp
return
if not node_ok:
new_node = node.clone_with_new_inputs([inp], strict=False)
mapping[out] = new_node.outputs[0]


Compositef32.special[Cast] = handle_cast


def handle_composite(node, mapping):
new_op = node.op.clone_float32()
new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True)
assert len(new_outs) == len(node.outputs)
for o, no in zip(node.outputs, new_outs, strict=True):
mapping[o] = no


Compositef32.special[Composite] = handle_composite
return (7,)
2 changes: 1 addition & 1 deletion pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def clone(self, name=None, **kwargs):
update, until = self.outputs, None
init = self.inputs[: len(update)]
constant = self.inputs[len(update) :]
return ScalarLoop(
return self.__class__(
init=init,
update=update,
constant=constant,
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,7 +1521,7 @@ def inner_outputs(self):

def clone(self) -> "Scan":
res = copy(self)
res.fgraph = res.fgraph.clone()
res.fgraph = res.fgraph.clone(clone_inner_graphs=True)
return res

def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def inner_outputs(self):

def clone(self):
copy_op = copy(self)
copy_op.fgraph = self.fgraph.clone()
copy_op.fgraph = self.fgraph.clone(clone_inner_graphs=True)
return copy_op

def prepare_node(
Expand Down
36 changes: 35 additions & 1 deletion tests/link/numba/test_compile_ops.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import numpy as np
import pytest

from pytensor import OpFromGraph, config, function, ifelse
from pytensor import Mode, OpFromGraph, config, function, ifelse, scan
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.graph import vectorize_graph
from pytensor.raise_op import assert_op
from pytensor.scalar import Add
from pytensor.scan.op import Scan
from pytensor.tensor import dmatrix, dtensor3, matrix
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.slinalg import Cholesky
from tests.link.numba.test_basic import compare_numba_and_py


Expand Down Expand Up @@ -197,3 +200,34 @@ def test_check_and_raise():
out = assert_op(x.sum(), np.array(True))

compare_numba_and_py([x], out, [x_test_value])


def test_ofg_with_inner_scan_rewrite():
# Regression test where inner scan would be mutated when compiling outer OFG
ys = pt.tensor("ys", shape=(5, 3, 3))
xs = scan(
lambda y: pt.linalg.cholesky(y),
sequences=[ys],
return_updates=False,
mode=Mode(optimizer=None),
)
xs_ofg = OpFromGraph([ys], [xs])(ys)
fn = function([ys], xs_ofg, mode="NUMBA")

# Check that we have a BlockwiseWithCoreShape in the inner Scan
fn_ofg_op = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(fn_ofg_op, OpFromGraph)
fn_scan_op = fn_ofg_op.fgraph.outputs[0].owner.op
assert isinstance(fn_scan_op, Scan)
fn_cholesky_op = fn_scan_op.fgraph.outputs[0].owner.op
assert isinstance(fn_cholesky_op, BlockwiseWithCoreShape)
assert isinstance(fn_cholesky_op.core_op, Cholesky)

# Check original Ops aren't modified
ofg_op = xs_ofg.owner.op
assert isinstance(ofg_op, OpFromGraph)
scan_op = ofg_op.fgraph.outputs[0].owner.op
assert isinstance(scan_op, Scan)
cholesky_op = scan_op.fgraph.outputs[0].owner.op
assert isinstance(cholesky_op, Blockwise)
assert isinstance(cholesky_op.core_op, Cholesky)
41 changes: 0 additions & 41 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
arctan,
arctan2,
arctanh,
cast,
complex64,
constant,
cos,
Expand All @@ -33,7 +32,6 @@
exp,
exp2,
expm1,
float16,
float32,
floats,
int8,
Expand All @@ -53,11 +51,9 @@
sin,
sinh,
sqrt,
switch,
tan,
tanh,
true_div,
uint8,
)
from pytensor.tensor.type import fscalar, imatrix, matrix
from tests.link.test_link import make_function
Expand All @@ -72,43 +68,6 @@ def test_mul_add_true():


class TestComposite:
def test_composite_clone_float32(self):
def has_f16(comp):
if any(v.type == float16 for v in comp.fgraph.variables):
return True
return False

w = int8()
x = float16()
y = float32()
cz = Composite([x, y], [tanh(x + cast(y, "float16"))])
c = Composite(
[w, x, y],
[
cz(x, y)
- cz(x, y) ** 2
+ cast(x, "int16")
+ cast(x, "float32")
+ cast(w, "float16")
- constant(np.float16(1.0))
],
)
assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)

v = uint8()
w = float16()
x = float16()
y = float16()
z = float16()

c = Composite([v, w, x, y, z], [switch(v, mul(w, x, y), z)])

assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)

def test_straightforward(self):
x, y, _z = floats("xyz")
e = mul(add(x, y), true_div(x, y))
Expand Down
Loading