Skip to content

Commit 19753e8

Browse files
authored
feat: pretrain dfm automodel (#36)
* init Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add sigma_min/amx Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add sigma_min/max Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * rename fientune.py to train.py Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add from_config Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * pass scheduler and model Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * update param Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * introduce NeMoWanPipeline Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add mode Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * update build_model_and_optimizer Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * update Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * update NeMoWanPipeline Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * rename Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * move examples Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * move Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix imports Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * lint Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * more lint Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix import Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix 3rdparty & pyproject Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * add torch Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * update uv.lock Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * update Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * revert 3rdparty Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * update uv.lock Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * fix Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> * update uv.lock Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com> --------- Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 5bb89f5 commit 19753e8

16 files changed

Lines changed: 304 additions & 81 deletions

File tree

3rdparty/Megatron-Bridge

Submodule Megatron-Bridge updated 77 files

dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import logging
1617
import os
1718
from typing import Any, Dict, Iterable, Optional, Tuple
1819

1920
import torch
2021
import torch.nn as nn
21-
from Automodel.distributed.dfm_parallelizer import WanParallelizationStrategy
22-
from diffusers import DiffusionPipeline
22+
from diffusers import DiffusionPipeline, WanPipeline
2323
from nemo_automodel.components.distributed import parallelizer
2424
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
2525
from nemo_automodel.shared.utils import dtype_from_str
2626

27+
from dfm.src.automodel.distributed.dfm_parallelizer import WanParallelizationStrategy
28+
2729

2830
logger = logging.getLogger(__name__)
2931

@@ -154,3 +156,71 @@ def from_pretrained(
154156
parallel_module = manager.parallelize(comp_module)
155157
setattr(pipe, comp_name, parallel_module)
156158
return pipe, created_managers
159+
160+
161+
class NeMoWanPipeline:
162+
def __init__(self, *args, **kwargs):
163+
super().__init__(*args, **kwargs)
164+
165+
@classmethod
166+
def from_pretrained(cls, *args, **kwargs):
167+
return NeMoAutoDiffusionPipeline.from_pretrained(*args, **kwargs)
168+
169+
@classmethod
170+
def from_config(
171+
cls,
172+
model_id,
173+
torch_dtype: torch.dtype = torch.bfloat16,
174+
config: dict = None,
175+
parallel_scheme: Optional[Dict[str, Dict[str, Any]]] = None,
176+
device: Optional[torch.device] = None,
177+
move_to_device: bool = True,
178+
components_to_load: Optional[Iterable[str]] = None,
179+
):
180+
# Load just the config
181+
from diffusers import WanTransformer3DModel
182+
183+
if config is None:
184+
transformer = WanTransformer3DModel.from_pretrained(
185+
model_id,
186+
subfolder="transformer",
187+
torch_dtype=torch.bfloat16,
188+
)
189+
190+
# Get config and reinitialize with random weights
191+
config = copy.deepcopy(transformer.config)
192+
del transformer
193+
194+
# Initialize with random weights
195+
transformer = WanTransformer3DModel.from_config(config)
196+
197+
# Load pipeline with random transformer
198+
pipe = WanPipeline.from_pretrained(
199+
model_id,
200+
transformer=transformer,
201+
torch_dtype=torch_dtype,
202+
)
203+
# Decide device
204+
dev = _choose_device(device)
205+
206+
# Move modules to device/dtype first (helps avoid initial OOM during sharding)
207+
if move_to_device:
208+
for name, module in _iter_pipeline_modules(pipe):
209+
if not components_to_load or name in components_to_load:
210+
logger.info("[INFO] Moving module: %s to device/dtype", name)
211+
_move_module_to_device(module, dev, torch_dtype)
212+
213+
# Use per-component FSDP2Manager init-args to parallelize components
214+
created_managers: Dict[str, FSDP2Manager] = {}
215+
if parallel_scheme is not None:
216+
assert torch.distributed.is_initialized(), "Expect distributed environment to be initialized"
217+
_init_parallelizer()
218+
for comp_name, comp_module in _iter_pipeline_modules(pipe):
219+
manager_args = parallel_scheme.get(comp_name)
220+
if manager_args is None:
221+
continue
222+
manager = FSDP2Manager(**manager_args)
223+
created_managers[comp_name] = manager
224+
parallel_module = manager.parallelize(comp_module)
225+
setattr(pipe, comp_name, parallel_module)
226+
return pipe, created_managers

dfm/src/automodel/datasets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from Automodel.datasets.wan21 import (
15+
from dfm.src.automodel.datasets.wan21 import (
1616
MetaFilesDataset,
1717
build_node_parallel_sampler,
1818
build_wan21_dataloader,

dfm/src/automodel/flow_matching/training_step_t2v.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from typing import Dict, Tuple
2020

2121
import torch
22-
from Automodel.flow_matching.time_shift_utils import (
22+
23+
from dfm.src.automodel.flow_matching.time_shift_utils import (
2324
compute_density_for_timestep_sampling,
2425
)
2526

@@ -28,8 +29,8 @@
2829

2930

3031
def step_fsdp_transformer_t2v(
31-
pipe,
32-
model_map: Dict,
32+
scheduler,
33+
model,
3334
batch,
3435
device,
3536
bf16,
@@ -40,6 +41,8 @@ def step_fsdp_transformer_t2v(
4041
logit_std: float = 1.0,
4142
flow_shift: float = 3.0,
4243
mix_uniform_ratio: float = 0.1,
44+
sigma_min: float = 0.0, # Default: no clamping (pretrain)
45+
sigma_max: float = 1.0, # Default: no clamping (pretrain)
4346
global_step: int = 0,
4447
) -> Tuple[torch.Tensor, Dict]:
4548
"""
@@ -74,7 +77,7 @@ def step_fsdp_transformer_t2v(
7477
# Flow Matching Timestep Sampling
7578
# ========================================================================
7679

77-
num_train_timesteps = pipe.scheduler.config.num_train_timesteps
80+
num_train_timesteps = scheduler.config.num_train_timesteps
7881

7982
if use_sigma_noise:
8083
use_uniform = torch.rand(1).item() < mix_uniform_ratio
@@ -96,12 +99,23 @@ def step_fsdp_transformer_t2v(
9699
# Apply flow shift: σ = shift/(shift + (1/u - 1))
97100
u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero
98101
sigma = flow_shift / (flow_shift + (1.0 / u_clamped - 1.0))
99-
sigma = torch.clamp(sigma, 0.0, 1.0)
102+
103+
# Clamp sigma (only if not full range [0,1])
104+
# Pretrain uses [0, 1], finetune uses [0.02, 0.55]
105+
if sigma_min > 0.0 or sigma_max < 1.0:
106+
sigma = torch.clamp(sigma, sigma_min, sigma_max)
107+
else:
108+
sigma = torch.clamp(sigma, 0.0, 1.0)
100109

101110
else:
102111
# Simple uniform without shift
103112
u = torch.rand(size=(batch_size,), device=device)
104-
sigma = u
113+
114+
# Clamp sigma (only if not full range [0,1])
115+
if sigma_min > 0.0 or sigma_max < 1.0:
116+
sigma = torch.clamp(u, sigma_min, sigma_max)
117+
else:
118+
sigma = u
105119
sampling_method = "uniform_no_shift"
106120

107121
# ========================================================================
@@ -186,10 +200,8 @@ def step_fsdp_transformer_t2v(
186200
# Forward Pass
187201
# ========================================================================
188202

189-
fsdp_model = model_map["transformer"]["fsdp_transformer"]
190-
191203
try:
192-
model_pred = fsdp_model(
204+
model_pred = model(
193205
hidden_states=noisy_latents,
194206
timestep=timesteps_for_model,
195207
encoder_hidden_states=text_embeddings,
@@ -243,7 +255,7 @@ def step_fsdp_transformer_t2v(
243255
logger.info(f"[STEP {global_step}] LOSS DEBUG")
244256
logger.info("=" * 80)
245257
logger.info("[TARGET] Flow matching: v = ε - x_0")
246-
logger.info(f"[PREDICTION] Scheduler type (inference only): {type(pipe.scheduler).__name__}")
258+
logger.info(f"[PREDICTION] Scheduler type (inference only): {type(scheduler).__name__}")
247259
logger.info("")
248260
logger.info(f"[RANGES] Model pred: [{model_pred.min():.4f}, {model_pred.max():.4f}]")
249261
logger.info(f"[RANGES] Target (v): [{target.min():.4f}, {target.max():.4f}]")
Lines changed: 47 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@
2222
import torch
2323
import torch.distributed as dist
2424
import wandb
25-
from Automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline
26-
from Automodel.flow_matching.training_step_t2v import (
27-
step_fsdp_transformer_t2v,
28-
)
2925
from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig
3026
from nemo_automodel.components.loggers.log_utils import setup_logging
3127
from nemo_automodel.components.loggers.wandb_utils import suppress_wandb_log_messages
@@ -36,68 +32,71 @@
3632
from torch.distributed.fsdp import MixedPrecisionPolicy
3733
from transformers.utils.hub import TRANSFORMERS_CACHE
3834

35+
from dfm.src.automodel._diffusers.auto_diffusion_pipeline import NeMoWanPipeline
36+
from dfm.src.automodel.flow_matching.training_step_t2v import (
37+
step_fsdp_transformer_t2v,
38+
)
39+
3940

4041
def build_model_and_optimizer(
4142
*,
4243
model_id: str,
44+
finetune_mode: bool,
4345
learning_rate: float,
4446
device: torch.device,
45-
bf16_dtype: torch.dtype,
47+
dtype: torch.dtype,
4648
cpu_offload: bool = False,
47-
tp_size: int = 1,
48-
cp_size: int = 1,
49-
pp_size: int = 1,
50-
dp_size: Optional[int] = None,
51-
dp_replicate_size: Optional[int] = None,
52-
use_hf_tp_plan: bool = False,
49+
fsdp_cfg: Dict[str, Any] = {},
5350
optimizer_cfg: Optional[Dict[str, Any]] = None,
54-
) -> tuple[NeMoAutoDiffusionPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]:
51+
) -> tuple[NeMoWanPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]:
5552
"""Build the WAN 2.1 diffusion model, parallel scheme, and optimizer."""
5653

57-
logging.info("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme...")
54+
logging.info("[INFO] Building NeMoWanPipeline with transformer parallel scheme...")
5855

5956
if not dist.is_initialized():
6057
logging.info("[WARN] torch.distributed not initialized; proceeding in single-process mode")
6158

6259
world_size = dist.get_world_size() if dist.is_initialized() else 1
6360

64-
if dp_size is None:
65-
denom = max(1, tp_size * cp_size * pp_size)
66-
dp_size = max(1, world_size // denom)
61+
if fsdp_cfg.get("dp_size", None) is None:
62+
denom = max(1, fsdp_cfg.get("tp_size", 1) * fsdp_cfg.get("cp_size", 1) * fsdp_cfg.get("pp_size", 1))
63+
fsdp_cfg.dp_size = max(1, world_size // denom)
6764

6865
manager_args: Dict[str, Any] = {
69-
"dp_size": dp_size,
70-
"dp_replicate_size": dp_replicate_size,
71-
"tp_size": tp_size,
72-
"cp_size": cp_size,
73-
"pp_size": pp_size,
66+
"dp_size": fsdp_cfg.get("dp_size", None),
67+
"dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None),
68+
"tp_size": fsdp_cfg.get("tp_size", 1),
69+
"cp_size": fsdp_cfg.get("cp_size", 1),
70+
"pp_size": fsdp_cfg.get("pp_size", 1),
7471
"backend": "nccl",
7572
"world_size": world_size,
76-
"use_hf_tp_plan": use_hf_tp_plan,
73+
"use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False),
7774
"activation_checkpointing": True,
7875
"mp_policy": MixedPrecisionPolicy(
79-
param_dtype=bf16_dtype,
80-
reduce_dtype=bf16_dtype,
81-
output_dtype=bf16_dtype,
76+
param_dtype=dtype,
77+
reduce_dtype=dtype,
78+
output_dtype=dtype,
8279
),
8380
}
8481

8582
parallel_scheme = {"transformer": manager_args}
8683

87-
pipe, created_managers = NeMoAutoDiffusionPipeline.from_pretrained(
84+
kwargs = {}
85+
if finetune_mode:
86+
kwargs["load_for_training"] = True
87+
kwargs["low_cpu_mem_usage"] = True
88+
init_fn = NeMoWanPipeline.from_pretrained if finetune_mode else NeMoWanPipeline.from_config
89+
90+
pipe, created_managers = init_fn(
8891
model_id,
89-
torch_dtype=bf16_dtype,
92+
torch_dtype=dtype,
9093
device=device,
9194
parallel_scheme=parallel_scheme,
92-
load_for_training=True,
9395
components_to_load=["transformer"],
96+
**kwargs,
9497
)
9598
fsdp2_manager = created_managers["transformer"]
96-
transformer_module = getattr(pipe, "transformer", None)
97-
if transformer_module is None:
98-
raise RuntimeError("transformer not found in pipeline after parallelization")
99-
100-
model_map: dict[str, Dict[str, Any]] = {"transformer": {"fsdp_transformer": transformer_module}}
99+
transformer_module = pipe.transformer
101100

102101
trainable_params = [p for p in transformer_module.parameters() if p.requires_grad]
103102
if not trainable_params:
@@ -121,7 +120,7 @@ def build_model_and_optimizer(
121120

122121
logging.info("[INFO] NeMoAutoDiffusion setup complete (pipeline + optimizer)")
123122

124-
return pipe, model_map, optimizer, fsdp2_manager.device_mesh
123+
return pipe, optimizer, getattr(fsdp2_manager, "device_mesh", None)
125124

126125

127126
def build_lr_scheduler(
@@ -198,36 +197,27 @@ def setup(self):
198197
self.logit_std = fm_cfg.get("logit_std", 1.0)
199198
self.flow_shift = fm_cfg.get("flow_shift", 3.0)
200199
self.mix_uniform_ratio = fm_cfg.get("mix_uniform_ratio", 0.1)
200+
self.sigma_min = fm_cfg.get("sigma_min", 0.0)
201+
self.sigma_max = fm_cfg.get("sigma_max", 1.0)
201202

202203
logging.info(f"[INFO] Flow matching: {'ENABLED' if self.use_sigma_noise else 'DISABLED'}")
203204
if self.use_sigma_noise:
204205
logging.info(f"[INFO] - Timestep sampling: {self.timestep_sampling}")
205206
logging.info(f"[INFO] - Flow shift: {self.flow_shift}")
206207
logging.info(f"[INFO] - Mix uniform ratio: {self.mix_uniform_ratio}")
207208

208-
tp_size = fsdp_cfg.get("tp_size", 1)
209-
cp_size = fsdp_cfg.get("cp_size", 1)
210-
pp_size = fsdp_cfg.get("pp_size", 1)
211-
dp_size = fsdp_cfg.get("dp_size", None)
212-
dp_replicate_size = fsdp_cfg.get("dp_replicate_size", None)
213-
use_hf_tp_plan = fsdp_cfg.get("use_hf_tp_plan", False)
214-
215-
(self.pipe, self.model_map, self.optimizer, self.device_mesh) = build_model_and_optimizer(
209+
(self.pipe, self.optimizer, self.device_mesh) = build_model_and_optimizer(
216210
model_id=self.model_id,
211+
finetune_mode=self.cfg.get("model.mode", "finetune").lower() == "finetune",
217212
learning_rate=self.learning_rate,
218213
device=self.device,
219-
bf16_dtype=self.bf16,
214+
dtype=self.bf16,
220215
cpu_offload=self.cpu_offload,
221-
tp_size=tp_size,
222-
cp_size=cp_size,
223-
pp_size=pp_size,
224-
dp_size=dp_size,
225-
dp_replicate_size=dp_replicate_size,
226-
use_hf_tp_plan=use_hf_tp_plan,
216+
fsdp_cfg=fsdp_cfg,
227217
optimizer_cfg=self.cfg.get("optim.optimizer", {}),
228218
)
229219

230-
self.model = self.model_map["transformer"]["fsdp_transformer"]
220+
self.model = self.pipe.transformer
231221
self.peft_config = None
232222

233223
batch_cfg = self.cfg.get("batch", {})
@@ -283,6 +273,9 @@ def setup(self):
283273
raise RuntimeError("Training dataloader is empty; cannot proceed with training")
284274

285275
# Derive DP size consistent with model parallel config
276+
tp_size = fsdp_cfg.get("tp_size", 1)
277+
cp_size = fsdp_cfg.get("cp_size", 1)
278+
pp_size = fsdp_cfg.get("pp_size", 1)
286279
denom = max(1, tp_size * cp_size * pp_size)
287280
self.dp_size = fsdp_cfg.get("dp_size", None)
288281
if self.dp_size is None:
@@ -356,8 +349,8 @@ def run_train_validation_loop(self):
356349
for micro_batch in batch_group:
357350
try:
358351
loss, _ = step_fsdp_transformer_t2v(
359-
pipe=self.pipe,
360-
model_map=self.model_map,
352+
scheduler=self.pipe.scheduler,
353+
model=self.model,
361354
batch=micro_batch,
362355
device=self.device,
363356
bf16=self.bf16,
@@ -367,6 +360,8 @@ def run_train_validation_loop(self):
367360
logit_std=self.logit_std,
368361
flow_shift=self.flow_shift,
369362
mix_uniform_ratio=self.mix_uniform_ratio,
363+
sigma_min=self.sigma_min,
364+
sigma_max=self.sigma_max,
370365
global_step=global_step,
371366
)
372367
except Exception as exc:

dfm/examples/automodel/finetune/finetune.py renamed to examples/automodel/finetune/finetune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
from __future__ import annotations
1616

17-
from Automodel.recipes.finetune import TrainWan21DiffusionRecipe
1817
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
1918

19+
from dfm.src.automodel.recipes.train import TrainWan21DiffusionRecipe
20+
2021

2122
def main(default_config_path="/opt/DFM/dfm/examples/Automodel/finetune/wan2_1_t2v_flow.yaml"):
2223
cfg = parse_args_and_load_config(default_config_path)

0 commit comments

Comments
 (0)