Make cute dsl mxfp8/nvfp4 quantizer bitwise exact#3387
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughAdds RN FP32 reciprocal and RN NVFP4 scale helpers, threads a ChangesNVFP4 Disable Fast-Math and Quantization Precision
MXFP8 Quantization Reference and Test Improvements
🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly Related PRs
Suggested Labels
Suggested Reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a disable_fast_math configuration for NVFP4 quantization, adding specialized round-to-nearest PTX operations and supporting logic across the quantization kernels. It also refines the MXFP8 reference implementation and scale swizzling to ensure bitwise exactness with host references. Feedback was provided regarding the rounding logic in the Python reference implementation, specifically noting that it should explicitly handle saturation and non-finite inputs to match the PTX behavior.
There was a problem hiding this comment.
🧹 Nitpick comments (5)
flashinfer/quantization/quantization_cute_dsl_utils.py (1)
22-22: 💤 Low valueClarify the tiny subnormal handling logic.
The new predicates
p_exp_zero,p_tiny_subsuppress the mantissa bump when the input is a very small denormal (exponent=0 and mantissa ≤ 0x400000). Consider adding a brief inline comment explaining why this threshold was chosen, as it may not be obvious to future maintainers why half the mantissa range is special-cased.Also applies to: 169-184
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/quantization/quantization_cute_dsl_utils.py` at line 22, Add a brief inline comment next to the predicates p_exp_zero and p_tiny_sub that explains the chosen threshold 0x400000: note that 0x400000 is half of the 23-bit mantissa range (i.e. half the max mantissa for single-precision), and we suppress the mantissa "bump" for denormals with mantissa ≤ 0x400000 to avoid incorrectly promoting very small subnormal values into normal range during rounding (preserving expected underflow/rounding semantics). Place the comment where p_exp_zero and p_tiny_sub are defined and also mirror it near the related logic around the mantissa bump handling further down (the block currently around the other predicate usage).tests/utils/test_fp8_quantize.py (1)
57-71: ⚡ Quick winSame dtype normalization concern for scale factors.
Similar to the quantized output comparison, if
ref_sfhas a different dtype thana_sf, thetorch.equalcomparison will fail even if the byte representations match. Consider applying consistent dtype normalization here as well.Suggested fix
+ actual_sf = a_sf.view(torch.uint8) if a_sf.dtype != torch.uint8 else a_sf + expected_sf = ref_sf.view(torch.uint8) if ref_sf.dtype != torch.uint8 else ref_sf - assert a_sf.shape == ref_sf.shape, ( - f"scale factors shape mismatch: actual={a_sf.shape}, expected={ref_sf.shape}" + assert actual_sf.shape == expected_sf.shape, ( + f"scale factors shape mismatch: actual={actual_sf.shape}, expected={expected_sf.shape}" ) - if not torch.equal(a_sf, ref_sf): - mismatch = a_sf != ref_sf + if not torch.equal(actual_sf, expected_sf): + mismatch = actual_sf != expected_sf🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/utils/test_fp8_quantize.py` around lines 57 - 71, The scale-factor comparison fails when ref_sf has a different dtype than a_sf; normalize ref_sf to a_sf's dtype (and device if needed) before equality checks and element-wise comparisons. Specifically, before torch.equal(a_sf, ref_sf) convert ref_sf via ref_sf = ref_sf.to(dtype=a_sf.dtype, device=a_sf.device) (or an equivalent conversion), then use that normalized ref_sf for mismatch = a_sf != ref_sf, torch.nonzero(...), and for the reported actual/expected values so the element comparison and printed values use matching dtypes.tests/trace/test_mxfp8_quantize_reference_correctness.py (2)
41-42: 💤 Low valueRedundant
torch.cuda.is_available()guard.Per repository conventions, tests assume CUDA is available. The
_skip_if_not_sm100()call at line 15 already handles unsupported hardware. This guard is unnecessary.Suggested fix
- if torch.cuda.is_available(): - torch.cuda.synchronize() + torch.cuda.synchronize()Based on learnings: "Tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/trace/test_mxfp8_quantize_reference_correctness.py` around lines 41 - 42, Remove the redundant CUDA availability guard and call torch.cuda.synchronize() unconditionally: the test file tests/trace/test_mxfp8_quantize_reference_correctness.py already calls _skip_if_not_sm100() at the top, so delete the `if torch.cuda.is_available():` check and leave a direct `torch.cuda.synchronize()` call where the guarded call now sits to ensure synchronization after CUDA operations.
34-40: 💤 Low valueInconsistent view conversion between
q_apiands_apicomparisons.Line 35-36 converts both
q_apiandq_reftouint8views for comparison. However, line 40 only convertss_apitouint8while comparing directly againsts_ref. Ifs_refis alreadyuint8from the reference, this is fine, but ifs_apiands_refhave different dtypes, the comparison may fail unexpectedly.Consider applying consistent view conversion:
Suggested fix
- torch.testing.assert_close(s_api.view(torch.uint8), s_ref, atol=0, rtol=0) + torch.testing.assert_close( + s_api.view(torch.uint8), + s_ref.view(torch.uint8), + atol=0, + rtol=0, + )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/trace/test_mxfp8_quantize_reference_correctness.py` around lines 34 - 40, The two assert_close checks use inconsistent view conversion: the first converts both q_api and q_ref to torch.uint8 before comparing, but the second only converts s_api; update the second assertion so both operands are compared with the same uint8 view (i.e., apply view(torch.uint8) to both s_api and s_ref) or otherwise ensure s_ref is explicitly converted to the same dtype as s_api; locate the torch.testing.assert_close call involving s_api and s_ref and make the operand dtypes consistent (match the pattern used for q_api/q_ref).flashinfer/trace/templates/quantize.py (1)
164-210: 💤 Low valueConsider using named constants or importing
SfLayoutenum for layout values.Magic layout values 0, 1, 2 are fragile. If the
SfLayoutenum values change elsewhere, this reference implementation would silently produce incorrect results without any compile-time or runtime warning.Suggested approach
# At top of function or module level _SF_LAYOUT_128x4 = 0 _SF_LAYOUT_8x4 = 1 _SF_LAYOUT_LINEAR = 2 # Then use in conditionals: if layout == _SF_LAYOUT_LINEAR: ... elif layout == _SF_LAYOUT_128x4: ... elif layout == _SF_LAYOUT_8x4: ...Or import the actual enum if available and compare against
SfLayout.layout_linear.value, etc.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/trace/templates/quantize.py` around lines 164 - 210, Replace the magic numeric layout checks with named constants or the SfLayout enum to avoid brittle literals: locate the conditional branches that test the variable layout (the if/elif/else block handling values 0,1,2 in this module) and either define module-level constants like _SF_LAYOUT_128x4/_SF_LAYOUT_8x4/_SF_LAYOUT_LINEAR and use them in the comparisons, or import the SfLayout enum and compare against SfLayout.layout_128x4.value, SfLayout.layout_8x4.value, and SfLayout.layout_linear.value (or against the enum members if layout is an enum), then update the if/elif clauses accordingly and keep the existing index calculations intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@flashinfer/quantization/quantization_cute_dsl_utils.py`:
- Line 22: Add a brief inline comment next to the predicates p_exp_zero and
p_tiny_sub that explains the chosen threshold 0x400000: note that 0x400000 is
half of the 23-bit mantissa range (i.e. half the max mantissa for
single-precision), and we suppress the mantissa "bump" for denormals with
mantissa ≤ 0x400000 to avoid incorrectly promoting very small subnormal values
into normal range during rounding (preserving expected underflow/rounding
semantics). Place the comment where p_exp_zero and p_tiny_sub are defined and
also mirror it near the related logic around the mantissa bump handling further
down (the block currently around the other predicate usage).
In `@flashinfer/trace/templates/quantize.py`:
- Around line 164-210: Replace the magic numeric layout checks with named
constants or the SfLayout enum to avoid brittle literals: locate the conditional
branches that test the variable layout (the if/elif/else block handling values
0,1,2 in this module) and either define module-level constants like
_SF_LAYOUT_128x4/_SF_LAYOUT_8x4/_SF_LAYOUT_LINEAR and use them in the
comparisons, or import the SfLayout enum and compare against
SfLayout.layout_128x4.value, SfLayout.layout_8x4.value, and
SfLayout.layout_linear.value (or against the enum members if layout is an enum),
then update the if/elif clauses accordingly and keep the existing index
calculations intact.
In `@tests/trace/test_mxfp8_quantize_reference_correctness.py`:
- Around line 41-42: Remove the redundant CUDA availability guard and call
torch.cuda.synchronize() unconditionally: the test file
tests/trace/test_mxfp8_quantize_reference_correctness.py already calls
_skip_if_not_sm100() at the top, so delete the `if torch.cuda.is_available():`
check and leave a direct `torch.cuda.synchronize()` call where the guarded call
now sits to ensure synchronization after CUDA operations.
- Around line 34-40: The two assert_close checks use inconsistent view
conversion: the first converts both q_api and q_ref to torch.uint8 before
comparing, but the second only converts s_api; update the second assertion so
both operands are compared with the same uint8 view (i.e., apply
view(torch.uint8) to both s_api and s_ref) or otherwise ensure s_ref is
explicitly converted to the same dtype as s_api; locate the
torch.testing.assert_close call involving s_api and s_ref and make the operand
dtypes consistent (match the pattern used for q_api/q_ref).
In `@tests/utils/test_fp8_quantize.py`:
- Around line 57-71: The scale-factor comparison fails when ref_sf has a
different dtype than a_sf; normalize ref_sf to a_sf's dtype (and device if
needed) before equality checks and element-wise comparisons. Specifically,
before torch.equal(a_sf, ref_sf) convert ref_sf via ref_sf =
ref_sf.to(dtype=a_sf.dtype, device=a_sf.device) (or an equivalent conversion),
then use that normalized ref_sf for mismatch = a_sf != ref_sf,
torch.nonzero(...), and for the reported actual/expected values so the element
comparison and printed values use matching dtypes.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 86f02475-3b51-44e6-84ac-442a100e4cfe
📒 Files selected for processing (8)
flashinfer/cute_dsl/fp4_common.pyflashinfer/quantization/fp8_quantization.pyflashinfer/quantization/kernels/nvfp4_quantize.pyflashinfer/quantization/quantization_cute_dsl_utils.pyflashinfer/trace/templates/quantize.pytests/trace/test_mxfp8_quantize_reference_correctness.pytests/utils/test_fp4_quantize.pytests/utils/test_fp8_quantize.py
|
Hi @zianglih Do you mind updating the PR description and/or link an issue with the motivation, what it is, and why? Seems like the PR is trying to add a "no-fast-math" path but I'm curious to know why |
|
Hi @bkryu , I have added the description. |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/utils/test_fp8_quantize.py (1)
302-305: ⚡ Quick winCover
-infexplicitly in the extreme-scale fixture.This new case only exercises the positive-infinity overflow path. Adding a
-infblock would close the remaining sign-specific extreme and make this edge-case test harder to regress.Suggested tweak
- a[:, 32:64] = float("inf") - a[:, 64:96] = 448.0 - a[:, 96:128] = -448.0 + a[:, 32:48] = float("inf") + a[:, 48:64] = float("-inf") + a[:, 64:96] = 448.0 + a[:, 96:128] = -448.0🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/utils/test_fp8_quantize.py` around lines 302 - 305, The extreme-scale test constructs tensor a and currently sets a[:, 32:64] = inf, a[:, 64:96] = 448.0, a[:, 96:128] = -448.0; modify this fixture to also assign a[:, 0:32] = float("-inf") (or another unused slice) so the tensor explicitly covers negative infinity and exercises the negative-infinity overflow path in the fp8 quantization tests (update any related comments to reflect the added -inf block).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/utils/test_fp8_quantize.py`:
- Around line 302-305: The extreme-scale test constructs tensor a and currently
sets a[:, 32:64] = inf, a[:, 64:96] = 448.0, a[:, 96:128] = -448.0; modify this
fixture to also assign a[:, 0:32] = float("-inf") (or another unused slice) so
the tensor explicitly covers negative infinity and exercises the
negative-infinity overflow path in the fp8 quantization tests (update any
related comments to reflect the added -inf block).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 3a71dfa0-5a69-4365-8e4b-a59afd833f6d
📒 Files selected for processing (2)
tests/utils/test_fp8_quantize.pytests/utils_fp8.py
📌 Description
@HumansAnd
We want to make sure FlashInfer quantization backends are bitwise identical with TE style implementation for RL use cases.
Currently cute dsl mxfp8 backend is not bitwise identical on subnormal edge cases, and cute dsl nvfp4 backend does not honor
TRTLLM_DISABLE_FP4_QUANT_FAST_MATHenv var and has fast math always enabled.TRTLLM_DISABLE_FP4_QUANT_FAST_MATH=1for bitwise-exact math.🔍 Related Issues
TRTLLM_DISABLE_FP4_QUANT_FAST_MATHintroduced by:🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Bug Fixes
Tests