-
Notifications
You must be signed in to change notification settings - Fork 258
[CK_Tile] Support for various group sizes Preshuffle quant for 2d block scale gemm #3445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
- Split cpp file to reduce building time - Support multiple GemmConfig
- Update Readme
- Add support for rowcol and tensor GEMM operations
- Update README
ThomasNing
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments. Please also add the tests and next step I believe to integrate that with the preshuffled.
| if constexpr(NPerQ <= WarpGemm::kN) | ||
| { | ||
| constexpr auto N1 = BlockGemmShape::kK / KPerQ; | ||
| constexpr auto N0 = WarpGemm::kN / NPerQ; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the variable naming here with N1 = Kvalues and k0 = nwarps will cause confusion. Could we restructure the variable names and add the comments?
| constexpr auto K0 = KPerTile / K1; | ||
| constexpr auto KR = 1; | ||
|
|
||
| return make_static_tile_distribution( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we also add the partition reasoning of the different condition of tile distribution?
| ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); | ||
| constexpr auto tile_window_height = block_n / warp_n; | ||
| auto block_n_idx = i_n / block_n; | ||
| constexpr auto tile_window_height = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could I understand the reason of this condition?
Proposed changes
This PR introduces Preshuffle Quant for group size N = (1, 8, 16, 32, 64) for both prefill and decode shapes for 2d block scale Gemm.
Changes include:
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered