Skip to content

Commit 78c4261

Browse files
committed
Adjust deallocation stream for legacy memory resources to avoid platform-dependent errors. Add dependence on mempool_device where needed for certain tests.
1 parent af22c81 commit 78c4261

File tree

4 files changed

+23
-21
lines changed

4 files changed

+23
-21
lines changed

cuda_core/cuda/core/experimental/_memory/_legacy.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def deallocate(self, ptr: DevicePointerT, size, stream):
6262
stream : Stream
6363
The stream on which to perform the deallocation synchronously.
6464
"""
65-
stream.sync()
65+
if stream is not None:
66+
stream.sync()
6667
(err,) = driver.cuMemFreeHost(ptr)
6768
raise_if_driver_error(err)
6869

@@ -97,10 +98,11 @@ def allocate(self, size, stream=None) -> Buffer:
9798
stream = default_stream()
9899
err, ptr = driver.cuMemAlloc(size)
99100
raise_if_driver_error(err)
100-
return Buffer._init(ptr, size, self)
101+
return Buffer._init(ptr, size, self, stream)
101102

102103
def deallocate(self, ptr, size, stream):
103-
stream.sync()
104+
if stream is not None:
105+
stream.sync()
104106
(err,) = driver.cuMemFree(ptr)
105107
raise_if_driver_error(err)
106108

cuda_core/tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,16 @@ def ipc_memory_resource(ipc_device):
110110
return mr
111111

112112

113+
@pytest.fixture
114+
def mempool_device():
115+
"""Obtains a device suitable for mempool tests, or skips."""
116+
device = Device()
117+
device.set_current()
118+
119+
if not device.properties.memory_pools_supported:
120+
pytest.skip("Device does not support mempool operations")
121+
122+
return device
123+
124+
113125
skipif_need_cuda_headers = pytest.mark.skipif(helpers.CUDA_INCLUDE_PATH is None, reason="need CUDA header")

cuda_core/tests/test_graph_mem.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def free(self, buffers):
7676

7777

7878
@pytest.mark.parametrize("mode", ["no_graph", "global", "thread_local", "relaxed"])
79-
def test_graph_alloc(init_cuda, mode):
79+
def test_graph_alloc(mempool_device, mode):
8080
"""Test basic graph capture with memory allocated and deallocated by GraphMemoryResource."""
8181
NBYTES = 64
82-
device = Device()
82+
device = mempool_device
8383
stream = device.create_stream()
8484
dmr = DeviceMemoryResource(device)
8585
gmr = GraphMemoryResource(device)
@@ -118,10 +118,10 @@ def apply_kernels(mr, stream, out):
118118

119119
@pytest.mark.skipif(IS_WINDOWS or IS_WSL, reason="auto_free_on_launch not supported on Windows")
120120
@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
121-
def test_graph_alloc_with_output(init_cuda, mode):
121+
def test_graph_alloc_with_output(mempool_device, mode):
122122
"""Test for memory allocated in a graph being used outside the graph."""
123123
NBYTES = 64
124-
device = Device()
124+
device = mempool_device
125125
stream = device.create_stream()
126126
gmr = GraphMemoryResource(device)
127127

@@ -157,8 +157,8 @@ def test_graph_alloc_with_output(init_cuda, mode):
157157

158158

159159
@pytest.mark.parametrize("mode", ["global", "thread_local", "relaxed"])
160-
def test_graph_mem_set_attributes(init_cuda, mode):
161-
device = Device()
160+
def test_graph_mem_set_attributes(mempool_device, mode):
161+
device = mempool_device
162162
stream = device.create_stream()
163163
gmr = GraphMemoryResource(device)
164164
mman = GraphMemoryTestManager(gmr, stream, mode)

cuda_core/tests/test_memory.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,6 @@
3939
POOL_SIZE = 2097152 # 2MB size
4040

4141

42-
@pytest.fixture(scope="function")
43-
def mempool_device():
44-
"""Obtains a device suitable for mempool tests, or skips."""
45-
device = Device()
46-
device.set_current()
47-
48-
if not device.properties.memory_pools_supported:
49-
pytest.skip("Device does not support mempool operations")
50-
51-
return device
52-
53-
5442
class DummyDeviceMemoryResource(MemoryResource):
5543
def __init__(self, device):
5644
self.device = device

0 commit comments

Comments
 (0)