Skip to content

[Draft] QAD algorithms - SmoothLAQ, AdaRound#1150

Closed
realAsma wants to merge 33 commits intomainfrom
asma/QAD-new-algorithms
Closed

[Draft] QAD algorithms - SmoothLAQ, AdaRound#1150
realAsma wants to merge 33 commits intomainfrom
asma/QAD-new-algorithms

Conversation

@realAsma
Copy link
Copy Markdown
Contributor

@realAsma realAsma commented Mar 31, 2026

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added post-training quantization (PTQ) script for FP4 quantization of language models.
    • Introduced advanced quantization algorithms: AdaRound, SmoothLSQ, LSQ, LAQ, SmoothLAQ with learnable scale support.
    • Added AdaRound and quantization-error regularization training options.
    • Enabled Liger fused kernel support for improved training performance.
    • Added local JSONL dataset loading with tokenization caching for QAT.
    • Introduced per-parameter learning rate configuration via YAML patterns.
  • Documentation

    • Added comprehensive QAT README with dataset and caching examples.
  • Tests

    • Added validation tests for AdaRound training, Liger fused loss, and QERR regularization.

realAsma and others added 30 commits February 28, 2026 13:04
Signed-off-by: realAsma <akuriparambi@nvidia.com>

minor

Signed-off-by: realAsma <akuriparambi@nvidia.com>

Respect narrow_range in IntCastSTEFunction for scale learning

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Made-with: Cursor

Fix scale_after_dequant for non-NVFP4 quantizers

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Made-with: Cursor
…quant

- Broaden local_hessian_calibrate to handle INT block quant (not just FP4)
- Support mse, local_hessian, and max methods in scale_after_dequant
- Add _convert_to_static_block_quantizers helper for max_calibrate path

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Made-with: Cursor
- Refactor utils.py to support local JSONL datasets (files/dirs) via
  --dataset and tokenized dataset caching via --dataset_cache_path
- Normalize Daring-Anteater conversations to standard messages format
- Add distributed-aware tokenization with per-rank sharding and merging
- Wire dataset_cache_path through launch.sh and main.py DataArguments
- Update README with local JSONL dataset and caching examples
- Remove unused NVFP4StaticQuantizer import in model_calib.py
- Fix import ordering in vllm plugin and test_quantize_api

Made-with: Cursor
Cast _per_block_scale and _per_tensor_scale to float32 before scale
computation, then cast the final scale back to the input dtype. This
prevents mixed-precision issues during fake quantization with learned
scales.

Made-with: Cursor
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Made-with: Cursor
- Simplify KDTrainer by assigning compute_loss_func in __init__ instead
  of overriding train() and compute_loss(); remove stale methods.
- Fix QADTrainer MRO to (KDTrainer, QATTrainer) so KD loss takes
  precedence during training.
- Guard modelopt state restore with `not is_quantized()` to avoid
  double-restoring an already-quantized model.
- Use tq.to() instead of tq.to_empty() in FSDP2 prepare patch.
- Remove unnecessary float32 dtype casts in StaticBlockScaleQuantizer.
- Add attn_implementation arg, fineweb_edu pretraining dataset,
  fp32 model loading, and pretrain tokenizer to LLM QAT example.

Made-with: Cursor
Signed-off-by: realAsma <akuriparambi@nvidia.com>
…r rename

- Rename _fp4_cast to _cast_ste, supporting both FP4 and INT cast
- Fix NVFP4StaticQuantizer -> StaticBlockScaleQuantizer references in adaround
- Add NVFP4StaticAdaRoundQuantizer import/restore in conversion.py
- Fix test config to use fp8_scale_sweep for matching calibration

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…eight support

- Add dist_loss_weight, beta_start, beta_end, freeze_weight to AdaRoundConfig
- Store hyperparams on NVFP4StaticAdaRoundQuantizer via from_nvfp4_quantizer
- Auto-detect adaround quantizers in QATTrainer, add annealed dist_loss to compute_loss
- Detach floor cast in _cast_ste when freeze_weight=True (only round_logits get grads)
- Add trainer tests: QATTrainer (with/without adaround), QADTrainer with adaround+KD

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
… kernel

Convert learnable DTensor parameters (_per_block_scale, round_logits) to
local tensors inside the quantizer computation path, matching the existing
to_local() conversion for inputs. Also fix fp4_step_size triton kernel to
preserve input dtype instead of hardcoding float32.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…"loss"

Remove compute_loss override entirely. Instead, compute dist_loss after
super().training_step() and do a separate accelerator.backward() so
gradients accumulate naturally. The Trainer's auto-logged "loss" now
reflects base_loss (comparable with non-adaround jobs), while
adaround/dist_loss and adaround/beta appear as separate TensorBoard scalars.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…izer

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Made-with: Cursor
…en param support

Move training-time knobs (beta_start, beta_end, dist_loss_weight, temperature)
out of AdaRoundConfig and NVFP4StaticAdaRoundQuantizer into a new
AdaRoundTrainingArguments dataclass on QATTrainer. Add trainable_params and
frozen_params (fnmatch patterns) to QuantizationArguments so QATTrainer can
configure requires_grad before optimizer creation. Remove redundant detach
logic from _cast_ste.

Made-with: Cursor
…assthrough

Adaround-specific args (beta_start, beta_end, dist_loss_weight, temperature,
freeze_weights) are passed through to main.py unchanged, so they don't need
explicit parsing. Unknown args now collect into EXTRA_ARGS instead of erroring,
making the script extensible without modification.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
- Fix adaround metrics not appearing in logs: override QATTrainer.log()
  to merge metrics before the callback pipeline (ProgressCallback was
  printing before _AdaRoundAuxCallback could inject them via on_log)
- Fix _liger_loss_func passing unexpected arg to zero-arg _compute()
- Update test_adaround_trainer to verify adaround quantizers and logged
  metrics directly instead of referencing removed _adaround_aux_callback

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Introduce a post-optimizer-step callback that computes per-weight
MSE between quantized and original weights and applies independent
manual SGD steps to push weights toward quantized grid points.
The regularization coefficient is linearly annealed from
qerr_coeff_start to qerr_coeff_stop over training.

Qerr and AdaRound are mutually exclusive (ValueError if both set).

Made-with: Cursor
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Made-with: Cursor
PyYAML safe_load parses bare exponential notation (e.g. 1e-5) as
strings instead of floats. This caused a TypeError in the LR scheduler
when base_lr (a string) was multiplied by the lambda output (a float).

Made-with: Cursor
…s_coeff

- Qerr MSE is now always computed and reported (qerr/mse, qerr/coeff) every
  step. Default coefficients are 0 (monitor-only, no gradient applied).
  Set non-zero qerr_coeff_start/stop to enable active regularization.
- Rename dist_loss_weight -> dist_loss_coeff for consistency.
- Parse QuantErrorTrainingArguments in main.py and pass qerr_args to trainer.
- Extract _compute_mse helper in _QuantErrorAuxCallback.
- Replace mutual exclusion error with if/elif (adaround takes priority).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…t dims

Reshape weight to [num_blocks, block_size] before calling _cast_ste in
_compute_mse, matching what the quantizer forward pass does. Without this,
MLP layers with intermediate_size=6144 trigger Triton's "arange's range
must be a power of 2" error.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…heck

Add --qerr_reduction flag to QuantErrorTrainingArguments with 'sum' as
default. The previous .mean() reduction produced per-element gradients too
small to move quantization error. The .sum() reduction aggregates across
all weight elements for stronger gradients. Logged metric key now reflects
the reduction used (qerr/sum or qerr/mean). Also restores the
adaround_args/qerr_args mutual exclusivity ValueError.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
HfArgumentParser creates default instances of both dataclasses, so both
are always non-None. The real exclusivity is handled by the if/elif in
_setup_training where adaround takes priority. Remove the now-incorrect
ValueError and its corresponding test.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…izer

