Skip to content

Commit edec024

Browse files
[bugfix] fix feature permute when use mc-ebc and mean pooling (#134)
1 parent 2da91b4 commit edec024

1 file changed

Lines changed: 20 additions & 0 deletions

File tree

tzrec/utils/dist_util.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,21 @@ def _mc_input_dist(
190190
) -> Awaitable[Awaitable[KJTList]]:
191191
if self._embedding_module._has_uninitialized_input_dist:
192192
if isinstance(self._embedding_module, ShardedEmbeddingBagCollection):
193+
self._features_order = []
194+
# disable feature permutation in mc, because we should
195+
# permute features in mc-ebc before mean pooling callback.
196+
if self._managed_collision_collection._has_uninitialized_input_dists:
197+
self._managed_collision_collection._create_input_dists(
198+
input_feature_names=features.keys()
199+
)
200+
self._managed_collision_collection._has_uninitialized_input_dists = (
201+
False
202+
)
203+
if self._managed_collision_collection._features_order:
204+
self._features_order = (
205+
self._managed_collision_collection._features_order
206+
)
207+
self._managed_collision_collection._features_order = []
193208
if self._embedding_module._has_mean_pooling_callback:
194209
self._embedding_module._init_mean_pooling_callback(
195210
features.keys(),
@@ -199,6 +214,11 @@ def _mc_input_dist(
199214
self._embedding_module._has_uninitialized_input_dist = False
200215
if isinstance(self._embedding_module, ShardedEmbeddingBagCollection):
201216
with torch.no_grad():
217+
if self._features_order:
218+
features = features.permute(
219+
self._features_order,
220+
self._managed_collision_collection._features_order_tensor,
221+
)
202222
if self._embedding_module._has_mean_pooling_callback:
203223
ctx.divisor = _create_mean_pooling_divisor(
204224
lengths=features.lengths(),

0 commit comments

Comments
 (0)