Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

* Fixed incorrect in-place advanced indexing for 4D arrays when using `range` or `list` as index keys [#2872](https://github.com/IntelPython/dpnp/pull/2872)

### Security


Expand Down
41 changes: 32 additions & 9 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@

import warnings

import numpy

import dpnp
import dpnp.tensor as dpt
import dpnp.tensor._type_utils as dtu
Expand All @@ -46,24 +48,45 @@
from .exceptions import AxisError


def _unwrap_index_element(x):
"""
Unwrap a single index element for the tensor indexing layer.

Converts dpnp arrays to usm_ndarray and array-like objects (range, list)
to numpy arrays with intp dtype for NumPy-compatible advanced indexing.

"""

if isinstance(x, dpt.usm_ndarray):
return x
if isinstance(x, dpnp_array):
return x.get_array()
Comment thread
vlad-perevezentsev marked this conversation as resolved.
if isinstance(x, range):
return numpy.asarray(x, dtype=numpy.intp)
if isinstance(x, list):
# keep boolean lists as boolean
arr = numpy.asarray(x)
# cast empty lists (float64 in NumPy) to intp
# for correct tensor indexing
if arr.size == 0:
arr = arr.astype(numpy.intp)
return arr
return x
Comment thread
ndgrigorian marked this conversation as resolved.


def _get_unwrapped_index_key(key):
"""
Get an unwrapped index key.

Return a key where each nested instance of DPNP array is unwrapped into
USM ndarray for further processing in DPCTL advanced indexing functions.
USM ndarray, and array-like objects (range, list) are converted to numpy
arrays for further processing in advanced indexing functions.

"""

if isinstance(key, tuple):
if any(isinstance(x, dpnp_array) for x in key):
# create a new tuple from the input key with unwrapped DPNP arrays
return tuple(
x.get_array() if isinstance(x, dpnp_array) else x for x in key
)
elif isinstance(key, dpnp_array):
return key.get_array()
return key
return tuple(_unwrap_index_element(x) for x in key)
return _unwrap_index_element(key)


# pylint: disable=too-many-public-methods
Expand Down
6 changes: 0 additions & 6 deletions dpnp/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,6 @@ cdef bint _is_boolean(object x) except *:
return f in "?"
else:
return False
if callable(getattr(x, "__bool__", None)):
try:
x.__bool__()
except (TypeError, ValueError):
return False
return True
Comment thread
ndgrigorian marked this conversation as resolved.
return False


Expand Down
53 changes: 53 additions & 0 deletions dpnp/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,59 @@ def test_indexing_array_negative_strides(self):
arr[slices] = 10
assert_equal(arr, 10.0, strict=False)

@pytest.mark.parametrize(
"idx",
[
(range(2), range(2)),
([0, 1], [0, 1]),
],
ids=["range", "list"],
)
def test_array_like_index_getitem(self, idx):
np_a = numpy.arange(36).reshape(2, 2, 3, 3)
dp_a = dpnp.arange(36).reshape(2, 2, 3, 3)
assert_array_equal(dp_a[idx], np_a[idx])

@pytest.mark.parametrize(
"idx",
[
(range(2), range(2)),
([0, 1], [0, 1]),
],
ids=["range", "list"],
)
def test_array_like_index_setitem(self, idx):
np_a = numpy.arange(36).reshape(2, 2, 3, 3)
dp_a = dpnp.arange(36).reshape(2, 2, 3, 3)
np_a[idx] = 0
dp_a[idx] = 0
assert_array_equal(dp_a, np_a)

def test_array_like_index_inplace_add(self):
np_a = numpy.arange(36).reshape(2, 2, 3, 3)
dp_a = dpnp.arange(36).reshape(2, 2, 3, 3)
np_tmp = -numpy.ones((2, 3, 3), dtype=numpy.intp)
dp_tmp = -dpnp.ones((2, 3, 3), dtype=numpy.intp)

np_a[range(2), range(2)] += 2 * np_tmp
dp_a[range(2), range(2)] += 2 * dp_tmp
assert_array_equal(dp_a, np_a)

@pytest.mark.parametrize(
"idx",
[
range(2),
[0, 1],
range(0),
[],
],
ids=["range", "list", "empty_range", "empty_list"],
)
def test_array_like_single_index(self, idx):
np_a = numpy.arange(24).reshape(2, 3, 4)
dp_a = dpnp.arange(24).reshape(2, 3, 4)
assert_array_equal(dp_a[idx], np_a[idx])


class TestIx:
@pytest.mark.parametrize(
Expand Down
Loading