@@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
345345
346346 # calculate number of trainable parameters
347347 n_params = 0
348- for params in params_to_optimize :
349- for p in params ["params" ]:
348+ for group in params_to_optimize :
349+ for p in group ["params" ]:
350350 n_params += p .numel ()
351351
352352 accelerator .print (f"train unet: { train_unet } , text_encoder1: { train_text_encoder1 } , text_encoder2: { train_text_encoder2 } " )
@@ -355,7 +355,44 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
355355
356356 # 学習に必要なクラスを準備する
357357 accelerator .print ("prepare optimizer, data loader etc." )
358- _ , _ , optimizer = train_util .get_optimizer (args , trainable_params = params_to_optimize )
358+
359+ if args .fused_optimizer_groups :
360+ # calculate total number of parameters
361+ n_total_params = sum (len (params ["params" ]) for params in params_to_optimize )
362+ params_per_group = math .ceil (n_total_params / args .fused_optimizer_groups )
363+
364+ # split params into groups
365+ grouped_params = []
366+ param_group = []
367+ param_group_lr = - 1
368+ for group in params_to_optimize :
369+ lr = group ["lr" ]
370+ for p in group ["params" ]:
371+ if lr != param_group_lr :
372+ if param_group :
373+ grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
374+ param_group = []
375+ param_group_lr = lr
376+ param_group .append (p )
377+ if len (param_group ) == params_per_group :
378+ grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
379+ param_group = []
380+ param_group_lr = - 1
381+ if param_group :
382+ grouped_params .append ({"params" : param_group , "lr" : param_group_lr })
383+
384+ # prepare optimizers for each group
385+ optimizers = []
386+ for group in grouped_params :
387+ _ , _ , optimizer = train_util .get_optimizer (args , trainable_params = [group ])
388+ optimizers .append (optimizer )
389+ optimizer = optimizers [0 ] # avoid error in the following code
390+
391+ print (len (grouped_params ))
392+ logger .info (f"using { len (optimizers )} optimizers for fused optimizer groups" )
393+
394+ else :
395+ _ , _ , optimizer = train_util .get_optimizer (args , trainable_params = params_to_optimize )
359396
360397 # dataloaderを準備する
361398 # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
@@ -382,7 +419,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
382419 train_dataset_group .set_max_train_steps (args .max_train_steps )
383420
384421 # lr schedulerを用意する
385- lr_scheduler = train_util .get_scheduler_fix (args , optimizer , accelerator .num_processes )
422+ if args .fused_optimizer_groups :
423+ lr_schedulers = [train_util .get_scheduler_fix (args , optimizer , accelerator .num_processes ) for optimizer in optimizers ]
424+ lr_scheduler = lr_schedulers [0 ] # avoid error in the following code
425+ else :
426+ lr_scheduler = train_util .get_scheduler_fix (args , optimizer , accelerator .num_processes )
386427
387428 # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
388429 if args .full_fp16 :
@@ -432,10 +473,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
432473
433474 if args .fused_backward_pass :
434475 import library .adafactor_fused
476+
435477 library .adafactor_fused .patch_adafactor_fused (optimizer )
436478 for param_group in optimizer .param_groups :
437479 for parameter in param_group ["params" ]:
438480 if parameter .requires_grad :
481+
439482 def __grad_hook (tensor : torch .Tensor , param_group = param_group ):
440483 if accelerator .sync_gradients and args .max_grad_norm != 0.0 :
441484 accelerator .clip_grad_norm_ (tensor , args .max_grad_norm )
@@ -444,6 +487,36 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
444487
445488 parameter .register_post_accumulate_grad_hook (__grad_hook )
446489
490+ elif args .fused_optimizer_groups :
491+ for i in range (1 , len (optimizers )):
492+ optimizers [i ] = accelerator .prepare (optimizers [i ])
493+ lr_schedulers [i ] = accelerator .prepare (lr_schedulers [i ])
494+
495+ global optimizer_hooked_count
496+ global num_parameters_per_group
497+ global parameter_optimizer_map
498+ optimizer_hooked_count = {}
499+ num_parameters_per_group = [0 ] * len (optimizers )
500+ parameter_optimizer_map = {}
501+ for opt_idx , optimizer in enumerate (optimizers ):
502+ for param_group in optimizer .param_groups :
503+ for parameter in param_group ["params" ]:
504+ if parameter .requires_grad :
505+
506+ def optimizer_hook (parameter : torch .Tensor ):
507+ if accelerator .sync_gradients and args .max_grad_norm != 0.0 :
508+ accelerator .clip_grad_norm_ (parameter , args .max_grad_norm )
509+
510+ i = parameter_optimizer_map [parameter ]
511+ optimizer_hooked_count [i ] += 1
512+ if optimizer_hooked_count [i ] == num_parameters_per_group [i ]:
513+ optimizers [i ].step ()
514+ optimizers [i ].zero_grad ()
515+
516+ parameter .register_post_accumulate_grad_hook (optimizer_hook )
517+ parameter_optimizer_map [parameter ] = opt_idx
518+ num_parameters_per_group [opt_idx ] += 1
519+
447520 # TextEncoderの出力をキャッシュするときにはCPUへ移動する
448521 if args .cache_text_encoder_outputs :
449522 # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
@@ -518,6 +591,10 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
518591
519592 for step , batch in enumerate (train_dataloader ):
520593 current_step .value = global_step
594+
595+ if args .fused_optimizer_groups :
596+ optimizer_hooked_count = {i : 0 for i in range (len (optimizers ))}
597+
521598 with accelerator .accumulate (* training_models ):
522599 if "latents" in batch and batch ["latents" ] is not None :
523600 latents = batch ["latents" ].to (accelerator .device ).to (dtype = weight_dtype )
@@ -596,7 +673,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
596673
597674 # Sample noise, sample a random timestep for each image, and add noise to the latents,
598675 # with noise offset and/or multires noise if specified
599- noise , noisy_latents , timesteps , huber_c = train_util .get_noise_noisy_latents_and_timesteps (args , noise_scheduler , latents )
676+ noise , noisy_latents , timesteps , huber_c = train_util .get_noise_noisy_latents_and_timesteps (
677+ args , noise_scheduler , latents
678+ )
600679
601680 noisy_latents = noisy_latents .to (weight_dtype ) # TODO check why noisy_latents is not weight_dtype
602681
@@ -614,7 +693,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
614693 or args .masked_loss
615694 ):
616695 # do not mean over batch dimension for snr weight or scale v-pred loss
617- loss = train_util .conditional_loss (noise_pred .float (), target .float (), reduction = "none" , loss_type = args .loss_type , huber_c = huber_c )
696+ loss = train_util .conditional_loss (
697+ noise_pred .float (), target .float (), reduction = "none" , loss_type = args .loss_type , huber_c = huber_c
698+ )
618699 if args .masked_loss :
619700 loss = apply_masked_loss (loss , batch )
620701 loss = loss .mean ([1 , 2 , 3 ])
@@ -630,21 +711,28 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
630711
631712 loss = loss .mean () # mean over batch dimension
632713 else :
633- loss = train_util .conditional_loss (noise_pred .float (), target .float (), reduction = "mean" , loss_type = args .loss_type , huber_c = huber_c )
714+ loss = train_util .conditional_loss (
715+ noise_pred .float (), target .float (), reduction = "mean" , loss_type = args .loss_type , huber_c = huber_c
716+ )
634717
635718 accelerator .backward (loss )
636719
637- if not args .fused_backward_pass :
720+ if not ( args .fused_backward_pass or args . fused_optimizer_groups ) :
638721 if accelerator .sync_gradients and args .max_grad_norm != 0.0 :
639722 params_to_clip = []
640723 for m in training_models :
641724 params_to_clip .extend (m .parameters ())
642725 accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
643726
644727 optimizer .step ()
728+ elif args .fused_optimizer_groups :
729+ for i in range (1 , len (optimizers )):
730+ lr_schedulers [i ].step ()
645731
646732 lr_scheduler .step ()
647- optimizer .zero_grad (set_to_none = True )
733+
734+ if not (args .fused_backward_pass or args .fused_optimizer_groups ):
735+ optimizer .zero_grad (set_to_none = True )
648736
649737 # Checks if the accelerator has performed an optimization step behind the scenes
650738 if accelerator .sync_gradients :
@@ -753,7 +841,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group):
753841
754842 accelerator .end_training ()
755843
756- if args .save_state or args .save_state_on_train_end :
844+ if args .save_state or args .save_state_on_train_end :
757845 train_util .save_state_on_train_end (args , accelerator )
758846
759847 del accelerator # この後メモリを使うのでこれは消す
@@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser:
822910 help = f"learning rates for each block of U-Net, comma-separated, { UNET_NUM_BLOCKS_FOR_BLOCK_LR } values / "
823911 + f"U-Netの各ブロックの学習率、カンマ区切り、{ UNET_NUM_BLOCKS_FOR_BLOCK_LR } 個の値" ,
824912 )
913+ parser .add_argument (
914+ "--fused_optimizer_groups" ,
915+ type = int ,
916+ default = None ,
917+ help = "number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数" ,
918+ )
825919 return parser
826920
827921
0 commit comments