2121import tvm
2222from tvm import te
2323from tvm import autotvm
24- from ..utils import get_const_tuple , traverse_inline , simplify
25- from ..nn .pad import pad
26- from ..nn .utils import get_pad_tuple3d
24+ from ..utils import get_const_tuple , traverse_inline
2725from .tensor_intrin import intrin_wmma_load_matrix_A
2826from .tensor_intrin import intrin_wmma_load_matrix_W
2927from .tensor_intrin import intrin_wmma_store_matrix
3028from .tensor_intrin import intrin_wmma_gemm
29+ from ..nn .conv2d import conv
3130
3231
3332def ndhwc_tensorcore_cuda (cfg , Input , Filter , stride , padding , dilation , out_dtype ):
3433 """Compute declaration for conv3d tensorcore function"""
35- assert isinstance (stride , int ) or len (stride ) == 3
36- assert isinstance (dilation , int ) or len (dilation ) == 3
37-
38- if isinstance (stride , int ):
39- stride_d = stride_h = stride_w = stride
40- else :
41- stride_d , stride_h , stride_w = stride
42-
43- if isinstance (dilation , int ):
44- dilation_d = dilation_h = dilation_w = dilation
45- else :
46- dilation_d , dilation_h , dilation_w = dilation
47-
48- batch , in_depth , in_height , in_width , in_channel = get_const_tuple (Input .shape )
49- kernel_d , kernel_h , kernel_w , _ , num_filter = get_const_tuple (Filter .shape )
50- assert (
51- (batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0 )
52- or (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0 )
53- or (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0 )
54- ), (
55- "The shape of (batch, in_channel, num_filter) "
56- "must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
57- )
58-
59- # compute the output shape
60- dilated_kernel_d = (kernel_d - 1 ) * dilation_d + 1
61- dilated_kernel_h = (kernel_h - 1 ) * dilation_h + 1
62- dilated_kernel_w = (kernel_w - 1 ) * dilation_w + 1
63- pad_front , pad_top , pad_left , pad_back , pad_down , pad_right = get_pad_tuple3d (
64- padding , (dilated_kernel_d , dilated_kernel_h , dilated_kernel_w )
65- )
66- out_channel = num_filter
67- out_depth = simplify ((in_depth - dilated_kernel_d + pad_front + pad_back ) // stride_d + 1 )
68- out_height = simplify ((in_height - dilated_kernel_h + pad_top + pad_down ) // stride_h + 1 )
69- out_width = simplify ((in_width - dilated_kernel_w + pad_left + pad_right ) // stride_w + 1 )
70- pad_before = [0 , pad_front , pad_top , pad_left , 0 ]
71- pad_after = [0 , pad_back , pad_down , pad_right , 0 ]
72- PaddedInput = pad (Input , pad_before , pad_after , name = "PaddedInput" )
73- rc = te .reduce_axis ((0 , in_channel ), name = "rc" )
74- rz = te .reduce_axis ((0 , kernel_d ), name = "rz" )
75- ry = te .reduce_axis ((0 , kernel_h ), name = "ry" )
76- rx = te .reduce_axis ((0 , kernel_w ), name = "rx" )
77- # convert data type of input feature maps and weights
78- # TODO: add checking here, datatype casting may cause precision loss
79- TransPaddedInput = te .compute (
80- PaddedInput .shape , lambda n , d , h , w , c : PaddedInput [n , d , h , w , c ].astype ("float16" )
81- )
82- TransFilter = te .compute (
83- Filter .shape , lambda d , h , w , i , o : Filter [d , h , w , i , o ].astype ("float16" )
84- )
85- Output = te .compute (
86- (batch , out_depth , out_height , out_width , out_channel ),
87- lambda nn , zz , yy , xx , ff : te .sum (
88- TransPaddedInput [
89- nn ,
90- zz * stride_d + rz * dilation_d ,
91- yy * stride_h + ry * dilation_h ,
92- xx * stride_w + rx * dilation_w ,
93- rc ,
94- ].astype (out_dtype )
95- * TransFilter [rz , ry , rx , rc , ff ].astype (out_dtype ),
96- axis = [rz , ry , rx , rc ],
97- ),
98- name = "Conv3dOutput" ,
99- tag = "conv3d_ndhwc_tensorcore" ,
100- )
101- return Output
34+ return conv (Input , Filter , stride , padding , dilation , 1 , "NDHWC" , out_dtype )
10235
10336
10437def schedule_ndhwc_tensorcore_cuda (cfg , s , Conv ):
@@ -109,12 +42,9 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
10942 in_dtype = trans_paddata .dtype
11043 batch , _ , _ , _ , _ = get_const_tuple (Conv .shape )
11144 _ , _ , _ , _ , out_channels = get_const_tuple (kernel .shape )
112- paddata = s [trans_paddata ].op .input_tensors
11345
114- # inline the pad and dtype transform
46+ # inline the pad
11547 s [trans_paddata ].compute_inline ()
116- s [kernel ].compute_inline ()
117- s [paddata [0 ]].compute_inline ()
11848
11949 # Designate the memory hierarchy
12050 AS = s .cache_read (trans_paddata , "shared" , [Conv ])
@@ -172,6 +102,8 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
172102 wmma_n = 32
173103 elif wmma_m == 32 :
174104 wmma_n = 8
105+ else :
106+ raise RuntimeError ("Invalid wmma size" )
175107
176108 warp_size = 32
177109
@@ -335,8 +267,9 @@ def get_strides(extents):
335267
336268
337269@autotvm .register_topi_compute ("conv3d_ndhwc_tensorcore.cuda" )
338- def conv3d_ndhwc_tensorcore (cfg , data , kernel , strides , padding , dilation , out_dtype ):
270+ def conv3d_ndhwc_tensorcore (cfg , data , kernel , strides , padding , dilation , groups , out_dtype ):
339271 """Compute conv3d with tensorcore for NDHWC layout"""
272+ assert groups == 1 , "tensorcore conv3d does not support groups"
340273 return ndhwc_tensorcore_cuda (cfg , data , kernel , strides , padding , dilation , out_dtype )
341274
342275
@@ -346,7 +279,7 @@ def schedule_conv3d_ndhwc_tensorcore(cfg, outs):
346279 s = te .create_schedule ([x .op for x in outs ])
347280
348281 def _callback (op ):
349- if "conv3d_ndhwc_tensorcore " in op .tag :
282+ if "conv3d_ndhwc " in op .tag :
350283 schedule_ndhwc_tensorcore_cuda (cfg , s , op .output (0 ))
351284
352285 traverse_inline (s , outs [0 ].op , _callback )
0 commit comments