The squared error `(q_weight - weight) ** 2` was allowing gradients to
flow back through the quantizer's forward pass, corrupting weight
updates. Adding `.detach()` ensures qerr only produces gradients w.r.t.
the original weight, which is the correct STE-like behavior.

Also skip the adaround aux step when round_logits are frozen (e.g. in
the freeze-round paradigm) to avoid unnecessary computation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…tion

LSQ computes s*Q(x/s) without pre-dividing weights, allowing the quantization
grid to adapt as learned scales change during training. Refactors shared helpers
out of scale_after_dequant for reuse. Includes lint/format fixes from pre-commit.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…ithms

Rename scale_after_dequant -> smooth_lsq to match paper terminology. Add two
new FP4 weight quantization algorithms: LAQ (learns a_max, derives s = a_max/Q_max,
no weight pre-division) and SmoothLAQ (learns a_max, weights pre-divided by a_max,
forward: Q_STE(w_a * Q_max) * s). Also cast LSQ/LAQ division to fp32 for precision.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- New _compute_laq_params() returns raw amaxes instead of roundtripping
  through per_block_scale (bad values default to Q_max so scale=1)
- enable_laq/enable_smooth_laq accept amax directly, bypass
  _enable_learnable_scales; rename _amax_param -> _amax_learnt
- SmoothLAQ no longer pre-divides weights; stores _amax_frozen buffer
  and divides by frozen scale in forward (optimizer updates original w)
- New _quantize_scale() helper shared across all 4 learnable algorithms
  for FP8 scale quantization (reused for both dequant and frozen scales)
- Unified _fake_quantize: compute scale -> quant input -> cast -> dequant
- amax property raises RuntimeError for SmoothLAQ (ambiguous)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Replace the deprecated _scale_after_dequant flag with _smooth_lsq and
generalize AdaRound to work with any learnable-scale init algorithm
(smooth_lsq, lsq, laq, smooth_laq) or plain calibration (max, mse,
local_hessian) which auto-converts via smooth_lsq.

- Rename _scale_after_dequant -> _smooth_lsq across quantizer, calib,
  conversion, config, and tests
- Replace smooth_lsq_args with init_algorithm in AdaRoundConfig
- Add _compute_weight_scaled() for per-mode weight scaling
- Add parametrized test_adaround_with_init_algorithms covering all 7 algos

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extract _amax_to_scale(amax, max_bound) that computes scale = amax/max_bound
with zero-guard, replacing duplicated scale computation across LAQ, SmoothLAQ,
LSQ, SmoothLSQ in _fake_quantize, _compute_block_scales, _compute_laq_params,
and _compute_weight_scaled.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma realAsma requested review from a team as code owners March 31, 2026 18:55
@realAsma realAsma requested review from ChenhanYu and mxinO March 31, 2026 18:55
@realAsma realAsma changed the title [QAD] New algorithms [Draft] New algorithms Mar 31, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 31, 2026

📝 Walkthrough

Walkthrough

This pull request introduces learnable-scale quantization algorithms (smooth_lsq, lsq, laq, smooth_laq, adaround), new training infrastructure for adaround and quantization-error regularization, a post-training quantization script, updated dataset handling for QAT examples, distillation loss refactoring, and corresponding test coverage.

Changes

Cohort / File(s) Summary
QAT Examples & Documentation
examples/llm_qat/README.md, examples/llm_qat/create_ptq.py, examples/llm_qat/launch.sh, examples/llm_qat/main.py, examples/llm_qat/utils.py
New PTQ script with FP4 quantization configs and algorithm registry. Extended main.py with attention implementation and cache path options. Refactored dataset handling via _load_cached_dataset, distributed sharding, on-disk caching, and support for JSONL inputs. Updated launch.sh argument parsing for dataset and cache paths. Added README documentation for end-to-end QAT with local datasets.
Quantization Algorithms & Configuration
modelopt/torch/quantization/config.py, modelopt/torch/quantization/mode.py, modelopt/torch/quantization/model_calib.py
New Pydantic algorithm config classes: SmoothLSQConfig, LSQConfig, LAQConfig, SmoothLAQConfig, AdaRoundConfig. Added corresponding mode descriptors to CalibrateModeRegistry. Introduced public calibration functions smooth_lsq, lsq, laq, smooth_laq, adaround with scale calibration and learnable initialization support. Generalized "static block" handling and added AdaRound conversion helpers.
Quantizer Classes
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Replaced NVFP4StaticQuantizer with new StaticBlockScaleQuantizer supporting FP4/INT block quantization. Added learnable-scale modes (enable_smooth_lsq, enable_lsq, enable_laq, enable_smooth_laq) with trainable per-block scales and amax. Introduced new NVFP4StaticAdaRoundQuantizer with AdaRound-specific rounding, dist_loss, and STE casting.
Triton Kernels
modelopt/torch/quantization/triton/fp4_kernel.py, modelopt/torch/quantization/triton/fp4_kernel_hopper.py
Added new static_blockwise_fp4_cast and static_blockwise_fp4_step_size kernels with rounding support. Updated fp4_fake_quant_kernel masking to handle zero, NaN, and infinity checks matching CUDA behavior.
Quantization Restoration & Model Logic
modelopt/torch/quantization/conversion.py, modelopt/torch/quantization/model_quant.py
Updated quantizer restoration to use StaticBlockScaleQuantizer and handle NVFP4StaticAdaRoundQuantizer conversion. Modified quantize() to conditionally apply mode replacement only when model is not already quantized, using set_quantizer_by_cfg for quantized models.
Training Infrastructure
modelopt/torch/opt/plugins/transformers.py, modelopt/torch/quantization/plugins/transformers_trainer.py
Added ModelOptTrainerArguments dataclass with trainable/frozen parameter patterns and per-parameter LR config support. Updated ModelOptHFTrainer to apply parameter freezing, LR configuration via YAML, and optional Liger fused-loss integration. Introduced AdaRoundTrainingArguments and QuantErrorTrainingArguments dataclasses. Extended QATTrainer with adaround_args and qerr_args, callbacks for AdaRound dist_loss and quantization-error regularization. Reordered QADTrainer base classes.
Distillation Loss
modelopt/torch/distill/plugins/huggingface.py
Refactored KDTrainer to install loss functions during initialization. Added support for custom KD loss (compute_kd_loss) with masking or Liger fused JSD (_liger_loss_func) depending on configuration. Added helpers _get_lm_head, _get_teacher_lm_head, _setup_liger_fused_loss. Modified LMLogitsLoss to return per-token unreduced KL-div losses. Removed previous compute_loss and train overrides.
Tensor Quantization Utilities
modelopt/torch/quantization/tensor_quant.py
Added straight-through estimator (STE) functions: FP4CastSTEFunction with FP4 casting and gradient clamping, IntCastSTEFunction with integer quantization bounds, fp4_step_size helper for FP4 level gaps.
Minor Cleanup
modelopt/torch/opt/plugins/huggingface.py
Removed blank line in patched __init__ wrapper.
Test Coverage
tests/gpu/torch/quantization/test_adaround_trainer.py, tests/gpu/torch/quantization/test_liger_loss.py, tests/gpu/torch/quantization/test_qerr_trainer.py, tests/gpu/torch/quantization/test_quantize_cuda.py, tests/unit/torch/opt/plugins/test_lr_config.py
Added AdaRound dist_loss integration tests with QATTrainer and QADTrainer. Added Liger fused-loss validation tests comparing fused vs non-fused training. Added quantization-error regularization tests. Extended test_quantize_cuda.py with smooth_lsq, lsq, laq, smooth_laq, and adaround gradient/behavior tests. Added unit tests for LR config loading, pattern matching, and optimizer construction.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant QAT as QAT Trainer
    participant Model as Model
    participant Quantizer as StaticBlockScale<br/>Quantizer
    participant AdaroundCB as AdaRound<br/>Callback

    User->>QAT: Initialize with adaround_args
    QAT->>Model: Check for NVFP4StaticAdaRoundQuantizer
    Model->>Quantizer: Enable AdaRound mode
    Quantizer->>Quantizer: Create trainable round_logits
    
    loop Training Step
        QAT->>Model: Forward pass
        Model->>Quantizer: Fake quantize with learned rounding
        Quantizer-->>Model: Rounded logits
        QAT->>AdaroundCB: Compute dist_loss from round_logits
        AdaroundCB->>Quantizer: Backprop to round_logits
        AdaroundCB->>AdaroundCB: Apply custom update with beta annealing
        AdaroundCB-->>QAT: Log adaround/dist_loss, beta
    end
    
    User->>QAT: Save model
    QAT->>Model: Export quantized weights with learned rounding
