Skip to content

Fix: Add torch.no_grad() decorator to MMLU evaluate metric#4073

Closed
fa-ina-tic wants to merge 1 commit into
quic:developfrom
fa-ina-tic:fix/torch-eval
Closed

Fix: Add torch.no_grad() decorator to MMLU evaluate metric#4073
fa-ina-tic wants to merge 1 commit into
quic:developfrom
fa-ina-tic:fix/torch-eval

Conversation

@fa-ina-tic
Copy link
Copy Markdown

@fa-ina-tic fa-ina-tic commented Feb 2, 2026

Hi team!
I noticed a memory issue during MMLU evaluation and wanted to share a quick fix.

Problem
Running python -m Examples.torch.evaluate --eval-mmlu in the quantization recipe causes continuous VRAM growth during evaluation, causing CUDA Out Of Memory error.

Root Cause
The GenericMMLU.evaluate() method was missing gradient computation disabling, unlike --eval-ppl which properly uses no_grad context.

Changes
Added @torch.no_grad() decorator to GenericMMLU.evaluate().

Note
Other evaluation metrics (Interactive, TrickyPrompts, Prompts) may have the same issue, but this has not been checked. Additional fixes may be required in a follow-up PR if confirmed.

Additional Info(for reproduce)
Machine

Operating System: Ubuntu 24.04.3 LTS              
Kernel: Linux 6.8.0-88-generic
Architecture: x86-64

Python setup

absl-py==2.4.0
accelerate==1.12.0
aimet-onnx @ file:///workplace/aimet_onnx-2.22.0+cu121-cp310-cp310-manylinux_2_34_x86_64.whl
aimet-torch @ file:///workplace/aimet_torch-2.22.0+cu121-py310-none-any.whl
aiohappyeyeballs==2.6.1
aiohttp==3.13.3
aiosignal==1.4.0
anyio==4.12.1
async-timeout==5.0.1
attrs==25.4.0
bleach==6.3.0
bokeh==3.2.2
certifi==2026.1.4
cffi==2.0.0
charset-normalizer==3.4.4
clarabel==0.11.1
colorcet==3.1.0
coloredlogs==15.0.1
contourpy==1.3.2
cuda-bindings==12.9.4
cuda-pathfinder==1.3.3
cvxpy==1.6.0
cycler==0.12.1
dataclasses==0.8
datasets==4.5.0
dill==0.4.0
exceptiongroup==1.3.1
filelock==3.20.3
flatbuffers==25.12.19
fonttools==4.61.1
frozenlist==1.8.0
fsspec==2025.10.0
grpcio==1.76.0
h11==0.16.0
h5py==3.15.1
hf-xet==1.2.1
holoviews==1.18.3
httpcore==1.0.9
httpx==0.28.1
huggingface-hub==0.36.0
humanfriendly==10.0
hvplot==0.9.2
idna==3.11
iniconfig==2.3.0
jinja2==3.1.6
joblib==1.5.3
jsonschema==4.26.0
jsonschema-specifications==2025.9.1
kiwisolver==1.4.9
linkify-it-py==2.0.3
markdown==3.10.1
markdown-it-py==4.0.0
markupsafe==3.0.3
matplotlib==3.10.8
mdit-py-plugins==0.5.0
mdurl==0.1.2
ml-dtypes==0.5.4
mpmath==1.3.0
multidict==6.7.1
multiprocess==0.70.18
networkx==3.4.2
numpy==1.24.4
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.5
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvshmem-cu12==3.4.5
nvidia-nvtx-cu12==12.8.90
onnx==1.20.1
onnx-ir==0.1.15
onnx2torch==1.5.15
onnxruntime-extensions==0.14.0
onnxruntime-gpu==1.23.2
onnxscript==0.6.0
onnxsim==0.4.36
osqp==1.1.0
packaging==26.0
pandas==2.3.3
panel==1.3.8
param==2.3.1
peft==0.18.1
pillow==12.1.0
pluggy==1.6.0
propcache==0.4.1
protobuf==6.33.5
psutil==7.2.2
pyarrow==23.0.0
pybind11==3.0.1
pycparser==3.0
pygments==2.19.2
pyparsing==3.3.2
pytest==9.0.2
python-dateutil==2.9.0.post0
pytz==2025.2
pyviz-comms==3.0.6
pyyaml==6.0.3
referencing==0.37.0
regex==2026.1.15
requests==2.32.5
rich==14.3.2
rpds-py==0.30.0
safetensors==0.7.0
scikit-learn==1.7.2
scipy==1.8.1
scs==3.2.11
setuptools==80.10.2
six==1.17.0
sympy==1.14.0
tensorboard==2.20.0
tensorboard-data-server==0.7.2
threadpoolctl==3.6.0
tokenizers==0.21.4
tomli==2.4.0
torch==2.10.0
torchvision==0.25.0
tornado==6.5.4
tqdm==4.67.2
transformers==4.53.0
triton==3.6.0
typing-extensions==4.15.0
tzdata==2025.3
uc-micro-py==1.0.3
urllib3==2.6.3
webencodings==0.5.1
werkzeug==3.1.5
xxhash==3.6.0
xyzservices==2025.11.0
yarl==1.22.0

