Skip to content

Commit 7b28787

Browse files
authored
[FFI] Update torch stream getter to use native torch c api (#18266)
This PR updates the torch stream getter to use _cuda_getCurrentRawStream in the torch C API that is also used by dynamo, saves us from load_inline the custom module.
1 parent 887a9ca commit 7b28787

2 files changed

Lines changed: 3 additions & 39 deletions

File tree

ffi/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
[project]
1919
name = "apache-tvm-ffi"
20-
version = "0.1.0a6"
20+
version = "0.1.0a7"
2121
description = "tvm ffi"
2222

2323
authors = [{ name = "TVM FFI team" }]

ffi/python/tvm_ffi/cython/function.pxi

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,6 @@ except ImportError:
2424
torch = None
2525

2626

27-
def load_torch_get_current_cuda_stream():
28-
"""Create a faster get_current_cuda_stream for torch through cpp extension.
29-
"""
30-
source = """
31-
#include <c10/cuda/CUDAStream.h>
32-
33-
int64_t get_current_cuda_stream(int device_id) {
34-
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(device_id);
35-
// fast invariant, default stream is always 0
36-
if (stream.id() == 0) return 0;
37-
// convert to cudaStream_t
38-
return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
39-
}
40-
"""
41-
def fallback_get_current_cuda_stream(device_id):
42-
"""Fallback with python api"""
43-
return torch.cuda.current_stream(device_id).cuda_stream
44-
try:
45-
from torch.utils import cpp_extension
46-
result = cpp_extension.load_inline(
47-
name="get_current_cuda_stream",
48-
cpp_sources=[source],
49-
cuda_sources=[],
50-
extra_cflags=["-O3"],
51-
extra_include_paths=cpp_extension.include_paths("cuda"),
52-
functions=["get_current_cuda_stream"],
53-
)
54-
return result.get_current_cuda_stream
55-
except Exception:
56-
return fallback_get_current_cuda_stream
57-
58-
59-
torch_get_current_cuda_stream = None
60-
61-
6227
cdef inline object make_ret_small_str(TVMFFIAny result):
6328
"""convert small string to return value."""
6429
cdef TVMFFIByteArray bytes
@@ -146,9 +111,8 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args,
146111
if is_cuda and ctx_dev_type != NULL and ctx_dev_type[0] == -1:
147112
ctx_dev_type[0] = temp_dltensor.device.device_type
148113
ctx_dev_id[0] = temp_dltensor.device.device_id
149-
if torch_get_current_cuda_stream is None:
150-
torch_get_current_cuda_stream = load_torch_get_current_cuda_stream()
151-
temp_ptr = torch_get_current_cuda_stream(temp_dltensor.device.device_id)
114+
# This is an API that dynamo and other uses to get the raw stream from torch
115+
temp_ptr = torch._C._cuda_getCurrentRawStream(temp_dltensor.device.device_id)
152116
ctx_stream[0] = <TVMFFIStreamHandle>temp_ptr
153117
temp_args.append(arg)
154118
elif hasattr(arg, "__dlpack__"):

0 commit comments

Comments
 (0)