Loading
sequenceDiagram
    participant Client as Caller
    participant QAT as QATTrainer
    participant Loss as LMLogitsLoss
    participant Fused as LigerFusedLinearJSD<br/>(optional)

    Client->>QAT: Initialize with distill_config and use_liger_kernel
    QAT->>QAT: Convert model to distillation form
    
    alt use_liger_kernel enabled
        QAT->>QAT: Check lm_head compatibility
        QAT->>Fused: Create fused JSD loss
        QAT->>QAT: Set compute_loss_func to fused path
    else use_liger_kernel disabled
        QAT->>Loss: Create compute_kd_loss with masking
        QAT->>QAT: Set compute_loss_func to KD path
    end
    
    loop Training Step
        Client->>QAT: Run training step
        QAT->>QAT: Forward student and teacher
        QAT->>Loss: Compute per-token KL-div or JSD
        Loss-->>QAT: Unreduced per-token losses (B*S,)
        QAT->>QAT: Reduce and backprop
    end
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 71.68% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title '[QAD] New algorithms' refers to real aspects of the changeset (introducing LSQ, LAQ, SmoothLAQ, AdaRound, and Qerr algorithms), but is overly broad and generic. It does not clearly communicate the primary change or scope. Consider using a more specific title like 'Add LSQ, LAQ, SmoothLAQ, and AdaRound quantization algorithms with training support' to better convey the main changes to reviewers.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch asma/QAD-new-algorithms

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma changed the title [Draft] New algorithms [Draft] QAD algorithms - SmoothLAQ, AdaRound Mar 31, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_qat/create_ptq.py`:
- Around line 137-145: The CLI defaults for dataset arguments are non-portable
and can cause stale cache reuse; update the p.add_argument call for "--dataset"
to remove the hardcoded default and instead make the arg required (e.g.,
p.add_argument("--dataset", type=str, required=True, help=...)), and change the
"--dataset_cache_path" default from "dataset_cache" to an empty string (or None)
so cache reuse is explicit (e.g., p.add_argument("--dataset_cache_path",
type=str, default="", help=...)); adjust any downstream code that assumes a
non-empty cache path to handle empty/None appropriately before returning
p.parse_args().

In `@examples/llm_qat/launch.sh`:
- Around line 53-55: The current launch.sh silently collects any unrecognized
token into EXTRA_ARGS which gets forwarded to main.py (which uses
HfArgumentParser.parse_args_into_dataclasses() and will crash on unknown args);
update the script to stop appending arbitrary arguments in the default *) branch
of the case: either remove the EXTRA_ARGS mechanism entirely, or
validate/whitelist tokens before appending (e.g., check $1 against a set of
allowed flags) and otherwise emit a clear error and exit; reference EXTRA_ARGS,
the default *) case in launch.sh, and main.py /
HfArgumentParser.parse_args_into_dataclasses() when making the change.

In `@examples/llm_qat/utils.py`:
- Around line 254-255: The code incorrectly treats pad_token_id == 0 as missing
by using "pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id"; change
the logic to detect only None so a zero ID is preserved (e.g., check
tokenizer.pad_token_id is None or use a ternary that uses tokenizer.eos_token_id
only when pad_token_id is None). Update the assignment around pad_token and the
subsequent None-check to use explicit None checks referencing
tokenizer.pad_token_id and tokenizer.eos_token_id so valid 0 pad IDs are not
overridden.
- Around line 96-104: The cache key currently omits tokenizer identity causing
tokenization mismatches; modify _build_cache_path to accept a tokenizer
identifier (e.g., tokenizer.name_or_path) and include it in the hashed string,
and then thread that identifier through the caller _load_cached_dataset so it
passes tokenizer.name_or_path into _build_cache_path when constructing the cache
path; ensure the identifier is stable (name_or_path) and included in the same
f"{...}" input used to compute cache_key.

In `@modelopt/torch/distill/plugins/huggingface.py`:
- Around line 56-73: Ensure the fused LIGER JSD path is only enabled when there
is exactly one distillation loss of the expected LMLogitsLoss type: in
_setup_liger_fused_loss check that model._layers_to_loss contains a single entry
and that that loss object is the LMLogitsLoss (or has the specific attributes
_temperature and logits inputs) before setting use_liger_kernel=True and
compute_loss_func=_liger_loss_func; otherwise set use_liger_kernel=False and
compute_loss_func=compute_kd_loss. Apply the same exact precondition check in
the other setup block mentioned (the similar code around lines 91-124) so
_liger_loss_func is never selected when multiple or different loss types are
configured.

In `@modelopt/torch/opt/plugins/transformers.py`:
- Around line 417-431: The patched forward assignment in _fsdp_forward_redirect
must be undone even on exceptions: wrap the invocation of fsdp_module in a
try/finally so fsdp_module.forward is always reset to original_forward; perform
the assignment fsdp_module.forward = wrapped_forward, call fsdp_module(...)
inside try, and in finally restore fsdp_module.forward = original_forward. Also
replace the string sentinel argument ("_fsdp_redirect") with a dummy tensor
created on the module's device/dtype (e.g., a torch.zeros tensor derived from a
model parameter or buffer) so forward pre-hooks receive a properly-typed input;
keep the inner wrapped_forward returning fn() unchanged.

In `@modelopt/torch/quantization/config.py`:
- Around line 268-279: Add the new INT3_BLOCKWISE_WEIGHT_ONLY_CFG preset to the
string-based selector by adding an entry for "INT3_BLOCKWISE_WEIGHT_ONLY_CFG" in
the choices mapping (the same mapping that currently lists other presets) so
that the config can be selected by name; update the choices dict/switch where
presets are registered (reference symbol: choices) to include a key pointing to
INT3_BLOCKWISE_WEIGHT_ONLY_CFG so examples that validate against
modelopt.torch.quantization.config.choices can select it.

In `@modelopt/torch/quantization/conversion.py`:
- Around line 132-140: The current sequential checks can call
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer(module) without ensuring
module is first a StaticBlockScaleQuantizer; update the restoration logic in
conversion.py so that when state["_is_nvfp4_static_adaround_quantizer"] is true
you first ensure or convert the module to a StaticBlockScaleQuantizer (i.e., run
StaticBlockScaleQuantizer.from_tensor_quantizer(module) if needed) or explicitly
validate that state["_is_nvfp4_static_quantizer"] is also true and raise/handle
a clear error; specifically touch the block that references
StaticBlockScaleQuantizer, NVFP4StaticAdaRoundQuantizer, from_nvfp4_quantizer,
and set_from_modelopt_state to enforce the invariant or perform the conversion
before calling from_nvfp4_quantizer.

In `@modelopt/torch/quantization/nn/modules/tensor_quantizer.py`:
- Around line 1555-1581: The conversion method from_nvfp4_quantizer should
reject non-FP4/ non-NVFP4 block quantizers instead of blindly mutating any
StaticBlockScaleQuantizer; add a guard near the start of
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer that inspects the source
quantizer's format/block-type (e.g., a property like block_format / is_fp4 /
nvfp4 flag on tq) and raise/assert if it is not an FP4/NVFP4 config, and do the
same check in the other related conversion methods mentioned (the other
from_nvfp4_quantizer conversion blocks at the ranges you noted) so AdaRound
remains FP4-only unless you implement a separate INT-specific AdaRound path.

In `@modelopt/torch/quantization/plugins/transformers_trainer.py`:
- Around line 308-312: The trainer's _setup_adaround() registers the AdaRound
callback but never propagates AdaRoundTrainingArguments.temperature to the
quantizers, so non-default temperatures are ignored; update _setup_adaround() to
iterate the model's quantizers (e.g., instances of NVFP4StaticAdaRoundQuantizer)
and set their temperature property from self.args.ada_round_args.temperature (or
self.ada_rounding_args / AdaRoundTrainingArguments.temperature where stored)
before adding the _AdaRoundAuxCallback, ensuring each
NVFP4StaticAdaRoundQuantizer.temperature (or equivalent attribute) is assigned
the trainer's temperature value so AdaRound uses the configured temperature.
- Around line 676-695: The branch in _compute_mse incorrectly uses
quantizer._cast_ste for any StaticBlockScaleQuantizer that has no learnable
modes, which incorrectly drops block/tensor scales for calibrated
(non-pre-divided) weights; change the condition so _cast_ste is used only when
the quantizer is a pre-divided/smooth_lsq case (i.e., when quantizer._smooth_lsq
is True), and for other StaticBlockScaleQuantizer instances call
quantizer(weight) instead; preserve the existing reshape logic using
quantizer._block_reshape_size and then reshape back to orig_shape before
computing sq_err.
- Around line 650-672: The current registration loop (_weight_entries) includes
quantized weights even when they are frozen or not in the optimizer, causing
qerr to backprop through tensors with no grads; update the loop that builds
self._weight_entries (the block using weight_attr_names, quantizer_attr_names,
weight_quantizer and is_enabled) to only append weights that are
torch.nn.Parameter AND weight.requires_grad is True AND the weight's id is
present in the optimizer param mapping (pid_to_group); similarly ensure the
later population of self._param_group_idx and self._multiplier only iterates
over these filtered _weight_entries (so id(weight) exists in pid_to_group) to
avoid creating entries for optimizer-less/frozen weights — apply the same guard
to the analogous registration block elsewhere that performs the same work.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b794a3ad-578d-4f61-90b1-5c3c9601564c

📥 Commits

Reviewing files that changed from the base of the PR and between 74a8694 and 67e69ea.

📒 Files selected for processing (23)
  • examples/llm_qat/README.md
  • examples/llm_qat/create_ptq.py
  • examples/llm_qat/launch.sh
  • examples/llm_qat/main.py
  • examples/llm_qat/utils.py
  • modelopt/torch/distill/plugins/huggingface.py
  • modelopt/torch/opt/plugins/huggingface.py
  • modelopt/torch/opt/plugins/transformers.py
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/conversion.py
  • modelopt/torch/quantization/mode.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/model_quant.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/plugins/transformers_trainer.py
  • modelopt/torch/quantization/tensor_quant.py
  • modelopt/torch/quantization/triton/fp4_kernel.py
  • modelopt/torch/quantization/triton/fp4_kernel_hopper.py
  • tests/gpu/torch/quantization/test_adaround_trainer.py
  • tests/gpu/torch/quantization/test_liger_loss.py
  • tests/gpu/torch/quantization/test_qerr_trainer.py
  • tests/gpu/torch/quantization/test_quantize_cuda.py
  • tests/unit/torch/opt/plugins/test_lr_config.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/opt/plugins/huggingface.py

Comment on lines +137 to +145
p.add_argument(
"--dataset",
type=str,
default="/home/scratch.akuriparambi_coreai/datasets/qat_blend_sft/blend_sft.jsonl",
)
p.add_argument("--eval_size", type=int, default=0)
p.add_argument("--train_size", type=int, default=0)
p.add_argument("--dataset_cache_path", type=str, default="dataset_cache")
return p.parse_args()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Use portable, non-stale defaults for the dataset arguments.

The current --dataset default only exists on one machine, and the fixed "dataset_cache" default will silently reuse tokenized data from previous runs with different inputs. Making --dataset required and leaving --dataset_cache_path empty avoids both footguns.

Suggested fix
     p.add_argument(
         "--dataset",
         type=str,
-        default="/home/scratch.akuriparambi_coreai/datasets/qat_blend_sft/blend_sft.jsonl",
+        required=True,
     )
@@
-    p.add_argument("--dataset_cache_path", type=str, default="dataset_cache")
+    p.add_argument("--dataset_cache_path", type=str, default="")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_qat/create_ptq.py` around lines 137 - 145, The CLI defaults for
dataset arguments are non-portable and can cause stale cache reuse; update the
p.add_argument call for "--dataset" to remove the hardcoded default and instead
make the arg required (e.g., p.add_argument("--dataset", type=str,
required=True, help=...)), and change the "--dataset_cache_path" default from
"dataset_cache" to an empty string (or None) so cache reuse is explicit (e.g.,
p.add_argument("--dataset_cache_path", type=str, default="", help=...)); adjust
any downstream code that assumes a non-empty cache path to handle empty/None
appropriately before returning p.parse_args().

