1414
1515import atexit
1616import functools
17+ import gc
1718import inspect
1819import logging as _logging
1920import os
5556from megatron .core .dist_checkpointing .validation import StrictHandling
5657from megatron .core .distributed import DistributedDataParallelConfig
5758from megatron .core .optimizer import OptimizerConfig
59+ from megatron .training .training import cuda_graph_capture , cuda_graph_set_manual_hooks
5860from torch import nn
5961from torch .distributed .algorithms .ddp_comm_hooks .debugging_hooks import noop_hook
6062from torch .distributed .checkpoint .utils import CheckpointException
@@ -572,6 +574,10 @@ def setup_distributed(self) -> None:
572574 """Setups dist env"""
573575 setup_parallel_ranks (self )
574576
577+ # Capture the external cudagraph on a side stream
578+ if hasattr (self .model , 'config' ) and getattr (self .model .config , 'external_cuda_graph' , False ):
579+ torch .cuda .set_stream (torch .cuda .Stream ())
580+
575581 # Implementation from superclass copied below in order to pass the store to the process group init
576582 reset_seed ()
577583 self .set_world_ranks ()
@@ -720,6 +726,35 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
720726 assert self .lightning_module is not None
721727 assert isinstance (self .model , MegatronParallel )
722728
729+ # Capture the external cuda graph for the first step
730+ if (
731+ self .trainer .global_step == 0
732+ and hasattr (self .model , 'config' )
733+ and getattr (self .model .config , 'external_cuda_graph' , False )
734+ ):
735+ # disable the prehook
736+ if self .ddp_config .use_distributed_optimizer and self .ddp_config .overlap_param_gather :
737+ self .model .disable_forward_pre_hook ()
738+ param_sync_func = self .model .config .param_sync_func
739+ self .model .config .param_sync_func = None
740+ import argparse
741+
742+ partial_cg_args = argparse .Namespace ()
743+ partial_cg_args .position_embedding_type = self .model .config .position_embedding_type
744+ partial_cg_args .seq_length = self .trainer .datamodule .seq_length
745+ partial_cg_args .micro_batch_size = self .trainer .datamodule .micro_batch_size
746+ cuda_graph_capture (self .model , self .model .config , partial_cg_args )
747+
748+ # Set grad to zero.
749+ for model_chunk in self .model :
750+ model_chunk .zero_grad_buffer ()
751+ for opt in self .optimizers :
752+ opt .zero_grad ()
753+
754+ # Collect garbage and empty unused memory.
755+ gc .collect ()
756+ torch .cuda .empty_cache ()
757+
723758 with self .precision_plugin .train_step_context (): # TODO: Do we need this?
724759 # Set grad to zero.
725760 for model_chunk in self .model :
@@ -740,6 +775,18 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP
740775
741776 reduced_train_loss = out ["loss" ]
742777
778+ if (
779+ self .trainer .global_step == 0
780+ and hasattr (self .model , 'config' )
781+ and getattr (self .model .config , 'external_cuda_graph' , False )
782+ ):
783+ if self .ddp_config .use_distributed_optimizer and self .ddp_config .overlap_param_gather :
784+ # enable the prehook
785+ self .model .enable_forward_pre_hook ()
786+ self .model .config .param_sync_func = param_sync_func
787+ param_sync_func = None
788+ cuda_graph_set_manual_hooks (self .model )
789+
743790 self .lightning_module .log (
744791 "global_step" ,
745792 self .trainer .global_step ,
0 commit comments