@@ -51,45 +51,48 @@ def dot_product_4x4_i8i8i32_neon(
5151 vec_ai32 = T .broadcast (re_int32 , 2 )
5252 vec_a = T .reinterpret (vec_ai32 , dtype = "int8x8" )
5353
54- vec_b = B .vload ([0 , 0 ], dtype = "int8x8" )
54+ vec_b = B .vload ([0 , 0 ], dtype = "int8x16" )
55+
56+ # TODO(masahi): Remove duplication when inlined function call is supported
57+ vec_b_low = T .vectorlow (vec_b , dtype = "int8x8" )
5558
56- multiply = T .call_llvm_pure_intrin (
59+ multiply_low = T .call_llvm_pure_intrin (
5760 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.smull.v8i16" ),
5861 T .uint32 (2 ),
5962 vec_a ,
60- vec_b ,
63+ vec_b_low ,
6164 dtype = "int16x8" ,
6265 )
6366
64- pair1 = T .call_llvm_pure_intrin (
67+ pairwise_reduction_low = T .call_llvm_pure_intrin (
6568 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.saddlp.v4i32.v8i16" ),
6669 T .uint32 (1 ),
67- multiply ,
70+ multiply_low ,
6871 dtype = "int32x4" ,
6972 )
7073
71- vec_b_2 = B . vload ([ 2 , 0 ] , dtype = "int8x8" )
74+ vec_b_high = T . vectorhigh ( vec_b , dtype = "int8x8" )
7275
73- multiply_2 = T .call_llvm_pure_intrin (
76+ multiply_high = T .call_llvm_pure_intrin (
7477 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.smull.v8i16" ),
7578 T .uint32 (2 ),
7679 vec_a ,
77- vec_b_2 ,
80+ vec_b_high ,
7881 dtype = "int16x8" ,
7982 )
8083
81- pair2 = T .call_llvm_pure_intrin (
84+ pairwise_reduction_high = T .call_llvm_pure_intrin (
8285 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.saddlp.v4i32.v8i16" ),
8386 T .uint32 (1 ),
84- multiply_2 ,
87+ multiply_high ,
8588 dtype = "int32x4" ,
8689 )
8790
8891 C [T .ramp (T .int32 (0 ), 1 , 4 )] += T .call_llvm_pure_intrin (
8992 T .llvm_lookup_intrinsic_id ("llvm.aarch64.neon.addp.v4i32" ),
9093 T .uint32 (2 ),
91- pair1 ,
92- pair2 ,
94+ pairwise_reduction_low ,
95+ pairwise_reduction_high ,
9396 dtype = "int32x4" ,
9497 )
9598
0 commit comments