Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
527b718
formatted
amd-khushbu Oct 31, 2025
c25420e
formatted
amd-khushbu Nov 5, 2025
8999dae
formatting
amd-khushbu Nov 5, 2025
00e61e0
formatting
amd-khushbu Nov 6, 2025
e5d6a80
formatting
amd-khushbu Nov 7, 2025
a967b49
[CK TILE GEMM] Refactor block_scale_gemm examples
CongMa13 Nov 7, 2025
bc26224
[CK TILE GEMM] Refactor block_scale_gemm examples
CongMa13 Nov 7, 2025
c553c87
enable prefill shapes
amd-khushbu Nov 7, 2025
8818018
[CK TILE GEMM] Refactor block_scale_gemm examples
CongMa13 Nov 8, 2025
9debcc1
[CK TILE GEMM] Refactor block_scale_gemm examples
CongMa13 Nov 10, 2025
51ec0c2
merge with Cong's Changes
amd-khushbu Nov 11, 2025
869bc5b
adding preshuffle quant as new parameter and its associated new files
amd-khushbu Nov 11, 2025
075c36b
remove debugging statements
amd-khushbu Nov 11, 2025
0f79fa5
adding test
amd-khushbu Nov 12, 2025
b8b5709
enable preshuffle quant with permuteN
amd-khushbu Nov 12, 2025
903800f
rebase with develop
amd-khushbu Nov 13, 2025
48e7559
updating readme and correcponding gemmconfigs
amd-khushbu Nov 13, 2025
36f2f87
updating cmake file
amd-khushbu Nov 13, 2025
07700cc
fixing CI failures for grouped quant gemm
amd-khushbu Nov 14, 2025
f5856af
Merge branch 'develop' into lwpck-3984
amd-khushbu Nov 14, 2025
2275548
debugging permuteN
amd-khushbu Nov 18, 2025
a974a08
debugging
amd-khushbu Nov 20, 2025
cf3f9b5
Merge branch 'develop' into lwpck-3985
amd-khushbu Nov 20, 2025
04aaf97
debugging PermuteN
amd-khushbu Nov 24, 2025
7788979
initial commit
amd-khushbu Nov 25, 2025
f290428
working code for preshuffleb
amd-khushbu Nov 26, 2025
3447196
resolving merge conflicts
amd-khushbu Nov 26, 2025
2441260
Merge branch 'develop' into 1dQuantPreshuffleWeight
amd-khushbu Dec 4, 2025
18fe146
Merge branch 'develop' into 2dQuantPreshuffleWeight
ThomasNing Dec 4, 2025
19b78e9
Merge remote-tracking branch 'origin/develop' into 2dQuantPreshuffleW…
amd-khushbu Dec 4, 2025
3021c7a
adding test cases
amd-khushbu Dec 4, 2025
48744f2
initial commit with prints
amd-khushbu Dec 5, 2025
3ea3ca7
debugging
amd-khushbu Dec 6, 2025
c28cf0e
fine-grained working
amd-khushbu Dec 9, 2025
ec044f5
rebase with develop
amd-khushbu Dec 10, 2025
341d0e3
debugging medium grained
Dec 11, 2025
92cbe3c
fixing the tile window
amd-khushbu Dec 11, 2025
995d1a5
resolving merge ocnflicts
amd-khushbu Dec 11, 2025
44aaaac
formatting
amd-khushbu Dec 12, 2025
9ad0687
enabling prefill shapes
amd-khushbu Dec 12, 2025
5a3e7de
Merge branch 'develop' into lwpck-4181
amd-khushbu Dec 15, 2025
373d89d
working prefill shapes
amd-khushbu Dec 16, 2025
05ff943
Merge branch 'develop' into lwpck-4181
amd-khushbu Dec 16, 2025
cc994a7
formatted
amd-khushbu Dec 16, 2025
26a1b52
clean up
amd-khushbu Dec 17, 2025
4c382e7
code cleanup
amd-khushbu Dec 17, 2025
7aeff21
resolving merge conflicts
amd-khushbu Dec 19, 2025
27d31ba
bug fix after merging with develop
amd-khushbu Dec 19, 2025
06c4866
Merge branch 'develop' into 2d_preshuffle_quant
ThomasNing Dec 23, 2025
46543df
Merge branch 'develop' into 2d_preshuffle_quant
amd-khushbu Jan 5, 2026
748add0
Merge branch '2d_preshuffle_quant' of https://github.com/ROCm/composa…
amd-khushbu Jan 5, 2026
4d8cba1
clean up after merging with develop
amd-khushbu Jan 5, 2026
6c52e9d
added comments for the tile window and tile distribution encoding
amd-khushbu Jan 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
QuantMode == ck_tile::QuantType::RowColQuant)
{
bq_tensor_ptr = std::make_unique<ck_tile::HostTensor<BQDataType>>(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, is_row_major(bq_layout)));
}
else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ struct BQuantBlockUniversalGemmAsBsCr
constexpr index_t reg_offset = nIter;
auto pull_from_lane =
(__lane_id() & (WarpGemm::kN - 1)) * Traits::KQPerBlock + kQScale;

auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset];
// cross lane ops
uint32_t scale_reg_dword;
Expand Down
645 changes: 44 additions & 601 deletions include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK;
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;

using WarpTile = typename Problem::BlockGemmShape::WarpTile;
Expand All @@ -68,7 +67,8 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
BlockSize,
NPerBlock / WarpGemm::kN,
ck_tile::integer_least_multiple(WarpGemm::kN * KPerBlockBQ, get_warp_size()),
VecLoadSize,
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout,
PreshuffleQuant>;
return TileEncodingPattern::make_2d_static_tile_distribution();
Expand All @@ -83,6 +83,7 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
KPerBlockBQ, // Logical K dimension
NPerBlockBQ, // Logical N dimension
Problem::BQuantGroupSize::kN,
Problem::BQuantGroupSize::kK,
BQLayout>;

return TileEncodingPattern::make_2d_static_tile_distribution();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;

static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);

static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
Expand Down Expand Up @@ -300,9 +302,12 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Prob
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
const BQDramTileWindowStep bq_dram_tile_window_step =
(PreshuffleQuant) ? make_array(ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
0)
(PreshuffleQuant)
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)
: is_bq_row_major ? make_array(KPerBlockBQ, 0)
: make_array(0, KPerBlockBQ);

Expand Down
166 changes: 128 additions & 38 deletions include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ template <typename BlockGemmShape,
index_t KPerTile,
index_t NPerTile,
index_t NPerQ,
index_t KPerQ,
typename BQLayout = tensor_layout::gemm::ColumnMajor,
bool PreshuffleQuant = false>
struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern
Expand All @@ -208,31 +209,6 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding
static_assert(num_warps == MWarps * NWarps * KWarps);
static_assert(KWarps == 1);

