Skip to content

Commit e1019f3

Browse files
masahimasa
authored andcommitted
[Torch] Clean up usage of try ... infer_value() ... except (apache#6504)
* clean up infer value usage * try silence pylint * remove unused variable * make on_failuare optional * make on_success optional True Co-authored-by: masa <masa@pop-os.localdomain>
1 parent 001d08c commit e1019f3

2 files changed

Lines changed: 44 additions & 35 deletions

File tree

python/tvm/relay/frontend/common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,23 @@ def infer_value_simulated(input_val, params):
563563
return output_value
564564

565565

566+
def try_infer_value(val, on_success=None, on_failure=None):
567+
"""Try running infer_value on the input val, and if successful, return the inferred value or
568+
pass it to on_success callback if provided. Otherwise, run on_failure callback if it is
569+
provided, or return the input val as output. In each case, the second return value
570+
indicates whether infer_value has succeeded or not.
571+
"""
572+
try:
573+
ret = infer_value(val, {}).asnumpy()
574+
if on_success:
575+
return on_success(ret), True
576+
return ret, True
577+
except Exception:
578+
if on_failure:
579+
return on_failure(), False
580+
return val, False
581+
582+
566583
def new_var(name_hint, type_annotation=None, shape=None, dtype="float32"):
567584
return _expr.var(name_hint, type_annotation, shape, dtype)
568585

python/tvm/relay/frontend/pytorch.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
1818
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
19-
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, unnecessary-comprehension
19+
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
2020
"""PT: PyTorch frontend."""
2121
import itertools
2222
import logging
@@ -36,6 +36,7 @@
3636
from .common import AttrCvt, get_relay_op
3737
from .common import infer_shape as _infer_shape
3838
from .common import infer_value as _infer_value
39+
from .common import try_infer_value
3940
from .common import infer_value_simulated as _infer_value_simulated
4041
from .common import infer_type as _infer_type
4142
from ..prelude import Prelude, StaticTensorArrayOps
@@ -185,11 +186,8 @@ def _impl(inputs, input_types):
185186
def _get_value(val, dtype):
186187
# dtype is a tvm dtype
187188
if isinstance(val, _expr.Expr):
188-
try:
189-
ret = _infer_value(_op.cast(val, dtype), {}).asnumpy()
190-
ret = _expr.const(ret, dtype)
191-
except Exception:
192-
ret = _op.cast(val, dtype)
189+
inp = _op.cast(val, dtype)
190+
ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype))
193191
else:
194192
ret = _create_typed_const(val, dtype)
195193
return ret
@@ -305,10 +303,7 @@ def _impl(inputs, input_types):
305303
dim = int(inputs[1])
306304
stride = int(inputs[4])
307305
if isinstance(inputs[2], _expr.Call):
308-
try:
309-
begin[dim] = np.asscalar(_infer_value(inputs[2], {}).asnumpy().astype(np.int))
310-
except Exception:
311-
begin[dim] = inputs[2]
306+
begin[dim], _ = try_infer_value(inputs[2], lambda ret: np.asscalar(ret.astype(np.int)))
312307
else:
313308
begin[dim] = int(inputs[2])
314309

@@ -329,10 +324,9 @@ def _impl(inputs, input_types):
329324
target_end = int(inputs[3])
330325
else:
331326
if isinstance(inputs[3], _expr.Expr):
332-
try:
333-
target_end = np.asscalar(_infer_value(inputs[3], {}).asnumpy().astype(np.int))
334-
except Exception:
335-
target_end = inputs[3]
327+
target_end, _ = try_infer_value(
328+
inputs[3], lambda ret: np.asscalar(ret.astype(np.int))
329+
)
336330
else:
337331
target_end = inputs[3]
338332

@@ -457,10 +451,7 @@ def _impl(inputs, input_types):
457451
sort = bool(inputs[4])
458452

459453
if isinstance(inputs[1], _expr.Expr):
460-
try:
461-
k = _infer_value(inputs[1], {}).asnumpy().tolist()
462-
except Exception:
463-
k = inputs[1]
454+
k, _ = try_infer_value(inputs[1], lambda ret: ret.tolist())
464455
else:
465456
k = inputs[1]
466457

@@ -546,15 +537,15 @@ def _full_impl(data, fill_value, dtype):
546537
size.append(dim)
547538
new_shape.append(dim)
548539
else:
549-
try:
550-
dim = int(_infer_value(dim, {}).asnumpy())
540+
dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0)
541+
new_shape.append(dim)
542+
543+
if success:
551544
if isinstance(size, list):
552545
size.append(dim)
553-
new_shape.append(dim)
554-
except Exception:
546+
else:
555547
size = None
556548
need_reshape = True
557-
new_shape.append(0)
558549
else:
559550
if isinstance(size, list):
560551
size.append(dim)
@@ -1346,12 +1337,11 @@ def _impl(inputs, input_types):
13461337
if isinstance(s, _expr.Constant):
13471338
tmp_shape.append(int(s.data.asnumpy()))
13481339
elif isinstance(s, _expr.Expr):
1349-
try:
1350-
dim = int(_infer_value(s, {}).asnumpy())
1351-
tmp_shape.append(dim)
1352-
except Exception:
1340+
dim, success = try_infer_value(s, lambda ret: int(ret))
1341+
tmp_shape.append(dim)
1342+
1343+
if not success:
13531344
is_dyn = True
1354-
tmp_shape.append(s)
13551345
else:
13561346
tmp_shape.append(s)
13571347

@@ -2312,13 +2302,15 @@ def _impl(inputs, input_types):
23122302
if isinstance(inputs[1], _expr.Expr):
23132303
out_size = inputs[1]
23142304
elif isinstance(inputs[1], list):
2315-
try:
2316-
infer_res = [_infer_value(size, {}) for size in inputs[1]]
2317-
out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res in infer_res]
2318-
except Exception:
2319-
h = _op.expand_dims(inputs[1][0], axis=0)
2320-
w = _op.expand_dims(inputs[1][1], axis=0)
2321-
out_size = _op.concatenate([h, w], axis=0)
2305+
out_size = []
2306+
for i in [0, 1]:
2307+
size, _ = try_infer_value(
2308+
inputs[1][i],
2309+
lambda ret: ret.astype(np.int),
2310+
lambda: _op.expand_dims(inputs[1][i], axis=0),
2311+
)
2312+
out_size.append(size)
2313+
out_size = _op.concatenate(out_size, axis=0)
23222314

23232315
data = inputs[0]
23242316
align_corners = inputs[4]

0 commit comments

Comments
 (0)