Skip to content

Commit 60a93de

Browse files
cyx-6tqchen
authored andcommitted
[FFI] Construct NDArray.strides by default (apache#18272)
This PR updates NDArray.strides to construct strides by default
1 parent e73677f commit 60a93de

3 files changed

Lines changed: 23 additions & 6 deletions

File tree

include/tvm/ffi/container/ndarray.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class NDArrayObj : public Object, public DLTensor {
151151
protected:
152152
// backs up the shape of the NDArray
153153
Optional<Shape> shape_data_;
154+
Optional<Shape> stride_data_;
154155

155156
static void DLManagedTensorDeleter(DLManagedTensor* tensor) {
156157
NDArrayObj* obj = static_cast<NDArrayObj*>(tensor->manager_ctx);
@@ -184,9 +185,11 @@ class NDArrayObjFromNDAlloc : public NDArrayObj {
184185
this->ndim = static_cast<int>(shape.size());
185186
this->dtype = dtype;
186187
this->shape = const_cast<int64_t*>(shape.data());
187-
this->strides = nullptr;
188+
Shape strides = Shape(details::MakeStridesFromShape(this->ndim, this->shape));
189+
this->strides = const_cast<int64_t*>(strides.data());
188190
this->byte_offset = 0;
189191
this->shape_data_ = std::move(shape);
192+
this->stride_data_ = std::move(strides);
190193
alloc_.AllocData(static_cast<DLTensor*>(this), std::forward<ExtraArgs>(extra_args)...);
191194
}
192195

@@ -202,9 +205,10 @@ class NDArrayObjFromDLPack : public NDArrayObj {
202205
public:
203206
explicit NDArrayObjFromDLPack(TDLPackManagedTensor* tensor) : tensor_(tensor) {
204207
*static_cast<DLTensor*>(this) = tensor_->dl_tensor;
205-
// set strides to nullptr if the tensor is contiguous.
206-
if (IsContiguous(tensor->dl_tensor)) {
207-
this->strides = nullptr;
208+
if (tensor_->dl_tensor.strides == nullptr) {
209+
Shape strides = Shape(details::MakeStridesFromShape(ndim, shape));
210+
this->strides = const_cast<int64_t*>(strides.data());
211+
this->stride_data_ = std::move(strides);
208212
}
209213
}
210214

include/tvm/ffi/container/shape.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,17 @@ TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeInplaceShape(IterType begin, IterType end
9191
return p;
9292
}
9393

94+
TVM_FFI_INLINE ObjectPtr<ShapeObj> MakeStridesFromShape(int64_t ndim, int64_t* shape) {
95+
int64_t* strides_data;
96+
ObjectPtr<ShapeObj> strides = details::MakeEmptyShape(ndim, &strides_data);
97+
int64_t stride = 1;
98+
for (int i = ndim - 1; i >= 0; --i) {
99+
strides_data[i] = stride;
100+
stride *= shape[i];
101+
}
102+
return strides;
103+
}
104+
94105
} // namespace details
95106

96107
/*!

tests/cpp/test_ndarray.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ TEST(NDArray, DLPack) {
6969
EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
7070
EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
7171
EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
72-
EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
72+
EXPECT_EQ(dlpack->dl_tensor.strides[0], 6);
73+
EXPECT_EQ(dlpack->dl_tensor.strides[1], 3);
74+
EXPECT_EQ(dlpack->dl_tensor.strides[2], 1);
7375
EXPECT_EQ(nd.use_count(), 2);
7476
{
7577
NDArray nd2 = NDArray::FromDLPack(dlpack);
@@ -96,7 +98,7 @@ TEST(NDArray, DLPackVersioned) {
9698
EXPECT_EQ(dlpack->dl_tensor.device.device_type, kDLCPU);
9799
EXPECT_EQ(dlpack->dl_tensor.device.device_id, 0);
98100
EXPECT_EQ(dlpack->dl_tensor.byte_offset, 0);
99-
EXPECT_EQ(dlpack->dl_tensor.strides, nullptr);
101+
EXPECT_EQ(dlpack->dl_tensor.strides[0], 1);
100102

101103
EXPECT_EQ(nd.use_count(), 2);
102104
{

0 commit comments

Comments
 (0)