Skip to content

Commit a978e30

Browse files
merrymercytmoreau89
authored andcommitted
Add schedule and test for group convolution (apache#5)
* group conv pass all * pass mobilenet
1 parent 243403d commit a978e30

9 files changed

Lines changed: 578 additions & 23 deletions

File tree

python/tvm/contrib/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def which(exec_name):
143143
return full_path
144144
return None
145145

146+
146147
def get_lower_ir(s):
147148
"""Get lower ir code of a schedule.
148149
This is useful for debug, since you don't have to find all inputs/outputs

topi/python/topi/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .conv2d_nchw_python import conv2d_nchw_python
99
from .conv2d_nhwc_python import conv2d_nhwc_python
1010
from .conv2d_transpose_nchw_python import conv2d_transpose_nchw_python
11+
from .group_conv2d import group_conv2d_nchw_python
1112
from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_conv2d_python_nhwc
1213
from .dilate_python import dilate_python
1314
from .softmax_python import softmax_python, log_softmax_python
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-branches
2+
"""Convolution in python"""
3+
import numpy as np
4+
import scipy.signal
5+
6+
7+
def group_conv2d_nchw_python(a_np, w_np, stride, padding, groups):
8+
"""Convolution operator in HWCN layout.
9+
10+
Parameters
11+
----------
12+
a_np : numpy.ndarray
13+
4-D with shape [batch, in_channel, in_height, in_width]
14+
15+
w_np : numpy.ndarray
16+
4-D with shape [num_filter, in_channel, filter_height, filter_width]
17+
18+
stride : int or a list/tuple of two ints
19+
Stride size, or [stride_height, stride_width]
20+
21+
padding : int or str or a list/tuple of two ints
22+
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width]
23+
24+
groups: int
25+
26+
Returns
27+
-------
28+
b_np : np.ndarray
29+
4-D with shape [batch, out_channel, out_height, out_width]
30+
"""
31+
batch, in_channel, in_height, in_width = a_np.shape
32+
num_filter, ci_g, kernel_h, kernel_w = w_np.shape
33+
if isinstance(stride, int):
34+
stride_h = stride_w = stride
35+
else:
36+
stride_h, stride_w = stride
37+
if isinstance(padding, int):
38+
pad_h = pad_w = padding * 2
39+
elif isinstance(padding, (list, tuple)):
40+
pad_h, pad_w = padding[0] * 2, padding[1] * 2
41+
else:
42+
pad_h = 0 if padding == 'VALID' else kernel_h - 1
43+
pad_w = 0 if padding == 'VALID' else kernel_w - 1
44+
pad_top = int(np.ceil(float(pad_h) / 2))
45+
pad_bottom = pad_h - pad_top
46+
pad_left = int(np.ceil(float(pad_w) / 2))
47+
pad_right = pad_w - pad_left
48+
# compute the output shape
49+
out_channel = num_filter
50+
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
51+
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
52+
b_np = np.zeros((batch, out_channel, out_height, out_width))
53+
54+
assert ci_g * groups == in_channel
55+
56+
# group computation
57+
for n in range(batch):
58+
for f in range(out_channel):
59+
for c in range(ci_g):
60+
base = f // (out_channel // groups) * ci_g
61+
if pad_h > 0 or pad_w > 0:
62+
apad = np.zeros((in_height + pad_h, in_width + pad_w))
63+
if pad_h == 0:
64+
apad[:, pad_left:-pad_right] = a_np[n, base + c]
65+
elif pad_w == 0:
66+
apad[pad_top:-pad_bottom, :] = a_np[n, base + c]
67+
else:
68+
apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, base + c]
69+
else:
70+
apad = a_np[n, base + c]
71+
out = scipy.signal.convolve2d(
72+
apad, np.rot90(np.rot90(w_np[f, c])), mode='valid')
73+
b_np[n, f] += out[::stride_h, ::stride_w]
74+
return b_np

vta/config/vta_config.json

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
"GEMM_II" : 1,
99
"TALU_II" : 2,
1010
"LOG_INP_WIDTH" : 3,
11-
"LOG_WGT_WIDTH" : 1,
11+
"LOG_WGT_WIDTH" : 3,
1212
"LOG_ACC_WIDTH" : 5,
1313
"LOG_OUT_WIDTH" : 3,
1414
"LOG_BATCH" : 0,
15-
"LOG_BLOCK_IN" : 5,
16-
"LOG_BLOCK_OUT" : 5,
15+
"LOG_BLOCK_IN" : 4,
16+
"LOG_BLOCK_OUT" : 4,
1717
"LOG_UOP_BUFF_SIZE" : 15,
18-
"LOG_INP_BUFF_SIZE" : 16,
18+
"LOG_INP_BUFF_SIZE" : 15,
1919
"LOG_WGT_BUFF_SIZE" : 18,
2020
"LOG_ACC_BUFF_SIZE" : 17
2121
}

vta/python/vta/top/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""TVM TOPI connector, eventually most of these should go to TVM repo"""
22

