Skip to content

Commit 6396b09

Browse files
committed
fmt
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 490505f commit 6396b09

3 files changed

Lines changed: 11 additions & 8 deletions

File tree

nemo_automodel/components/distributed/megatron_fsdp.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717
import torch
1818
import torch.distributed as dist
19-
from torch.distributed.device_mesh import DeviceMesh
2019
import torch.nn as nn
20+
from torch.distributed.device_mesh import DeviceMesh
21+
2122
from nemo_automodel.components.distributed.config import MegatronFSDPConfig
2223
from nemo_automodel.components.distributed.parallelizer import (
2324
_get_parallel_plan,
@@ -29,6 +30,7 @@
2930
try:
3031
from megatron_fsdp import MegatronFSDP
3132
from megatron_fsdp.fully_shard import fully_shard_optimizer as megatron_fsdp_fully_shard_optimizer
33+
3234
HAS_MEGATRON_FSDP = True
3335
except (ImportError, FileNotFoundError):
3436
# raise FileNotFoundError(
@@ -160,14 +162,15 @@ def parallelize(self, model, optimizer=None):
160162

161163
return model, optimizer
162164

165+
163166
def fully_shard_optimizer(
164-
model: nn.Module,
165-
optimizer: torch.optim.Optimizer, preproc_state_dict_for_dcp_ckpt: bool = True
167+
model: nn.Module, optimizer: torch.optim.Optimizer, preproc_state_dict_for_dcp_ckpt: bool = True
166168
) -> torch.optim.Optimizer:
167-
"""
168-
"""
169+
""" """
169170
if not isinstance(model, MegatronFSDP):
170171
return optimizer
171172
if not HAS_MEGATRON_FSDP:
172-
raise ImportError("MegatronFSDP is not installed, please visit https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/distributed/fsdp/src for more information")
173+
raise ImportError(
174+
"MegatronFSDP is not installed, please visit https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/distributed/fsdp/src for more information"
175+
)
173176
return megatron_fsdp_fully_shard_optimizer(optimizer)

nemo_automodel/recipes/base_recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from transformers.tokenization_utils import PreTrainedTokenizerBase
4343

4444
from nemo_automodel.components.checkpoint.checkpointing import save_config
45-
from nemo_automodel.components.config.loader import config_to_yaml_str, ConfigNode
45+
from nemo_automodel.components.config.loader import ConfigNode, config_to_yaml_str
4646
from nemo_automodel.components.optim.scheduler import OptimizerParamScheduler
4747
from nemo_automodel.components.training.rng import StatefulRNG
4848
from nemo_automodel.components.training.step_scheduler import StepScheduler

nemo_automodel/recipes/llm/train_ft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@
4848
from nemo_automodel.components.datasets.llm.megatron_dataset import MegatronPretraining
4949
from nemo_automodel.components.datasets.llm.packed_sequence import pack_dataset
5050
from nemo_automodel.components.distributed.config import MegatronFSDPConfig
51-
from nemo_automodel.components.distributed.megatron_fsdp import fully_shard_optimizer
5251
from nemo_automodel.components.distributed.cp_utils import make_cp_batch_and_ctx
5352
from nemo_automodel.components.distributed.init_utils import (
5453
initialize_distributed,
5554
)
55+
from nemo_automodel.components.distributed.megatron_fsdp import fully_shard_optimizer
5656
from nemo_automodel.components.distributed.mesh import MeshContext
5757
from nemo_automodel.components.distributed.pipelining import AutoPipeline
5858
from nemo_automodel.components.distributed.utils import FirstRankPerNode, get_sync_ctx

0 commit comments

Comments
 (0)