Skip to content

RuntimeError: aten.add.Tensor got mixed torch.Tensor and DTensor when using Context Parallel SDPA #3340

@francesco-bertolotti

Description

@francesco-bertolotti

Bug description

Training with Context Parallelism and ScaledDotProductAttention crashes with:

  RuntimeError: aten.add.Tensor got mixed torch.Tensor and DTensor,
  need to convert all torch.Tensor to DTensor before calling distributed operators!

The error fires during the forward pass inside F.scaled_dot_product_attention, inside an activation-checkpointed transformer layer.

Reproduction

Run any CP training job using ScaledDotProductAttention with AC enabled (e.g. torchtitan/distributed/context_parallel.py path).

Root cause

apply_cp_to_forward (context_parallel.py:80) wraps ScaledDotProductAttention.forward with cp_forward, which wraps q/k/v as DTensors
and then calls orig_fn — the original ScaledDotProductAttention.forward:

  def cp_forward(q, k, v, **kwargs):
      q = DTensor.from_local(q, mesh, [Shard(1)], run_check=False)
      k = DTensor.from_local(k, mesh, [Shard(1)], run_check=False)
      v = DTensor.from_local(v, mesh, [Shard(1)], run_check=False)
      output = orig_fn(q, k, v, **kwargs)          # ← passes DTensors into SDPA forward
      return output.to_local() if isinstance(output, DTensor) else output

ScaledDotProductAttention.forward transposes q/k/v (still DTensors) and then calls:

  with sdpa_kernel([CUDNN_ATTENTION, FLASH_ATTENTION, MATH], set_priority=True):
      out = F.scaled_dot_product_attention(q, k, v, ...)

The SDPA backend selector (_select_sdp_backend in C++) receives DTensor inputs. Both CUDNN and FLASH backends fail their constraint
checks (contiguity / stride requirements) against DTensors. MATH is selected as the last resort.

The MATH backend decomposes SDPA into constituent ops. It creates an intermediate attention bias as a plain CUDA tensor:

  attn_weight = q @ k.T          # DTensor @ DTensor → DTensor
  attn_weight += attn_bias        # aten.add(DTensor, plain_tensor) ← crash

  DTensor's dispatcher for aten.add detects the mixed types and throws.

Finally, this should probably be a 1 but I was not able to test it

Versions

Single script repro:

import torchtitan.hf_datasets.text_datasets
import torchtitan.components.lr_scheduler
import torchtitan.components.checkpoint
import torchtitan.components.dataloader
import torchtitan.components.optimizer
import torchtitan.components.validate
import torchtitan.components.metrics
import torchtitan.components.loss
import torchtitan.tools.logging
import torchtitan.models.common
import torchtitan.models.llama3
import torchtitan.protocols
import torchtitan.trainer
import torchtitan.config
import torchtitan

torchtitan.tools.logging.init_logger()

torchtitan.trainer.Trainer.Config(
    model_spec=torchtitan.models.llama3.model_registry("debugmodel", attn_backend="sdpa"),
    dataloader = torchtitan.hf_datasets.text_datasets.HuggingFaceTextDataLoader.Config(
        dataset = "c4_test",
        num_workers = 1,
    ),
    parallelism = torchtitan.trainer.ParallelismConfig(
        data_parallel_replicate_degree=1, 
        data_parallel_shard_degree=-1,
        context_parallel_degree=4,
        tensor_parallel_degree=1,
        context_parallel_load_balancer=None,
    ),
    compile = torchtitan.trainer.CompileConfig(
        enable=False,
    ),
    optimizer = torchtitan.components.optimizer.OptimizersContainer.Config(
        lr=1e-4,
    ),
    training = torchtitan.config.TrainingConfig(
        local_batch_size=1,
        seq_len=512,
        steps=100,
        dtype="float32",
        mixed_precision_param="float32"
    ),
    loss = torchtitan.components.loss.CrossEntropyLoss.Config(),
    lr_scheduler=torchtitan.components.lr_scheduler.LRSchedulersContainer.Config(
        warmup_steps=10,
        decay_ratio=0,
        decay_type="linear",
        min_lr_factor=0.0,
    ),
    checkpoint=torchtitan.components.checkpoint.CheckpointManager.Config(
        interval=100,
        last_save_model_only=False,
        export_dtype="bfloat16",
        enable=True,
    ),
    activation_checkpoint=torchtitan.config.ActivationCheckpointConfig(
        mode="selective",
    ),
    metrics=torchtitan.components.metrics.MetricsProcessor.Config(
        enable_tensorboard=True,
        enable_wandb=False,
        log_freq=1
    ),
    debug=torchtitan.config.configs.DebugConfig(
        seed=42,
        deterministic=True,
    ),
    validator=torchtitan.components.validate.Validator.Config(
        enable=True,
        freq=10,
        dataloader=torchtitan.hf_datasets.text_datasets.HuggingFaceTextDataLoader.Config(
            dataset = "c4_test",
            dataset_path = None,
            infinite = False,
        ),
        steps=3,
    ),
    hf_assets_path="./tests/assets/tokenizer",
    dump_folder=".data/experiments/test",
).build().train()

run with

uv run torchrun --nproc-per-node 4 scripts/debug-llama3-sdpa-cp.py

torch version:
2.13.0.dev20260512+cu126

Metadata

Metadata

Assignees

Labels

wontfixThis will not be worked on

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions