Skip to content

Commit 6db849f

Browse files
gdengkguyueh1
authored andcommitted
[MoE] Partial Cudagraph support for MoE (NVIDIA-NeMo#14362)
* local experimental setup Signed-off-by: gdeng <gdeng@nvidia.com> * local working version Signed-off-by: gdeng <gdeng@nvidia.com> * cleanup Signed-off-by: gdeng <gdeng@nvidia.com> * Apply isort and black reformatting Signed-off-by: gdengk <gdengk@users.noreply.github.com> * fix the issue Signed-off-by: gdeng <gdeng@nvidia.com> * Apply isort and black reformatting Signed-off-by: gdengk <gdengk@users.noreply.github.com> --------- Signed-off-by: gdeng <gdeng@nvidia.com> Signed-off-by: gdengk <gdengk@users.noreply.github.com> Co-authored-by: gdengk <gdengk@users.noreply.github.com> Signed-off-by: Guyue Huang <guyueh@nvidia.com>
1 parent 02baadb commit 6db849f

2 files changed

Lines changed: 48 additions & 1 deletion

File tree

nemo/collections/llm/gpt/model/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None, vp_sta
326326
Returns:
327327
MCoreGPTModel: Configured Megatron Core GPT model instance
328328
"""
329-
if self.enable_cuda_graph:
329+
if self.enable_cuda_graph or self.external_cuda_graph:
330330
assert HAVE_TE, "Transformer Engine is required for cudagraphs."
331331
assert getattr(self, "use_te_rng_tracker", False), (
332332
"Transformer engine's RNG tracker is required for cudagraphs, it can be "

nemo/lightning/pytorch/strategies/megatron_strategy.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import atexit
1616
import functools
17+
import gc
1718
import inspect
1819
import logging as _logging
1920
import os
@@ -55,6 +56,7 @@
5556
from megatron.core.dist_checkpointing.validation import StrictHandling
5657
from megatron.core.distributed import DistributedDataParallelConfig
5758
from megatron.core.optimizer import OptimizerConfig
59+
from megatron.training.training import cuda_graph_capture, cuda_graph_set_manual_hooks
5860
from torch import nn
5961
from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook
6062
from torch.distributed.checkpoint.utils import CheckpointException
@@ -572,6 +574,10 @@ def setup_distributed(self) -> None:
572574
"""Setups dist env"""
573575
setup_parallel_ranks(self)
574576

577+
# Capture the external cudagraph on a side stream
578+
if hasattr(self.model, 'config') and getattr(self.model.config, 'external_cuda_graph', False):
579+
torch.cuda.set_stream(torch.cuda.Stream())
580+
575581
# Implementation from superclass copied below in order to pass the store to the process group init
576582
reset_seed()
577583
self.set_world_ranks()
@@ -720,6 +726,35 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
720726
assert self.lightning_module is not None
721727
assert isinstance(self.model, MegatronParallel)
722728

729+
# Capture the external cuda graph for the first step
730+
if (
731+
self.trainer.global_step == 0
732+
and hasattr(self.model, 'config')
733+
and getattr(self.model.config, 'external_cuda_graph', False)
734+
):
735+
# disable the prehook
736+
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
737+
self.model.disable_forward_pre_hook()
738+
param_sync_func = self.model.config.param_sync_func
739+
self.model.config.param_sync_func = None
740+
import argparse
741+
742+
partial_cg_args = argparse.Namespace()
743+
partial_cg_args.position_embedding_type = self.model.config.position_embedding_type
744+
partial_cg_args.seq_length = self.trainer.datamodule.seq_length
745+
partial_cg_args.micro_batch_size = self.trainer.datamodule.micro_batch_size
746+
cuda_graph_capture(self.model, self.model.config, partial_cg_args)
747+
748+
# Set grad to zero.
749+
for model_chunk in self.model:
750+
model_chunk.zero_grad_buffer()
751+
for opt in self.optimizers:
752+
opt.zero_grad()
753+
754+
# Collect garbage and empty unused memory.
755+
gc.collect()
756+
torch.cuda.empty_cache()
757+
723758
with self.precision_plugin.train_step_context(): # TODO: Do we need this?
724759
# Set grad to zero.
725760
for model_chunk in self.model:
@@ -740,6 +775,18 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
740775

741776
reduced_train_loss = out["loss"]
742777

778+
if (
779+
self.trainer.global_step == 0
780+
and hasattr(self.model, 'config')
781+
and getattr(self.model.config, 'external_cuda_graph', False)
782+
):
783+
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
784+
# enable the prehook
785+
self.model.enable_forward_pre_hook()
786+
self.model.config.param_sync_func = param_sync_func
787+
param_sync_func = None
788+
cuda_graph_set_manual_hooks(self.model)
789+
743790
self.lightning_module.log(
744791
"global_step",
745792
self.trainer.global_step,

0 commit comments

Comments
 (0)