@@ -473,11 +473,11 @@ def test_forward_stridedslice():
473473
474474
475475#######################################################################
476- # Gather
476+ # Gather, GatherV2
477477# ------
478478
479479def _test_gather (ip_shape , indice_shape , indice_value , axis , dtype ):
480- """ One iteration of a Gather """
480+ """ One iteration of a GatherV2 """
481481
482482 tf .reset_default_graph ()
483483 in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
@@ -497,7 +497,7 @@ def _fill_indices(indice_value):
497497 compare_tf_with_tvm ([np_data , np_indices ], ['in_data:0' , 'indices:0' ], 'GatherV2:0' )
498498
499499def test_forward_gather ():
500- '''test gather layer'''
500+ '''test GatherV2 layer'''
501501 _test_gather ((4 ,), (1 ,), 1 , 0 , 'int32' )
502502 _test_gather ((4 ,), (1 ,), 1 , 0 , 'float32' )
503503 _test_gather ((1 ,4 ), (1 ,), [0 ], 0 , 'int32' )
@@ -509,6 +509,42 @@ def test_forward_gather():
509509 _test_gather ((3 ,3 ,3 ), (1 ,1 ,2 ), [[[1 ,0 ]]], 2 , 'int32' )
510510 _test_gather ((4 ,3 ,5 ,6 ), (1 ,4 ), [[2 ,1 ,0 ,0 ]], 0 , 'float32' )
511511
512+
513+ def _test_gather_v1 (ip_shape , indice_shape , indice_value , dtype ):
514+ """ One iteration of a Gather"""
515+ tf .reset_default_graph ()
516+ in_data = tf .placeholder (dtype , ip_shape , name = "in_data" )
517+ indices = tf .placeholder ("int32" , indice_shape , name = "indices" )
518+ tf .gather (in_data , indices )
519+ np_data = np .random .uniform (size = ip_shape ).astype (dtype )
520+
521+ def _fill_indices (indice_value ):
522+ indices = np .array (ip_shape , dtype = dtype )
523+ if isinstance (indice_value , int ):
524+ indices = np .array ([indice_value ], dtype = 'int32' )
525+ else :
526+ indices = np .asarray (indice_value , dtype = 'int32' )
527+ return indices
528+ np_indices = _fill_indices (indice_value )
529+
530+ compare_tf_with_tvm ([np_data , np_indices ], ['in_data:0' , 'indices:0' ], 'Gather:0' )
531+
532+ def test_forward_gather_v1 ():
533+ '''test gather layer'''
534+ #_test_gather((2,3), (1,), 1, 0, 'int32')
535+ _test_gather_v1 ((4 ,), (1 , 2 , 2 ), [[[1 , 0 ], [0 , 1 ]]], 'float32' )
536+ _test_gather_v1 ((4 ,), (1 ,), 1 , 'int32' )
537+ _test_gather_v1 ((4 ,), (1 ,), 1 , 'float32' )
538+ _test_gather_v1 ((1 ,4 ), (1 ,), [0 ], 'int32' )
539+ _test_gather_v1 ((4 ,), (1 ,2 ,2 ), [[[1 ,0 ],[0 ,1 ]]], 'float32' )
540+ _test_gather_v1 ((2 ,2 ), (1 ,2 ,2 ), [[[1 ,0 ],[0 ,1 ]]], 'int32' )
541+ _test_gather_v1 ((2 ,2 ), (1 ,2 ,2 ), [[[1 ,0 ],[0 ,1 ]]], 'int32' )
542+ _test_gather_v1 ((2 ,2 ), (1 ,2 ,2 ), [[[1 ,0 ],[0 ,1 ]]], 'float32' )
543+ _test_gather_v1 ((3 ,3 ,3 ), (1 ,1 ,2 ), [[[1 ,0 ]]], 'int32' )
544+ _test_gather_v1 ((3 ,3 ,3 ), (1 ,1 ,2 ), [[[1 ,0 ]]], 'int32' )
545+ _test_gather_v1 ((4 ,3 ,5 ,6 ), (1 ,4 ), [[2 ,1 ,0 ,0 ]], 'float32' )
546+
547+
512548#######################################################################
513549# Split
514550# -----
@@ -1213,6 +1249,8 @@ def test_forward_rel_ops():
12131249 test_forward_crop ()
12141250 test_forward_pad ()
12151251 test_forward_gather ()
1252+ # Gather was used in tf 1.6 and before
1253+ # test_forward_gather_v1()
12161254 test_forward_stridedslice ()
12171255 test_forward_split ()
12181256 test_forward_unstack ()
0 commit comments