Problem Description
aiter/utility/fp4_utils.py contains an MXFP4 quantization kernel (_dynamic_mxfp4_quant_kernel_asm_layout) that still uses the old round-ties-up rounding logic, which was identified and fixed in aiter/ops/triton/_triton_kernels/quant.py via PR #975 (fixing #974).
The fix in #975 replaced the manual shift-based conversion with proper roundTiesToEven (banker's rounding) using three-way branching (saturate/denormal/normal masks) and the magic-number addition trick for denormals — matching torchao and MI355 v_cvt_scalef32_pk_fp4_f32 behavior. However, only quant.py was patched. The same buggy logic remains in fp4_utils.py.
Affected code
https://github.com/ROCm/aiter/blob/main/aiter/utility/fp4_utils.py#L321
The kernel at this location still contains the pre-#975 conversion:
# Extract sign, exponents and mantissa fields from FP32
s = qx & 0x80000000
e = (qx >> 23) & 0xFF
m = qx & 0x7FFFFF
E8_BIAS: tl.constexpr = 127
E2_BIAS: tl.constexpr = 1
# Denormal numbers
adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False)
m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m)
e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7)
e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8)
What's wrong
Two issues (same as originally reported in #974):
-
Normal values use round-ties-up instead of roundTiesToEven. The (value + 1) >> 1 pattern always rounds midpoints up. For example, -0.625 / scale at the exact midpoint between FP4 values -1.0 and -1.5 rounds to -1.0 instead of the correct -1.5 (even mantissa).
-
Denormals cannot round up properly. The manual shift-based denormal path doesn't handle rounding correctly. For example, FP32 value 0x3F000003 should round up to 0.5 but rounds down to 0.0.
Both torchao and the MI355 hardware instruction v_cvt_scalef32_pk_fp4_f32 use roundTiesToEven, so this code produces results inconsistent with the hardware and with the already-patched quant.py.
Expected behavior
fp4_utils.py should use the same corrected rounding logic that was applied to quant.py in PR #975, which implements:
- Three-way masking:
saturate_mask, denormal_mask, normal_mask
- Denormal conversion via the magic-number addition trick (
denorm_mask_float)
- Normal conversion with proper round-to-nearest-even via
mant_odd bias
Suggested fix
Apply the same transformation from PR #975 to the _dynamic_mxfp4_quant_kernel_asm_layout kernel in fp4_utils.py. The corrected conversion block from quant.py can be used directly.
Related
Operating System
Linux-6.8.0-60-generic-x86_64-with-glibc2.39
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD Instinct MI355X
ROCm Version
ROCm 7.1
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
Problem Description
aiter/utility/fp4_utils.pycontains an MXFP4 quantization kernel (_dynamic_mxfp4_quant_kernel_asm_layout) that still uses the old round-ties-up rounding logic, which was identified and fixed inaiter/ops/triton/_triton_kernels/quant.pyvia PR #975 (fixing #974).The fix in #975 replaced the manual shift-based conversion with proper roundTiesToEven (banker's rounding) using three-way branching (saturate/denormal/normal masks) and the magic-number addition trick for denormals — matching torchao and MI355
v_cvt_scalef32_pk_fp4_f32behavior. However, onlyquant.pywas patched. The same buggy logic remains infp4_utils.py.Affected code
https://github.com/ROCm/aiter/blob/main/aiter/utility/fp4_utils.py#L321
The kernel at this location still contains the pre-#975 conversion:
What's wrong
Two issues (same as originally reported in #974):
Normal values use round-ties-up instead of roundTiesToEven. The
(value + 1) >> 1pattern always rounds midpoints up. For example,-0.625 / scaleat the exact midpoint between FP4 values-1.0and-1.5rounds to-1.0instead of the correct-1.5(even mantissa).Denormals cannot round up properly. The manual shift-based denormal path doesn't handle rounding correctly. For example, FP32 value
0x3F000003should round up to0.5but rounds down to0.0.Both torchao and the MI355 hardware instruction
v_cvt_scalef32_pk_fp4_f32use roundTiesToEven, so this code produces results inconsistent with the hardware and with the already-patchedquant.py.Expected behavior
fp4_utils.pyshould use the same corrected rounding logic that was applied toquant.pyin PR #975, which implements:saturate_mask,denormal_mask,normal_maskdenorm_mask_float)mant_oddbiasSuggested fix
Apply the same transformation from PR #975 to the
_dynamic_mxfp4_quant_kernel_asm_layoutkernel infp4_utils.py. The corrected conversion block fromquant.pycan be used directly.Related
aiter/ops/triton/_triton_kernels/quant.py(merged Nov 26, 2025)Operating System
Linux-6.8.0-60-generic-x86_64-with-glibc2.39
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD Instinct MI355X
ROCm Version
ROCm 7.1
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response