Skip to content

Commit c6079be

Browse files
committed
import from different paths for different pytorch versions
Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 8eca187 commit c6079be

1 file changed

Lines changed: 5 additions & 10 deletions

File tree

examples/speculative_decoding/eagle_utils.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from packaging.version import Version
3232
from PIL import Image
3333
from scripts.ar_validate import validate_ar
34-
from torch.distributed.tensor.experimental._context_parallel._attention import _SDPAMerger
3534
from torch.utils.data import Dataset
3635
from transformers import AutoProcessor, Trainer, TrainerCallback
3736
from transformers.trainer_pt_utils import LabelSmoother
@@ -679,13 +678,9 @@ def patch_ring_attention_for_ttt():
679678
# Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
680679

681680
if Version(torch.__version__) < Version("2.10.0"):
682-
raise RuntimeError(
683-
f"Context parallel TTT only supported for PyTorch >= 2.10.0. "
684-
f"Got {torch.__version__}. "
685-
f"Please use torch 2.10.0 or cp_size=1."
686-
)
687-
688-
from torch.distributed.tensor.experimental._context_parallel import _attention
681+
from torch.distributed.tensor.experimental import _attention
682+
else:
683+
from torch.distributed.tensor.experimental._context_parallel import _attention
689684

690685
# 1. Disable load balance, which is designed for causal mask.
691686
# This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
@@ -702,11 +697,11 @@ def patch_ring_attention_for_ttt():
702697
)
703698

704699
# 3. Patch merger to skip the blank shard to avoid difference in output.
705-
original_sdpa_merger_step = _SDPAMerger.step
700+
original_sdpa_merger_step = _attention._SDPAMerger.step
706701

707702
def patched_sdpa_merger_step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool):
708703
if lse.sum() <= 0:
709704
return
710705
return original_sdpa_merger_step(self, out, lse, partial)
711706

712-
_SDPAMerger.step = patched_sdpa_merger_step
707+
_attention._SDPAMerger.step = patched_sdpa_merger_step

0 commit comments

Comments
 (0)