Skip to content

Commit 6c5cc7f

Browse files
Added Quantize Configs to grouped Qauntization
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
1 parent 1f1ab92 commit 6c5cc7f

3 files changed

Lines changed: 34 additions & 10 deletions

File tree

tests/cpp/operator/test_cast_nvfp4_transpose_grouped.cu

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ void compare_nvfp4_tensors(const std::string& name,
227227
}
228228
}
229229

230-
constexpr bool print_detailed_summary = true;
230+
bool print_detailed_summary = false;
231231
if (print_detailed_summary) {
232232
// Always report summary - either success or failure
233233
std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl;
@@ -492,7 +492,11 @@ void performTest(const ShapeRepresentation shape_rep,
492492
&offsets_tensor, sizeof(offsets_tensor));
493493
}
494494

495-
nvte_group_quantize(in_group_tensor, out_group_tensor, 0);
495+
QuantizationConfigWrapper quant_config;
496+
quant_config.set_use_fast_math(use_fast_math);
497+
quant_config.set_stochastic_rounding(false);
498+
499+
nvte_group_quantize_v2(in_group_tensor, out_group_tensor, quant_config, 0);
496500
cudaDeviceSynchronize();
497501
auto err = cudaGetLastError();
498502
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
@@ -548,14 +552,14 @@ void performTest(const ShapeRepresentation shape_rep,
548552

549553
// {shape_representation, num_tensors, [logical_shape_M, logical_shape_K], [M_i], [K_i]}
550554
std::vector<std::vector<size_t>> grouped_input_config = {
551-
// {SAME_BOTH_DIMS, 1, 128,128},
552-
// {SAME_BOTH_DIMS, 2, 256,128},
553-
// {VARYING_FIRST_DIM, 2, 512,128, 128,384},
554-
// {VARYING_FIRST_DIM, 3, 1024,160, 128,384,512},
555-
// {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512},
556-
// {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
557-
// {VARYING_LAST_DIM, 3, 256,896, 128,256,512},
558-
// {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
555+
{SAME_BOTH_DIMS, 1, 128,128},
556+
{SAME_BOTH_DIMS, 2, 256,128},
557+
{VARYING_FIRST_DIM, 2, 512,128, 128,384},
558+
{VARYING_FIRST_DIM, 3, 1024,160, 128,384,512},
559+
{VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512},
560+
{VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
561+
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
562+
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
559563
{VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640},
560564
};
561565

transformer_engine/common/cast/cast.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
5656
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
5757
}
5858

59+
void nvte_group_quantize_v2(const NVTEGroupedTensor input, NVTEGroupedTensor output,
60+
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
61+
NVTE_API_CALL(nvte_group_quantize_v2);
62+
using namespace transformer_engine;
63+
64+
constexpr bool IS_ACT = false;
65+
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, quant_config, stream);
66+
}
67+
5968
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
6069
NVTETensor workspace, cudaStream_t stream) {
6170
NVTE_API_CALL(nvte_quantize_dbias);

transformer_engine/common/include/transformer_engine/cast.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
124124
void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
125125
const NVTEQuantizationConfig quant_config, cudaStream_t stream);
126126

127+
128+
/*! \brief Casts input tensor to quantized output tensor, with advanced quantization options.
129+
*
130+
* \param[in] input Input grouped tensor to be cast.
131+
* \param[in,out] output Output grouped quantized tensor.
132+
* \param[in] quant_config Quantization configuration.
133+
* \param[in] stream CUDA stream used for the operation.
134+
*/
135+
void nvte_group_quantize_v2(const NVTEGroupedTensor input, NVTEGroupedTensor output,
136+
const NVTEQuantizationConfig quant_config, cudaStream_t stream);
137+
127138
/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns.
128139
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
129140
* the block quantization (MXFP8) of the specified shape of the block will be used.

0 commit comments

Comments
 (0)