@@ -69,6 +69,9 @@ class DtypeConfig:
6969 # fp8 related
7070 fp8 : str = None
7171 fp8_recipe : str = "delayed"
72+ # fp4 related
73+ fp4 : str = None
74+ fp4_recipe : str = "nvfp4"
7275 first_last_layers_bf16 : bool = False
7376 fp8_margin : int = 0
7477 fp8_amax_history_len : int = 1
@@ -116,6 +119,9 @@ def __init__(
116119 fp8_multi_head_attention : bool = False ,
117120 fp8_params : bool = None ,
118121 fp8_param_gather : bool = None ,
122+ # fp4 related
123+ fp4 : str = None ,
124+ fp4_recipe : str = "nvfp4" ,
119125 fp16_loss_scale : float = None ,
120126 fp16_initial_loss_scale : float = 4294967296 ,
121127 fp16_min_loss_scale : float = 1.0 ,
@@ -161,6 +167,8 @@ def __init__(
161167 fp8_multi_head_attention = fp8_multi_head_attention ,
162168 fp8_param = fp8_param_gather ,
163169 fp8_param_gather = fp8_param_gather ,
170+ fp4 = fp4 ,
171+ fp4_recipe = fp4_recipe ,
164172 num_layers_at_start_in_bf16 = num_layers_at_start_in_bf16 ,
165173 num_layers_at_end_in_bf16 = num_layers_at_end_in_bf16 ,
166174 reuse_grad_buf_for_mxfp8_param_ag = reuse_grad_buf_for_mxfp8_param_ag ,
0 commit comments