@@ -76,10 +76,10 @@ def free(self, buffers):
7676
7777
7878@pytest .mark .parametrize ("mode" , ["no_graph" , "global" , "thread_local" , "relaxed" ])
79- def test_graph_alloc (init_cuda , mode ):
79+ def test_graph_alloc (mempool_device , mode ):
8080 """Test basic graph capture with memory allocated and deallocated by GraphMemoryResource."""
8181 NBYTES = 64
82- device = Device ()
82+ device = mempool_device
8383 stream = device .create_stream ()
8484 dmr = DeviceMemoryResource (device )
8585 gmr = GraphMemoryResource (device )
@@ -118,10 +118,10 @@ def apply_kernels(mr, stream, out):
118118
119119@pytest .mark .skipif (IS_WINDOWS or IS_WSL , reason = "auto_free_on_launch not supported on Windows" )
120120@pytest .mark .parametrize ("mode" , ["global" , "thread_local" , "relaxed" ])
121- def test_graph_alloc_with_output (init_cuda , mode ):
121+ def test_graph_alloc_with_output (mempool_device , mode ):
122122 """Test for memory allocated in a graph being used outside the graph."""
123123 NBYTES = 64
124- device = Device ()
124+ device = mempool_device
125125 stream = device .create_stream ()
126126 gmr = GraphMemoryResource (device )
127127
@@ -157,8 +157,8 @@ def test_graph_alloc_with_output(init_cuda, mode):
157157
158158
159159@pytest .mark .parametrize ("mode" , ["global" , "thread_local" , "relaxed" ])
160- def test_graph_mem_set_attributes (init_cuda , mode ):
161- device = Device ()
160+ def test_graph_mem_set_attributes (mempool_device , mode ):
161+ device = mempool_device
162162 stream = device .create_stream ()
163163 gmr = GraphMemoryResource (device )
164164 mman = GraphMemoryTestManager (gmr , stream , mode )
@@ -209,12 +209,12 @@ def test_graph_mem_set_attributes(init_cuda, mode):
209209
210210
211211@pytest .mark .parametrize ("mode" , ["global" , "thread_local" , "relaxed" ])
212- def test_gmr_check_capture_state (init_cuda , mode ):
212+ def test_gmr_check_capture_state (mempool_device , mode ):
213213 """
214214 Test expected errors (and non-errors) using GraphMemoryResource with graph
215215 capture.
216216 """
217- device = Device ()
217+ device = mempool_device
218218 stream = device .create_stream ()
219219 gmr = GraphMemoryResource (device )
220220
@@ -233,12 +233,12 @@ def test_gmr_check_capture_state(init_cuda, mode):
233233
234234
235235@pytest .mark .parametrize ("mode" , ["global" , "thread_local" , "relaxed" ])
236- def test_dmr_check_capture_state (init_cuda , mode ):
236+ def test_dmr_check_capture_state (mempool_device , mode ):
237237 """
238238 Test expected errors (and non-errors) using DeviceMemoryResource with graph
239239 capture.
240240 """
241- device = Device ()
241+ device = mempool_device
242242 stream = device .create_stream ()
243243 dmr = DeviceMemoryResource (device )
244244
0 commit comments