Skip to content

Commit f818ef3

Browse files
committed
Respond to review comments
1 parent 534c35c commit f818ef3

File tree

2 files changed

+35
-40
lines changed

2 files changed

+35
-40
lines changed

python/tvm/relay/backend/contrib/ethosu/legalize.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter
1818
"""A set of passes to legalize some of operations for the NPU"""
19-
from typing import List, Type
19+
from typing import List, Type, Callable
2020
import math
2121

2222
import numpy as np # type: ignore
@@ -125,7 +125,9 @@ def __call__(self, *args, **kwargs):
125125
pass
126126

127127

128-
def get_lut_from_func(ifm_scale, ifm_zp, ofm_scale, ofm_zp, func):
128+
def get_lut_from_func(
129+
ifm_scale: float, ifm_zp: int, ofm_scale: float, ofm_zp: int, func: Callable[[float], float]
130+
) -> List[int]:
129131
"""Method to calculate the values of the lookup table based on the calculation function"""
130132
lut_values = list()
131133
# Only int8 is currently supported
@@ -144,13 +146,15 @@ def get_lut_from_func(ifm_scale, ifm_zp, ofm_scale, ofm_zp, func):
144146
class LutActivationRewriter(DFPatternCallback):
145147
"""A class to create an identity operator with the LUT"""
146148

147-
def __init__(self, params_class, activation_type, calc_func):
149+
def __init__(
150+
self, params_class: Type, activation_type: str, calc_func: Callable[[float], float]
151+
):
148152
super().__init__(require_type=True, rewrite_once=True)
149153
self.pattern = (wildcard().has_attr({"Composite": params_class.composite_name}))(wildcard())
150154
self.activation_type = activation_type
151155
self.calc_func = calc_func
152156

153-
def callback(self, pre, post, node_map):
157+
def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map):
154158
id_input = post.args[0]
155159

156160
quantize_args = post.op.body.args
@@ -205,7 +209,7 @@ def __call__(self, *args, **kwargs):
205209
pass
206210

207211

208-
def sigmoid_calc_func(x):
212+
def sigmoid_calc_func(x: float) -> float:
209213
"""Function to calculate the values for sigmoid"""
210214
# Thse limits are inherited from TFLite
211215
upper_limit = 8.0

python/tvm/relay/op/contrib/ethosu.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -918,27 +918,30 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
918918
return pattern
919919

920920

921-
class TanhParams:
921+
class LutActivationParams:
922922
"""
923-
This class will parse a call to a ethos-u.tanh composite function
924-
and extract the parameter information.
923+
A parent class for LUT based activation functions that extract the input and
924+
output tensors and check whether they are valid.
925925
"""
926926

927-
composite_name = "ethos-u.tanh"
928-
929927
def __init__(self, func_body: Call):
930928
self.ofm = TensorParams(func_body)
931929
self.ifm = TensorParams(func_body.args[0].args[0].args[0])
932930

933931
def is_valid(self):
934932
"""
935-
This function checks whether reshape has compatible attributes with the NPU
933+
This function checks whether activation has compatible attributes with the NPU
936934
"""
937935
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
938936
return False
939937
return True
940938

941939

940+
class TanhParams(LutActivationParams):
941+
942+
composite_name = "ethos-u.tanh"
943+
944+
942945
def tanh_pattern():
943946
"""Create pattern for tanh"""
944947
dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
@@ -947,6 +950,23 @@ def tanh_pattern():
947950
return quant
948951

949952

953+
class SigmoidParams(LutActivationParams):
954+
"""
955+
This class will parse a call to a ethos-u.sigmoid composite function
956+
and extract the parameter information.
957+
"""
958+
959+
composite_name = "ethos-u.sigmoid"
960+
961+
962+
def sigmoid_pattern():
963+
"""Create pattern for sigmoid"""
964+
dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
965+
sigmoid = is_op("sigmoid")(dequant)
966+
quant = is_op("qnn.quantize")(sigmoid, is_constant(), is_constant())
967+
return quant
968+
969+
950970
class MeanParams:
951971
"""
952972
This class will parse a call to ethosu.mean composite function
@@ -1087,35 +1107,6 @@ def concat_pattern():
10871107
return concat
10881108

10891109

1090-
class SigmoidParams:
1091-
"""
1092-
This class will parse a call to a ethos-u.sigmoid composite function
1093-
and extract the parameter information.
1094-
"""
1095-
1096-
composite_name = "ethos-u.sigmoid"
1097-
1098-
def __init__(self, func_body: Call):
1099-
self.ofm = TensorParams(func_body)
1100-
self.ifm = TensorParams(func_body.args[0].args[0].args[0])
1101-
1102-
def is_valid(self):
1103-
"""
1104-
This function checks whether sigmoid has compatible attributes with the NPU
1105-
"""
1106-
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int8]):
1107-
return False
1108-
return True
1109-
1110-
1111-
def sigmoid_pattern():
1112-
"""Create pattern for sigmoid"""
1113-
dequant = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant())
1114-
sigmoid = is_op("sigmoid")(dequant)
1115-
quant = is_op("qnn.quantize")(sigmoid, is_constant(), is_constant())
1116-
return quant
1117-
1118-
11191110
@register_pattern_table("ethos-u")
11201111
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
11211112
return [

0 commit comments

Comments
 (0)