Skip to content

Make cute dsl mxfp8/nvfp4 quantizer bitwise exact#3387

Open
zianglih wants to merge 4 commits into
flashinfer-ai:mainfrom
zianglih:bitwise
Open

Make cute dsl mxfp8/nvfp4 quantizer bitwise exact#3387
zianglih wants to merge 4 commits into
flashinfer-ai:mainfrom
zianglih:bitwise

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented May 21, 2026

📌 Description

@HumansAnd

We want to make sure FlashInfer quantization backends are bitwise identical with TE style implementation for RL use cases.

Currently cute dsl mxfp8 backend is not bitwise identical on subnormal edge cases, and cute dsl nvfp4 backend does not honor TRTLLM_DISABLE_FP4_QUANT_FAST_MATH env var and has fast math always enabled.

  • MXFP8
    • Added MXFP8 bitwise-exact reference checks against reference implementation.
    • Fixed MXFP8 CuTe scale conversion edge cases and host scale-output initialization.
  • NVFP4
    • Expanded NVFP4 TE-reference exactness coverage to include the CuTe DSL backend.
    • Skipped unsupported CuTe NVFP4 modes in the TE-reference test: per-token activation and 4over6.
    • Made CuTe NVFP4 honor TRTLLM_DISABLE_FP4_QUANT_FAST_MATH=1 for bitwise-exact math.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Optional precise NVFP4 quantization mode (env-controlled) that uses round-to-nearest math, with new RN reciprocal/scale helpers and per-kernel switch for fast-math behavior.
  • Bug Fixes

    • Fixed an uninitialized buffer in the FP8 quantization CPU path.
  • Tests

    • Tightened validation to require exact parity, improved mismatch diagnostics, added an extreme-scale MXFP8 test and reference utilities, and backend parametrization.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 21, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c2a58ab0-1ee2-44f2-bdde-ad980d8d8a58

📥 Commits

Reviewing files that changed from the base of the PR and between c2f74b9 and a16c15c.

📒 Files selected for processing (1)
  • tests/utils_fp8.py

📝 Walkthrough

Walkthrough

Adds RN FP32 reciprocal and RN NVFP4 scale helpers, threads a disable_fast_math env flag through NVFP4 kernels (Linear, Swizzled, TMA) to select RN vs fast-math paths, implements MXFP8 reference quantization with scale swizzling, and replaces tolerance-based tests with exact trace-based assertions.

Changes

NVFP4 Disable Fast-Math and Quantization Precision

Layer / File(s) Summary
RN precision helper intrinsics
flashinfer/cute_dsl/fp4_common.py
Three new DSL user operations (rcp_rn, nvfp4_compute_output_scale_rn, nvfp4_scale_from_amax_rn) implement round-to-nearest FP32 division/multiplication with PTX div.rn.f32/mul.rn.f32 and zero handling.
disable_fast_math control infrastructure
flashinfer/quantization/kernels/nvfp4_quantize.py
Adds env helper _env_flag_enabled, threads disable_fast_math into kernel constructors (Linear, Swizzled, TMA) and cached compilation functions, and reads TRTLLM_DISABLE_FP4_QUANT_FAST_MATH at runtime to choose variants.
NVFP4 block processing with precision branching
flashinfer/quantization/kernels/nvfp4_quantize.py, flashinfer/quantization/quantization_cute_dsl_utils.py
Block helpers (process_nvfp4_block_half, _bfloat, _fp8) and TMA _quantize_sf_block branch on disable_fast_math: RN helpers (rcp_rn, nvfp4_scale_from_amax_rn, nvfp4_compute_output_scale_rn) when enabled, otherwise previous fast-math approximations are used.

MXFP8 Quantization Reference and Test Improvements

Layer / File(s) Summary
MXFP8 reference quantization and scale swizzling
tests/utils_fp8.py
Adds mxfp8_quantize_reference to compute FP8 outputs and per-block UE8M0 scale bytes, and _swizzle_mxfp8_scales to reorder/zero-pad scale bytes for supported SfLayouts.
MXFP8 exact quantization assertion
tests/utils_fp8.py
Adds assert_mxfp8_quantize_exact which recomputes reference outputs and enforces exact bitwise equality for FP8 outputs and scale bytes, reporting mismatch counts and first differing indices.
MXFP8 test suite updates
tests/utils/test_fp8_quantize.py
Replaces host dequantization/tolerance checks with assert_mxfp8_quantize_exact across tests, tightens backend-parity to strict torch.equal, and adds extreme-scale input tests.
NVFP4 test backend parameterization
tests/utils/test_fp4_quantize.py
Parameterizes test_nvfp4_quantize_te_reference over backend (cuda/cute-dsl), adds cute-dsl-specific skips for per-token and 4over6 cases, and forwards backend into nvfp4_quantize calls.
Scale factor buffer initialization
flashinfer/quantization/fp8_quantization.py
CPU quantization path now allocates out_sf with torch.zeros(...) instead of torch.empty(...) to ensure initialized buffer before host routine.

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly Related PRs