3-
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
43
from . import vta_conv2d
54
from . import arm_conv2d
5+
66
from .bitpack import bitpack
7+
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
8+
from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d

vta/python/vta/top/arm_conv2d.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,88 @@
55
from topi.nn import conv2d, conv2d_alter_layout
66
from topi import generic
77

8+
_WORKLOADS = [
9+
# resnet 18
10+
Workload('float32', 'float32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
11+
Workload('int8', 'int32', 224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
12+
Workload('int8', 'int32', 56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
13+
Workload('int8', 'int32', 56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
14+
Workload('int8', 'int32', 56, 56, 64, 128, 3, 3, 1, 1, 2, 2),
15+
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 2, 2),
16+
Workload('int8', 'int32', 28, 28, 128, 128, 3, 3, 1, 1, 1, 1),
17+
Workload('int8', 'int32', 28, 28, 128, 256, 3, 3, 1, 1, 2, 2),
18+
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 2, 2),
19+
Workload('int8', 'int32', 14, 14, 256, 256, 3, 3, 1, 1, 1, 1),
20+
Workload('int8', 'int32', 14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
21+
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
22+
Workload('int8', 'int32', 7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
23+
24+
# mobilenet float32
25+
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
26+
Workload('float32', 'float32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
27+
Workload('float32', 'float32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
28+
Workload('float32', 'float32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
29+
Workload('float32', 'float32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
30+
Workload('float32', 'float32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
31+
Workload('float32', 'float32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
32+
Workload('float32', 'float32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
33+
Workload('float32', 'float32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
34+
Workload('float32', 'float32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
35+
36+
# mobilenet int8
37+
Workload('float32', 'float32', 224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
38+
Workload('int8', 'int32', 112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
39+
Workload('int8', 'int32', 56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
40+
Workload('int8', 'int32', 56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
41+
Workload('int8', 'int32', 28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
42+
Workload('int8', 'int32', 28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
43+
Workload('int8', 'int32', 14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
44+
Workload('int8', 'int32', 14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
45+
Workload('int8', 'int32', 7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
46+
Workload('int8', 'int32', 7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
47+
]
48+
49+
_SCHEDULES = [
50+
# float32 imagenet
51+
SpatialPack(1, 8, 4, 1, 4, True),
52+
SpatialPack(1, 8, 4, 1, 4, True),
53+
SpatialPack(1, 7, 4, 2, 4, True),
54+
SpatialPack(1, 4, 8, 4, 1, True),
55+
SpatialPack(1, 4, 4, 1, 16, False),
56+
SpatialPack(1, 4, 8, 4, 8, False),
57+
SpatialPack(1, 7, 4, 3, 8, True),
58+
SpatialPack(1, 2, 8, 1, 8, True),
59+
SpatialPack(2, 1, 16, 1, 4, True),
60+
SpatialPack(1, 7, 4, 1, 1, True),
61+
Im2ColPack(7, 4, 1, 16, True),
62+
Im2ColPack(7, 4, 1, 8, False),
63+
Im2ColPack(7, 4, 1, 16, False),
64+
65+
# float32 mobilenet
66+
SpatialPack(2, 2, 4, 28, 1, True),
67+
SpatialPack(1, 4, 8, 14, 1, False),
68+
SpatialPack(1, 2, 16, 8, 1, True),
69+
SpatialPack(1, 4, 8, 8, 8, True),
70+
SpatialPack(2, 2, 8, 1, 1, False),
71+
SpatialPack(1, 4, 8, 4, 8, False),
72+
SpatialPack(2, 2, 8, 1, 4, False),
73+
SpatialPack(2, 2, 8, 1, 8, False),
74+
Im2ColPack(7, 4, 1, 16, False),
75+
Im2ColPack(7, 4, 1, 4, True),
76+
77+
# int8 mobilenet
78+
SpatialPack(2, 2, 4, 28, 1, True),
79+
SpatialPack(1, 4, 8, 14, 1, False),
80+
SpatialPack(1, 2, 16, 8, 1, True),
81+
SpatialPack(1, 4, 8, 8, 8, True),
82+
SpatialPack(2, 2, 8, 1, 1, False),
83+
SpatialPack(1, 4, 8, 4, 8, False),
84+
SpatialPack(2, 2, 8, 1, 4, False),
85+
SpatialPack(2, 2, 8, 1, 8, False),
86+
Im2ColPack(7, 4, 1, 16, False),
87+
Im2ColPack(7, 4, 1, 4, True),
88+
]
89+
890
@conv2d.register(["vtacpu", "vta"])
991
def compute(*args, **kwargs):
1092
with tvm.target.arm_cpu("vtacpu"):

vta/python/vta/top/vta_conv2d.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from nnvm.top import nn as _nn
1212
from ..environment import get_env
1313
from ..ptr_alias import reinterpret
14+
from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d
15+
1416

1517
Workload = namedtuple("Conv2DWorkload",
1618
['batch', 'height', 'width', 'in_filter', 'out_filter',
@@ -262,22 +264,26 @@ def compute_conv2d(attrs, inputs, out):
262264

263265
assert dilation == (1, 1), "not support dilate now"
264266
if is_packed_layout(layout):
265-
assert groups == 1
266-
env = get_env()
267-
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
268-
assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
269-
inputs = list(inputs)
270-
w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH)
271-
assert inputs[1].dtype == "int8"
272-
273-
# Apply bit packing if necessary
274-
if w_pack_factor != 1:
275-
kshape = list(topi.util.get_const_tuple(inputs[1].shape))
276-
kshape[-1] *= w_pack_factor
277-
inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype)
278-
279-
return packed_conv2d(inputs[0], inputs[1],
280-
padding, strides, out_dtype=out_dtype)
267+
if groups == 1:
268+
assert groups == 1
269+
env = get_env()
270+
assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now"
271+
assert env.LOG_OUT_WIDTH == 3, "only support 8bit inp for now"
272+
inputs = list(inputs)
273+
w_pack_factor = 1 << (3 - env.LOG_WGT_WIDTH)
274+
assert inputs[1].dtype == "int8"
275+
276+
# Apply bit packing if necessary
277+
if w_pack_factor != 1:
278+
kshape = list(topi.util.get_const_tuple(inputs[1].shape))
279+
kshape[-1] *= w_pack_factor
280+
inputs[1] = reinterpret(inputs[1], kshape, dtype=env.wgt_dtype)
281+
282+
return packed_conv2d(inputs[0], inputs[1],
283+
padding, strides, out_dtype=out_dtype)
284+
else:
285+
return packed_group_conv2d(inputs[0], inputs[1],
286+
padding, strides, groups, out_dtype=out_dtype)
281287
return _nn.compute_conv2d(attrs, inputs, out)
282288

283289

@@ -286,12 +292,16 @@ def schedule_conv2d(attrs, outs, target):
286292
""" 2D convolution schedule.
287293
"""
288294
layout = attrs["layout"]
295+
groups = attrs.get_int('groups')
289296

290297
if is_packed_layout(layout):
291298
target = tvm.target.create(target)
292299
if target.device_name == "vta":
293-
return schedule_packed_conv2d(outs)
294-
if str(target).startswith("llvm"):
300+
if groups == 1:
301+
return schedule_packed_conv2d(outs)
302+
else:
303+
return schedule_packed_group_conv2d(outs)
304+
elif str(target).startswith("llvm"):
295305
return tvm.create_schedule([x.op for x in outs])
296306
raise RuntimeError("not support target %s" % target)
297307
return _nn.schedule_conv2d(attrs, outs, target)

0 commit comments

Comments
 (0)