Skip to content

Commit 146464e

Browse files
committed
Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc
1 parent 596333b commit 146464e

3 files changed

Lines changed: 101 additions & 87 deletions

File tree

src/runtime/contrib/cudnn/conv_forward.cc

Lines changed: 4 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -35,94 +35,11 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
3535
const int stride[], const int dilation[], DLTensor* x, DLTensor* w,
3636
DLTensor* y, const std::string& conv_dtype) {
3737
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
38-
// Set Mode
39-
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
40-
// Set Format
41-
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
42-
// Set Algo
43-
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
44-
// Set Device
45-
entry_ptr->conv_entry.device = x->device;
46-
// Set Data Type
47-
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
48-
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
49-
// Dims includes N and C
50-
int full_dims = dims + 2;
51-
52-
std::vector<int> dim(full_dims);
53-
std::vector<int> tensor_stride(full_dims);
54-
55-
// Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
56-
// in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int
57-
58-
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
59-
if (dims == 2) {
60-
// Set Desc
61-
CUDNN_CALL(cudnnSetConvolution2dDescriptor(
62-
entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
63-
dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
64-
int ni, ci, hi, wi;
65-
if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
66-
ni = 0;
67-
ci = 3;
68-
hi = 1;
69-
wi = 2;
70-
} else {
71-
ni = 0;
72-
ci = 1;
73-
hi = 2;
74-
wi = 3;
75-
}
76-
77-
// Set Filter
78-
CUDNN_CALL(cudnnSetFilter4dDescriptor(
79-
entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
80-
static_cast<int>(w->shape[ni]), static_cast<int>(w->shape[ci]),
81-
static_cast<int>(w->shape[hi]), static_cast<int>(w->shape[wi])));
82-
// Set Input
83-
CUDNN_CALL(cudnnSetTensor4dDescriptor(
84-
entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
85-
static_cast<int>(x->shape[ni]), static_cast<int>(x->shape[ci]),
86-
static_cast<int>(x->shape[hi]), static_cast<int>(x->shape[wi])));
87-
// Set Output
88-
CUDNN_CALL(cudnnSetTensor4dDescriptor(
89-
entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
90-
static_cast<int>(y->shape[ni]), static_cast<int>(y->shape[ci]),
91-
static_cast<int>(y->shape[hi]), static_cast<int>(y->shape[wi])));
92-
} else {
93-
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
94-
dilation, entry_ptr->conv_entry.mode,
95-
entry_ptr->conv_entry.data_type));
96-
97-
// Set Filter
98-
for (int i = 0; i < full_dims; i++) {
99-
dim[i] = static_cast<int>(w->shape[i]);
100-
}
101-
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
102-
entry_ptr->conv_entry.tensor_format, full_dims,
103-
dim.data()));
104-
// Set Input
105-
for (int i = 0; i < full_dims; i++) {
106-
dim[i] = static_cast<int>(x->shape[i]);
107-
}
108-
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
109-
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
110-
dim.data(), tensor_stride.data()));
111-
// Set Output
112-
for (int i = 0; i < full_dims; i++) {
113-
dim[i] = static_cast<int>(y->shape[i]);
114-
}
115-
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
116-
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
117-
dim.data(), tensor_stride.data()));
118-
}
119-
120-
if (cudnnGetVersion() > 7000) {
121-
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
122-
}
38+
SetConvDescriptors(entry_ptr, mode, format, algo, dims, groups, pad, stride, dilation, x, w, y,
39+
conv_dtype);
12340

124-
// Set workspace
125-
size_t workspace_size = 0;
41+
// Set workspace
42+
size_t workspace_size = 0;
12643
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
12744
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
12845
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,

src/runtime/contrib/cudnn/cudnn_utils.cc

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "cudnn_utils.h"
2424

2525
#include <dmlc/thread_local.h>
26+
#include <tvm/runtime/data_type.h>
2627
#include <tvm/runtime/registry.h>
2728