Suggested Labels

run-ci, op: attention

Suggested Reviewers

  • yzh119
  • bkryu
  • cyx-6
  • jimmyzho
  • nv-yunzheq
  • IwakuraRein

"I nibble on bits of RN delight,
I hop through scales in the moonlit night.
Exact bytes swizzle, tests sing clear,
Fast-math naps while precision draws near.
A rabbit applauds — code snug and sincere."

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.89% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: making CuTe DSL MXFP8/NVFP4 quantizers bitwise exact, which aligns with the substantive changes across multiple files.
Description check ✅ Passed The PR description comprehensively addresses the template structure with clear sections: describes the core objective (bitwise exactness), lists specific MXFP8 and NVFP4 fixes, links related issues, and confirms completion of pre-commit and testing checklists.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@zianglih zianglih changed the title Make cute dsl mxfp8/nvfp4 backend bitwise exact Make cute dsl mxfp8/nvfp4 quantizer bitwise exact May 21, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a disable_fast_math configuration for NVFP4 quantization, adding specialized round-to-nearest PTX operations and supporting logic across the quantization kernels. It also refines the MXFP8 reference implementation and scale swizzling to ensure bitwise exactness with host references. Feedback was provided regarding the rounding logic in the Python reference implementation, specifically noting that it should explicitly handle saturation and non-finite inputs to match the PTX behavior.

Comment thread flashinfer/trace/templates/quantize.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (5)
flashinfer/quantization/quantization_cute_dsl_utils.py (1)

22-22: 💤 Low value

Clarify the tiny subnormal handling logic.

The new predicates p_exp_zero, p_tiny_sub suppress the mantissa bump when the input is a very small denormal (exponent=0 and mantissa ≤ 0x400000). Consider adding a brief inline comment explaining why this threshold was chosen, as it may not be obvious to future maintainers why half the mantissa range is special-cased.

Also applies to: 169-184

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/quantization/quantization_cute_dsl_utils.py` at line 22, Add a
brief inline comment next to the predicates p_exp_zero and p_tiny_sub that
explains the chosen threshold 0x400000: note that 0x400000 is half of the 23-bit
mantissa range (i.e. half the max mantissa for single-precision), and we
suppress the mantissa "bump" for denormals with mantissa ≤ 0x400000 to avoid
incorrectly promoting very small subnormal values into normal range during
rounding (preserving expected underflow/rounding semantics). Place the comment
where p_exp_zero and p_tiny_sub are defined and also mirror it near the related
logic around the mantissa bump handling further down (the block currently around
the other predicate usage).
tests/utils/test_fp8_quantize.py (1)

57-71: ⚡ Quick win

Same dtype normalization concern for scale factors.

Similar to the quantized output comparison, if ref_sf has a different dtype than a_sf, the torch.equal comparison will fail even if the byte representations match. Consider applying consistent dtype normalization here as well.

Suggested fix
+    actual_sf = a_sf.view(torch.uint8) if a_sf.dtype != torch.uint8 else a_sf
+    expected_sf = ref_sf.view(torch.uint8) if ref_sf.dtype != torch.uint8 else ref_sf
-    assert a_sf.shape == ref_sf.shape, (
-        f"scale factors shape mismatch: actual={a_sf.shape}, expected={ref_sf.shape}"
+    assert actual_sf.shape == expected_sf.shape, (
+        f"scale factors shape mismatch: actual={actual_sf.shape}, expected={expected_sf.shape}"
     )
-    if not torch.equal(a_sf, ref_sf):
-        mismatch = a_sf != ref_sf
+    if not torch.equal(actual_sf, expected_sf):
+        mismatch = actual_sf != expected_sf
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/utils/test_fp8_quantize.py` around lines 57 - 71, The scale-factor
comparison fails when ref_sf has a different dtype than a_sf; normalize ref_sf
to a_sf's dtype (and device if needed) before equality checks and element-wise
comparisons. Specifically, before torch.equal(a_sf, ref_sf) convert ref_sf via
ref_sf = ref_sf.to(dtype=a_sf.dtype, device=a_sf.device) (or an equivalent
conversion), then use that normalized ref_sf for mismatch = a_sf != ref_sf,
torch.nonzero(...), and for the reported actual/expected values so the element
comparison and printed values use matching dtypes.
tests/trace/test_mxfp8_quantize_reference_correctness.py (2)

41-42: 💤 Low value

Redundant torch.cuda.is_available() guard.

Per repository conventions, tests assume CUDA is available. The _skip_if_not_sm100() call at line 15 already handles unsupported hardware. This guard is unnecessary.

