diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 055497a8ff..f924288a86 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -872,7 +872,7 @@ def inner_outputs(self): def clone(self): res = copy(self) - res.fgraph = res.fgraph.clone() + res.fgraph = res.fgraph.clone(clone_inner_graphs=True) return res def perform(self, node, inputs, outputs): diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 2728d50a50..a0a71fab4f 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -838,9 +838,13 @@ def check_integrity(self) -> None: def __repr__(self): return f"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})" - def clone(self, check_integrity=True) -> "FunctionGraph": + def clone( + self, check_integrity=True, clone_inner_graphs: bool = False + ) -> "FunctionGraph": """Clone the graph.""" - return self.clone_get_equiv(check_integrity)[0] + return self.clone_get_equiv( + check_integrity, clone_inner_graphs=clone_inner_graphs + )[0] def clone_get_equiv( self, check_integrity: bool = True, attach_feature: bool = True, **kwargs diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index af0b0b7173..fa717b8d65 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4300,14 +4300,8 @@ def fgraph(self): self._fgraph = fgraph return self._fgraph - def clone_float32(self): - # This will not modify the fgraph or the nodes - new_ins, new_outs = composite_f32.apply(self.fgraph) - return Composite(new_ins, new_outs) - def clone(self): - new_ins, new_outs = composite_f32.apply(self.fgraph) - return Composite(new_ins, new_outs) + return self.__class__(self.fgraph.inputs, self.fgraph.outputs) def output_types(self, input_types): if tuple(input_types) != self.inputs_type: @@ -4423,86 +4417,4 @@ def c_code(self, node, nodename, inames, onames, sub): return self.c_code_template % d def c_code_cache_version_outer(self) -> tuple[int, ...]: - return (6,) - - -class Compositef32: - # This is a dict of scalar op classes that need special handling - special: dict = {} - - def apply(self, fgraph): - mapping = {} - topo = fgraph.toposort() - for i in fgraph.inputs: - if i.dtype == "float16": - mapping[i] = get_scalar_type("float32")() - if hasattr(i.tag, "test_value"): - mapping[i].tag.test_value = i.tag.test_value - else: - mapping[i] = i - for node in topo: - # Patch up for constants - for i in node.inputs: - if i not in mapping: - assert type(i) is ScalarConstant - if i.type == float16: - ni = ScalarConstant(float32, i.data) - else: - ni = i - mapping[i] = ni - if isinstance(node.op, tuple(self.special)): - self.special[type(node.op)](node, mapping) - continue - new_node = node.clone_with_new_inputs( - [mapping[inp] for inp in node.inputs], strict=False - ) - # make sure we don't produce any float16. - assert not any(o.dtype == "float16" for o in new_node.outputs) - mapping.update(zip(node.outputs, new_node.outputs, strict=True)) - - new_ins = [mapping[inp] for inp in fgraph.inputs] - new_outs = [mapping[out] for out in fgraph.outputs] - return new_ins, new_outs - - -composite_f32 = Compositef32() - - -def handle_cast(node, mapping): - inp = mapping[node.inputs[0]] - out = node.outputs[0] - node_ok = False - if node.op.o_type == float16: - if node.inputs[0].type == float32: - # cast f32 -> f16, remove - mapping[out] = inp - return - else: - # cast to f16, convert to f32 - new_out = cast(inp, "float32") - # change the node for the following if - node = new_out.owner - mapping[out] = new_out - node_ok = True - if node.inputs[0].type == float16: - if node.op.o_type == inp.type: - # cast f16 to new input type, remove - mapping[out] = inp - return - if not node_ok: - new_node = node.clone_with_new_inputs([inp], strict=False) - mapping[out] = new_node.outputs[0] - - -Compositef32.special[Cast] = handle_cast - - -def handle_composite(node, mapping): - new_op = node.op.clone_float32() - new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True) - assert len(new_outs) == len(node.outputs) - for o, no in zip(node.outputs, new_outs, strict=True): - mapping[o] = no - - -Compositef32.special[Composite] = handle_composite + return (7,) diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index de2dfb4f30..888db3e52b 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -123,7 +123,7 @@ def clone(self, name=None, **kwargs): update, until = self.outputs, None init = self.inputs[: len(update)] constant = self.inputs[len(update) :] - return ScalarLoop( + return self.__class__( init=init, update=update, constant=constant, diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 6efeddc8bb..cf6bebfb82 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1521,7 +1521,7 @@ def inner_outputs(self): def clone(self) -> "Scan": res = copy(self) - res.fgraph = res.fgraph.clone() + res.fgraph = res.fgraph.clone(clone_inner_graphs=True) return res def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 7653d01b54..3bfc8be09f 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -211,7 +211,7 @@ def inner_outputs(self): def clone(self): copy_op = copy(self) - copy_op.fgraph = self.fgraph.clone() + copy_op.fgraph = self.fgraph.clone(clone_inner_graphs=True) return copy_op def prepare_node( diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index 5e263762ae..f6af836029 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -1,14 +1,17 @@ import numpy as np import pytest -from pytensor import OpFromGraph, config, function, ifelse +from pytensor import Mode, OpFromGraph, config, function, ifelse, scan from pytensor import tensor as pt from pytensor.compile import ViewOp from pytensor.graph import vectorize_graph from pytensor.raise_op import assert_op from pytensor.scalar import Add +from pytensor.scan.op import Scan from pytensor.tensor import dmatrix, dtensor3, matrix +from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.slinalg import Cholesky from tests.link.numba.test_basic import compare_numba_and_py @@ -197,3 +200,34 @@ def test_check_and_raise(): out = assert_op(x.sum(), np.array(True)) compare_numba_and_py([x], out, [x_test_value]) + + +def test_ofg_with_inner_scan_rewrite(): + # Regression test where inner scan would be mutated when compiling outer OFG + ys = pt.tensor("ys", shape=(5, 3, 3)) + xs = scan( + lambda y: pt.linalg.cholesky(y), + sequences=[ys], + return_updates=False, + mode=Mode(optimizer=None), + ) + xs_ofg = OpFromGraph([ys], [xs])(ys) + fn = function([ys], xs_ofg, mode="NUMBA") + + # Check that we have a BlockwiseWithCoreShape in the inner Scan + fn_ofg_op = fn.maker.fgraph.outputs[0].owner.op + assert isinstance(fn_ofg_op, OpFromGraph) + fn_scan_op = fn_ofg_op.fgraph.outputs[0].owner.op + assert isinstance(fn_scan_op, Scan) + fn_cholesky_op = fn_scan_op.fgraph.outputs[0].owner.op + assert isinstance(fn_cholesky_op, BlockwiseWithCoreShape) + assert isinstance(fn_cholesky_op.core_op, Cholesky) + + # Check original Ops aren't modified + ofg_op = xs_ofg.owner.op + assert isinstance(ofg_op, OpFromGraph) + scan_op = ofg_op.fgraph.outputs[0].owner.op + assert isinstance(scan_op, Scan) + cholesky_op = scan_op.fgraph.outputs[0].owner.op + assert isinstance(cholesky_op, Blockwise) + assert isinstance(cholesky_op.core_op, Cholesky) diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 0f4b0db607..5fadaa002a 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -23,7 +23,6 @@ arctan, arctan2, arctanh, - cast, complex64, constant, cos, @@ -33,7 +32,6 @@ exp, exp2, expm1, - float16, float32, floats, int8, @@ -53,11 +51,9 @@ sin, sinh, sqrt, - switch, tan, tanh, true_div, - uint8, ) from pytensor.tensor.type import fscalar, imatrix, matrix from tests.link.test_link import make_function @@ -72,43 +68,6 @@ def test_mul_add_true(): class TestComposite: - def test_composite_clone_float32(self): - def has_f16(comp): - if any(v.type == float16 for v in comp.fgraph.variables): - return True - return False - - w = int8() - x = float16() - y = float32() - cz = Composite([x, y], [tanh(x + cast(y, "float16"))]) - c = Composite( - [w, x, y], - [ - cz(x, y) - - cz(x, y) ** 2 - + cast(x, "int16") - + cast(x, "float32") - + cast(w, "float16") - - constant(np.float16(1.0)) - ], - ) - assert has_f16(c) - nc = c.clone_float32() - assert not has_f16(nc) - - v = uint8() - w = float16() - x = float16() - y = float16() - z = float16() - - c = Composite([v, w, x, y, z], [switch(v, mul(w, x, y), z)]) - - assert has_f16(c) - nc = c.clone_float32() - assert not has_f16(nc) - def test_straightforward(self): x, y, _z = floats("xyz") e = mul(add(x, y), true_div(x, y))