Skip to content

Commit 912993f

Browse files
authored
[ARM] Fix int8 NCHWc compute and alter layout (#10839)
This PR fixes a bug in TE ARM int8 compute for NCHWc conv2d, introduced in #10310. The compute itself, not the schedule, is broken for the following reasons: * We are using `n_elems = 8` in https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L350. Thus, the innermost axis of the transformed kernel has extent 8: https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_alter_op.py#L375 * In the TE compute, we iterate over the innermost axis `ic_s_inner` of the kernel at https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L577. `ic_s_inner` has extent `n_elems` according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L566. `n_elems` is 4 by default according to https://github.com/apache/tvm/blob/f6f252f0abc8f621a96506739f9534083d1fe213/python/tvm/topi/nn/conv2d.py#L478 * The ARM code that calls this compute does not explicitly pass `n_elems`, according to https://github.com/apache/tvm/blob/e9091d6c68d5d70c28881e5c75bfe72e385c1f4d/python/tvm/topi/arm_cpu/conv2d_int8.py#L106-L108 * Thus, even though the innermost axis of the kernel has extent 8, the TE compute only loops over `n_elems = 4` of the input channel dimension. Initially, I tried to keep `n_elems = 8` in alter layout and fix the intrinsic definition. But `n_elems = 8` breaks tensorization pattern matching, since now the compute is doing 4x8 innermost loop but this intrinsic is supposed to do 4x4 dot product, see https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L467-L479. Setting `num_int8_elements = 8` there does fix the tensorize pattern matching, but the result was still incorrect. Rather than fixing the intrin implementation in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/tensor_intrin.py#L492 to adapt for 4x8 dot product, I settled on setting `n_elems = 4` in alter layout. It turned out this change is enough to get the correct output. Moreover, `n_elems = 8` is simply wrong for the dot product path in https://github.com/apache/tvm/blob/7896108fc41663a1fecbb52345194a93278e9e28/python/tvm/topi/arm_cpu/conv2d_int8.py#L154-L155 which computes 4x4 dot product in one instruction. @tkonolige I suggest doing perf benchmark again, since the numbers in #10310 are invalid. cc @mbrookhart @Mousius @junrushao1994 @vinx13
1 parent 63bb3b9 commit 912993f

6 files changed

Lines changed: 26 additions & 23 deletions

File tree

python/tvm/topi/arm_cpu/conv2d_alter_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
347347
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
348348
out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
349349

350-
n_elems = 8
350+
n_elems = 4
351351

352352
if cfg.is_fallback:
353353
_get_default_config_int8(

python/tvm/topi/arm_cpu/conv2d_int8.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
5757
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
5858
in_channel = ic_chunk * ic_bn
5959

60-
oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, n_elems = get_const_tuple(kernel.shape)
60+
oc_chunk, ic_chunk, kh, kw, ic_bn, oc_bn, _ = get_const_tuple(kernel.shape)
6161
num_filter = oc_chunk * oc_bn
6262
else:
6363
# data is nchw, implicitly treat it as nchw1c
@@ -103,8 +103,10 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
103103
if len(data.shape) == 4:
104104
data, kernel = _pack_data(cfg, data, kernel)
105105

106+
n_elems = int(kernel.shape[-1])
107+
106108
return nn.conv2d_NCHWc_int8(
107-
data, kernel, strides, padding, dilation, layout, out_layout, out_dtype
109+
data, kernel, strides, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems
108110
)
109111

110112

@@ -149,7 +151,8 @@ def _callback(op):
149151

150152
args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]]
151153
# int8 conv kernel is 7-dim
152-
_, _, kh, kw, _, _, _ = get_const_tuple(kernel_vec.shape)
154+
_, _, kh, kw, _, _, n_elems = get_const_tuple(kernel_vec.shape)
155+
assert n_elems == 4
153156
dtype = "uint" if data.dtype == "uint8" else "int"
154157
if is_dotprod_available():
155158
intrin = dot_int8_int8_int32_neon_82(int32_lanes=4, dtype=dtype)

python/tvm/topi/arm_cpu/tensor_intrin.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -614,21 +614,22 @@ def _instr(index):
614614
ib.emit(outs[0].vstore(0, tvm.tir.const(0, int_32xl)))
615615
return ib.get()
616616

617-
def pairwise_add_mul(idx):
618-
# this broadcasts data to the vector size
619-
a_int8 = ins[0].vload([0], "int8x4")
620-
re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
621-
vec_ai32 = re_int32.astype("int32x2")
622-
vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32)
617+
# this broadcasts data to the vector size
618+
a_int8 = ins[0].vload([0], "int8x4")
619+
re_int32 = tvm.tir.call_intrin("int32", "tir.reinterpret", a_int8)
620+
vec_ai32 = re_int32.astype("int32x2")
621+
vec_a = tvm.tir.call_intrin(int_8xl, "tir.reinterpret", vec_ai32)
623622

624-
vec_b = ins[1].vload([idx * 2, 0], int_8xl) # we take two inputs at a time
623+
vec_b = ins[1].vload([0, 0], "int8x16")
625624

625+
def pairwise_add_mul(extract_half):
626+
vec_b_half = tvm.tir.call_intrin("int8x8", extract_half, vec_b)
626627
multiply = tvm.tir.call_llvm_pure_intrin(
627628
"int16x8",
628629
"llvm.aarch64.neon.smull.v8i16", # saturating pairwise multiplication
629630
tvm.tir.const(2, "uint32"),
630631
vec_a,
631-
vec_b,
632+
vec_b_half,
632633
)
633634
pairwise_reduction = tvm.tir.call_llvm_pure_intrin(
634635
"int32x4",
@@ -638,8 +639,8 @@ def pairwise_add_mul(idx):
638639
)
639640
return pairwise_reduction
640641

641-
pair_1 = pairwise_add_mul(0)
642-
pair_2 = pairwise_add_mul(1)
642+
pair_1 = pairwise_add_mul("tir.vectorlow")
643+
pair_2 = pairwise_add_mul("tir.vectorhigh")
643644
quad_reduction = tvm.tir.call_llvm_pure_intrin(
644645
"int32x4",
645646
"llvm.aarch64.neon.addp.v4i32",

python/tvm/topi/nn/conv2d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,6 @@ def conv2d_NCHWc_int8(
486486
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(
487487
kernel.shape
488488
)
489-
num_filter = oc_chunk * oc_bn
490489
groups = ic_chunk // ic_chunk_group
491490

492491
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1

python/tvm/topi/x86/conv2d_int8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def _pack_data(cfg, data, kernel):
120120
kernel = te.compute(
121121
(oc_chunk, ic_chunk, kh, kw, ic_bn // n_elems, oc_bn, n_elems),
122122
lambda occ, icc, k_h, k_w, icbc, ocb, icbb: kernel[
123-
occ * oc_bn + ocb, icc * ic_bn + icbc * ic_bn // n_elems + icbb, k_h, k_w
123+
occ * oc_bn + ocb, icc * ic_bn + icbc * n_elems + icbb, k_h, k_w
124124
],
125125
name="kernel_vec",
126126
)

tests/python/topi/python/test_topi_conv2d_int8.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import tvm
2222
from tvm import te
2323
from tvm import autotvm
24-
from tvm.autotvm.task.space import FallbackConfigEntity
2524
from tvm import topi
2625
import tvm.topi.testing
2726
from tvm.contrib.pickle_memoize import memoize
@@ -34,6 +33,7 @@
3433
from common import Int8Fallback
3534
import tvm.testing
3635
import pytest
36+
import platform
3737

3838

3939
def compile_conv2d_NHWC_gemm_int8_arm(
@@ -299,7 +299,6 @@ def get_ref_data():
299299

300300
a_np, w_np, b_np, c_np = get_ref_data()
301301

302-
print("Running on target: %s" % target)
303302
with tvm.target.Target(target):
304303
C = compute(
305304
A,
@@ -311,8 +310,6 @@ def get_ref_data():
311310
"NCHW",
312311
out_dtype,
313312
)
314-
print(C.shape)
315-
print(bias.shape)
316313
if add_bias:
317314
C = topi.add(C, bias)
318315
if add_relu:
@@ -342,6 +339,8 @@ def get_ref_data():
342339
if build_only:
343340
return
344341

342+
print("Running on target: %s" % target)
343+
345344
func(*run_args)
346345

347346
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
@@ -364,14 +363,15 @@ def get_ref_data():
364363
# ),
365364
]
366365

367-
# TODO(tvm-team): Properly run ARM code on CI aarch64 environment
366+
build_only_aarch64 = platform.machine() != "aarch64"
367+
368368
targets.append(
369369
(
370370
"llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon,+v8.2a,+dotprod",
371371
topi.arm_cpu.conv2d_NCHWc_int8,
372372
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
373373
8,
374-
True,
374+
build_only_aarch64,
375375
)
376376
)
377377

@@ -382,7 +382,7 @@ def get_ref_data():
382382
topi.arm_cpu.conv2d_NCHWc_int8,
383383
topi.arm_cpu.schedule_conv2d_NCHWc_int8,
384384
8,
385-
True,
385+
build_only_aarch64,
386386
)
387387
)
388388

0 commit comments

Comments
 (0)