Skip to content

Commit 924245a

Browse files
committed
[Relay] add ClipByValue and Neg in tf frontend
1 parent 78a0f47 commit 924245a

2 files changed

Lines changed: 18 additions & 0 deletions

File tree

python/tvm/relay/frontend/tensorflow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,13 @@ def _impl(inputs, attr, params):
934934
return AttrCvt(op_name="where")(inputs, attr)
935935
return _impl
936936

937+
def _clip_by_value():
938+
def _impl(inputs, attr, params):
939+
a_min = params.pop(inputs[1].name_hint).asnumpy()[0]
940+
a_max = params.pop(inputs[2].name_hint).asnumpy()[0]
941+
return _op.clip(inputs[0], a_min=a_min, a_max=a_max)
942+
return _impl
943+
937944
def _reverse_v2():
938945
def _impl(inputs, attr, params):
939946
axis = params.pop(inputs[1].name_hint).asnumpy()[0]
@@ -1190,6 +1197,7 @@ def _impl(inputs, attr, params):
11901197
'Cast' : _cast(),
11911198
'Ceil' : AttrCvt('ceil'),
11921199
'CheckNumerics' : _check_numerics(),
1200+
'ClipByValue' : _clip_by_value(),
11931201
'Concat' : _concat(),
11941202
'ConcatV2' : _concatV2(),
11951203
'Conv2D' : _conv('conv'),
@@ -1223,6 +1231,7 @@ def _impl(inputs, attr, params):
12231231
'Mean' : _mean(),
12241232
'Minimum' : _elemwise('minimum'),
12251233
'Mul' : _elemwise('multiply'),
1234+
'Neg' : AttrCvt('negative'),
12261235
'NotEqual' : _broadcast('not_equal'),
12271236
'Pack' : _pack(),
12281237
'Pad' : _pad('Pad'),

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,14 @@ def test_forward_log():
15581558
tf.log(in_data, name="log")
15591559
compare_tf_with_tvm([np_data], ['in_data:0'], 'log:0')
15601560

1561+
def test_forward_negative():
1562+
"""test operator Neg """
1563+
np_data = np.random.uniform(-100, 100, size=(224, 224, 3)).astype(np.float32)
1564+
tf.reset_default_graph()
1565+
in_data = tf.placeholder(tf.float32, (224, 224, 3), name="in_data")
1566+
tf.negative(in_data, name="negative")
1567+
compare_tf_with_tvm([np_data], ['in_data:0'], 'negative:0')
1568+
15611569
def test_forward_softplus():
15621570
"""test operator Softplus"""
15631571
np_data = np.random.uniform(1, 10, size=(2, 3, 5)).astype(np.float32)
@@ -1708,6 +1716,7 @@ def test_placeholder():
17081716
test_forward_pow_exp()
17091717
test_forward_sign()
17101718
test_forward_log()
1719+
test_forward_negative()
17111720
test_forward_softplus()
17121721
test_forward_sqrt()
17131722
test_forward_rsqrt()

0 commit comments

Comments
 (0)