diff --git a/cuda_core/tests/test_utils.py b/cuda_core/tests/test_utils.py index e5e4789464..5bcdead92c 100644 --- a/cuda_core/tests/test_utils.py +++ b/cuda_core/tests/test_utils.py @@ -111,6 +111,12 @@ def _arr_is_writeable(arr): return arr.flags.writeable if hasattr(arr.flags, "writeable") else True +def _arr_dtype(arr): + if torch is not None and isinstance(arr, torch.Tensor): + return np.dtype(arr.__cuda_array_interface__["typestr"]) + return arr.dtype + + def _cpu_array_samples(): samples = [ np.empty(3, dtype=np.int32), @@ -171,7 +177,10 @@ def _check_view(self, view, in_arr): assert view.shape == expected_shape assert view.size == _arr_size(in_arr) strides_in_counts = _arr_strides_in_counts(in_arr) - assert (_arr_is_c_contiguous(in_arr) and view.strides is None) or view.strides == strides_in_counts + if view.strides is None: + assert _arr_is_c_contiguous(in_arr) + else: + assert view.strides == strides_in_counts assert view.device_id == -1 assert view.is_device_accessible is False assert view.exporting_obj is in_arr @@ -277,8 +286,8 @@ def _check_view(self, view, in_arr, dev): assert view.shape == expected_shape assert view.size == _arr_size(in_arr) strides_in_counts = _arr_strides_in_counts(in_arr) - if _arr_is_c_contiguous(in_arr): - assert view.strides in (None, strides_in_counts) + if view.strides is None: + assert _arr_is_c_contiguous(in_arr) else: assert view.strides == strides_in_counts assert view.device_id == dev.device_id @@ -343,15 +352,16 @@ def test_cuda_array_interface_gpu(self, in_arr, use_stream): def _check_view(self, view, in_arr, dev): assert isinstance(view, StridedMemoryView) - assert view.ptr == gpu_array_ptr(in_arr) - assert view.shape == in_arr.shape - assert view.size == in_arr.size - strides_in_counts = convert_strides_to_counts(in_arr.strides, in_arr.dtype.itemsize) - if in_arr.flags["C_CONTIGUOUS"]: - assert view.strides is None + assert view.ptr == _arr_ptr(in_arr) + expected_shape = tuple(in_arr.shape) + assert view.shape == expected_shape + assert view.size == _arr_size(in_arr) + strides_in_counts = _arr_strides_in_counts(in_arr) + if view.strides is None: + assert _arr_is_c_contiguous(in_arr) else: assert view.strides == strides_in_counts - assert view.dtype == in_arr.dtype + assert view.dtype == _arr_dtype(in_arr) assert view.device_id == dev.device_id assert view.is_device_accessible is True assert view.exporting_obj is in_arr