Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions scripts/mini_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,16 @@ def create(arch: str, output_dir: Path) -> None:
def verify(arch: str, model_dir: Path) -> None:
preset = ARCH_PRESETS[arch]
print(f"Verifying HF <-> PrimeRL roundtrip for {model_dir}...")
device = torch.device("cuda")

trust_remote_code = preset["hf_model_class"] is None
config = AutoConfig.from_pretrained(str(model_dir), trust_remote_code=trust_remote_code)
config._attn_implementation = "sdpa"

with torch.device("cuda"), default_dtype(torch.float32):
# Avoid loading inside a torch.device("cuda") context. Recent transformers
# versions treat this like a default-device/device_map workflow and require
# accelerate for from_pretrained().
with default_dtype(torch.float32):
hf_model = _load_hf_model(preset, model_dir, config)
prime_model = preset["prime_model_class"]._from_config(config)

Expand All @@ -155,10 +159,14 @@ def verify(arch: str, model_dir: Path) -> None:
prime_model.load_state_dict(state_dict)

inject_prime_lm_head(prime_model, chunk_size=None)
hf_model = hf_model.to(device)
prime_model = prime_model.to(device)

with torch.device("cuda"), default_dtype(torch.float32):
with default_dtype(torch.float32):
input_ids = torch.randint(0, config.vocab_size, (1, 64))
position_ids = torch.arange(1, 65).unsqueeze(0)
input_ids = input_ids.to(device)
position_ids = position_ids.to(device)

hf_output = hf_model(input_ids=input_ids, position_ids=position_ids)
prime_output = prime_model(input_ids, position_ids)
Expand All @@ -170,9 +178,10 @@ def verify(arch: str, model_dir: Path) -> None:

with torch.no_grad():
roundtrip_state_dict = prime_model.convert_to_hf(prime_model.state_dict())
with torch.device("cuda"), default_dtype(torch.float32):
with default_dtype(torch.float32):
hf_roundtrip = _create_hf_model_from_config(preset, config)
hf_roundtrip.load_state_dict(roundtrip_state_dict)
hf_roundtrip = hf_roundtrip.to(device)

hf_roundtrip_output = hf_roundtrip(input_ids=input_ids, position_ids=position_ids)
roundtrip_diff = hf_roundtrip_output.logits - hf_output.logits
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ class Glm4MoeConfig(PretrainedConfig):
"norm": (["hidden_states"], ["hidden_states"]),
}

@property
def head_dim(self) -> int:
return getattr(self, "_head_dim", self.hidden_size // self.num_attention_heads)

@head_dim.setter
def head_dim(self, value: int) -> None:
self._head_dim = value

def __init__(
self,
vocab_size=151552,
Expand Down
9 changes: 3 additions & 6 deletions src/prime_rl/trainer/models/glm4_moe/converting_glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,13 @@ def convert_tt_layer_to_hf(state_dict: dict[str, Tensor], layer_index: int):
state_dict[f"model.layers.{i}.mlp.gate.weight"] = state_dict[f"model.layers.{i}.mlp.router.gate.weight"]
del state_dict[f"model.layers.{i}.mlp.router.gate.weight"]

# Routed experts - convert to per-expert format (compatible with vLLM and transformers)
# Routed experts - convert to fused HF format (transformers 5.0+)
w1 = state_dict.pop(f"model.layers.{i}.mlp.experts.w1") # (num_experts, moe_dim, dim)
w2 = state_dict.pop(f"model.layers.{i}.mlp.experts.w2") # (num_experts, dim, moe_dim)
w3 = state_dict.pop(f"model.layers.{i}.mlp.experts.w3") # (num_experts, moe_dim, dim)

num_experts = w1.shape[0]
for j in range(num_experts):
state_dict[f"model.layers.{i}.mlp.experts.{j}.gate_proj.weight"] = w1[j]
state_dict[f"model.layers.{i}.mlp.experts.{j}.down_proj.weight"] = w2[j]
state_dict[f"model.layers.{i}.mlp.experts.{j}.up_proj.weight"] = w3[j]
state_dict[f"model.layers.{i}.mlp.experts.gate_up_proj"] = torch.cat([w1, w3], dim=1)
state_dict[f"model.layers.{i}.mlp.experts.down_proj"] = w2


def convert_tt_to_hf_moe(state_dict: dict[str, Tensor]):
Expand Down
27 changes: 27 additions & 0 deletions src/prime_rl/trainer/rl/broadcast/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Literal

import torch
import torch.nn as nn
from torch.distributed.tensor import DTensor

Expand All @@ -20,6 +21,29 @@
from prime_rl.utils.utils import get_broadcast_dir, get_step_path


def _convert_glm4_moe_fused_experts_to_per_expert(state_dict: dict[str, torch.Tensor]) -> None:
"""Convert GLM4-MoE fused expert tensors to per-expert weights for vLLM live reload."""
layer_prefixes = {
key.removesuffix(".gate_up_proj")
for key in state_dict
if key.startswith("model.layers.") and key.endswith(".mlp.experts.gate_up_proj")
}

for prefix in layer_prefixes:
gate_up_proj = state_dict.pop(f"{prefix}.gate_up_proj")
down_proj = state_dict.pop(f"{prefix}.down_proj")

num_experts, fused_dim, _ = gate_up_proj.shape
moe_dim = fused_dim // 2
gate_proj = gate_up_proj[:, :moe_dim, :]
up_proj = gate_up_proj[:, moe_dim:, :]

for expert_idx in range(num_experts):
state_dict[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_proj[expert_idx].contiguous()
state_dict[f"{prefix}.{expert_idx}.down_proj.weight"] = down_proj[expert_idx].contiguous()
state_dict[f"{prefix}.{expert_idx}.up_proj.weight"] = up_proj[expert_idx].contiguous()


class FileSystemWeightBroadcast(WeightBroadcast):
"""Broadcast weights into the inference engine via shared filesystem."""

Expand Down Expand Up @@ -51,6 +75,9 @@ def broadcast_weights(self, model: nn.Module, step: int) -> None:

state_dict = revert_weight_conversion(model, state_dict)

if getattr(getattr(model, "config", None), "model_type", None) == "glm4_moe":
_convert_glm4_moe_fused_experts_to_per_expert(state_dict)

for idx in self.multi_run_manager.ready_to_update_idxs:
self.logger.debug(
f"Broadcasting weights for run {idx} (ready_to_update={self.multi_run_manager.ready_to_update[idx]})"
Expand Down