Skip to content

Commit bad445e

Browse files
Merge pull request #2945 from devitocodes/investigate-hanging
compiler: Fixup DefFunction reconstruction
2 parents cf29367 + d33ee66 commit bad445e

2 files changed

Lines changed: 38 additions & 1 deletion

File tree

devito/symbolics/extended_sympy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sympy
1010
from sympy import Expr, Function, Number, Tuple, cacheit, sympify
1111
from sympy.core.decorators import call_highest_priority
12+
from sympy.core.function import Application
1213
from sympy.logic.boolalg import BooleanFunction
1314

1415
from devito.finite_differences.elementary import Max, Min
@@ -718,7 +719,13 @@ def __new__(cls, name, arguments=None, template=None, **kwargs):
718719
if _template:
719720
args.append(Tuple(*_template))
720721

721-
obj = Function.__new__(cls, *args)
722+
# `Function.__new__` and `Application.__new__` are both cached by
723+
# SymPy. DefFunction subclasses may attach reconstruction kwargs as
724+
# side attributes after this base constructor returns; going through
725+
# the cached route could then alias a previous object and mutate it
726+
# during reconstruction. Call Application's uncached constructor
727+
# explicitly instead of using super()/Function.__new__.
728+
obj = Application.__new__.__wrapped__(cls, *args)
722729
obj._name = name
723730
obj._arguments = tuple(_arguments)
724731
obj._template = tuple(_template)

tests/test_symbolics.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,36 @@ def __new__(cls, name=None, arguments=None, p0=None, p1=None, p2=None):
934934
assert func1.p1 == (g,)
935935
assert func1.p2 == 'bar'
936936

937+
def test_custom_def_function_reconstruction_no_aliasing(self):
938+
939+
class MyDefFunction(DefFunction):
940+
__rargs__ = ('name', 'arguments')
941+
__rkwargs__ = ('p0',)
942+
943+
def __new__(cls, name=None, arguments=None, p0=None):
944+
obj = super().__new__(cls, name=name, arguments=arguments)
945+
obj.p0 = p0
946+
return obj
947+
948+
def _hashable_content(self):
949+
return super()._hashable_content() + (self.p0,)
950+
951+
grid = Grid(shape=(4, 4))
952+
953+
f = Function(name='f', grid=grid)
954+
g = Function(name='g', grid=grid)
955+
956+
func0 = MyDefFunction(name='foo', arguments=f.indexify(), p0=f)
957+
h0 = hash(func0)
958+
959+
func1 = func0.func(p0=g)
960+
961+
assert func1 is not func0
962+
assert func1 != func0
963+
assert hash(func0) == h0
964+
assert func0.p0 is f
965+
assert func1.p0 is g
966+
937967
def test_reduce_to_number(self):
938968
grid = Grid(shape=(4, 4))
939969
x, _ = grid.dimensions

0 commit comments

Comments
 (0)