1010
1111#include < cstring>
1212
13+ #ifdef CUDA_AVAILABLE
14+ #include < executorch/backends/aoti/slim/c10/cuda/Exception.h>
15+ #include < executorch/backends/cuda/runtime/guard.h>
16+ #endif
17+
1318#include < executorch/backends/aoti/slim/c10/core/Device.h>
1419#include < executorch/backends/aoti/slim/c10/core/ScalarType.h>
1520#include < executorch/backends/aoti/slim/util/ArrayRefUtil.h>
1621#include < executorch/backends/aoti/slim/util/SharedPtr.h>
1722#include < executorch/backends/aoti/slim/util/SizeUtil.h>
1823#include < executorch/runtime/platform/assert.h>
24+ #include < executorch/runtime/platform/log.h>
1925
2026namespace executorch ::backends::aoti::slim {
2127
@@ -30,6 +36,10 @@ inline void noop(void*) {}
3036// / Default CPU device constant.
3137inline const c10::Device CPU_DEVICE = c10::Device(c10::DeviceType::CPU, 0 );
3238
39+ // / Default CUDA device constant.
40+ inline const c10::Device DEFAULT_CUDA_DEVICE =
41+ c10::Device (c10::DeviceType::CUDA, 0 );
42+
3343// / DeviceTraits template for device-specific operations.
3444// / Device-specific implementations provide allocate(), free(), and memcpy().
3545template <c10::DeviceType D>
@@ -74,6 +84,119 @@ struct DeviceTraits<c10::DeviceType::CPU> {
7484 }
7585};
7686
87+ #ifdef CUDA_AVAILABLE
88+ // / CUDA specialization of DeviceTraits.
89+ // / Provides CUDA memory allocation and copy operations using
90+ // / cudaMallocAsync/cudaFreeAsync with proper stream handling.
91+ // /
92+ // / IMPORTANT: Callers are expected to set the correct CUDA device and stream
93+ // / using CUDAStreamGuard before calling these methods. This is consistent
94+ // / with PyTorch's CUDACachingAllocator design pattern where the allocator
95+ // / assumes the caller has already set the correct device context.
96+ template <>
97+ struct DeviceTraits <c10::DeviceType::CUDA> {
98+ // / Allocates CUDA device memory on the current stream.
99+ // / Uses cudaMallocAsync for asynchronous allocation on the stream
100+ // / that is currently set via CUDAStreamGuard, similar to how
101+ // / PyTorch's CUDACachingAllocator works.
102+ // /
103+ // / NOTE: Caller must ensure the correct device is already set via
104+ // / CUDAStreamGuard. This function does NOT create a device guard internally.
105+ // /
106+ // / @param nbytes Number of bytes to allocate.
107+ // / @param device The target CUDA device (used to get the stream).
108+ // / @return Pointer to allocated device memory.
109+ static void * allocate (size_t nbytes, const c10::Device& device) {
110+ // Get the current stream for this device (set by CUDAStreamGuard if any)
111+ // This follows PyTorch's pattern where the allocator assumes the caller
112+ // has already set the correct device via CUDAStreamGuard.
113+ auto stream_result =
114+ executorch::backends::cuda::getCurrentCUDAStream (device.index ());
115+ ET_CHECK_MSG (
116+ stream_result.ok (),
117+ " Failed to get current CUDA stream for device %d" ,
118+ static_cast <int >(device.index ()));
119+
120+ cudaStream_t stream = stream_result.get ();
121+ void * data = nullptr ;
122+ ET_CUDA_CHECK (cudaMallocAsync (&data, nbytes, stream));
123+ return data;
124+ }
125+
126+ // / Frees CUDA device memory on the current stream.
127+ // / @param ptr Pointer to device memory to free.
128+ static void free (void * ptr) {
129+ // Get the current stream for the current device
130+ auto stream_result = executorch::backends::cuda::getCurrentCUDAStream (-1 );
131+ if (stream_result.ok ()) {
132+ ET_CUDA_LOG_WARN (cudaFreeAsync (ptr, stream_result.get ()));
133+ } else {
134+ // Fallback to synchronous free if we can't get the stream
135+ ET_CUDA_LOG_WARN (cudaFree (ptr));
136+ }
137+ }
138+
139+ // / Copies memory between CPU and CUDA or CUDA and CUDA.
140+ // / @param dst Destination pointer.
141+ // / @param src Source pointer.
142+ // / @param nbytes Number of bytes to copy.
143+ // / @param dst_device Destination device.
144+ // / @param src_device Source device.
145+ static void memcpy (
146+ void * dst,
147+ const void * src,
148+ size_t nbytes,
149+ const c10::Device& dst_device,
150+ const c10::Device& src_device) {
151+ cudaMemcpyKind direction = cudaMemcpyDeviceToDevice;
152+
153+ if (src_device.is_cpu ()) {
154+ direction = cudaMemcpyHostToDevice;
155+ } else if (dst_device.is_cpu ()) {
156+ direction = cudaMemcpyDeviceToHost;
157+ } else {
158+ ET_CHECK_MSG (
159+ src_device.index () == dst_device.index (),
160+ " CUDA memcpy across different device indices not supported: %d != %d" ,
161+ static_cast <int >(src_device.index ()),
162+ static_cast <int >(dst_device.index ()));
163+ }
164+
165+ ET_CUDA_CHECK (cudaMemcpy (dst, src, nbytes, direction));
166+ }
167+ };
168+ #else
169+ // / CUDA stub when CUDA_AVAILABLE is not defined.
170+ // / All operations abort with an error message.
171+ template <>
172+ struct DeviceTraits <c10::DeviceType::CUDA> {
173+ static void * allocate (size_t nbytes, const c10::Device& device) {
174+ (void )nbytes;
175+ (void )device;
176+ ET_CHECK_MSG (false , " Build with CUDA_AVAILABLE=1 to enable CUDA support" );
177+ }
178+
179+ static void free (void * ptr) {
180+ (void )ptr;
181+ ET_LOG (Error, " Build with CUDA_AVAILABLE=1 to enable CUDA support" );
182+ }
183+
184+ static void memcpy (
185+ void * dst,
186+ const void * src,
187+ size_t nbytes,
188+ const c10::Device& dst_device,
189+ const c10::Device& src_device) {
190+ (void )dst;
191+ (void )src;
192+ (void )nbytes;
193+ (void )dst_device;
194+ (void )src_device;
195+ ET_CHECK_MSG (false , " Build with CUDA_AVAILABLE=1 to enable CUDA support" );
196+ }
197+ };
198+ #endif // CUDA_AVAILABLE
199+
77200/* *
78201 * MaybeOwningStorage - A storage class that manages tensor data memory.
79202 *
@@ -93,17 +216,19 @@ struct DeviceTraits<c10::DeviceType::CPU> {
93216class MaybeOwningStorage {
94217 public:
95218 // / Constructs owning storage with allocated memory.
96- // / @param device The device for storage (must be CPU ).
219+ // / @param device The device for storage (CPU or CUDA ).
97220 // / @param nbytes Number of bytes to allocate.
98221 MaybeOwningStorage (const c10::Device& device, size_t nbytes)
99222 : device_(device), capacity_(nbytes), is_owning_(true ) {
100- ET_CHECK_MSG (
101- device.is_cpu (),
102- " Only CPU device is currently supported, got: %s" ,
103- device.str ().c_str ());
104-
105- data_ = DeviceTraits<c10::DeviceType::CPU>::allocate (nbytes, device);
106- deleter_ = DeviceTraits<c10::DeviceType::CPU>::free;
223+ if (device.is_cpu ()) {
224+ data_ = DeviceTraits<c10::DeviceType::CPU>::allocate (nbytes, device);
225+ deleter_ = DeviceTraits<c10::DeviceType::CPU>::free;
226+ } else if (device.is_cuda ()) {
227+ data_ = DeviceTraits<c10::DeviceType::CUDA>::allocate (nbytes, device);
228+ deleter_ = DeviceTraits<c10::DeviceType::CUDA>::free;
229+ } else {
230+ ET_CHECK_MSG (false , " Unsupported device type: %s" , device.str ().c_str ());
231+ }
107232 }
108233
109234 // / Default constructor is deleted - storage must have a device.
0 commit comments