Commit 0663edf
committed
fix(infra): keep model.to(device) on unsharded post-shard load
Persistent buffers initialized via torch.tensor()/torch.ones() inside init_empty_weights() (e.g. Gemma4's Gemma4ClippableLinear input_min/max, Gemma4TextDecoderLayer layer_scalar) stay on CPU because the context only patches register_parameter, not register_buffer. The post-shard load path then unconditionally skipped model.to(device), leaving these buffers stranded and tripping torch.clamp on cuda:0 vs cpu.
The skip exists for FSDP's reset_sharded_param issue with tied params under TP>1 (pytorch/pytorch#151085). Narrow it to its actual precondition: any DTensor in the model, so single-GPU, DDP, and other unsharded configs still run model.to(device). Add unit coverage for both the unsharded and DTensor-sharded checkpoint load paths.
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>1 parent 8222a4f commit 0663edf
2 files changed
Lines changed: 52 additions & 9 deletions
File tree
- nemo_automodel/_transformers
- tests/unit_tests/_transformers
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
565 | 565 | | |
566 | 566 | | |
567 | 567 | | |
568 | | - | |
569 | | - | |
570 | | - | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
571 | 572 | | |
572 | | - | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
573 | 582 | | |
574 | 583 | | |
575 | 584 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
190 | 190 | | |
191 | 191 | | |
192 | 192 | | |
193 | | - | |
194 | | - | |
| 193 | + | |
| 194 | + | |
195 | 195 | | |
196 | 196 | | |
197 | 197 | | |
| |||
217 | 217 | | |
218 | 218 | | |
219 | 219 | | |
220 | | - | |
221 | | - | |
222 | | - | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
223 | 257 | | |
224 | 258 | | |
225 | 259 | | |
| |||
0 commit comments