Skip to content

Commit ed46bf2

Browse files
fix zch export
1 parent 1efd8c7 commit ed46bf2

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

tzrec/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
from tzrec.utils.fx_util import symbolic_trace
8484
from tzrec.utils.logging_util import ProgressLogger, logger
8585
from tzrec.utils.plan_util import create_planner, get_default_sharders
86-
from tzrec.utils.state_dict_util import init_parameters, validate_state
86+
from tzrec.utils.state_dict_util import init_parameters
8787
from tzrec.version import __version__ as tzrec_version
8888

8989

@@ -778,7 +778,9 @@ def _script_model(
778778
model.load_state_dict(state_dict, strict=False)
779779

780780
# for mc modules, we should validate and sort mch buffers
781-
validate_state(model)
781+
# but we do not gather state dict now, validate_state already
782+
# in post_load_state_dict hook
783+
# validate_state(model)
782784

783785
batch = next(iter(dataloader))
784786

tzrec/utils/state_dict_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def validate_state(model: nn.Module) -> None:
2424
if validate_log_flag:
2525
logger.info("validate states...")
2626
validate_log_flag = False
27-
m.validate_state()
2827
if isinstance(m, MCHManagedCollisionModule):
2928
# fix output_segments_tensor is a meta tensor.
3029
output_segments = [
@@ -36,6 +35,7 @@ def validate_state(model: nn.Module) -> None:
3635
dtype=torch.int64,
3736
device=m._current_iter_tensor.device,
3837
)
38+
m.validate_state()
3939

4040

4141
def init_parameters(module: nn.Module, device: torch.device) -> None:

0 commit comments

Comments
 (0)