Suggested fix
-    if torch.cuda.is_available():
-        torch.cuda.synchronize()
+    torch.cuda.synchronize()

Based on learnings: "Tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/trace/test_mxfp8_quantize_reference_correctness.py` around lines 41 -
42, Remove the redundant CUDA availability guard and call
torch.cuda.synchronize() unconditionally: the test file
tests/trace/test_mxfp8_quantize_reference_correctness.py already calls
_skip_if_not_sm100() at the top, so delete the `if torch.cuda.is_available():`
check and leave a direct `torch.cuda.synchronize()` call where the guarded call
now sits to ensure synchronization after CUDA operations.

34-40: 💤 Low value

Inconsistent view conversion between q_api and s_api comparisons.

Line 35-36 converts both q_api and q_ref to uint8 views for comparison. However, line 40 only converts s_api to uint8 while comparing directly against s_ref. If s_ref is already uint8 from the reference, this is fine, but if s_api and s_ref have different dtypes, the comparison may fail unexpectedly.

Consider applying consistent view conversion:

Suggested fix
-    torch.testing.assert_close(s_api.view(torch.uint8), s_ref, atol=0, rtol=0)
+    torch.testing.assert_close(
+        s_api.view(torch.uint8),
+        s_ref.view(torch.uint8),
+        atol=0,
+        rtol=0,
+    )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/trace/test_mxfp8_quantize_reference_correctness.py` around lines 34 -
40, The two assert_close checks use inconsistent view conversion: the first
converts both q_api and q_ref to torch.uint8 before comparing, but the second
only converts s_api; update the second assertion so both operands are compared
with the same uint8 view (i.e., apply view(torch.uint8) to both s_api and s_ref)
or otherwise ensure s_ref is explicitly converted to the same dtype as s_api;
locate the torch.testing.assert_close call involving s_api and s_ref and make
the operand dtypes consistent (match the pattern used for q_api/q_ref).
flashinfer/trace/templates/quantize.py (1)

164-210: 💤 Low value

Consider using named constants or importing SfLayout enum for layout values.

Magic layout values 0, 1, 2 are fragile. If the SfLayout enum values change elsewhere, this reference implementation would silently produce incorrect results without any compile-time or runtime warning.

Suggested approach
# At top of function or module level
_SF_LAYOUT_128x4 = 0
_SF_LAYOUT_8x4 = 1
_SF_LAYOUT_LINEAR = 2

# Then use in conditionals:
if layout == _SF_LAYOUT_LINEAR:
    ...
elif layout == _SF_LAYOUT_128x4:
    ...
elif layout == _SF_LAYOUT_8x4:
    ...

Or import the actual enum if available and compare against SfLayout.layout_linear.value, etc.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/trace/templates/quantize.py` around lines 164 - 210, Replace the
magic numeric layout checks with named constants or the SfLayout enum to avoid
brittle literals: locate the conditional branches that test the variable layout
(the if/elif/else block handling values 0,1,2 in this module) and either define
module-level constants like _SF_LAYOUT_128x4/_SF_LAYOUT_8x4/_SF_LAYOUT_LINEAR
and use them in the comparisons, or import the SfLayout enum and compare against
SfLayout.layout_128x4.value, SfLayout.layout_8x4.value, and
SfLayout.layout_linear.value (or against the enum members if layout is an enum),
then update the if/elif clauses accordingly and keep the existing index
calculations intact.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@flashinfer/quantization/quantization_cute_dsl_utils.py`:
- Line 22: Add a brief inline comment next to the predicates p_exp_zero and
p_tiny_sub that explains the chosen threshold 0x400000: note that 0x400000 is
half of the 23-bit mantissa range (i.e. half the max mantissa for
single-precision), and we suppress the mantissa "bump" for denormals with
mantissa ≤ 0x400000 to avoid incorrectly promoting very small subnormal values
into normal range during rounding (preserving expected underflow/rounding
semantics). Place the comment where p_exp_zero and p_tiny_sub are defined and
also mirror it near the related logic around the mantissa bump handling further
down (the block currently around the other predicate usage).

In `@flashinfer/trace/templates/quantize.py`:
- Around line 164-210: Replace the magic numeric layout checks with named
constants or the SfLayout enum to avoid brittle literals: locate the conditional
branches that test the variable layout (the if/elif/else block handling values
0,1,2 in this module) and either define module-level constants like
_SF_LAYOUT_128x4/_SF_LAYOUT_8x4/_SF_LAYOUT_LINEAR and use them in the
comparisons, or import the SfLayout enum and compare against
SfLayout.layout_128x4.value, SfLayout.layout_8x4.value, and
SfLayout.layout_linear.value (or against the enum members if layout is an enum),
then update the if/elif clauses accordingly and keep the existing index
calculations intact.

In `@tests/trace/test_mxfp8_quantize_reference_correctness.py`:
- Around line 41-42: Remove the redundant CUDA availability guard and call
torch.cuda.synchronize() unconditionally: the test file
tests/trace/test_mxfp8_quantize_reference_correctness.py already calls
_skip_if_not_sm100() at the top, so delete the `if torch.cuda.is_available():`
check and leave a direct `torch.cuda.synchronize()` call where the guarded call
now sits to ensure synchronization after CUDA operations.
- Around line 34-40: The two assert_close checks use inconsistent view
conversion: the first converts both q_api and q_ref to torch.uint8 before
comparing, but the second only converts s_api; update the second assertion so
both operands are compared with the same uint8 view (i.e., apply
view(torch.uint8) to both s_api and s_ref) or otherwise ensure s_ref is
explicitly converted to the same dtype as s_api; locate the
torch.testing.assert_close call involving s_api and s_ref and make the operand
dtypes consistent (match the pattern used for q_api/q_ref).

In `@tests/utils/test_fp8_quantize.py`:
- Around line 57-71: The scale-factor comparison fails when ref_sf has a
different dtype than a_sf; normalize ref_sf to a_sf's dtype (and device if
needed) before equality checks and element-wise comparisons. Specifically,
before torch.equal(a_sf, ref_sf) convert ref_sf via ref_sf =
ref_sf.to(dtype=a_sf.dtype, device=a_sf.device) (or an equivalent conversion),
then use that normalized ref_sf for mismatch = a_sf != ref_sf,
torch.nonzero(...), and for the reported actual/expected values so the element
comparison and printed values use matching dtypes.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 86f02475-3b51-44e6-84ac-442a100e4cfe

📥 Commits

Reviewing files that changed from the base of the PR and between f6f01a4 and ad0aa7a.

📒 Files selected for processing (8)
  • flashinfer/cute_dsl/fp4_common.py
  • flashinfer/quantization/fp8_quantization.py
  • flashinfer/quantization/kernels/nvfp4_quantize.py
  • flashinfer/quantization/quantization_cute_dsl_utils.py
  • flashinfer/trace/templates/quantize.py
  • tests/trace/test_mxfp8_quantize_reference_correctness.py
  • tests/utils/test_fp4_quantize.py
  • tests/utils/test_fp8_quantize.py

@bkryu
Copy link
Copy Markdown
Collaborator

bkryu commented May 21, 2026

Hi @zianglih Do you mind updating the PR description and/or link an issue with the motivation, what it is, and why?

Seems like the PR is trying to add a "no-fast-math" path but I'm curious to know why

@zianglih
Copy link
Copy Markdown
Contributor Author

Hi @bkryu , I have added the description.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/utils/test_fp8_quantize.py (1)

302-305: ⚡ Quick win

Cover -inf explicitly in the extreme-scale fixture.

This new case only exercises the positive-infinity overflow path. Adding a -inf block would close the remaining sign-specific extreme and make this edge-case test harder to regress.

Suggested tweak
-    a[:, 32:64] = float("inf")
-    a[:, 64:96] = 448.0
-    a[:, 96:128] = -448.0
+    a[:, 32:48] = float("inf")
+    a[:, 48:64] = float("-inf")
+    a[:, 64:96] = 448.0
+    a[:, 96:128] = -448.0
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/utils/test_fp8_quantize.py` around lines 302 - 305, The extreme-scale
test constructs tensor a and currently sets a[:, 32:64] = inf, a[:, 64:96] =
448.0, a[:, 96:128] = -448.0; modify this fixture to also assign a[:, 0:32] =
float("-inf") (or another unused slice) so the tensor explicitly covers negative
infinity and exercises the negative-infinity overflow path in the fp8
quantization tests (update any related comments to reflect the added -inf
block).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/utils/test_fp8_quantize.py`:
- Around line 302-305: The extreme-scale test constructs tensor a and currently
sets a[:, 32:64] = inf, a[:, 64:96] = 448.0, a[:, 96:128] = -448.0; modify this
fixture to also assign a[:, 0:32] = float("-inf") (or another unused slice) so
the tensor explicitly covers negative infinity and exercises the
negative-infinity overflow path in the fp8 quantization tests (update any
related comments to reflect the added -inf block).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 3a71dfa0-5a69-4365-8e4b-a59afd833f6d

📥 Commits

Reviewing files that changed from the base of the PR and between ad0aa7a and 5fff12a.

📒 Files selected for processing (2)
  • tests/utils/test_fp8_quantize.py
  • tests/utils_fp8.py

@zianglih zianglih marked this pull request as draft May 22, 2026 01:15
@zianglih zianglih marked this pull request as ready for review May 22, 2026 01:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants