[TRITON] fix: MXFP4 mantissa rounding#975
Conversation
|
cc @vgokhale |
|
Hi @hann-wang, I couldn't replicate what you said about denormal numbers here #974, from my testing, the current AITER kernel does the rounding this way: While your implementation does it like this: They differ only for |x| = 0.25, where your kernel follows the roundTiesToEven rule (the same issue your kernel fixes for normal values). So, can you please elaborate on this problem and maybe provide an input where this happens? Thanks! Also, I couldn't get your unit test to run, it gives me this error: > e2m1_value = torch.where(denormal_mask, denormal_x, e2m1_value)
E RuntimeError: The size of tensor a (32) must match the size of tensor b (128) at non-singleton dimension 2 |
Here is a minimal reproducible example for the issue I mentioned. |
You are right and I made a mistake describing the issue. Just updated the description #974. The current AITER does not follow the round even rule at 0.25. |
|
Hi @hann-wang, can you please do a merge/rebase from main? Looks like there are some conflicts that need to be solved before merging your PR. |
There was a problem hiding this comment.
Pull Request Overview
This PR fixes the MXFP4 mantissa rounding implementation to address issue #974. The changes update both the PyTorch reference implementation and the Triton kernel to use a more correct rounding approach.
Key Changes:
- Implements proper round-to-nearest-even (banker's rounding) for mantissa values
- Separates handling of denormal, normal, and saturated values with explicit masking
- Adds constants for FP32 and FP4 format specifications (exponent bias, mantissa/exponent bits)
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| op_tests/triton_tests/test_quant_mxfp4.py | Updates PyTorch reference implementation with corrected MXFP4 quantization logic including proper rounding |
| aiter/ops/triton/_triton_kernels/quant.py | Updates Triton kernel implementation to match the corrected quantization algorithm |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Hi @lucas-santos-amd , I just merged the changes from main. |
I'll merge the PR now, thanks for your contribution! |
* fix: MXFP4 mantissa rounding * fix: mantissa rounding in test_quant_mxfp4 * refactor dynamic_mxfp4_quant * chore: format * fix: mxfp4 quantization tests * chore: format * fix: mxfp4 quantization test with correct bitwidth and sign * chore: restore DEBUG_MODE * chore: align test_quant_mxfp4 with triton kernel --------- Co-authored-by: lucas-santos-amd <Lucas.Santos@amd.com>
* fix: MXFP4 mantissa rounding * fix: mantissa rounding in test_quant_mxfp4 * refactor dynamic_mxfp4_quant * chore: format * fix: mxfp4 quantization tests * chore: format * fix: mxfp4 quantization test with correct bitwidth and sign * chore: restore DEBUG_MODE * chore: align test_quant_mxfp4 with triton kernel --------- Co-authored-by: lucas-santos-amd <Lucas.Santos@amd.com>
* fix: MXFP4 mantissa rounding * fix: mantissa rounding in test_quant_mxfp4 * refactor dynamic_mxfp4_quant * chore: format * fix: mxfp4 quantization tests * chore: format * fix: mxfp4 quantization test with correct bitwidth and sign * chore: restore DEBUG_MODE * chore: align test_quant_mxfp4 with triton kernel --------- Co-authored-by: lucas-santos-amd <Lucas.Santos@amd.com>
fixes #974