Comment on lines 53 to 55
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
EXTRA_ARGS="${EXTRA_ARGS:-} $1"
;;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify main.py uses parse_args_into_dataclasses without remainder handling
rg -n "parse_args_into_dataclasses|parse_known_args" examples/llm_qat/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 114


🏁 Script executed:

# Get context around line 189 in main.py to understand HfArgumentParser usage
head -200 examples/llm_qat/main.py | tail -30

Repository: NVIDIA/Model-Optimizer

Length of output: 996


🏁 Script executed:

# Find where EXTRA_ARGS is used in launch.sh command invocation
rg -n "EXTRA_ARGS" examples/llm_qat/launch.sh -A 2 -B 2

Repository: NVIDIA/Model-Optimizer

Length of output: 429


🏁 Script executed:

# Check the imports and dataclass setup in main.py
head -50 examples/llm_qat/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2113


🏁 Script executed:

# Check if there's any error handling around parse_args_into_dataclasses
grep -A 10 -B 5 "parse_args_into_dataclasses" examples/llm_qat/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 698


🏁 Script executed:

# Search for any tests or usage of EXTRA_ARGS in launch.sh to see if it's actually expected to work
grep -n "EXTRA_ARGS\|unknown\|unrecognized" examples/llm_qat/launch.sh

Repository: NVIDIA/Model-Optimizer

Length of output: 207


🏁 Script executed:

# Check if there's any documentation or comments about handling extra arguments
grep -n "arg\|option" examples/llm_qat/launch.sh | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 119


