@@ -24,41 +24,6 @@ except ImportError:
2424 torch = None
2525
2626
27- def load_torch_get_current_cuda_stream ():
28- """ Create a faster get_current_cuda_stream for torch through cpp extension.
29- """
30- source = """
31- #include <c10/cuda/CUDAStream.h>
32-
33- int64_t get_current_cuda_stream(int device_id) {
34- at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id);
35- // fast invariant, default stream is always 0
36- if (stream.id() == 0) return 0;
37- // convert to cudaStream_t
38- return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
39- }
40- """
41- def fallback_get_current_cuda_stream (device_id ):
42- """ Fallback with python api"""
43- return torch.cuda.current_stream(device_id).cuda_stream
44- try :
45- from torch.utils import cpp_extension
46- result = cpp_extension.load_inline(
47- name = " get_current_cuda_stream" ,
48- cpp_sources = [source],
49- cuda_sources = [],
50- extra_cflags = [" -O3" ],
51- extra_include_paths = cpp_extension.include_paths(" cuda" ),
52- functions = [" get_current_cuda_stream" ],
53- )
54- return result.get_current_cuda_stream
55- except Exception :
56- return fallback_get_current_cuda_stream
57-
58-
59- torch_get_current_cuda_stream = None
60-
61-
6227cdef inline object make_ret_small_str(TVMFFIAny result):
6328 """ convert small string to return value."""
6429 cdef TVMFFIByteArray bytes
@@ -146,9 +111,8 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
146111 if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0 ] == - 1 :
147112 ctx_dev_type[0 ] = temp_dltensor.device.device_type
148113 ctx_dev_id[0 ] = temp_dltensor.device.device_id
149- if torch_get_current_cuda_stream is None :
150- torch_get_current_cuda_stream = load_torch_get_current_cuda_stream()
151- temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id)
114+ # This is an API that dynamo and other uses to get the raw stream from torch
115+ temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id)
152116 ctx_stream[0 ] = < TVMFFIStreamHandle> temp_ptr
153117 temp_args.append(arg)
154118 elif hasattr (arg, " __dlpack__" ):
0 commit comments