[core] Enable CP for kernels-based attention backends#12812
Conversation
| wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward", | ||
| wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward", |
There was a problem hiding this comment.
Only FA2 provides these.
There was a problem hiding this comment.
If you take a closer look, there is an equivalent for FA3. FA2 just renames its backward for wrapped_xxx
- The original backward is noted in https://github.com/Dao-AILab/flash-attention/blob/a8780f2a17099fc1a3e7b00d7f5d9e08c5b71142/flash_attn/flash_attn_interface.py#L330-L333 (which is essentially just fancy ABI wrapping)
- In lower torch this leads to https://github.com/Dao-AILab/flash-attention/blob/a8780f2a17099fc1a3e7b00d7f5d9e08c5b71142/flash_attn/flash_attn_interface.py#L242
So I expect that when torch may come around FA3, we get the same standardization but for now the equivalent is just
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
@DN6 a gentle ping. |
DN6
left a comment
There was a problem hiding this comment.
One comment about the FA3 backward. Not a merge blocker since it mostly affects CP based training
| key_r = key.detach().requires_grad_(True) | ||
| value_r = value.detach().requires_grad_(True) | ||
|
|
||
| out = kernel_fn( |
There was a problem hiding this comment.
This would result in a second in a forward pass during the backward op right? Would it make sense to just raise an error here similar to sage attention?
What does this PR do?
Adds CP support to the
kernels-based attention backends.Our CP support is quickly gaining traction. Currently, we have a few attention backends that are fully based on
kernels. In order for their adoption to grow and make them a bit more complete in terms of feature parity, I think we should make them CP-compatible, too.Code to test:
Outputs: