Skip to content

perf: use 1x4 warp layout for small query length#322

Merged
yzh119 merged 7 commits into
mainfrom
another-prefill-1x4
Jun 21, 2024
Merged

perf: use 1x4 warp layout for small query length#322
yzh119 merged 7 commits into
mainfrom
another-prefill-1x4

Conversation

@yzh119
Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 commented Jun 20, 2024

Duplicate of #304 and #185, just rebased on main.

This PR can accelerate GQA, we will release v0.0.6 after this PR gets merged (ETA: tonight).

@yzh119 yzh119 changed the title perm: use 1x4 warp layout for small query length perf: use 1x4 warp layout for small query length Jun 20, 2024
@yzh119 yzh119 merged commit 4e89b4d into main Jun 21, 2024
yzh119 added a commit that referenced this pull request Jun 21, 2024
🤖 I have created a release *beep* *boop*
---


##
[0.0.6](v0.0.5...v0.0.6)
(2024-06-21)


### Performance Improvements

* use 1x4 warp layout for small query length
([#322](#322))
([4e89b4d](4e89b4d))

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <expye@outlook.com>
@yzh119 yzh119 mentioned this pull request Jun 21, 2024
yzh119 added a commit that referenced this pull request Jun 21, 2024
Some last commits for bugfix are missing for #322.
yzh119 added a commit that referenced this pull request Jun 21, 2024
Disable #322 for v0.0.6 release because binary size is too large.
v0.0.6 will only include bugfix at the moment.
@yzh119 yzh119 deleted the another-prefill-1x4 branch June 30, 2024 07:14
yzh119 added a commit that referenced this pull request Jul 4, 2024
Changes:
1. Prefetch page indices (we have already done such optimization on
decode kernels, but not on append/prefill kernels which was used in
GQA).
2. Unlock 1x4 warp layout in
#322, we didn't enable
this because the binary size is too large, we should further reduce some
unnecessary template arguments.
3. Optimize `threadblock_sync_mdo_states` for efficient merging
attention states of multiple warps in a threadblock. Our previous
implementation assumes small shared memory size and interleaves shared
memory reads/writes with computations, which is not as efficient as a
bulk shared memory access.

After this PR, the GQA kernel execution time (on H100) for setting
`batch_size=128, seq_len=1024, num_qo_heads=32, num_kv_heads=4,
head_dim=128` was improved from 133us to 103us.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant