-
Notifications
You must be signed in to change notification settings - Fork 134
Expand file tree
/
Copy pathutils.py
More file actions
247 lines (205 loc) · 8.79 KB
/
utils.py
File metadata and controls
247 lines (205 loc) · 8.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import ContextDecorator, nullcontext
from datetime import datetime, timedelta
from typing import Optional
import torch
import torch.distributed
import torch.distributed as dist
logger = logging.getLogger(__name__)
def _create_gloo_group():
"""
Create a Gloo process group for barrier operations.
This allows us to use monitored_barrier with Gloo backend while
keeping NCCL for the main training operations.
Returns:
ProcessGroup: Gloo process group for barriers
"""
if not dist.is_initialized():
return None
try:
# Create a Gloo group for barrier operations
gloo_group = dist.new_group(backend="gloo")
logger.debug("Created Gloo group for barrier operations")
return gloo_group
except Exception as e:
logger.warning(f"Failed to create Gloo group: {e}")
return None
def _barrier_with_timeout(timeout: timedelta, group=None):
"""
A timeout wrapper for torch.distributed.barrier() using Gloo backend.
This approach creates a separate Gloo process group for barrier operations
while keeping the main NCCL backend for training operations.
Args:
timeout: Maximum time to wait for the barrier
group: Process group for the barrier operation
Returns:
bool: True if barrier completed successfully, False if timeout occurred
"""
if not dist.is_initialized() or dist.get_world_size() == 1:
return True
# Use Gloo group for barrier operations
gloo_group = _create_gloo_group()
if gloo_group is None:
# Fallback to regular barrier if Gloo group creation fails
try:
dist.barrier(group=group)
return True
except Exception as e:
logger.warning(f"Barrier failed: {e}")
return False
try:
# Use monitored_barrier with Gloo group
dist.monitored_barrier(group=gloo_group, timeout=timeout)
return True
except Exception as e:
logger.warning(f"Monitored barrier failed: {e}")
return False
finally:
# Clean up the Gloo group
try:
dist.destroy_process_group(gloo_group)
except Exception:
pass
class FirstRankPerNode(ContextDecorator):
"""
Context manager to enforce rank0 to process section over other ranks.
- Lets RANK==0 run the protected code first on each node.
- Inserts an extra barrier across *only* the node-local rank-0 processes.
- Works on a single GPU (no env flags, no distributed initialisation).
Note: it is assumed the scoped code is not torch.distributed heavy.
"""
def __enter__(self, timeout=timedelta(hours=10)):
"""
Create / bootstrap a (distributed) proc. group that rank0 enters first.
Returns:
bool: ``True`` - if the current process is node-rank-0
``False`` - otherwise
"""
self._created_pg = False
self._node0_group = None
self._first = True # default for single-GPU / no-dist case
self._timeout = timeout
if not dist.is_initialized() or dist.get_world_size() == 1:
# pure single GPU
return True
# Figure out rank
self._first = dist.get_rank() == 0
# Synchronisation logic
if not self._first:
# Non-rank-0 processes wait for their node-rank-0
# Use Gloo group for monitored_barrier to avoid NCCL timeout issues
success = _barrier_with_timeout(timeout=self._timeout)
if not success:
logger.warning("Barrier timed out, continuing anyway")
return self._first
def __exit__(self, exc_type, exc_val, exc_tb):
"""
Tear down the context.
1. If the current process was the first on its node, release the
waiting peer ranks by issuing a barrier.
2. If an exception occurred, abort the *entire* distributed job.
3. If this context manager created the process group, destroy it.
Args:
exc_type (Type[BaseException] | None): Exception class if one
occurred inside the ``with`` block.
exc_val (BaseException | None): The raised exception instance.
exc_tb (TracebackType | None): Traceback associated with the
exception.
Returns:
bool: ``False`` so that any exception raised inside the ``with``
block is propagated to the caller (standard CM semantics).
"""
try:
if self._first and dist.is_initialized():
# Re-sync the whole world so that non-rank-0s can proceed
# Use Gloo group for monitored_barrier to avoid NCCL timeout issues
success = _barrier_with_timeout(timeout=self._timeout)
if not success:
logger.warning("Barrier timed out during exit, continuing anyway")
if exc_type is not None:
# Log the exception and make the error visible
logger.error(
"Exception inside FirstRankPerNode: %s: %s",
exc_type.__name__,
exc_val,
exc_info=(exc_type, exc_val, exc_tb),
)
raise SystemExit(1) from exc_val
finally:
if self._created_pg:
dist.destroy_process_group()
# propagate any exception to outer scope
return False
def barrier_and_log(string: str) -> None:
"""
Perform a distributed barrier and then log a message on rank 0.
Args:
string: The message string to log.
"""
if torch.distributed.is_initialized():
torch.distributed.barrier()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
logger.info("[{}] datetime: {} ".format(string, time_str))
def reduce_loss(
loss_store: list[torch.Tensor],
total_num_tokens: torch.Tensor,
per_token_loss: bool = True,
dp_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Reduce loss across all ranks.
Args:
loss_store: List of loss tensors to reduce.
total_num_tokens: Total number of tokens to divide the loss by.
per_token_loss: Whether to divide the loss by the number of tokens.
dp_group: Process group to reduce the loss across.
Returns:
Tuple of reduced loss and denominator.
"""
loss = torch.sum(torch.stack(loss_store).float()).view(1).clone().detach()
if dp_group is not None:
dist.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=dp_group)
if per_token_loss:
denominator = total_num_tokens.clone().detach().to(torch.int)
else:
denominator = torch.tensor([len(loss_store)], dtype=torch.int, device="cuda")
if dp_group is not None:
dist.all_reduce(denominator, op=torch.distributed.ReduceOp.SUM, group=dp_group)
return loss, denominator
def get_sync_ctx(model, is_optim_step, defer_fsdp_grad_sync: bool):
"""
Get the synchronization context for the model.
Args:
model: The model to synchronize.
is_optim_step: Whether the current step is an optimizer step.
defer_fsdp_grad_sync: Controls FSDP2 gradient synchronization during gradient accumulation.
- True: disable gradient sync on non-final micro-batches (saves comm, can increase peak memory).
- False: always sync gradients on every micro-batch (more comm, lower peak memory).
Returns:
A context manager that synchronizes the model.
"""
# Use `no_sync` on DDP models when we are *not* on the final micro-batch for
# this gradient update (i.e., when `is_grad` is False). This avoids an
# all-reduce for every micro-batch and greatly improves throughput.
sync_ctx = nullcontext()
if isinstance(model, dist.fsdp._fully_shard._fully_shard.FSDPModule):
if defer_fsdp_grad_sync:
model.set_requires_gradient_sync(is_optim_step)
else:
model.set_requires_gradient_sync(True)
elif isinstance(model, torch.nn.parallel.DistributedDataParallel) and not is_optim_step:
sync_ctx = model.no_sync()
return sync_ctx