Skip to content

Commit 203984d

Browse files
committed
whitespace changes
1 parent 5b30c4f commit 203984d

1 file changed

Lines changed: 21 additions & 1 deletion

File tree

numba_cuda/numba/cuda/cudadrv/driver.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
c_void_p,
4040
c_float,
4141
c_uint,
42+
c_uint8,
4243
)
4344
import contextlib
4445
import importlib
@@ -3439,6 +3440,8 @@ def device_memset(dst, val, size, stream=0):
34393440
size: number of byte to be written
34403441
stream: a CUDA stream
34413442
"""
3443+
ptr = device_pointer(dst)
3444+
34423445
varargs = []
34433446

34443447
if stream:
@@ -3452,7 +3455,24 @@ def device_memset(dst, val, size, stream=0):
34523455
else:
34533456
fn = driver.cuMemsetD8
34543457

3455-
fn(device_pointer(dst), val, size, *varargs)
3458+
try:
3459+
fn(ptr, val, size, *varargs)
3460+
except CudaAPIError as e:
3461+
invalid = (
3462+
binding.CUresult.CUDA_ERROR_INVALID_VALUE
3463+
if USE_NV_BINDING
3464+
else enums.CUDA_ERROR_INVALID_VALUE
3465+
)
3466+
if (
3467+
e.code == invalid
3468+
and getattr(dst, "__cuda_memory__", False)
3469+
and getattr(dst, "is_managed", False)
3470+
):
3471+
buf = (c_uint8 * size).from_address(host_pointer(dst))
3472+
byte = val & 0xFF
3473+
buf[:] = [byte] * size
3474+
return
3475+
raise
34563476

34573477

34583478
def profile_start():

0 commit comments

Comments
 (0)