Skip to content

Commit 76016ec

Browse files
[bugfix] fix correctness of kjt.lengths when ShardedEmbeddingBag’s pooling_type is mean and shard_type is row_wise (#106)
1 parent 9e43f22 commit 76016ec

2 files changed

Lines changed: 87 additions & 31 deletions

File tree

tzrec/tests/configs/multi_tower_din_fg_mock.config

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ feature_configs {
5959
hash_bucket_size: 100
6060
embedding_dim: 16
6161
embedding_name: "id_4_emb"
62+
value_dim: 0
63+
pooling: "mean"
6264
}
6365
}
6466
feature_configs {
@@ -68,6 +70,8 @@ feature_configs {
6870
hash_bucket_size: 100
6971
embedding_dim: 16
7072
embedding_name: "id_4_emb"
73+
value_dim: 0
74+
pooling: "mean"
7175
}
7276
}
7377
feature_configs {

tzrec/utils/dist_util.py

Lines changed: 83 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,40 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from collections import OrderedDict
13-
from typing import List
12+
from typing import Dict, List, Optional, Tuple
1413

1514
import torch
1615
from torch import distributed as dist
17-
from torch import nn
18-
from torchrec.distributed.types import ShardingPlan, ShardingType
19-
20-
21-
def sync_dp_emb_table(model: nn.Module, plan: ShardingPlan) -> None:
22-
"""Sync data parallel embedding table params."""
23-
dp_param_names = []
24-
for _, module_plan in plan.plan.items():
25-
# pyre-ignore [16]
26-
for param_name, param_sharding in module_plan.items():
27-
if param_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
28-
dp_param_names.append(param_name)
29-
dp_params = OrderedDict()
30-
for name, param in model.named_parameters():
31-
name_parts = name.split(".")
32-
if (
33-
len(name_parts) > 2
34-
and name_parts[-1] == "weight"
35-
and name_parts[-2] in dp_param_names
36-
):
37-
# pyre-ignore [16]
38-
ori_t = param._original_tensor
39-
if ori_t not in dp_params:
40-
dp_params[ori_t] = 1
41-
broadcast_works = []
42-
for t in dp_params:
43-
broadcast_works.append(dist.broadcast(t.detach(), src=0, async_op=True))
44-
for w in broadcast_works:
45-
w.wait()
16+
from torch.autograd.profiler import record_function
17+
from torchrec.distributed import embeddingbag
18+
from torchrec.distributed.utils import none_throws
19+
from torchrec.modules.embedding_configs import PoolingType
20+
from torchrec.sparse.jagged_tensor import _to_offsets
4621

4722

4823
def broadcast_string(s: str, src: int = 0) -> str:
@@ -106,3 +81,80 @@ def gather_strings(s: str, dst: int = 0) -> List[str]:
10681
gathered_strings.append(string)
10782

10883
return gathered_strings
84+
85+
86+
# lengths of kjt will be modified by create_mean_pooling_divisor, we fix it
87+
# with lengths = lengths.clone() temporarily.
88+
def _create_mean_pooling_divisor(
89+
lengths: torch.Tensor,
90+
keys: List[str],
91+
offsets: torch.Tensor,
92+
stride: int,
93+
stride_per_key: List[int],
94+
dim_per_key: torch.Tensor,
95+
pooling_type_to_rs_features: Dict[str, List[str]],
96+
embedding_names: List[str],
97+
embedding_dims: List[int],
98+
variable_batch_per_feature: bool,
99+
kjt_inverse_order: torch.Tensor,
100+
kjt_key_indices: Dict[str, int],
101+
kt_key_ordering: torch.Tensor,
102+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
103+
weights: Optional[torch.Tensor] = None,
104+
) -> torch.Tensor:
105+
with record_function("## ebc create mean pooling callback ##"):
106+
batch_size = (
107+
none_throws(inverse_indices)[1].size(dim=1)
108+
if variable_batch_per_feature
109+
else stride
110+
)
111+
112+
if weights is not None:
113+
# if we have weights, lengths is the sum of weights by offsets for feature
114+
lengths = torch.ops.fbgemm.segment_sum_csr(1, offsets.int(), weights)
115+
116+
if variable_batch_per_feature:
117+
inverse_indices = none_throws(inverse_indices)
118+
device = inverse_indices[1].device
119+
inverse_indices_t = inverse_indices[1]
120+
if len(keys) != len(inverse_indices[0]):
121+
inverse_indices_t = torch.index_select(
122+
inverse_indices[1], 0, kjt_inverse_order
123+
)
124+
offsets = _to_offsets(torch.tensor(stride_per_key, device=device))[
125+
:-1
126+
].unsqueeze(-1)
127+
indices = (inverse_indices_t + offsets).flatten()
128+
lengths = torch.index_select(input=lengths, dim=0, index=indices)
129+
130+
# only convert the sum pooling features to be 1 lengths
131+
lengths = lengths.clone()
132+
for feature in pooling_type_to_rs_features[PoolingType.SUM.value]:
133+
feature_index = kjt_key_indices[feature]
134+
feature_index = feature_index * batch_size
135+
lengths[feature_index : feature_index + batch_size] = 1
136+
137+
if len(embedding_names) != len(keys):
138+
lengths = torch.index_select(
139+
lengths.reshape(-1, batch_size),
140+
0,
141+
kt_key_ordering,
142+
).reshape(-1)
143+
144+
# transpose to align features with keyed tensor dim_per_key
145+
lengths = lengths.reshape(-1, batch_size).T # [batch_size, num_features]
146+
output_size = sum(embedding_dims)
147+
148+
divisor = torch.repeat_interleave(
149+
input=lengths,
150+
repeats=dim_per_key,
151+
dim=1,
152+
output_size=output_size,
153+
)
154+
eps = 1e-6 # used to safe guard against 0 division
155+
divisor = divisor + eps
156+
return divisor.detach()
157+
158+
159+
# pyre-ignore [9]
160+
embeddingbag._create_mean_pooling_divisor = _create_mean_pooling_divisor

0 commit comments

Comments
 (0)