Skip to content

Conversation

@mshr-h
Copy link
Contributor

@mshr-h mshr-h commented Nov 14, 2023

@github-actions github-actions bot requested a review from Hzfengsy November 14, 2023 02:43
@mshr-h mshr-h force-pushed the pytorch-linalg_vector_norm branch from b4e0e40 to f2e6107 Compare November 14, 2023 04:19
return _op.scatter_nd(source, indices, values, mode)

def linalg_vector_norm(self, inputs, input_types):
data = inputs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this method is based on torch.linalg.vector_norm. The latter assums that input data type is float, double or complex one; dtype also should be real or complex. Would you check it?

Copy link
Contributor Author

@mshr-h mshr-h Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. That's true. It's based on torch.linalg.vector_norm and it supports float, double and complex dtypes as input. Seems like PyTorch doesn't support complex data type. convert_pt_to_tvm_type

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @mshr-h! My idea was to add something like assert data.dtype == float or data.dtype == double. And may be add TODO for further support complex values, but I do not think it is needed just now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vvchernov ! I've added assertion and testcases for double-precision input data.

Copy link
Contributor

@vvchernov vvchernov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work! LGTM, see my small comment

@mshr-h mshr-h marked this pull request as ready for review November 15, 2023 12:50
@github-actions github-actions bot requested a review from junrushao November 20, 2023 01:40

class VectorNorm5(torch.nn.Module):
def forward(self, x):
return torch.linalg.vector_norm(x, ord=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you only need different ord for each test case, you don't need to create a different class for each. Please clean them up.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi Revised as suggested. Please review. Thanks!

@masahi masahi merged commit 3fd3a63 into apache:main Nov 26, 2023
@mshr-h mshr-h deleted the pytorch-linalg_vector_norm branch November 27, 2023 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FRONTEND][PyTorch] NotImplementedError: The following operators are not implemented: ['aten::linalg_vector_norm']

3 participants