@@ -486,13 +486,12 @@ def get_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
486486 def local_hessian_error (x : torch .Tensor , xq : torch .Tensor ) -> torch .Tensor :
487487 """Compute local Hessian-weighted error."""
488488 original_shape = x .shape
489- dw = (x - xq ).view (- 1 , 1 , bs ) # (num_blocks, 1, block_size)
490- # Repeat hessian for each output channel
491- hessian_expanded = hessian .repeat (
492- cout , 1 , 1
493- ) # (num_blocks, block_size, block_size)
494- # Per-block loss: (num_blocks,)
495- block_loss = (dw @ hessian_expanded @ dw .transpose (- 1 , - 2 )).squeeze (- 1 ).squeeze (- 1 )
489+ # Reshape to (cout, num_blocks_per_cin, block_size)
490+ dw = (x - xq ).view (cout , - 1 , bs )
491+ # Use einsum to avoid materializing cout-repeated Hessian
492+ # dw: (cout, n_blocks, bs), hessian: (n_blocks, bs, bs) -> (cout, n_blocks)
493+ block_loss = torch .einsum ("cnb,nbd,cnd->cn" , dw , hessian , dw )
494+ block_loss = block_loss .reshape (- 1 )
496495 error = block_loss .unsqueeze (- 1 ).expand (- 1 , bs ).reshape (original_shape )
497496 return error
498497
@@ -522,12 +521,14 @@ def forward(self, input, *args, **kwargs):
522521 # Setup helpers for all quantized linear modules
523522 name_to_module = dict (model .named_modules ())
524523 weight_quantizers_info = []
524+ all_patched_modules = [] # Track all modules for cleanup (including disabled ones)
525525
526526 for name , module in name_to_module .items ():
527527 if is_quantized_linear (module ) and module .weight_quantizer .is_enabled :
528528 with enable_weight_access_and_writeback (module , model , name_to_module ):
529529 module .local_hessian = LocalHessianHelper (module , name )
530530 module .local_hessian .setup ()
531+ all_patched_modules .append ((name , module ))
531532 if module .local_hessian .is_enabled :
532533 weight_quantizers_info .append ((name , module ))
533534
@@ -619,7 +620,7 @@ def quant_func(x, amax, quantizer=weight_quantizer):
619620
620621 # Cleanup and free memory
621622 LocalHessianHelper .cache_mode = False
622- for name , module in weight_quantizers_info :
623+ for name , module in all_patched_modules :
623624 module .local_hessian .cleanup ()
624625
625626 print_rank_0 ("local_hessian: Calibration complete." )
0 commit comments