Skip to content

Commit bfbef63

Browse files
committed
add support for aten::unflatten
1 parent 707492a commit bfbef63

2 files changed

Lines changed: 39 additions & 0 deletions

File tree

python/tvm/relay/frontend/pytorch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1546,6 +1546,16 @@ def flatten(self, inputs, input_types):
15461546
out = _op.squeeze(out, axis=squeeze_axes)
15471547
return out
15481548

1549+
def unflatten(self, inputs, input_types):
1550+
data = inputs[0]
1551+
dim = int(inputs[1])
1552+
unflattened_size = tuple(inputs[2])
1553+
dshape = get_const_tuple(self.infer_shape_with_prelude(data))
1554+
assert len(dshape) > dim
1555+
new_shape = dshape[:dim] + unflattened_size + dshape[dim+1:]
1556+
out = _op.reshape(data, new_shape)
1557+
return out
1558+
15491559
def addmm(self, inputs, input_types):
15501560
input_mat = inputs[0]
15511561
mat1 = inputs[1]
@@ -3945,6 +3955,7 @@ def create_convert_map(self):
39453955
"aten::t": self.transpose,
39463956
"aten::numpy_T": self.numpy_T,
39473957
"aten::flatten": self.flatten,
3958+
"aten::unflatten": self.unflatten,
39483959
"aten::addmm": self.addmm,
39493960
"aten::size": self.size,
39503961
"aten::view": self.view,

tests/python/frontend/pytorch/test_forward.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,34 @@ def _test_flatten(start_dim, end_dim):
15441544
verify_model(_test_flatten(-3, -2), inp)
15451545

15461546

1547+
#@tvm.testing.uses_gpu
1548+
def test_unflatten():
1549+
"""test_unflatten"""
1550+
1551+
def _test_unflatten(dim, unflattened_size):
1552+
return lambda inp: torch.unflatten(inp, dim, unflattened_size)
1553+
1554+
inp = torch.rand(60)
1555+
1556+
# [60] -> [3, 5, 2, 2]
1557+
verify_model(_test_unflatten(0, (3, 5, 2, 2)), inp)
1558+
verify_model(_test_unflatten(0, (-1, 5, 2, 2)), inp)
1559+
verify_model(_test_unflatten(0, (3, -1, 2, 2)), inp)
1560+
verify_model(_test_unflatten(0, (3, 5, -1, 2)), inp)
1561+
verify_model(_test_unflatten(0, (3, 5, 2, -1)), inp)
1562+
1563+
inp = torch.rand(3, 4, 1)
1564+
1565+
# [3, 4, 1] -> [3, 2, 2, 1]
1566+
verify_model(_test_unflatten(1, (2, 2)), inp)
1567+
verify_model(_test_unflatten(1, (-1, 2)), inp)
1568+
1569+
inp = torch.rand(5, 12, 3)
1570+
1571+
# [5, 12, 3] -> [5, 2, 2, 3, 1, 1, 3]
1572+
verify_model(_test_unflatten(-2, (2, 2, 3, 1, 1)), inp)
1573+
1574+
15471575
@tvm.testing.uses_gpu
15481576
def test_forward_transpose():
15491577
"""test_forward_transpose"""

0 commit comments

Comments
 (0)