@@ -720,9 +720,31 @@ def forward(self, *args):
720720 c = (args [0 ][:, :, 2 ] + 5 ) * 13
721721 return torch .cat ([t .unsqueeze (2 ) for t in [a , b , c ]], 2 )
722722
723+ class Concatenate3 (Module ):
724+ """
725+ torch.concat is preserved as aten::concat only when in a nested module.
726+ (In the most cases, It is converted to aten::cat instead of aten::concat.)
727+ """
728+
729+ def __init__ (self ):
730+ super ().__init__ ()
731+
732+ class _Concatenate (Module ):
733+ def forward (self , * args ):
734+ a = (args [0 ][:, :, 0 ] + 2 ) * 7
735+ b = (args [0 ][:, :, 1 ] + 3 ) * 11
736+ c = (args [0 ][:, :, 2 ] + 5 ) * 13
737+ return torch .concat ([t .unsqueeze (2 ) for t in [a , b , c ]], 2 )
738+
739+ self .mod = _Concatenate ()
740+
741+ def forward (self , * args ):
742+ return self .mod (* args )
743+
723744 input_data = torch .rand (input_shape ).float ()
724745 verify_model (Concatenate1 ().float ().eval (), input_data = input_data )
725746 verify_model (Concatenate2 ().float ().eval (), input_data = input_data )
747+ verify_model (Concatenate3 ().float ().eval (), input_data = input_data )
726748
727749
728750@tvm .testing .uses_gpu
0 commit comments