Skip to content

Commit aa416cf

Browse files
lixiaoquanMarisaKirisame
authored andcommitted
[TENSORFLOW] Convert scalar Const into tvm.relay.const (apache#3885)
* [TENSORFLOW] Convert scalar Const into tvm.relay.const * use _get_num_param() and _get_list_param()
1 parent 78fdcf9 commit aa416cf

1 file changed

Lines changed: 21 additions & 19 deletions

File tree

python/tvm/relay/frontend/tensorflow.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,12 @@ def _dim_check(attrs):
8484
return _dim_check, "Only 2d kernel supported."
8585

8686
def _get_param(params, input_node):
87+
if isinstance(input_node, _expr.Constant):
88+
return np.atleast_1d(input_node.data.asnumpy())
8789
return params.pop(input_node.name_hint).asnumpy()
8890

8991
def _get_num_param(params, input_node):
90-
return _get_param(params, input_node)[0]
92+
return _get_param(params, input_node).item()
9193

9294
def _get_list_param(params, input_node):
9395
return _get_param(params, input_node).tolist()
@@ -335,9 +337,9 @@ def _impl(inputs, attr, params):
335337
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
336338
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
337339
try:
338-
boxes = params.pop(inputs[1].name_hint).asnumpy().tolist()
339-
box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist()
340-
crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist()
340+
boxes = _get_list_param(params, inputs[1])
341+
box_ind = _get_list_param(params, inputs[2])
342+
crop_size = _get_list_param(params, inputs[3])
341343
except (IndexError, KeyError):
342344
boxes = _infer_value(inputs[1], params).asnumpy().tolist()
343345
box_ind = _infer_value(inputs[2], params).asnumpy().tolist()
@@ -505,7 +507,7 @@ def _impl(inputs, attr, params):
505507

506508
def _tile():
507509
def _impl(inputs, attr, params):
508-
reps = params[inputs.pop().name_hint].asnumpy()
510+
reps = _get_list_param(params, inputs.pop())
509511
new_input = []
510512
new_input.append(inputs.pop(0))
511513

@@ -752,7 +754,7 @@ def _impl(inputs, attr, params):
752754

753755
def _reduce(op):
754756
def _impl(inputs, attr, params):
755-
axis = params.pop(inputs[1].name_hint).asnumpy()
757+
axis = _get_list_param(params, inputs[1])
756758
axis = tuple(axis)
757759
return AttrCvt(
758760
op_name=op,
@@ -937,8 +939,8 @@ def _impl(inputs, attr, params):
937939

938940
def _clip_by_value():
939941
def _impl(inputs, attr, params):
940-
a_min = params.pop(inputs[1].name_hint).asnumpy()[0]
941-
a_max = params.pop(inputs[2].name_hint).asnumpy()[0]
942+
a_min = _get_num_param(params, inputs[1])
943+
a_max = _get_num_param(params, inputs[2])
942944
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
943945
return _impl
944946

@@ -965,10 +967,11 @@ def _impl(inputs, attr, params):
965967

966968
def _range():
967969
def _impl(inputs, attr, params):
968-
start = params.pop(inputs[0].name_hint).asnumpy()[0]
969-
limit = params.pop(inputs[1].name_hint).asnumpy()[0] \
970-
if hasattr(inputs[1], "name_hint") else params.pop('Rank').asnumpy()[0]
971-
delta = params.pop(inputs[2].name_hint).asnumpy()[0]
970+
start = _get_param(params, inputs[0])[0]
971+
limit = _get_param(params, inputs[1])[0] \
972+
if hasattr(inputs[1], "name_hint") or isinstance(inputs[1], _expr.Constant) \
973+
else params.pop('Rank').asnumpy()[0]
974+
delta = _get_param(params, inputs[2])[0]
972975
dtype = attr['dtype'].name if 'dtype' in attr else "int32"
973976
return AttrCvt(
974977
op_name="arange",
@@ -1084,7 +1087,7 @@ def _impl(inputs, attr, params):
10841087

10851088
def _topk():
10861089
def _impl(inputs, attr, params):
1087-
k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
1090+
k = int(_get_num_param(params, inputs.pop(1)))
10881091
if k < 1:
10891092
raise tvm.error.OpAttributeInvalid(
10901093
'Attribute k must be positive in operator TopKV2')
@@ -1196,7 +1199,7 @@ def _impl(inputs, attr, params):
11961199

11971200
def _prod():
11981201
def _impl(inputs, attr, params):
1199-
axis = params.pop(inputs[1].name_hint).asnumpy()[0]
1202+
axis = _get_num_param(params, inputs[1])
12001203
keepdims = attr['keep_dims']
12011204
return _op.prod(inputs[0], int(axis), keepdims=keepdims)
12021205
return _impl
@@ -2104,13 +2107,12 @@ def _parse_param(self, key, value, name, shape):
21042107
if array_ndim == 0:
21052108
new_array = np.empty([1], dtype=np_array.dtype)
21062109
new_array[0] = np_array
2107-
self._params[name] = tvm.nd.array(new_array)
2110+
self._nodes[name] = [tvm.relay.const(new_array)]
21082111
else:
21092112
self._params[name] = tvm.nd.array(np_array)
2110-
2111-
self._nodes[name] = [_expr.var(name,
2112-
shape=self._params[name].shape,
2113-
dtype=self._params[name].dtype)]
2113+
self._nodes[name] = [_expr.var(name,
2114+
shape=self._params[name].shape,
2115+
dtype=self._params[name].dtype)]
21142116
else:
21152117
if key not in ('dtype', '_output_shapes', '_class'):
21162118
raise NotImplementedError \

0 commit comments

Comments
 (0)