Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None, vp_sta
Returns:
MCoreGPTModel: Configured Megatron Core GPT model instance
"""
if self.enable_cuda_graph:
if self.enable_cuda_graph or self.external_cuda_graph:
assert HAVE_TE, "Transformer Engine is required for cudagraphs."
assert getattr(self, "use_te_rng_tracker", False), (
"Transformer engine's RNG tracker is required for cudagraphs, it can be "
Expand Down
47 changes: 47 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import atexit
import functools
import gc
import inspect
import logging as _logging
import os
Expand Down Expand Up @@ -55,6 +56,7 @@
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from megatron.training.training import cuda_graph_capture, cuda_graph_set_manual_hooks
from torch import nn
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
from torch.distributed.checkpoint.utils import CheckpointException
Expand Down Expand Up @@ -572,6 +574,10 @@ def setup_distributed(self) -> None:
"""Setups dist env"""
setup_parallel_ranks(self)

# Capture the external cudagraph on a side stream
if hasattr(self.model, 'config') and getattr(self.model.config, 'external_cuda_graph', False):
torch.cuda.set_stream(torch.cuda.Stream())

# Implementation from superclass copied below in order to pass the store to the process group init
reset_seed()
self.set_world_ranks()
Expand Down Expand Up @@ -720,6 +726,35 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
assert self.lightning_module is not None
assert isinstance(self.model, MegatronParallel)

# Capture the external cuda graph for the first step
if (
self.trainer.global_step == 0
and hasattr(self.model, 'config')
and getattr(self.model.config, 'external_cuda_graph', False)
):
# disable the prehook
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
self.model.disable_forward_pre_hook()
param_sync_func = self.model.config.param_sync_func
self.model.config.param_sync_func = None
import argparse

partial_cg_args = argparse.Namespace()
partial_cg_args.position_embedding_type = self.model.config.position_embedding_type
partial_cg_args.seq_length = self.trainer.datamodule.seq_length
partial_cg_args.micro_batch_size = self.trainer.datamodule.micro_batch_size
cuda_graph_capture(self.model, self.model.config, partial_cg_args)

# Set grad to zero.
for model_chunk in self.model:
model_chunk.zero_grad_buffer()
for opt in self.optimizers:
opt.zero_grad()

# Collect garbage and empty unused memory.
gc.collect()
torch.cuda.empty_cache()

with self.precision_plugin.train_step_context(): # TODO: Do we need this?
# Set grad to zero.
for model_chunk in self.model:
Expand All @@ -740,6 +775,18 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP

reduced_train_loss = out["loss"]

if (
self.trainer.global_step == 0
and hasattr(self.model, 'config')
and getattr(self.model.config, 'external_cuda_graph', False)
):
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
# enable the prehook
self.model.enable_forward_pre_hook()
self.model.config.param_sync_func = param_sync_func
param_sync_func = None
cuda_graph_set_manual_hooks(self.model)

self.lightning_module.log(
"global_step",
self.trainer.global_step,
Expand Down
Loading