Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions cuda_core/cuda/core/_utils/cuda_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,9 @@ cpdef inline int _check_driver_error(cydriver.CUresult error) except?-1 nogil:
cpdef inline int _check_runtime_error(error) except?-1:
if error == _RUNTIME_SUCCESS:
return 0
name_err, name = runtime.cudaGetErrorName(error)
if name_err != _RUNTIME_SUCCESS:
raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
name = name.decode()
# `_check_error()` reaches this path only for `runtime.cudaError_t` values.
# Use the enum name directly because Windows hybrid cudart can lag that table.
name = error.name
expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error))
if expl is not None:
raise CUDAError(f"{name}: {expl}")
Expand Down
10 changes: 1 addition & 9 deletions cuda_core/tests/test_cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,14 @@ def test_check_driver_error():


def test_check_runtime_error():
num_unexpected = 0
for error in runtime.cudaError_t:
if error == runtime.cudaError_t.cudaSuccess:
assert cuda_utils._check_runtime_error(error) == 0
else:
with pytest.raises(cuda_utils.CUDAError) as e:
cuda_utils._check_runtime_error(error)
msg = str(e)
if "UNEXPECTED ERROR CODE" in msg:
num_unexpected += 1
else:
# Example repr(error): <cudaError_t.cudaErrorUnknown: 999>
enum_name = repr(error).split(".", 1)[1].split(":", 1)[0]
assert enum_name in msg
# Smoke test: We don't want most to be unexpected.
assert num_unexpected < len(driver.CUresult) * 0.5
assert error.name in msg


def test_driver_error_enum_has_non_empty_docstring():
Expand Down
Loading