Skip to content

Commit 1c45389

Browse files
authored
[BugFix] Fixed Inappropriate Logical Expression (#16272)
[BugFix] Fixed a comparison for splitting tensor In the `tensor_split` method, there's a comparsion that checks if the input tensor is zero-dimensional or one-dimensional long tensor. In the comparsion, there's a typo that converts the shape of the tensor to a list and compares against integer. This commit fixes the bug by comapring the length of the tensor against the integer. Signed-off-by: fazledyn-or <ataf@openrefactory.com>
1 parent a050696 commit 1c45389

1 file changed

Lines changed: 1 addition & 3 deletions

File tree

python/tvm/relay/frontend/pytorch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -595,9 +595,7 @@ def tensor_split(self, inputs, input_types):
595595
)
596596
raise AssertionError(msg)
597597

598-
if isinstance(inputs[1], torch.Tensor) and not (
599-
list(inputs[1].shape) == [] or list(inputs[1].shape) == 1
600-
):
598+
if isinstance(inputs[1], torch.Tensor) and len(inputs[1].shape) not in [0, 1]:
601599
msg = "indices_or_sections must be a zero-dimensional or one-dimensional long tensor"
602600
raise AssertionError(msg)
603601

0 commit comments

Comments
 (0)