Skip to content

Commit 7b3f1e5

Browse files
authored
Merge branch 'dev' into optimize_hybrid_ep
2 parents 3ae9ffb + 23e092f commit 7b3f1e5

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

megatron/core/models/gpt/gpt_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def _postprocess(
562562
if not self.post_process:
563563
return hidden_states
564564

565-
if self.config.mtp_num_layers is not None:
565+
# Skip when mtp_num_layers is None or 0
566+
if self.config.mtp_num_layers:
566567
mtp_labels = labels.clone()
567568
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
568569
hidden_states = hidden_states_list[0]

0 commit comments

Comments
 (0)