[Relax][PyTorch] Fix output shape of torch.nn.functional.scaled_dot_product_attention#17379
Conversation
torch.nn.functional.scaled_dot_product_attention
|
we can transpose to get the expected result. Thanks for the effort! |
456c72e to
185d28c
Compare
|
MSC E2E test is failing. Seems like we also need to change something other than relax frontend. Link to the ci log: https://ci.tlcpack.ai/blue/organizations/jenkins/tvm-unity/detail/PR-17379/6/pipeline/ tests/python/contrib/test_msc/test_translate_torch.py::test_attention FAILED
|
185d28c to
43268e1
Compare
43268e1 to
a783823
Compare
a783823 to
a2b29c0
Compare
torch.nn.functional.scaled_dot_product_attention outputs in the shape of
(N, ..., L, E_v)butrelax.op.nn.attentiondoes(N, L, ..., E_v)so the output should also be transposed.Maybe we should add E2E tests in
tests/python/nightly/to check the relax torch frontend.cc: @yongwww