🏁 Script executed:

# Check the dataclass definitions to see what arguments are expected
grep -n "class.*Arguments.*:" examples/llm_qat/main.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 208


🏁 Script executed:

# Look for field definitions in the dataclasses to understand what arguments are supported
grep -A 2 "field(" examples/llm_qat/main.py | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 1233


🏁 Script executed:

# Check if there's any documentation or comments about EXTRA_ARGS in launch.sh
head -100 examples/llm_qat/launch.sh | tail -50

Repository: NVIDIA/Model-Optimizer

Length of output: 1606


🏁 Script executed:

# Check if dataclasses have any special field configuration for handling extra args
grep -n "allow_extra_keys\|extra_keys\|allow_extra" examples/llm_qat/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for any sys.argv manipulation or argument filtering before parse_args_into_dataclasses
grep -B 20 "parse_args_into_dataclasses" examples/llm_qat/main.py | grep -n "sys\|argv\|args"

Repository: NVIDIA/Model-Optimizer

Length of output: 213


🏁 Script executed:

# Check the actual command being executed to see the full context
sed -n '170,190p' examples/llm_qat/launch.sh

Repository: NVIDIA/Model-Optimizer

Length of output: 570


🏁 Script executed:

# Check if there are any example usage or tests showing how EXTRA_ARGS is meant to be used
find examples/llm_qat -type f \( -name "*.md" -o -name "README*" -o -name "*.txt" \) | xargs grep -l "EXTRA_ARGS" 2>/dev/null

Repository: NVIDIA/Model-Optimizer

Length of output: 50


🏁 Script executed:

# Look for any documentation about the script usage
head -50 examples/llm_qat/launch.sh

Repository: NVIDIA/Model-Optimizer

Length of output: 3424


🏁 Script executed:

# Verify the exact HfArgumentParser import and check if there's any custom handling
grep -A 5 "HfArgumentParser" examples/llm_qat/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 240


🏁 Script executed:

# Check the transformers library version to understand HfArgumentParser behavior
grep -r "transformers" examples/llm_qat/ | grep -E "requirements|setup|version" | head -10

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Search for any tests or examples in the repo that show EXTRA_ARGS being used
find . -type f -name "*.sh" -o -name "*.md" | xargs grep -l "EXTRA_ARGS" 2>/dev/null | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 94


🏁 Script executed:

# Check if there's any try-catch or error handling in the train function
sed -n '185,210p' examples/llm_qat/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1243


Unrecognized arguments forwarded to main.py will cause a runtime crash.

The EXTRA_ARGS mechanism (line 54) silently collects unrecognized flags and appends them to the Python command (line 181), but main.py uses HfArgumentParser.parse_args_into_dataclasses() which calls argparse.parse_args() internally—not parse_known_args(). Any unrecognized argument will raise an error at runtime.

Either:

  1. Switch main.py to use parse_known_args() and handle the remainder, or
  2. Validate that EXTRA_ARGS only contains arguments recognized by main.py, or
  3. Remove the EXTRA_ARGS mechanism if it's not actively needed
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_qat/launch.sh` around lines 53 - 55, The current launch.sh
silently collects any unrecognized token into EXTRA_ARGS which gets forwarded to
main.py (which uses HfArgumentParser.parse_args_into_dataclasses() and will
crash on unknown args); update the script to stop appending arbitrary arguments
in the default *) branch of the case: either remove the EXTRA_ARGS mechanism
entirely, or validate/whitelist tokens before appending (e.g., check $1 against
a set of allowed flags) and otherwise emit a clear error and exit; reference
EXTRA_ARGS, the default *) case in launch.sh, and main.py /
HfArgumentParser.parse_args_into_dataclasses() when making the change.

Comment on lines +96 to +104
def _build_cache_path(
dataset: str, dataset_cache_path: str, max_length: int, train_size: int, eval_size: int
) -> str:
if dataset_cache_path:
return dataset_cache_path
cache_key = hashlib.sha1(
f"{dataset}|{max_length}|{train_size}|{eval_size}".encode()
).hexdigest()[:12]
return os.path.join(tempfile.gettempdir(), f"llm_qat_tokenized_{cache_key}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Include tokenizer identity in the derived cache path.

The hash only uses dataset/length/split sizes. A second run with a different tokenizer but the same dataset will silently reload the first run's cached token IDs and labels.

Suggested fix
 def _build_cache_path(
-    dataset: str, dataset_cache_path: str, max_length: int, train_size: int, eval_size: int
+    dataset: str,
+    dataset_cache_path: str,
+    tokenizer_id: str,
+    max_length: int,
+    train_size: int,
+    eval_size: int,
 ) -> str:
     if dataset_cache_path:
         return dataset_cache_path
     cache_key = hashlib.sha1(
-        f"{dataset}|{max_length}|{train_size}|{eval_size}".encode()
+        f"{dataset}|{tokenizer_id}|{max_length}|{train_size}|{eval_size}".encode()
     ).hexdigest()[:12]

Then thread a stable tokenizer identifier through _load_cached_dataset(), e.g. tokenizer.name_or_path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_qat/utils.py` around lines 96 - 104, The cache key currently
omits tokenizer identity causing tokenization mismatches; modify
_build_cache_path to accept a tokenizer identifier (e.g.,
tokenizer.name_or_path) and include it in the hashed string, and then thread
that identifier through the caller _load_cached_dataset so it passes
tokenizer.name_or_path into _build_cache_path when constructing the cache path;
ensure the identifier is stable (name_or_path) and included in the same f"{...}"
input used to compute cache_key.

Comment on lines +254 to +255
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
if pad_token is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don't treat pad_token_id == 0 as missing.

0 is a valid token ID. Using or here falls through to eos_token_id, so some tokenizers will pad with the wrong token.

Suggested fix
-        pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
+        pad_token = (
+            tokenizer.pad_token_id
+            if tokenizer.pad_token_id is not None
+            else tokenizer.eos_token_id
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
if pad_token is None:
pad_token = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id
)
if pad_token is None:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_qat/utils.py` around lines 254 - 255, The code incorrectly
treats pad_token_id == 0 as missing by using "pad_token = tokenizer.pad_token_id
or tokenizer.eos_token_id"; change the logic to detect only None so a zero ID is
preserved (e.g., check tokenizer.pad_token_id is None or use a ternary that uses
tokenizer.eos_token_id only when pad_token_id is None). Update the assignment
around pad_token and the subsequent None-check to use explicit None checks
referencing tokenizer.pad_token_id and tokenizer.eos_token_id so valid 0 pad IDs
are not overridden.

Comment on lines +56 to +73
def _setup_liger_fused_loss(self):
"""Set up fused JSD for KD.

Args:
model: The model to compute loss for.
inputs: The inputs to the model.
No-op when called from ModelOptHFTrainer.__init__ (teacher not yet created).
Re-called from KDTrainer.__init__ after _convert_to_distillation_model().
"""
if not model.training:
_compute_loss_func = self.compute_loss_func
self.compute_loss_func = None
model = self.accelerator.unwrap_model(self.model)
if not hasattr(model, "_teacher_model"):
return
teacher = model._teacher_model
if not hasattr(model, "lm_head") or not hasattr(teacher, "lm_head"):
self.use_liger_kernel = False
self.compute_loss_func = self.compute_kd_loss
return

loss_fn = next(iter(model._layers_to_loss.values()))
self._liger_temperature = getattr(loss_fn, "_temperature", 1.0)
self.compute_loss_func = self._liger_loss_func
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Only enable fused KD for the single-LMLogitsLoss case.

This code reads just the first entry from model._layers_to_loss, and _liger_loss_func() always computes one fused JSD term. If a caller configures multiple distillation losses, or a different criterion type, use_liger_kernel=True silently changes the objective instead of reproducing compute_kd_loss(). Please validate that precondition here and fall back to compute_kd_loss() otherwise.

Also applies to: 91-124

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/distill/plugins/huggingface.py` around lines 56 - 73, Ensure
the fused LIGER JSD path is only enabled when there is exactly one distillation
loss of the expected LMLogitsLoss type: in _setup_liger_fused_loss check that
model._layers_to_loss contains a single entry and that that loss object is the
LMLogitsLoss (or has the specific attributes _temperature and logits inputs)
before setting use_liger_kernel=True and compute_loss_func=_liger_loss_func;
otherwise set use_liger_kernel=False and compute_loss_func=compute_kd_loss.
Apply the same exact precondition check in the other setup block mentioned (the
similar code around lines 91-124) so _liger_loss_func is never selected when
multiple or different loss types are configured.

Comment on lines 132 to 140
if state.get("_is_nvfp4_static_quantizer") and not isinstance(
module, NVFP4StaticQuantizer
module, StaticBlockScaleQuantizer
):
NVFP4StaticQuantizer.from_tensor_quantizer(module)
StaticBlockScaleQuantizer.from_tensor_quantizer(module)
if state.get("_is_nvfp4_static_adaround_quantizer") and not isinstance(
module, NVFP4StaticAdaRoundQuantizer
):
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer(module)
module.set_from_modelopt_state(quantizer_state_dict[name])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if _is_nvfp4_static_quantizer is set when creating AdaRound quantizer
ast-grep --pattern 'from_nvfp4_quantizer($$$) {
  $$$
}'

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

#!/bin/bash
# Search for where the flags are set in the codebase
echo "=== Searching for _is_nvfp4_static_quantizer flag assignments ==="
rg "_is_nvfp4_static_quantizer" -n --context 3

echo ""
echo "=== Searching for _is_nvfp4_static_adaround_quantizer flag assignments ==="
rg "_is_nvfp4_static_adaround_quantizer" -n --context 3

Repository: NVIDIA/Model-Optimizer

Length of output: 2995


🏁 Script executed:

#!/bin/bash
# Find the implementation of from_nvfp4_quantizer and from_tensor_quantizer
echo "=== Searching for from_nvfp4_quantizer method ==="
rg "def from_nvfp4_quantizer" -A 10 -B 2

echo ""
echo "=== Searching for from_tensor_quantizer method ==="
rg "def from_tensor_quantizer" -A 10 -B 2

Repository: NVIDIA/Model-Optimizer

Length of output: 2674


🏁 Script executed:

#!/bin/bash
# Check class hierarchy and understand the relationship
echo "=== Searching for class definitions ==="
rg "class StaticBlockScaleQuantizer|class NVFP4StaticAdaRoundQuantizer" -B 2 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 1617


🏁 Script executed:

#!/bin/bash
# Find where state is saved/created for these quantizers
echo "=== Looking for state_dict or similar save patterns ==="
rg "state\[" modelopt/torch/quantization/nn/modules/tensor_quantizer.py -B 2 -A 2

echo ""
echo "=== Check the full from_nvfp4_quantizer implementation ==="
rg "def from_nvfp4_quantizer" -A 30 modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1973


🏁 Script executed:

#!/bin/bash
# Find how states are serialized/created
echo "=== Search for places where flags are used in state tracking ==="
rg "get_state|state_dict" modelopt/torch/quantization/nn/modules/tensor_quantizer.py -B 2 -A 2

echo ""
echo "=== Check how quantizers are initialized from saved state ==="
rg "set_from_modelopt_state" -B 5 -A 10 modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1117


🏁 Script executed:

#!/bin/bash
# Understand the full restoration flow in conversion.py
echo "=== Full restore_quantizer_state function context ==="
rg "def restore_quantizer_state" -A 20 modelopt/torch/quantization/conversion.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1345


🏁 Script executed:

#!/bin/bash
# Check what properties are saved and if there's any validation
echo "=== Check _get_properties_for_modelopt_state ==="
rg "_get_properties_for_modelopt_state" -A 15 modelopt/torch/quantization/nn/modules/tensor_quantizer.py

echo ""
echo "=== Check if flags are explicitly tracked in saved properties ==="
rg "_is_nvfp4" modelopt/torch/quantization/nn/modules/tensor_quantizer.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1517


🏁 Script executed:

#!/bin/bash
# Check if there's any scenario where flags could be set independently
echo "=== Search for all places where _is_nvfp4_static_quantizer is set ==="
rg "_is_nvfp4_static_quantizer\s*=" modelopt/torch/quantization/

echo ""
echo "=== Search for all places where _is_nvfp4_static_adaround_quantizer is set ==="
rg "_is_nvfp4_static_adaround_quantizer\s*=" modelopt/torch/quantization/

echo ""
echo "=== Check if there's any initialization that sets one without the other ==="
rg "class NVFP4StaticAdaRoundQuantizer" -A 30 modelopt/torch/quantization/nn/modules/tensor_quantizer.py | grep -A 30 "__init__"

Repository: NVIDIA/Model-Optimizer

Length of output: 995


🏁 Script executed:

#!/bin/bash
# Check if there's validation or tests that verify state consistency
echo "=== Search for tests of restore_quantizer_state ==="
fd "test.*conversion" modelopt/torch/quantization/tests/ --type f

echo ""
echo "=== Check test files ==="
rg "restore_quantizer_state" modelopt/torch/quantization/tests/ -B 2 -A 5

Repository: NVIDIA/Model-Optimizer

Length of output: 357


🏁 Script executed:

#!/bin/bash
# Check if there's any documentation about state consistency
echo "=== Check docstrings and comments about state flags ==="
rg "_is_nvfp4" -B 5 modelopt/torch/quantization/nn/modules/tensor_quantizer.py | grep -A 5 -B 5 "is_nvfp4"

echo ""
echo "=== Check if there's validation during state restoration ==="
rg "set_from_modelopt_state" -A 20 modelopt/torch/quantization/nn/modules/tensor_quantizer.py | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 1752


🏁 Script executed:

#!/bin/bash
# Verify the actual restore flow more carefully
echo "=== Full context around restore_quantizer_state for loop ==="
rg "def restore_quantizer_state" -A 50 modelopt/torch/quantization/conversion.py | head -60

echo ""
echo "=== Check if flags are initialized on new instances ==="
rg "class StaticBlockScaleQuantizer" -A 15 modelopt/torch/quantization/nn/modules/tensor_quantizer.py | grep -A 15 "__init__"

Repository: NVIDIA/Model-Optimizer

Length of output: 2784


🏁 Script executed:

#!/bin/bash
# Let me trace through the exact order of operations in restore
echo "=== Verify the order: conversion happens BEFORE set_from_modelopt_state ==="
rg "for name, module in model.named_modules" -A 15 modelopt/torch/quantization/conversion.py | head -25

Repository: NVIDIA/Model-Optimizer

Length of output: 1196


Add validation to ensure both flags are set together for AdaRound quantizers.

The two conversion checks are independent if statements and execute sequentially. The from_nvfp4_quantizer method requires its input to be a StaticBlockScaleQuantizer (enforced by assertion). If a saved state has _is_nvfp4_static_adaround_quantizer=True but _is_nvfp4_static_quantizer=False, the first conversion is skipped and the second fails on the assertion.

While normal operation ensures both flags are set together when creating an NVFP4StaticAdaRoundQuantizer, there is no explicit validation during state restoration. Consider adding a check to ensure saved states maintain this invariant, or restructure the restoration logic to handle this edge case gracefully (e.g., convert to StaticBlockScaleQuantizer first if needed before checking for AdaRound).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/conversion.py` around lines 132 - 140, The
current sequential checks can call
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer(module) without ensuring
module is first a StaticBlockScaleQuantizer; update the restoration logic in
conversion.py so that when state["_is_nvfp4_static_adaround_quantizer"] is true
you first ensure or convert the module to a StaticBlockScaleQuantizer (i.e., run
StaticBlockScaleQuantizer.from_tensor_quantizer(module) if needed) or explicitly
validate that state["_is_nvfp4_static_quantizer"] is also true and raise/handle
a clear error; specifically touch the block that references
StaticBlockScaleQuantizer, NVFP4StaticAdaRoundQuantizer, from_nvfp4_quantizer,
and set_from_modelopt_state to enforce the invariant or perform the conversion
before calling from_nvfp4_quantizer.

Comment on lines +1555 to +1581
def from_nvfp4_quantizer(
cls,
tq: StaticBlockScaleQuantizer,
weight_scaled: torch.Tensor | None = None,
) -> "NVFP4StaticAdaRoundQuantizer":
"""Convert an NVFP4StaticQuantizer to NVFP4StaticAdaRoundQuantizer in-place.

Args:
tq: The NVFP4StaticQuantizer to convert.
weight_scaled: Pre-scaled weight tensor of shape ``[num_blocks, block_size]``.
If provided, :meth:`enable_adaround` is called immediately.
"""
assert isinstance(tq, StaticBlockScaleQuantizer), (
f"Expected StaticBlockScaleQuantizer, got {type(tq)}"
)

if isinstance(tq, cls):
if weight_scaled is not None:
tq.enable_adaround(weight_scaled)
return tq
tq.__class__ = cls
tq._is_nvfp4_static_adaround_quantizer = True
tq._adaround_enabled = False
tq.temperature = 1.0
if weight_scaled is not None:
tq.enable_adaround(weight_scaled)
return tq
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Reject non-FP4 quantizers in from_nvfp4_quantizer().

After this PR StaticBlockScaleQuantizer also represents INT block formats, but this conversion path still initializes AdaRound with fp4_cast_ste() / fp4_step_size() and _cast_ste() stays FP4-only. Converting an INT4/INT8 block quantizer here will learn rounding decisions on the wrong grid. Please guard this to the supported FP4 config, or add an int-specific AdaRound implementation.

Also applies to: 1595-1609, 1638-1648

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/nn/modules/tensor_quantizer.py` around lines 1555
- 1581, The conversion method from_nvfp4_quantizer should reject non-FP4/
non-NVFP4 block quantizers instead of blindly mutating any
StaticBlockScaleQuantizer; add a guard near the start of
NVFP4StaticAdaRoundQuantizer.from_nvfp4_quantizer that inspects the source
quantizer's format/block-type (e.g., a property like block_format / is_fp4 /
nvfp4 flag on tq) and raise/assert if it is not an FP4/NVFP4 config, and do the
same check in the other related conversion methods mentioned (the other
from_nvfp4_quantizer conversion blocks at the ranges you noted) so AdaRound
remains FP4-only unless you implement a separate INT-specific AdaRound path.

Comment on lines +308 to +312
def _setup_adaround(self):
"""Set up AdaRound: register aux callback and freeze parent weights."""
self._adaround_pending_metrics = {}
self.add_callback(_AdaRoundAuxCallback(trainer=self))
self._freeze_adaround_weights()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Propagate adaround_args.temperature into the quantizers.

AdaRoundTrainingArguments.temperature is stored on the trainer, but _setup_adaround() never writes it to any NVFP4StaticAdaRoundQuantizer. Non-default values are ignored and AdaRound always runs at temperature = 1.0.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/transformers_trainer.py` around lines 308
- 312, The trainer's _setup_adaround() registers the AdaRound callback but never
propagates AdaRoundTrainingArguments.temperature to the quantizers, so
non-default temperatures are ignored; update _setup_adaround() to iterate the
model's quantizers (e.g., instances of NVFP4StaticAdaRoundQuantizer) and set
their temperature property from self.args.ada_round_args.temperature (or
self.ada_rounding_args / AdaRoundTrainingArguments.temperature where stored)
before adding the _AdaRoundAuxCallback, ensuring each
NVFP4StaticAdaRoundQuantizer.temperature (or equivalent attribute) is assigned
the trainer's temperature value so AdaRound uses the configured temperature.

Comment on lines +650 to +672
model = self._trainer.accelerator.unwrap_model(self._trainer.model)
self._weight_entries = [] # list of (weight_param, quantizer)
for _name, module in model.named_modules():
for weight_name in weight_attr_names(module):
wq_name = quantizer_attr_names(weight_name).weight_quantizer
quantizer = getattr(module, wq_name, None)
if quantizer is None or not quantizer.is_enabled:
continue
weight = getattr(module, weight_name, None)
if not isinstance(weight, torch.nn.Parameter):
continue
self._weight_entries.append((weight, quantizer))

pid_to_group = {}
for group_idx, group in enumerate(self._trainer.optimizer.param_groups):
for p in group["params"]:
pid_to_group[id(p)] = group_idx
self._param_group_idx = {}
self._multiplier = {}
for weight, _q in self._weight_entries:
self._param_group_idx[id(weight)] = pid_to_group.get(id(weight), 0)
self._multiplier[id(weight)] = torch.zeros(1, device=weight.device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

QERR should ignore frozen weights.

This helper registers every quantized weight, even when it has requires_grad=False or is absent from the optimizer. In LoRA / trainable_params runs, the qerr_coeff > 0 path backprops through those tensors and later uses weight.grad even though it is None. Filter _weight_entries down to trainable optimizer-owned weights here.

Also applies to: 717-728

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/transformers_trainer.py` around lines 650
- 672, The current registration loop (_weight_entries) includes quantized
weights even when they are frozen or not in the optimizer, causing qerr to
backprop through tensors with no grads; update the loop that builds
self._weight_entries (the block using weight_attr_names, quantizer_attr_names,
weight_quantizer and is_enabled) to only append weights that are
torch.nn.Parameter AND weight.requires_grad is True AND the weight's id is
present in the optimizer param mapping (pid_to_group); similarly ensure the
later population of self._param_group_idx and self._multiplier only iterates
over these filtered _weight_entries (so id(weight) exists in pid_to_group) to
avoid creating entries for optimizer-less/frozen weights — apply the same guard
to the analogous registration block elsewhere that performs the same work.

Comment on lines +676 to +695
def _compute_mse(self, weight, quantizer):
"""Compute MSE between original and quantized weight."""
if isinstance(quantizer, StaticBlockScaleQuantizer) and not (
quantizer._lsq or quantizer._laq or quantizer._smooth_laq
):
# smooth_lsq: weights are pre-divided, use raw cast
orig_shape = weight.shape
if hasattr(quantizer, "_block_reshape_size"):
w = weight.reshape(quantizer._block_reshape_size)
else:
w = weight
q_weight = quantizer._cast_ste(w)
q_weight = q_weight.reshape(orig_shape)
else:
q_weight = quantizer(weight)
sq_err = (q_weight.detach() - weight) ** 2
if self._trainer.qerr_args.qerr_reduction == "sum":
return sq_err.sum()
return sq_err.mean()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The _cast_ste shortcut is wrong for plain static-block quantizers.

