Skip to content

Commit 9465231

Browse files
committed
up
1 parent e80f6c9 commit 9465231

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,6 +1284,12 @@ def _flash_attention_3_hub_backward_op(
12841284
):
12851285
query, key, value = ctx.saved_tensors
12861286
kernel_fn = ctx._hub_kernel
1287+
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
1288+
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
1289+
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
1290+
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
1291+
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
1292+
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
12871293
with torch.enable_grad():
12881294
query_r = query.detach().requires_grad_(True)
12891295
key_r = key.detach().requires_grad_(True)

0 commit comments

Comments
 (0)