Skip to content

Commit e2e423d

Browse files
author
Tristan Konolige
committed
[TOPI] Add support for groupped conv3d
Change conv3d to use generic conv implementation which supports groupped convolutions. Also, remove support for non-float16 tensorcore operations as they cause large degradation in accuracy. Generic conv now supports autoscheduler.
1 parent f6f252f commit e2e423d

File tree

14 files changed

+269
-350
lines changed

14 files changed

+269
-350
lines changed

python/tvm/relay/op/strategy/cuda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
619619
and stride_w == 1
620620
and dilation_h == 1
621621
and dilation_w == 1
622+
and attrs["groups"] == 1
622623
):
623624
strategy.add_implementation(
624625
wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd),
@@ -641,7 +642,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
641642
(N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0)
642643
or (N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0)
643644
or (N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0)
644-
):
645+
) and out_type == "float16":
645646
strategy.add_implementation(
646647
wrap_compute_conv3d(topi.cuda.conv3d_ndhwc_tensorcore),
647648
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc_tensorcore),

python/tvm/relay/op/strategy/generic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,8 @@ def _compute_conv3d(attrs, inputs, out_type):
543543
(dilation_d, dilation_h, dilation_w) = dilation
544544
if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
545545
raise ValueError("Dilation should be positive value")
546-
if groups != 1:
547-
raise ValueError("Not support arbitrary group number for conv3d")
548546

549-
args = [inputs[0], inputs[1], strides, padding, dilation]
547+
args = [inputs[0], inputs[1], strides, padding, dilation, groups]
550548
if need_layout:
551549
args.append(layout)
552550
args.append(out_dtype)

python/tvm/topi/cuda/conv3d.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
@autotvm.register_topi_compute("conv3d_ncdhw.cuda")
29-
def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
29+
def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"):
3030
"""Conv3D operator in NCDHW layout for cuda backend.
3131
3232
Parameters
@@ -49,6 +49,9 @@ def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype="float
4949
dilation: int or a list/tuple of three ints
5050
dilation size, or [dilation_depth, dilation_height, dilation_width]
5151
52+
groups: int
53+
Number of groups
54+
5255
out_dtype: str
5356
The output type. This is used for mixed precision.
5457
@@ -57,7 +60,7 @@ def conv3d_ncdhw(cfg, data, kernel, strides, padding, dilation, out_dtype="float
5760
output : tvm.te.Tensor
5861
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
5962
"""
60-
return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
63+
return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, groups, out_dtype)
6164

6265

6366
@autotvm.register_topi_schedule("conv3d_ncdhw.cuda")
@@ -82,15 +85,15 @@ def schedule_conv3d_ncdhw(cfg, outs):
8285
s = te.create_schedule([x.op for x in outs])
8386

8487
def _callback(op):
85-
if op.tag == "conv3d_ncdhw":
88+
if "conv3d_ncdhw" in op.tag:
8689
schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NCDHW", "conv3d_ncdhw.cuda")
8790

8891
traverse_inline(s, outs[0].op, _callback)
8992
return s
9093

9194

9295
@autotvm.register_topi_compute("conv3d_ndhwc.cuda")
93-
def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float32"):
96+
def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, groups, out_dtype="float32"):
9497
"""Conv3d operator in NDHWC layout for cuda backend.
9598
9699
Parameters
@@ -110,12 +113,15 @@ def conv3d_ndhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float
110113
dilation: int or a list/tuple of three ints
111114
dilation size, or [dilation_depth, dilation_height, dilation_width]
112115
116+
groups: int
117+
Number of groups
118+
113119
Returns
114120
-------
115121
Output : tvm.te.Tensor
116122
5-D with shape [batch, out_depth, out_height, out_width, out_channel]
117123
"""
118-
return nn.conv3d_ndhwc(data, kernel, strides, padding, dilation, out_dtype)
124+
return nn.conv3d_ndhwc(data, kernel, strides, padding, dilation, groups, out_dtype)
119125

120126

121127
@autotvm.register_topi_schedule("conv3d_ndhwc.cuda")
@@ -140,7 +146,7 @@ def schedule_conv3d_ndhwc(cfg, outs):
140146
s = te.create_schedule([x.op for x in outs])
141147

142148
def _callback(op):
143-
if op.tag == "conv3d_ndhwc":
149+
if "conv3d_ndhwc" in op.tag:
144150
schedule_direct_conv3d_cuda(cfg, s, op.output(0), "NDHWC", "conv3d_ndhwc.cuda")
145151

146152
traverse_inline(s, outs[0].op, _callback)
@@ -149,7 +155,7 @@ def _callback(op):
149155

