Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions tzrec/utils/plan_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def __init__(self, mem_bins_per_device: int = 100) -> None:
# indices of sharding_options
self._proposal_list: List[List[int]] = []
self._current_proposal: int = -1
self._plan_by_hbm = True
self._storage_type = "hbm"
if not torch.cuda.is_available():
self._plan_by_hbm = False
self._storage_type = "ddr"

def load(
self,
Expand All @@ -225,10 +225,7 @@ def load(
self._reset()
# order the sharding_option by total_storage.hbm from low to high
for sharding_option in sorted(
search_space,
key=lambda x: x.total_storage.hbm
if self._plan_by_hbm
else x.total_storage.ddr,
search_space, key=lambda x: getattr(x.total_storage, self._storage_type)
):
fqn = sharding_option.fqn
if fqn not in self._sharding_options_by_fqn:
Expand Down Expand Up @@ -273,12 +270,12 @@ def feedback(

assert storage_constraint is not None
# are we assuming the table will be evenly sharded on all devices?
mem_total = sum(
[
x.storage.hbm if self._plan_by_hbm else x.storage.ddr
for x in storage_constraint.devices
]
)
max_device_mem = 0
mem_total = 0
for x in storage_constraint.devices:
cur_device_mem = getattr(x.storage, self._storage_type)
max_device_mem = max(max_device_mem, cur_device_mem)
mem_total += cur_device_mem

bin_count = self._mem_bins_per_device * len(storage_constraint.devices)
bin_size = float(mem_total) / bin_count
Expand All @@ -304,10 +301,19 @@ def feedback(
self._sharding_options_by_fqn.values()
):
for opt_id, sharding_option in enumerate(sharding_options):
# prune mem of one shard > mem of one device
if (
max(
[
getattr(shard.storage, self._storage_type)
for shard in sharding_option.shards
]
)
> max_device_mem
):
continue
mem_by_fqn[table_id][opt_id] = _bytes_to_float_bin(
sharding_option.total_storage.hbm
if self._plan_by_hbm
else sharding_option.total_storage.ddr,
getattr(sharding_option.total_storage, self._storage_type),
bin_size,
)
perf_by_fqn[table_id][opt_id] = sharding_option.total_perf
Expand Down
47 changes: 47 additions & 0 deletions tzrec/utils/plan_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,53 @@ def test_dp_proposer(self) -> None:
for k, v in best_grid_proposal.items():
self.assertEqual(str(v), str(best_dp_proposal[k]))

def test_dp_proposer_with_prune(self) -> None:
topology = Topology(
world_size=2,
hbm_cap=(1000**3) * 10 * 2 * 4,
compute_device="cuda" if torch.cuda.is_available() else "cpu",
)
enumerator = EmbeddingEnumerator(topology=topology, batch_size=8196)
partitioner = GreedyPerfPartitioner()

tables = [
EmbeddingBagConfig(
num_embeddings=1000**i,
embedding_dim=10 * i,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(1, 4)
]
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
search_space = enumerator.enumerate(
module=model,
sharders=get_default_sharders(),
)

dp_proposer = DynamicProgrammingProposer()
dp_proposer.load(search_space)
best_dp_perf = float("inf")
best_dp_proposal = None
num_proposals = 0
proposal = dp_proposer.propose()
while proposal:
num_proposals += 1
try:
partitioner.partition(proposal, topology)
cur_perf = sum([x.total_perf for x in proposal])
if cur_perf < best_dp_perf:
best_dp_proposal = {x.fqn: x for x in proposal}
best_dp_perf = cur_perf
except PlannerError:
pass
dp_proposer.feedback(partitionable=True, storage_constraint=topology)
proposal = dp_proposer.propose()
self.assertEqual(
best_dp_proposal["sparse.ebc.table_3"].sharding_type,
"row_wise" if torch.cuda.is_available() else "table_wise",
)


if __name__ == "__main__":
unittest.main()