Skip to content

Commit 25db045

Browse files
committed
[Relay][Frontend] Support TF Gather
1 parent 4ac64fc commit 25db045

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,19 @@ def _impl(inputs, attr, params):
676676
'Taxis', '_class'])(new_input, attr)
677677
return _impl
678678

679+
def _gather_v1():
680+
def _impl(inputs, attr, params):
681+
#axis = params[inputs.pop(2).name_hint].asnumpy()[0]
682+
axis = 0
683+
new_input = []
684+
new_input.append(inputs.pop(0))
685+
new_input.append(inputs.pop(0))
686+
return AttrCvt(op_name="take",
687+
extras={'axis': tvm.const(axis, 'int32')},
688+
ignores=['Tindices', 'Tparams', 'validate_indices', \
689+
'Taxis', '_class'])(new_input, attr)
690+
return _impl
691+
679692
def _infer_out_shapes(inputs, params):
680693
"""A method to get the output shape of an intermediate node in the relay graph."""
681694
out_type = ir_pass.infer_type(inputs)
@@ -1003,6 +1016,7 @@ def _impl(inputs, attr, params):
10031016
'Sigmoid' : AttrCvt('sigmoid'),
10041017
'Fill' : _fill(),
10051018
'GatherV2' : _gather_v2(),
1019+
'Gather' : _gather_v1(),
10061020
'StridedSlice' : _stridedSlice(),
10071021
'LRN' : _lrn(),
10081022
'Pad' : _pad('Pad'),

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -473,11 +473,11 @@ def test_forward_stridedslice():
473473

474474

475475
#######################################################################
476-
# Gather
476+
# Gather, GatherV2
477477
# ------
478478

479479
def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
480-
""" One iteration of a Gather """
480+
""" One iteration of a GatherV2 """
481481

482482
tf.reset_default_graph()
483483
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
@@ -497,7 +497,7 @@ def _fill_indices(indice_value):
497497
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'GatherV2:0')
498498

499499
def test_forward_gather():
500-
'''test gather layer'''
500+
'''test GatherV2 layer'''
501501
_test_gather((4,), (1,), 1, 0, 'int32')
502502
_test_gather((4,), (1,), 1, 0, 'float32')
503503
_test_gather((1,4), (1,), [0], 0, 'int32')
@@ -509,6 +509,42 @@ def test_forward_gather():
509509
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
510510
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
511511

512+
513+
def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
514+
""" One iteration of a Gather"""
515+
tf.reset_default_graph()
516+
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
517+
indices = tf.placeholder("int32", indice_shape, name="indices")
518+
tf.gather(in_data, indices)
519+
np_data = np.random.uniform(size=ip_shape).astype(dtype)
520+
521+
def _fill_indices(indice_value):
522+
indices = np.array(ip_shape, dtype=dtype)
523+
if isinstance(indice_value, int):
524+
indices = np.array([indice_value], dtype='int32')
525+
else:
526+
indices = np.asarray(indice_value, dtype='int32')
527+
return indices
528+
np_indices = _fill_indices(indice_value)
529+
530+
compare_tf_with_tvm([np_data, np_indices], ['in_data:0', 'indices:0'], 'Gather:0')
531+
532+
def test_forward_gather_v1():
533+
'''test gather layer'''
534+
#_test_gather((2,3), (1,), 1, 0, 'int32')
535+
_test_gather_v1((4,), (1, 2, 2), [[[1, 0], [0, 1]]], 'float32')
536+
_test_gather_v1((4,), (1,), 1, 'int32')
537+
_test_gather_v1((4,), (1,), 1, 'float32')
538+
_test_gather_v1((1,4), (1,), [0], 'int32')
539+
_test_gather_v1((4,), (1,2,2), [[[1,0],[0,1]]], 'float32')
540+
_test_gather_v1((2,2), (1,2,2), [[[1,0],[0,1]]], 'int32')
541+
_test_gather_v1((2,2), (1,2,2), [[[1,0],[0,1]]], 'int32')
542+
_test_gather_v1((2,2), (1,2,2), [[[1,0],[0,1]]], 'float32')
543+
_test_gather_v1((3,3,3), (1,1,2), [[[1,0]]], 'int32')
544+
_test_gather_v1((3,3,3), (1,1,2), [[[1,0]]], 'int32')
545+
_test_gather_v1((4,3,5,6), (1,4), [[2,1,0,0]], 'float32')
546+
547+
512548
#######################################################################
513549
# Split
514550
# -----
@@ -1213,6 +1249,8 @@ def test_forward_rel_ops():
12131249
test_forward_crop()
12141250
test_forward_pad()
12151251
test_forward_gather()
1252+
# Gather was used in tf 1.6 and before
1253+
# test_forward_gather_v1()
12161254
test_forward_stridedslice()
12171255
test_forward_split()
12181256
test_forward_unstack()

0 commit comments

Comments
 (0)