@@ -35,94 +35,11 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
3535 const int stride[], const int dilation[], DLTensor* x, DLTensor* w,
3636 DLTensor* y, const std::string& conv_dtype) {
3737 CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal ();
38- // Set Mode
39- entry_ptr->conv_entry .mode = static_cast <cudnnConvolutionMode_t>(mode);
40- // Set Format
41- entry_ptr->conv_entry .tensor_format = static_cast <cudnnTensorFormat_t>(format);
42- // Set Algo
43- entry_ptr->conv_entry .fwd_algo = static_cast <cudnnConvolutionFwdAlgo_t>(algo);
44- // Set Device
45- entry_ptr->conv_entry .device = x->device ;
46- // Set Data Type
47- entry_ptr->conv_entry .data_type = CuDNNDataType::DLTypeToCuDNNType (String2DLDataType (conv_dtype));
48- cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType (x->dtype );
49- // Dims includes N and C
50- int full_dims = dims + 2 ;
51-
52- std::vector<int > dim (full_dims);
53- std::vector<int > tensor_stride (full_dims);
54-
55- // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
56- // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int
57-
58- CUDNN_CALL (cudnnSetConvolutionGroupCount (entry_ptr->conv_entry .conv_desc , groups));
59- if (dims == 2 ) {
60- // Set Desc
61- CUDNN_CALL (cudnnSetConvolution2dDescriptor (
62- entry_ptr->conv_entry .conv_desc , pad[0 ], pad[1 ], stride[0 ], stride[1 ], dilation[0 ],
63- dilation[1 ], entry_ptr->conv_entry .mode , entry_ptr->conv_entry .data_type ));
64- int ni, ci, hi, wi;
65- if (entry_ptr->conv_entry .tensor_format == CUDNN_TENSOR_NHWC) {
66- ni = 0 ;
67- ci = 3 ;
68- hi = 1 ;
69- wi = 2 ;
70- } else {
71- ni = 0 ;
72- ci = 1 ;
73- hi = 2 ;
74- wi = 3 ;
75- }
76-
77- // Set Filter
78- CUDNN_CALL (cudnnSetFilter4dDescriptor (
79- entry_ptr->conv_entry .filter_desc , data_type, entry_ptr->conv_entry .tensor_format ,
80- static_cast <int >(w->shape [ni]), static_cast <int >(w->shape [ci]),
81- static_cast <int >(w->shape [hi]), static_cast <int >(w->shape [wi])));
82- // Set Input
83- CUDNN_CALL (cudnnSetTensor4dDescriptor (
84- entry_ptr->conv_entry .input_desc , entry_ptr->conv_entry .tensor_format , data_type,
85- static_cast <int >(x->shape [ni]), static_cast <int >(x->shape [ci]),
86- static_cast <int >(x->shape [hi]), static_cast <int >(x->shape [wi])));
87- // Set Output
88- CUDNN_CALL (cudnnSetTensor4dDescriptor (
89- entry_ptr->conv_entry .output_desc , entry_ptr->conv_entry .tensor_format , data_type,
90- static_cast <int >(y->shape [ni]), static_cast <int >(y->shape [ci]),
91- static_cast <int >(y->shape [hi]), static_cast <int >(y->shape [wi])));
92- } else {
93- CUDNN_CALL (cudnnSetConvolutionNdDescriptor (entry_ptr->conv_entry .conv_desc , dims, pad, stride,
94- dilation, entry_ptr->conv_entry .mode ,
95- entry_ptr->conv_entry .data_type ));
96-
97- // Set Filter
98- for (int i = 0 ; i < full_dims; i++) {
99- dim[i] = static_cast <int >(w->shape [i]);
100- }
101- CUDNN_CALL (cudnnSetFilterNdDescriptor (entry_ptr->conv_entry .filter_desc , data_type,
102- entry_ptr->conv_entry .tensor_format , full_dims,
103- dim.data ()));
104- // Set Input
105- for (int i = 0 ; i < full_dims; i++) {
106- dim[i] = static_cast <int >(x->shape [i]);
107- }
108- GetCudnnStride (full_dims, dim.data (), tensor_stride.data ());
109- CUDNN_CALL (cudnnSetTensorNdDescriptor (entry_ptr->conv_entry .input_desc , data_type, full_dims,
110- dim.data (), tensor_stride.data ()));
111- // Set Output
112- for (int i = 0 ; i < full_dims; i++) {
113- dim[i] = static_cast <int >(y->shape [i]);
114- }
115- GetCudnnStride (full_dims, dim.data (), tensor_stride.data ());
116- CUDNN_CALL (cudnnSetTensorNdDescriptor (entry_ptr->conv_entry .output_desc , data_type, full_dims,
117- dim.data (), tensor_stride.data ()));
118- }
119-
120- if (cudnnGetVersion () > 7000 ) {
121- CUDNN_CALL (cudnnSetConvolutionMathType (entry_ptr->conv_entry .conv_desc , CUDNN_TENSOR_OP_MATH))
122- }
38+ SetConvDescriptors (entry_ptr, mode, format, algo, dims, groups, pad, stride, dilation, x, w, y,
39+ conv_dtype);
12340
124- // Set workspace
125- size_t workspace_size = 0 ;
41+ // Set workspace
42+ size_t workspace_size = 0 ;
12643 CUDNN_CALL (cudnnGetConvolutionForwardWorkspaceSize (
12744 entry_ptr->handle , entry_ptr->conv_entry .input_desc , entry_ptr->conv_entry .filter_desc ,
12845 entry_ptr->conv_entry .conv_desc , entry_ptr->conv_entry .output_desc ,
0 commit comments