Skip to content

Commit 03b4444

Browse files
masahiylc
authored andcommitted
[CUDNN] Support gradient kernels (apache#9986)
* Dgrad nchw, nhwc, fp16 working commit 426e5dc Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 11:48:53 2022 +0900 black commit 211a58b Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 11:43:52 2022 +0900 fp16 also works commit c2a34d4 Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 11:36:36 2022 +0900 nhwc test also worked commit c0609ab Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 11:21:23 2022 +0900 nchw test worked commit 2bf68c7 Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 10:41:35 2022 +0900 add test stub commit c86b128 Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 10:32:09 2022 +0900 add python definition stub commit 3166952 Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 06:57:18 2022 +0900 bwd filter compiled commit e311ba3 Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 06:27:55 2022 +0900 dgrad compiled commit 47f35be Author: Masahiro Masuda <masahi129@gmail.com> Date: Tue Jan 18 06:16:43 2022 +0900 add dgrad stub commit ebed032 Author: Masahiro Masuda <masahi129@gmail.com> Date: Mon Jan 17 17:01:56 2022 +0900 cpplint commit 834f54a Author: Masahiro Masuda <masahi129@gmail.com> Date: Mon Jan 17 16:55:58 2022 +0900 remove cudnn get output commit dcbd9c9 Author: Masahiro Masuda <masahi129@gmail.com> Date: Mon Jan 17 16:28:07 2022 +0900 more refactor commit 146464e Author: Masahiro Masuda <masahi129@gmail.com> Date: Mon Jan 17 15:57:35 2022 +0900 Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc * add python function for cudnn wgrad * adding wgrad test * black * wgrad nchw and nhwc worked * remove bwd algo name stuff * compute output shape properly * swap arg order in wgrad * add kernel size arg in test * black * cleanup * more fix * fix dgrad test * support running relay conv2d_backward_weight directly with cudnn * black * refactor reference function to support nhwc * removed unused function * lint * enable offloading conv2d_transpose to cudnn dgrad * relax tol * name fix, remove print
1 parent 398fd17 commit 03b4444

17 files changed

Lines changed: 996 additions & 74 deletions

File tree

python/tvm/contrib/cudnn.py

Lines changed: 416 additions & 44 deletions
Large diffs are not rendered by default.

python/tvm/relay/op/nn/_nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,10 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
10621062
reg.register_injective_schedule("nn.batch_to_space_nd")
10631063

10641064

1065+
reg.register_strategy("nn.conv2d_backward_weight", strategy.conv2d_backward_weight_strategy)
1066+
reg.register_pattern("nn.conv2d_backward_weight", OpPattern.OUT_ELEMWISE_FUSABLE)
1067+
1068+
10651069
@reg.register_legalize("nn.conv2d_backward_weight")
10661070
def legalize_conv2d_backward_weight(attrs, inputs, types):
10671071
"""Legalize conv2d_backward_weight op.

python/tvm/relay/op/strategy/cuda.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,25 @@ def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target):
564564
return strategy
565565

566566

567+
@conv2d_backward_weight_strategy.register(["cuda"])
568+
def conv2d_backward_weight_strategy_cuda(attrs, inputs, out_type, target):
569+
"""conv2d_backward_weight cuda strategy"""
570+
strategy = _op.OpStrategy()
571+
if target.kind.name == "cuda" and "cudnn" in target.libs:
572+
strategy.add_implementation(
573+
wrap_compute_conv2d_backward_weight(topi.cuda.conv2d_backward_weight_cudnn),
574+
wrap_topi_schedule(topi.generic.schedule_extern),
575+
name="conv2d_backward_weight_strategy.cudnn",
576+
plevel=15,
577+
)
578+
else:
579+
raise RuntimeError(
580+
"conv2d_backward_weight on cuda is currently only supported with cudnn. "
581+
"Please run Legalize pass to decompose this op into supported ops."
582+
)
583+
return strategy
584+
585+
567586
@conv2d_transpose_strategy.register(["cuda", "gpu"])
568587
def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
569588
"""conv2d_transpose cuda strategy"""
@@ -579,6 +598,15 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
579598
wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw),
580599
name="conv2d_transpose_nchw.cuda",
581600
)
601+
602+
if target.kind.name == "cuda" and "cudnn" in target.libs and attrs.kernel_layout == "IOHW":
603+
strategy.add_implementation(
604+
wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_cudnn),
605+
wrap_topi_schedule(topi.generic.schedule_extern),
606+
name="conv2d_transpose.cudnn.cuda",
607+
plevel=25,
608+
)
609+
# TODO(masahi): Support conv2d_transpose NHWC.
582610
return strategy
583611

584612

python/tvm/relay/op/strategy/generic.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,3 +1849,41 @@ def einsum_strategy(attrs, inputs, out_type, target):
18491849
name="einsum.generic",
18501850
)
18511851
return strategy
1852+
1853+
1854+
# conv2d_backward_weight
1855+
def wrap_compute_conv2d_backward_weight(topi_compute):
1856+
"""wrap conv2d_backward_weight topi compute"""
1857+
1858+
def _compute_conv2d_backward_weight(attrs, inputs, out_dtype):
1859+
kernel_size = get_const_tuple(attrs.kernel_size)
1860+
padding = get_const_tuple(attrs.padding)
1861+
strides = get_const_tuple(attrs.strides)
1862+
dilation = get_const_tuple(attrs.dilation)
1863+
groups = attrs.groups
1864+
out_dtype = attrs.out_dtype
1865+
layout = attrs.data_layout
1866+
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
1867+
out = topi_compute(
1868+
inputs[0],
1869+
inputs[1],
1870+
kernel_size,
1871+
padding,
1872+
strides,
1873+
dilation,
1874+
groups,
1875+
layout,
1876+
out_dtype,
1877+
)
1878+
return [out]
1879+
1880+
return _compute_conv2d_backward_weight
1881+
1882+
1883+
@override_native_generic_func("conv2d_backward_weight_strategy")
1884+
def conv2d_backward_weight_strategy(attrs, inputs, out_type, target):
1885+
"""wgrad generic strategy"""
1886+
raise RuntimeError(
1887+
"conv2d_backward_weight is currently only supported with cudnn. "
1888+
"Please run Legalize pass to decompose this op into supported ops."
1889+
)

python/tvm/topi/cuda/conv2d.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,22 @@ def conv2d_cudnn(
123123
def schedule_conv2d_cudnn(cfg, outs):
124124
"""Create the schedule for conv2d_cudnn"""
125125
return generic.schedule_extern(outs)
126+
127+
128+
def conv2d_backward_weight_cudnn(
129+
dy, x, kernel_size, padding, stride, dilation, groups, layout, output_dtype
130+
):
131+
"""Compute conv2d wgrad using CuDNN library"""
132+
assert layout in ["NCHW", "NHWC"]
133+
return cudnn.conv_backward_filter(
134+
dy,
135+
x,
136+
kernel_size,
137+
padding,
138+
stride,
139+
dilation,
140+
conv_mode=1,
141+
tensor_format=0 if layout == "NCHW" else 1,
142+
conv_dtype=output_dtype,
143+
groups=groups,
144+
)

python/tvm/topi/cuda/conv2d_transpose_nchw.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import tvm
2121
from tvm import te
22+
from tvm.contrib import cudnn
2223
from tvm import autotvm
2324
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
2425
from .. import nn
@@ -286,3 +287,10 @@ def _callback(op):
286287
traverse_inline(s, outs[0].op, _callback)
287288

288289
return s
290+
291+
292+
def conv2d_transpose_cudnn(x, w, stride, padding, out_dtype, output_padding=(0, 0)):
293+
"""Compute conv2d_tranpose using cudnn dgrad kernel"""
294+
return cudnn.conv_backward_data(
295+
x, w, padding, stride, (1, 1), 1, 0, out_dtype, groups=1, output_padding=output_padding
296+
)

python/tvm/topi/nn/conv2d_transpose.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ def conv2d_transpose_legalize(attrs, inputs, types):
298298
result : tvm.relay.Expr
299299
The legalized expr
300300
"""
301-
302301
data, kernel = inputs
303302
kernel_layout = attrs["kernel_layout"]
304303
if attrs["data_layout"] == "NHWC":

python/tvm/topi/testing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,4 @@
7575
from .nll_loss import nll_loss
7676
from .dense import dense
7777
from .searchsorted import searchsorted_ref
78-
from .conv2d_backcward_weight_python import conv2d_backward_weight_nchw_python
78+
from .conv2d_backcward_weight_python import conv2d_backward_weight_python

python/tvm/topi/testing/conv2d_backcward_weight_python.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
4242
4343
Returns
4444
-------
45-
b_np : np.ndarray
45+
dw_np : np.ndarray
4646
4-D with shape [num_filter, in_channel, filter_height, filter_width]
4747
4848
"""
@@ -74,3 +74,45 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding
7474
dw[k, c, r, s] = acc
7575

7676
return dw
77+
78+
79+
def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"):
80+
"""Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout.
81+
82+
Parameters
83+
----------
84+
dy_np : numpy.ndarray
85+
4-D with shape [batch, in_channel, out_height, out_width] for NCHW layout
86+
87+
x_np : numpy.ndarray
88+
4-D with shape [batch, in_channel, in_height, in_width] for NCHW layout
89+
90+
kernel_size : tuple of two ints
91+
Height and width of the weight
92+
93+
stride : tuple of two ints
94+
Stride size, or [stride_height, stride_width]
95+
96+
padding : tuple of two ints
97+
Spatial padding, or [pad_h, pad_w]
98+
99+
layout: string
100+
Layout of dy_np and x_np
101+
102+
Returns
103+
-------
104+
dw_np : np.ndarray
105+
Tensor of shape [num_filter, in_channel, filter_height, filter_width] for NCHW layout,
106+
[num_filter, filter_height, filter_width, in_channel] for NHWC layout.
107+
"""
108+
if layout == "NCHW":
109+
return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding)
110+
111+
dw_np_oihw = conv2d_backward_weight_nchw_python(
112+
np.transpose(dy_np, [0, 3, 1, 2]),
113+
np.transpose(x_np, [0, 3, 1, 2]),
114+
kernel_size,
115+
stride,
116+
padding,
117+
)
118+
return np.transpose(dw_np_oihw, [0, 2, 3, 1])

python/tvm/topi/testing/conv2d_transpose_python.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
7373
dilated_a_np.shape[2] + bpad_top + bpad_bottom,
7474
dilated_a_np.shape[3] + bpad_left + bpad_right,
7575
)
76-
)
76+
).astype(a_np.dtype)
7777
padded_a_np[
7878
:,
7979
:,
@@ -83,7 +83,7 @@ def _conv2d_transpose_nchw_python(a_np, w_np, stride, padding, output_padding):
8383
# convolution stage
8484
out_h = (in_h - 1) * stride_h - fpad_top - fpad_bottom + filter_h + opad_h
8585
out_w = (in_w - 1) * stride_w - fpad_left - fpad_right + filter_w + opad_w
86-
b_np = np.zeros((batch, out_c, out_h, out_w))
86+
b_np = np.zeros((batch, out_c, out_h, out_w)).astype(a_np.dtype)
8787
for n in range(batch):
8888
for f in range(out_c):
8989
for c in range(in_c):

0 commit comments

Comments
 (0)