Skip to content

Commit ce79bbe

Browse files
NimraisMykola LukashchoukScitator
authored
Feature: A new contrastive loss (Barlow Twins) (#1259)
* Add Barlow Twins loss as a new contrastive loss * update CHANGELOG.md * Add BarlowTwinsLoss into catalyst/contrib/nn/criterion/__init__.py * Add example with simple test * Delete test code * Barlow Twins cross-correlation matrix - a laconic way of off-diagonal element selection * handle zero varience * Example from >>> to code-block * add explicit BarlowTwinsLoss init check into test_criterion_init * Add simple test for Barlow Twins loss * typo std -> var * i.i.d. distibuted and normalized * delete unbised from torch.var * lambda influence testing * hidden trailing whitespace * Update catalyst/contrib/nn/criterion/contrastive.py lmbda -> lambda Co-authored-by: Sergey Kolesnikov <scitator@gmail.com> * rename parameter lmbda -> offdiag_lambda * Add ValueErrors into BarlowTwinsLoss * laconic example in BarlowTwinsLoss * rebase * Fixed typo Co-authored-by: Mykola Lukashchouk <mykola@Mac-mini-Mykola.local> Co-authored-by: Sergey Kolesnikov <scitator@gmail.com>
1 parent ede122d commit ce79bbe

4 files changed

Lines changed: 182 additions & 11 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- added `pre-commit` hook to run codestyle checker on commit ([#1257](https://github.com/catalyst-team/catalyst/pull/1257))
1212
- `on publish` github action for docker and docs added [#1260](https://github.com/catalyst-team/catalyst/pull/1260)
13+
- Barlow twins loss ([#1259](https://github.com/catalyst-team/catalyst/pull/1259))
1314

1415
### Changed
1516

catalyst/contrib/nn/criterion/__init__.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,40 +5,35 @@
55

66
from catalyst.contrib.nn.criterion.ce import (
77
MaskCrossEntropyLoss,
8-
SymmetricCrossEntropyLoss,
98
NaiveCrossEntropyLoss,
9+
SymmetricCrossEntropyLoss,
1010
)
1111
from catalyst.contrib.nn.criterion.circle import CircleLoss
1212
from catalyst.contrib.nn.criterion.contrastive import (
13+
BarlowTwinsLoss,
1314
ContrastiveDistanceLoss,
1415
ContrastiveEmbeddingLoss,
1516
ContrastivePairwiseEmbeddingLoss,
1617
)
1718
from catalyst.contrib.nn.criterion.dice import DiceLoss
18-
from catalyst.contrib.nn.criterion.focal import (
19-
FocalLossBinary,
20-
FocalLossMultiClass,
21-
)
22-
from catalyst.contrib.nn.criterion.gan import (
23-
GradientPenaltyLoss,
24-
MeanOutputLoss,
25-
)
19+
from catalyst.contrib.nn.criterion.focal import FocalLossBinary, FocalLossMultiClass
20+
from catalyst.contrib.nn.criterion.gan import GradientPenaltyLoss, MeanOutputLoss
2621

2722
if torch.__version__ < "1.9":
2823
from catalyst.contrib.nn.criterion.huber import HuberLoss
2924

3025
from catalyst.contrib.nn.criterion.iou import IoULoss
31-
from catalyst.contrib.nn.criterion.trevsky import TrevskyLoss, FocalTrevskyLoss
3226
from catalyst.contrib.nn.criterion.lovasz import (
3327
LovaszLossBinary,
3428
LovaszLossMultiClass,
3529
LovaszLossMultiLabel,
3630
)
3731
from catalyst.contrib.nn.criterion.margin import MarginLoss
32+
from catalyst.contrib.nn.criterion.trevsky import FocalTrevskyLoss, TrevskyLoss
3833
from catalyst.contrib.nn.criterion.triplet import (
3934
TripletLoss,
4035
TripletLossV2,
41-
TripletPairwiseEmbeddingLoss,
4236
TripletMarginLossWithSampler,
37+
TripletPairwiseEmbeddingLoss,
4338
)
4439
from catalyst.contrib.nn.criterion.wing import WingLoss

catalyst/contrib/nn/criterion/contrastive.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,96 @@ def forward(self, embeddings_pred, embeddings_true) -> torch.Tensor:
135135
return loss
136136

137137

138+
class BarlowTwinsLoss(nn.Module):
139+
"""The Contrastive embedding loss.
140+
141+
It has been proposed in `Barlow Twins:
142+
Self-Supervised Learning via Redundancy Reduction`_.
143+
144+
Example:
145+
146+
.. code-block:: python
147+
148+
import torch
149+
from torch.nn import functional as F
150+
from catalyst.contrib.nn import BarlowTwinsLoss
151+
152+
embeddings_left = F.normalize(torch.rand(256, 64, requires_grad=True))
153+
embeddings_right = F.normalize(torch.rand(256, 64, requires_grad=True))
154+
criterion = BarlowTwinsLoss(offdiag_lambda = 1)
155+
criterion(embeddings_left, embeddings_right)
156+
157+
.. _`Barlow Twins: Self-Supervised Learning via Redundancy Reduction`:
158+
https://arxiv.org/abs/2103.03230
159+
"""
160+
161+
def __init__(self, offdiag_lambda=1.0, eps=1e-12):
162+
"""
163+
Args:
164+
offdiag_lambda: trade-off parameter
165+
eps: shift for the varience (var + eps)
166+
"""
167+
super().__init__()
168+
self.offdiag_lambda = offdiag_lambda
169+
self.eps = eps
170+
171+
def forward(
172+
self, embeddings_left: torch.Tensor, embeddings_right: torch.Tensor,
173+
) -> torch.Tensor:
174+
"""Forward propagation method for the contrastive loss.
175+
176+
Args:
177+
embeddings_left: left objects embeddings [batch_size, features_dim]
178+
embeddings_right: right objects embeddings [batch_size, features_dim]
179+
180+
Raises:
181+
ValueError: if the batch size is 1
182+
ValueError: if embeddings_left and embeddings_right shapes are different
183+
ValueError: if embeddings shapes are not in a form (batch_size, features_dim)
184+
185+
Returns:
186+
torch.Tensor: loss
187+
"""
188+
shape_left, shape_right = embeddings_left.shape, embeddings_right.shape
189+
if len(shape_left) != 2:
190+
raise ValueError(
191+
f"Left shape should be (batch_size, feature_dim), but got - {shape_left}!"
192+
)
193+
elif len(shape_right) != 2:
194+
raise ValueError(
195+
f"Right shape should be (batch_size, feature_dim), but got - {shape_right}!"
196+
)
197+
if shape_left[0] == 1:
198+
raise ValueError(f"Batch size should be >= 2, but got - {shape_left[0]}!")
199+
if shape_left != shape_right:
200+
raise ValueError(f"Shapes should be equall, but got - {shape_left} and {shape_right}!")
201+
# normalization
202+
z_left = (embeddings_left - embeddings_left.mean(dim=0)) / (
203+
embeddings_left.var(dim=0) + self.eps
204+
).pow(1 / 2)
205+
z_right = (embeddings_right - embeddings_right.mean(dim=0)) / (
206+
embeddings_right.var(dim=0) + self.eps
207+
).pow(1 / 2)
208+
209+
# cross-correlation matrix
210+
batch_size = z_left.shape[0]
211+
cross_correlation = torch.matmul(z_left.T, z_right) / batch_size
212+
213+
# selection of diagonal elements and off diagonal elements
214+
on_diag = torch.diagonal(cross_correlation)
215+
off_diag = cross_correlation.clone().fill_diagonal_(0)
216+
217+
# the loss described in the original Barlow Twin's paper
218+
# encouraging off_diag to be zero and on_diag to be one
219+
on_diag_loss = on_diag.add_(-1).pow_(2).sum()
220+
off_diag_loss = off_diag.pow_(2).sum()
221+
loss = on_diag_loss + self.offdiag_lambda * off_diag_loss
222+
return loss
223+
224+
138225
__all__ = [
139226
"ContrastiveEmbeddingLoss",
140227
"ContrastiveDistanceLoss",
141228
"ContrastivePairwiseEmbeddingLoss",
229+
"BarlowTwinsLoss",
142230
]

tests/catalyst/contrib/nn/test_criterion.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# flake8: noqa
2+
import numpy as np
3+
import pytest
4+
import torch
25

36
from catalyst.contrib.nn import criterion as module
47
from catalyst.contrib.nn.criterion import CircleLoss, TripletMarginLossWithSampler
8+
from catalyst.contrib.nn.criterion.contrastive import BarlowTwinsLoss
59
from catalyst.data import AllTripletsSampler
610

711

@@ -13,6 +17,8 @@ def test_criterion_init():
1317
instance = module_class(margin=0.25, gamma=256)
1418
elif module_class == TripletMarginLossWithSampler:
1519
instance = module_class(margin=1.0, sampler_inbatch=AllTripletsSampler())
20+
elif module_class == BarlowTwinsLoss:
21+
instance = module_class(offdiag_lambda=1, eps=1e-12)
1622
else:
1723
# @TODO: very dirty trick
1824
try:
@@ -21,3 +27,84 @@ def test_criterion_init():
2127
print(module_class)
2228
instance = 1
2329
assert instance is not None
30+
31+
32+
@pytest.mark.parametrize(
33+
"embeddings_left,embeddings_right,offdiag_lambda,eps,true_value",
34+
(
35+
(
36+
torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
37+
torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
38+
1,
39+
1e-12,
40+
1,
41+
),
42+
(
43+
torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
44+
torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
45+
0,
46+
1e-12,
47+
0.5,
48+
),
49+
(
50+
torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
51+
torch.tensor([[1.0, 0.0], [0.0, 1.0]]),
52+
2,
53+
1e-12,
54+
1.5,
55+
),
56+
(
57+
torch.tensor(
58+
[
59+
[-0.31887834],
60+
[1.3980029],
61+
[0.30775256],
62+
[0.29397671],
63+
[-1.47968253],
64+
[-0.72796992],
65+
[-0.30937596],
66+
[1.16363952],
67+
[-2.15524895],
68+
[-0.0440765],
69+
]
70+
),
71+
torch.tensor(
72+
[
73+
[-0.31887834],
74+
[1.3980029],
75+
[0.30775256],
76+
[0.29397671],
77+
[-1.47968253],
78+
[-0.72796992],
79+
[-0.30937596],
80+
[1.16363952],
81+
[-2.15524895],
82+
[-0.0440765],
83+
]
84+
),
85+
1,
86+
1e-12,
87+
0.01,
88+
),
89+
),
90+
)
91+
def test_barlow_twins_loss(
92+
embeddings_left: torch.Tensor,
93+
embeddings_right: torch.Tensor,
94+
offdiag_lambda: float,
95+
eps: float,
96+
true_value: float,
97+
):
98+
"""
99+
Test Barlow Twins loss
100+
Args:
101+
embeddings_left: left objects embeddings [batch_size, features_dim]
102+
embeddings_right: right objects embeddings [batch_size, features_dim]
103+
offdiag_lambda: trade off parametr
104+
eps: zero varience handler (var + eps)
105+
true_value: expected loss value
106+
"""
107+
value = BarlowTwinsLoss(offdiag_lambda=offdiag_lambda, eps=eps)(
108+
embeddings_left, embeddings_right
109+
).item()
110+
assert np.isclose(value, true_value)

0 commit comments

Comments
 (0)