@@ -84,10 +84,12 @@ def _dim_check(attrs):
8484 return _dim_check , "Only 2d kernel supported."
8585
8686def _get_param (params , input_node ):
87+ if isinstance (input_node , _expr .Constant ):
88+ return np .atleast_1d (input_node .data .asnumpy ())
8789 return params .pop (input_node .name_hint ).asnumpy ()
8890
8991def _get_num_param (params , input_node ):
90- return _get_param (params , input_node )[ 0 ]
92+ return _get_param (params , input_node ). item ()
9193
9294def _get_list_param (params , input_node ):
9395 return _get_param (params , input_node ).tolist ()
@@ -335,9 +337,9 @@ def _impl(inputs, attr, params):
335337 # input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
336338 # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
337339 try :
338- boxes = params . pop ( inputs [1 ]. name_hint ). asnumpy (). tolist ( )
339- box_ind = params . pop ( inputs [2 ]. name_hint ). asnumpy (). tolist ( )
340- crop_size = params . pop ( inputs [3 ]. name_hint ). asnumpy (). tolist ( )
340+ boxes = _get_list_param ( params , inputs [1 ])
341+ box_ind = _get_list_param ( params , inputs [2 ])
342+ crop_size = _get_list_param ( params , inputs [3 ])
341343 except (IndexError , KeyError ):
342344 boxes = _infer_value (inputs [1 ], params ).asnumpy ().tolist ()
343345 box_ind = _infer_value (inputs [2 ], params ).asnumpy ().tolist ()
@@ -505,7 +507,7 @@ def _impl(inputs, attr, params):
505507
506508def _tile ():
507509 def _impl (inputs , attr , params ):
508- reps = params [ inputs .pop (). name_hint ]. asnumpy ( )
510+ reps = _get_list_param ( params , inputs .pop ())
509511 new_input = []
510512 new_input .append (inputs .pop (0 ))
511513
@@ -752,7 +754,7 @@ def _impl(inputs, attr, params):
752754
753755def _reduce (op ):
754756 def _impl (inputs , attr , params ):
755- axis = params . pop ( inputs [1 ]. name_hint ). asnumpy ( )
757+ axis = _get_list_param ( params , inputs [1 ])
756758 axis = tuple (axis )
757759 return AttrCvt (
758760 op_name = op ,
@@ -937,8 +939,8 @@ def _impl(inputs, attr, params):
937939
938940def _clip_by_value ():
939941 def _impl (inputs , attr , params ):
940- a_min = params . pop ( inputs [1 ]. name_hint ). asnumpy ()[ 0 ]
941- a_max = params . pop ( inputs [2 ]. name_hint ). asnumpy ()[ 0 ]
942+ a_min = _get_num_param ( params , inputs [1 ])
943+ a_max = _get_num_param ( params , inputs [2 ])
942944 return _op .clip (inputs [0 ], a_min = a_min , a_max = a_max )
943945 return _impl
944946
@@ -965,10 +967,11 @@ def _impl(inputs, attr, params):
965967
966968def _range ():
967969 def _impl (inputs , attr , params ):
968- start = params .pop (inputs [0 ].name_hint ).asnumpy ()[0 ]
969- limit = params .pop (inputs [1 ].name_hint ).asnumpy ()[0 ] \
970- if hasattr (inputs [1 ], "name_hint" ) else params .pop ('Rank' ).asnumpy ()[0 ]
971- delta = params .pop (inputs [2 ].name_hint ).asnumpy ()[0 ]
970+ start = _get_param (params , inputs [0 ])[0 ]
971+ limit = _get_param (params , inputs [1 ])[0 ] \
972+ if hasattr (inputs [1 ], "name_hint" ) or isinstance (inputs [1 ], _expr .Constant ) \
973+ else params .pop ('Rank' ).asnumpy ()[0 ]
974+ delta = _get_param (params , inputs [2 ])[0 ]
972975 dtype = attr ['dtype' ].name if 'dtype' in attr else "int32"
973976 return AttrCvt (
974977 op_name = "arange" ,
@@ -1084,7 +1087,7 @@ def _impl(inputs, attr, params):
10841087
10851088def _topk ():
10861089 def _impl (inputs , attr , params ):
1087- k = int (params . pop ( inputs .pop (1 ). name_hint ). asnumpy ( ))
1090+ k = int (_get_num_param ( params , inputs .pop (1 )))
10881091 if k < 1 :
10891092 raise tvm .error .OpAttributeInvalid (
10901093 'Attribute k must be positive in operator TopKV2' )
@@ -1196,7 +1199,7 @@ def _impl(inputs, attr, params):
11961199
11971200def _prod ():
11981201 def _impl (inputs , attr , params ):
1199- axis = params . pop ( inputs [1 ]. name_hint ). asnumpy ()[ 0 ]
1202+ axis = _get_num_param ( params , inputs [1 ])
12001203 keepdims = attr ['keep_dims' ]
12011204 return _op .prod (inputs [0 ], int (axis ), keepdims = keepdims )
12021205 return _impl
@@ -2104,13 +2107,12 @@ def _parse_param(self, key, value, name, shape):
21042107 if array_ndim == 0 :
21052108 new_array = np .empty ([1 ], dtype = np_array .dtype )
21062109 new_array [0 ] = np_array
2107- self ._params [name ] = tvm .nd . array (new_array )
2110+ self ._nodes [name ] = [ tvm .relay . const (new_array )]
21082111 else :
21092112 self ._params [name ] = tvm .nd .array (np_array )
2110-
2111- self ._nodes [name ] = [_expr .var (name ,
2112- shape = self ._params [name ].shape ,
2113- dtype = self ._params [name ].dtype )]
2113+ self ._nodes [name ] = [_expr .var (name ,
2114+ shape = self ._params [name ].shape ,
2115+ dtype = self ._params [name ].dtype )]
21142116 else :
21152117 if key not in ('dtype' , '_output_shapes' , '_class' ):
21162118 raise NotImplementedError \
0 commit comments