diff --git a/CHANGELOG.md b/CHANGELOG.md index 752bf2ad4b3..d924fc61bfc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +* Added support for buffer protocol objects as advanced index keys in `dpnp.ndarray` [#2889](https://github.com/IntelPython/dpnp/pull/2889) + ### Changed ### Deprecated diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 02cd655fcef..365f2759d2c 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -52,26 +52,31 @@ def _unwrap_index_element(x): """ Unwrap a single index element for the tensor indexing layer. - Converts dpnp arrays to usm_ndarray and array-like objects (range, list) - to numpy arrays with intp dtype for NumPy-compatible advanced indexing. + Converts dpnp arrays to usm_ndarray and array-like objects (range, list, + buffer protocol objects) to numpy arrays for NumPy-compatible advanced + indexing. Scalars and slices pass through to the tensor layer. """ - if isinstance(x, dpt.usm_ndarray): + if ( + x is None + or x is Ellipsis + or isinstance(x, (dpt.usm_ndarray, slice, numpy.ndarray)) + ): return x if isinstance(x, dpnp_array): return x.get_array() - if isinstance(x, range): - return numpy.asarray(x, dtype=numpy.intp) - if isinstance(x, list): - # keep boolean lists as boolean - arr = numpy.asarray(x) - # cast empty lists (float64 in NumPy) to intp - # for correct tensor indexing - if arr.size == 0: - arr = arr.astype(numpy.intp) - return arr - return x + # scalars (int, bool, numpy scalars) pass through to the tensor layer + if isinstance(x, (int, numpy.generic)): + return x + + # convert array-like objects (range, list, buffer protocol) to numpy + arr = numpy.asarray(x) + # cast empty arrays (float64 in NumPy) to intp + # for correct tensor indexing + if arr.size == 0 and arr.dtype.kind == "f": + arr = arr.astype(numpy.intp) + return arr def _get_unwrapped_index_key(key): diff --git a/dpnp/tests/test_indexing.py b/dpnp/tests/test_indexing.py index 2edc8214f3e..19273fef36c 100644 --- a/dpnp/tests/test_indexing.py +++ b/dpnp/tests/test_indexing.py @@ -1,3 +1,4 @@ +import array import functools import dpctl @@ -406,6 +407,65 @@ def test_array_like_single_index(self, idx): dp_a = dpnp.arange(24).reshape(2, 3, 4) assert_array_equal(dp_a[idx], np_a[idx]) + def test_buffer_protocol_getitem(self): + inds = array.array("l") + inds.frombytes(numpy.arange(3).tobytes()) + np_a = numpy.arange(12).reshape(3, 4) + dp_a = dpnp.arange(12).reshape(3, 4) + assert_array_equal(dp_a[inds], np_a[inds]) + + def test_buffer_protocol_paired_index(self): + inds = array.array("l") + inds.frombytes(numpy.arange(3).tobytes()) + np_a = numpy.arange(12).reshape(3, 4) + dp_a = dpnp.arange(12).reshape(3, 4) + assert_array_equal(dp_a[inds, inds], np_a[inds, inds]) + + def test_buffer_protocol_setitem(self): + inds = array.array("l") + inds.frombytes(numpy.arange(3).tobytes()) + np_a = numpy.arange(12).reshape(3, 4) + dp_a = dpnp.arange(12).reshape(3, 4) + np_a[inds, inds] = 0 + dp_a[inds, inds] = 0 + assert_array_equal(dp_a, np_a) + + def test_memoryview_getitem(self): + inds = memoryview(array.array("l", [0, 1, 2])) + np_a = numpy.arange(12).reshape(3, 4) + dp_a = dpnp.arange(12).reshape(3, 4) + assert_array_equal(dp_a[inds], np_a[inds]) + + def test_bytearray_getitem(self): + inds = bytearray(b"\x00\x01\x02") + np_a = numpy.arange(10) + dp_a = dpnp.arange(10) + assert_array_equal(dp_a[inds], np_a[inds]) + + @pytest.mark.parametrize( + "idx", + [ + 1.0, + 1 + 0j, + numpy.float64(1.0), + numpy.complex128(1.0), + "a", + [0.5, 1.5], + ], + ids=[ + "float", + "complex", + "np.float64", + "np.complex128", + "str", + "float_list", + ], + ) + def test_invalid_index(self, idx): + dp_a = dpnp.arange(12).reshape(3, 4) + with pytest.raises((IndexError, TypeError)): + dp_a[idx] + class TestIx: @pytest.mark.parametrize(