2323
2424
2525@T .prim_func
26- def dot_product_desc (a : T .handle , b : T .handle , c : T .handle ) -> None :
26+ def dot_product_16x4_desc (a : T .handle , b : T .handle , c : T .handle ) -> None :
2727 A = T .match_buffer (a , (4 ,), "uint8" , offset_factor = 1 )
2828 B = T .match_buffer (b , (16 , 4 ), "int8" , offset_factor = 1 )
2929 C = T .match_buffer (c , (16 ,), "int32" , offset_factor = 1 )
@@ -41,7 +41,7 @@ def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
4141
4242
4343@T .prim_func
44- def dot_product_intrin (a : T .handle , b : T .handle , c : T .handle ) -> None :
44+ def dot_product_16x4_vnni_impl (a : T .handle , b : T .handle , c : T .handle ) -> None :
4545 A = T .match_buffer (a , (4 ,), "uint8" , offset_factor = 1 )
4646 B = T .match_buffer (b , (16 , 4 ), "int8" , offset_factor = 1 )
4747 C = T .match_buffer (c , (16 ,), "int32" , offset_factor = 1 )
@@ -66,6 +66,6 @@ def dot_product_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
6666 )
6767
6868
69- INTRIN_NAME = "dot_16x1x16_uint8_int8_int32_cascadelake "
69+ VNNI_INTRIN = "dot_16x4_vnni "
7070
71- TensorIntrin .register (INTRIN_NAME , dot_product_desc , dot_product_intrin )
71+ TensorIntrin .register (VNNI_INTRIN , dot_product_16x4_desc , dot_product_16x4_vnni_impl )
0 commit comments