Skip to content

Commit 9314073

Browse files
majosminducer
authored andcommitted
add any/all/array_equal to PytatoPyOpenCLArrayContext.np
1 parent e0e1a42 commit 9314073

2 files changed

Lines changed: 32 additions & 2 deletions

File tree

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,11 @@
2626
from arraycontext.fake_numpy import (
2727
BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace,
2828
)
29+
from arraycontext.container import is_array_container
2930
from arraycontext.container.traversal import (
30-
rec_multimap_array_container, rec_map_array_container,
31+
rec_map_array_container,
32+
rec_multimap_array_container,
33+
multimap_reduce_array_container,
3134
rec_map_reduce_array_container,
3235
)
3336
import pytato as pt
@@ -158,4 +161,31 @@ def _rec_ravel(a):
158161

159162
return rec_map_array_container(_rec_ravel, a)
160163

164+
def any(self, a):
165+
return rec_map_reduce_array_container(
166+
partial(reduce, pt.logical_or),
167+
lambda subary: pt.any(subary), a)
168+
169+
def all(self, a):
170+
return rec_map_reduce_array_container(
171+
partial(reduce, pt.logical_and),
172+
lambda subary: pt.all(subary), a)
173+
174+
def array_equal(self, a, b):
175+
def as_device_scalar(bool_value):
176+
import numpy as np
177+
return self._array_context.from_numpy(
178+
np.array(int(bool_value), dtype=np.int8))
179+
180+
if type(a) != type(b):
181+
return as_device_scalar(False)
182+
elif not is_array_container(a):
183+
if a.shape != b.shape:
184+
return as_device_scalar(False)
185+
else:
186+
return pt.all(pt.equal(a, b))
187+
else:
188+
return multimap_reduce_array_container(
189+
partial(reduce, pt.logical_and), self.array_equal, a, b)
190+
161191
# }}}

test/test_arraycontext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def test_array_equal_same_as_numpy(actx_factory):
578578
lambda _np, *_args: getattr(_np, sym_name)(*_args), [ary, ary_diff_shape])
579579

580580
# Different types
581-
assert not actx.np.array_equal(ary, ary_diff_type)
581+
assert not actx.to_numpy(actx.np.array_equal(ary, ary_diff_type))
582582

583583

584584
# }}}

0 commit comments

Comments
 (0)