Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
19b6b08
Initial implementation
zianglih May 9, 2026
7b0b2d0
Make 4over6 compile time for dequant
zianglih May 9, 2026
1e5b6ad
Expand 1d fwd+bwd test
zianglih May 9, 2026
99660fc
Refactor
zianglih May 9, 2026
cb2e0a3
Clean up
zianglih May 9, 2026
2c066f9
Clean up
zianglih May 9, 2026
69e8f3a
Add gemm test
zianglih May 9, 2026
009e651
Add more tests and fix offload
zianglih May 9, 2026
3153fc3
Fix offload
zianglih May 9, 2026
e31b758
Clean up arg
zianglih May 9, 2026
fcd526c
Add more test
zianglih May 9, 2026
100c378
Add more tests
zianglih May 10, 2026
1c9f26b
Clean up test
zianglih May 10, 2026
93fe922
Refactor cuh kernel impl
zianglih May 10, 2026
f4e4a4e
Further extract
zianglih May 10, 2026
b3f59ee
Clean up
zianglih May 10, 2026
31decf9
Add recipe_id
zianglih May 10, 2026
2fa6b8c
Fix failing unit tests
zianglih May 10, 2026
7df2db0
Clean up test
zianglih May 10, 2026
ce85be2
Clean up
zianglih May 10, 2026
1b68038
Refactor ref
zianglih May 10, 2026
bb722a3
Update comments and docs
zianglih May 10, 2026
fe18a1e
Drop unnecessary test_sanity workaround
zianglih May 10, 2026
522e93e
Refactor `QuantizerRole`
zianglih May 11, 2026
782b7ee
Allow separate recipe 4over6 config
zianglih May 11, 2026
d9cd12c
Support 2d
zianglih May 12, 2026
708c1ec
Refactor 2d
zianglih May 12, 2026
4d31f18
Clean up anti pattern
zianglih May 12, 2026
dfc15f2
Enforce 4over6 consistency
zianglih May 12, 2026
9453670
Update comments
zianglih May 12, 2026
6d871da
Update docs
zianglih May 12, 2026
f8338e8
Fix test
zianglih May 12, 2026
c9bc921
Drop test_fusible_ops
zianglih May 12, 2026
00ba694
Revert "Drop test_fusible_ops"
zianglih May 12, 2026
3252d4e
Refactor test_fusible_ops
zianglih May 12, 2026
3f33c1d
Refactor ref and extend cpp test
zianglih May 12, 2026
8607e03
Clean up cpp test
zianglih May 12, 2026
d3dbf34
Minor comment
zianglih May 12, 2026
565f33f
Drop doc
zianglih May 12, 2026
54b4da8
Explicit handle conditional smem buffer
zianglih May 12, 2026
fa09200
Further clean up
zianglih May 12, 2026
e57e8be
More templates
zianglih May 12, 2026
a1df319
Simplify cpp
zianglih May 12, 2026
21720da
Drop write back lifting
zianglih May 12, 2026
b1d073a
Add MAE and dedicated fast math env var
zianglih May 12, 2026
0392708
Harden cpp test
zianglih May 12, 2026
0b77a37
Add warning and err fast math coverage
zianglih May 12, 2026
81e579e
Fold test case and clean up cpp test
zianglih May 12, 2026
1e311ef
Initial 448 vs 256 implementation
zianglih May 12, 2026
38a1c4c
Use e4m3 max instead of boolean, more template
zianglih May 12, 2026
3cdd9d9
Add benchmark script and minor optimization
zianglih May 13, 2026
7deba75
Use standalone kernels
zianglih May 13, 2026
93dbf2b
Use cp async
zianglih May 13, 2026
8819d12
Add benchmark script
zianglih May 13, 2026
24e417b
Minor fix after rebase
zianglih May 13, 2026
472e5b8
Naming consistency
zianglih May 13, 2026
83e2308
Remove 4over6 benchmark
zianglih May 13, 2026
2980cb1
Refactor modes
zianglih May 19, 2026
967293f
Relax tol for `test_layernorm_mlp` for `nvfp4_4over6`
zianglih May 19, 2026
f555bf2
Minor fix recipe naming
zianglih May 19, 2026
7a4b5c0
Remove gradient 4over6 quantization and partially allow SR/RHT
zianglih May 19, 2026
e036a7c
Allow RHT in pytorch ref
zianglih May 19, 2026
f8c4373
Update transformer_engine/pytorch/csrc/quantizer.cpp
timmoon10 May 19, 2026
96fcb43
Minor fix TODO lint
zianglih May 20, 2026
3e6d4cd
Use standard nvfp4 for grad ref in test_fusible_ops.py since 4over6 i…
zianglih May 20, 2026
1a5c19d
Minor fix test-fusible_ops 4over6 helper
zianglih May 20, 2026
63b82a5
Default to 256 for 4over6
zianglih May 20, 2026
3e130f7
Reset RNG state for each TE ops test
timmoon10 May 21, 2026
5f2d761
Merge branch 'main' into 4over6
zianglih May 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,30 @@ Kernel Configuration
:Default: ``0``
:Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar.

