Skip to content

Conversation

@danisereb
Copy link

@danisereb danisereb commented Jan 5, 2026

What does this PR do?

Type of change: new feature

Overview: Add support for MXFP8 PTQ, enabling MXFP8 hardware acceleration during inference on Blackwell GPUs.

Usage

export MODEL_PATH=/my_home/hf_models/nvidia/OpenMath2-Llama3.1-8B
export OUTPUT_PATH=/my_home/hf_models/nvidia/OpenMath2-Llama3.1-8B-MXFP8
mkdir -p $OUTPUT_PATH

python examples/llm_ptq/hf_ptq.py \
--export_fmt hf \
--dataset cnn_dailymail \
--pyt_ckpt_path $MODEL_PATH \
--export_path $OUTPUT_PATH \
--qformat mxfp8

The hf_quant_config.json of the output checkpoint:

{
    "producer": {
        "name": "modelopt",
        "version": "0.41.0.dev50+g7a796a875"
    },
    "quantization": {
        "quant_algo": "MXFP8",
        "kv_cache_quant_algo": "FP8",
        "group_size": 32,
        "exclude_modules": [
            "lm_head"
        ]
    }
}

And config.json (only the quantization_config):

...
    "quantization_config": {
        "ignore": [
            "lm_head"
        ],
        "quant_algo": "MXFP8",
        "kv_cache_scheme": {
            "dynamic": false,
            "num_bits": 8,
            "type": "float"
        },
        "producer": {
            "name": "modelopt",
            "version": "0.41.0.dev50+g7a796a875"
        },
        "quant_method": "modelopt"
    }

Testing

Used hf_ptq.py to quantize the model nvidia/OpenMath2-Llama3.1-8B (available in hugging-face), see the example command above.

Checked that the generated MXFP8 checkpoint can be loaded with vLLM (required changes in vLLM, not merged to main).

Added tests for MXFP8QTensor in tests/gpu/torch/quantization/test_qtensor_cuda.py.
Added "mxfp8" in ‎tests/examples/llm_ptq/test_llm_ptq.py

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 5, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@danisereb danisereb marked this pull request as ready for review January 6, 2026 12:31
@danisereb danisereb requested review from a team as code owners January 6, 2026 12:31
@danisereb danisereb requested review from mxinO and sugunav14 January 6, 2026 12:31
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
@sugunav14
Copy link
Contributor

Could you also add the corresponding unit tests for impacted functions in quant_utils.py here? Thanks!

# Convert E8M0 biased exponent to scale factor: scale = 2^(127 - exponent)
scale_factor = torch.exp2(127 - e8m0_scale.float())

# NOTE: vLLM/flashinfer may require this behavior:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this required? Should we assert e8m0_scale != 0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIU, it doesn't align with MXFP8 specification.
But one of my teammates said that it worked for him in a certain case.

So I wanted to leave some documentation for it for future reference.

# sm89
PTQCommand(quant="fp8", min_sm=89),
PTQCommand(quant="fp8", kv_cache_quant="none", min_sm=89), # sm100
PTQCommand(quant="mxfp8", min_sm=100),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does hopper support mxfp8?

Copy link
Author

@danisereb danisereb Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blackwell has hardware acceleration for MXFP8.
Hopper does not.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#MXFP8-and-block-scaling

NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: MXFP8.

See what we have for NVFP4 (line below the "mxfp8"):

PTQCommand(quant="nvfp4", min_sm=100),

@codecov
Copy link

codecov bot commented Jan 6, 2026

Codecov Report

❌ Patch coverage is 21.50538% with 73 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.42%. Comparing base (d541324) to head (d5fced8).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
...odelopt/torch/quantization/qtensor/mxfp8_tensor.py 21.59% 69 Missing ⚠️
.../torch/quantization/nn/modules/tensor_quantizer.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #736      +/-   ##
==========================================
- Coverage   74.69%   74.42%   -0.27%     
==========================================
  Files         192      193       +1     
  Lines       18948    19043      +95     
==========================================
+ Hits        14153    14173      +20     
- Misses       4795     4870      +75     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

assert dequant_tensor.shape == input_shape, (
f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}"
)
assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also compare with the fake quant here.

"test_input",
[
# FP8 E4M3 boundary test values (max is 448, various powers of 2)
torch.tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The format looks weird, we can turn off the auto format for the tensors, and define them on the top.

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants