3131from packaging .version import Version
3232from PIL import Image
3333from scripts .ar_validate import validate_ar
34- from torch .distributed .tensor .experimental ._context_parallel ._attention import _SDPAMerger
3534from torch .utils .data import Dataset
3635from transformers import AutoProcessor , Trainer , TrainerCallback
3736from transformers .trainer_pt_utils import LabelSmoother
@@ -679,13 +678,9 @@ def patch_ring_attention_for_ttt():
679678 # Torch Ring Attention only supports no mask or causal mask. We apply the following patches to enable TTT mask.
680679
681680 if Version (torch .__version__ ) < Version ("2.10.0" ):
682- raise RuntimeError (
683- f"Context parallel TTT only supported for PyTorch >= 2.10.0. "
684- f"Got { torch .__version__ } . "
685- f"Please use torch 2.10.0 or cp_size=1."
686- )
687-
688- from torch .distributed .tensor .experimental ._context_parallel import _attention
681+ from torch .distributed .tensor .experimental import _attention
682+ else :
683+ from torch .distributed .tensor .experimental ._context_parallel import _attention
689684
690685 # 1. Disable load balance, which is designed for causal mask.
691686 # This affect how buffers are sharded. So need to be done permanently before accelerate/hf trainer init.
@@ -702,11 +697,11 @@ def patch_ring_attention_for_ttt():
702697 )
703698
704699 # 3. Patch merger to skip the blank shard to avoid difference in output.
705- original_sdpa_merger_step = _SDPAMerger .step
700+ original_sdpa_merger_step = _attention . _SDPAMerger .step
706701
707702 def patched_sdpa_merger_step (self , out : torch .Tensor , lse : torch .Tensor , partial : bool ):
708703 if lse .sum () <= 0 :
709704 return
710705 return original_sdpa_merger_step (self , out , lse , partial )
711706
712- _SDPAMerger .step = patched_sdpa_merger_step
707+ _attention . _SDPAMerger .step = patched_sdpa_merger_step
0 commit comments