Skip to content

Commit 15e3cd9

Browse files
committed
Add mistral small3 24B config and recipe
Signed-off-by: Joosung Yoon <joosungy@nvidia.com>
1 parent 709da78 commit 15e3cd9

5 files changed

Lines changed: 340 additions & 1 deletion

File tree

nemo/collections/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
MistralConfig7B,
142142
MistralModel,
143143
MistralNeMoConfig12B,
144+
MistralSmall3Config24B,
144145
MixtralConfig,
145146
MixtralConfig8x3B,
146147
MixtralConfig8x7B,
@@ -251,6 +252,7 @@
251252
"MaskedTokenLossReduction",
252253
"MistralConfig7B",
253254
"MistralNeMoConfig12B",
255+
"MistralSmall3Config24B",
254256
"MistralModel",
255257
"MixtralConfig",
256258
"MixtralConfig8x3B",

nemo/collections/llm/gpt/model/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,12 @@
110110
Llama33NemotronSuper49BConfig,
111111
LlamaNemotronModel,
112112
)
113-
from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel, MistralNeMoConfig12B
113+
from nemo.collections.llm.gpt.model.mistral import (
114+
MistralConfig7B,
115+
MistralModel,
116+
MistralNeMoConfig12B,
117+
MistralSmall3Config24B,
118+
)
114119
from nemo.collections.llm.gpt.model.mixtral import (
115120
MixtralConfig,
116121
MixtralConfig8x3B,
@@ -194,6 +199,7 @@
194199
"MistralConfig7B",
195200
"MistralModel",
196201
"MistralNeMoConfig12B",
202+
"MistralSmall3Config24B",
197203
"MixtralConfig8x3B",
198204
"MixtralConfig8x7B",
199205
"MixtralConfig8x22B",

nemo/collections/llm/gpt/model/mistral.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,26 @@ class MistralNeMoConfig123B(MistralConfig7B):
102102
params_dtype: torch.dtype = torch.bfloat16
103103

104104

105+
@dataclass
106+
class MistralSmall3Config24B(MistralConfig7B):
107+
"""
108+
https://mistral.ai/news/mistral-small-3/
109+
"""
110+
111+
num_layers: int = 40
112+
hidden_size: int = 5120
113+
ffn_hidden_size: int = 32768
114+
num_attention_heads: int = 32
115+
kv_channels: int = 128
116+
seq_length: int = 32768
117+
118+
window_size: List[int] = None
119+
cp_comm_type: str = None
120+
rotary_percent: float = 1.0
121+
rotary_base: float = 100000000.0
122+
params_dtype: torch.dtype = torch.bfloat16
123+
124+
105125
class MistralModel(GPTModel):
106126
""" """
107127

nemo/collections/llm/recipes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
mamba2_hybrid_8b,
6969
mistral_7b,
7070
mistral_nemo_12b,
71+
mistral_small3_24b,
7172
mixtral_8x7b,
7273
mixtral_8x7b_16k,
7374
mixtral_8x7b_64k,
@@ -171,6 +172,7 @@
171172
"nemotron_nano_12b_v2",
172173
"mistral_7b",
173174
"mistral_nemo_12b",
175+
"mistral_small3_24b",
174176
"hyena_base",
175177
"hyena_1b",
176178
"hyena_7b",
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Callable, Optional
17+
18+
import lightning.pytorch as pl
19+
import nemo_run as run
20+
import torch
21+
from lightning.pytorch.callbacks.callback import Callback
22+
from megatron.core.distributed import DistributedDataParallelConfig
23+
24+
from nemo import lightning as nl
25+
from nemo.collections.llm.api import finetune, pretrain
26+
from nemo.collections.llm.gpt.data.mock import MockDataModule
27+
from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralSmall3Config24B
28+
from nemo.collections.llm.peft import PEFT_STR2CLS
29+
from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe
30+
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
31+
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
32+
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
33+
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
34+
from nemo.utils.exp_manager import TimingCallback
35+
36+
NAME = "mistral_small3_24b"
37+
38+
39+
@run.cli.factory(name=NAME)
40+
def model() -> run.Config[pl.LightningModule]:
41+
"""
42+
Factory function to create a Mistral-Small-3-24B model configuration.
43+
44+
Returns:
45+
run.Config[pl.LightningModule]: Configuration for the Mistral-Small-3-24B model.
46+
47+
Examples:
48+
CLI usage:
49+
$ nemo llm pretrain model=mistral_small3_24b ...
50+
51+
Python API usage:
52+
>>> model_config = model()
53+
>>> print(model_config)
54+
"""
55+
return run.Config(MistralModel, config=run.Config(MistralSmall3Config24B))
56+
57+
58+
def trainer(
59+
tensor_parallelism: int = 4,
60+
pipeline_parallelism: int = 2,
61+
pipeline_parallelism_type: Optional[torch.dtype] = None,
62+
virtual_pipeline_parallelism: Optional[int] = None,
63+
context_parallelism: int = 1,
64+
sequence_parallelism: bool = True,
65+
num_nodes: int = 1,
66+
num_gpus_per_node: int = 8,
67+
max_steps: int = 1168251,
68+
callbacks: Optional[list[run.Config[Callback]]] = None,
69+
) -> run.Config[nl.Trainer]:
70+
"""
71+
Configure the NeMo Lightning Trainer for Mistral-Small-3-24B model.
72+
73+
This function sets up the distributed training strategy and other training parameters.
74+
75+
Args:
76+
tensor_parallelism (int): Degree of tensor model parallelism.
77+
pipeline_parallelism (int): Degree of pipeline model parallelism.
78+
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
79+
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
80+
context_parallelism (int): Degree of context parallelism.
81+
sequence_parallelism (bool): Whether to use sequence parallelism.
82+
num_nodes (int): Number of compute nodes to use.
83+
num_gpus_per_node (int): Number of GPUs per node.
84+
max_steps (int): Maximum number of training steps.
85+
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
86+
87+
Returns:
88+
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
89+
90+
Examples:
91+
CLI usage:
92+
$ nemo llm pretrain trainer=mistral_small3_24b ...
93+
94+
Python API usage:
95+
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
96+
>>> print(trainer_config)
97+
98+
Note:
99+
For more information on distributed training strategies, refer to the
100+
NeMo documentation on multi-GPU and multi-node training.
101+
"""
102+
strategy = run.Config(
103+
nl.MegatronStrategy,
104+
tensor_model_parallel_size=tensor_parallelism,
105+
pipeline_model_parallel_size=pipeline_parallelism,
106+
pipeline_dtype=pipeline_parallelism_type,
107+
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
108+
context_parallel_size=context_parallelism,
109+
sequence_parallel=sequence_parallelism,
110+
gradient_as_bucket_view=True,
111+
ckpt_async_save=True,
112+
ckpt_parallel_load=True,
113+
ddp=run.Config(
114+
DistributedDataParallelConfig,
115+
check_for_nan_in_grad=True,
116+
grad_reduce_in_fp32=True,
117+
overlap_grad_reduce=True,
118+
overlap_param_gather=True,
119+
),
120+
)
121+
122+
trainer = run.Config(
123+
nl.Trainer,
124+
accelerator="gpu",
125+
accumulate_grad_batches=1,
126+
callbacks=callbacks,
127+
devices=num_gpus_per_node,
128+
limit_test_batches=50,
129+
limit_val_batches=32,
130+
log_every_n_steps=10,
131+
max_steps=max_steps,
132+
num_nodes=num_nodes,
133+
plugins=bf16_mixed(),
134+
strategy=strategy,
135+
use_distributed_sampler=False,
136+
val_check_interval=2000,
137+
)
138+
139+
return trainer
140+
141+
142+
@run.cli.factory(target=pretrain, name=NAME)
143+
def pretrain_recipe(
144+
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
145+
) -> run.Partial:
146+
"""
147+
Create a pre-training recipe for Mistral-Small-3-24B model.
148+
149+
This function sets up a complete configuration for pre-training, including
150+
model, trainer, data, logging, optimization, and resumption settings.
151+
152+
Args:
153+
dir (Optional[str]): Directory for saving logs and checkpoints.
154+
name (str): Name of the pre-training run.
155+
num_nodes (int): Number of compute nodes to use.
156+
num_gpus_per_node (int): Number of GPUs per node.
157+
fn (Callable): The pre-training function to use.
158+
159+
Returns:
160+
run.Partial: Partial configuration for pre-training.
161+
162+
Examples:
163+
CLI usage:
164+
$ nemo llm pretrain --factory mistral_small3_24b
165+
$ nemo llm pretrain --factory "mistral_small3_24b(num_nodes=2, name='my_pretrain')"
166+
167+
Python API usage:
168+
>>> recipe = pretrain_recipe(name="mistral_small3_24b", num_nodes=2)
169+
>>> print(recipe)
170+
"""
171+
return run.Partial(
172+
fn,
173+
model=model(),
174+
trainer=trainer(
175+
num_nodes=num_nodes,
176+
num_gpus_per_node=num_gpus_per_node,
177+
callbacks=[run.Config(TimingCallback)],
178+
),
179+
data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1),
180+
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
181+
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
182+
resume=default_resume(),
183+
)
184+
185+
186+
@run.cli.factory(target=pretrain, name=NAME + "_optimized")
187+
def pretrain_recipe_performance(
188+
dir: Optional[str] = None,
189+
name: str = "default",
190+
num_nodes: int = 1,
191+
num_gpus_per_node: int = 8,
192+
fn: Callable = pretrain,
193+
) -> run.Partial:
194+
"""
195+
Create a performance-optimized pre-training recipe for Mistral-Small-3-24B model.
196+
197+
This recipe enables performance optimizations that may not be suitable for all use cases.
198+
It builds upon the standard pre-training recipe and adds additional performance enhancements.
199+
200+
Args:
201+
dir (Optional[str]): Directory for saving logs and checkpoints.
202+
name (str): Name of the pre-training run.
203+
num_nodes (int): Number of compute nodes to use.
204+
num_gpus_per_node (int): Number of GPUs per node.
205+
fn (Callable): The pre-training function to use.
206+
207+
Returns:
208+
run.Partial: Partial configuration for performance-optimized pre-training.
209+
210+
Examples:
211+
$ nemo llm pretrain --factory mistral_small3_24b_optimized
212+
213+
Python API usage:
214+
>>> recipe = pretrain_recipe_performance(name="mistral_small3_24b_perf", num_nodes=4)
215+
>>> print(recipe)
216+
217+
Note:
218+
Use this recipe with caution and only when you need maximum performance.
219+
It may not be suitable for all hardware configurations or use cases.
220+
"""
221+
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)
222+
223+
recipe.trainer.callbacks.append(
224+
run.Config(
225+
MegatronCommOverlapCallback,
226+
tp_comm_overlap=True,
227+
)
228+
)
229+
return recipe
230+
231+
232+
@run.cli.factory(target=finetune, name=NAME)
233+
def finetune_recipe(
234+
dir: Optional[str] = None,
235+
resume_path: str = "mistralai/Mistral-Small-3-24B-Instruct-2501",
236+
name: str = "default",
237+
num_nodes: int = 1,
238+
num_gpus_per_node: int = 8,
239+
peft_scheme: Optional[str] = 'lora',
240+
seq_length: Optional[int] = None,
241+
packed_sequence: bool = False,
242+
) -> run.Partial:
243+
"""
244+
Create a fine-tuning recipe for Mistral-Small-3-24B model.
245+
246+
This function sets up a complete configuration for fine-tuning, including
247+
model, trainer, data, logging, optimization, and resumption settings.
248+
The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None.
249+
250+
Args:
251+
dir (Optional[str]): Directory for saving logs and checkpoints.
252+
resume_path (str): Path to the NeMo checkpoint
253+
name (str): Name of the fine-tuning run.
254+
num_nodes (int): Number of compute nodes to use.
255+
num_gpus_per_node (int): Number of GPUs per node.
256+
peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning.
257+
Allowed values: 'lora'/'dora'/'none'/None.
258+
seq_length (int): Maximum number of tokens per microbatch.
259+
packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training
260+
efficiency. Default sequence length is 2048.
261+
262+
Returns:
263+
run.Partial: Partial configuration for fine-tuning.
264+
265+
Examples:
266+
CLI usage:
267+
$ nemo llm finetune --factory mistral_small3_24b
268+
269+
Python API usage:
270+
>>> recipe = finetune_recipe(name="mistral_small3_24b_finetune", num_nodes=2)
271+
>>> print(recipe)
272+
273+
Note:
274+
This recipe uses the SQuAD dataset for fine-tuning.
275+
"""
276+
277+
# For unpacked sequence, most samples in SQuAD dataset are shorter than 2K
278+
if seq_length is None:
279+
seq_length = 4096 if packed_sequence else 2048
280+
281+
recipe = default_finetune_recipe(
282+
model(),
283+
resume_path,
284+
dir,
285+
name,
286+
num_nodes,
287+
num_gpus_per_node,
288+
packed_sequence,
289+
)
290+
if peft_scheme is None or peft_scheme.lower() == 'none':
291+
recipe.trainer.strategy.tensor_model_parallel_size = 4
292+
recipe.trainer.strategy.pipeline_model_parallel_size = 2
293+
recipe.optim.config.lr = 5e-6
294+
elif peft_scheme.lower() in ['lora', 'dora']:
295+
recipe.peft = run.Config(
296+
PEFT_STR2CLS[peft_scheme.lower()], target_modules=['linear_qkv', 'linear_proj'], dim=32
297+
)
298+
recipe.optim.config.lr = 1e-4
299+
else:
300+
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")
301+
302+
# Sequence length settings in the model and dataset must agree
303+
recipe.model.config.seq_length = seq_length
304+
recipe.data.seq_length = seq_length
305+
if packed_sequence:
306+
recipe.data.dataset_kwargs = {'pad_to_max_length': True}
307+
recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length)
308+
309+
return recipe

0 commit comments

Comments
 (0)