Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 35 additions & 54 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
# NOQA
from torchrec.distributed.train_pipeline import TrainPipelineSparseDist
from torchrec.inference.modules import quantize_embeddings
from torchrec.inference.state_dict_transform import (
state_dict_to_device,
)
from torchrec.optim.apply_optimizer_in_backward import (
apply_optimizer_in_backward, # NOQA
)
Expand Down Expand Up @@ -70,6 +67,7 @@
from tzrec.models.model import BaseModel, CudaExportWrapper, ScriptWrapper, TrainWrapper
from tzrec.models.tdm import TDM, TDMEmbedding
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.utils import BaseModule
from tzrec.ops import Kernel
from tzrec.optim import optimizer_builder
from tzrec.optim.lr_scheduler import BaseLR
Expand All @@ -85,7 +83,7 @@
from tzrec.utils.fx_util import symbolic_trace
from tzrec.utils.logging_util import ProgressLogger, logger
from tzrec.utils.plan_util import create_planner, get_default_sharders
from tzrec.utils.state_dict_util import state_dict_gather, validate_state
from tzrec.utils.state_dict_util import fix_mch_state, init_parameters
from tzrec.version import __version__ as tzrec_version


Expand Down Expand Up @@ -764,8 +762,8 @@ def evaluate(

def _script_model(
pipeline_config: EasyRecConfig,
model: nn.Module,
state_dict: Dict[str, Any],
model: BaseModule,
state_dict: Optional[Dict[str, Any]],
dataloader: DataLoader,
save_dir: str,
) -> None:
Expand All @@ -774,15 +772,13 @@ def _script_model(
if is_rank_zero:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
model = model.to_empty(device="cpu")
logger.info("gather states to cpu model...")
model.set_is_inference(True)
if state_dict is not None:
model.to_empty(device="cpu")
model.load_state_dict(state_dict, strict=False)

state_dict_gather(state_dict, model.state_dict())
dist.barrier()

if is_rank_zero:
# for mc modules, we should validate and sort mch buffers
validate_state(model)
# for mc modules, fix output_segments_tensor is a meta tensor.
fix_mch_state(model)

batch = next(iter(dataloader))

Expand Down Expand Up @@ -851,11 +847,21 @@ def export(
model specified by model_dir in pipeline_config_path.
asset_files (str, optional): more files will be copied to export_dir.
"""
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
if not is_rank_zero:
logger.warning("Only first rank will be used for export now.")
return
else:
if os.environ.get("WORLD_SIZE") != "1":
logger.warning(
"export only support WORLD_SIZE=1 now, we set WORLD_SIZE to 1."
)
os.environ["WORLD_SIZE"] = "1"

pipeline_config = config_util.load_pipeline_config(pipeline_config_path)
ori_pipeline_config = copy.copy(pipeline_config)

device, _ = init_process_group()
is_rank_zero = int(os.environ.get("RANK", 0)) == 0
dist.init_process_group("gloo")
if is_rank_zero:
if os.path.exists(export_dir):
raise RuntimeError(f"directory {export_dir} already exist.")
Expand Down Expand Up @@ -887,19 +893,9 @@ def export(
features,
list(data_config.label_fields),
)
model = ScriptWrapper(model)

planner = create_planner(
device=device,
batch_size=data_config.batch_size,
)
plan = planner.collective_plan(
model, get_default_sharders(), dist.GroupMember.WORLD
)
if is_rank_zero:
logger.info(str(plan))

model = DistributedModelParallel(module=model, device=device, plan=plan)
InferWrapper = CudaExportWrapper if is_aot() else ScriptWrapper
model = InferWrapper(model)
init_parameters(model, torch.device("cpu"))

if not checkpoint_path:
checkpoint_path, _ = checkpoint_util.latest_checkpoint(
Expand All @@ -922,23 +918,8 @@ def export(
else:
raise ValueError("checkpoint path should be specified.")

checkpoint_pg = dist.new_group(backend="gloo")
if is_rank_zero:
logger.info("copy sharded state_dict to cpu...")
cpu_state_dict = state_dict_to_device(
model.state_dict(), pg=checkpoint_pg, device=torch.device("cpu")
)

cpu_model = _create_model(
pipeline_config.model_config,
features,
list(data_config.label_fields),
)
cpu_model.set_is_inference(True)

InferWrapper = CudaExportWrapper if is_aot() else ScriptWrapper
if isinstance(cpu_model, MatchModel):
for name, module in cpu_model.named_children():
if isinstance(model.model, MatchModel):
for name, module in model.model.named_children():
if isinstance(module, MatchTower) or isinstance(module, MatchTowerWoEG):
wrapper = (
TowerWrapper if isinstance(module, MatchTower) else TowerWoEGWrapper
Expand All @@ -948,28 +929,28 @@ def export(
_script_model(
ori_pipeline_config,
tower,
cpu_state_dict,
model.state_dict(),
dataloader,
tower_export_dir,
)
for asset in assets:
shutil.copy(asset, tower_export_dir)
elif isinstance(cpu_model, TDM):
for name, module in cpu_model.named_children():
elif isinstance(model.model, TDM):
for name, module in model.model.named_children():
if isinstance(module, EmbeddingGroup):
emb_module = InferWrapper(TDMEmbedding(module, name))
_script_model(
ori_pipeline_config,
emb_module,
cpu_state_dict,
model.state_dict(),
dataloader,
os.path.join(export_dir, "embedding"),
)
break
_script_model(
ori_pipeline_config,
InferWrapper(cpu_model),
cpu_state_dict,
model,
None,
dataloader,
os.path.join(export_dir, "model"),
)
Expand All @@ -978,8 +959,8 @@ def export(
else:
_script_model(
ori_pipeline_config,
InferWrapper(cpu_model),
cpu_state_dict,
model,
None,
dataloader,
export_dir,
)
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/dat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tzrec.models.dat import DAT
from tzrec.protos import feature_pb2, loss_pb2, model_pb2, module_pb2, tower_pb2
from tzrec.protos.models import match_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class DATTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/dbmtl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
tower_pb2,
)
from tzrec.protos.models import multi_task_rank_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class DBMTLTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/dc2vr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
tower_pb2,
)
from tzrec.protos.models import multi_task_rank_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class DC2VRTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/deepfm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tzrec.models.deepfm import DeepFM
from tzrec.protos import feature_pb2, loss_pb2, model_pb2, module_pb2, seq_encoder_pb2
from tzrec.protos.models import rank_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class DeepFMTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/dlrm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tzrec.models.dlrm import DLRM
from tzrec.protos import feature_pb2, loss_pb2, model_pb2, module_pb2
from tzrec.protos.models import rank_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class DLRMTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/dssm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tzrec.models.dssm import DSSM
from tzrec.protos import feature_pb2, loss_pb2, model_pb2, module_pb2, tower_pb2
from tzrec.protos.models import match_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class DSSMTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/dssm_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tzrec.models.dssm_v2 import DSSMV2
from tzrec.protos import feature_pb2, loss_pb2, model_pb2, module_pb2, tower_pb2
from tzrec.protos.models import match_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class DSSMV2Test(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/hstu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
tower_pb2,
)
from tzrec.protos.models import match_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class HSTUTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/mind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tzrec.models.mind import MIND
from tzrec.protos import feature_pb2, loss_pb2, model_pb2, module_pb2, tower_pb2
from tzrec.protos.models import match_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class MINDTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/mmoe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
tower_pb2,
)
from tzrec.protos.models import multi_task_rank_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class MMoETest(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions tzrec/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_features_in_feature_groups(
TRAIN_FWD_TYPE = Tuple[torch.Tensor, TRAIN_OUT_TYPE]


class TrainWrapper(nn.Module):
class TrainWrapper(BaseModule):
"""Model train wrapper for pipeline."""

def __init__(self, module: nn.Module) -> None:
Expand Down Expand Up @@ -214,7 +214,7 @@ def forward(self, batch: Batch) -> TRAIN_FWD_TYPE:
return total_loss, (losses, predictions, batch)


class ScriptWrapper(nn.Module):
class ScriptWrapper(BaseModule):
"""Model inference wrapper for jit.script."""

def __init__(self, module: nn.Module) -> None:
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/multi_tower_din_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
tower_pb2,
)
from tzrec.protos.models import rank_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class MultiTowerDINTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/multi_tower_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
tower_pb2,
)
from tzrec.protos.models import rank_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class MultiTowerTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/ple_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
tower_pb2,
)
from tzrec.protos.models import multi_task_rank_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class PLETest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/rocket_launching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
seq_encoder_pb2,
)
from tzrec.protos.models import general_rank_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class RocketLaunchingTest(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tzrec/models/tdm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from tzrec.models.tdm import TDM
from tzrec.protos import feature_pb2, loss_pb2, model_pb2, module_pb2, tower_pb2
from tzrec.protos.models import match_model_pb2
from tzrec.utils.test_util import TestGraphType, create_test_model, init_parameters
from tzrec.utils.state_dict_util import init_parameters
from tzrec.utils.test_util import TestGraphType, create_test_model


class TDMTest(unittest.TestCase):
Expand Down
Loading