|
9 | 9 | # See the License for the specific language governing permissions and |
10 | 10 | # limitations under the License. |
11 | 11 |
|
12 | | -from collections import OrderedDict |
13 | | -from typing import List |
| 12 | +from typing import Dict, List, Optional, Tuple |
14 | 13 |
|
15 | 14 | import torch |
16 | 15 | from torch import distributed as dist |
17 | | -from torch import nn |
18 | | -from torchrec.distributed.types import ShardingPlan, ShardingType |
19 | | - |
20 | | - |
21 | | -def sync_dp_emb_table(model: nn.Module, plan: ShardingPlan) -> None: |
22 | | - """Sync data parallel embedding table params.""" |
23 | | - dp_param_names = [] |
24 | | - for _, module_plan in plan.plan.items(): |
25 | | - # pyre-ignore [16] |
26 | | - for param_name, param_sharding in module_plan.items(): |
27 | | - if param_sharding.sharding_type == ShardingType.DATA_PARALLEL.value: |
28 | | - dp_param_names.append(param_name) |
29 | | - dp_params = OrderedDict() |
30 | | - for name, param in model.named_parameters(): |
31 | | - name_parts = name.split(".") |
32 | | - if ( |
33 | | - len(name_parts) > 2 |
34 | | - and name_parts[-1] == "weight" |
35 | | - and name_parts[-2] in dp_param_names |
36 | | - ): |
37 | | - # pyre-ignore [16] |
38 | | - ori_t = param._original_tensor |
39 | | - if ori_t not in dp_params: |
40 | | - dp_params[ori_t] = 1 |
41 | | - broadcast_works = [] |
42 | | - for t in dp_params: |
43 | | - broadcast_works.append(dist.broadcast(t.detach(), src=0, async_op=True)) |
44 | | - for w in broadcast_works: |
45 | | - w.wait() |
| 16 | +from torch.autograd.profiler import record_function |
| 17 | +from torchrec.distributed import embeddingbag |
| 18 | +from torchrec.distributed.utils import none_throws |
| 19 | +from torchrec.modules.embedding_configs import PoolingType |
| 20 | +from torchrec.sparse.jagged_tensor import _to_offsets |
46 | 21 |
|
47 | 22 |
|
48 | 23 | def broadcast_string(s: str, src: int = 0) -> str: |
@@ -106,3 +81,80 @@ def gather_strings(s: str, dst: int = 0) -> List[str]: |
106 | 81 | gathered_strings.append(string) |
107 | 82 |
|
108 | 83 | return gathered_strings |
| 84 | + |
| 85 | + |
| 86 | +# lengths of kjt will be modified by create_mean_pooling_divisor, we fix it |
| 87 | +# with lengths = lengths.clone() temporarily. |
| 88 | +def _create_mean_pooling_divisor( |
| 89 | + lengths: torch.Tensor, |
| 90 | + keys: List[str], |
| 91 | + offsets: torch.Tensor, |
| 92 | + stride: int, |
| 93 | + stride_per_key: List[int], |
| 94 | + dim_per_key: torch.Tensor, |
| 95 | + pooling_type_to_rs_features: Dict[str, List[str]], |
| 96 | + embedding_names: List[str], |
| 97 | + embedding_dims: List[int], |
| 98 | + variable_batch_per_feature: bool, |
| 99 | + kjt_inverse_order: torch.Tensor, |
| 100 | + kjt_key_indices: Dict[str, int], |
| 101 | + kt_key_ordering: torch.Tensor, |
| 102 | + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, |
| 103 | + weights: Optional[torch.Tensor] = None, |
| 104 | +) -> torch.Tensor: |
| 105 | + with record_function("## ebc create mean pooling callback ##"): |
| 106 | + batch_size = ( |
| 107 | + none_throws(inverse_indices)[1].size(dim=1) |
| 108 | + if variable_batch_per_feature |
| 109 | + else stride |
| 110 | + ) |
| 111 | + |
| 112 | + if weights is not None: |
| 113 | + # if we have weights, lengths is the sum of weights by offsets for feature |
| 114 | + lengths = torch.ops.fbgemm.segment_sum_csr(1, offsets.int(), weights) |
| 115 | + |
| 116 | + if variable_batch_per_feature: |
| 117 | + inverse_indices = none_throws(inverse_indices) |
| 118 | + device = inverse_indices[1].device |
| 119 | + inverse_indices_t = inverse_indices[1] |
| 120 | + if len(keys) != len(inverse_indices[0]): |
| 121 | + inverse_indices_t = torch.index_select( |
| 122 | + inverse_indices[1], 0, kjt_inverse_order |
| 123 | + ) |
| 124 | + offsets = _to_offsets(torch.tensor(stride_per_key, device=device))[ |
| 125 | + :-1 |
| 126 | + ].unsqueeze(-1) |
| 127 | + indices = (inverse_indices_t + offsets).flatten() |
| 128 | + lengths = torch.index_select(input=lengths, dim=0, index=indices) |
| 129 | + |
| 130 | + # only convert the sum pooling features to be 1 lengths |
| 131 | + lengths = lengths.clone() |
| 132 | + for feature in pooling_type_to_rs_features[PoolingType.SUM.value]: |
| 133 | + feature_index = kjt_key_indices[feature] |
| 134 | + feature_index = feature_index * batch_size |
| 135 | + lengths[feature_index : feature_index + batch_size] = 1 |
| 136 | + |
| 137 | + if len(embedding_names) != len(keys): |
| 138 | + lengths = torch.index_select( |
| 139 | + lengths.reshape(-1, batch_size), |
| 140 | + 0, |
| 141 | + kt_key_ordering, |
| 142 | + ).reshape(-1) |
| 143 | + |
| 144 | + # transpose to align features with keyed tensor dim_per_key |
| 145 | + lengths = lengths.reshape(-1, batch_size).T # [batch_size, num_features] |
| 146 | + output_size = sum(embedding_dims) |
| 147 | + |
| 148 | + divisor = torch.repeat_interleave( |
| 149 | + input=lengths, |
| 150 | + repeats=dim_per_key, |
| 151 | + dim=1, |
| 152 | + output_size=output_size, |
| 153 | + ) |
| 154 | + eps = 1e-6 # used to safe guard against 0 division |
| 155 | + divisor = divisor + eps |
| 156 | + return divisor.detach() |
| 157 | + |
| 158 | + |
| 159 | +# pyre-ignore [9] |
| 160 | +embeddingbag._create_mean_pooling_divisor = _create_mean_pooling_divisor |
0 commit comments