Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions monai/networks/blocks/crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -44,6 +44,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -62,6 +63,7 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
Comment thread
virginiafdez marked this conversation as resolved.
Outdated
"""

super().__init__()
Expand All @@ -81,6 +83,17 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False"
)

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.context_input_size = context_input_size if context_input_size else hidden_size
Expand All @@ -101,6 +114,7 @@ def __init__(

self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -145,23 +159,29 @@ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
q = q.view(b, t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, t, hs)
k = k.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
v = v.view(b, kv_t, self.num_heads, c // self.num_heads).transpose(1, 2) # (b, nh, kv_t, hs)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
if self.use_flash_attention:
x = torch.nn.functional.scaled_dot_product_attention(q, k, v).contiguous()
else:

att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
# apply relative positional embedding if defined
att_mat = (
self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
Comment thread
KumoLiu marked this conversation as resolved.
Outdated
)
Comment thread
virginiafdez marked this conversation as resolved.
Outdated

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
46 changes: 33 additions & 13 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@

import torch
import torch.nn as nn
import torch.nn.functional as F

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import
from monai.utils import optional_import, pytorch_after

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")

Expand All @@ -42,6 +43,7 @@ def __init__(
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -59,6 +61,7 @@ def __init__(
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.

"""

Expand All @@ -82,6 +85,17 @@ def __init__(
if causal and sequence_length is None:
raise ValueError("sequence_length is necessary for causal attention.")

if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
raise ValueError(
"use_flash_attention is only supported for PyTorch versions >= 2.0."
"Upgrade your PyTorch or set the flag to False."
)
if use_flash_attention and save_attn:
raise ValueError(
"save_attn has been set to True, but use_flash_attention is also set"
"to True. save_attn can only be used if use_flash_attention is False"
)

self.num_heads = num_heads
self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
Expand All @@ -97,6 +111,7 @@ def __init__(
self.attention_dtype = attention_dtype
self.causal = causal
self.sequence_length = sequence_length
self.use_flash_attention = use_flash_attention

if causal and sequence_length is not None:
# causal mask to ensure that attention is only applied to the left in the input sequence
Expand Down Expand Up @@ -130,23 +145,28 @@ def forward(self, x):
q = q.to(self.attention_dtype)
k = k.to(self.attention_dtype)

att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
if self.use_flash_attention:
x = F.scaled_dot_product_attention(q, k, v)
Comment thread
KumoLiu marked this conversation as resolved.
Outdated
else:
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
# apply relative positional embedding if defined
att_mat = (
self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
Comment thread
ericspod marked this conversation as resolved.
Outdated
)

if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
if self.causal:
att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[1], : x.shape[1]] == 0, float("-inf"))
Comment thread
virginiafdez marked this conversation as resolved.
Outdated

att_mat = att_mat.softmax(dim=-1)
att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()
if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
self.att_mat = att_mat.detach()

att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = self.out_rearrange(x)
x = self.out_proj(x)
x = self.drop_output(x)
Expand Down
8 changes: 7 additions & 1 deletion monai/networks/blocks/spatialattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SpatialAttentionBlock(nn.Module):
num_channels: number of input channels. Must be divisible by num_head_channels.
num_head_channels: number of channels per head.
attention_dtype: cast attention operations to this dtype.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.

"""

Expand All @@ -44,6 +45,7 @@ def __init__(
norm_num_groups: int = 32,
norm_eps: float = 1e-6,
attention_dtype: Optional[torch.dtype] = None,
use_flash_attention: bool = False,
) -> None:
super().__init__()

Expand All @@ -54,7 +56,11 @@ def __init__(
raise ValueError("num_channels must be divisible by num_head_channels")
num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
self.attn = SABlock(
hidden_size=num_channels, num_heads=num_heads, qkv_bias=True, attention_dtype=attention_dtype
hidden_size=num_channels,
num_heads=num_heads,
qkv_bias=True,
attention_dtype=attention_dtype,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor):
Expand Down
10 changes: 9 additions & 1 deletion monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
causal: bool = False,
sequence_length: int | None = None,
with_cross_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
"""
Args:
Expand All @@ -45,6 +46,7 @@ def __init__(
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.

"""

Expand All @@ -66,13 +68,19 @@ def __init__(
save_attn=save_attn,
causal=causal,
sequence_length=sequence_length,
use_flash_attention=use_flash_attention,
)
self.norm2 = nn.LayerNorm(hidden_size)
self.with_cross_attention = with_cross_attention

self.norm_cross_attn = nn.LayerNorm(hidden_size)
self.cross_attn = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, qkv_bias=qkv_bias, causal=False
hidden_size=hidden_size,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
causal=False,
use_flash_attention=use_flash_attention,
)

def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down
4 changes: 4 additions & 0 deletions monai/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class DiffusionUNetTransformerBlock(nn.Module):
dropout: dropout probability to use.
cross_attention_dim: size of the context vector for cross attention.
upcast_attention: if True, upcast attention operations to full precision.
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.

"""

Expand All @@ -77,6 +78,7 @@ def __init__(
dropout: float = 0.0,
cross_attention_dim: int | None = None,
upcast_attention: bool = False,
use_flash_attention: bool = False,
) -> None:
super().__init__()
self.attn1 = SABlock(
Expand All @@ -86,6 +88,7 @@ def __init__(
dim_head=num_head_channels,
dropout_rate=dropout,
attention_dtype=torch.float if upcast_attention else None,
use_flash_attention=use_flash_attention,
)
self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
self.attn2 = CrossAttentionBlock(
Expand All @@ -96,6 +99,7 @@ def __init__(
dim_head=num_head_channels,
dropout_rate=dropout,
attention_dtype=torch.float if upcast_attention else None,
use_flash_attention=use_flash_attention,
)
self.norm1 = nn.LayerNorm(num_channels)
self.norm2 = nn.LayerNorm(num_channels)
Expand Down
16 changes: 15 additions & 1 deletion tests/test_crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.networks.blocks.crossattention import CrossAttentionBlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

einops, has_einops = optional_import("einops")

Expand Down Expand Up @@ -50,10 +51,16 @@ class TestResBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_CABLOCK)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
Comment thread
ericspod marked this conversation as resolved.
# Without flash attention
net = CrossAttentionBlock(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"]))
self.assertEqual(result.shape, expected_shape)
# With flash attention
net = CrossAttentionBlock(**input_param, use_flash_attention=True)
Comment thread
virginiafdez marked this conversation as resolved.
Outdated
with eval_mode(net):
result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"]))
self.assertEqual(result.shape, expected_shape)

def test_ill_arg(self):
with self.assertRaises(ValueError):
Expand All @@ -62,6 +69,13 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4)

@SkipIfBeforePyTorchVersion((1, 13))
def test_save_attn_with_flash_attention(self):
with self.assertRaises(ValueError):
CrossAttentionBlock(
hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True
)

@skipUnless(has_einops, "Requires einops")
def test_attention_dim_not_multiple_of_heads(self):
with self.assertRaises(ValueError):
Expand Down Expand Up @@ -119,7 +133,7 @@ def test_access_attn_matrix(self):
# no of elements is zero
assert no_matrix_acess_blk.att_mat.nelement() == 0

# be able to acess the attention matrix
# be able to acess the attention matrix.
matrix_acess_blk = CrossAttentionBlock(
hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.networks.blocks.selfattention import SABlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import
from tests.utils import SkipIfBeforePyTorchVersion

einops, has_einops = optional_import("einops")

Expand Down Expand Up @@ -49,11 +50,17 @@ class TestResBlock(unittest.TestCase):

@parameterized.expand(TEST_CASE_SABLOCK)
@skipUnless(has_einops, "Requires einops")
@SkipIfBeforePyTorchVersion((0, 2))
Comment thread
virginiafdez marked this conversation as resolved.
Outdated
def test_shape(self, input_param, input_shape, expected_shape):
net = SABlock(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)
# With flash attention
net_fa = SABlock(**input_param, use_flash_attention=True)
with eval_mode(net):
result_fa = net_fa(torch.randn(input_shape))
self.assertEqual(result_fa.shape, expected_shape)

def test_ill_arg(self):
with self.assertRaises(ValueError):
Expand All @@ -62,6 +69,11 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)

@SkipIfBeforePyTorchVersion((1, 13))
def test_save_attn_with_flash_attention(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True)

def test_attention_dim_not_multiple_of_heads(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1)
Expand Down