This branch also catches a calibrated StaticBlockScaleQuantizer with no learnable mode enabled. In that case the stored weight is still in real scale, so _cast_ste(w) drops the block/tensor scales and the QERR metric/update targets the wrong grid. Keep the shortcut for pre-divided smooth_lsq weights only; use quantizer(weight) for the other static-block cases.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/plugins/transformers_trainer.py` around lines 676
- 695, The branch in _compute_mse incorrectly uses quantizer._cast_ste for any
StaticBlockScaleQuantizer that has no learnable modes, which incorrectly drops
block/tensor scales for calibrated (non-pre-divided) weights; change the
condition so _cast_ste is used only when the quantizer is a
pre-divided/smooth_lsq case (i.e., when quantizer._smooth_lsq is True), and for
other StaticBlockScaleQuantizer instances call quantizer(weight) instead;
preserve the existing reshape logic using quantizer._block_reshape_size and then
reshape back to orig_shape before computing sq_err.

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: This PR introduces five new quantization algorithms (SmoothLSQ, LSQ, LAQ, SmoothLAQ, AdaRound) with learnable-scale quantizer support, refactors the base quantizer class from NVFP4StaticQuantizer to StaticBlockScaleQuantizer, adds training infrastructure (AdaRound dist_loss callback, quantization error regularization, Liger kernel fusion, per-parameter LR config), and extends the QAT example with local dataset support and caching. This is a very large PR (~3500 lines of diff across 23 files).

Issues Found:

  1. [Correctness] Hardcoded internal path in create_ptq.py:138
    The default --dataset argument is /home/scratch.akuriparambi_coreai/datasets/qat_blend_sft/blend_sft.jsonl — an internal filesystem path that won't exist outside NVIDIA. This should be changed to require the argument or use a publicly available dataset.

  2. [Correctness] Tokenization format change in utils.py:_make_tokenize_fn
    The new tokenizer prepends role labels (f"{role}: {content}\n", e.g. "User: hello\n") whereas the original get_daring_anteater tokenized only the conversation value (conversation["value"] + "\n"). This silently changes the tokenization semantics for the Daring-Anteater dataset path, which flows through _normalize_to_messages_make_tokenize_fn. The role prefix tokens weren't in the original training data and will affect loss masking since only "Assistant" role tokens get labels.

  3. [Correctness] _load_cached_dataset tokenizes then splits vs original split-then-tokenize
    The original get_daring_anteater shuffled, selected, tokenized (via .map()), then did train_test_split. The new _load_cached_dataset shuffles, selects, splits, then tokenizes per split. While functionally similar, the split happens on raw data vs tokenized data, meaning the seed-42 split will produce different train/test assignments than before — a silent regression for anyone comparing against old results.

  4. [Correctness] launch.sh silently swallows unknown args
    Lines 52-53 change the behavior from erroring on unknown args to collecting them in EXTRA_ARGS. This is convenient for pass-through but means typos in argument names (e.g. --ouput_dir) will silently be passed to main.py rather than caught early. At minimum, a warning would be appropriate.

  5. [Correctness] model_quant.py:quantize() skip-quantize path
    The new logic at line 231-236 skips apply_mode for already-quantized models and only calls set_quantizer_by_cfg. This means if a user calls quantize() twice with a different quant_cfg, the module structure won't be updated (e.g., no new quantizer modules added), only existing quantizer attributes get set. This could silently produce wrong results if the second config targets quantizers that don't exist yet. The behavior change should at minimum be documented.

  6. [Correctness] _apply_gradient_checkpointing_defaults silently overrides user setting
    In transformers.py:207-215, use_reentrant=True is forcibly overridden to False with only a warning. The original code in main.py set use_reentrant=True (now removed). While use_reentrant=False is generally better, forcibly overriding a user's explicit setting is surprising behavior for a base trainer class.

  7. [Correctness] QATTrainer rejects non-distributed mode
    Lines 258-262 raise ValueError for ParallelMode.NOT_DISTRIBUTED, but the test fixtures monkeypatch to NOT_PARALLEL. This means the tests bypass a check that would fire in real single-GPU usage. If single-GPU is truly unsupported, this needs clearer documentation; if it should work, the check needs refinement.

  8. [Duplicated Code] is_static_block_scale check repeated 3+ times
    The pattern checking module.is_static_block_quant and module._block_sizes is not None and ((module._num_bits == (2, 1) and ...) or isinstance(module._num_bits, int)) appears in mse_calibrate, local_hessian_calibrate, and _convert_to_static_block_quantizers. Extract this to a helper like _is_eligible_for_static_block_scale(quantizer).

  9. [Duplicated Code] Config class boilerplate
    SmoothLSQConfig, LSQConfig, LAQConfig, SmoothLAQConfig in config.py are nearly identical — each has only a method literal and identical scale_algorithm field with the same description. Consider a base class or factory to reduce the ~100 lines of duplication.

  10. [Readability] StaticBlockScaleQuantizer._fake_quantize is very long
    The _fake_quantize method handles 6+ code paths (smooth_lsq, lsq, laq, smooth_laq, FP4 static, INT static, fallback). This 50+ line method with deeply nested conditions would benefit from being split into named helpers (e.g. _fake_quantize_learnable, _fake_quantize_static).

  11. [Readability] Class-level mutable state as class attributes
    StaticBlockScaleQuantizer defines _smooth_lsq, _lsq, _laq, _smooth_laq as class-level booleans (line 1293-1296 in the diff). These are instance-level state mutated by enable_* methods. While Python handles this correctly (instance assignment shadows class attribute), it's a confusing pattern — these should be set in from_tensor_quantizer or __init__.

  12. [Tests] Missing test for tokenization change
    The tokenization format change in _make_tokenize_fn (prepending role labels) has no test coverage. The existing tests only cover the quantization algorithms and trainer integration.

  13. [Correctness] torch.float32 default dtype change in main.py:202
    Changed from torch_dtype=torch.bfloat16 to torch.float32 with a comment about mixed precision. This doubles the memory requirement for model loading. The justification comment is thin — mixed precision training typically handles bf16 casting itself, but loading in fp32 means the model sits in fp32 until the first forward pass. This could cause OOM on memory-constrained setups. Should at least be configurable.

  14. [Correctness] _QuantErrorAuxCallback._compute_mse calls mse.backward() outside autocast
    The QERR callback (line 735 in transformers_trainer.py) calls mse.backward() on a manually computed MSE. If mixed precision is active, this backward pass happens outside the trainer's autocast context, which could lead to dtype mismatches or suboptimal gradient precision.

Suggestions:

  • The PR is very large (23 files, ~3500 lines). Consider splitting into: (1) StaticBlockScaleQuantizer refactor, (2) learnable-scale algorithms, (3) AdaRound, (4) training infrastructure (Liger, LR config, QERR), (5) example/dataset improvements.
  • Add a migration note for the NVFP4StaticQuantizerStaticBlockScaleQuantizer rename, even though the alias is preserved.
  • The _dataset_cache module-level dict in utils.py is a global mutable singleton that persists across calls — consider documenting this clearly or using a more explicit caching mechanism.

Overall Assessment: The algorithmic work (learnable-scale quantizers, AdaRound) appears well-designed with good test coverage for the core quantization paths. However, the PR bundles too many orthogonal changes (training infra, dataset handling, Liger integration, gradient checkpointing policy changes) making it hard to review safely. The tokenization format change and the quantize() skip-path are the highest-risk correctness concerns. The hardcoded internal path must be fixed before merge.

@realAsma realAsma closed this Apr 1, 2026
@realAsma realAsma deleted the asma/QAD-new-algorithms branch April 1, 2026 12:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants