@@ -141,7 +141,9 @@ def forward(
141141
142142 if position_ids is None :
143143 seq_len = embed_input .shape [1 ]
144- position_ids = torch .arange (seq_len , device = embed_input .device ).unsqueeze (0 ).expand (embed_input .shape [0 ], - 1 )
144+ position_ids = (
145+ torch .arange (seq_len , device = embed_input .device ).unsqueeze (0 ).expand (embed_input .shape [0 ], - 1 )
146+ )
145147 position_embeddings = self ._rotary_emb (embed_input , position_ids )
146148 position_embeddings_compress = self ._rotary_emb_compress (embed_input , position_ids )
147149
@@ -235,9 +237,7 @@ def forward(
235237 per_depth_h : list [torch .Tensor ] = []
236238 cur_input_ids = input_ids
237239 if embed_inputs is not None and len (embed_inputs ) != len (self .layers ):
238- raise ValueError (
239- f"Expected { len (self .layers )} MTP embedding tensors, got { len (embed_inputs )} "
240- )
240+ raise ValueError (f"Expected { len (self .layers )} MTP embedding tensors, got { len (embed_inputs )} " )
241241 if embed_inputs is None and (cur_input_ids is None or embed_fn is None ):
242242 raise ValueError ("MTP requires either embed_inputs or both input_ids and embed_fn" )
243243
@@ -263,7 +263,9 @@ def forward(
263263def build_mtp_config_from_hf (config , * , loss_scaling_factor : float = 0.1 ) -> MTPConfig :
264264 """Build an MTPConfig from a DeepseekV4Config."""
265265 num_layers = int (getattr (config , "num_nextn_predict_layers" , 0 ) or 0 )
266- return MTPConfig (num_layers = num_layers , layer_pattern = "*" if num_layers > 0 else "" , loss_scaling_factor = loss_scaling_factor )
266+ return MTPConfig (
267+ num_layers = num_layers , layer_pattern = "*" if num_layers > 0 else "" , loss_scaling_factor = loss_scaling_factor
268+ )
267269
268270
269271def build_deepseek_v4_mtp (
0 commit comments