Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/collections/llm/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
bert_110m,
bert_340m,
chatglm3_6b,
e5_340m,
gemma2,
gemma2_2b,
gemma2_9b,
Expand Down Expand Up @@ -84,6 +85,7 @@
"bert_110m",
"bert_340m",
"chatglm3_6b",
"e5_340m",
"gemma_2b",
"gemma_7b",
"llama3_8b",
Expand Down
138 changes: 138 additions & 0 deletions nemo/collections/llm/recipes/bert_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import lightning.pytorch as pl
import nemo_run as run
import torch
from lightning.pytorch.callbacks.callback import Callback
from megatron.core.distributed import DistributedDataParallelConfig

from nemo import lightning as nl
from nemo.collections.llm import BertEmbeddingLargeConfig, BertEmbeddingMiniConfig, BertEmbeddingModel
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed, fp16_mixed


def bert_embedding_model(version: str) -> run.Config[pl.LightningModule]:
"""
A function to create a Bert models.

Args:
version (str): The version of the Nemotron model to create. one of ["bert_110m", "bert_340m"].
bert_type (str): The Bert type. either "megatron" or "huggingface".

Returns:
run.Config[pl.LightningModule]: Configuration for the Bert model.
"""
config = None
if "340m" in version:
config = run.Config(BertEmbeddingLargeConfig)
elif "110m" in version:
config = run.Config(BertEmbeddingMiniConfig)

assert config is not None, f"Invalid BERT version: {version}"
return run.Config(BertEmbeddingModel, config=config)


def bert_trainer(
tensor_parallelism: int = 2,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 1,
sequence_parallelism: bool = False,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
precision: str = "bf16-mixed",
accumulate_grad_batches: int = 1,
limit_test_batches: int = 32,
limit_val_batches: int = 32,
log_every_n_steps: int = 10,
val_check_interval: int = 2000,
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for BERT models.

This function sets up the distributed training strategy and other training parameters.

Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
precision (str): Precision configuration, one of fp32, 16-mixed or bf16-mixed.
accumulate_grad_batches (int): Number of steps per gradient accumulation.
limit_test_batches (int): Limit the number of test batches.
limit_val_batches (int): Limit the number of validation batches.
log_every_n_steps (int): Log every n steps.
val_check_interval (int): Run validation every N steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.

Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
ckpt_include_optimizer=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=False,
overlap_param_gather=True,
average_in_collective=True,
),
)

precision_plugin = None
if precision == "16-mixed":
precision_plugin = fp16_mixed()
elif precision == "bf16-mixed":
precision_plugin = bf16_mixed()

trainer = run.Config(
nl.Trainer,
accelerator="gpu",
callbacks=callbacks,
devices=num_gpus_per_node,
accumulate_grad_batches=accumulate_grad_batches,
limit_test_batches=limit_test_batches,
limit_val_batches=limit_val_batches,
log_every_n_steps=log_every_n_steps,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=precision_plugin,
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=val_check_interval,
)

return trainer
104 changes: 104 additions & 0 deletions nemo/collections/llm/recipes/e5_340m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import lightning.pytorch as pl
import nemo_run as run
from nemo.collections import llm
from nemo.collections.llm.api import finetune
from nemo.collections.llm.recipes.bert_embedding import bert_embedding_model
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe

NAME = "e5_340m"


@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a E5-Large (340 million) model configuration.

Returns:
run.Config[pl.LightningModule]: Configuration for the E5-Large (340 million) model.

Examples:
CLI usage:
$ nemo llm pretrain model=e5_340m ...

Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return bert_embedding_model(version=NAME)


@run.cli.factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
resume_path: str = "intfloat/e5-large-v2",
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
peft_scheme: Optional[str] = None,
seq_length: int = 512,
micro_batch_size: int = 4,
global_batch_size: int = 32,
) -> run.Partial:
"""
Create a fine-tuning recipe for E5-large (340 million) model.

This function sets up a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.
Only SFT is currently supported for E5 model.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning.
Allowed values: 'none'/None.
resume_path (str): Path to the NeMo checkpoint
seq_length (int): Maximum number of tokens per microbatch.
micro_batch_size (int): Micro batch size.
global_batch_size (int): Global batch size.


Returns:
run.Partial: Partial configuration for fine-tuning.

Examples:
CLI usage:
$ nemo llm finetune --factory e5_340m

Python API usage:
>>> recipe = finetune_recipe(name="e5_340m_finetune", num_nodes=1)
>>> print(recipe)

Note:
This recipe uses the Specter dataset for fine-tuning. For more information
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
"""
recipe = default_finetune_recipe(model(), resume_path, dir, name, num_nodes, num_gpus_per_node)
datamodule = run.Config(
llm.SpecterDataModule,
seq_length=seq_length,
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
)
recipe.data = datamodule

assert peft_scheme is None or peft_scheme.lower() == 'none', 'E5 only supports SFT.'
return recipe
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,7 @@ def fwd_output_only_func(dataloader_iter, model):
and isinstance(model.module, MCoreGPTModel)
):
attention_mask = None

output_tensor = model(tokens, position_ids, attention_mask, **extra_arg)

# Advance inference sequence offset.
Expand Down