/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and applying
/// quantization scales to the B matrix based on the quantization group size (NPerQ) relative
/// to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale broadcast
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = NPerQ / WarpGemm::kN
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution()
{
// Preshuffle only supported for ColumnMajor currently
Expand All @@ -241,22 +217,136 @@ struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding

if constexpr(PreshuffleQuant)
{
// ColumnMajor only for preshuffle
constexpr index_t X1 = warp_size;
constexpr index_t X0 = NPerTile / warp_size;
constexpr index_t Y1 = NWarps;
constexpr index_t Y0 = KPerTile / Y1;

return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps>,
tuple<sequence<Y0, Y1>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2>>,
tuple<sequence<0, 1>, sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{});
// =============================================================================
// PRE-SHUFFLED BQ SCALE TILE DISTRIBUTION
// =============================================================================
// For pre-shuffled quantization, the BQ scale tensor has been reorganized
// (pre-shuffled) to optimize memory access patterns during dequantization.
//
// Tile Dimensions:
// - K-axis (Y in encoding): Corresponds to the K-dimension iteration
// - N-axis (X in encoding): Flattened scale index combining N and K groups
//
// The encoding distributes work across threads such that each thread loads
// the correct pre-shuffled scale for its corresponding B-matrix elements.
// =============================================================================
if constexpr(NPerQ <= WarpGemm::kN)
{
// =========================================================================
// CASE 1: Fine-grained Quantization (NPerQ <= WarpGemm::kN)
// =========================================================================
// Multiple quantization scales exist within a single warp's N-dimension.
// Each warp processes multiple scales: WarpGemm::kN / NPerQ scales per warp.
//
// Example: NPerQ=8, WarpGemm::kN=16, KPerQ=128, BlockGemmShape::kK=256
// → 2 scales per warp in N, 2 K-groups per block
constexpr auto N1 = BlockGemmShape::kK /
KPerQ; // Number of K-dimension quantization groups per block,
// Each K-group of KPerQ elements shares the same scale.
constexpr auto N0 =
WarpGemm::kN / NPerQ; // Number of scales per warp in N-dimension, Since NPerQ
// <= WarpGemm::kN, each warp handles multiple scales.
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = NPerQ; // Elements sharing the same scale in N-dimension
constexpr auto NR0 =
warp_size /
(N0 * N1 * N2 * NR1); // Interleave factor to ensure full warp utilization
constexpr auto K1 = NWarps; // Number of warps distributed along this dimension
constexpr auto K0 = KPerTile / K1; // Iterations per warp to cover the K-tile
constexpr auto KR = 1; // No replication in K-dimension

return make_static_tile_distribution(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also add the partition reasoning of the different condition of tile distribution?

tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1>, sequence<0, 2, 0, 2, 0>>,
tuple<sequence<0, 1>, sequence<1, 0, 2, 1, 3>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else if constexpr(NPerQ < WarpGemm::kN * NWarps)
{
// =========================================================================
// CASE 2: Medium-grained Quantization (WarpGemm::kN < NPerQ < WarpGemm::kN *
// NWarps)
// =========================================================================
// Each warp handles exactly one quantization scale in N-dimension.
// Some warps share the same scale (KR > 1 creates warp grouping).
//
// Example: NPerQ=32, WarpGemm::kN=16, NWarps=4
// → KR=2 (2 warps share same scale), K1=2 (2 unique scale groups)

constexpr auto KR = NPerQ / WarpGemm::kN; // Number of warps sharing the same scale
constexpr auto K1 = NWarps / KR; // Number of distinct warp groups (unique scales)
constexpr auto K0 = KPerTile / K1; // Iterations to cover K-tile per warp group
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
constexpr auto N0 = 1; // Scales per warp in N-dim (1 since NPerQ >= WarpGemm::kN)
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = NPerQ; // Scale broadcast factor (full NPerQ)
constexpr auto NR0 =
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor

return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NR0, NR1, KR>,
tuple<sequence<K0, K1>, sequence<N0, N1, N2>>,
tuple<sequence<0, 1, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1, 3>, sequence<1, 0, 2, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
// =========================================================================
// CASE 3: Coarse-grained Quantization (NPerQ >= WarpGemm::kN * NWarps)
// =========================================================================
// The quantization group spans ALL warps in N-dimension.
// All warps share the same scale value for their N-tiles.
//
// Example: NPerQ=128, WarpGemm::kN=16, NWarps=4
// → 128 >= 16*4=64, so all 4 warps use the same scale
constexpr auto N1 = BlockGemmShape::kK / KPerQ; // K-dimension quantization groups
constexpr auto N0 = 1; // Minimal (1) since scale is shared across N
constexpr auto N2 = 1; // Elements per thread
constexpr auto NR1 = 32; // Fixed broadcast size
constexpr auto NR0 =
warp_size / (N0 * N1 * N2 * NR1); // Remaining interleave factor
return make_static_tile_distribution(
tile_distribution_encoding<sequence<MWarps, NWarps, NR0, NR1>,
tuple<sequence<KPerTile>, sequence<N0, N1, N2>>,
tuple<sequence<0, 0>, sequence<0, 2, 0, 2>>,
tuple<sequence<0, 1>, sequence<2, 0, 3, 1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
}
else
{
/// @brief Creates a 2D tile distribution for BQ (B-matrix quantization scales)
///
/// This function determines the optimal thread distribution pattern for loading and
/// applying quantization scales to the B matrix based on the quantization group size
/// (NPerQ) relative to warp dimensions.
///
/// Three distinct distribution patterns are handled:
///
/// 1. Fine-grained quantization (NPerQ < WarpGemm::kN):
/// - Multiple quantization groups exist within a single warp's N-dimension
/// - Each warp processes multiple scales (WarpGemm::kN / NPerQ scales per warp)
/// - Distribution includes explicit replication factor (XR = NPerQ) for scale
/// broadcast
/// - Example: NPerQ=8, WarpGemm::kN=16, NWarps=4 → 2 scales per warp
///
/// 2. Medium-grained quantization (WarpGemm::kN <= NPerQ <= WarpGemm::kN * NWarps):
/// - Each warp handles exactly one quantization scale
/// - Scales are distributed across warps with replication factor XR = NPerQ /
/// WarpGemm::kN
/// - Example: NPerQ=64, WarpGemm::kN=16, NWarps=4 → 1 scale per warp, XR=4
///
/// 3. Coarse-grained quantization (NPerQ > WarpGemm::kN * NWarps):
/// - Quantization group spans multiple warps
/// - All warps share the same scale value
/// - Example: NPerQ=128, WarpGemm::kN=16, NWarps=4 → all warps use same scale
///
/// @return A static tile distribution encoding for the BQ scale tensor
if constexpr(NPerQ < WarpGemm::kN)
{
// Case 1: Fine-grained - multiple quantization scales within a single warp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV

static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
static constexpr index_t NPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kN, QuantGroupSize::kN);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, QuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
Expand Down Expand Up @@ -351,8 +353,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else
Expand Down Expand Up @@ -426,8 +430,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else
Expand Down Expand Up @@ -461,8 +467,10 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
if constexpr(PreshuffleQuant)
{
move_tile_window(bq_copy_dram_window,
{ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{}),
{((NPerBlockBQ < BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, QuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, kNPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0});
}
else
Expand Down