2829
namespace tvm {
@@ -160,6 +161,98 @@ void ConvEntry::CleanWorkspace() {
160161
workspace_size = 0;
161162
}
162163

164+
void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int mode, int format, int algo, int dims,
165+
int groups, const int pad[], const int stride[], const int dilation[],
166+
DLTensor* x, DLTensor* w, DLTensor* y, const std::string& conv_dtype) {
167+
// Set Mode
168+
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
169+
// Set Format
170+
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
171+
// Set Algo
172+
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
173+
// Set Device
174+
entry_ptr->conv_entry.device = x->device;
175+
// Set Data Type
176+
entry_ptr->conv_entry.data_type =
177+
CuDNNDataType::DLTypeToCuDNNType(runtime::String2DLDataType(conv_dtype));
178+
179+
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
180+
// Dims includes N and C
181+
int full_dims = dims + 2;
182+
183+
std::vector<int> dim(full_dims);
184+
std::vector<int> tensor_stride(full_dims);
185+
186+
// Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
187+
// in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int
188+
189+
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
190+
if (dims == 2) {
191+
// Set Desc
192+
CUDNN_CALL(cudnnSetConvolution2dDescriptor(
193+
entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
194+
dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
195+
int ni, ci, hi, wi;
196+
if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
197+
ni = 0;
198+
ci = 3;
199+
hi = 1;
200+
wi = 2;
201+
} else {
202+
ni = 0;
203+
ci = 1;
204+
hi = 2;
205+
wi = 3;
206+
}
207+
208+
// Set Filter
209+
CUDNN_CALL(cudnnSetFilter4dDescriptor(
210+
entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
211+
static_cast<int>(w->shape[ni]), static_cast<int>(w->shape[ci]),
212+
static_cast<int>(w->shape[hi]), static_cast<int>(w->shape[wi])));
213+
// Set Input
214+
CUDNN_CALL(cudnnSetTensor4dDescriptor(
215+
entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
216+
static_cast<int>(x->shape[ni]), static_cast<int>(x->shape[ci]),
217+
static_cast<int>(x->shape[hi]), static_cast<int>(x->shape[wi])));
218+
// Set Output
219+
CUDNN_CALL(cudnnSetTensor4dDescriptor(
220+
entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
221+
static_cast<int>(y->shape[ni]), static_cast<int>(y->shape[ci]),
222+
static_cast<int>(y->shape[hi]), static_cast<int>(y->shape[wi])));
223+
} else {
224+
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
225+
dilation, entry_ptr->conv_entry.mode,
226+
entry_ptr->conv_entry.data_type));
227+
228+
// Set Filter
229+
for (int i = 0; i < full_dims; i++) {
230+
dim[i] = static_cast<int>(w->shape[i]);
231+
}
232+
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
233+
entry_ptr->conv_entry.tensor_format, full_dims,
234+
dim.data()));
235+
// Set Input
236+
for (int i = 0; i < full_dims; i++) {
237+
dim[i] = static_cast<int>(x->shape[i]);
238+
}
239+
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
240+
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
241+
dim.data(), tensor_stride.data()));
242+
// Set Output
243+
for (int i = 0; i < full_dims; i++) {
244+
dim[i] = static_cast<int>(y->shape[i]);
245+
}
246+
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
247+
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
248+
dim.data(), tensor_stride.data()));
249+
}
250+
251+
if (cudnnGetVersion() > 7000) {
252+
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
253+
}
254+
}
255+
163256
// SoftmaxEntry
164257

165258
SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); }

src/runtime/contrib/cudnn/cudnn_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ struct CuDNNThreadEntry {
103103
static CuDNNThreadEntry* ThreadLocal(bool check_exists = true);
104104
}; // CuDNNThreadEntry
105105

106+
void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int mode, int format, int algo, int dims,
107+
int groups, const int pad[], const int stride[], const int dilation[],
108+
DLTensor* x, DLTensor* w, DLTensor* y, const std::string& conv_dtype);
109+
106110
} // namespace contrib
107111
} // namespace tvm
108112

0 commit comments

Comments
 (0)