.. envvar:: NVTE_NVFP4_4OVER6

:Type: ``str`` (``none``, ``weights``, ``activations``, or ``all``)
:Default: ``none``
:Description: Enable 4over6 adaptive NVFP4 block scaling for weights, activations, or both in the ``NVFP4BlockScaling`` recipe. For each selected FP4 block, quantization compares map-to-4 and map-to-6 candidates and stores the candidate with lower configured error. ``none`` keeps standard NVFP4. Current 4over6 support targets RL and post-training scenarios; pre-training paths that combine 4over6 with RHT are not yet implemented.

.. envvar:: NVTE_NVFP4_4OVER6_E4M3_USE_256

:Type: ``str`` (``none``, ``weights``, ``activations``, or ``all``)
:Default: ``all``
:Description: Select NVFP4 4over6 quantizers that use 256 instead of 448 as the global E4M3 scale bound. By default, all 4over6 quantizers use 256. Set the env var to ``none`` (or set ``NVFP4BlockScaling(nvfp4_4over6_e4m3_use_256="none")``) to use the standard NVFP4 448 bound for all 4over6 quantizers. This option is only meaningful for tensor roles that also enable :envvar:`NVTE_NVFP4_4OVER6`.

.. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE

:Type: ``str`` (``MAE`` or ``MSE``)
:Default: ``MAE``
:Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe.

.. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH

:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Allow the NVFP4 4over6 candidate error computation to use faster non-strict floating-point expressions. By default, 4over6 error comparison uses strict expressions; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path.

Torch Compilation and Fusion
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
616 changes: 520 additions & 96 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu

Large diffs are not rendered by default.

68 changes: 54 additions & 14 deletions tests/cpp/operator/test_dequantize_nvfp4.cu
Comment thread
timmoon10 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data,
OType *output,
size_t rows,
size_t cols,
size_t scale_stride) {
constexpr float factor_inv = 1.0f / (6.0f * 448.0f);
size_t scale_stride,
int e4m3_max) {
const float factor_inv = 1.0f / (6.0f * static_cast<float>(e4m3_max));
constexpr size_t BLOCK_SIZE = 16;
const size_t Mread = cols / BLOCK_SIZE;
const size_t bytes_per_block = BLOCK_SIZE / 2;
Expand Down Expand Up @@ -86,11 +87,18 @@ float compute_amax(test::Tensor &t, size_t rows, size_t cols) {
return amax;
}

struct NVFP4DequantizeTestConfig {
NVTENVFP44Over6Mode mode = kNVTENVFP44Over6Disabled;
int e4m3_max = 448;
};

// Quantize a high-precision input to NVFP4, then dequantize and compare
// against a CPU reference computed from the quantized data.
template <typename OutputType>
void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
const bool row_scaled_nvfp4) {
const bool row_scaled_nvfp4,
const NVTENVFP44Over6Mode mode,
const int e4m3_max) {
using namespace test;
DType otype = TypeInfo<OutputType>::dtype;

Expand All @@ -105,6 +113,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,

// Configure quantized tensor amax
size_t amax_size = 1;
quantized.set_nvfp4_e4m3_max(e4m3_max);
ASSERT_EQ(quantized.nvfp4_e4m3_max(), e4m3_max);
if (row_scaled_nvfp4) {
quantized.set_row_scaled_nvfp4(true);
amax_size = rows;
Expand All @@ -116,7 +126,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,

// Quantize
if (rows > 0 && cols > 0) {
nvte_quantize(input.data(), quantized.data(), 0);
QuantizationConfigWrapper quant_config;
quant_config.set_nvfp4_4over6_mode(mode);
nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
Expand Down Expand Up @@ -146,7 +158,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
std::make_unique<OutputType[]>(rows * cols);
compute_ref_dequantize_nvfp4<OutputType>(
fp4_data, scales, amax_vals, ref_output.get(),
rows, cols, scale_stride);
rows, cols, scale_stride, e4m3_max);

// Compare results from TE and reference impls
auto [atol, rtol] = getTolerances(otype);
Expand All @@ -156,7 +168,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols,
// Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path.
template <typename OutputType>
void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
const bool row_scaled_nvfp4) {
const bool row_scaled_nvfp4,
const NVTENVFP44Over6Mode mode,
const int e4m3_max) {
using namespace test;
DType otype = TypeInfo<OutputType>::dtype;

Expand All @@ -165,6 +179,8 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,

Tensor quantized_compact("quantized_compact", std::vector<size_t>{rows, cols},
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
quantized_compact.set_nvfp4_e4m3_max(e4m3_max);
ASSERT_EQ(quantized_compact.nvfp4_e4m3_max(), e4m3_max);
if (row_scaled_nvfp4) {
quantized_compact.set_row_scaled_nvfp4(true);
} else if (rows > 0 && cols > 0) {
Expand All @@ -174,7 +190,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
}

if (rows > 0 && cols > 0) {
nvte_quantize(input.data(), quantized_compact.data(), 0);
QuantizationConfigWrapper quant_config;
quant_config.set_nvfp4_4over6_mode(mode);
nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0);
cudaDeviceSynchronize();
}

Expand All @@ -186,6 +204,8 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols,
// Create tensor with same FP4 data but swizzled scales
Tensor quantized_swizzled("quantized_swizzled", std::vector<size_t>{rows, cols},
DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING);
quantized_swizzled.set_nvfp4_e4m3_max(e4m3_max);
ASSERT_EQ(quantized_swizzled.nvfp4_e4m3_max(), e4m3_max);
if (row_scaled_nvfp4) {
quantized_swizzled.set_row_scaled_nvfp4(true);
} else {
Expand Down Expand Up @@ -260,7 +280,8 @@ std::vector<std::pair<size_t, size_t>> nvfp4_tensor_dims = {
class DequantizeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
transformer_engine::DType,
bool>> {};
bool,
NVFP4DequantizeTestConfig>> {};

TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4)
{
Expand All @@ -271,10 +292,12 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4)
const auto tensor_size = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const bool row_scaled_nvfp4 = std::get<2>(GetParam());
const NVFP4DequantizeTestConfig config = std::get<3>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
performTest_dequantize_nvfp4<OutputType>(
tensor_size.first, tensor_size.second, row_scaled_nvfp4);
tensor_size.first, tensor_size.second, row_scaled_nvfp4, config.mode,
config.e4m3_max);
);
}

