diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 4df7265b80..32601cdb3e 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -753,6 +753,19 @@ def __init__(self, type: _TypeType, data: Any, name: str | None = None): add_tag_trace(self) def signature(self): + """Return a hashable object identifying this Constant by value. + + The returned object must satisfy: + 1. Hashable: ``hash(sig)`` must not raise. + 2. Self-equality: ``sig == sig`` must be ``True`` (not an array). + 3. Pickle-stable: ``pickle.loads(pickle.dumps(sig)) == sig`` + and same ``hash``. This is required for C module cache keys. + + The default ``(type, data)`` is sufficient for simple Python + objects (None, slices, etc.) but breaks for numpy data (NaN, + arrays). Subclasses with numeric data must override this. + See ``TensorConstantSignature``, ``ScalarConstantSignature``. + """ return (self.type, self.data) def __str__(self): @@ -812,7 +825,7 @@ class FrozenApply(Apply): constructing a ``FrozenApply`` with the same op and input variables returns the cached instance. - Constants are keyed by ``(type, data_bytes)`` so that two independently + Constants are keyed by their ``signature()`` so that two independently created Constants with the same value resolve to the same cached node. """ @@ -828,15 +841,9 @@ def _input_to_key(inp: Variable): cache keys that would prevent GC from collecting chains of FrozenApply nodes in a single pass. - Constants use their byte representation so that independently-created - equal constants (including NaN) produce the same key. Object-dtype - constants (e.g. slices) fall back to ``signature()`` since their byte - representation stores pointers, not values. + Constants use their ``signature()`` for value-based deduplication. """ if isinstance(inp, Constant): - a = np.asarray(inp.data) - if a.dtype.kind != "O": - return (inp.type, a.tobytes(), a.dtype.str, a.shape) return inp.signature() return id(inp) diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index cc053a18ee..e7958c20a3 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -979,9 +979,8 @@ def _resolve_input(inp, memo=memo): return inp raise ValueError( f"Orphan {inp} found in the graph. " - "All variables must be graph inputs, " - "Constants, or produced by Apply nodes " - "reachable from the inputs." + "All variables must be graph inputs, constants, " + "or produced by Apply nodes reachable from the inputs." ) for node in toposort(outputs, blockers=inputs): @@ -992,30 +991,36 @@ def _resolve_input(inp, memo=memo): memo.update(zip(node.outputs, new_node.outputs, strict=True)) - # Handle outputs that are Constants or AtomicVariables not - # encountered during toposort (e.g. a graph with no Apply nodes) - for o in outputs: - if o not in memo: - # TODO: We could create those dummy ApplyOutput here and get the interned constant - if isinstance(o, Constant): - memo[o] = o - - try: - frozen_outputs = tuple(memo[o] for o in outputs) - except KeyError: - # TODO: Can this ever happen if we didn't fail in the previous look? - unmapped = [o for o in outputs if o not in memo] - raise ValueError( - f"Output variable {unmapped[0]} could not be mapped to a frozen " - "graph variable. All outputs must be graph inputs, " - "constants, or produced by Apply nodes reachable from " - "the inputs." - ) + # Create dummy Output nodes for each output, mirroring FunctionGraph. + # (It also makes eq/hash cheaper) + output_nodes = [] + for i, o in enumerate(outputs): + try: + resolved = memo[o] + except KeyError: + if not isinstance(o, AtomicVariable): + raise ValueError( + f"Output variable {o} could not be mapped to a frozen graph variable. " + "All outputs must be graph inputs, constants, " + "or produced by Apply nodes reachable from the inputs." + ) + # A constant or graph input passed directly as output; + # these are the only cases not already in memo. + out_node = FrozenApply(Output(i), (o,), ()) + output_nodes.append(out_node) + # FrozenApply interning may return a cached node that holds a previously seen equal constant. + # Store the canonical constant in memo for vars_between/clients. + memo[o] = out_node.inputs[0] + else: + output_nodes.append(FrozenApply(Output(i), (resolved,), ())) self.inputs: tuple[Variable, ...] = nominal_inputs - self.outputs: tuple[Variable, ...] = frozen_outputs + self.outputs: tuple[Variable, ...] = tuple( + node.inputs[0] for node in output_nodes + ) self.apply_nodes: frozenset[Apply] = frozenset(sorted_apply_nodes) self._toposort: tuple[Apply, ...] = tuple(sorted_apply_nodes) + self._output_nodes: tuple[Apply, ...] = tuple(output_nodes) self._variables: frozenset[Variable] | None = None self._clients: dict[Variable, list[ClientType]] | None = None @@ -1023,7 +1028,7 @@ def __reduce__(self): return FrozenFunctionGraph, (self.inputs, self.outputs) def __hash__(self): - return hash(self.outputs) + return hash(self._output_nodes) def __eq__(self, other): if self is other: @@ -1057,6 +1062,8 @@ def clients(self) -> dict[Variable, list[ClientType]]: # type: ignore[override] for node in self.toposort(): for i, inp in enumerate(node.inputs): clients[inp].append((node, i)) + for out_node in self._output_nodes: + clients[out_node.inputs[0]].append((out_node, 0)) self._clients = clients return self._clients diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 14f8dcf0f5..36d3252565 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -958,10 +958,31 @@ class ScalarVariable(_scalar_py_operators, Variable): ScalarType.variable_type = ScalarVariable +class ScalarConstantSignature(tuple): + """Signature for ScalarConstant that handles NaN equality and hashing.""" + + def __eq__(self, other): + if type(self) is not type(other): + return False + (t0, d0), (t1, d1) = self, other + if t0 != t1: + return False + return (d0 == d1) or (np.isnan(d0) and np.isnan(d1)) + + def __hash__(self): + t, d = self + if np.isnan(d): + return hash((type(self), t, "NaN")) + return hash((type(self), t, d)) + + class ScalarConstant(ScalarVariable, Constant): def __init__(self, *args, **kwargs): Constant.__init__(self, *args, **kwargs) + def signature(self): + return ScalarConstantSignature((self.type, self.data)) + # Register ScalarConstant as the type of Constant corresponding to ScalarType ScalarType.constant_type = ScalarConstant diff --git a/tests/graph/test_fg.py b/tests/graph/test_fg.py index 53a8d2ec7e..ecaa4fee6d 100644 --- a/tests/graph/test_fg.py +++ b/tests/graph/test_fg.py @@ -5,10 +5,10 @@ from pytensor.configdefaults import config from pytensor.graph.basic import NominalVariable -from pytensor.graph.fg import FunctionGraph, Output +from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph, Output from pytensor.graph.utils import MissingInputError from pytensor.printing import debugprint -from pytensor.scalar.basic import add, float64, mul +from pytensor.scalar.basic import ScalarConstant, add, float64, mul from tests.graph.utils import ( MyConstant, MyOp, @@ -810,8 +810,6 @@ def test_cross_graph_output_identity(self): assert all(a is b for a, b in zip(ffg1.outputs, ffg2.outputs)) def test_pickle_round_trip(self): - from pytensor.scalar.basic import add, float64, mul - x, y = float64("x"), float64("y") ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze() @@ -822,8 +820,6 @@ def test_pickle_round_trip(self): assert all(o1 is o2 for o1, o2 in zip(ffg.outputs, ffg2.outputs)) def test_pickle_with_constants(self): - from pytensor.scalar.basic import ScalarConstant, add, float64 - x = float64("x") c = ScalarConstant(float64, 3.14) ffg = FunctionGraph([x], [add(x, c)]).freeze() @@ -843,8 +839,6 @@ def test_pickle_identity_output(self): def test_pickle_multi_output_shared_subexpr(self): """Pickle round-trip with multiple outputs sharing subexpressions.""" - from pytensor.scalar.basic import add, float64, mul - x, y = float64("x"), float64("y") shared = add(x, y) out1 = mul(shared, x) @@ -857,8 +851,6 @@ def test_pickle_multi_output_shared_subexpr(self): def test_pickle_hash_stability(self): """Hash is the same before and after pickle, and across independent constructions.""" - from pytensor.scalar.basic import add, float64, mul - x, y = float64("x"), float64("y") ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze() h_before = hash(ffg) @@ -878,8 +870,6 @@ def test_different_arity_not_equal(self): assert ffg1 != ffg2 def test_orphan_non_constant_raises(self): - from pytensor.graph.fg import FrozenFunctionGraph - var1 = MyVariable("x") orphan = MyVariable("orphan") out = op1(var1, orphan) @@ -887,8 +877,6 @@ def test_orphan_non_constant_raises(self): FrozenFunctionGraph([var1], [out]) def test_unmapped_output_raises(self): - from pytensor.graph.fg import FrozenFunctionGraph - var1 = MyVariable("x") disconnected = MyVariable("disconnected") with pytest.raises(ValueError, match="could not be mapped"): @@ -903,8 +891,6 @@ def test_interned_constant_in_variables(self): cache hit overwrites memo[c2]=c1, c2 is evicted from variables while the cache-miss node still references it. """ - from pytensor.graph.fg import FrozenFunctionGraph - op_shared = MyOp("shared") op_unique = MyOp("unique") @@ -925,6 +911,33 @@ def test_interned_constant_in_variables(self): for inp in node.inputs: assert inp in fg.variables + def test_constant_output_equality(self): + """FFGs with distinct but equal constant outputs should be equal.""" + c1 = ScalarConstant(float64, 3.14) + c2 = ScalarConstant(float64, 3.14) + assert c1 is not c2 + + ffg1 = FrozenFunctionGraph([], [c1]) + ffg2 = FrozenFunctionGraph([], [c2]) + assert ffg1 == ffg2 + assert hash(ffg1) == hash(ffg2) + assert ffg1.outputs == ffg2.outputs + + def test_output_clients(self): + """Output variables should have dummy Output node clients.""" + x, y = float64("x"), float64("y") + ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze() + + for i, out in enumerate(ffg.outputs): + out_clients = ffg.clients[out] + output_clients = [ + (node, idx) for node, idx in out_clients if isinstance(node.op, Output) + ] + assert len(output_clients) == 1 + node, idx = output_clients[0] + assert node.op.idx == i + assert idx == 0 + def test_freeze_unfreeze_round_trip(self): x, y = float64("x"), float64("y") ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze()