1616# under the License.
1717# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
1818# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
19- # pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
19+ # pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
2020"""PT: PyTorch frontend."""
2121import itertools
2222import logging
3636from .common import AttrCvt , get_relay_op
3737from .common import infer_shape as _infer_shape
3838from .common import infer_value as _infer_value
39+ from .common import try_infer_value
3940from .common import infer_value_simulated as _infer_value_simulated
4041from .common import infer_type as _infer_type
4142from ..prelude import Prelude , StaticTensorArrayOps
@@ -185,11 +186,8 @@ def _impl(inputs, input_types):
185186 def _get_value (val , dtype ):
186187 # dtype is a tvm dtype
187188 if isinstance (val , _expr .Expr ):
188- try :
189- ret = _infer_value (_op .cast (val , dtype ), {}).asnumpy ()
190- ret = _expr .const (ret , dtype )
191- except Exception :
192- ret = _op .cast (val , dtype )
189+ inp = _op .cast (val , dtype )
190+ ret , _ = try_infer_value (inp , lambda ret : _expr .const (ret , dtype ))
193191 else :
194192 ret = _create_typed_const (val , dtype )
195193 return ret
@@ -305,10 +303,7 @@ def _impl(inputs, input_types):
305303 dim = int (inputs [1 ])
306304 stride = int (inputs [4 ])
307305 if isinstance (inputs [2 ], _expr .Call ):
308- try :
309- begin [dim ] = np .asscalar (_infer_value (inputs [2 ], {}).asnumpy ().astype (np .int ))
310- except Exception :
311- begin [dim ] = inputs [2 ]
306+ begin [dim ], _ = try_infer_value (inputs [2 ], lambda ret : np .asscalar (ret .astype (np .int )))
312307 else :
313308 begin [dim ] = int (inputs [2 ])
314309
@@ -329,10 +324,9 @@ def _impl(inputs, input_types):
329324 target_end = int (inputs [3 ])
330325 else :
331326 if isinstance (inputs [3 ], _expr .Expr ):
332- try :
333- target_end = np .asscalar (_infer_value (inputs [3 ], {}).asnumpy ().astype (np .int ))
334- except Exception :
335- target_end = inputs [3 ]
327+ target_end , _ = try_infer_value (
328+ inputs [3 ], lambda ret : np .asscalar (ret .astype (np .int ))
329+ )
336330 else :
337331 target_end = inputs [3 ]
338332
@@ -457,10 +451,7 @@ def _impl(inputs, input_types):
457451 sort = bool (inputs [4 ])
458452
459453 if isinstance (inputs [1 ], _expr .Expr ):
460- try :
461- k = _infer_value (inputs [1 ], {}).asnumpy ().tolist ()
462- except Exception :
463- k = inputs [1 ]
454+ k , _ = try_infer_value (inputs [1 ], lambda ret : ret .tolist ())
464455 else :
465456 k = inputs [1 ]
466457
@@ -546,15 +537,15 @@ def _full_impl(data, fill_value, dtype):
546537 size .append (dim )
547538 new_shape .append (dim )
548539 else :
549- try :
550- dim = int (_infer_value (dim , {}).asnumpy ())
540+ dim , success = try_infer_value (dim , lambda ret : int (ret ), lambda : 0 )
541+ new_shape .append (dim )
542+
543+ if success :
551544 if isinstance (size , list ):
552545 size .append (dim )
553- new_shape .append (dim )
554- except Exception :
546+ else :
555547 size = None
556548 need_reshape = True
557- new_shape .append (0 )
558549 else :
559550 if isinstance (size , list ):
560551 size .append (dim )
@@ -1346,12 +1337,11 @@ def _impl(inputs, input_types):
13461337 if isinstance (s , _expr .Constant ):
13471338 tmp_shape .append (int (s .data .asnumpy ()))
13481339 elif isinstance (s , _expr .Expr ):
1349- try :
1350- dim = int ( _infer_value ( s , {}). asnumpy () )
1351- tmp_shape . append ( dim )
1352- except Exception :
1340+ dim , success = try_infer_value ( s , lambda ret : int ( ret ))
1341+ tmp_shape . append ( dim )
1342+
1343+ if not success :
13531344 is_dyn = True
1354- tmp_shape .append (s )
13551345 else :
13561346 tmp_shape .append (s )
13571347
@@ -2312,13 +2302,15 @@ def _impl(inputs, input_types):
23122302 if isinstance (inputs [1 ], _expr .Expr ):
23132303 out_size = inputs [1 ]
23142304 elif isinstance (inputs [1 ], list ):
2315- try :
2316- infer_res = [_infer_value (size , {}) for size in inputs [1 ]]
2317- out_size = [np .asscalar (res .asnumpy ().astype (np .int )) for res in infer_res ]
2318- except Exception :
2319- h = _op .expand_dims (inputs [1 ][0 ], axis = 0 )
2320- w = _op .expand_dims (inputs [1 ][1 ], axis = 0 )
2321- out_size = _op .concatenate ([h , w ], axis = 0 )
2305+ out_size = []
2306+ for i in [0 , 1 ]:
2307+ size , _ = try_infer_value (
2308+ inputs [1 ][i ],
2309+ lambda ret : ret .astype (np .int ),
2310+ lambda : _op .expand_dims (inputs [1 ][i ], axis = 0 ),
2311+ )
2312+ out_size .append (size )
2313+ out_size = _op .concatenate (out_size , axis = 0 )
23222314
23232315 data = inputs [0 ]
23242316 align_corners = inputs [4 ]
0 commit comments