Skip to content

Commit bbc6558

Browse files
committed
Simply reference weight handling for custom vs default blocks
1 parent a4f3d79 commit bbc6558

1 file changed

Lines changed: 7 additions & 12 deletions

File tree

src/open_clip/transformer.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ def __init__(
264264
]))
265265
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
266266

267+
def get_reference_weight(self):
268+
return self.mlp.c_fc.weight
269+
267270
def attention(
268271
self,
269272
q_x: torch.Tensor,
@@ -516,18 +519,10 @@ def __init__(
516519
])
517520

518521
def get_cast_dtype(self) -> torch.dtype:
519-
# Handle both ResidualAttentionBlock and CustomResidualAttentionBlock
520-
if hasattr(self.resblocks[0], 'get_reference_weight'):
521-
# CustomResidualAttentionBlock has get_reference_weight method
522-
weight = self.resblocks[0].get_reference_weight()
523-
if hasattr(weight, 'int8_original_dtype'):
524-
return weight.int8_original_dtype
525-
return weight.dtype
526-
else:
527-
# ResidualAttentionBlock
528-
if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'):
529-
return self.resblocks[0].mlp.c_fc.int8_original_dtype
530-
return self.resblocks[0].mlp.c_fc.weight.dtype
522+
weight = self.resblocks[0].get_reference_weight()
523+
if hasattr(weight, 'int8_original_dtype'):
524+
return weight.int8_original_dtype
525+
return weight.dtype
531526

532527
def forward_intermediates(
533528
self,

0 commit comments

Comments
 (0)