Skip to content

Commit e391ea1

Browse files
committed
add rabbit feedback
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
1 parent 2931f61 commit e391ea1

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -486,13 +486,12 @@ def get_error_func(self) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
486486
def local_hessian_error(x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
487487
"""Compute local Hessian-weighted error."""
488488
original_shape = x.shape
489-
dw = (x - xq).view(-1, 1, bs) # (num_blocks, 1, block_size)
490-
# Repeat hessian for each output channel
491-
hessian_expanded = hessian.repeat(
492-
cout, 1, 1
493-
) # (num_blocks, block_size, block_size)
494-
# Per-block loss: (num_blocks,)
495-
block_loss = (dw @ hessian_expanded @ dw.transpose(-1, -2)).squeeze(-1).squeeze(-1)
489+
# Reshape to (cout, num_blocks_per_cin, block_size)
490+
dw = (x - xq).view(cout, -1, bs)
491+
# Use einsum to avoid materializing cout-repeated Hessian
492+
# dw: (cout, n_blocks, bs), hessian: (n_blocks, bs, bs) -> (cout, n_blocks)
493+
block_loss = torch.einsum("cnb,nbd,cnd->cn", dw, hessian, dw)
494+
block_loss = block_loss.reshape(-1)
496495
error = block_loss.unsqueeze(-1).expand(-1, bs).reshape(original_shape)
497496
return error
498497

@@ -522,12 +521,14 @@ def forward(self, input, *args, **kwargs):
522521
# Setup helpers for all quantized linear modules
523522
name_to_module = dict(model.named_modules())
524523
weight_quantizers_info = []
524+
all_patched_modules = [] # Track all modules for cleanup (including disabled ones)
525525

526526
for name, module in name_to_module.items():
527527
if is_quantized_linear(module) and module.weight_quantizer.is_enabled:
528528
with enable_weight_access_and_writeback(module, model, name_to_module):
529529
module.local_hessian = LocalHessianHelper(module, name)
530530
module.local_hessian.setup()
531+
all_patched_modules.append((name, module))
531532
if module.local_hessian.is_enabled:
532533
weight_quantizers_info.append((name, module))
533534

@@ -619,7 +620,7 @@ def quant_func(x, amax, quantizer=weight_quantizer):
619620

620621
# Cleanup and free memory
621622
LocalHessianHelper.cache_mode = False
622-
for name, module in weight_quantizers_info:
623+
for name, module in all_patched_modules:
623624
module.local_hessian.cleanup()
624625

625626
print_rank_0("local_hessian: Calibration complete.")

tests/gpu/torch/quantization/test_quantize_cuda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
mtq.NVFP4_AWQ_LITE_CFG,
8888
mtq.NVFP4_AWQ_CLIP_CFG,
8989
mtq.NVFP4_AWQ_FULL_CFG,
90-
mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG,
90+
mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG,
9191
mtq.MXFP8_DEFAULT_CFG,
9292
mtq.MXFP6_DEFAULT_CFG,
9393
mtq.MXFP4_DEFAULT_CFG,
@@ -114,7 +114,7 @@ def test_quantize(model_cls, config):
114114
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
115115
NVFP4_WEIGHT_ACT_MSE_CFG,
116116
NVFP4_WEIGHT_MSE_FP8_SWEEP_CFG,
117-
mtq.NVFP4_LOCAL_HESSIAN_WEIGHT_ONLY_CFG,
117+
mtq.NVFP4_W4A4_WEIGHT_LOCAL_HESSIAN_CFG,
118118
]:
119119
if get_cuda_ext_mx() is None:
120120
pytest.skip("cuda_ext_mx is not available")

0 commit comments

Comments
 (0)