Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,19 @@ def __init__(self, type: _TypeType, data: Any, name: str | None = None):
add_tag_trace(self)

def signature(self):
"""Return a hashable object identifying this Constant by value.

The returned object must satisfy:
1. Hashable: ``hash(sig)`` must not raise.
2. Self-equality: ``sig == sig`` must be ``True`` (not an array).
3. Pickle-stable: ``pickle.loads(pickle.dumps(sig)) == sig``
and same ``hash``. This is required for C module cache keys.

The default ``(type, data)`` is sufficient for simple Python
objects (None, slices, etc.) but breaks for numpy data (NaN,
arrays). Subclasses with numeric data must override this.
See ``TensorConstantSignature``, ``ScalarConstantSignature``.
"""
return (self.type, self.data)

def __str__(self):
Expand Down Expand Up @@ -812,7 +825,7 @@ class FrozenApply(Apply):
constructing a ``FrozenApply`` with the same op and input variables returns
the cached instance.

Constants are keyed by ``(type, data_bytes)`` so that two independently
Constants are keyed by their ``signature()`` so that two independently
created Constants with the same value resolve to the same cached node.
"""

Expand All @@ -828,15 +841,9 @@ def _input_to_key(inp: Variable):
cache keys that would prevent GC from collecting chains of
FrozenApply nodes in a single pass.

Constants use their byte representation so that independently-created
equal constants (including NaN) produce the same key. Object-dtype
constants (e.g. slices) fall back to ``signature()`` since their byte
representation stores pointers, not values.
Constants use their ``signature()`` for value-based deduplication.
"""
if isinstance(inp, Constant):
a = np.asarray(inp.data)
if a.dtype.kind != "O":
return (inp.type, a.tobytes(), a.dtype.str, a.shape)
return inp.signature()
return id(inp)

Expand Down
55 changes: 31 additions & 24 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,9 +979,8 @@ def _resolve_input(inp, memo=memo):
return inp
raise ValueError(
f"Orphan {inp} found in the graph. "
"All variables must be graph inputs, "
"Constants, or produced by Apply nodes "
"reachable from the inputs."
"All variables must be graph inputs, constants, "
"or produced by Apply nodes reachable from the inputs."
)

for node in toposort(outputs, blockers=inputs):
Expand All @@ -992,38 +991,44 @@ def _resolve_input(inp, memo=memo):

memo.update(zip(node.outputs, new_node.outputs, strict=True))

# Handle outputs that are Constants or AtomicVariables not
# encountered during toposort (e.g. a graph with no Apply nodes)
for o in outputs:
if o not in memo:
# TODO: We could create those dummy ApplyOutput here and get the interned constant
if isinstance(o, Constant):
memo[o] = o

try:
frozen_outputs = tuple(memo[o] for o in outputs)
except KeyError:
# TODO: Can this ever happen if we didn't fail in the previous look?
unmapped = [o for o in outputs if o not in memo]
raise ValueError(
f"Output variable {unmapped[0]} could not be mapped to a frozen "
"graph variable. All outputs must be graph inputs, "
"constants, or produced by Apply nodes reachable from "
"the inputs."
)
# Create dummy Output nodes for each output, mirroring FunctionGraph.
# (It also makes eq/hash cheaper)
output_nodes = []
for i, o in enumerate(outputs):
try:
resolved = memo[o]
except KeyError:
if not isinstance(o, AtomicVariable):
raise ValueError(
f"Output variable {o} could not be mapped to a frozen graph variable. "
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really possible for a user to reach here? I ask because the error isn't really actionable, so it's another pytensor-special cryptic error if you do end up here somewhere.

Copy link
Copy Markdown
Member Author

@ricardoV94 ricardoV94 Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this doubt myself before. I don't think so, it would trigger in the first loop. Wanna take it out?

One edge case is a NominalVariable output not specified as input. That should fail, which means this branch should only really accept Constants (not general AtomicVariables)

That may be findable only here but would need to check

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make it an assert if we just want it for safety but dont think it should ever be hit (keeping the message)

"All outputs must be graph inputs, constants, "
"or produced by Apply nodes reachable from the inputs."
)
# A constant or graph input passed directly as output;
# these are the only cases not already in memo.
out_node = FrozenApply(Output(i), (o,), ())
output_nodes.append(out_node)
# FrozenApply interning may return a cached node that holds a previously seen equal constant.
# Store the canonical constant in memo for vars_between/clients.
memo[o] = out_node.inputs[0]
else:
output_nodes.append(FrozenApply(Output(i), (resolved,), ()))

self.inputs: tuple[Variable, ...] = nominal_inputs
self.outputs: tuple[Variable, ...] = frozen_outputs
self.outputs: tuple[Variable, ...] = tuple(
node.inputs[0] for node in output_nodes
)
self.apply_nodes: frozenset[Apply] = frozenset(sorted_apply_nodes)
self._toposort: tuple[Apply, ...] = tuple(sorted_apply_nodes)
self._output_nodes: tuple[Apply, ...] = tuple(output_nodes)
self._variables: frozenset[Variable] | None = None
self._clients: dict[Variable, list[ClientType]] | None = None

def __reduce__(self):
return FrozenFunctionGraph, (self.inputs, self.outputs)

def __hash__(self):
return hash(self.outputs)
return hash(self._output_nodes)

def __eq__(self, other):
if self is other:
Expand Down Expand Up @@ -1057,6 +1062,8 @@ def clients(self) -> dict[Variable, list[ClientType]]: # type: ignore[override]
for node in self.toposort():
for i, inp in enumerate(node.inputs):
clients[inp].append((node, i))
for out_node in self._output_nodes:
clients[out_node.inputs[0]].append((out_node, 0))
self._clients = clients
return self._clients

