Skip to content

🐛 [Bug] Accuracy issue with dynamic shapes for rotary embeddings #3978

@zhaoyuanh

Description

@zhaoyuanh

Bug Description

from typing import Optional, Tuple
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch_tensorrt


class CosmosRotaryPosEmbed(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        max_size: Tuple[int, int, int] = (128, 240, 240),
        patch_size: Tuple[int, int, int] = (1, 2, 2),
        base_fps: int = 24,
        rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
    ) -> None:
        super().__init__()

        self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
        self.patch_size = patch_size
        self.base_fps = base_fps

        self.dim_h = hidden_size // 6 * 2
        self.dim_w = hidden_size // 6 * 2
        self.dim_t = hidden_size - self.dim_h - self.dim_w

        self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
        self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
        self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))

    def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None, num_ranks: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
        if num_ranks is not None:
            pe_size[0] = pe_size[0] * num_ranks
        device = hidden_states.device

        h_theta = 10000.0 * self.h_ntk_factor
        w_theta = 10000.0 * self.w_ntk_factor
        t_theta = 10000.0 * self.t_ntk_factor

        seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
        dim_h_range = (
            torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
        )
        dim_w_range = (
            torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
        )
        dim_t_range = (
            torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
        )
        h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
        w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
        temporal_freqs = 1.0 / (t_theta**dim_t_range)

        # Use expand() instead of repeat() for torch_tensorrt compatibility
        emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].expand(pe_size[0], -1, pe_size[2], -1)
        emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].expand(pe_size[0], pe_size[1], -1, -1)

        # Apply sequence scaling in temporal dimension
        if fps is None:
            # Images
            emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
        else:
            # Videos
            emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)

        emb_t = emb_t[:, None, None, :].expand(-1, pe_size[1], pe_size[2], -1)
        freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
        cos = torch.cos(freqs)
        sin = torch.sin(freqs)
        if num_ranks is not None:
            cos = cos.view(num_ranks, cos.shape[0] // num_ranks, *cos.shape[(1) :])
            cos = cos[0]
            sin = sin.view(num_ranks, sin.shape[0] // num_ranks, *sin.shape[1:])
            sin = sin[0]
        return cos, sin


def export_attention(model, hidden_states, fps, num_ranks):
    with torch.no_grad():
        # Only mark sequence length as dynamic, like run_llm.py does
        # Don't mark batch dimension as dynamic to avoid constraint violations
        seq_len = torch.export.Dim("seq_len", min=1, max=16)
        print("Trying to export the model using torch.export.export()..")
        # strict=False only enables autograd tracing and excludes dynamo.
        # Use tuple format like export_llm - only mark sequence length (dim 1) as dynamic
        ep = torch.export.export(
            model,
            args=(hidden_states, fps, num_ranks),
            kwargs={},
            dynamic_shapes=({2: seq_len}, None, None), 
            strict=False,
        )

    return ep


def compile_torchtrt(model, hidden_states, fps, num_ranks, min_block_size, debug):
    ep = export_attention(model, hidden_states, fps, num_ranks)
    # Set precision specific flags
    use_fp32_acc = False
    use_explicit_typing = False
    enabled_precisions = {torch.bfloat16}
    use_fp32_acc = False

    with torch_tensorrt.logging.debug() if debug else nullcontext():
        trt_model = torch_tensorrt.dynamo.compile(
            ep,
            inputs=[hidden_states, fps, num_ranks],
            enabled_precisions=enabled_precisions,
            # truncate_double=True,
            use_explicit_typing=use_explicit_typing,
            use_fp32_acc=use_fp32_acc,
            disable_tf32=True,
            use_python_runtime=True,
            debug=debug,
            offload_module_to_cpu=False,
            min_block_size=min_block_size,
        )

    return trt_model


if __name__ == "__main__":
    min_block_size = 1
    attention_head_dim = 128
    enable_pytorch_run = True
    debug = False
    device = "cuda"

    # hidden_size = num_attention_heads * attention_head_dim

    with torch.inference_mode():
        model = CosmosRotaryPosEmbed(
            hidden_size=attention_head_dim, 
            max_size=(128, 240, 240), 
            patch_size=(1, 2, 2), 
            rope_scale=(2.0, 1.0, 1.0),
        ).to(device)

        # Convert model to the appropriate precision
        model = model.to(torch.bfloat16)
        input_dtype = torch.bfloat16

        # Prepare input for benchmarking or evaluation
        hidden_states = torch.randn(
            1, 17, 8, 88, 160, dtype=input_dtype
        ).to(device)
        fps = 30
        num_ranks = 2

        # Pyt
        pyt_output_cos, pyt_output_sin = model(hidden_states, fps, num_ranks)
        print("PyTorch output shape:", pyt_output_cos.shape, pyt_output_sin.shape)
        print("Pytorch output:", pyt_output_cos.flatten(), pyt_output_sin.flatten())

        # Compile the model with Torch-TensorRT
        trt_model = compile_torchtrt(model, hidden_states, fps, num_ranks, min_block_size, debug)
        # trt_model = torch.compile(
        #     model,
        #     backend="torch_tensorrt",
        #     options={
        #         "enabled_precisions": {input_dtype},
        #         "use_python_runtime": True,
        #         "min_block_size": min_block_size,
        #     },
        #     dynamic=None,
        # )
        trt_model = trt_model.to(device)

        trt_output_cos, trt_output_sin = trt_model(hidden_states, fps, num_ranks)
        print("TensorRT output shape:", trt_output_cos.shape, trt_output_sin.shape)
        print("TensorRT output:", trt_output_cos.flatten(), trt_output_sin.flatten())
    
    # Verify results match
    diff_cos = (pyt_output_cos - trt_output_cos).abs().max().item()
    diff_sin = (pyt_output_sin - trt_output_sin).abs().max().item()
    print(f"Max difference between PyTorch and TRT: {diff_cos}, {diff_sin}")

    # Check if results are close enough
    tolerance = 0.01
    if diff_cos < tolerance:
        print(f"✅ Results match! (difference: {diff_cos} < {tolerance})")
    else:
        print(f"⚠️  Results differ! (difference: {diff_cos} >= {tolerance})")
    
    if diff_sin < tolerance:
        print(f"✅ Results match! (difference: {diff_sin} < {tolerance})")
    else:
        print(f"⚠️  Results differ! (difference: {diff_sin} >= {tolerance})")

Here is the result:

⚠️  Results differ! (difference: 0.11663024872541428 >= 0.01)
⚠️  Results differ! (difference: 0.1167231947183609 >= 0.01)

When I disable dynamic shapes, like commenting out dynamic_shapes=({2: seq_len}, None, None), the result can match.

To Reproduce

Steps to reproduce the behavior:

  1. Run Python script above

Expected behavior

The Torch-TRT output matches the Torch output.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.10.0.dev0
  • PyTorch Version (e.g. 1.0): 2.9.0
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): PYTHON_ONLY=1 pip install -e .
  • Are you using local sources or building from archives: No
  • Python version: 3.10
  • CUDA version: 12.9
  • GPU models and configuration: Nvidia B200
  • Any other relevant information: None

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstory: LLM & Generative AILarge language models (GPT2, Llama, Mistral, Qwen), diffusion models (FLUX, SD), VLMs, MoE, attentio

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions