[TIR] Support cross-threaad reduction lowering with thread-broadcasting rewrite#15192
Merged
junrushao merged 1 commit intoapache:mainfrom Jul 3, 2023
Merged
Conversation
Collaborator
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
junrushao
approved these changes
Jul 2, 2023
Member
|
@tvm-bot rerun |
yzh119
approved these changes
Jul 2, 2023
Member
yzh119
left a comment
There was a problem hiding this comment.
Overall LGTM, some slight comments :)
tqchen
requested changes
Jul 2, 2023
…ng rewrite This PR enhances the LowerCrossThreadReduction pass with the thread-broadcasting block rewrite. Specifically, previously whenever a TIR block has thread-broadcast behavior (i.e., there exists some thread var which is free for the block), we never insert a predicate for the block and therefore the generated final code has race condition, which sometimes lead to wrong computation results. This PR enhances the pass by collecting thread var information along transformation, and rewrite the thread-broadcast TIR block with additional predicate clauses which bound the thread vars and effectively state that "only execute the block when `thread_var == 0`". Therefore, the race condition issue in such blocks is resolved.
93b7ea3 to
460736c
Compare
tqchen
approved these changes
Jul 3, 2023
junrushao
added a commit
to junrushao/tvm
that referenced
this pull request
Jul 6, 2023
This PR improves the Decode-GEMV scheduling by further analyzing its epilogue pattern. The existing behavior assumes that the outcome of cross-thread reduction stays in register files local to each thread, which is further used to calculate the epilogue in the same thread. This strategy means the cross-thread reduction outcome is stored only on thread 0, while the other threads cannot participate in subsequent computation (i.e. epilogue). Related: apache#15192. When the epilogue is relatively lightweight, i.e. elementwise add, casting on scalars, this strategy is optimal. However, once the outcome needs to be broadcasted to compute over a non-trivial region, for example, act as a normalizer of `np.mean`, it would become much slower because only one thread in a thread block is effectively used. In this case, we will need to broadcast the cross-thread reduction outcome in shared memory, making it visible to other threads, and then bind the compute region to all threads in the threadblock.
tqchen
pushed a commit
that referenced
this pull request
Jul 6, 2023
This PR improves the Decode-GEMV scheduling by further analyzing its epilogue pattern. The existing behavior assumes that the outcome of cross-thread reduction stays in register files local to each thread, which is further used to calculate the epilogue in the same thread. This strategy means the cross-thread reduction outcome is stored only on thread 0, while the other threads cannot participate in subsequent computation (i.e. epilogue). Related: #15192. When the epilogue is relatively lightweight, i.e. elementwise add, casting on scalars, this strategy is optimal. However, once the outcome needs to be broadcasted to compute over a non-trivial region, for example, act as a normalizer of `np.mean`, it would become much slower because only one thread in a thread block is effectively used. In this case, we will need to broadcast the cross-thread reduction outcome in shared memory, making it visible to other threads, and then bind the compute region to all threads in the threadblock.
MasterJH5574
added a commit
to MasterJH5574/tvm
that referenced
this pull request
Jul 21, 2023
This PR fixes the predicate handling logic of the cross-thread reduction lowering pass. For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of `t == 0` for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts: * the clause of the original reduction block's predicate which contains spatial loop var, * `t == 0` for each reduction thread dim **only when the write-back buffer is global or shared**. So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out. Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of `blockIdx`) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by apache#15192.
MasterJH5574
added a commit
to MasterJH5574/tvm
that referenced
this pull request
Jul 21, 2023
This PR fixes the predicate handling logic of the cross-thread reduction lowering pass. For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of `t == 0` for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts: * the clause of the original reduction block's predicate which contains spatial loop var, * `t == 0` for each reduction thread dim **only when the write-back buffer is global or shared**. So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out. Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of `blockIdx`) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by apache#15192.
MasterJH5574
added a commit
to MasterJH5574/tvm
that referenced
this pull request
Jul 21, 2023
This PR fixes the predicate handling logic of the cross-thread reduction lowering pass. For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of `t == 0` for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts: * the clause of the original reduction block's predicate which contains spatial loop var, * `t == 0` for each reduction thread dim **only when the write-back buffer is global or shared**. So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out. Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of `blockIdx`) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by apache#15192.
MasterJH5574
added a commit
to MasterJH5574/tvm
that referenced
this pull request
Jul 21, 2023
This PR fixes the predicate handling logic of the cross-thread reduction lowering pass. For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of `t == 0` for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts: * the clause of the original reduction block's predicate which contains spatial loop var, * `t == 0` for each reduction thread dim **only when the write-back buffer is global or shared**. So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out. Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of `blockIdx`) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by apache#15192.
MasterJH5574
added a commit
to MasterJH5574/tvm
that referenced
this pull request
Jul 22, 2023
This PR fixes the predicate handling logic of the cross-thread reduction lowering pass. For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of `t == 0` for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts: * the clause of the original reduction block's predicate which contains spatial loop var, * `t == 0` for each reduction thread dim **only when the write-back buffer is global or shared**. So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out. Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of `blockIdx`) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by apache#15192.
tqchen
pushed a commit
that referenced
this pull request
Jul 22, 2023
This PR fixes the predicate handling logic of the cross-thread reduction lowering pass. For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of `t == 0` for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts: * the clause of the original reduction block's predicate which contains spatial loop var, * `t == 0` for each reduction thread dim **only when the write-back buffer is global or shared**. So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out. Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of `blockIdx`) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by #15192.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR enhances the LowerCrossThreadReduction pass with the thread-broadcasting block rewrite.
Specifically, previously whenever a TIR block has thread-broadcast behavior (i.e., there exists some thread var which is free for the block), we never insert a predicate for the block and therefore the generated final code has race condition, which sometimes lead to wrong computation results.
This PR enhances the pass by collecting thread var information along transformation, and rewrite the thread-broadcast TIR block with additional predicate clauses which bound the thread vars and effectively state that "only execute the block when
thread_var == 0". Therefore, the race condition issue in such blocks is resolved.