Skip to content

Comments

[core] Enable CP for kernels-based attention backends#12812

Merged
sayakpaul merged 13 commits intomainfrom
enable-cp-kernels
Feb 19, 2026
Merged

[core] Enable CP for kernels-based attention backends#12812
sayakpaul merged 13 commits intomainfrom
enable-cp-kernels

Conversation

@sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Dec 9, 2025

What does this PR do?

Adds CP support to the kernels-based attention backends.

Our CP support is quickly gaining traction. Currently, we have a few attention backends that are fully based on kernels. In order for their adoption to grow and make them a bit more complete in terms of feature parity, I think we should make them CP-compatible, too.

Code to test:
import argparse
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig, AutoModel


CKPT_ID = "black-forest-labs/FLUX.1-dev"

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cp-backend",
        type=str,
        choices=["ring", "ulysses", "unified"],
        default="ulysses",
        help="Context parallel backend to use.",
    )
    parser.add_argument(
        "--attn-backend",
        type=str,
        choices=["flash_hub", "_flash_3_hub", "sage_hub"],
        default="flash_hub",
        help="Attention backend to use.",
    )
    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    rank = dist.get_rank()
    device = torch.device(f"cuda:{rank}")
    torch.cuda.set_device(device)
    return device


def main():
    args = parse_args()

    device = setup_distributed()
    world_size = dist.get_world_size()
    if args.cp_backend == "ring":
        cp_config = ContextParallelConfig(ring_degree=world_size)
    elif args.cp_backend == "unified":
        cp_config = ContextParallelConfig(ring_degree=world_size // 2, ulysses_degree=world_size // 2)
    else:
        cp_config = ContextParallelConfig(ulysses_degree=world_size)

    transformer = AutoModel.from_pretrained(
        CKPT_ID, 
        subfolder="transformer", 
        torch_dtype=torch.bfloat16, 
        parallel_config=cp_config
    )

    pipeline = DiffusionPipeline.from_pretrained(
        CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
    ).to(device)
    pipeline.transformer.set_attention_backend(args.attn_backend)

    prompt = """
    cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
    highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
    """

    generator = torch.Generator().manual_seed(42)
    image = pipeline(
        prompt,
        guidance_scale=3.5,
        num_inference_steps=50,
        generator=generator,
    ).images[0]

    if dist.get_rank() == 0:
        image.save(f"output_{args.cp_backend}_{args.attn_backend}.png")

    if dist.is_initialized():
        dist.destroy_process_group()


if __name__ == "__main__":
    main()

Outputs:

FA2+ Ulysses FA3 + Ulysses SAGE + Ulysses
Ring Ulysses Unified

@sayakpaul sayakpaul requested a review from DN6 December 9, 2025 09:47
@sayakpaul sayakpaul added the performance Anything related to performance improvements, profiling and benchmarking label Dec 9, 2025
Comment on lines +280 to +281
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only FA2 provides these.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you take a closer look, there is an equivalent for FA3. FA2 just renames its backward for wrapped_xxx

So I expect that when torch may come around FA3, we get the same standardization but for now the equivalent is just

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jan 10, 2026
@sayakpaul sayakpaul removed the stale Issues that haven't received updates label Jan 11, 2026
@sayakpaul
Copy link
Member Author

@DN6 a gentle ping.

@sayakpaul sayakpaul added the roadmap Add to current release roadmap label Feb 16, 2026
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment about the FA3 backward. Not a merge blocker since it mostly affects CP based training

key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)

out = kernel_fn(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would result in a second in a forward pass during the backward op right? Would it make sense to just raise an error here similar to sage attention?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment in 9465231

@sayakpaul sayakpaul merged commit 99daaa8 into main Feb 19, 2026
12 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Diffusers Roadmap 0.37 Feb 19, 2026
@sayakpaul sayakpaul deleted the enable-cp-kernels branch February 19, 2026 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Anything related to performance improvements, profiling and benchmarking roadmap Add to current release roadmap

Projects

Development

Successfully merging this pull request may close these issues.

4 participants