Skip to content

[Weight Transfer] NIXL + MX Integration#2389

Open
S1ro1 wants to merge 35 commits intomainfrom
nixl_mx
Open

[Weight Transfer] NIXL + MX Integration#2389
S1ro1 wants to merge 35 commits intomainfrom
nixl_mx

Conversation

@S1ro1
Copy link
Copy Markdown
Collaborator

@S1ro1 S1ro1 commented May 2, 2026

Consolidates all the existing PRs for NIXl + MX into a clean implementation where I actually watched my agents

image BF16 <-> BF16 + BF16 <-> FP8 Qwen3 30B on dummy task. image Qwen3 235B A22B BF16 <-> BF16 transfer time comparison

Note

High Risk
Introduces a new RDMA-based weight transfer backend and changes orchestrator/trainer/inference coordination logic, which can affect distributed training stability and correctness. Also adds new runtime dependencies (Model Express, protobuf) and Docker-launched services that may impact deployment and CI.

Overview
Enables a new nixl_mx weight broadcast option that pushes weights trainer→inference via NIXL RDMA while using Model Express (MX) for rendezvous/metadata, wiring this through configs, orchestrator init, and inference worker extensions.

Implements the NIXL+MX transport stack: new vLLM worker (NIXLMxWeightUpdateWorker) and /init_nixl_mx endpoint, trainer-side NIXLMxWeightBroadcast with slot-based buffer allocation + conversion (bf16 passthrough / FP8 blockwise) and a TransportPlan that posts RDMA WRITEs based on MX-published tensor descriptors and expert maps.

Updates SLURM templates to optionally launch the MX + Redis stack via docker compose, sets UCX/LD paths for RDMA, and adjusts orchestrator scheduling/weight update flow to use MX status handshakes instead of filesystem markers when in nixl_mx mode. Adds unit/integration tests for MX rendezvous and conversion/slot planning, and adds modelexpress/protobuf dependencies plus a tilelang CUDA runtime preload shim.

Reviewed by Cursor Bugbot for commit dabaa19. Bugbot is set up for automated code reviews on this repo. Configure here.

Comment thread src/prime_rl/transport/nixl_agent.py
Comment thread src/prime_rl/trainer/rl/broadcast/__init__.py Outdated
Comment thread src/prime_rl/trainer/models/glm_moe_dsa/sparse_mla_attention.py Outdated
Comment thread src/prime_rl/transport/classic_cuda_pool.py
S1ro1 added 3 commits May 3, 2026 02:56
These files were on the branch from prior unrelated work (commit 40ff419,
plus sonic_ep.py accidentally staged with the config scaffold). Saved to
../scratchpad/nixl_mx_unrelated/ for future reapplication.
Comment thread src/prime_rl/trainer/models/slots.py
Comment thread src/prime_rl/inference/vllm/server.py
Comment thread src/prime_rl/configs/rl.py
Comment thread src/prime_rl/trainer/rl/broadcast/nixl_mx.py Outdated
Comment thread src/prime_rl/trainer/rl/broadcast/nixl_mx.py Outdated
Comment thread src/prime_rl/trainer/rl/broadcast/nixl_mx.py Outdated
Comment thread src/prime_rl/configs/rl.py
Comment thread src/prime_rl/inference/vllm/worker/nixl_mx.py
S1ro1 and others added 2 commits May 3, 2026 16:01
Signed-off-by: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com>
Comment thread src/prime_rl/inference/vllm/server.py Outdated
Comment thread src/prime_rl/trainer/rl/broadcast/nixl_mx.py Outdated
Comment thread src/prime_rl/transport/transport_plan.py
Comment thread tests/unit/transport/conftest.py
Comment thread src/prime_rl/configs/rl.py
Comment thread src/prime_rl/trainer/rl/broadcast/nixl_mx.py
Comment thread src/prime_rl/orchestrator/orchestrator.py
S1ro1 and others added 2 commits May 5, 2026 20:30
Signed-off-by: Matej Sirovatka <54212263+S1ro1@users.noreply.github.com>
Comment thread src/prime_rl/orchestrator/scheduler.py
@torch.no_grad()
def update_weights_from_path(self, weight_dir: str | None = None) -> None:
"""Block until the trainer's RDMA push completes, then recompute the MLA absorbed weights and return, orchestrator can then call `/resume`"""
self.rendezvous.wait_for_all_peers_ready(timeout=1200)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded timeout ignores configurable timeout setting

Medium Severity

update_weights_from_path calls self.rendezvous.wait_for_all_peers_ready(timeout=1200) with a hardcoded 1200-second timeout. The trainer and orchestrator both use a configurable timeout from their respective configs. If a user configures a different timeout (e.g., for very large models needing longer), the inference worker won't respect it and could time out prematurely or wait too long.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 9d9b2cc. Configure here.

Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

There are 3 total unresolved issues (including 1 from previous review).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit 75efd8c. Configure here.

)
self.checkpoint_ready.clear()
wait_for_ckpt_start_time = time.perf_counter()
await wait_for_path(get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) / "STABLE")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scheduler filesystem probe broken for NIXL+MX mode

Medium Severity

_compute_next_ckpt_step always calls get_latest_ckpt_step on the filesystem broadcast directory, but NIXL+MX never writes checkpoint files there. This means latest_ckpt_step is always 0, forcing the scheduler into effectively strict-async behavior regardless of the strict_async_level setting. In non-strict mode with filesystem/NCCL, the orchestrator can jump ahead to the latest available checkpoint; this optimization is silently lost with NIXL+MX.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 75efd8c. Configure here.

self.rendezvous.set_status(p2p_pb2.SOURCE_STATUS_READY)
self.logger.debug(f"NIXL+MX push completed in {time.perf_counter() - start:.2f}s")

dist.barrier()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-primary HSDP ranks skip lazy_init but hit barrier deadlock risk

Medium Severity

In broadcast_weights, when dp_replicate > 1, the master rank (rank 0) calls self.rendezvous.wait_for_all_peers_ready which blocks synchronously waiting for the orchestrator, while all other ranks (including non-primary HSDP ranks) proceed directly to dist.barrier(). If the orchestrator is slow to signal, rank 0 blocks before the barrier while other ranks are already waiting at the barrier. This is fine structurally, but wait_for_all_peers_ready is a synchronous blocking call that holds the GIL and busy-polls with time.sleep(0.05), which could delay the barrier on large clusters.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 75efd8c. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant