Skip to content

[WIP,Cute,Flex,Sm100] vectorized mask mod application#2261

Draft
reubenconducts wants to merge 1 commit intoDao-AILab:mainfrom
reubenconducts:rstern/mask-vec
Draft

[WIP,Cute,Flex,Sm100] vectorized mask mod application#2261
reubenconducts wants to merge 1 commit intoDao-AILab:mainfrom
reubenconducts:rstern/mask-vec

Conversation

@reubenconducts
Copy link
Copy Markdown
Contributor

@reubenconducts reubenconducts commented Feb 17, 2026

Follow-up to #2236. The approach to vectorizing is bipartite:

  • Vectorize mask application, to compile down to r2p
  • Vectorize mask evaluation

The latter is important for example in situations where mask_mod depends on aux_tensors that are contiguous in the kv idx, or when aux_tensors don't depend on kv index at all.

mask_mods still emit TensorSSAs, but they need not be single values. These are treated as bit-packed masks.

cc @drisspg

Comment thread flash_attn/cute/mask.py
# 2: application, where it is applied to compile down to r2p
#
# evaluation
num_mask_vals = (ncol + 32 - 1) // 32
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need R2P width global constant, this is where the 32 comes from right?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or I guess its vecsize?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the 32 does come from the R2P width yes

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually sorry that's not the case - it was just my choice to keep the bitmask in Uint32s

Copy link
Copy Markdown
Collaborator

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple clarifying questions but this looks good, I just put up autotuning PR: pytorch/pytorch#176055

helps alot in some cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants