Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 142 additions & 72 deletions timm/layers/pos_embed_sincos.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,9 @@ def swap_shape_xy(seq: List[int]) -> List[int]:
return [seq[1], seq[0]] + list(seq[2:])


def build_fourier_pos_embed(
def _build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False,
bands: torch.Tensor,
include_grid: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
Expand All @@ -100,46 +96,10 @@ def build_fourier_pos_embed(
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> List[torch.Tensor]:
"""

Args:
feat_shape: Feature shape for embedding.
bands: Pre-calculated frequency bands.
num_bands: Number of frequency bands (determines output dim).
max_res: Maximum resolution for pixel based freq.
temperature: Temperature for non-pixel freq.
linear_bands: Linear band spacing for pixel based freq.
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
grid_offset: Constant offset to add to grid for non-pixel freq.
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
dtype: Output dtype.
device: Output device.

Returns:

"""
if bands is None:
if in_pixels:
bands = pixel_freq_bands(
num_bands,
float(max_res),
linear_bands=linear_bands,
device=device,
)
else:
bands = freq_bands(
num_bands,
temperature=temperature,
step=1,
device=device,
)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype

if grid_indexing == 'xy':
feat_shape = swap_shape_xy(feat_shape)
Expand Down Expand Up @@ -170,6 +130,92 @@ def build_fourier_pos_embed(
return out


def _compute_bands(
bands: Optional[torch.Tensor],
num_bands: int,
max_res: int,
temperature: float,
linear_bands: bool,
in_pixels: bool,
device: Optional[torch.device],
dtype: torch.dtype,
) -> torch.Tensor:
if bands is None:
if in_pixels:
bands = pixel_freq_bands(
num_bands,
float(max_res),
linear_bands=linear_bands,
device=device,
)
else:
bands = freq_bands(
num_bands,
temperature=temperature,
step=1,
device=device,
)
return bands


def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False,
include_grid: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
) -> List[torch.Tensor]:
"""

Args:
feat_shape: Feature shape for embedding.
bands: Pre-calculated frequency bands.
num_bands: Number of frequency bands (determines output dim).
max_res: Maximum resolution for pixel based freq.
temperature: Temperature for non-pixel freq.
linear_bands: Linear band spacing for pixel based freq.
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
grid_offset: Constant offset to add to grid for non-pixel freq.
grid_indexing: Indexing mode for meshgrid ('ij' or 'xy')
dtype: Output dtype.
device: Output device.

Returns:

"""
bands = _compute_bands(
bands=bands,
num_bands=num_bands,
max_res=max_res,
temperature=temperature,
linear_bands=linear_bands,
in_pixels=in_pixels,
device=device,
dtype=dtype,
)
return _build_fourier_pos_embed(
feat_shape,
bands,
include_grid,
in_pixels,
ref_feat_shape,
grid_offset,
grid_indexing,
device,
dtype,
)


class FourierEmbed(nn.Module):

def __init__(
Expand Down Expand Up @@ -206,7 +252,7 @@ def init_non_persistent_buffers(self) -> None:
def forward(self, x):
B, C = x.shape[:2]
feat_shape = x.shape[2:]
emb = build_fourier_pos_embed(
emb = _build_fourier_pos_embed(
feat_shape,
self.bands,
include_grid=self.concat_grid,
Expand Down Expand Up @@ -336,6 +382,35 @@ def apply_keep_indices_nlc(
return pos_embed.gather(-2, keep_indices)


def _build_rotary_pos_embed(
feat_shape: List[int],
bands: torch.Tensor,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
grid_offset: float = 0.,
grid_indexing: str = 'ij',
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
):
sin_emb, cos_emb = _build_fourier_pos_embed(
feat_shape,
bands=bands,
in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
grid_offset=grid_offset,
grid_indexing=grid_indexing,
device=device,
dtype=dtype,
)
num_spatial_dim = 1
# this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
for x in feat_shape:
num_spatial_dim *= x
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb


def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -369,27 +444,26 @@ def build_rotary_pos_embed(
Returns:

"""
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape,
bands = _compute_bands(
bands=bands,
num_bands=dim // 4,
max_res=max_res,
temperature=temperature,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
grid_offset=grid_offset,
grid_indexing=grid_indexing,
device=device,
dtype=dtype,
)
num_spatial_dim = 1
# this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
for x in feat_shape:
num_spatial_dim *= x
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
return _build_rotary_pos_embed(
feat_shape,
bands,
in_pixels,
ref_feat_shape,
grid_offset,
grid_indexing,
device,
dtype,
)


class RotaryEmbedding(nn.Module):
Expand Down Expand Up @@ -480,12 +554,10 @@ def _compute_bands(self, device=None, dtype=None):
return bands.to(device=device, dtype=dtype)

def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32):
emb_sin, emb_cos = build_rotary_pos_embed(
bands = self._compute_bands(device, dtype)
emb_sin, emb_cos = _build_rotary_pos_embed(
feat_shape=feat_shape,
dim=self.dim,
max_res=self.max_res,
temperature=self.temperature,
linear_bands=self.linear_bands,
bands=bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
Expand Down Expand Up @@ -514,7 +586,7 @@ def update_feat_shape(self, feat_shape: List[int]):
def get_embed(self, shape: Optional[List[int]] = None):
if shape is not None and self.bands is not None:
# rebuild embeddings every call, use if target shape changes
return build_rotary_pos_embed(
return _build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
Expand Down Expand Up @@ -614,12 +686,10 @@ def _compute_bands(self, device=None, dtype=None):
return bands.to(device=device, dtype=dtype)

def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32):
embeds = build_rotary_pos_embed(
bands = self._compute_bands(device, dtype)
embeds = _build_rotary_pos_embed(
feat_shape=feat_shape,
dim=self.dim,
max_res=self.max_res,
temperature=self.temperature,
linear_bands=self.linear_bands,
bands=bands,
in_pixels=self.in_pixels,
ref_feat_shape=self.ref_feat_shape,
grid_offset=self.grid_offset,
Expand Down Expand Up @@ -647,7 +717,7 @@ def update_feat_shape(self, feat_shape: List[int]):
def get_embed(self, shape: Optional[List[int]] = None):
if shape is not None and self.bands is not None:
# rebuild embeddings from cached bands every call, use if target shape changes
embeds = build_rotary_pos_embed(
embeds = _build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
Expand Down Expand Up @@ -691,7 +761,7 @@ def get_batch_embeds(
max_w = max(w for h, w in shapes)

# Generate embeddings for max size ONCE
sin_emb, cos_emb = build_rotary_pos_embed(
sin_emb, cos_emb = _build_rotary_pos_embed(
feat_shape=(max_h, max_w),
bands=self.bands,
in_pixels=self.in_pixels,
Expand Down