[Frontend][PyTorch] Add: Relay stft operator#11190
Conversation
|
@jsheng-jian Thanks, looks like I'm very curious how you implemented stft (without fft in TVM)! |
|
|
|
Please address the CI issue, there is a warning error from doc. |
python/tvm/relay/op/transform.py
Outdated
| win_length : int | ||
| The size of window frame and STFT filter | ||
| window : relay.Expr | ||
| A 1-D tensor window frame |
There was a problem hiding this comment.
In PyTorch, window argument is optional. So shouldn't we support that too?
python/tvm/relay/op/transform.py
Outdated
| Returns | ||
| ------- | ||
| output : relay.Expr | ||
| Tensor containing the STFT result |
There was a problem hiding this comment.
Document the output shape. I had to read the type rel to see how the output shape looks like.
python/tvm/topi/cuda/stft.py
Outdated
|
|
||
| with ib.for_range(0, output_ptr.shape[0]) as batch: | ||
| with ib.for_range(0, output_ptr.shape[1]) as row: | ||
| with ib.for_range(0, output_ptr.shape[2]) as col: |
There was a problem hiding this comment.
This looks weird. You try to parallelize over the batch dim but nothing is parallelized.
python/tvm/topi/stft.py
Outdated
|
|
||
| with ib.for_range(0, output_ptr.shape[0]) as batch: | ||
| # https://librosa.org/doc/0.7.2/_modules/librosa/core/spectrum.html#stft | ||
| with ib.for_range(0, output_ptr.shape[1], kind="parallel") as row: |
There was a problem hiding this comment.
fuse the outer loop to have one big parallel loop.
python/tvm/topi/stft.py
Outdated
|
|
||
| output_buf = tir.decl_buffer(output_shape, data.dtype, "output_buf") | ||
| loop_kind = "vectorize" | ||
| if hasattr(output_shape[2], "name") and output_shape[2].name == "any_dim": |
There was a problem hiding this comment.
if isinstance(output_shape[2], tir.expr.Any)
There was a problem hiding this comment.
The type is tir.expr.SizeVar, updated.
| verify_trace_model(test_fn(3, 3, 3, False, "reflect", False, True), [input, window], targets) | ||
| window = torch.tensor([1, 3], dtype=torch.int32) | ||
| verify_trace_model(test_fn(2, 1, 2, False, "reflect", False, True), [input, window], targets) | ||
|
|
There was a problem hiding this comment.
Please add a test for window=None case.
| tvm.testing.assert_allclose(op_res.numpy(), ref_res, rtol=1e-5) | ||
| relay.backend.te_compiler.get().clear() | ||
|
|
||
|
|
python/tvm/relay/op/transform.py
Outdated
| def stft(data, n_fft, hop_length, win_length, window, normalized, onesided): | ||
| """ | ||
| The STFT computes the Fourier transform of short overlapping windows of the input. | ||
| This giving frequency components of the signal as they change over time. |
There was a problem hiding this comment.
giving -> gives
Fix the same typo in other files too.
tests/python/relay/test_op_level3.py
Outdated
| ) | ||
|
|
||
|
|
||
| def verify_func2(target, dev, func, data, ref_res, rtol=1e-5, atol=1e-7, kinds=["vm"]): |
There was a problem hiding this comment.
Why do you need this? It looks identical to verify_func.
There was a problem hiding this comment.
I want to expost rtol atol and kinds, I can update the original function instead of creating a new one?
python/tvm/topi/cuda/stft.py
Outdated
| win_length, | ||
| window, | ||
| normalized, | ||
| onesided, # pylint: disable=unused-argument |
There was a problem hiding this comment.
You can remove # pylint: disable=unused-argument and add unused-argument to L17.
python/tvm/topi/cuda/stft.py
Outdated
| win_length, | ||
| window_ptr, | ||
| normalized, | ||
| onesided, # pylint: disable=unused-argument |
python/tvm/topi/stft.py
Outdated
| win_length, | ||
| window, | ||
| normalized, | ||
| onesided, # pylint: disable=unused-argument |
There was a problem hiding this comment.
same as the comment in topi/cuda/stft.py
python/tvm/topi/stft.py
Outdated
| win_length, | ||
| window_ptr, | ||
| normalized, | ||
| onesided, # pylint: disable=unused-argument |
|
The CI failure seems to be unrelated. |
|
It looks the CI failure is resolved in main branch, do I need to rebase my changes? |
|
Yes, please send another job |
* Add: Relay stft operator * fix doc * address PR comments * address addtional comments
* Add: Relay stft operator * fix doc * address PR comments * address addtional comments
This PR adds the stft, amax, amin
torch.stft