Skip to content

Commit 33516b8

Browse files
committed
[slimtensor] Add SlimTensor-based aoti_torch_empty_strided
Pull Request resolved: #16447 Add SlimTensor-based implementations of AOTI shim functions for tensor creation: `aoti_torch_create_tensor_from_blob_v2()` - Creates a non-owning SlimTensor that wraps existing memory using the `from_blob()` factory Both functions support CPU and CUDA devices and handle all 7 SlimTensor dtypes. Changes: - Add `memory_slim.h` and `memory_slim.cpp` with SlimTensor-based shim implementations - Add `runtime_shims_slim` library target to TARGETS with `CUDA_AVAILABLE=1` preprocessor flag - Add `cuda_shim_slim_cpp_unittest()` function for SlimTensor test targets ghstack-source-id: 336216554 @exported-using-ghexport Differential Revision: [D90126244](https://our.internmc.facebook.com/intern/diff/D90126244/)
1 parent 7689118 commit 33516b8

File tree

4 files changed

+536
-0
lines changed

4 files changed

+536
-0
lines changed

backends/cuda/runtime/shims/memory_slim.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ using c10::ScalarType;
2323
using executorch::backends::aoti::slim::empty_strided;
2424
using executorch::backends::aoti::slim::from_blob;
2525
using executorch::backends::aoti::slim::IntArrayRef;
26+
using executorch::backends::aoti::slim::makeArrayRef;
2627

2728
extern "C" {
2829

@@ -76,6 +77,51 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(
7677
return Error::Ok;
7778
}
7879

80+
AOTITorchError aoti_torch_empty_strided(
81+
int64_t ndim,
82+
const int64_t* sizes_ptr,
83+
const int64_t* strides_ptr,
84+
int32_t dtype,
85+
int32_t device_type,
86+
int32_t device_index,
87+
Tensor** ret_new_tensor) {
88+
ET_CHECK_OR_RETURN_ERROR(
89+
ret_new_tensor != nullptr,
90+
InvalidArgument,
91+
"aoti_torch_empty_strided: ret_new_tensor is null");
92+
93+
ET_CHECK_OR_RETURN_ERROR(
94+
!(sizes_ptr == nullptr && ndim > 0),
95+
InvalidArgument,
96+
"aoti_torch_empty_strided: sizes_ptr is null but ndim > 0");
97+
98+
IntArrayRef sizes(sizes_ptr, static_cast<size_t>(ndim));
99+
100+
// Handle nullptr strides by computing contiguous strides
101+
if (strides_ptr == nullptr) {
102+
std::vector<int64_t> contig_strides =
103+
executorch::backends::aoti::slim::compute_contiguous_strides(sizes);
104+
*ret_new_tensor = new Tensor(empty_strided(
105+
sizes,
106+
makeArrayRef(contig_strides),
107+
static_cast<ScalarType>(dtype),
108+
Device(
109+
static_cast<DeviceType>(device_type),
110+
static_cast<DeviceIndex>(device_index))));
111+
} else {
112+
IntArrayRef strides(strides_ptr, static_cast<size_t>(ndim));
113+
*ret_new_tensor = new Tensor(empty_strided(
114+
sizes,
115+
strides,
116+
static_cast<ScalarType>(dtype),
117+
Device(
118+
static_cast<DeviceType>(device_type),
119+
static_cast<DeviceIndex>(device_index))));
120+
}
121+
122+
return Error::Ok;
123+
}
124+
79125
} // extern "C"
80126

81127
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/memory_slim.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,28 @@ AOTI_SHIM_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2(
5757
const uint8_t* opaque_metadata,
5858
int64_t opaque_metadata_size);
5959

60+
/**
61+
* Creates an uninitialized tensor with specified dimensions, strides, and
62+
* dtype on either CPU or CUDA device.
63+
*
64+
* @param ndim Number of dimensions in the tensor
65+
* @param sizes_ptr Pointer to array of dimension sizes
66+
* @param strides_ptr Pointer to array of strides for each dimension
67+
* @param dtype Data type identifier (matches PyTorch scalar types)
68+
* @param device_type Device type (0=CPU, 1=CUDA)
69+
* @param device_index Device index
70+
* @param ret_new_tensor Output parameter for the created tensor
71+
* @return AOTITorchError error code (Error::Ok on success)
72+
*/
73+
AOTI_SHIM_EXPORT AOTITorchError aoti_torch_empty_strided(
74+
int64_t ndim,
75+
const int64_t* sizes_ptr,
76+
const int64_t* strides_ptr,
77+
int32_t dtype,
78+
int32_t device_type,
79+
int32_t device_index,
80+
Tensor** ret_new_tensor);
81+
6082
} // extern "C"
6183

6284
} // namespace executorch::backends::cuda

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,5 @@ def define_common_targets():
7171
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")
7272

7373
# SlimTensor-based shim tests
74+
cuda_shim_slim_cpp_unittest("aoti_torch_empty_strided")
7475
cuda_shim_slim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")

0 commit comments

Comments
 (0)