Skip to content

Fix FA3 ring attention signature compatibility#2399

Draft
rasdani wants to merge 1 commit intomainfrom
fix/fa3-signature-compat
Draft

Fix FA3 ring attention signature compatibility#2399
rasdani wants to merge 1 commit intomainfrom
fix/fa3-signature-compat

Conversation

@rasdani
Copy link
Copy Markdown
Contributor

@rasdani rasdani commented May 3, 2026

Summary

Fix FlashAttention 3 ring-attention calls so they match the installed FA3 low-level function signature.

The pinned flash-attn-3 package exposes _flash_attn_forward / _flash_attn_backward with causal and tuple window_size parameters, while the PrimeRL ring-attention wrapper was always passing is_causal, window_size_left, and window_size_right. CP + FA3 model paths therefore failed before reaching the kernel.

This change keeps compatibility with both observed signature shapes:

  • causal + window_size
  • is_causal + window_size_left / window_size_right

Repro

On current main, this hits the bug through the PrimeRL FA3 ring-attention wrapper:

UV_CACHE_DIR=/tmp/prime-repro-uv-cache uv run --no-sync python - <<'PY'
import torch
import prime_rl._compat  # noqa: F401
from prime_rl.trainer.models.layers import ring_attn

q = torch.empty((1, 1, 64), dtype=torch.bfloat16)
k = torch.empty_like(q)
v = torch.empty_like(q)
cu = torch.tensor([0, 1], dtype=torch.int32)

ring_attn._fa3_varlen_forward(
    q,
    k,
    v,
    cu,
    cu,
    1,
    1,
    1.0,
    True,
    window_size=(3, 0),
)
PY

Before this patch, the installed pinned FA3 wrapper rejects the unexpected is_causal keyword.

After this patch, the same path passes causal=True and window_size=(3, 0) to the installed FA3 wrapper.

Tests

UV_CACHE_DIR=/tmp/prime-repro-uv-cache uv run --no-sync pytest tests/unit/train/models/test_ring_attn.py
UV_CACHE_DIR=/tmp/prime-repro-uv-cache uv run --no-sync ruff check src/prime_rl/trainer/models/layers/ring_attn.py tests/unit/train/models/test_ring_attn.py
git diff --check

Note

Low Risk
Low risk: small, localized keyword-arg plumbing change for FlashAttention-3 wrappers plus targeted unit tests; behavior only differs in how causal/window_size kwargs are passed into the FA3 kernel entrypoints.

Overview
Fixes FlashAttention-3 ring-attention varlen wrapper calls to handle both observed low-level FA3 signatures by setting either causal + window_size or is_causal + window_size_left/right based on the target function’s default-arg keys.

Adds unit tests that monkeypatch flash_attn_interface to validate the forward/backward wrappers pass the correct causal/window parameters for each signature variant.

Reviewed by Cursor Bugbot for commit 573efcf. Bugbot is set up for automated code reviews on this repo. Configure here.

@rasdani rasdani requested review from S1ro1 and samsja May 3, 2026 00:15
@@ -0,0 +1,196 @@
import sys
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this pls, seems useless

@rasdani rasdani marked this pull request as draft May 3, 2026 00:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants