Skip to content

Commit 4e1a835

Browse files
WanZzzzzzqiyuwgautham-kollu
authored
fp4 support (#14625)
Signed-off-by: qiyuw <qiyuw@nvidia.com> Co-authored-by: qiyuw <qiyuw@nvidia.com> Co-authored-by: gautham-kollu <gkollu@nvidia.com>
1 parent 6217032 commit 4e1a835

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

nemo/lightning/pytorch/plugins/mixed_precision.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class DtypeConfig:
6969
# fp8 related
7070
fp8: str = None
7171
fp8_recipe: str = "delayed"
72+
# fp4 related
73+
fp4: str = None
74+
fp4_recipe: str = "nvfp4"
7275
first_last_layers_bf16: bool = False
7376
fp8_margin: int = 0
7477
fp8_amax_history_len: int = 1
@@ -116,6 +119,9 @@ def __init__(
116119
fp8_multi_head_attention: bool = False,
117120
fp8_params: bool = None,
118121
fp8_param_gather: bool = None,
122+
# fp4 related
123+
fp4: str = None,
124+
fp4_recipe: str = "nvfp4",
119125
fp16_loss_scale: float = None,
120126
fp16_initial_loss_scale: float = 4294967296,
121127
fp16_min_loss_scale: float = 1.0,
@@ -161,6 +167,8 @@ def __init__(
161167
fp8_multi_head_attention=fp8_multi_head_attention,
162168
fp8_param=fp8_param_gather,
163169
fp8_param_gather=fp8_param_gather,
170+
fp4=fp4,
171+
fp4_recipe=fp4_recipe,
164172
num_layers_at_start_in_bf16=num_layers_at_start_in_bf16,
165173
num_layers_at_end_in_bf16=num_layers_at_end_in_bf16,
166174
reuse_grad_buf_for_mxfp8_param_ag=reuse_grad_buf_for_mxfp8_param_ag,

0 commit comments

Comments
 (0)