Skip to content

Commit 36f984b

Browse files
committed
Fixed the CI issue
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 4a07d88 commit 36f984b

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

tests/unit/torch/export/test_export_diffusers.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717

1818
import pytest
19+
import torch
1920
from _test_utils.torch.diffusers_models import get_tiny_dit, get_tiny_flux, get_tiny_unet
2021

2122
pytest.importorskip("diffusers")
@@ -93,7 +94,6 @@ def _process_stub(*_args, **_kwargs):
9394
("int8", mtq.INT8_DEFAULT_CFG),
9495
("int8_smoothquant", mtq.INT8_SMOOTHQUANT_CFG),
9596
("fp8", mtq.FP8_DEFAULT_CFG),
96-
("fp4", mtq.NVFP4_DEFAULT_CFG),
9797
],
9898
)
9999
def test_export_diffusers_real_quantized(tmp_path, model_factory, config_id, quant_cfg):
@@ -108,12 +108,30 @@ def _calib_fn(m):
108108

109109
mtq.quantize(model, quant_cfg, forward_loop=_calib_fn)
110110

111-
try:
112-
export_hf_checkpoint(model, export_dir=export_dir)
113-
except AssertionError as e:
114-
if "block size" in str(e) and config_id == "fp4":
115-
pytest.skip(f"Tiny model weights incompatible with FP4 block quantization: {e}")
116-
raise
111+
export_hf_checkpoint(model, export_dir=export_dir)
112+
113+
config_path = export_dir / "config.json"
114+
assert config_path.exists()
115+
116+
config_data = _load_config(config_path)
117+
assert "quantization_config" in config_data
118+
119+
120+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="FP4 export requires NVIDIA GPU")
121+
def test_export_diffusers_real_quantized_fp4(tmp_path):
122+
"""FP4 export test using get_tiny_dit (the only tiny model with FP4-compatible weight shapes)."""
123+
model = get_tiny_dit()
124+
export_dir = tmp_path / "export_DiTTransformer2DModel_fp4_real_quant"
125+
126+
def _calib_fn(m):
127+
param = next(m.parameters())
128+
dummy_inputs = generate_diffusion_dummy_inputs(m, param.device, param.dtype)
129+
assert dummy_inputs is not None
130+
m(**dummy_inputs)
131+
132+
mtq.quantize(model, mtq.NVFP4_DEFAULT_CFG, forward_loop=_calib_fn)
133+
134+
export_hf_checkpoint(model, export_dir=export_dir)
117135

118136
config_path = export_dir / "config.json"
119137
assert config_path.exists()

0 commit comments

Comments
 (0)