Skip to content

Commit 56f2e9a

Browse files
committed
use vectorlow/high in arm intrin
1 parent 995cc8d commit 56f2e9a

3 files changed

Lines changed: 19 additions & 15 deletions

File tree

python/tvm/script/tir/__init__.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr:
124124
def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ...
125125
def evaluate(value: PrimExpr) -> None: ...
126126
def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ...
127+
def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ...
128+
def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ...
127129
def store(
128130
var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True
129131
) -> None: ...
@@ -143,7 +145,7 @@ def preflattened_buffer(
143145
) -> Buffer: ...
144146

145147
"""
146-
Intrinsics - tvm builtin
148+
Intrinsics - tvm builtin
147149
"""
148150

149151
def tvm_thread_allreduce(

python/tvm/tir/tensor_intrin/arm_cpu.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,45 +51,48 @@ def dot_product_4x4_i8i8i32_neon(
5151
vec_ai32 = T.broadcast(re_int32, 2)
5252
vec_a = T.reinterpret(vec_ai32, dtype="int8x8")
5353

54-
vec_b = B.vload([0, 0], dtype="int8x8")
54+
vec_b = B.vload([0, 0], dtype="int8x16")
55+
56+
# TODO(masahi): Remove duplication when inlined function call is supported
57+
vec_b_low = T.vectorlow(vec_b, dtype="int8x8")
5558

56-
multiply = T.call_llvm_pure_intrin(
59+
multiply_low = T.call_llvm_pure_intrin(
5760
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
5861
T.uint32(2),
5962
vec_a,
60-
vec_b,
63+
vec_b_low,
6164
dtype="int16x8",
6265
)
6366

64-
pair1 = T.call_llvm_pure_intrin(
67+
pairwise_reduction_low = T.call_llvm_pure_intrin(
6568
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
6669
T.uint32(1),
67-
multiply,
70+
multiply_low,
6871
dtype="int32x4",
6972
)
7073

71-
vec_b_2 = B.vload([2, 0], dtype="int8x8")
74+
vec_b_high = T.vectorhigh(vec_b, dtype="int8x8")
7275

73-
multiply_2 = T.call_llvm_pure_intrin(
76+
multiply_high = T.call_llvm_pure_intrin(
7477
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"),
7578
T.uint32(2),
7679
vec_a,
77-
vec_b_2,
80+
vec_b_high,
7881
dtype="int16x8",
7982
)
8083

81-
pair2 = T.call_llvm_pure_intrin(
84+
pairwise_reduction_high = T.call_llvm_pure_intrin(
8285
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"),
8386
T.uint32(1),
84-
multiply_2,
87+
multiply_high,
8588
dtype="int32x4",
8689
)
8790

8891
C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin(
8992
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"),
9093
T.uint32(2),
91-
pair1,
92-
pair2,
94+
pairwise_reduction_low,
95+
pairwise_reduction_high,
9396
dtype="int32x4",
9497
)
9598

tests/python/unittest/test_tir_schedule_tensorize.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,5 +593,4 @@ def test_tensorize_arm_dot():
593593

594594

595595
if __name__ == "__main__":
596-
# sys.exit(pytest.main([__file__] + sys.argv[1:]))
597-
test_tensorize_arm_dot()
596+
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)