[TIR] Add T.thread_return() for early thread exit in CUDA kernels#18134
Merged
tqchen merged 1 commit intoapache:mainfrom Jul 14, 2025
Merged
[TIR] Add T.thread_return() for early thread exit in CUDA kernels#18134tqchen merged 1 commit intoapache:mainfrom
T.thread_return() for early thread exit in CUDA kernels#18134tqchen merged 1 commit intoapache:mainfrom
Conversation
This commit implements T.thread_return() functionality that allows threads
to exit early from CUDA kernels. The feature is useful for cases where
threads need to conditionally return based on thread indices or other
conditions.
Key changes:
- Add thread_return builtin in TIR
- Implement CUDA codegen for thread_return
- Add Python bindings for T.thread_return()
- Update TIR IR builder to support thread_return
- Add tests demonstrating thread_return usage
Example usage:
```python
@T.prim_func
def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")):
for i in T.thread_binding(16, thread="blockIdx.x"):
for j in T.thread_binding(32, thread="threadIdx.x"):
if j >= 16:
T.thread_return() # Early exit for threads with j >= 16
B[i, j] = A[i, j]
```
and generate code is:
```cuda
extern "C" __global__ void __launch_bounds__(32) main_kernel(float* __restrict__ A, float* __restrict__ B) {
if (16 <= ((int)threadIdx.x)) {
return;
}
B[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))] = A[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))];
}
```
Member
Author
|
cc @LeiWang1999 |
tqchen
approved these changes
Jul 14, 2025
ShiboXing
pushed a commit
to ShiboXing/tvm
that referenced
this pull request
Aug 10, 2025
…pache#18134) This commit implements T.thread_return() functionality that allows threads to exit early from CUDA kernels. The feature is useful for cases where threads need to conditionally return based on thread indices or other conditions. Key changes: - Add thread_return builtin in TIR - Implement CUDA codegen for thread_return - Add Python bindings for T.thread_return() - Update TIR IR builder to support thread_return - Add tests demonstrating thread_return usage Example usage: ```python @T.prim_func def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): for i in T.thread_binding(16, thread="blockIdx.x"): for j in T.thread_binding(32, thread="threadIdx.x"): if j >= 16: T.thread_return() # Early exit for threads with j >= 16 B[i, j] = A[i, j] ``` and generate code is: ```cuda extern "C" __global__ void __launch_bounds__(32) main_kernel(float* __restrict__ A, float* __restrict__ B) { if (16 <= ((int)threadIdx.x)) { return; } B[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))] = A[((((int)blockIdx.x) * 16) + ((int)threadIdx.x))]; } ```
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This commit implements T.thread_return() functionality that allows threads to exit early from CUDA kernels. The feature is useful for cases where threads need to conditionally return based on thread indices or other conditions.
Key changes:
Example usage:
and generate code is: