diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index 774fae3594..a1cf80dd25 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -479,7 +479,12 @@ def call_feature(self, call, feat_config, layer_name, **kwargs): """ if call.__name__ == "inspect_tensor": kwargs_copy = kwargs.copy() - for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]: + for k in [ + "quantizer", + "columnwise_quantized_tensor", + "rowwise_quantized_tensor", + "tp_size", + ]: if k not in call.__code__.co_varnames: kwargs_copy.pop(k) else: @@ -490,6 +495,10 @@ def call_feature(self, call, feat_config, layer_name, **kwargs): "inspect_tensor_postquantize is deprecated, use inspect_tensor instead.", DeprecationWarning, ) + kwargs_copy = kwargs.copy() + for k in ["tp_size"]: + if k not in call.__code__.co_varnames: + kwargs_copy.pop(k, None) return call(feat_config, layer_name, **kwargs_copy) diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index fd18d590ec..cf11964e25 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -311,6 +311,7 @@ def inspect_tensor( rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, + tp_size: int = 1, ): """ API call used to collect the data about the tensor after process_tensor()/quantization. @@ -357,7 +358,7 @@ def inspect_tensor( ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group + tensor_name, tp_group, tp_size ) STATS_BUFFERS.try_add_buffer( diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 18ac8619f3..8a76f4edcf 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -148,6 +148,7 @@ def inspect_tensor( rowwise_quantized_tensor: Optional[QuantizedTensor] = None, columnwise_quantized_tensor: Optional[QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, + tp_size: int = 1, ): """ API call used to collect the data about the tensor after process_tensor()/quantization. @@ -199,7 +200,7 @@ def inspect_tensor( ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group + tensor_name, tp_group, tp_size ) # Add nvfp4_ prefix to all stats for internal use diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 76e61fab24..5e6ce137bd 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -184,6 +184,7 @@ def inspect_tensor( rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, + tp_size: int = 1, ): # pylint: disable=unused-argument """API call used to collect the data about the tensor before process_tensor()/quantization.""" @@ -214,7 +215,7 @@ def inspect_tensor( ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group + tensor_name, tp_group, tp_size ) for stat in config["stats"]: diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index d691c1828c..813fb2addc 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -12,7 +12,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState -def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup): +def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup, tp_size: int): """ Returns the statistics reduction parameters for the tensor. """ @@ -20,8 +20,14 @@ def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGr reduction_group = debug_api.get_tensor_reduction_group() reduce_within_microbatch = tensor_name != "weight" if tensor_name == "weight": - if TEDebugState.weight_tensor_tp_group_reduce: - reduction_group = tp_group + if TEDebugState.weight_tensor_tp_group_reduce and tp_size > 1: + # Do not overwrite with `None`: in torch.distributed collectives + # group=None means the default/world process group. + if tp_group is not None: + reduction_group = tp_group + else: + # "Reduce in TP group" requested, but TP group is missing. + skip_reduction = True else: skip_reduction = True return skip_reduction, reduction_group, reduce_within_microbatch diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 57a5967079..ed5fdd4660 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -53,6 +53,7 @@ def __init__( tensor_name: str, parent_quantizer: Optional[Quantizer], tp_group: torch.distributed.ProcessGroup, + tp_size: int, ): super().__init__(rowwise=True, columnwise=True) @@ -60,6 +61,7 @@ def __init__( self.tensor_name = tensor_name self.parent_quantizer = parent_quantizer self.tp_group = tp_group # used in inspect_tensor calls + self.tp_size = tp_size self.iteration = TEDebugState.get_iteration() # Configure parent quantizer @@ -263,6 +265,7 @@ def _call_inspect_tensor_api( "tensor_name": self.tensor_name, "iteration": TEDebugState.get_iteration(), "tp_group": self.tp_group, + "tp_size": self.tp_size, "columnwise_quantized_tensor": columnwise_gemm_tensor, "rowwise_quantized_tensor": rowwise_gemm_tensor, "quantizer": self.parent_quantizer, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f3e7b57cf1..02607e45e5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1082,7 +1082,7 @@ def _get_debug_quantizers(self): names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( [ - DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group) + DebugQuantizer(self.name + f".gemm_{q_id}", name, q, self.tp_group, self.tp_size) for q_id, q in enumerate(qs) ] for name, qs in zip(names, original_quantizers) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ce0581024a..420f2a613a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1646,7 +1646,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( - DebugQuantizer(self.name, name, q, self.tp_group) + DebugQuantizer(self.name, name, q, self.tp_group, self.tp_size) for name, q in zip(names, original_quantizers) ) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 16e620fd94..e58201b660 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2373,6 +2373,7 @@ def make_debug(prefix, offset): label, None if label in ("dgrad", "wgrad") else base_quantizers[i + offset], self.tp_group, + self.tp_size, ) for i, label in enumerate(labels) ] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 31dac4d329..be0c76c1c8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1513,7 +1513,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] return tuple( - DebugQuantizer(self.name, name, q, self.tp_group) + DebugQuantizer(self.name, name, q, self.tp_group, self.tp_size) for name, q in zip(names, original_quantizers) )