Skip to content

Commit c147a31

Browse files
brounezhiics
authored andcommitted
Add support and testing for tf.assert (as no-op) and tf.no_op to TF Relay frontend. (#4172)
1 parent 5408d3a commit c147a31

3 files changed

Lines changed: 161 additions & 2 deletions

File tree

python/tvm/relay/frontend/tensorflow.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,24 @@ def _impl(inputs, attr, params):
436436
return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
437437
return _impl
438438

439+
def _assert():
440+
# ToDo: In general people want asserts to be gone from TensorFlow graphs
441+
# when they are optimizing them, so converting it to a no-op is
442+
# reasonable. However, it would be nice to have the option to keep them
443+
# once Relay gets a Halt or Assert op.
444+
return _no_op()
445+
446+
def _no_op():
447+
def _impl(inputs, attr, params):
448+
# ToDo: This should really be an op that returns nothing, which could
449+
# be represented as an empty tuple. It turns out that TVM
450+
# infrastructure doesn't like running functions that return None and
451+
# also don't like running functions that return an empty tuple. So it
452+
# doesn't work, but it should be made to work and then this could be
453+
# improved. In the mean time, it is hard to imagine a case where it
454+
# matters in any real way that a no-op is converted to a constant 0.
455+
return tvm.relay.const(0)
456+
return _impl
439457

440458
def _matmul():
441459
def _impl(inputs, attr, params):
@@ -1326,6 +1344,7 @@ def _impl(inputs, attr, params):
13261344
'All' : _reduce('all'),
13271345
'ArgMax' : _argx(_op.argmax, 'argmax'),
13281346
'ArgMin' : _argx(_op.argmin, 'argmin'),
1347+
'Assert' : _assert(),
13291348
'AvgPool' : _pooling('avg_pool'),
13301349
'BatchMatMul' : _batch_matmul(),
13311350
'BatchMatMulV2' : _batch_matmul(),
@@ -1384,6 +1403,7 @@ def _impl(inputs, attr, params):
13841403
'Mod' : _elemwise('mod'),
13851404
'Mul' : _elemwise('multiply'),
13861405
'Neg' : AttrCvt('negative'),
1406+
'NoOp' : _no_op(),
13871407
'NotEqual' : _broadcast('not_equal'),
13881408
'OneHot' : _one_hot(),
13891409
'Pack' : _pack(),
@@ -2196,8 +2216,11 @@ def _parse_param(self, key, value, name, shape):
21962216
if np_array.dtype == np.dtype(object):
21972217
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
21982218
# Just leave it as placeholder.
2199-
self._nodes[name] = [_expr.var(name, shape=shape[name], dtype='uint8')]
2200-
2219+
if shape:
2220+
var_shape = shape[name]
2221+
else:
2222+
var_shape = tensor_util.TensorShapeProtoToList(value.tensor.tensor_shape)
2223+
self._nodes[name] = [_expr.var(name, shape=var_shape, dtype='uint8')]
22012224
return
22022225

22032226
array_ndim = len(np_array.shape)
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Unit tests for converting TensorFlow debugging ops to Relay."""
18+
import tensorflow as tf
19+
import numpy as np
20+
from tvm import relay
21+
from tvm.relay.frontend.tensorflow import from_tensorflow
22+
23+
def run_relay(graph, *vars):
24+
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
25+
ex = relay.create_executor('debug', mod=mod)
26+
return ex.evaluate()(*vars)
27+
28+
def test_assert_true():
29+
g = tf.Graph()
30+
with g.as_default():
31+
x = tf.placeholder(tf.float32, shape=())
32+
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"])
33+
34+
with tf.Session() as sess:
35+
x_value = np.random.rand()
36+
assert sess.run(assert_op, feed_dict={x: x_value}) is None
37+
38+
# In TVM, tf.assert is converted to a no-op which is actually a 0,
39+
# though it should probably be none or an empty tuple.
40+
#
41+
# ToDo: It appears that the frontend converter gets confused here and
42+
# entirely eliminates all operands from main(). Likely because x <= x
43+
# is always true, so the placeholder can be eliminated. But TF doesn't
44+
# do that, it's happening in Relay, and that optimization shouldn't
45+
# affect the arity of the main function. We should have to pass in
46+
# x_value here.
47+
np.testing.assert_allclose(0, run_relay(g).asnumpy())
48+
49+
def test_assert_true_var_capture():
50+
g = tf.Graph()
51+
with g.as_default():
52+
x = tf.placeholder(tf.float32, shape=())
53+
54+
# It turns out that tf.assert() creates a large and complex subgraph if
55+
# you capture a variable as part of the error message. So we need to
56+
# test that, too.
57+
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed", x])
58+
59+
with tf.Session() as sess:
60+
x_value = np.random.rand()
61+
assert sess.run(assert_op, feed_dict={x: x_value}) is None
62+
63+
# ToDo: The frontend converter gets confused here as well, thinking
64+
# that it needs to be told what x is twice. It also notes the output of
65+
# the graph as a boolean, which is not correct - as you can see above,
66+
# TF believes that the value of this graph is None. In addition, the
67+
# arity of the translated function should be 1, not 2.
68+
np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy())
69+
70+
def test_assert_false():
71+
g = tf.Graph()
72+
with g.as_default():
73+
assert_op = tf.Assert(tf.constant(False), ["it failed"])
74+
75+
with tf.Session() as sess:
76+
try:
77+
print(sess.run(assert_op))
78+
assert False # TF should have thrown an exception
79+
except tf.errors.InvalidArgumentError as e:
80+
assert "it failed" in e.message
81+
82+
# In TVM, tf.assert is converted to a no-op which is actually a 0,
83+
# though it should probably be none or an empty tuple. For the same
84+
# reason, there should not be an error here, even though the assertion
85+
# argument is false.
86+
np.testing.assert_allclose(0, run_relay(g).asnumpy())
87+
88+
89+
if __name__ == "__main__":
90+
test_assert_true()
91+
test_assert_true_var_capture()
92+
test_assert_false()
93+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Unit tests for converting TensorFlow debugging ops to Relay."""
18+
import tensorflow as tf
19+
import numpy as np
20+
from tvm import relay
21+
from tvm.relay.frontend.tensorflow import from_tensorflow
22+
23+
def run_relay(graph):
24+
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
25+
ex = relay.create_executor('debug', mod=mod)
26+
return ex.evaluate()(**params)
27+
28+
def test_no_op():
29+
g = tf.Graph()
30+
with g.as_default():
31+
no_op = tf.no_op()
32+
with tf.Session() as sess:
33+
# In TF, the type of a no-op is None.
34+
assert sess.run(no_op) is None
35+
36+
# In TVM, no-op is currently translated to 0, though it should
37+
# probably be none or an empty tuple.
38+
np.testing.assert_allclose(0, run_relay(g).asnumpy())
39+
40+
41+
if __name__ == "__main__":
42+
test_no_op()
43+

0 commit comments

Comments
 (0)