@@ -190,6 +190,21 @@ def _mc_input_dist(
190190) -> Awaitable [Awaitable [KJTList ]]:
191191 if self ._embedding_module ._has_uninitialized_input_dist :
192192 if isinstance (self ._embedding_module , ShardedEmbeddingBagCollection ):
193+ self ._features_order = []
194+ # disable feature permutation in mc, because we should
195+ # permute features in mc-ebc before mean pooling callback.
196+ if self ._managed_collision_collection ._has_uninitialized_input_dists :
197+ self ._managed_collision_collection ._create_input_dists (
198+ input_feature_names = features .keys ()
199+ )
200+ self ._managed_collision_collection ._has_uninitialized_input_dists = (
201+ False
202+ )
203+ if self ._managed_collision_collection ._features_order :
204+ self ._features_order = (
205+ self ._managed_collision_collection ._features_order
206+ )
207+ self ._managed_collision_collection ._features_order = []
193208 if self ._embedding_module ._has_mean_pooling_callback :
194209 self ._embedding_module ._init_mean_pooling_callback (
195210 features .keys (),
@@ -199,6 +214,11 @@ def _mc_input_dist(
199214 self ._embedding_module ._has_uninitialized_input_dist = False
200215 if isinstance (self ._embedding_module , ShardedEmbeddingBagCollection ):
201216 with torch .no_grad ():
217+ if self ._features_order :
218+ features = features .permute (
219+ self ._features_order ,
220+ self ._managed_collision_collection ._features_order_tensor ,
221+ )
202222 if self ._embedding_module ._has_mean_pooling_callback :
203223 ctx .divisor = _create_mean_pooling_divisor (
204224 lengths = features .lengths (),
0 commit comments