Bug description
Training with Context Parallelism and ScaledDotProductAttention crashes with:
RuntimeError: aten.add.Tensor got mixed torch.Tensor and DTensor,
need to convert all torch.Tensor to DTensor before calling distributed operators!
The error fires during the forward pass inside F.scaled_dot_product_attention, inside an activation-checkpointed transformer layer.
Reproduction
Run any CP training job using ScaledDotProductAttention with AC enabled (e.g. torchtitan/distributed/context_parallel.py path).
Root cause
apply_cp_to_forward (context_parallel.py:80) wraps ScaledDotProductAttention.forward with cp_forward, which wraps q/k/v as DTensors
and then calls orig_fn — the original ScaledDotProductAttention.forward:
def cp_forward(q, k, v, **kwargs):
q = DTensor.from_local(q, mesh, [Shard(1)], run_check=False)
k = DTensor.from_local(k, mesh, [Shard(1)], run_check=False)
v = DTensor.from_local(v, mesh, [Shard(1)], run_check=False)
output = orig_fn(q, k, v, **kwargs) # ← passes DTensors into SDPA forward
return output.to_local() if isinstance(output, DTensor) else output
ScaledDotProductAttention.forward transposes q/k/v (still DTensors) and then calls:
with sdpa_kernel([CUDNN_ATTENTION, FLASH_ATTENTION, MATH], set_priority=True):
out = F.scaled_dot_product_attention(q, k, v, ...)
The SDPA backend selector (_select_sdp_backend in C++) receives DTensor inputs. Both CUDNN and FLASH backends fail their constraint
checks (contiguity / stride requirements) against DTensors. MATH is selected as the last resort.
The MATH backend decomposes SDPA into constituent ops. It creates an intermediate attention bias as a plain CUDA tensor:
attn_weight = q @ k.T # DTensor @ DTensor → DTensor
attn_weight += attn_bias # aten.add(DTensor, plain_tensor) ← crash
DTensor's dispatcher for aten.add detects the mixed types and throws.
Finally, this should probably be a 1 but I was not able to test it
Versions
Single script repro:
import torchtitan.hf_datasets.text_datasets
import torchtitan.components.lr_scheduler
import torchtitan.components.checkpoint
import torchtitan.components.dataloader
import torchtitan.components.optimizer
import torchtitan.components.validate
import torchtitan.components.metrics
import torchtitan.components.loss
import torchtitan.tools.logging
import torchtitan.models.common
import torchtitan.models.llama3
import torchtitan.protocols
import torchtitan.trainer
import torchtitan.config
import torchtitan
torchtitan.tools.logging.init_logger()
torchtitan.trainer.Trainer.Config(
model_spec=torchtitan.models.llama3.model_registry("debugmodel", attn_backend="sdpa"),
dataloader = torchtitan.hf_datasets.text_datasets.HuggingFaceTextDataLoader.Config(
dataset = "c4_test",
num_workers = 1,
),
parallelism = torchtitan.trainer.ParallelismConfig(
data_parallel_replicate_degree=1,
data_parallel_shard_degree=-1,
context_parallel_degree=4,
tensor_parallel_degree=1,
context_parallel_load_balancer=None,
),
compile = torchtitan.trainer.CompileConfig(
enable=False,
),
optimizer = torchtitan.components.optimizer.OptimizersContainer.Config(
lr=1e-4,
),
training = torchtitan.config.TrainingConfig(
local_batch_size=1,
seq_len=512,
steps=100,
dtype="float32",
mixed_precision_param="float32"
),
loss = torchtitan.components.loss.CrossEntropyLoss.Config(),
lr_scheduler=torchtitan.components.lr_scheduler.LRSchedulersContainer.Config(
warmup_steps=10,
decay_ratio=0,
decay_type="linear",
min_lr_factor=0.0,
),
checkpoint=torchtitan.components.checkpoint.CheckpointManager.Config(
interval=100,
last_save_model_only=False,
export_dtype="bfloat16",
enable=True,
),
activation_checkpoint=torchtitan.config.ActivationCheckpointConfig(
mode="selective",
),
metrics=torchtitan.components.metrics.MetricsProcessor.Config(
enable_tensorboard=True,
enable_wandb=False,
log_freq=1
),
debug=torchtitan.config.configs.DebugConfig(
seed=42,
deterministic=True,
),
validator=torchtitan.components.validate.Validator.Config(
enable=True,
freq=10,
dataloader=torchtitan.hf_datasets.text_datasets.HuggingFaceTextDataLoader.Config(
dataset = "c4_test",
dataset_path = None,
infinite = False,
),
steps=3,
),
hf_assets_path="./tests/assets/tokenizer",
dump_folder=".data/experiments/test",
).build().train()
run with
uv run torchrun --nproc-per-node 4 scripts/debug-llama3-sdpa-cp.py
torch version:
2.13.0.dev20260512+cu126
Bug description
Training with Context Parallelism and ScaledDotProductAttention crashes with:
The error fires during the forward pass inside F.scaled_dot_product_attention, inside an activation-checkpointed transformer layer.
Reproduction
Run any CP training job using ScaledDotProductAttention with AC enabled (e.g. torchtitan/distributed/context_parallel.py path).
Root cause
apply_cp_to_forward(context_parallel.py:80) wraps ScaledDotProductAttention.forward with cp_forward, which wraps q/k/v as DTensorsand then calls orig_fn — the original ScaledDotProductAttention.forward:
ScaledDotProductAttention.forward transposes q/k/v (still DTensors) and then calls:
The SDPA backend selector (_select_sdp_backend in C++) receives DTensor inputs. Both CUDNN and FLASH backends fail their constraint
checks (contiguity / stride requirements) against DTensors. MATH is selected as the last resort.
The MATH backend decomposes SDPA into constituent ops. It creates an intermediate attention bias as a plain CUDA tensor:
Finally, this should probably be a
1but I was not able to test ittorchtitan/torchtitan/distributed/context_parallel.py
Line 81 in 6f2fa2f
Versions
Single script repro:
run with
torch version:
2.13.0.dev20260512+cu126