@@ -287,7 +287,7 @@ class Sub(BinaryBase):
287287 relax_op = relax .op .subtract
288288
289289 @classmethod
290- def _impl_v1 (cls , bb , inputs , attr , params ):
290+ def _impl_v7 (cls , bb , inputs , attr , params ):
291291 return cls .base_impl (bb , inputs , attr , params )
292292
293293
@@ -298,7 +298,7 @@ class Mul(BinaryBase):
298298 relax_op = relax .op .multiply
299299
300300 @classmethod
301- def _impl_v1 (cls , bb , inputs , attr , params ):
301+ def _impl_v7 (cls , bb , inputs , attr , params ):
302302 return cls .base_impl (bb , inputs , attr , params )
303303
304304
@@ -309,7 +309,7 @@ class Div(BinaryBase):
309309 relax_op = relax .op .divide
310310
311311 @classmethod
312- def _impl_v1 (cls , bb , inputs , attr , params ):
312+ def _impl_v7 (cls , bb , inputs , attr , params ):
313313 return cls .base_impl (bb , inputs , attr , params )
314314
315315
@@ -320,7 +320,24 @@ class Pow(BinaryBase):
320320 relax_op = relax .op .power
321321
322322 @classmethod
323- def _impl_v1 (cls , bb , inputs , attr , params ):
323+ def _impl_v7 (cls , bb , inputs , attr , params ):
324+ return cls .base_impl (bb , inputs , attr , params )
325+
326+
327+ class Mod (BinaryBase ):
328+ """Converts an onnx Mod node into an equivalent Relax expression."""
329+
330+ numpy_op = _np .mod
331+ relax_op = relax .op .mod
332+
333+ @classmethod
334+ def _impl_v10 (cls , bb , inputs , attr , params ):
335+ if attr .get ("fmod" , 0 ) == 0 :
336+ cls .numpy_op = _np .fmod
337+ cls .relax_op = relax .op .floor_mod
338+ else :
339+ cls .numpy_op = _np .mod
340+ cls .relax_op = relax .op .mod
324341 return cls .base_impl (bb , inputs , attr , params )
325342
326343
@@ -523,6 +540,23 @@ def _impl_v13(cls, bb, inputs, attr, params):
523540 return relax .op .nn .log_softmax (inputs [0 ], axis = axis )
524541
525542
543+ class Hardmax (OnnxOpConverter ):
544+ """Converts an onnx Hardmax node into an equivalent Relax expression."""
545+
546+ @classmethod
547+ def _impl_v13 (cls , bb , inputs , attr , params ):
548+ axis = attr .get ("axis" , - 1 )
549+ indices = inputs [0 ]
550+ dtype = indices .struct_info .dtype
551+ axis_len = int (inputs [0 ].struct_info .shape [axis ])
552+ argmax = relax .op .argmax (indices , axis = axis )
553+ on_value = relax .PrimValue (tvm .tir .const (1.0 , dtype ))
554+ off_value = relax .PrimValue (tvm .tir .const (0.0 , dtype ))
555+
556+ one_hot = relax .op .one_hot (argmax , on_value , off_value , axis_len , axis )
557+ return one_hot
558+
559+
526560class Transpose (OnnxOpConverter ):
527561 """Converts an onnx Transpose node into an equivalent Relax expression."""
528562
@@ -731,6 +765,22 @@ def _impl_v1(cls, bb, inputs, attr, params):
731765 return relax .op .prod (relax .op .shape_to_tensor (relax .op .shape_of (inputs [0 ])))
732766
733767
768+ class EyeLike (OnnxOpConverter ):
769+ """Convert an onnx EyeLike node into an equivalent Relax expression."""
770+
771+ @classmethod
772+ def _impl_v9 (cls , bb , inputs , attr , params ):
773+ k = attr .get ("k" , 0 )
774+ input_dtype = inputs [0 ].struct_info .dtype
775+ if "dtype" in attr and get_type (attr ["dtype" ]) != input_dtype :
776+ raise ValueError (
777+ f"dtype mismatch between input ({ input_dtype } ) and attribute ({ attr ['dtype' ]} )"
778+ )
779+ else :
780+ dtype = input_dtype
781+ return relax .op .eye_like (inputs [0 ], k , dtype )
782+
783+
734784class Gemm (OnnxOpConverter ):
735785 """Convert an onnx Gemm node into an equivalent Relax expression."""
736786
@@ -2520,13 +2570,13 @@ def _impl_v11(cls, bb, inputs, attr, params):
25202570 depth = get_constant (inputs [1 ], params )
25212571 values = get_constant (inputs [2 ], params )
25222572 axis = attr .get ("axis" , - 1 )
2523- dtype = values .struct_info .dtype
25242573 assert isinstance (depth , relax .Constant ), "Only constant depth currently supported."
25252574 depth = depth .data .numpy ().tolist ()
25262575 assert isinstance (values , relax .Constant ), "Only constant values currently supported."
25272576 values = values .data .numpy ().tolist ()
25282577 off_value , on_value = values
2529- return bb .emit_te (topi .one_hot , indices , on_value , off_value , depth , axis , dtype )
2578+ off_value , on_value = relax .PrimValue (off_value ), relax .PrimValue (on_value )
2579+ return relax .op .one_hot (indices , on_value , off_value , depth , axis )
25302580
25312581
25322582class Unique (OnnxOpConverter ):
@@ -2800,7 +2850,7 @@ def _get_convert_map():
28002850 "Sub" : Sub ,
28012851 "Mul" : Mul ,
28022852 "Div" : Div ,
2803- # "Mod": Mod,
2853+ "Mod" : Mod ,
28042854 "Less" : Less ,
28052855 "LessOrEqual" : LessOrEqual ,
28062856 "Greater" : Greater ,
@@ -2870,7 +2920,7 @@ def _get_convert_map():
28702920 "Sigmoid" : Sigmoid ,
28712921 "Softmax" : Softmax ,
28722922 "LogSoftmax" : LogSoftmax ,
2873- # "Hardmax": Hardmax,
2923+ "Hardmax" : Hardmax ,
28742924 "Transpose" : Transpose ,
28752925 "Unsqueeze" : Unsqueeze ,
28762926 "Where" : Where ,
@@ -2889,7 +2939,7 @@ def _get_convert_map():
28892939 "ScatterND" : ScatterND ,
28902940 # "Compress": Compress,
28912941 "Size" : Size ,
2892- # "EyeLike": EyeLike,
2942+ "EyeLike" : EyeLike ,
28932943 # Normalization
28942944 "BatchNormalization" : BatchNormalization ,
28952945 "LayerNormalization" : LayerNormalization ,
0 commit comments