2222from tvm import relay
2323from .. import op as reg
2424
25+ #################################################
26+ # Register the functions for different operators.
27+ #################################################
28+
2529# Registering QNN Conv2D legalization function.
2630@reg .register_qnn_legalize ("qnn.conv2d" )
2731def legalize_qnn_conv2d (attrs , inputs , types ):
28- """Legalizes QNN conv2d op.
32+ return qnn_conv2d_legalize (attrs , inputs , types )
33+
34+ # Registering QNN dense legalization function.
35+ @reg .register_qnn_legalize ("qnn.dense" )
36+ def legalize_qnn_dense (attrs , inputs , types ):
37+ return qnn_dense_legalize (attrs , inputs , types )
38+
39+ # Default to None. If overridden by target, this will not be run.
40+ # Generic QNN Conv2D legalization function.
41+ @tvm .target .generic_func
42+ def qnn_conv2d_legalize (attrs , inputs , types ):
43+ """Default legalization is None."""
44+ return None
45+
46+ # Generic QNN Conv2D legalization function.
47+ @tvm .target .generic_func
48+ def qnn_dense_legalize (attrs , inputs , types ):
49+ """Default legalization is None."""
50+ return None
51+
52+ ###################
53+ # Helper functions.
54+ ###################
55+
56+ # Helper function for lowering in the abscence of fast Int8 arithmetic units.
57+ def helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay_op ):
58+ """ Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
59+ not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
60+ much more efficiently if the convolution or dense operator input datatypes are int16 instead of
61+ int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
2962
3063 Parameters
3164 ----------
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
4174 result : tvm.relay.Expr
4275 The legalized expr
4376 """
44- return qnn_conv2d_legalize (attrs , inputs , types )
4577
46- # Generic QNN Conv2D legalization function.
47- @tvm .target .generic_func
48- def qnn_conv2d_legalize (attrs , inputs , types ):
49- """Default legalization is None."""
50- return None
78+ # Collect the input exprs.
79+ data , kernel = inputs
5180
52- # Intel x86 QNN Conv2D legalization function.
53- @qnn_conv2d_legalize .register ('cpu' )
54- def _qnn_conv2d_legalize (attrs , inputs , types ):
55- """Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
56- we dont transform. Else, we shift the tensor values and zero points to change the dtype.
81+ input_zp = attrs ['input_zero_point' ]
82+ kernel_zp = attrs ['kernel_zero_point' ]
83+
84+ shift_data = relay .subtract (relay .cast (data , dtype = 'int16' ),
85+ relay .const (input_zp , 'int16' ))
86+ shift_kernel = relay .subtract (relay .cast (kernel , dtype = 'int16' ),
87+ relay .const (kernel_zp , 'int16' ))
88+ new_attrs = {k : attrs [k ] for k in attrs .keys ()}
89+ del new_attrs ['kernel_zero_point' ]
90+ del new_attrs ['input_zero_point' ]
91+ return relay_op (shift_data , shift_kernel , ** new_attrs )
92+
93+ # Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
94+ def helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay_op ):
95+ """Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
96+ are already good, we dont transform. Else, we shift the tensor values and zero points to change
97+ the dtype.
5798
5899 Converting from int8 to uint8 can be done in following manner.
59100
@@ -95,26 +136,13 @@ def _shift(data, out_dtype):
95136 data_modified = relay .cast (data_modified , out_dtype )
96137 return data_modified
97138
98- def _is_int8_hw_support (target ):
99- """
100- Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
101- and above.
102- """
103- supported_arches = {'-mcpu=skylake-avx512' , '-mcpu=cascadelake' }
104- return supported_arches .intersection (set (target .options ))
105-
106139 # Collect the dtypes.
107140 data_dtype = types [0 ].dtype
108141 kernel_dtype = types [1 ].dtype
109142
110143 # Collect the input exprs.
111144 data , kernel = inputs
112145
113- # The VNNI transformations are applicable only Skylake and above.g
114- target = tvm .target .current_target (allow_none = False )
115- if not _is_int8_hw_support (target ):
116- return None
117-
118146 # VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
119147 if data_dtype == 'uint8' and kernel_dtype == 'int8' :
120148 return None
@@ -137,4 +165,124 @@ def _is_int8_hw_support(target):
137165 new_attrs = {k : attrs [k ] for k in attrs .keys ()}
138166 new_attrs ['input_zero_point' ] = input_zp
139167 new_attrs ['kernel_zero_point' ] = kernel_zp
140- return relay .qnn .op .conv2d (data , kernel , ** new_attrs )
168+ return relay_op (data , kernel , ** new_attrs )
169+
170+ # Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
171+ def helper_change_dtypes_to_be_same (attrs , inputs , types , relay_op ):
172+ """ Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
173+ many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
174+ conv2d/dense such that both the dtypes are same.
175+
176+ Parameters
177+ ----------
178+ attrs : tvm.attrs.Attrs
179+ Attributes of current convolution
180+ inputs : list of tvm.relay.Expr
181+ The args of the Relay expr to be legalized
182+ types : list of types
183+ List of input and output types
184+
185+ Returns
186+ -------
187+ result : tvm.relay.Expr
188+ The legalized expr
189+ """
190+
191+ def _shift (data , out_dtype ):
192+ """Shifts (add/subtracts) the qnn tensor with +/-128)"""
193+ if out_dtype == 'uint8' :
194+ shift = 128
195+ elif out_dtype == 'int8' :
196+ shift = - 128
197+ else :
198+ raise ValueError ("Unsupport out dtype." )
199+ data_modified = relay .cast (data , 'int32' )
200+ data_modified = relay .add (data_modified , relay .const (shift , 'int32' ))
201+ data_modified = relay .cast (data_modified , out_dtype )
202+ return data_modified
203+
204+ # Collect the dtypes.
205+ data_dtype = types [0 ].dtype
206+ kernel_dtype = types [1 ].dtype
207+
208+ # Collect the input exprs.
209+ data , kernel = inputs
210+
211+ if data_dtype == kernel_dtype :
212+ return None
213+
214+ assert 'int8' in data_dtype and 'int8' in kernel_dtype , \
215+ "Qnn Conv2D only accepts uint8 or int8 inputs"
216+
217+ # Shift input if necessary.
218+ input_zp = attrs ['input_zero_point' ]
219+ data = _shift (data , kernel_dtype )
220+ if data_dtype == 'int8' :
221+ input_zp = input_zp + 128
222+ elif data_dtype == 'uint8' :
223+ input_zp = input_zp - 128
224+ else :
225+ raise RuntimeError ("Qnn Conv2D only accepts uint8 or int8 inputs" )
226+
227+ new_attrs = {k : attrs [k ] for k in attrs .keys ()}
228+ new_attrs ['input_zero_point' ] = input_zp
229+ return relay_op (data , kernel , ** new_attrs )
230+
231+ def is_fast_int8_hw_present ():
232+ """
233+ Checks whether the hardware has support for fast Int8 arithmetic operations.
234+ 1) Intel - Skylake/CascadeLake
235+ 2) ARM - Dotprod
236+ We can extend this function to add more device targets.
237+ """
238+
239+ target = tvm .target .current_target (allow_none = False )
240+
241+ # Intel cpu
242+ intel_supported_arches = {'-mcpu=skylake-avx512' , '-mcpu=cascadelake' }
243+ is_present_intel = intel_supported_arches .intersection (set (target .options ))
244+
245+ # ARM cpu
246+ arm_supported_attr = '+v8.2a,+dotprod'
247+ is_present_arm = False
248+ for opt in target .options :
249+ if arm_supported_attr in opt :
250+ is_present_arm = True
251+
252+ return is_present_intel or is_present_arm
253+
254+ ########################
255+ # ARM CPU legalizations.
256+ ########################
257+
258+ @qnn_conv2d_legalize .register ('arm_cpu' )
259+ def _qnn_conv2d_legalize_arm_cpu (attrs , inputs , types ):
260+ # ARM prefers the dtypes to be same.
261+ if is_fast_int8_hw_present ():
262+ return helper_change_dtypes_to_be_same (attrs , inputs , types , relay .qnn .op .conv2d )
263+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .conv2d )
264+
265+ @qnn_dense_legalize .register ('arm_cpu' )
266+ def _qnn_dense_legalize_arm_cpu (attrs , inputs , types ):
267+ # ARM prefers the dtypes to be same.
268+ if is_fast_int8_hw_present ():
269+ return helper_change_dtypes_to_be_same (attrs , inputs , types , relay .qnn .op .dense )
270+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .dense )
271+
272+ ##########################
273+ # Intel CPU legalizations.
274+ ##########################
275+
276+ @qnn_conv2d_legalize .register ('cpu' )
277+ def _qnn_conv2d_legalize_intel_cpu (attrs , inputs , types ):
278+ # The VNNI transformations prefer uint8 x int8 datatypes.
279+ if is_fast_int8_hw_present ():
280+ return helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay .qnn .op .conv2d )
281+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .conv2d )
282+
283+ @qnn_dense_legalize .register ('cpu' )
284+ def _qnn_dense_legalize_intel_cpu (attrs , inputs , types ):
285+ # The VNNI transformations prefer uint8 x int8 datatypes.
286+ if is_fast_int8_hw_present ():
287+ return helper_change_dtypes_to_uint8_int8 (attrs , inputs , types , relay .qnn .op .dense )
288+ return helper_no_fast_int8_hw_legalization (attrs , inputs , types , relay .nn .dense )
0 commit comments