Expand Down
21 changes: 21 additions & 0 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,10 +958,31 @@ class ScalarVariable(_scalar_py_operators, Variable):
ScalarType.variable_type = ScalarVariable


class ScalarConstantSignature(tuple):
"""Signature for ScalarConstant that handles NaN equality and hashing."""

def __eq__(self, other):
if type(self) is not type(other):
return False
(t0, d0), (t1, d1) = self, other
if t0 != t1:
return False
return (d0 == d1) or (np.isnan(d0) and np.isnan(d1))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are there any other corner cases with np.inf? I don't think so just checking.


def __hash__(self):
t, d = self
if np.isnan(d):
return hash((type(self), t, "NaN"))
return hash((type(self), t, d))


class ScalarConstant(ScalarVariable, Constant):
def __init__(self, *args, **kwargs):
Constant.__init__(self, *args, **kwargs)

def signature(self):
return ScalarConstantSignature((self.type, self.data))


# Register ScalarConstant as the type of Constant corresponding to ScalarType
ScalarType.constant_type = ScalarConstant
Expand Down
45 changes: 29 additions & 16 deletions tests/graph/test_fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from pytensor.configdefaults import config
from pytensor.graph.basic import NominalVariable
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph, Output
from pytensor.graph.utils import MissingInputError
from pytensor.printing import debugprint
from pytensor.scalar.basic import add, float64, mul
from pytensor.scalar.basic import ScalarConstant, add, float64, mul
from tests.graph.utils import (
MyConstant,
MyOp,
Expand Down Expand Up @@ -810,8 +810,6 @@ def test_cross_graph_output_identity(self):
assert all(a is b for a, b in zip(ffg1.outputs, ffg2.outputs))

def test_pickle_round_trip(self):
from pytensor.scalar.basic import add, float64, mul

x, y = float64("x"), float64("y")
ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze()

Expand All @@ -822,8 +820,6 @@ def test_pickle_round_trip(self):
assert all(o1 is o2 for o1, o2 in zip(ffg.outputs, ffg2.outputs))

def test_pickle_with_constants(self):
from pytensor.scalar.basic import ScalarConstant, add, float64

x = float64("x")
c = ScalarConstant(float64, 3.14)
ffg = FunctionGraph([x], [add(x, c)]).freeze()
Expand All @@ -843,8 +839,6 @@ def test_pickle_identity_output(self):

def test_pickle_multi_output_shared_subexpr(self):
"""Pickle round-trip with multiple outputs sharing subexpressions."""
from pytensor.scalar.basic import add, float64, mul

x, y = float64("x"), float64("y")
shared = add(x, y)
out1 = mul(shared, x)
Expand All @@ -857,8 +851,6 @@ def test_pickle_multi_output_shared_subexpr(self):

def test_pickle_hash_stability(self):
"""Hash is the same before and after pickle, and across independent constructions."""
from pytensor.scalar.basic import add, float64, mul

x, y = float64("x"), float64("y")
ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze()
h_before = hash(ffg)
Expand All @@ -878,17 +870,13 @@ def test_different_arity_not_equal(self):
assert ffg1 != ffg2

def test_orphan_non_constant_raises(self):
from pytensor.graph.fg import FrozenFunctionGraph

var1 = MyVariable("x")
orphan = MyVariable("orphan")
out = op1(var1, orphan)
with pytest.raises(ValueError, match=r"Orphan.*orphan"):
FrozenFunctionGraph([var1], [out])

def test_unmapped_output_raises(self):
from pytensor.graph.fg import FrozenFunctionGraph

var1 = MyVariable("x")
disconnected = MyVariable("disconnected")
with pytest.raises(ValueError, match="could not be mapped"):
Expand All @@ -903,8 +891,6 @@ def test_interned_constant_in_variables(self):
cache hit overwrites memo[c2]=c1, c2 is evicted from variables
while the cache-miss node still references it.
"""
from pytensor.graph.fg import FrozenFunctionGraph
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should consider adding the ruff rule that all imports need to be on top of the file, im embarrassed i checked this in.


op_shared = MyOp("shared")
op_unique = MyOp("unique")

Expand All @@ -925,6 +911,33 @@ def test_interned_constant_in_variables(self):
for inp in node.inputs:
assert inp in fg.variables

def test_constant_output_equality(self):
"""FFGs with distinct but equal constant outputs should be equal."""
c1 = ScalarConstant(float64, 3.14)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other circle constants you know: https://www.youtube.com/watch?v=dUGFb_HgG1c

c2 = ScalarConstant(float64, 3.14)
assert c1 is not c2

ffg1 = FrozenFunctionGraph([], [c1])
ffg2 = FrozenFunctionGraph([], [c2])
assert ffg1 == ffg2
assert hash(ffg1) == hash(ffg2)
assert ffg1.outputs == ffg2.outputs

def test_output_clients(self):
"""Output variables should have dummy Output node clients."""
x, y = float64("x"), float64("y")
ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze()

for i, out in enumerate(ffg.outputs):
out_clients = ffg.clients[out]
output_clients = [
(node, idx) for node, idx in out_clients if isinstance(node.op, Output)
]
assert len(output_clients) == 1
node, idx = output_clients[0]
assert node.op.idx == i
assert idx == 0

def test_freeze_unfreeze_round_trip(self):
x, y = float64("x"), float64("y")
ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze()
Expand Down
Loading