File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -264,6 +264,9 @@ def __init__(
264264 ]))
265265 self .ls_2 = LayerScale (d_model , ls_init_value ) if ls_init_value is not None else nn .Identity ()
266266
267+ def get_reference_weight (self ):
268+ return self .mlp .c_fc .weight
269+
267270 def attention (
268271 self ,
269272 q_x : torch .Tensor ,
@@ -516,18 +519,10 @@ def __init__(
516519 ])
517520
518521 def get_cast_dtype (self ) -> torch .dtype :
519- # Handle both ResidualAttentionBlock and CustomResidualAttentionBlock
520- if hasattr (self .resblocks [0 ], 'get_reference_weight' ):
521- # CustomResidualAttentionBlock has get_reference_weight method
522- weight = self .resblocks [0 ].get_reference_weight ()
523- if hasattr (weight , 'int8_original_dtype' ):
524- return weight .int8_original_dtype
525- return weight .dtype
526- else :
527- # ResidualAttentionBlock
528- if hasattr (self .resblocks [0 ].mlp .c_fc , 'int8_original_dtype' ):
529- return self .resblocks [0 ].mlp .c_fc .int8_original_dtype
530- return self .resblocks [0 ].mlp .c_fc .weight .dtype
522+ weight = self .resblocks [0 ].get_reference_weight ()
523+ if hasattr (weight , 'int8_original_dtype' ):
524+ return weight .int8_original_dtype
525+ return weight .dtype
531526
532527 def forward_intermediates (
533528 self ,
You can’t perform that action at this time.
0 commit comments