File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 8383from tzrec .utils .fx_util import symbolic_trace
8484from tzrec .utils .logging_util import ProgressLogger , logger
8585from tzrec .utils .plan_util import create_planner , get_default_sharders
86- from tzrec .utils .state_dict_util import init_parameters , validate_state
86+ from tzrec .utils .state_dict_util import init_parameters
8787from tzrec .version import __version__ as tzrec_version
8888
8989
@@ -778,7 +778,9 @@ def _script_model(
778778 model .load_state_dict (state_dict , strict = False )
779779
780780 # for mc modules, we should validate and sort mch buffers
781- validate_state (model )
781+ # but we do not gather state dict now, validate_state already
782+ # in post_load_state_dict hook
783+ # validate_state(model)
782784
783785 batch = next (iter (dataloader ))
784786
Original file line number Diff line number Diff line change @@ -24,7 +24,6 @@ def validate_state(model: nn.Module) -> None:
2424 if validate_log_flag :
2525 logger .info ("validate states..." )
2626 validate_log_flag = False
27- m .validate_state ()
2827 if isinstance (m , MCHManagedCollisionModule ):
2928 # fix output_segments_tensor is a meta tensor.
3029 output_segments = [
@@ -36,6 +35,7 @@ def validate_state(model: nn.Module) -> None:
3635 dtype = torch .int64 ,
3736 device = m ._current_iter_tensor .device ,
3837 )
38+ m .validate_state ()
3939
4040
4141def init_parameters (module : nn .Module , device : torch .device ) -> None :
You can’t perform that action at this time.
0 commit comments