Skip to content

Commit eec0495

Browse files
authored
Merge pull request #2595 from IntelPython/reuse-tensor-isin
Implement `dpnp.isin`
2 parents 4a0d4e1 + 155910a commit eec0495

File tree

7 files changed

+222
-1
lines changed

7 files changed

+222
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
2525
* Added implementation of `dpnp.ndarray.__format__` method [#2662](https://github.com/IntelPython/dpnp/pull/2662)
2626
* Added implementation of `dpnp.ndarray.__bytes__` method [#2671](https://github.com/IntelPython/dpnp/pull/2671)
2727
* Added implementation of `dpnp.divmod` [#2674](https://github.com/IntelPython/dpnp/pull/2674)
28+
* Added implementation of `dpnp.isin` function [#2595](https://github.com/IntelPython/dpnp/pull/2595)
2829

2930
### Changed
3031

dpnp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@
498498
unique_inverse,
499499
unique_values,
500500
)
501+
from .dpnp_iface_logic import isin
501502

502503
# -----------------------------------------------------------------------------
503504
# Sorting, searching, and counting
@@ -981,6 +982,7 @@
981982

982983
# Set routines
983984
__all__ += [
985+
"isin",
984986
"unique",
985987
"unique_all",
986988
"unique_counts",

dpnp/dpnp_iface_logic.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import dpnp.backend.extensions.ufunc._ufunc_impl as ufi
5454
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc
5555

56+
from .dpnp_array import dpnp_array
5657
from .dpnp_utils import get_usm_allocations
5758

5859

@@ -1166,6 +1167,120 @@ def isfortran(a):
11661167
return a.flags.fnc
11671168

11681169

1170+
def isin(
1171+
element,
1172+
test_elements,
1173+
assume_unique=False, # pylint: disable=unused-argument
1174+
invert=False,
1175+
*,
1176+
kind=None, # pylint: disable=unused-argument
1177+
):
1178+
"""
1179+
Calculates ``element in test_elements``, broadcasting over `element` only.
1180+
Returns a boolean array of the same shape as `element` that is ``True``
1181+
where an element of `element` is in `test_elements` and ``False``
1182+
otherwise.
1183+
1184+
For full documentation refer to :obj:`numpy.isin`.
1185+
1186+
Parameters
1187+
----------
1188+
element : {dpnp.ndarray, usm_ndarray, scalar}
1189+
Input array.
1190+
test_elements : {dpnp.ndarray, usm_ndarray, scalar}
1191+
The values against which to test each value of `element`.
1192+
This argument is flattened if it is an array.
1193+
assume_unique : bool, optional
1194+
Ignored, as no performance benefit is gained by assuming the
1195+
input arrays are unique. Included for compatibility with NumPy.
1196+
1197+
Default: ``False``.
1198+
invert : bool, optional
1199+
If ``True``, the values in the returned array are inverted, as if
1200+
calculating ``element not in test_elements``.
1201+
``dpnp.isin(a, b, invert=True)`` is equivalent to (but faster
1202+
than) ``dpnp.invert(dpnp.isin(a, b))``.
1203+
1204+
Default: ``False``.
1205+
kind : {None, "sort"}, optional
1206+
Ignored, as the only algorithm implemented is ``"sort"``. Included for
1207+
compatibility with NumPy.
1208+
1209+
Default: ``None``.
1210+
1211+
Returns
1212+
-------
1213+
isin : dpnp.ndarray of bool dtype
1214+
Has the same shape as `element`. The values `element[isin]`
1215+
are in `test_elements`.
1216+
1217+
Examples
1218+
--------
1219+
>>> import dpnp as np
1220+
>>> element = 2*np.arange(4).reshape((2, 2))
1221+
>>> element
1222+
array([[0, 2],
1223+
[4, 6]])
1224+
>>> test_elements = np.array([1, 2, 4, 8])
1225+
>>> mask = np.isin(element, test_elements)
1226+
>>> mask
1227+
array([[False, True],
1228+
[ True, False]])
1229+
>>> element[mask]
1230+
array([2, 4])
1231+
1232+
The indices of the matched values can be obtained with `nonzero`:
1233+
1234+
>>> np.nonzero(mask)
1235+
(array([0, 1]), array([1, 0]))
1236+
1237+
The test can also be inverted:
1238+
1239+
>>> mask = np.isin(element, test_elements, invert=True)
1240+
>>> mask
1241+
array([[ True, False],
1242+
[False, True]])
1243+
>>> element[mask]
1244+
array([0, 6])
1245+
1246+
"""
1247+
1248+
dpnp.check_supported_arrays_type(element, test_elements, scalar_type=True)
1249+
if dpnp.isscalar(element):
1250+
usm_element = dpnp.as_usm_ndarray(
1251+
element,
1252+
usm_type=test_elements.usm_type,
1253+
sycl_queue=test_elements.sycl_queue,
1254+
)
1255+
usm_test = dpnp.get_usm_ndarray(test_elements)
1256+
elif dpnp.isscalar(test_elements):
1257+
usm_test = dpnp.as_usm_ndarray(
1258+
test_elements,
1259+
usm_type=element.usm_type,
1260+
sycl_queue=element.sycl_queue,
1261+
)
1262+
usm_element = dpnp.get_usm_ndarray(element)
1263+
else:
1264+
if (
1265+
dpu.get_execution_queue(
1266+
(element.sycl_queue, test_elements.sycl_queue)
1267+
)
1268+
is None
1269+
):
1270+
raise dpu.ExecutionPlacementError(
1271+
"Input arrays have incompatible allocation queues"
1272+
)
1273+
usm_element = dpnp.get_usm_ndarray(element)
1274+
usm_test = dpnp.get_usm_ndarray(test_elements)
1275+
return dpnp_array._create_from_usm_ndarray(
1276+
dpt.isin(
1277+
usm_element,
1278+
usm_test,
1279+
invert=invert,
1280+
)
1281+
)
1282+
1283+
11691284
_ISINF_DOCSTRING = """
11701285
Tests each element :math:`x_i` of the input array `x` to determine if equal to
11711286
positive or negative infinity.

dpnp/tests/test_logic.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import dpctl
12
import numpy
23
import pytest
4+
from dpctl.utils import ExecutionPlacementError
35
from numpy.testing import (
46
assert_allclose,
57
assert_array_equal,
@@ -795,3 +797,103 @@ def test_array_equal_nan(a):
795797
result = dpnp.array_equal(dpnp.array(a), dpnp.array(b), equal_nan=True)
796798
expected = numpy.array_equal(a, b, equal_nan=True)
797799
assert_equal(result, expected)
800+
801+
802+
class TestIsin:
803+
@pytest.mark.parametrize(
804+
"a",
805+
[
806+
numpy.array([1, 2, 3, 4]),
807+
numpy.array([[1, 2], [3, 4]]),
808+
],
809+
)
810+
@pytest.mark.parametrize(
811+
"b",
812+
[
813+
numpy.array([2, 4, 6]),
814+
numpy.array([[1, 3], [5, 7]]),
815+
],
816+
)
817+
def test_isin_basic(self, a, b):
818+
dp_a = dpnp.array(a)
819+
dp_b = dpnp.array(b)
820+
821+
expected = numpy.isin(a, b)
822+
result = dpnp.isin(dp_a, dp_b)
823+
assert_equal(result, expected)
824+
825+
@pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True))
826+
def test_isin_dtype(self, dtype):
827+
a = numpy.array([1, 2, 3, 4], dtype=dtype)
828+
b = numpy.array([2, 4], dtype=dtype)
829+
830+
dp_a = dpnp.array(a, dtype=dtype)
831+
dp_b = dpnp.array(b, dtype=dtype)
832+
833+
expected = numpy.isin(a, b)
834+
result = dpnp.isin(dp_a, dp_b)
835+
assert_equal(result, expected)
836+
837+
@pytest.mark.parametrize(
838+
"sh_a, sh_b", [((3, 1), (1, 4)), ((2, 3, 1), (1, 1))]
839+
)
840+
def test_isin_broadcast(self, sh_a, sh_b):
841+
a = numpy.arange(numpy.prod(sh_a)).reshape(sh_a)
842+
b = numpy.arange(numpy.prod(sh_b)).reshape(sh_b)
843+
844+
dp_a = dpnp.array(a)
845+
dp_b = dpnp.array(b)
846+
847+
expected = numpy.isin(a, b)
848+
result = dpnp.isin(dp_a, dp_b)
849+
assert_equal(result, expected)
850+
851+
def test_isin_scalar_elements(self):
852+
a = numpy.array([1, 2, 3])
853+
b = 2
854+
855+
dp_a = dpnp.array(a)
856+
dp_b = dpnp.array(b)
857+
858+
expected = numpy.isin(a, b)
859+
result = dpnp.isin(dp_a, dp_b)
860+
assert_equal(result, expected)
861+
862+
def test_isin_scalar_test_elements(self):
863+
a = 2
864+
b = numpy.array([1, 2, 3])
865+
866+
dp_a = dpnp.array(a)
867+
dp_b = dpnp.array(b)
868+
869+
expected = numpy.isin(a, b)
870+
result = dpnp.isin(dp_a, dp_b)
871+
assert_equal(result, expected)
872+
873+
def test_isin_empty(self):
874+
a = numpy.array([], dtype=int)
875+
b = numpy.array([1, 2, 3])
876+
877+
dp_a = dpnp.array(a)
878+
dp_b = dpnp.array(b)
879+
880+
expected = numpy.isin(a, b)
881+
result = dpnp.isin(dp_a, dp_b)
882+
assert_equal(result, expected)
883+
884+
def test_isin_errors(self):
885+
q1 = dpctl.SyclQueue()
886+
q2 = dpctl.SyclQueue()
887+
888+
a = dpnp.arange(5, sycl_queue=q1)
889+
b = dpnp.arange(3, sycl_queue=q2)
890+
891+
# unsupported type for elements or test_elements
892+
with pytest.raises(TypeError):
893+
dpnp.isin(dict(), a)
894+
895+
with pytest.raises(TypeError):
896+
dpnp.isin(a, dict())
897+
898+
with pytest.raises(ExecutionPlacementError):
899+
dpnp.isin(a, b)

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ def test_logic_op_1in(op, device):
564564
"greater",
565565
"greater_equal",
566566
"isclose",
567+
"isin",
567568
"less",
568569
"less_equal",
569570
"logical_and",

dpnp/tests/test_usm_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def test_logic_op_1in(op, usm_type_x):
355355
"greater",
356356
"greater_equal",
357357
"isclose",
358+
"isin",
358359
"less",
359360
"less_equal",
360361
"logical_and",

dpnp/tests/third_party/cupy/logic_tests/test_truth.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def test_with_out(self, xp, dtype):
8989
return out
9090

9191

92-
@pytest.mark.skip("isin() is not supported yet")
9392
@testing.parameterize(
9493
*testing.product(
9594
{

0 commit comments

Comments
 (0)