150156
@autotvm.register_topi_compute("conv3d_cudnn.cuda")
151157
def conv3d_cudnn(
152-
cfg, data, kernel, strides, padding, dilation, layout="NCDHW", out_dtype="float32"
158+
cfg, data, kernel, strides, padding, dilation, groups, layout="NCDHW", out_dtype="float32"
153159
):
154160
"""Conv3D operator for cuda backend.
155161
@@ -194,6 +200,8 @@ def conv3d_cudnn(
194200
raise ValueError("Unsupported layout %s in cudnn" % layout)
195201
CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
196202

203+
assert groups == 1, "conv3d_cudnn does not support groups"
204+
197205
# handle dilation
198206
stride_d, stride_h, stride_w = (
199207
(strides, strides, strides) if isinstance(strides, int) else strides

python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py

Lines changed: 9 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -21,84 +21,17 @@
2121
import tvm
2222
from tvm import te
2323
from tvm import autotvm
24-
from ..utils import get_const_tuple, traverse_inline, simplify
25-
from ..nn.pad import pad
26-
from ..nn.utils import get_pad_tuple3d
24+
from ..utils import get_const_tuple, traverse_inline
2725
from .tensor_intrin import intrin_wmma_load_matrix_A
2826
from .tensor_intrin import intrin_wmma_load_matrix_W
2927
from .tensor_intrin import intrin_wmma_store_matrix
3028
from .tensor_intrin import intrin_wmma_gemm
29+
from ..nn.conv2d import conv
3130

3231

3332
def ndhwc_tensorcore_cuda(cfg, Input, Filter, stride, padding, dilation, out_dtype):
3433
"""Compute declaration for conv3d tensorcore function"""
35-
assert isinstance(stride, int) or len(stride) == 3
36-
assert isinstance(dilation, int) or len(dilation) == 3
37-
38-
if isinstance(stride, int):
39-
stride_d = stride_h = stride_w = stride
40-
else:
41-
stride_d, stride_h, stride_w = stride
42-
43-
if isinstance(dilation, int):
44-
dilation_d = dilation_h = dilation_w = dilation
45-
else:
46-
dilation_d, dilation_h, dilation_w = dilation
47-
48-
batch, in_depth, in_height, in_width, in_channel = get_const_tuple(Input.shape)
49-
kernel_d, kernel_h, kernel_w, _, num_filter = get_const_tuple(Filter.shape)
50-
assert (
51-
(batch % 16 == 0 and in_channel % 16 == 0 and num_filter % 16 == 0)
52-
or (batch % 8 == 0 and in_channel % 16 == 0 and num_filter % 32 == 0)
53-
or (batch % 32 == 0 and in_channel % 16 == 0 and num_filter % 8 == 0)
54-
), (
55-
"The shape of (batch, in_channel, num_filter) "
56-
"must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
57-
)
58-
59-
# compute the output shape
60-
dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
61-
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
62-
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
63-
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
64-
padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)
65-
)
66-
out_channel = num_filter
67-
out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1)
68-
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
69-
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
70-
pad_before = [0, pad_front, pad_top, pad_left, 0]
71-
pad_after = [0, pad_back, pad_down, pad_right, 0]
72-
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
73-
rc = te.reduce_axis((0, in_channel), name="rc")
74-
rz = te.reduce_axis((0, kernel_d), name="rz")
75-
ry = te.reduce_axis((0, kernel_h), name="ry")
76-
rx = te.reduce_axis((0, kernel_w), name="rx")
77-
# convert data type of input feature maps and weights
78-
# TODO: add checking here, datatype casting may cause precision loss
79-
TransPaddedInput = te.compute(
80-
PaddedInput.shape, lambda n, d, h, w, c: PaddedInput[n, d, h, w, c].astype("float16")
81-
)
82-
TransFilter = te.compute(
83-
Filter.shape, lambda d, h, w, i, o: Filter[d, h, w, i, o].astype("float16")
84-
)
85-
Output = te.compute(
86-
(batch, out_depth, out_height, out_width, out_channel),
87-
lambda nn, zz, yy, xx, ff: te.sum(
88-
TransPaddedInput[
89-
nn,
90-
zz * stride_d + rz * dilation_d,
91-
yy * stride_h + ry * dilation_h,
92-
xx * stride_w + rx * dilation_w,
93-
rc,
94-
].astype(out_dtype)
95-
* TransFilter[rz, ry, rx, rc, ff].astype(out_dtype),
96-
axis=[rz, ry, rx, rc],
97-
),
98-
name="Conv3dOutput",
99-
tag="conv3d_ndhwc_tensorcore",
100-
)
101-
return Output
34+
return conv(Input, Filter, stride, padding, dilation, 1, "NDHWC", out_dtype)
10235

10336

10437
def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
@@ -109,12 +42,9 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
10942
in_dtype = trans_paddata.dtype
11043
batch, _, _, _, _ = get_const_tuple(Conv.shape)
11144
_, _, _, _, out_channels = get_const_tuple(kernel.shape)
112-
paddata = s[trans_paddata].op.input_tensors
11345

114-
# inline the pad and dtype transform
46+
# inline the pad
11547
s[trans_paddata].compute_inline()
116-
s[kernel].compute_inline()
117-
s[paddata[0]].compute_inline()
11848

11949
# Designate the memory hierarchy
12050
AS = s.cache_read(trans_paddata, "shared", [Conv])
@@ -172,6 +102,8 @@ def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
172102
wmma_n = 32
173103
elif wmma_m == 32:
174104
wmma_n = 8
105+
else:
106+
raise RuntimeError("Invalid wmma size")
175107

176108
warp_size = 32
177109

@@ -335,8 +267,9 @@ def get_strides(extents):
335267

336268

337269
@autotvm.register_topi_compute("conv3d_ndhwc_tensorcore.cuda")
338-
def conv3d_ndhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, out_dtype):
270+
def conv3d_ndhwc_tensorcore(cfg, data, kernel, strides, padding, dilation, groups, out_dtype):
339271
"""Compute conv3d with tensorcore for NDHWC layout"""
272+
assert groups == 1, "tensorcore conv3d does not support groups"
340273
return ndhwc_tensorcore_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype)
341274

342275

@@ -346,7 +279,7 @@ def schedule_conv3d_ndhwc_tensorcore(cfg, outs):
346279
s = te.create_schedule([x.op for x in outs])
347280

348281
def _callback(op):
349-
if "conv3d_ndhwc_tensorcore" in op.tag:
282+
if "conv3d_ndhwc" in op.tag:
350283
schedule_ndhwc_tensorcore_cuda(cfg, s, op.output(0))
351284

352285
traverse_inline(s, outs[0].op, _callback)

0 commit comments

Comments
 (0)