Skip to content

Commit 461a40d

Browse files
committed
Fix batched gemm
1 parent 6c9071a commit 461a40d

2 files changed

Lines changed: 18 additions & 6 deletions

File tree

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
6060
const long_index_t c_batch_offset =
6161
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
6262

63-
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
64-
typename GridwiseGemm::EpilogueCShuffle>();
63+
using EpilogueType =
64+
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
65+
GridwiseGemm::UseDirectStore,
66+
typename GridwiseGemm::EpilogueDirectStore,
67+
typename GridwiseGemm::EpilogueCShuffle>::type;
68+
69+
constexpr index_t LDS_size =
70+
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
6571
__shared__ char p_shared[LDS_size];
6672

6773
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
@@ -84,7 +90,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
8490
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
8591
});
8692

87-
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
93+
auto epilogue_args = EpilogueType{};
8894

8995
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
9096
p_as_grid_shift,

include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
4646
std::is_same_v<c_data_type, ck::bhalf_t>)))
4747
{
4848
#endif
49-
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
50-
typename GridwiseGemm::EpilogueCShuffle>();
49+
using EpilogueType =
50+
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
51+
GridwiseGemm::UseDirectStore,
52+
typename GridwiseGemm::EpilogueDirectStore,
53+
typename GridwiseGemm::EpilogueCShuffle>::type;
54+
55+
constexpr index_t LDS_size =
56+
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
5157
// The normal approach to batching would be to increase the grid size by just stretching out
5258
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
5359
// functions not directly using the Z dimension for other calculations. As it turns out, k
@@ -86,7 +92,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
8692
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
8793
});
8894

89-
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
95+
auto epilogue_args = EpilogueType{};
9096

9197
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
9298
p_as_grid_shift,

0 commit comments

Comments
 (0)