Skip to content
1 change: 1 addition & 0 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}

Expand Down
12 changes: 12 additions & 0 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, allow_duplicate=False):
topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
topi.nn.dense: "topi_nn_dense",
topi.nn.batch_matmul: "topi_nn_batch_matmul",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
Expand All @@ -103,6 +104,7 @@ def __init__(self, allow_duplicate=False):
topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
Expand All @@ -118,6 +120,7 @@ def __init__(self, allow_duplicate=False):
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
topi.nn.batch_matmul: lambda x: setattr(topi.nn, 'batch_matmul', x),
topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
Expand Down Expand Up @@ -226,6 +229,15 @@ def _topi_nn_dense(*args, **kwargs):
return s, [data, weight, bias, C]
return s, [data, weight, C]

@register("topi_nn_batch_matmul")
def _topi_nn_batch_matmul(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, B = args
C = topi.nn.batch_matmul(A, B)
s = topi.generic.schedule_batch_matmul([C])
return s, [A, B, C]

@register("topi_nn_bitserial_conv2d_nhwc")
def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
Expand Down
66 changes: 49 additions & 17 deletions topi/python/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,26 @@
"""x86 batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas
from topi.nn import batch_matmul, batch_matmul_default
from .. import generic
from .. import generic, nn
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

@batch_matmul.register(["cpu"])
def batch_matmul_x86(x, y):

@autotvm.register_topi_compute(nn.batch_matmul, "cpu", "direct")
def _declaration_batch_matmul_nopack(cfg, x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.

Parameters
----------
cfg : ConfigSpace
Autotvm tuning space config file
x : tvm.Tensor
3-D with shape [batch, M, K]

y : tvm.Tensor
3-D with shape [batch, N, K]

Returns
-------
output : tvm.Tensor
Expand All @@ -44,17 +46,41 @@ def batch_matmul_x86(x, y):
target = tvm.target.current_target()
if "cblas" in target.libs:
return cblas.batch_matmul(x, y, False, True)
return batch_matmul_default(x, y)

@generic.schedule_batch_matmul.register(["cpu"])
def schedule_batch_matmul(outs):
assert len(x.shape) == 3 and len(
y.shape) == 3, "only support 3-dim batch_matmul"
XB, M, XK = get_const_tuple(x.shape)
YB, N, YK = get_const_tuple(y.shape)
assert XB == YB, "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistant"
B = XB
K = XK
# create tuning space
cfg.define_split("tile_y", M, num_outputs=2)
cfg.define_split("tile_x", N, num_outputs=2)
cfg.define_split("tile_k", K, num_outputs=2)
if cfg.is_fallback:
_default_batch_matmul_nopack_config(cfg, M, N, K)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part should belong to the schedule instead of the declaration. I suggest moving them to the schedule function like other ops.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually extremely similar to the topi dense declaration in x86 as it's based directly on it. I would argue that the functional similarity between dense and batch_matmul encourages us to keep the syntax as close as possible to make transferring optimizations simple. If you feel strongly that configuration declarations should be in the schedule, I'd by happy to move both the batch_matmul and the dense declarations.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't think that's the best practice, because it would be tedious when we want to reuse the compute function on the different target (e.g., CUDA). It would also be confusing when someone tries to improve the schedule in the future, so I think it would be great to change both dense and batch_norm. On the other hand, I would also be happy to know other's opinion.

cc @icemelon9

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to move both to be in the schedule.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for moving this into the schedule.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest push moves the configuration declaration in both batch_matmul and dense to scheduling function. However, note that in the dense declaration, some splits are actually used in the computation (tile_k in dense_no_pack for example) and so cannot be moved. This means that the declarations arent all located in the same place. Do you guys prefer it this way or should we leave dense alone and only change batch_matmul?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case I personally prefer to leave the dense there, and file a separate PR to refactor the dense compute.
cc @yzhliu @icemelon9 @Laurawly

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac I agree thats the best way to proceed. I reverted the changes to dense in the latest commit.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac In certain case, the config space must be defined in the compute as the compute needs to use it to define intermediate compute stage.


k = tvm.reduce_axis((0, K), name='k')
C = tvm.compute(
(B, M, N),
lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
tag='batch_matmul')
return C


@autotvm.register_topi_schedule(generic.schedule_batch_matmul, "cpu", "direct")
def schedule_batch_matmul(cfg, outs):
"""Schedule for batch_matmul

Parameters
Comment thread
jwfromm marked this conversation as resolved.
----------
outs: Array of Tensor
The computation graph description of batch_matmul
in the format of an array of tensors.
cfg : ConfigSpace
AutoTVM tuning space config file.
outs : Array of Tensor
The computation graph description of batch_matmul
in the format of an array of tensors.

Returns
-------
Expand All @@ -73,14 +99,12 @@ def _callback(op):
A, B = s[C].op.input_tensors
_, M, N = get_const_tuple(C.shape)
k, = s[C].op.reduce_axis
ko, ki = s[C].split(k, 16)
ko, ki = cfg["tile_k"].apply(s, C, k)
CC = s.rfactor(C, ki)

b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 8)
x_bn = get_max_power2_factor(N, 8)
yo, yi = s[C].split(y, y_bn)
xo, xi = s[C].split(x, x_bn)
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(b, yo, xo, yi, xi)
bxyo = s[C].fuse(b, yo, xo)
s[C].parallel(bxyo)
Expand All @@ -94,3 +118,11 @@ def _callback(op):

traverse_inline(s, outs[0].op, _callback)
return s


def _default_batch_matmul_nopack_config(cfg, M, N, K):
cfg["tile_k"] = SplitEntity([K // 16, 16])
x_bn = get_max_power2_factor(N, 8)
cfg["tile_x"] = SplitEntity([N // x_bn, x_bn])
y_bn = get_max_power2_factor(M, 8)
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])