Skip to content

Commit 5383603

Browse files
committed
norm the correct decoder dimension
1 parent 764d4ac commit 5383603

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

trainers/matroyshka_batch_top_k.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def forward(self, x: t.Tensor, output_features: bool = False):
7878
@t.no_grad()
7979
def set_decoder_norm_to_unit_norm(self):
8080
eps = t.finfo(self.W_dec.dtype).eps
81-
norm = t.norm(self.W_dec.data, dim=0, keepdim=True)
81+
norm = t.norm(self.W_dec.data, dim=1, keepdim=True)
82+
8283
self.W_dec.data /= norm + eps
8384

8485
@t.no_grad()

0 commit comments

Comments
 (0)