Skip to content

Commit 429d426

Browse files
masahimehrdadh
authored andcommitted
[TIR] VNNI and ARM dot product intrinsic for tensorization (apache#10925)
1 parent dda8d2c commit 429d426

7 files changed

Lines changed: 311 additions & 62 deletions

File tree

python/tvm/script/tir/__init__.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr:
124124
def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
125125
def evaluate(value: PrimExpr) -> None: ...
126126
def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
127+
def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ...
128+
def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ...
127129
def store(
128130
var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
129131
) -> None: ...
@@ -143,7 +145,7 @@ def preflattened_buffer(
143145
) -> Buffer: ...
144146

145147
"""
146-
Intrinsics - tvm builtin
148+
Intrinsics - tvm builtin
147149
"""
148150

149151
def tvm_thread_allreduce(

python/tvm/script/tir/special_stmt.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@
2525

2626
import tvm.tir
2727
from tvm.runtime import Object, String
28-
from tvm import te
2928
from tvm.target import Target
3029
from tvm.ir import Span
31-
from tvm.tir import IntImm, IterVar
30+
from tvm.tir import IntImm, IterVar, Var
3231

3332
from .node import BufferSlice
3433
from .utils import buffer_slice_to_region
@@ -800,7 +799,7 @@ def var(dtype, span):
800799
self.context.report_error(
801800
f"VarDef expected assign to only one var, but got {names}", span
802801
)
803-
v = te.var(names[0], dtype, span=span)
802+
v = Var(names[0], dtype, span=span)
804803
self.context.update_symbol(v.name, v, self.node)
805804

806805
super().__init__(var, def_symbol=True)
@@ -821,7 +820,7 @@ def buffer_var(dtype, storage_scope, span):
821820
f"VarDef expected assign to only one var, but got {names}", span
822821
)
823822
ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
824-
v = te.var(names[0], ptr_type, span=span)
823+
v = Var(names[0], ptr_type, span=span)
825824
self.context.update_symbol(v.name, v, self.node)
826825

827826
super().__init__(buffer_var, def_symbol=True)
@@ -841,7 +840,7 @@ def env_thread(env_name, span):
841840
self.context.report_error(
842841
f"VarDef expected assign to only one var, but got {names}", span
843842
)
844-
v = te.var(names[0], span=span)
843+
v = Var(names[0], dtype="int32", span=span)
845844
self.context.func_var_env_dict[v] = env_name
846845
self.context.update_symbol(v.name, v, self.node)
847846

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-import
18+
"""Intrinsics for tensorization."""
19+
from .x86 import *
20+
from .arm_cpu import *
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,missing-function-docstring
18+
"""Intrinsics for ARM tensorization."""
19+
from tvm.script import tir as T
20+
from .. import TensorIntrin
21+
22+
23+
# TODO(masahi): Parametrize the TVMScript description of dot product by
24+
# shape and dtype, and share the common description with x86.
25+
26+
27+
@T.prim_func
28+
def dot_product_4x4_i8i8i32_desc(
29+
A: T.Buffer((4,), "int8", offset_factor=1),
30+
B: T.Buffer((4, 4), "int8", offset_factor=1),
31+
C: T.Buffer((4,), "int32", offset_factor=1),
32+
) -> None:
33+
with T.block("root"):
34+
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
35+
T.writes(C[0:4])
36+
for i in T.serial(0, 4):
37+
with T.init():
38+
C[i] = T.int32(0)
39+
for k in T.serial(0, 4):
40+
with T.block("update"):
41+
vi, vk = T.axis.remap("SR", [i, k])
42+
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
43+
44+
45+
@T.prim_func
46+
def dot_product_4x4_i8i8i32_neon(
47+
A: T.Buffer((4,), "int8", offset_factor=1),
48+
B: T.Buffer((4, 4), "int8", offset_factor=1),
49+
C: T.Buffer((4,), "int32", offset_factor=1),
50+
) -> None:
51+
with T.block("root"):
52+
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
53+
T.writes(C[0:4])
54+
55+
A_int8 = A.vload([0], "int8x4")
56+
re_int32 = T.reinterpret(A_int8, dtype="int32")
57+
vec_ai32 = T.broadcast(re_int32, 2)
58+
vec_a = T.reinterpret(vec_ai32, dtype="int8x8")
59+
60+
vec_b = B.vload([0, 0], dtype="int8x16")
61+
62+
# TODO(masahi): Remove duplication when inlined function call is supported
63+
vec_b_low = T.vectorlow(vec_b, dtype="int8x8")
64+
65+
multiply_low = T.call_llvm_pure_intrin(
66+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
67+
T.uint32(2),
68+
vec_a,
69+
vec_b_low,
70+
dtype="int16x8",
71+
)
72+
73+
pairwise_reduction_low = T.call_llvm_pure_intrin(
74+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
75+
T.uint32(1),
76+
multiply_low,
77+
dtype="int32x4",
78+
)
79+
80+
vec_b_high = T.vectorhigh(vec_b, dtype="int8x8")
81+
82+
multiply_high = T.call_llvm_pure_intrin(
83+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
84+
T.uint32(2),
85+
vec_a,
86+
vec_b_high,
87+
dtype="int16x8",
88+
)
89+
90+
pairwise_reduction_high = T.call_llvm_pure_intrin(
91+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
92+
T.uint32(1),
93+
multiply_high,
94+
dtype="int32x4",
95+
)
96+
97+
C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
98+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
99+
T.uint32(2),
100+
pairwise_reduction_low,
101+
pairwise_reduction_high,
102+
dtype="int32x4",
103+
)
104+
105+
106+
@T.prim_func
107+
def dot_product_4x4_i8i8i32_sdot(
108+
A: T.Buffer((4,), "int8", offset_factor=1),
109+
B: T.Buffer((4, 4), "int8", offset_factor=1),
110+
C: T.Buffer((4,), "int32", offset_factor=1),
111+
) -> None:
112+
with T.block("root"):
113+
T.reads(C[0:4], A[0:4], B[0:4, 0:4])
114+
T.writes(C[0:4])
115+
116+
A_i8x4 = A.vload([0], "int8x4")
117+
A_i32 = T.reinterpret(A_i8x4, dtype="int32")
118+
vec_ai32 = T.broadcast(A_i32, 4)
119+
vec_a = T.reinterpret(vec_ai32, dtype="int8x16")
120+
121+
vec_b = B.vload([0, 0], dtype="int8x16")
122+
123+
C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
124+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
125+
T.uint32(3),
126+
T.int32x4(0),
127+
vec_a,
128+
vec_b,
129+
dtype="int32x4",
130+
)
131+
132+
133+
ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
134+
ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"
135+
136+
TensorIntrin.register(
137+
ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon
138+
)
139+
140+
TensorIntrin.register(
141+
ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot
142+
)
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,missing-function-docstring
18+
"""Intrinsics for x86 tensorization."""
19+
from tvm.script import tir as T
20+
from .. import TensorIntrin
21+
22+
23+
# Tensorized intrinsic description and VNNI-specific implementation.
24+
# Equivalent to the ones in topi/x86/tensor_intrin.py
25+
26+
27+
@T.prim_func
28+
def dot_product_16x4_u8i8i32_desc(
29+
A: T.Buffer((4,), "uint8", offset_factor=1),
30+
B: T.Buffer((16, 4), "int8", offset_factor=1),
31+
C: T.Buffer((16,), "int32", offset_factor=1),
32+
) -> None:
33+
with T.block("root"):
34+
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
35+
T.writes(C[0:16])
36+
for i in T.serial(0, 16):
37+
with T.init():
38+
C[i] = T.int32(0)
39+
for k in T.serial(0, 4):
40+
with T.block("update"):
41+
vi, vk = T.axis.remap("SR", [i, k])
42+
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
43+
44+
45+
@T.prim_func
46+
def dot_product_16x4_u8i8i32_vnni(
47+
A: T.Buffer((4,), "uint8", offset_factor=1),
48+
B: T.Buffer((16, 4), "int8", offset_factor=1),
49+
C: T.Buffer((16,), "int32", offset_factor=1),
50+
) -> None:
51+
with T.block("root"):
52+
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
53+
T.writes(C[0:16])
54+
55+
A_u8x4 = A.vload([0], "uint8x4")
56+
A_i32 = T.reinterpret(A_u8x4, dtype="int32")
57+
58+
B_i8x64 = B.vload([0, 0], dtype="int8x64")
59+
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
60+
61+
C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update +=
62+
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
63+
T.uint32(0),
64+
T.int32x16(0),
65+
T.broadcast(A_i32, 16),
66+
B_i32x16,
67+
dtype="int32x16",
68+
)
69+
70+
71+
VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni"
72+
73+
TensorIntrin.register(
74+
VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni
75+
)

tests/python/unittest/test_meta_schedule_tune_relay.py

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from tvm.target.target import Target
3939
from tvm.tir.schedule import BlockRV, Schedule
4040
from tvm.tir.schedule.trace import Trace
41+
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
42+
4143

4244
logging.basicConfig()
4345
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)
@@ -328,57 +330,6 @@ def get_output(data, lib):
328330
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
329331

330332

331-
# Tensorized intrinsic description and VNNI-specific implementation.
332-
# Equivalent to the ones in topi/x86/tensor_intrin.py
333-
334-
335-
@T.prim_func
336-
def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
337-
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
338-
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
339-
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
340-
341-
with T.block("root"):
342-
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
343-
T.writes(C[0:16])
344-
for i in T.serial(0, 16):
345-
with T.init():
346-
C[i] = T.int32(0)
347-
for k in T.serial(0, 4):
348-
with T.block("update"):
349-
vi, vk = T.axis.remap("SR", [i, k])
350-
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")
351-
352-
353-
@T.prim_func
354-
def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None:
355-
A = T.match_buffer(a, (4,), "uint8", offset_factor=1)
356-
B = T.match_buffer(b, (16, 4), "int8", offset_factor=1)
357-
C = T.match_buffer(c, (16,), "int32", offset_factor=1)
358-
359-
with T.block("root"):
360-
T.reads(C[0:16], A[0:4], B[0:16, 0:4])
361-
T.writes(C[0:16])
362-
363-
A_u8x4 = A.vload([0], "uint8x4")
364-
A_i32 = T.reinterpret(A_u8x4, dtype="int32")
365-
366-
B_i8x64 = B.vload([0, 0], dtype="int8x64")
367-
B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16")
368-
369-
C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update +=
370-
T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"),
371-
T.uint32(0),
372-
T.int32x16(0),
373-
T.broadcast(A_i32, 16),
374-
B_i32x16,
375-
dtype="int32x16",
376-
)
377-
378-
379-
VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake"
380-
381-
382333
def schedule_dense(dense_block, M, do_tune, sch):
383334
"""
384335
Manually schedule a dense block, created from TE compute op via CreatePrimFunc,
@@ -546,10 +497,6 @@ def schedule_fn(task, sch):
546497

547498
@pytest.mark.skip("Requires cascadelake")
548499
def test_tune_relay_manual_tir_vnni():
549-
# Register a pair of an intrinsic description for 16x4 dot product, and its
550-
# VNNI-specific implementation.
551-
tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni)
552-
553500
manual_tir_common(do_tune=False)
554501

555502
"""

0 commit comments

Comments
 (0)