@@ -918,27 +918,30 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
918918 return pattern
919919
920920
921- class TanhParams :
921+ class LutActivationParams :
922922 """
923- This class will parse a call to a ethos-u.tanh composite function
924- and extract the parameter information .
923+ A parent class for LUT based activation functions that extract the input and
924+ output tensors and check whether they are valid .
925925 """
926926
927- composite_name = "ethos-u.tanh"
928-
929927 def __init__ (self , func_body : Call ):
930928 self .ofm = TensorParams (func_body )
931929 self .ifm = TensorParams (func_body .args [0 ].args [0 ].args [0 ])
932930
933931 def is_valid (self ):
934932 """
935- This function checks whether reshape has compatible attributes with the NPU
933+ This function checks whether activation has compatible attributes with the NPU
936934 """
937935 if not check_valid_dtypes ([self .ifm , self .ofm ], supported_dtypes = [np .int8 ]):
938936 return False
939937 return True
940938
941939
940+ class TanhParams (LutActivationParams ):
941+
942+ composite_name = "ethos-u.tanh"
943+
944+
942945def tanh_pattern ():
943946 """Create pattern for tanh"""
944947 dequant = is_op ("qnn.dequantize" )(wildcard (), is_constant (), is_constant ())
@@ -947,6 +950,23 @@ def tanh_pattern():
947950 return quant
948951
949952
953+ class SigmoidParams (LutActivationParams ):
954+ """
955+ This class will parse a call to a ethos-u.sigmoid composite function
956+ and extract the parameter information.
957+ """
958+
959+ composite_name = "ethos-u.sigmoid"
960+
961+
962+ def sigmoid_pattern ():
963+ """Create pattern for sigmoid"""
964+ dequant = is_op ("qnn.dequantize" )(wildcard (), is_constant (), is_constant ())
965+ sigmoid = is_op ("sigmoid" )(dequant )
966+ quant = is_op ("qnn.quantize" )(sigmoid , is_constant (), is_constant ())
967+ return quant
968+
969+
950970class MeanParams :
951971 """
952972 This class will parse a call to ethosu.mean composite function
@@ -1087,35 +1107,6 @@ def concat_pattern():
10871107 return concat
10881108
10891109
1090- class SigmoidParams :
1091- """
1092- This class will parse a call to a ethos-u.sigmoid composite function
1093- and extract the parameter information.
1094- """
1095-
1096- composite_name = "ethos-u.sigmoid"
1097-
1098- def __init__ (self , func_body : Call ):
1099- self .ofm = TensorParams (func_body )
1100- self .ifm = TensorParams (func_body .args [0 ].args [0 ].args [0 ])
1101-
1102- def is_valid (self ):
1103- """
1104- This function checks whether sigmoid has compatible attributes with the NPU
1105- """
1106- if not check_valid_dtypes ([self .ifm , self .ofm ], supported_dtypes = [np .int8 ]):
1107- return False
1108- return True
1109-
1110-
1111- def sigmoid_pattern ():
1112- """Create pattern for sigmoid"""
1113- dequant = is_op ("qnn.dequantize" )(wildcard (), is_constant (), is_constant ())
1114- sigmoid = is_op ("sigmoid" )(dequant )
1115- quant = is_op ("qnn.quantize" )(sigmoid , is_constant (), is_constant ())
1116- return quant
1117-
1118-
11191110@register_pattern_table ("ethos-u" )
11201111def pattern_table () -> List [Tuple [str , tvm .relay .dataflow_pattern .DFPattern , Callable ]]:
11211112 return [
0 commit comments