Skip to content

Commit 5014e6f

Browse files
committed
[QNN][Legalize] Specialize for Platforms without any fast Int8 arithmetic units.
1 parent 10b77ef commit 5014e6f

2 files changed

Lines changed: 321 additions & 32 deletions

File tree

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 174 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,43 @@
2222
from tvm import relay
2323
from .. import op as reg
2424

25+
#################################################
26+
# Register the functions for different operators.
27+
#################################################
28+
2529
# Registering QNN Conv2D legalization function.
2630
@reg.register_qnn_legalize("qnn.conv2d")
2731
def legalize_qnn_conv2d(attrs, inputs, types):
28-
"""Legalizes QNN conv2d op.
32+
return qnn_conv2d_legalize(attrs, inputs, types)
33+
34+
# Registering QNN dense legalization function.
35+
@reg.register_qnn_legalize("qnn.dense")
36+
def legalize_qnn_dense(attrs, inputs, types):
37+
return qnn_dense_legalize(attrs, inputs, types)
38+
39+
# Default to None. If overridden by target, this will not be run.
40+
# Generic QNN Conv2D legalization function.
41+
@tvm.target.generic_func
42+
def qnn_conv2d_legalize(attrs, inputs, types):
43+
"""Default legalization is None."""
44+
return None
45+
46+
# Generic QNN Conv2D legalization function.
47+
@tvm.target.generic_func
48+
def qnn_dense_legalize(attrs, inputs, types):
49+
"""Default legalization is None."""
50+
return None
51+
52+
###################
53+
# Helper functions.
54+
###################
55+
56+
# Helper function for lowering in the abscence of fast Int8 arithmetic units.
57+
def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
58+
""" Converts QNN operators into a sequence of Relay operators that are friendly to HW that do
59+
not have fast Int8 arithmetic. For example, for ARM, LLVM utilizes the assembly instructions
60+
much more efficiently if the convolution or dense operator input datatypes are int16 instead of
61+
int8. More details are present at https://github.com/apache/incubator-tvm/pull/4277.
2962
3063
Parameters
3164
----------
@@ -41,19 +74,27 @@ def legalize_qnn_conv2d(attrs, inputs, types):
4174
result : tvm.relay.Expr
4275
The legalized expr
4376
"""
44-
return qnn_conv2d_legalize(attrs, inputs, types)
4577

46-
# Generic QNN Conv2D legalization function.
47-
@tvm.target.generic_func
48-
def qnn_conv2d_legalize(attrs, inputs, types):
49-
"""Default legalization is None."""
50-
return None
78+
# Collect the input exprs.
79+
data, kernel = inputs
5180

52-
# Intel x86 QNN Conv2D legalization function.
53-
@qnn_conv2d_legalize.register('cpu')
54-
def _qnn_conv2d_legalize(attrs, inputs, types):
55-
"""Legalizes QNN conv2d op. VNNI supports u8 x i8 fast conv/MM. If the dtypes are already good,
56-
we dont transform. Else, we shift the tensor values and zero points to change the dtype.
81+
input_zp = attrs['input_zero_point']
82+
kernel_zp = attrs['kernel_zero_point']
83+
84+
shift_data = relay.subtract(relay.cast(data, dtype='int16'),
85+
relay.const(input_zp, 'int16'))
86+
shift_kernel = relay.subtract(relay.cast(kernel, dtype='int16'),
87+
relay.const(kernel_zp, 'int16'))
88+
new_attrs = {k : attrs[k] for k in attrs.keys()}
89+
del new_attrs['kernel_zero_point']
90+
del new_attrs['input_zero_point']
91+
return relay_op(shift_data, shift_kernel, **new_attrs)
92+
93+
# Helper function to change dtypes to uint8 x int8. Intel VNNI instructions prefer this setting.
94+
def helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay_op):
95+
"""Legalizes QNN conv2d/dense op for Intel HW. VNNI supports u8 x i8 fast conv/MM. If the dtypes
96+
are already good, we dont transform. Else, we shift the tensor values and zero points to change
97+
the dtype.
5798
5899
Converting from int8 to uint8 can be done in following manner.
59100
@@ -95,26 +136,13 @@ def _shift(data, out_dtype):
95136
data_modified = relay.cast(data_modified, out_dtype)
96137
return data_modified
97138

98-
def _is_int8_hw_support(target):
99-
"""
100-
Checks to ensure that we can use Intel DLBoost instructions - Check if the target is skylake
101-
and above.
102-
"""
103-
supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
104-
return supported_arches.intersection(set(target.options))
105-
106139
# Collect the dtypes.
107140
data_dtype = types[0].dtype
108141
kernel_dtype = types[1].dtype
109142

110143
# Collect the input exprs.
111144
data, kernel = inputs
112145

113-
# The VNNI transformations are applicable only Skylake and above.g
114-
target = tvm.target.current_target(allow_none=False)
115-
if not _is_int8_hw_support(target):
116-
return None
117-
118146
# VNNI supports u8 x i8 fast conv/MM. Don't do anything if it is already satisfied.
119147
if data_dtype == 'uint8' and kernel_dtype == 'int8':
120148
return None
@@ -137,4 +165,124 @@ def _is_int8_hw_support(target):
137165
new_attrs = {k : attrs[k] for k in attrs.keys()}
138166
new_attrs['input_zero_point'] = input_zp
139167
new_attrs['kernel_zero_point'] = kernel_zp
140-
return relay.qnn.op.conv2d(data, kernel, **new_attrs)
168+
return relay_op(data, kernel, **new_attrs)
169+
170+
# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
171+
def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
172+
""" Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
173+
many devices like ARM prefer the datatypes to be same for the HW units. This helper transforms
174+
conv2d/dense such that both the dtypes are same.
175+
176+
Parameters
177+
----------
178+
attrs : tvm.attrs.Attrs
179+
Attributes of current convolution
180+
inputs : list of tvm.relay.Expr
181+
The args of the Relay expr to be legalized
182+
types : list of types
183+
List of input and output types
184+
185+
Returns
186+
-------
187+
result : tvm.relay.Expr
188+
The legalized expr
189+
"""
190+
191+
def _shift(data, out_dtype):
192+
"""Shifts (add/subtracts) the qnn tensor with +/-128)"""
193+
if out_dtype == 'uint8':
194+
shift = 128
195+
elif out_dtype == 'int8':
196+
shift = -128
197+
else:
198+
raise ValueError("Unsupport out dtype.")
199+
data_modified = relay.cast(data, 'int32')
200+
data_modified = relay.add(data_modified, relay.const(shift, 'int32'))
201+
data_modified = relay.cast(data_modified, out_dtype)
202+
return data_modified
203+
204+
# Collect the dtypes.
205+
data_dtype = types[0].dtype
206+
kernel_dtype = types[1].dtype
207+
208+
# Collect the input exprs.
209+
data, kernel = inputs
210+
211+
if data_dtype == kernel_dtype:
212+
return None
213+
214+
assert 'int8' in data_dtype and 'int8' in kernel_dtype, \
215+
"Qnn Conv2D only accepts uint8 or int8 inputs"
216+
217+
# Shift input if necessary.
218+
input_zp = attrs['input_zero_point']
219+
data = _shift(data, kernel_dtype)
220+
if data_dtype == 'int8':
221+
input_zp = input_zp + 128
222+
elif data_dtype == 'uint8':
223+
input_zp = input_zp - 128
224+
else:
225+
raise RuntimeError("Qnn Conv2D only accepts uint8 or int8 inputs")
226+
227+
new_attrs = {k : attrs[k] for k in attrs.keys()}
228+
new_attrs['input_zero_point'] = input_zp
229+
return relay_op(data, kernel, **new_attrs)
230+
231+
def is_fast_int8_hw_present():
232+
"""
233+
Checks whether the hardware has support for fast Int8 arithmetic operations.
234+
1) Intel - Skylake/CascadeLake
235+
2) ARM - Dotprod
236+
We can extend this function to add more device targets.
237+
"""
238+
239+
target = tvm.target.current_target(allow_none=False)
240+
241+
# Intel cpu
242+
intel_supported_arches = {'-mcpu=skylake-avx512', '-mcpu=cascadelake'}
243+
is_present_intel = intel_supported_arches.intersection(set(target.options))
244+
245+
# ARM cpu
246+
arm_supported_attr = '+v8.2a,+dotprod'
247+
is_present_arm = False
248+
for opt in target.options:
249+
if arm_supported_attr in opt:
250+
is_present_arm = True
251+
252+
return is_present_intel or is_present_arm
253+
254+
########################
255+
# ARM CPU legalizations.
256+
########################
257+
258+
@qnn_conv2d_legalize.register('arm_cpu')
259+
def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
260+
# ARM prefers the dtypes to be same.
261+
if is_fast_int8_hw_present():
262+
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
263+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
264+
265+
@qnn_dense_legalize.register('arm_cpu')
266+
def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
267+
# ARM prefers the dtypes to be same.
268+
if is_fast_int8_hw_present():
269+
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.dense)
270+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)
271+
272+
##########################
273+
# Intel CPU legalizations.
274+
##########################
275+
276+
@qnn_conv2d_legalize.register('cpu')
277+
def _qnn_conv2d_legalize_intel_cpu(attrs, inputs, types):
278+
# The VNNI transformations prefer uint8 x int8 datatypes.
279+
if is_fast_int8_hw_present():
280+
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.conv2d)
281+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
282+
283+
@qnn_dense_legalize.register('cpu')
284+
def _qnn_dense_legalize_intel_cpu(attrs, inputs, types):
285+
# The VNNI transformations prefer uint8 x int8 datatypes.
286+
if is_fast_int8_hw_present():
287+
return helper_change_dtypes_to_uint8_int8(attrs, inputs, types, relay.qnn.op.dense)
288+
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.dense)

0 commit comments

Comments
 (0)