diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index f0c85d8c3f0..6d1f2e5d00e 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -91,13 +91,13 @@ jobs: - name: Build conda package id: build_conda_pkg continue-on-error: true - run: conda build --no-test --python ${{ matrix.python }} --numpy 2.0 ${{ env.channels-list }} conda-recipe + run: conda-build --no-test --python ${{ matrix.python }} --numpy 2.0 ${{ env.channels-list }} conda-recipe env: MAX_BUILD_CMPL_MKL_VERSION: '2026.0a0' - name: ReBuild conda package if: steps.build_conda_pkg.outcome == 'failure' - run: conda build --no-test --python ${{ matrix.python }} --numpy 2.0 ${{ env.channels-list }} conda-recipe + run: conda-build --no-test --python ${{ matrix.python }} --numpy 2.0 ${{ env.channels-list }} conda-recipe env: MAX_BUILD_CMPL_MKL_VERSION: '2026.0a0' diff --git a/CHANGELOG.md b/CHANGELOG.md index 81b7578044f..debbacf480d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +* Added support for buffer protocol objects as advanced index keys in `dpnp.ndarray` [#2889](https://github.com/IntelPython/dpnp/pull/2889) + ### Changed ### Deprecated diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 02cd655fcef..5c5c28a7002 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -52,8 +52,9 @@ def _unwrap_index_element(x): """ Unwrap a single index element for the tensor indexing layer. - Converts dpnp arrays to usm_ndarray and array-like objects (range, list) - to numpy arrays with intp dtype for NumPy-compatible advanced indexing. + Converts dpnp arrays to usm_ndarray and array-like objects (range, list, + buffer protocol objects) to numpy arrays for NumPy-compatible advanced + indexing. """ @@ -71,6 +72,16 @@ def _unwrap_index_element(x): if arr.size == 0: arr = arr.astype(numpy.intp) return arr + if isinstance(x, numpy.ndarray): + return x + # convert buffer protocol objects (array.array, memoryview, etc.) + try: + mv = memoryview(x) + except TypeError: + return x + # 0-d buffers are handled by the tensor layer + if mv.ndim > 0: + return numpy.asarray(x) return x diff --git a/dpnp/tests/test_indexing.py b/dpnp/tests/test_indexing.py index 2edc8214f3e..aaa9f03e7ae 100644 --- a/dpnp/tests/test_indexing.py +++ b/dpnp/tests/test_indexing.py @@ -1,3 +1,4 @@ +import array import functools import dpctl @@ -406,6 +407,41 @@ def test_array_like_single_index(self, idx): dp_a = dpnp.arange(24).reshape(2, 3, 4) assert_array_equal(dp_a[idx], np_a[idx]) + def test_buffer_protocol_getitem(self): + inds = array.array("l") + inds.frombytes(numpy.arange(3).tobytes()) + np_a = numpy.arange(12).reshape(3, 4) + dp_a = dpnp.arange(12).reshape(3, 4) + assert_array_equal(dp_a[inds], np_a[inds]) + + def test_buffer_protocol_paired_index(self): + inds = array.array("l") + inds.frombytes(numpy.arange(3).tobytes()) + np_a = numpy.arange(12).reshape(3, 4) + dp_a = dpnp.arange(12).reshape(3, 4) + assert_array_equal(dp_a[inds, inds], np_a[inds, inds]) + + def test_buffer_protocol_setitem(self): + inds = array.array("l") + inds.frombytes(numpy.arange(3).tobytes()) + np_a = numpy.arange(12).reshape(3, 4) + dp_a = dpnp.arange(12).reshape(3, 4) + np_a[inds, inds] = 0 + dp_a[inds, inds] = 0 + assert_array_equal(dp_a, np_a) + + def test_memoryview_getitem(self): + inds = memoryview(array.array("l", [0, 1, 2])) + np_a = numpy.arange(12).reshape(3, 4) + dp_a = dpnp.arange(12).reshape(3, 4) + assert_array_equal(dp_a[inds], np_a[inds]) + + def test_bytearray_getitem(self): + inds = bytearray(b"\x00\x01\x02") + np_a = numpy.arange(10) + dp_a = dpnp.arange(10) + assert_array_equal(dp_a[inds], np_a[inds]) + class TestIx: @pytest.mark.parametrize(