@@ -576,66 +576,84 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
576576 assembly = lib .get_source ("asm" )
577577 return assembly
578578
579- # compile conv2d for x86 (skylake) and test assembly contains *pmadd* instructions
580- target = "llvm -mcpu=skylake-avx512"
581- name = "llvm.x86.avx512.pmaddubs.w.512"
582- llvm_id = tvm .codegen .llvm_lookup_intrinsic_id (name )
583- if llvm_id != 0 :
584- fast_int8_dtypes = ('uint8' , 'int8' , 'int32' )
585- # Sweep the input channels to check int8 robustness
586- for ic in range (1 , 24 ):
587- asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
588- dtypes = fast_int8_dtypes )
589- assert "pmaddubs" in asm
590-
591- for ic in range (1 , 24 ):
592- asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
593- dtypes = fast_int8_dtypes )
594- assert "pmaddubs" in asm
595-
596-
597- # Sweep the output channels to check int8 robustness
598- for oc in range (2 , 24 ):
599- asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
579+ def has_fast_int8_instruction (asm , target ):
580+ intel_device_type = None
581+ if 'skylake-avx512' in target :
582+ return "pmaddubs" in asm
583+ elif 'cascadelake' in target :
584+ return "vpdpbusd" in asm
585+ else :
586+ assert False , "Target should be Skylake or Cascadelake"
587+
588+ # compile conv2d for x86 (skylake, cascadelake) and test assembly contains *pmadd* instructions
589+ targets = ["llvm -mcpu=skylake-avx512" , "llvm -mcpu=cascadelake" ]
590+ name_skylake = "llvm.x86.avx512.pmaddubs.w.512"
591+ name_cascadelake = 'llvm.x86.avx512.vpdpbusd.512'
592+ llvm_id_skylake = tvm .codegen .llvm_lookup_intrinsic_id (name_skylake )
593+ llvm_id_cascadelake = tvm .codegen .llvm_lookup_intrinsic_id (name_cascadelake )
594+ for target in targets :
595+ if llvm_id_skylake != 0 and llvm_id_cascadelake :
596+ fast_int8_dtypes = ('uint8' , 'int8' , 'int32' )
597+ # Sweep the input channels to check int8 robustness
598+ # Input channels should be a multiple of 4 internally.
599+ for ic in [1 , 4 , 6 ]:
600+ asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NCHW" ,
601+ kernel_layout = 'OIHW' ,
602+ dtypes = fast_int8_dtypes )
603+ assert has_fast_int8_instruction (asm , target )
604+
605+ for ic in [1 , 4 , 6 ]:
606+ asm = _compile (ic = ic , oc = 32 , target = target , data_layout = "NHWC" ,
607+ kernel_layout = 'HWIO' ,
608+ dtypes = fast_int8_dtypes )
609+ assert has_fast_int8_instruction (asm , target )
610+
611+
612+ # Sweep the output channels to check int8 robustness
613+ # Output channels should be a multiple of 16 internally.
614+ for oc in [4 , 16 , 20 ]:
615+ asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NCHW" ,
616+ kernel_layout = 'OIHW' ,
617+ dtypes = fast_int8_dtypes )
618+ assert has_fast_int8_instruction (asm , target )
619+
620+ for oc in [4 , 16 , 20 ]:
621+ asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NHWC" ,
622+ kernel_layout = 'HWIO' ,
623+ dtypes = fast_int8_dtypes )
624+ assert has_fast_int8_instruction (asm , target )
625+
626+ # Check that both non-divisible oc and ic work
627+ asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
600628 dtypes = fast_int8_dtypes )
601- assert "pmaddubs" in asm
629+ assert has_fast_int8_instruction ( asm , target )
602630
603- for oc in range (2 , 24 ):
604- asm = _compile (ic = 16 , oc = oc , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
631+ asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
605632 dtypes = fast_int8_dtypes )
606- assert "pmaddubs" in asm
607-
608- # Check that both non-divisible oc and ic work
609- asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
610- dtypes = fast_int8_dtypes )
611- assert "pmaddubs" in asm
612-
613- asm = _compile (ic = 17 , oc = 29 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
633+ assert has_fast_int8_instruction (asm , target )
634+
635+ # Ensure that code is generated when datatypes are not HW supported.
636+ dtypes = ('int8' , 'int8' , 'int32' )
637+ asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
638+ dtypes = dtypes )
639+ # Check that intrinisic is not present in the assembly.
640+ assert not has_fast_int8_instruction (asm , target )
641+
642+ # Ensure that code is generated when datatypes are not HW supported.
643+ dtypes = ('uint8' , 'uint8' , 'int32' )
644+ asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
645+ dtypes = dtypes )
646+ # Check that intrinisic is not present in the assembly.
647+ assert not has_fast_int8_instruction (asm , target )
648+
649+ # Check that a vectorized instruction is generated for older Intel
650+ # generations, because we default to NCHWc layout.
651+ target = "llvm -mcpu=core-avx2"
652+ fast_int8_dtypes = ('uint8' , 'int8' , 'int32' )
653+ asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
614654 dtypes = fast_int8_dtypes )
615- assert "pmaddubs" in asm
616-
617- # Ensure that code is generated when datatypes are not HW supported.
618- dtypes = ('int8' , 'int8' , 'int32' )
619- asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
620- dtypes = dtypes )
621- # Check that intrinisic is not present in the assembly.
622- assert "pmaddubs" not in asm
623-
624- # Ensure that code is generated when datatypes are not HW supported.
625- dtypes = ('uint8' , 'uint8' , 'int32' )
626- asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NHWC" , kernel_layout = 'HWIO' ,
627- dtypes = dtypes )
628- # Check that intrinisic is not present in the assembly.
629- assert "pmaddubs" not in asm
630-
631- # Check that a vectorized instruction is generated for older Intel
632- # generations, because we default to NCHWc layout.
633- target = "llvm -mcpu=core-avx2"
634- fast_int8_dtypes = ('uint8' , 'int8' , 'int32' )
635- asm = _compile (ic = 16 , oc = 32 , target = target , data_layout = "NCHW" , kernel_layout = 'OIHW' ,
636- dtypes = fast_int8_dtypes )
637- # Check that vector int mult and add instructions are generated.
638- assert "vpmulld" in asm and "vpadd" in asm
655+ # Check that vector int mult and add instructions are generated.
656+ assert "vpmulld" in asm and "vpadd" in asm
639657
640658
641659def test_bitserial_conv2d_infer_type ():
0 commit comments