Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import atexit
import functools
import gc
import inspect
import logging as _logging
import os
Expand Down Expand Up @@ -55,6 +56,7 @@
from megatron.core.dist_checkpointing.validation import StrictHandling
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from megatron.training.training import cuda_graph_capture, cuda_graph_set_manual_hooks
from torch import nn
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
from torch.distributed.checkpoint.utils import CheckpointException
Expand Down Expand Up @@ -572,6 +574,10 @@
"""Setups dist env"""
setup_parallel_ranks(self)

# Capture Cudagraph on a side stream
if self.model.config.external_cuda_graph:
torch.cuda.set_stream(torch.cuda.Stream())

# Implementation from superclass copied below in order to pass the store to the process group init
reset_seed()
self.set_world_ranks()
Expand Down Expand Up @@ -720,6 +726,31 @@
assert self.lightning_module is not None
assert isinstance(self.model, MegatronParallel)

# (TODO:) Capture the cuda graph for the first step
if self.trainer.global_step == 0 and self.model.config.external_cuda_graph:
# disable prehook
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
self.model.disable_forward_pre_hook()
param_sync_func = self.model.config.param_sync_func
self.model.config.param_sync_func = None
import argparse

partial_cg_args = argparse.Namespace()
partial_cg_args.position_embedding_type = self.model.config.position_embedding_type
partial_cg_args.seq_length = self.trainer.datamodule.seq_length
partial_cg_args.micro_batch_size = self.trainer.datamodule.micro_batch_size
cuda_graph_capture(self.model, self.model.config, partial_cg_args)

# Set grad to zero.
for model_chunk in self.model:
model_chunk.zero_grad_buffer()
for opt in self.optimizers:
opt.zero_grad()

# Collect garbage and empty unused memory.
gc.collect()
torch.cuda.empty_cache()

with self.precision_plugin.train_step_context(): # TODO: Do we need this?
# Set grad to zero.
for model_chunk in self.model:
Expand All @@ -740,6 +771,13 @@

reduced_train_loss = out["loss"]

if self.trainer.global_step == 0 and self.model.config.external_cuda_graph:
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
self.model.enable_forward_pre_hook()
self.model.config.param_sync_func = param_sync_func
param_sync_func = None
Comment thread Dismissed
cuda_graph_set_manual_hooks(self.model)

self.lightning_module.log(
"global_step",
self.trainer.global_step,
Expand Down
Loading