Expand All @@ -284,21 +307,29 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(
::testing::ValuesIn(nvfp4_tensor_dims),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Bool()),
::testing::Bool(),
::testing::Values(NVFP4DequantizeTestConfig{},
NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 448},
NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 256})),
[](const testing::TestParamInfo<DequantizeNVFP4TestSuite::ParamType>& info)
{
const NVFP4DequantizeTestConfig config = std::get<3>(info.param);
const bool use_4over6 = config.mode != kNVTENVFP44Over6Disabled;
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
(std::get<2>(info.param) ? "RowScaled" : "PerTensor");
(std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" +
(use_4over6 ? "FourOverSix" : "Default") + "X" +
(config.e4m3_max == 256 ? "E4M3Max256" : "E4M3Max448");
return name;
}
);

class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam
<std::tuple<std::pair<size_t, size_t>,
transformer_engine::DType,
bool>> {};
bool,
NVFP4DequantizeTestConfig>> {};

TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled)
{
Expand All @@ -309,10 +340,12 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled)
const auto tensor_size = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const bool row_scaled_nvfp4 = std::get<2>(GetParam());
const NVFP4DequantizeTestConfig config = std::get<3>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType,
performTest_dequantize_nvfp4_swizzled<OutputType>(
tensor_size.first, tensor_size.second, row_scaled_nvfp4);
tensor_size.first, tensor_size.second, row_scaled_nvfp4, config.mode,
config.e4m3_max);
);
}

Expand All @@ -322,13 +355,20 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Combine(
::testing::ValuesIn(nvfp4_tensor_dims),
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Bool()),
::testing::Bool(),
::testing::Values(NVFP4DequantizeTestConfig{},
NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 448},
NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 256})),
[](const testing::TestParamInfo<DequantizeNVFP4SwizzledTestSuite::ParamType>& info)
{
const NVFP4DequantizeTestConfig config = std::get<3>(info.param);
const bool use_4over6 = config.mode != kNVTENVFP44Over6Disabled;
std::string name = std::to_string(std::get<0>(info.param).first) + "X" +
std::to_string(std::get<0>(info.param).second) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
(std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" +
(use_4over6 ? "FourOverSix" : "Default") + "X" +
(config.e4m3_max == 256 ? "E4M3Max256" : "E4M3Max448") + "X" +
"Swizzled";
return name;
}
Expand Down
12 changes: 12 additions & 0 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,18 @@ void Tensor::set_row_scaled_nvfp4(bool row_scaled_nvfp4) {
}
}

void Tensor::set_nvfp4_e4m3_max(int nvfp4_e4m3_max) {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING,
"NVFP4 E4M3 max is only supported for NVFP4 tensors.");
tensor_.set_nvfp4_e4m3_max(nvfp4_e4m3_max);
}

int Tensor::nvfp4_e4m3_max() const {
NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING,
"NVFP4 E4M3 max is only supported for NVFP4 tensors.");
return tensor_.get_nvfp4_e4m3_max();
}

void Tensor::to_cpu() {
if (data_rowwise_) { data_rowwise_->to_cpu(); }
if (data_columnwise_) { data_columnwise_->to_cpu(); }
Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,13 @@ class Tensor {
return columnwise_;
}

int nvfp4_e4m3_max() const;

void set_tensor_amax_nullptr();

void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales);
void set_row_scaled_nvfp4(bool row_scaled_nvfp4);
void set_nvfp4_e4m3_max(int nvfp4_e4m3_max);

void to_cpu();
void from_cpu();
Expand Down
Loading