Skip to content

Commit 65121c8

Browse files
authored
[Relay][Frontend] Add support for aten::concat (#16199)
* Update pytorch.py * Add concat test * rm whitespace * Add diable docstring * update comment
1 parent bf071de commit 65121c8

2 files changed

Lines changed: 23 additions & 0 deletions

File tree

python/tvm/relay/frontend/pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4051,6 +4051,7 @@ def create_convert_map(self):
40514051
"aten::squeeze": self.squeeze,
40524052
"aten::unsqueeze": self.unsqueeze,
40534053
"aten::cat": self.concatenate,
4054+
"aten::concat": self.concatenate,
40544055
"aten::slice": self.slice,
40554056
"aten::narrow": self.narrow,
40564057
"aten::split": self.split,

tests/python/frontend/pytorch/test_forward.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,9 +720,31 @@ def forward(self, *args):
720720
c = (args[0][:, :, 2] + 5) * 13
721721
return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2)
722722

723+
class Concatenate3(Module):
724+
"""
725+
torch.concat is preserved as aten::concat only when in a nested module.
726+
(In the most cases, It is converted to aten::cat instead of aten::concat.)
727+
"""
728+
729+
def __init__(self):
730+
super().__init__()
731+
732+
class _Concatenate(Module):
733+
def forward(self, *args):
734+
a = (args[0][:, :, 0] + 2) * 7
735+
b = (args[0][:, :, 1] + 3) * 11
736+
c = (args[0][:, :, 2] + 5) * 13
737+
return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2)
738+
739+
self.mod = _Concatenate()
740+
741+
def forward(self, *args):
742+
return self.mod(*args)
743+
723744
input_data = torch.rand(input_shape).float()
724745
verify_model(Concatenate1().float().eval(), input_data=input_data)
725746
verify_model(Concatenate2().float().eval(), input_data=input_data)
747+
verify_model(Concatenate3().float().eval(), input_data=input_data)
726748

727749

728750
@tvm.testing.uses_gpu

0 commit comments

Comments
 (0)