@@ -23,13 +23,15 @@ def before():
2323 x = relay .var ("x" , shape = (10 , 20 ))
2424 y = relay .add (x , relay .const (1 , "float32" ))
2525 z = relay .exp (y )
26- return relay .Function ([x ], z )
26+ w = relay .squeeze (z )
27+ return relay .Function ([x ], w )
2728
2829 def expected ():
2930 x = relay .var ("p" , shape = (10 , 20 ))
3031 y = relay .add (x , relay .const (1 , "float32" ))
3132 z = relay .exp (y )
32- f1 = relay .Function ([x ], z )
33+ w = relay .squeeze (z )
34+ f1 = relay .Function ([x ], w )
3335 x = relay .var ("x" , shape = (10 , 20 ))
3436 y = relay .Call (f1 , [x ])
3537 return relay .Function ([x ], y )
@@ -503,6 +505,38 @@ def expected(dshape):
503505 assert relay .ir_pass .alpha_equal (zz , after )
504506
505507
508+ def test_fuse_parallel_injective ():
509+ """Test fusing parallel injective ops to an elemwise op."""
510+ def before ():
511+ x = relay .var ("x" , shape = (10 , 20 ))
512+ y = relay .add (x , relay .const (1 , "float32" ))
513+ z = relay .squeeze (y )
514+ u = relay .transpose (y , axes = [0 , 1 ])
515+ w = relay .left_shift (z , u )
516+ return relay .Function ([x ], w )
517+
518+ def expected ():
519+ x = relay .var ("p" , shape = (10 , 20 ))
520+ y = relay .add (x , relay .const (1 , "float32" ))
521+ z = relay .squeeze (y )
522+ u = relay .transpose (y , axes = [0 , 1 ])
523+ w = relay .left_shift (z , u )
524+ f1 = relay .Function ([x ], w )
525+ x = relay .var ("x" , shape = (10 , 20 ))
526+ y = relay .Call (f1 , [x ])
527+ return relay .Function ([x ], y )
528+
529+ z = before ()
530+ z = relay .ir_pass .infer_type (z )
531+ zz = relay .ir_pass .fuse_ops (z , opt_level = 0 )
532+ assert not relay .ir_pass .free_vars (zz )
533+ zz = relay .ir_pass .fuse_ops (z , opt_level = 2 )
534+ zz = relay .ir_pass .infer_type (zz )
535+ assert not relay .ir_pass .free_vars (zz )
536+ after = relay .ir_pass .infer_type (expected ())
537+ assert relay .ir_pass .alpha_equal (zz , after )
538+
539+
506540if __name__ == "__main__" :
507541 test_fuse_simple ()
508542 test_conv2d_fuse ()
@@ -515,3 +549,4 @@ def expected(dshape):
515549 test_tuple_intermediate ()
516550 test_tuple_consecutive ()
517551 test_inception_like ()
552+ test_fuse_parallel_injective ()
0 commit comments