Skip to content

Commit 39a1ad7

Browse files
jwfrommXingyu Zhou
authored andcommitted
[AutoTVM] Add batch_matmul to tunable operations (apache#4242)
* Batch matmul tuning running but with errors. * Default x86 schedule as good as before. * Code Cleanup * Remove unused argument. * improved template documentation. * Silly lint fix * Removed leftover comment. * Moved cfg declaration to schedule for batch_matmul * Moved x86 dense cfg declaration to schedule. * lint fix * Removed duplicate cfg declaration in dense. * Reverted changes to dense.
1 parent 976d816 commit 39a1ad7

3 files changed

Lines changed: 67 additions & 18 deletions

File tree

python/tvm/autotvm/task/relay_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
117117
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
118118
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
119119
tvm.relay.op.nn.dense: [topi.nn.dense],
120+
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
120121
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
121122
}
122123

python/tvm/autotvm/task/topi_integration.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(self, allow_duplicate=False):
8787
topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
8888
topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
8989
topi.nn.dense: "topi_nn_dense",
90+
topi.nn.batch_matmul: "topi_nn_batch_matmul",
9091
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
9192
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
9293
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
@@ -103,6 +104,7 @@ def __init__(self, allow_duplicate=False):
103104
topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
104105
topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
105106
topi.nn.dense: [topi.generic.schedule_dense],
107+
topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
106108
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
107109
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
108110
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
@@ -118,6 +120,7 @@ def __init__(self, allow_duplicate=False):
118120
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
119121
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
120122
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
123+
topi.nn.batch_matmul: lambda x: setattr(topi.nn, 'batch_matmul', x),
121124
topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
122125
topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
123126
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
@@ -226,6 +229,15 @@ def _topi_nn_dense(*args, **kwargs):
226229
return s, [data, weight, bias, C]
227230
return s, [data, weight, C]
228231

232+
@register("topi_nn_batch_matmul")
233+
def _topi_nn_batch_matmul(*args, **kwargs):
234+
assert not kwargs, "Do not support kwargs in template function call"
235+
args = deserialize_args(args)
236+
A, B = args
237+
C = topi.nn.batch_matmul(A, B)
238+
s = topi.generic.schedule_batch_matmul([C])
239+
return s, [A, B, C]
240+
229241
@register("topi_nn_bitserial_conv2d_nhwc")
230242
def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
231243
args = deserialize_args(args)

topi/python/topi/x86/batch_matmul.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,26 @@
1818
"""x86 batch_matmul operators"""
1919
from __future__ import absolute_import as _abs
2020
import tvm
21+
from tvm import autotvm
22+
from tvm.autotvm.task.space import SplitEntity
2123
from tvm.contrib import cblas
22-
from topi.nn import batch_matmul, batch_matmul_default
23-
from .. import generic
24+
from .. import generic, nn
2425
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
2526

26-
@batch_matmul.register(["cpu"])
27-
def batch_matmul_x86(x, y):
27+
28+
@autotvm.register_topi_compute(nn.batch_matmul, "cpu", "direct")
29+
def _declaration_batch_matmul_nopack(cfg, x, y):
2830
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
2931
data in batch.
3032
3133
Parameters
3234
----------
35+
cfg : ConfigSpace
36+
Autotvm tuning space config file
3337
x : tvm.Tensor
3438
3-D with shape [batch, M, K]
35-
3639
y : tvm.Tensor
3740
3-D with shape [batch, N, K]
38-
3941
Returns
4042
-------
4143
output : tvm.Tensor
@@ -44,17 +46,37 @@ def batch_matmul_x86(x, y):
4446
target = tvm.target.current_target()
4547
if "cblas" in target.libs:
4648
return cblas.batch_matmul(x, y, False, True)
47-
return batch_matmul_default(x, y)
4849

49-
@generic.schedule_batch_matmul.register(["cpu"])
50-
def schedule_batch_matmul(outs):
50+
assert len(x.shape) == 3 and len(
51+
y.shape) == 3, "only support 3-dim batch_matmul"
52+
XB, M, XK = get_const_tuple(x.shape)
53+
YB, N, YK = get_const_tuple(y.shape)
54+
assert XB == YB, "batch dimension doesn't match"
55+
assert XK == YK, "shapes of x and y is inconsistant"
56+
B = XB
57+
K = XK
58+
if cfg.is_fallback:
59+
_default_batch_matmul_nopack_config(cfg, M, N, K)
60+
61+
k = tvm.reduce_axis((0, K), name='k')
62+
C = tvm.compute(
63+
(B, M, N),
64+
lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
65+
tag='batch_matmul')
66+
return C
67+
68+
69+
@autotvm.register_topi_schedule(generic.schedule_batch_matmul, "cpu", "direct")
70+
def schedule_batch_matmul(cfg, outs):
5171
"""Schedule for batch_matmul
5272
5373
Parameters
5474
----------
55-
outs: Array of Tensor
56-
The computation graph description of batch_matmul
57-
in the format of an array of tensors.
75+
cfg : ConfigSpace
76+
AutoTVM tuning space config file.
77+
outs : Array of Tensor
78+
The computation graph description of batch_matmul
79+
in the format of an array of tensors.
5880
5981
Returns
6082
-------
@@ -71,16 +93,22 @@ def _callback(op):
7193
if "batch_matmul" in op.tag:
7294
C = op.output(0)
7395
A, B = s[C].op.input_tensors
74-
_, M, N = get_const_tuple(C.shape)
96+
_, M, K = get_const_tuple(A.shape)
97+
_, _, N = get_const_tuple(C.shape)
98+
99+
# create tuning space
100+
cfg.define_split("tile_y", M, num_outputs=2)
101+
cfg.define_split("tile_x", N, num_outputs=2)
102+
cfg.define_split("tile_k", K, num_outputs=2)
103+
75104
k, = s[C].op.reduce_axis
76-
ko, ki = s[C].split(k, 16)
105+
106+
ko, ki = cfg["tile_k"].apply(s, C, k)
77107
CC = s.rfactor(C, ki)
78108

79109
b, y, x = s[C].op.axis
80-
y_bn = get_max_power2_factor(M, 8)
81-
x_bn = get_max_power2_factor(N, 8)
82-
yo, yi = s[C].split(y, y_bn)
83-
xo, xi = s[C].split(x, x_bn)
110+
yo, yi = cfg["tile_y"].apply(s, C, y)
111+
xo, xi = cfg["tile_x"].apply(s, C, x)
84112
s[C].reorder(b, yo, xo, yi, xi)
85113
bxyo = s[C].fuse(b, yo, xo)
86114
s[C].parallel(bxyo)
@@ -94,3 +122,11 @@ def _callback(op):
94122

95123
traverse_inline(s, outs[0].op, _callback)
96124
return s
125+
126+
127+
def _default_batch_matmul_nopack_config(cfg, M, N, K):
128+
cfg["tile_k"] = SplitEntity([K // 16, 16])
129+
x_bn = get_max_power2_factor(N, 8)
130+
cfg["tile_x"] = SplitEntity([N // x_bn, x_bn])
131+
y_bn = get_max_power2_factor(M, 8)
132+
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])

0 commit comments

Comments
 (0)