File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -1284,6 +1284,12 @@ def _flash_attention_3_hub_backward_op(
12841284):
12851285 query , key , value = ctx .saved_tensors
12861286 kernel_fn = ctx ._hub_kernel
1287+ # NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
1288+ # primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
1289+ # therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
1290+ # `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
1291+ # the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
1292+ # in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
12871293 with torch .enable_grad ():
12881294 query_r = query .detach ().requires_grad_ (True )
12891295 key_r = key .detach ().requires_grad_ (True )
You can’t perform that action at this time.
0 commit comments