Skip to content

add async activation offloading to CPU with pinned memory pool#2581

Open
dean-mccoppin wants to merge 1 commit intopytorch:mainfrom
dean-mccoppin:feat/activation-offloading
Open

add async activation offloading to CPU with pinned memory pool#2581
dean-mccoppin wants to merge 1 commit intopytorch:mainfrom
dean-mccoppin:feat/activation-offloading

Conversation

@dean-mccoppin
Copy link
Contributor

Implements async activation offloading to CPU during forward with prefetch back before backward (issue #2379). Activations are copied D2H on a dedicated CUDA stream, then prefetched H2D on a second stream, so transfers overlap with compute at near-zero overhead

Uses a pre-allocated pinned CPU slab to avoid per-tensor cudaMallocHost calls, which trigger implicit CUDA sync and kill the overlap benefit. enable_activation_offload: bool added to TrainingConfig. Not supported with pipeline parallel

Found a small bug during testing inside with torch.cuda.stream(s), torch.cuda.current_stream() returns s, not the compute stream. The wait_stream call was a no-op, so d2h_stream could overwrite pool memory that h2d_stream was still reading from the previous step. Shows up as silent gradient corruption on pool reuse across steps, but only with async streams enabled. Fix is to capture the compute stream before the context switch.

Here's whats tested:
15 unit tests covering offload eligibility, gradient correctness across multiple steps, pool alignment for mixed dtypes (bf16 + f32), pool exhaustion fallback, and composability with checkpoint_wrapper and FSDP2. Numerical smoke tests confirm float32 and bfloat16 gradients match baseline within tolerance across 3 steps with pool reuse.

Here's my limitations:
Multi-GPU FSDP integration test was not run since only one GPU is available to me. The stream ordering is correct in theory but unverified empirically

Offloads saved activations to CPU during forward via a pre-allocated
pinned slab, then prefetches back on a separate H2D stream before
backward. Near-zero overhead vs selective AC on bandwidth-bound ops.

Key fix: capture compute_stream before entering the d2h stream context.
Inside with torch.cuda.stream(s), current_stream() returns s, so the
wait_stream was a no-op causing pool corruption on reuse across steps.

Not supported with pipeline parallel.
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 15, 2026
@tianyu-l tianyu-l requested a review from soulitzer March 15, 2026 22:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

1 participant