From 2bf2ab5e0c733811d663057d48988df024175387 Mon Sep 17 00:00:00 2001 From: Victor Oliveira Date: Thu, 8 Jan 2026 00:48:15 +0000 Subject: [PATCH] ONNX: Fix FP8 quantization for the second MLP in LayernormMLP Signed-off-by: Victor Oliveira --- .../pytorch/module/layernorm_mlp.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ddb33f303c..4256028c8b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2243,14 +2243,23 @@ def onnx_forward( assert not TEDebugState.debug_enabled, "Debug mode is not supported in ONNX export" assert_warmed_up(self) + + # Get quantizers ( fc1_input_quantizer, fc1_weight_quantizer, + _, + _, + _, + _, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - *_, + fc2_output_quantizer, + _, + _, + _, ) = self._get_quantizers(False, is_grad_enabled) + inp_dtype = inp.dtype fc1_weight, fc2_weight = self._get_weight_tensors() @@ -2324,7 +2333,7 @@ def _clamped_swiglu(x, limit, alpha): fc2_out = onnx_gemm(fc2_weight, act_out, fc2_bias) - if output_quantizer is not None: + if fc2_output_quantizer is not None: raise NotImplementedError("ONNX export of quantized output is not supported") if self.return_layernorm_output: