From c2adcafac4de0d13973124d211cb3a6b5725dfb0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 4 Mar 2026 09:14:19 +0000 Subject: [PATCH 1/3] code drop Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/features/utils/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index d691c1828c..c741ad6353 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -21,7 +21,13 @@ def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGr reduce_within_microbatch = tensor_name != "weight" if tensor_name == "weight": if TEDebugState.weight_tensor_tp_group_reduce: - reduction_group = tp_group + # Do not overwrite with `None`: in torch.distributed collectives + # group=None means the default/world process group. + if tp_group is not None: + reduction_group = tp_group + else: + # "Reduce in TP group" requested, but TP group is missing. + skip_reduction = True else: skip_reduction = True return skip_reduction, reduction_group, reduce_within_microbatch From 2117ea104611d51f7f9334912a83b17b6fc1de36 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 10 Mar 2026 10:52:38 +0100 Subject: [PATCH 2/3] [Debug] Pass tp_size to DebugQuantizer and use it in get_reduction_params Use tp_size to determine whether tensor parallelism is active instead of checking tp_group is None (which is ambiguous since None means world group in torch.distributed). Also add tp_size to the backward-compat kwargs filtering in call_feature so custom features without tp_size in their inspect_tensor signature continue to work. Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/features/api.py | 11 ++++++++++- .../debug/features/log_fp8_tensor_stats.py | 3 ++- .../debug/features/log_nvfp4_tensor_stats.py | 3 ++- transformer_engine/debug/features/log_tensor_stats.py | 3 ++- transformer_engine/debug/features/utils/__init__.py | 6 ++++-- .../debug/pytorch/debug_quantization.py | 3 +++ transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 1 + transformer_engine/pytorch/module/linear.py | 2 +- 10 files changed, 27 insertions(+), 9 deletions(-) diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index 774fae3594..a1cf80dd25 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -479,7 +479,12 @@ def call_feature(self, call, feat_config, layer_name, **kwargs): """ if call.__name__ == "inspect_tensor": kwargs_copy = kwargs.copy() - for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]: + for k in [ + "quantizer", + "columnwise_quantized_tensor", + "rowwise_quantized_tensor", + "tp_size", + ]: if k not in call.__code__.co_varnames: kwargs_copy.pop(k) else: @@ -490,6 +495,10 @@ def call_feature(self, call, feat_config, layer_name, **kwargs): "inspect_tensor_postquantize is deprecated, use inspect_tensor instead.", DeprecationWarning, ) + kwargs_copy = kwargs.copy() + for k in ["tp_size"]: + if k not in call.__code__.co_varnames: + kwargs_copy.pop(k, None) return call(feat_config, layer_name, **kwargs_copy) diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index fd18d590ec..cf11964e25 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -311,6 +311,7 @@ def inspect_tensor( rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, + tp_size: int = 1, ): """ API call used to collect the data about the tensor after process_tensor()/quantization. @@ -357,7 +358,7 @@ def inspect_tensor( ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group + tensor_name, tp_group, tp_size ) STATS_BUFFERS.try_add_buffer( diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 18ac8619f3..8a76f4edcf 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -148,6 +148,7 @@ def inspect_tensor( rowwise_quantized_tensor: Optional[QuantizedTensor] = None, columnwise_quantized_tensor: Optional[QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, + tp_size: int = 1, ): """ API call used to collect the data about the tensor after process_tensor()/quantization. @@ -199,7 +200,7 @@ def inspect_tensor( ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group + tensor_name, tp_group, tp_size ) # Add nvfp4_ prefix to all stats for internal use diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 76e61fab24..5e6ce137bd 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -184,6 +184,7 @@ def inspect_tensor( rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, + tp_size: int = 1, ): # pylint: disable=unused-argument """API call used to collect the data about the tensor before process_tensor()/quantization.""" @@ -214,7 +215,7 @@ def inspect_tensor( ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group + tensor_name, tp_group, tp_size ) for stat in config["stats"]: diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index c741ad6353..75bab4be9b 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -12,7 +12,9 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState -def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup): +def get_reduction_params( + tensor_name: str, tp_group: torch.distributed.ProcessGroup, tp_size: int +): """ Returns the statistics reduction parameters for the tensor. """ @@ -20,7 +22,7 @@ def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGr reduction_group = debug_api.get_tensor_reduction_group() reduce_within_microbatch = tensor_name != "weight" if tensor_name == "weight": - if TEDebugState.weight_tensor_tp_group_reduce: + if TEDebugState.weight_tensor_tp_group_reduce and tp_size > 1: # Do not overwrite with `None`: in torch.distributed collectives # group=None means the default/world process group. if tp_group is not None: diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 57a5967079..ed5fdd4660 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -53,6 +53,7 @@ def __init__( tensor_name: str, parent_quantizer: Optional[Quantizer], tp_group: torch.distributed.ProcessGroup, + tp_size: int, ): super().__init__(rowwise=True, columnwise=True) @@ -60,6 +61,7 @@ def __init__( self.tensor_name = tensor_name self.parent_quantizer = parent_quantizer self.tp_group = tp_group # used in inspect_tensor calls + self.tp_size = tp_size self.iteration = TEDebugState.get_iteration() # Configure parent quantizer @@ -263,6 +265,7 @@ def _call_inspect_tensor_api( "tensor_name": self.tensor_name, "iteration": TEDebugState.get_iteration(), "tp_group": self.tp_group, + "tp_size": self.tp_size, "columnwise_quantized_tensor": columnwise_gemm_tensor, "rowwise_quantized_tensor": rowwise_gemm_tensor, "quantizer": self.parent_quantizer, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f3e7b57cf1..02607e45e5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1082,7 +1082,7 @@ def _get_debug_quantizers(self): names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( [ - DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group) + DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group, self.tp_size) for q_id, q in enumerate(qs) ] for name, qs in zip(names, original_quantizers) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ce0581024a..420f2a613a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1646,7 +1646,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( - DebugQuantizer(self.name, name, q, self.tp_group) + DebugQuantizer(self.name, name, q, self.tp_group, self.tp_size) for name, q in zip(names, original_quantizers) ) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 16e620fd94..e58201b660 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2373,6 +2373,7 @@ def make_debug(prefix, offset): label, None if label in ("dgrad", "wgrad") else base_quantizers[i + offset], self.tp_group, + self.tp_size, ) for i, label in enumerate(labels) ] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 31dac4d329..be0c76c1c8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1513,7 +1513,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( - DebugQuantizer(self.name, name, q, self.tp_group) + DebugQuantizer(self.name, name, q, self.tp_group, self.tp_size) for name, q in zip(names, original_quantizers) ) From add5be3c60fa195485404c09dcdef711297e3435 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 09:53:32 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/features/utils/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index 75bab4be9b..813fb2addc 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -12,9 +12,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState -def get_reduction_params( - tensor_name: str, tp_group: torch.distributed.ProcessGroup, tp_size: int -): +def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup, tp_size: int): """ Returns the statistics reduction parameters for the tensor. """