command

python -m Examples.torch.evaluate \
 --model-id "meta-llama/Llama-3.2-1B-Instruct"  \
--checkpoint "./torch_pcq"  \
--eval-mmlu

full error log

Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 285/285 [00:00<00:00, 28345.27 examples/s]
Evaluating MMLU:  23%|██████████████████████████████▏                                                                                                  | 3283/14042 [15:10<49:42,  3.61it/s]
Traceback (most recent call last):
  File "/usr/local/bin/python-3.10.19/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/bin/python-3.10.19/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/workplace/aimet-develop/Examples/torch/evaluate.py", line 219, in <module>
    mmlu_score = MMLU.evaluate(generator, tokenizer, CONTEXT_LENGTH)
  File "/workplace/aimet-develop/GenAITests/shared/helpers/metrics.py", line 127, in evaluate
    outputs = model(input_ids=batch["input_ids"])
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workplace/aimet-develop/GenAITests/shared/models/generator.py", line 373, in forward
    local_outputs = self.model(*prepared_inputs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workplace/aimet-develop/GenAITests/shared/models/utils/model_utils.py", line 49, in forward
    lm_logits, new_past_key_values = self.model(
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 552, in forward
    outputs: BaseModelOutputWithPast = self.model(
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/transformers/utils/generic.py", line 943, in wrapper
    output = func(self, *args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 440, in forward
    layer_outputs = decoder_layer(
  File "/workplace/.venv/lib/python3.10/site-packages/transformers/modeling_layers.py", line 83, in __call__
    return super().__call__(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 306, in forward
    hidden_states = self.mlp(hidden_states)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 151, in forward
    down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/nn/base.py", line 262, in __call__
    return super().__call__(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/utils.py", line 256, in wrapper
    return fn(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/nn/true_quant.py", line 1716, in forward
    return super().forward(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/nn/true_quant.py", line 534, in forward
    output = super().forward(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/nn/true_quant.py", line 213, in forward
    return super().forward(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/nn/base.py", line 272, in forward
    return super().forward(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 134, in forward
    return F.linear(input, self.weight, self.bias)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/nn/true_quant.py", line 493, in __torch_function__
    return super().__torch_function__(impl, types, args, kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/overrides.py", line 2108, in __torch_function__
    return func(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/nn/true_quant.py", line 569, in wrapper
    return _quantize_dequantize_if_applicable(output, self.output_quantizers[0])
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/nn/true_quant.py", line 101, in _quantize_dequantize_if_applicable
    data = quantizer(data)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/quantization/affine/quantizer.py", line 1039, in forward
    output = quantize_dequantize(
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/quantization/affine/backends/__init__.py", line 367, in quantize_dequantize
    return get_backend().quantize_dequantize(
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/experimental/onnx/_export.py", line 470, in wrapper
    return python_fn(*args, **kwargs)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/quantization/affine/backends/torch_builtins.py", line 277, in quantize_dequantize
    qdq_tensor = QuantDequantFunc.apply(
  File "/workplace/.venv/lib/python3.10/site-packages/torch/autograd/function.py", line 583, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/quantization/affine/backends/torch_builtins.py", line 637, in forward
    return impl(ctx, tensor, scale, offset, qmin, qmax)
  File "/workplace/.venv/lib/python3.10/site-packages/aimet_torch/v2/quantization/affine/backends/torch_builtins.py", line 651, in _forward_impl
    mask = (qmin <= x_round) & (x_round <= qmax)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB. GPU 0 has a total capacity of 93.10 GiB of which 5.31 MiB is free. Process 672050 has 9.72 GiB memory in use. Process 3138231 has 83.35 GiB memory in use. Of the allocated memory 82.66 GiB is allocated by PyTorch, and 33.34 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Signed-off-by: fa-ina-tic <sooh0601@gmail.com>
@quic-hitameht
Copy link
Copy Markdown
Contributor

Thanks @fa-ina-tic for opening the PR. I have merged a fix, and it should resolve the issue. Feel free to test it out, and you can close this PR whenever you're ready.

@fa-ina-tic
Copy link
Copy Markdown
Author

Thanks for your work, @quic-hitameht, it works perfect:)
I'll close this PR.

@fa-ina-tic fa-ina-tic closed this Feb 3, 2026
@fa-ina-tic fa-ina-tic deleted the fix/torch-eval branch February 9, 2026 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants