-
Notifications
You must be signed in to change notification settings - Fork 182
Cleanup frozen FunctionGraph #2043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v3
Are you sure you want to change the base?
Changes from all commits
d53917c
054bb36
0210b63
f789d33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -958,10 +958,31 @@ class ScalarVariable(_scalar_py_operators, Variable): | |
| ScalarType.variable_type = ScalarVariable | ||
|
|
||
|
|
||
| class ScalarConstantSignature(tuple): | ||
| """Signature for ScalarConstant that handles NaN equality and hashing.""" | ||
|
|
||
| def __eq__(self, other): | ||
| if type(self) is not type(other): | ||
| return False | ||
| (t0, d0), (t1, d1) = self, other | ||
| if t0 != t1: | ||
| return False | ||
| return (d0 == d1) or (np.isnan(d0) and np.isnan(d1)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are there any other corner cases with np.inf? I don't think so just checking. |
||
|
|
||
| def __hash__(self): | ||
| t, d = self | ||
| if np.isnan(d): | ||
| return hash((type(self), t, "NaN")) | ||
| return hash((type(self), t, d)) | ||
|
|
||
|
|
||
| class ScalarConstant(ScalarVariable, Constant): | ||
| def __init__(self, *args, **kwargs): | ||
| Constant.__init__(self, *args, **kwargs) | ||
|
|
||
| def signature(self): | ||
| return ScalarConstantSignature((self.type, self.data)) | ||
|
|
||
|
|
||
| # Register ScalarConstant as the type of Constant corresponding to ScalarType | ||
| ScalarType.constant_type = ScalarConstant | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,10 +5,10 @@ | |
|
|
||
| from pytensor.configdefaults import config | ||
| from pytensor.graph.basic import NominalVariable | ||
| from pytensor.graph.fg import FunctionGraph, Output | ||
| from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph, Output | ||
| from pytensor.graph.utils import MissingInputError | ||
| from pytensor.printing import debugprint | ||
| from pytensor.scalar.basic import add, float64, mul | ||
| from pytensor.scalar.basic import ScalarConstant, add, float64, mul | ||
| from tests.graph.utils import ( | ||
| MyConstant, | ||
| MyOp, | ||
|
|
@@ -810,8 +810,6 @@ def test_cross_graph_output_identity(self): | |
| assert all(a is b for a, b in zip(ffg1.outputs, ffg2.outputs)) | ||
|
|
||
| def test_pickle_round_trip(self): | ||
| from pytensor.scalar.basic import add, float64, mul | ||
|
|
||
| x, y = float64("x"), float64("y") | ||
| ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze() | ||
|
|
||
|
|
@@ -822,8 +820,6 @@ def test_pickle_round_trip(self): | |
| assert all(o1 is o2 for o1, o2 in zip(ffg.outputs, ffg2.outputs)) | ||
|
|
||
| def test_pickle_with_constants(self): | ||
| from pytensor.scalar.basic import ScalarConstant, add, float64 | ||
|
|
||
| x = float64("x") | ||
| c = ScalarConstant(float64, 3.14) | ||
| ffg = FunctionGraph([x], [add(x, c)]).freeze() | ||
|
|
@@ -843,8 +839,6 @@ def test_pickle_identity_output(self): | |
|
|
||
| def test_pickle_multi_output_shared_subexpr(self): | ||
| """Pickle round-trip with multiple outputs sharing subexpressions.""" | ||
| from pytensor.scalar.basic import add, float64, mul | ||
|
|
||
| x, y = float64("x"), float64("y") | ||
| shared = add(x, y) | ||
| out1 = mul(shared, x) | ||
|
|
@@ -857,8 +851,6 @@ def test_pickle_multi_output_shared_subexpr(self): | |
|
|
||
| def test_pickle_hash_stability(self): | ||
| """Hash is the same before and after pickle, and across independent constructions.""" | ||
| from pytensor.scalar.basic import add, float64, mul | ||
|
|
||
| x, y = float64("x"), float64("y") | ||
| ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze() | ||
| h_before = hash(ffg) | ||
|
|
@@ -878,17 +870,13 @@ def test_different_arity_not_equal(self): | |
| assert ffg1 != ffg2 | ||
|
|
||
| def test_orphan_non_constant_raises(self): | ||
| from pytensor.graph.fg import FrozenFunctionGraph | ||
|
|
||
| var1 = MyVariable("x") | ||
| orphan = MyVariable("orphan") | ||
| out = op1(var1, orphan) | ||
| with pytest.raises(ValueError, match=r"Orphan.*orphan"): | ||
| FrozenFunctionGraph([var1], [out]) | ||
|
|
||
| def test_unmapped_output_raises(self): | ||
| from pytensor.graph.fg import FrozenFunctionGraph | ||
|
|
||
| var1 = MyVariable("x") | ||
| disconnected = MyVariable("disconnected") | ||
| with pytest.raises(ValueError, match="could not be mapped"): | ||
|
|
@@ -903,8 +891,6 @@ def test_interned_constant_in_variables(self): | |
| cache hit overwrites memo[c2]=c1, c2 is evicted from variables | ||
| while the cache-miss node still references it. | ||
| """ | ||
| from pytensor.graph.fg import FrozenFunctionGraph | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should consider adding the ruff rule that all imports need to be on top of the file, im embarrassed i checked this in. |
||
|
|
||
| op_shared = MyOp("shared") | ||
| op_unique = MyOp("unique") | ||
|
|
||
|
|
@@ -925,6 +911,33 @@ def test_interned_constant_in_variables(self): | |
| for inp in node.inputs: | ||
| assert inp in fg.variables | ||
|
|
||
| def test_constant_output_equality(self): | ||
| """FFGs with distinct but equal constant outputs should be equal.""" | ||
| c1 = ScalarConstant(float64, 3.14) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are other circle constants you know: https://www.youtube.com/watch?v=dUGFb_HgG1c |
||
| c2 = ScalarConstant(float64, 3.14) | ||
| assert c1 is not c2 | ||
|
|
||
| ffg1 = FrozenFunctionGraph([], [c1]) | ||
| ffg2 = FrozenFunctionGraph([], [c2]) | ||
| assert ffg1 == ffg2 | ||
| assert hash(ffg1) == hash(ffg2) | ||
| assert ffg1.outputs == ffg2.outputs | ||
|
|
||
| def test_output_clients(self): | ||
| """Output variables should have dummy Output node clients.""" | ||
| x, y = float64("x"), float64("y") | ||
| ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze() | ||
|
|
||
| for i, out in enumerate(ffg.outputs): | ||
| out_clients = ffg.clients[out] | ||
| output_clients = [ | ||
| (node, idx) for node, idx in out_clients if isinstance(node.op, Output) | ||
| ] | ||
| assert len(output_clients) == 1 | ||
| node, idx = output_clients[0] | ||
| assert node.op.idx == i | ||
| assert idx == 0 | ||
|
|
||
| def test_freeze_unfreeze_round_trip(self): | ||
| x, y = float64("x"), float64("y") | ||
| ffg = FunctionGraph([x, y], [mul(add(x, y), y)]).freeze() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it really possible for a user to reach here? I ask because the error isn't really actionable, so it's another pytensor-special cryptic error if you do end up here somewhere.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had this doubt myself before. I don't think so, it would trigger in the first loop. Wanna take it out?
One edge case is a NominalVariable output not specified as input. That should fail, which means this branch should only really accept Constants (not general AtomicVariables)
That may be findable only here but would need to check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can make it an assert if we just want it for safety but dont think it should ever be hit (keeping the message)