[CuTe,Flex] varlen blocksparsity#2224
Conversation
|
Hi there, @reubenconducts ! Thank you so much for your draft. Since I also need this feature eagerly, I tried to continue development based on your branch which fixes some grammar issue (SeanLi-OI@7ccfc5e). Though it can run, but returns wrong results when batch_size > 1. I completely understand you may be busy with other priorities. If you have a moment, I’d be truly grateful for any guidance: |
|
@SeanLi-OI Yes, I will be continuing this, but not until next week. |
a4f3021 to
bc15c46
Compare
bc15c46 to
03f2f92
Compare
03f2f92 to
04d3016
Compare
ab9bbeb to
9c4370b
Compare
|
Hi @drisspg, @reubenconducts, just checking in on this PR. Are there any remaining blockers or changes needed before it can be merged? |
6b85348 to
6f736e9
Compare
This PR extends blocksparsity to the variable sequence length case. Whereas in batched blocksparsity the metadata tensors take the shapes
in varlen blocksparsity, we pack our metadata tensors to take the shapes
where
total_m_blocksis the sum of allmblocks per head (equiv. number of work tiles per head) andtotal_n_blocksis the total of allnblocks potentially processed per head across all sequences in the batch. For example, consider a varlen batch with sequences contained inseqlens_qandseqlens_k. At batch indexb, we letand define
total_m_blocks = sum_{b \in B} num_m(b) total_n_blocks = sum_{b \in B} num_m(b) * num_n(b)To properly index into the blocksparsity tensors, we use auxiliary
mCuTotalMBlocksandmCuTotalNBlockstensors, which can be prepared on host.cc @drisspg @v0i0
NOT INTENDED FOR THIS PR: