|
26 | 26 | from arraycontext.fake_numpy import ( |
27 | 27 | BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, |
28 | 28 | ) |
| 29 | +from arraycontext.container import is_array_container |
29 | 30 | from arraycontext.container.traversal import ( |
30 | | - rec_multimap_array_container, rec_map_array_container, |
| 31 | + rec_map_array_container, |
| 32 | + rec_multimap_array_container, |
| 33 | + multimap_reduce_array_container, |
31 | 34 | rec_map_reduce_array_container, |
32 | 35 | ) |
33 | 36 | import pytato as pt |
@@ -158,4 +161,31 @@ def _rec_ravel(a): |
158 | 161 |
|
159 | 162 | return rec_map_array_container(_rec_ravel, a) |
160 | 163 |
|
| 164 | + def any(self, a): |
| 165 | + return rec_map_reduce_array_container( |
| 166 | + partial(reduce, pt.logical_or), |
| 167 | + lambda subary: pt.any(subary), a) |
| 168 | + |
| 169 | + def all(self, a): |
| 170 | + return rec_map_reduce_array_container( |
| 171 | + partial(reduce, pt.logical_and), |
| 172 | + lambda subary: pt.all(subary), a) |
| 173 | + |
| 174 | + def array_equal(self, a, b): |
| 175 | + def as_device_scalar(bool_value): |
| 176 | + import numpy as np |
| 177 | + return self._array_context.from_numpy( |
| 178 | + np.array(int(bool_value), dtype=np.int8)) |
| 179 | + |
| 180 | + if type(a) != type(b): |
| 181 | + return as_device_scalar(False) |
| 182 | + elif not is_array_container(a): |
| 183 | + if a.shape != b.shape: |
| 184 | + return as_device_scalar(False) |
| 185 | + else: |
| 186 | + return pt.all(pt.equal(a, b)) |
| 187 | + else: |
| 188 | + return multimap_reduce_array_container( |
| 189 | + partial(reduce, pt.logical_and), self.array_equal, a, b) |
| 190 | + |
161 | 191 | # }}} |
0 commit comments