Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions cuda_core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def _arr_is_writeable(arr):
return arr.flags.writeable if hasattr(arr.flags, "writeable") else True


def _arr_dtype(arr):
if torch is not None and isinstance(arr, torch.Tensor):
return np.dtype(arr.__cuda_array_interface__["typestr"])
return arr.dtype


def _cpu_array_samples():
samples = [
np.empty(3, dtype=np.int32),
Expand Down Expand Up @@ -171,7 +177,10 @@ def _check_view(self, view, in_arr):
assert view.shape == expected_shape
assert view.size == _arr_size(in_arr)
strides_in_counts = _arr_strides_in_counts(in_arr)
assert (_arr_is_c_contiguous(in_arr) and view.strides is None) or view.strides == strides_in_counts
if view.strides is None:
assert _arr_is_c_contiguous(in_arr)
else:
assert view.strides == strides_in_counts
assert view.device_id == -1
assert view.is_device_accessible is False
assert view.exporting_obj is in_arr
Expand Down Expand Up @@ -277,8 +286,8 @@ def _check_view(self, view, in_arr, dev):
assert view.shape == expected_shape
assert view.size == _arr_size(in_arr)
strides_in_counts = _arr_strides_in_counts(in_arr)
if _arr_is_c_contiguous(in_arr):
assert view.strides in (None, strides_in_counts)
if view.strides is None:
assert _arr_is_c_contiguous(in_arr)
else:
assert view.strides == strides_in_counts
assert view.device_id == dev.device_id
Expand Down Expand Up @@ -343,15 +352,16 @@ def test_cuda_array_interface_gpu(self, in_arr, use_stream):

def _check_view(self, view, in_arr, dev):
assert isinstance(view, StridedMemoryView)
assert view.ptr == gpu_array_ptr(in_arr)
assert view.shape == in_arr.shape
assert view.size == in_arr.size
strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize)
if in_arr.flags["C_CONTIGUOUS"]:
Comment thread
leofang marked this conversation as resolved.
assert view.strides is None
assert view.ptr == _arr_ptr(in_arr)
expected_shape = tuple(in_arr.shape)
assert view.shape == expected_shape
assert view.size == _arr_size(in_arr)
strides_in_counts = _arr_strides_in_counts(in_arr)
if view.strides is None:
assert _arr_is_c_contiguous(in_arr)
else:
assert view.strides == strides_in_counts
assert view.dtype == in_arr.dtype
Comment thread
leofang marked this conversation as resolved.
assert view.dtype == _arr_dtype(in_arr)
assert view.device_id == dev.device_id
assert view.is_device_accessible is True
assert view.exporting_obj is in_arr
Expand Down
Loading