@@ -1544,6 +1544,34 @@ def _test_flatten(start_dim, end_dim):
15441544 verify_model (_test_flatten (- 3 , - 2 ), inp )
15451545
15461546
1547+ #@tvm.testing.uses_gpu
1548+ def test_unflatten ():
1549+ """test_unflatten"""
1550+
1551+ def _test_unflatten (dim , unflattened_size ):
1552+ return lambda inp : torch .unflatten (inp , dim , unflattened_size )
1553+
1554+ inp = torch .rand (60 )
1555+
1556+ # [60] -> [3, 5, 2, 2]
1557+ verify_model (_test_unflatten (0 , (3 , 5 , 2 , 2 )), inp )
1558+ verify_model (_test_unflatten (0 , (- 1 , 5 , 2 , 2 )), inp )
1559+ verify_model (_test_unflatten (0 , (3 , - 1 , 2 , 2 )), inp )
1560+ verify_model (_test_unflatten (0 , (3 , 5 , - 1 , 2 )), inp )
1561+ verify_model (_test_unflatten (0 , (3 , 5 , 2 , - 1 )), inp )
1562+
1563+ inp = torch .rand (3 , 4 , 1 )
1564+
1565+ # [3, 4, 1] -> [3, 2, 2, 1]
1566+ verify_model (_test_unflatten (1 , (2 , 2 )), inp )
1567+ verify_model (_test_unflatten (1 , (- 1 , 2 )), inp )
1568+
1569+ inp = torch .rand (5 , 12 , 3 )
1570+
1571+ # [5, 12, 3] -> [5, 2, 2, 3, 1, 1, 3]
1572+ verify_model (_test_unflatten (- 2 , (2 , 2 , 3 , 1 , 1 )), inp )
1573+
1574+
15471575@tvm .testing .uses_gpu
15481576def test_forward_transpose ():
15491577 """test_forward_transpose"""
0 commit comments