Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pyre_configuration
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"tzrec/*/*/*_test.py",
"tzrec/tests/*.py",
"tzrec/utils/load_class.py",
"tzrec/utils/filesystem_util.py",
"tzrec/tools/convert_easyrec_config_to_tzrec_config.py",
"tzrec/ops/triton/*.py",
],
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ alibabacloud_credentials==0.3.6
anytree
common_io @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/common_io-0.4.1%2Btunnel-py2.py3-none-any.whl
fbgemm-gpu==1.3.0
fsspec
graphlearn @ https://tzrec.oss-cn-beijing.aliyuncs.com/third_party/graphlearn-1.3.7-cp312-cp312-linux_x86_64.whl ; python_version=="3.12"
graphlearn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/graphlearn-1.3.7-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
graphlearn @ https://tzrec.oss-accelerate.aliyuncs.com/third_party/graphlearn-1.3.7-cp310-cp310-linux_x86_64.whl ; python_version=="3.10"
Expand Down
7 changes: 7 additions & 0 deletions tzrec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,10 @@
format="[%(asctime)s][%(levelname)s] %(message)s", level=_log_level
)
_load_class.auto_import()


from tzrec.utils.filesystem_util import ( # NOQA
register_external_filesystem as _register_external_filesystem, # NOQA
) # NOQA

_register_external_filesystem()
17 changes: 14 additions & 3 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
init_process_group,
)
from tzrec.utils.export_util import export_model
from tzrec.utils.filesystem_util import url_to_fs
from tzrec.utils.logging_util import ProgressLogger, logger
from tzrec.utils.plan_util import create_planner, get_default_sharders
from tzrec.version import __version__ as tzrec_version
Expand Down Expand Up @@ -922,6 +923,19 @@ def predict(
if output_columns is not None:
output_cols = [x.strip() for x in output_columns.split(",")]

device_and_backend = init_process_group()
device: torch.device = device_and_backend[0]

fs, local_path = url_to_fs(scripted_model_path)
if fs is not None:
# scripted model use io in cpp, so that we can not path to fsspec
local_path = os.environ.get("LOCAL_CACHE_DIR", local_path)
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
logger.info(f"downloading {scripted_model_path} to {local_path}.")
fs.download(scripted_model_path, local_path, recursive=True)
dist.barrier()
scripted_model_path = local_path

pipeline_config = config_util.load_pipeline_config(
os.path.join(scripted_model_path, "pipeline.config"), allow_unknown_field=True
)
Expand Down Expand Up @@ -953,9 +967,6 @@ def predict(
edit_config_json = json.loads(edit_config_json)
config_util.edit_config(pipeline_config, edit_config_json)

device_and_backend = init_process_group()
device: torch.device = device_and_backend[0]

is_rank_zero = int(os.environ.get("RANK", 0)) == 0
is_local_rank_zero = int(os.environ.get("LOCAL_RANK", 0)) == 0

Expand Down
7 changes: 6 additions & 1 deletion tzrec/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,17 +1072,19 @@ def test_eval(
def test_export(
pipeline_config_path: str,
test_dir: str,
export_dir: str = "",
asset_files: str = "",
env_str: str = "",
) -> bool:
"""Run export integration test."""
log_dir = os.path.join(test_dir, "log_export")
export_dir = export_dir or f"{test_dir}/export"
cmd_str = (
f"PYTHONPATH=. torchrun {_standalone()} "
f"--nnodes=1 --nproc-per-node=2 --log_dir {log_dir} "
"-r 3 -t 3 tzrec/export.py "
f"--pipeline_config_path {pipeline_config_path} "
f"--export_dir {test_dir}/export "
f"--export_dir {export_dir} "
)
if env_str:
cmd_str = f"{env_str} {cmd_str}"
Expand Down Expand Up @@ -1121,6 +1123,7 @@ def test_predict(
test_dir: str,
predict_threads: Optional[int] = None,
predict_steps: Optional[int] = None,
env_str: str = "",
) -> bool:
"""Run predict integration test."""
log_dir = os.path.join(test_dir, "log_predict")
Expand All @@ -1146,6 +1149,8 @@ def test_predict(
cmd_str += f"--predict_threads {predict_threads} "
if predict_steps is not None:
cmd_str += f"--predict_steps {predict_steps} "
if env_str:
cmd_str = f"{env_str} {cmd_str}"

return misc_util.run_cmd(
cmd_str, os.path.join(test_dir, "log_predict.txt"), timeout=600
Expand Down
4 changes: 3 additions & 1 deletion tzrec/utils/checkpoint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def latest_checkpoint(model_dir: str) -> Tuple[Optional[str], int]:
latest_step: step of the latest checkpoint
"""
if "model.ckpt-" not in model_dir:
ckpt_metas = glob.glob(os.path.join(model_dir, "model.ckpt-*"))
# fsspec glob need endswith os.path.sep
ckpt_metas = glob.glob(os.path.join(model_dir, "model.ckpt-*" + os.path.sep))
ckpt_metas = list(map(lambda x: x.rstrip(os.path.sep), ckpt_metas))
if len(ckpt_metas) == 0:
model_ckpt_dir = os.path.join(model_dir, "model")
optim_ckpt_dir = os.path.join(model_dir, "optimizer")
Expand Down
13 changes: 11 additions & 2 deletions tzrec/utils/export_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from tzrec.protos.pipeline_pb2 import EasyRecConfig
from tzrec.utils import checkpoint_util, config_util
from tzrec.utils.dist_util import DistributedModelParallel, init_process_group
from tzrec.utils.filesystem_util import url_to_fs
from tzrec.utils.fx_util import (
fx_mark_keyed_tensor,
fx_mark_seq_len,
Expand All @@ -75,13 +76,21 @@ def export_model(
"""Export a EasyRec model, may be a part of model in PipelineConfig."""
use_rtp = os.environ.get("USE_RTP", "0") == "1"
impl = export_rtp_model if use_rtp else export_model_normal
return impl(
fs, local_path = url_to_fs(save_dir)
if fs is not None:
# scripted model use io in cpp, so that we can not path to fsspec
local_path = os.environ.get("LOCAL_CACHE_DIR", local_path)
impl(
pipeline_config=pipeline_config,
model=model,
checkpoint_path=checkpoint_path,
save_dir=save_dir,
save_dir=local_path,
assets=assets,
)
if fs is not None and int(os.environ.get("LOCAL_RANK", 0)) == 0:
logger.info(f"uploading {local_path} to {save_dir}.")
fs.upload(local_path, save_dir, recursive=True)
shutil.rmtree(local_path)


def export_model_normal(
Expand Down
Loading