Skip to content

Commit 47f35be

Browse files
committed
add dgrad stub
1 parent ebed032 commit 47f35be

1 file changed

Lines changed: 149 additions & 0 deletions

File tree

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file Use external cudnn utils function
22+
*/
23+
#include <tvm/runtime/data_type.h>
24+
#include <tvm/runtime/device_api.h>
25+
#include <tvm/runtime/registry.h>
26+
27+
#include "cudnn_utils.h"
28+
29+
namespace tvm {
30+
namespace contrib {
31+
32+
using namespace runtime;
33+
34+
void ConvolutionBackwardData(int mode, int format, int algo, int dims, int groups, const int pad[],
35+
const int stride[], const int dilation[], DLTensor* x, DLTensor* w,
36+
DLTensor* y, const std::string& conv_dtype) {
37+
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
38+
// Set Mode
39+
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
40+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
41+
y->shape, x->dtype, conv_dtype);
42+
// Set Device
43+
entry_ptr->conv_entry.device = x->device;
44+
// Set Algo
45+
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
46+
47+
// Set workspace
48+
size_t workspace_size = 0;
49+
CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
50+
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
51+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
52+
entry_ptr->conv_entry.fwd_algo, &workspace_size));
53+
entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
54+
CUDNN_CALL(cudnnConvolutionForward(
55+
entry_ptr->handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
56+
entry_ptr->conv_entry.input_desc, x->data, entry_ptr->conv_entry.filter_desc, w->data,
57+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
58+
entry_ptr->conv_entry.workspace, workspace_size,
59+
CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
60+
entry_ptr->conv_entry.output_desc, y->data));
61+
}
62+
63+
void BackwardDataFindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
64+
const int dilation[], const int x_dim[], const int w_dim[],
65+
const int y_dim[], const std::string& data_dtype,
66+
const std::string& conv_dtype, TVMRetValue* ret) {
67+
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
68+
const int full_dims = dims + 2;
69+
std::vector<int64_t> x_dim_int64(full_dims);
70+
std::vector<int64_t> w_dim_int64(full_dims);
71+
std::vector<int64_t> y_dim_int64(full_dims);
72+
for (int i = 0; i < full_dims; ++i) {
73+
x_dim_int64[i] = x_dim[i];
74+
w_dim_int64[i] = w_dim[i];
75+
y_dim_int64[i] = y_dim[i];
76+
}
77+
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(),
78+
w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype),
79+
conv_dtype);
80+
81+
int returned_algo_count = 0;
82+
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
83+
CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(
84+
entry_ptr->handle, entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.filter_desc,
85+
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
86+
CUDNN_CONVOLUTION_FWD_ALGO_COUNT, &returned_algo_count, perf_results));
87+
88+
const std::vector<std::string> fwd_algo_names{"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
89+
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM",
90+
"CUDNN_CONVOLUTION_FWD_ALGO_GEMM",
91+
"CUDNN_CONVOLUTION_FWD_ALGO_DIRECT",
92+
"CUDNN_CONVOLUTION_FWD_ALGO_FFT",
93+
"CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING",
94+
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD",
95+
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"};
96+
97+
auto best_algo = perf_results[0].algo;
98+
LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms, choosing "
99+
<< fwd_algo_names[best_algo];
100+
for (int i = 0; i < returned_algo_count; ++i) {
101+
LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
102+
<< " - time: " << perf_results[i].time << " ms"
103+
<< ", Memory: " << perf_results[i].memory;
104+
}
105+
106+
ret[0] = best_algo;
107+
}
108+
109+
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv2d.backward_data")
110+
.set_body([](TVMArgs args, TVMRetValue* ret) {
111+
int mode = args[0];
112+
int format = args[1];
113+
int algo = args[2];
114+
int pad_v[2], stride_v[2], dilation_v[2];
115+
for (int i = 0; i < 2; i++) {
116+
pad_v[i] = args[3 + i];
117+
stride_v[i] = args[5 + i];
118+
dilation_v[i] = args[7 + i];
119+
}
120+
DLTensor* x = args[9];
121+
DLTensor* w = args[10];
122+
DLTensor* y = args[11];
123+
std::string conv_dtype = args[12];
124+
int groups = args[13];
125+
126+
ConvolutionBackwardData(mode, format, algo, 2, groups, pad_v, stride_v, dilation_v, x, w, y,
127+
conv_dtype);
128+
});
129+
130+
TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.backward_data_find_algo")
131+
.set_body([](TVMArgs args, TVMRetValue* ret) {
132+
int format = args[0];
133+
int dims = args[1];
134+
int* pad = static_cast<int*>(static_cast<void*>(args[2]));
135+
int* stride = static_cast<int*>(static_cast<void*>(args[3]));
136+
int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
137+
int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
138+
int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
139+
int* y_dim = static_cast<int*>(static_cast<void*>(args[7]));
140+
std::string data_dtype = args[8];
141+
std::string conv_dtype = args[9];
142+
int groups = args[10];
143+
144+
BackwardDataFindAlgo(format, dims, groups, pad, stride, dilation, x_dim, w_dim, y_dim,
145+
data_dtype, conv_dtype, ret);
146+
});
147+
148+
} // namespace contrib
149+
} // namespace tvm

0 commit comments

Comments
 (0)