[Relay][Pytorch] Add support for aten::unflatten#16131
Conversation
cfa10bf to
bfbef63
Compare
bfbef63 to
6a42bff
Compare
| unflattened_size = tuple(inputs[2]) | ||
| dshape = get_const_tuple(self.infer_shape_with_prelude(data)) | ||
| assert len(dshape) > dim | ||
| new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] |
There was a problem hiding this comment.
Add check that dshape[dim] == multiplication of dimensions in unflattened_size
There was a problem hiding this comment.
@vvchernov Thanks!
I don't think we have to check it because torch.jit.trace does it.
They provide something like the below RuntimeError when the shape is wrong.
RuntimeError: unflatten: Provided sizes [3, 5, 3, -1] don't multiply up to the size of dim 0 (60) in the input tensor
Should we add the check in TVM's PyTorch frontend?
There was a problem hiding this comment.
Hello @mshr-h! Ok torch.jit.trace do it, but in this case we do not need assert len(dshape) > dim.
It looks like TVM usually rechecks all corner cases.
About -1 in unflattened_size. In this case we can multiply together other dimension and check dshape[dim] % mult == 0
There was a problem hiding this comment.
Hello @mshr-h! There are two cases: unflattened_size has -1 and does not have. You check only the first one. Example: dshape[dim] = 8, unflattened_size = [2, 2, 1, 1] pass your assert, but it is failure case
There was a problem hiding this comment.
I've checked that _op.reshape does not check size correctness by it-self. Nevertheless I found that dim can be not only -1, but from list {0, -1, -2, -3, -4} See
tvm/python/tvm/relay/op/transform.py
Line 243 in 748882a
I suggest to check -1 case only. It looks like torch does not have other options
There was a problem hiding this comment.
@vvchernov
Thanks! Updated the shape check.
There was a problem hiding this comment.
Ah, dim can be negative but I didn't check that. I'll add the check.
There was a problem hiding this comment.
Ah, dim can be negative but I didn't check that. I'll add the check.
Done.
6a42bff to
252a58b
Compare
Fix #15663
Support torch.unflatten.
cc @jikechao @vvchernov @Hzfengsy @junrushao