From fe437c1c806b9f6753b0b26167d20872afd5b89d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 3 Feb 2026 08:54:26 +0000 Subject: [PATCH 01/27] code drop Signed-off-by: Pawel Gadzinski --- .../debug/features/dump_tensors.py | 327 ++++++++++++++++++ 1 file changed, 327 insertions(+) create mode 100644 transformer_engine/debug/features/dump_tensors.py diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py new file mode 100644 index 0000000000..e13983af67 --- /dev/null +++ b/transformer_engine/debug/features/dump_tensors.py @@ -0,0 +1,327 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""DumpTensors Feature support for nvidia-dlframework-inspect.""" + +from typing import Dict, Optional + +import torch + +import nvdlfw_inspect.api as debug_api +from nvdlfw_inspect.logging import get_tensor_logger +from nvdlfw_inspect.registry import Registry, api_method + +from transformer_engine.debug.features.api import TEConfigAPIMapper +from transformer_engine.debug.features.utils import next_enabled_iter +from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor + + +@Registry.register_feature(namespace="transformer_engine") +class DumpTensors(TEConfigAPIMapper): + """ + Dump tensors to files for debugging purposes. + + This feature saves tensors to disk using torch.save(). It supports dumping + both high-precision tensors (before quantization) and quantized tensors. + + Each tensor is saved to a separate file with the iteration number, layer name, + and tensor name in the filename. Files are organized per-rank in distributed settings. + + Parameters + ---------- + high_precision_tensor : bool + If True, dump the high-precision tensor (before quantization). + quantized_tensor : bool + If True, dump the quantized tensor (after quantization). + extended_quantized_tensor_log : bool, default = False + If True, dump additional files with raw data and scales for quantized tensors: + - For Float8Tensor: raw_data (uint8), scale_inv (FP32) + - For MXFP8Tensor: rowwise_raw_data, columnwise_raw_data (uint8), + rowwise_scale_inv, columnwise_scale_inv (decoded to FP32) + - For NVFP4Tensor: rowwise_raw_data, columnwise_raw_data (uint8), + rowwise_scale_inv, columnwise_scale_inv (decoded to FP32), + rowwise_amax, columnwise_amax (FP32) + tensors/tensors_struct : List[str] + list of tensors to dump: + - activation + - gradient + - weight + - output + - wgrad + - dgrad + freq : Optional[int], default = 1 + frequency of dumping tensors, tensors will be dumped every `freq` steps + start_step : Optional[int], default = 0 + start step of dumping tensors + end_step : Optional[int], default = -1 + end step of dumping tensors (-1 means no end) + start_end_list : Optional[list([int, int])], default = None + non-overlapping list of (start, end) pairs in incremental order. + If not None, will ignore start_step and end_step + + Example + ------- + .. code-block:: yaml + + dump_tensors_example: + enabled: True + layers: + layer_name_regex_pattern: .*(fc1|self_attention).* + transformer_engine: + DumpTensors: + enabled: True + tensors_struct: + - tensor: activation + high_precision_tensor: True + quantized_tensor: True + extended_quantized_tensor_log: True + freq: 100 + - tensor: weight + high_precision_tensor: True + quantized_tensor: False + freq: 500 + + Output Structure + ---------------- + Files are saved to: ``nvdlfw_inspect_tensor_dumps/rank_{rank}/`` + + Basic files: + - ``{layer}_{tensor}_iter_{iter}_high_precision.pt`` + - ``{layer}_{tensor}_iter_{iter}_quantized.pt`` + + Extended files (when extended_quantized_tensor_log=True): + - ``{layer}_{tensor}_iter_{iter}_raw_data.pt`` + - ``{layer}_{tensor}_iter_{iter}_scale_inv.pt`` + - (MXFP8/NVFP4) ``{layer}_{tensor}_iter_{iter}_rowwise_scale_inv.pt`` + - (NVFP4) ``{layer}_{tensor}_iter_{iter}_rowwise_amax.pt`` + """ + + @api_method + def inspect_tensor_enabled( + self, config: Dict, layer_name: str, tensor_name: str, iteration: int + ): # pylint: disable=unused-argument + """API call used to determine whether to run inspect_tensor() in the forward.""" + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + return run_current, next_iter + + @api_method + def inspect_tensor( + self, + config: Dict, + layer_name: str, + tensor_name: str, + iteration: int, + tp_group: torch.distributed.ProcessGroup, + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, + ): # pylint: disable=unused-argument + """ + API call used to dump tensors to files. + + Supports dumping both high-precision tensors and quantized tensors based on config. + """ + # Assert that rowwise and columnwise are the same (or one is None) + assert rowwise_quantized_tensor is columnwise_quantized_tensor, ( + "[NVTORCH INSPECT ERROR] DumpTensors expects rowwise_quantized_tensor and " + "columnwise_quantized_tensor to be the same object or both None." + ) + + quantized_tensor = rowwise_quantized_tensor + + dump_hp = config.get("high_precision_tensor", False) + dump_quant = config.get("quantized_tensor", False) + + if not dump_hp and not dump_quant: + debug_api.log_message( + f"Feature={self.__class__.__name__}: Neither high_precision_tensor nor " + "quantized_tensor is enabled. Nothing to dump.", + layer_name, + ) + return + + tensor_logger = get_tensor_logger() + + # Dump high-precision tensor + if dump_hp and tensor is not None: + tensor_logger.save_tensor( + tensor=tensor, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_high_precision", + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: " + f"Dumped high-precision {tensor_name} at iteration {iteration}", + layer_name, + ) + + # Dump quantized tensor + if dump_quant and quantized_tensor is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_quantized", + ) + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: " + f"Dumped quantized {tensor_name} at iteration {iteration}", + layer_name, + ) + + # Extended logging for quantized tensors + if config.get("extended_quantized_tensor_log", False): + self._dump_extended_quantized_info( + tensor_logger, quantized_tensor, layer_name, tensor_name, iteration + ) + + elif dump_quant and quantized_tensor is None: + debug_api.log_message( + f"Feature={self.__class__.__name__}: quantized_tensor is True but " + f"no quantized tensor available for {tensor_name}. Skipping.", + layer_name, + ) + + def _dump_extended_quantized_info( + self, + tensor_logger, + quantized_tensor: QuantizedTensor, + layer_name: str, + tensor_name: str, + iteration: int, + ): + """Dump extended debug info for quantized tensors (raw data and scales).""" + + if isinstance(quantized_tensor, Float8Tensor): + # Float8Tensor: raw_data (uint8), scale_inv (FP32) + tensor_logger.save_tensor( + tensor=quantized_tensor._data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_raw_data", + ) + tensor_logger.save_tensor( + tensor=quantized_tensor._scale_inv, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_scale_inv", + ) + + elif isinstance(quantized_tensor, MXFP8Tensor): + # MXFP8Tensor: raw data and scales (decoded from E8M0) + if quantized_tensor._rowwise_data is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._rowwise_data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_raw_data", + ) + if quantized_tensor._columnwise_data is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._columnwise_data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_raw_data", + ) + # Decode E8M0 scales to FP32 + if quantized_tensor._rowwise_scale_inv is not None: + decoded = torch.pow( + torch.tensor(2.0, device=quantized_tensor._rowwise_scale_inv.device), + quantized_tensor._rowwise_scale_inv.to(torch.float32) - 127.0, + ) + tensor_logger.save_tensor( + tensor=decoded, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_scale_inv", + ) + if quantized_tensor._columnwise_scale_inv is not None: + decoded = torch.pow( + torch.tensor(2.0, device=quantized_tensor._columnwise_scale_inv.device), + quantized_tensor._columnwise_scale_inv.to(torch.float32) - 127.0, + ) + tensor_logger.save_tensor( + tensor=decoded, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_scale_inv", + ) + + elif isinstance(quantized_tensor, NVFP4Tensor): + # NVFP4Tensor: raw data, scales (decoded from E4M3), and amax + if quantized_tensor._rowwise_data is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._rowwise_data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_raw_data", + ) + if quantized_tensor._columnwise_data is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._columnwise_data, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_raw_data", + ) + # Decode E4M3 scales to FP32 + if quantized_tensor._rowwise_scale_inv is not None: + decoded = quantized_tensor._rowwise_scale_inv.view(torch.float8_e4m3fn).to( + torch.float32 + ) + tensor_logger.save_tensor( + tensor=decoded, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_scale_inv", + ) + if quantized_tensor._columnwise_scale_inv is not None: + decoded = quantized_tensor._columnwise_scale_inv.view(torch.float8_e4m3fn).to( + torch.float32 + ) + tensor_logger.save_tensor( + tensor=decoded, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_scale_inv", + ) + # Amax values (already FP32) + if quantized_tensor._amax_rowwise is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._amax_rowwise, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_rowwise_amax", + ) + if quantized_tensor._amax_columnwise is not None: + tensor_logger.save_tensor( + tensor=quantized_tensor._amax_columnwise, + layer_name=layer_name, + tensor_name=tensor_name, + iteration=iteration, + suffix="_columnwise_amax", + ) From a54a743af93c198535636dad47c0c59fee9a1ba8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 3 Feb 2026 10:42:56 +0000 Subject: [PATCH 02/27] code drop Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 72 ++++ .../debug/features/dump_tensors.py | 369 ++++++++++-------- 2 files changed, 270 insertions(+), 171 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 5d6fc41ac7..5f7adc0a41 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -592,3 +592,75 @@ def test_compute_max_blockwise_dynamic_range_direct(): ) print("All direct tests for compute_max_blockwise_dynamic_range passed!") + + +# DumpTensors tests +DUMP_TENSORS_CONFIG = """ +dump: + layers: + layer_name_regex_pattern: .* + enabled: True + transformer_engine: + DumpTensors: + enabled: True + tensors: [activation] + high_precision_tensor: True + quantized_tensor: True + dump_quantized_internals: True + freq: 1 +""" + + +def test_dump_tensors_sanity(feature_dirs): + """Sanity test for DumpTensors feature - verify files are created with correct structure.""" + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + with debug_session(DUMP_TENSORS_CONFIG, feature_dirs) as log_dir: + from transformer_engine.pytorch.quantization import RecipeState + + recipe_state = RecipeState.create( + recipe.DelayedScaling(), + mode="forward", + num_quantizers=3, + ) + + tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(tensor) + + debug_api.transformer_engine.inspect_tensor( + layer_name="test_layer", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=quantized_tensor, + columnwise_quantized_tensor=quantized_tensor, + ) + debug_api.step() + + # Check that dump file was created + dump_dir = os.path.join(log_dir, "tensor_dumps", "rank_0") + assert os.path.exists(dump_dir), f"Dump directory not created: {dump_dir}" + + dump_files = os.listdir(dump_dir) + assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}" + + # Load and verify structure + dump_file = os.path.join(dump_dir, dump_files[0]) + data = torch.load(dump_file, weights_only=False) + + assert isinstance(data, dict), "Dump should be a dictionary" + assert "high_precision" in data, "Missing high_precision tensor" + assert "quantized" in data, "Missing quantized tensor" + + # Check internals are present (dump_quantized_internals=True) + assert "data" in data, "Missing data (raw FP8 data)" + assert "scale_inv" in data, "Missing scale_inv" + + # Verify tensor shapes match + assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" + + print("DumpTensors sanity test passed!") diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index e13983af67..136771b7ae 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -4,22 +4,95 @@ """DumpTensors Feature support for nvidia-dlframework-inspect.""" +import os from typing import Dict, Optional import torch +import torch.distributed as dist import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.logging import get_tensor_logger +from nvdlfw_inspect.logging import get_logger from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.api import TEConfigAPIMapper from transformer_engine.debug.features.utils import next_enabled_iter +from transformer_engine.pytorch.constants import TE_DType_To_Torch from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor +class TensorLogger: + """Logger for saving tensors to files. Each rank saves to its own directory.""" + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if TensorLogger._initialized: + return + self.root_dir = None + self.rank = 0 + TensorLogger._initialized = True + + def initialize(self, root_log_dir: str): + """Initialize the TensorLogger with the root directory for tensor dumps.""" + self.rank = 0 + if dist.is_initialized(): + self.rank = dist.get_rank() + + self.root_dir = os.path.join( + root_log_dir, "tensor_dumps", f"rank_{self.rank}" + ) + os.makedirs(self.root_dir, exist_ok=True) + + debug_api.log_message( + f"TensorLogger initialized. Saving tensors to: {self.root_dir}", + log_level="info", + ) + + @staticmethod + def _sanitize_name(name: str) -> str: + """Sanitize layer/tensor names for use in file paths.""" + for char in ["/", "\\", ":", "*", "?", '"', "<", ">", "|", " "]: + name = name.replace(char, "_") + return name + + def save_tensor( + self, + tensor, + layer_name: str, + tensor_name: str, + iteration: int, + ): + """Save a tensor (or dict of tensors) to a file.""" + if self.root_dir is None: + raise RuntimeError( + "[TE DumpTensors] TensorLogger not initialized. " + "Call initialize() first." + ) + + safe_layer_name = self._sanitize_name(layer_name) + safe_tensor_name = self._sanitize_name(tensor_name) + + filename = f"{safe_layer_name}_{safe_tensor_name}_iter_{iteration:06d}.pt" + filepath = os.path.join(self.root_dir, filename) + + torch.save(tensor, filepath) + + +def _get_tensor_logger() -> TensorLogger: + """Get the singleton TensorLogger instance.""" + return TensorLogger() + + @Registry.register_feature(namespace="transformer_engine") class DumpTensors(TEConfigAPIMapper): """ @@ -37,14 +110,10 @@ class DumpTensors(TEConfigAPIMapper): If True, dump the high-precision tensor (before quantization). quantized_tensor : bool If True, dump the quantized tensor (after quantization). - extended_quantized_tensor_log : bool, default = False - If True, dump additional files with raw data and scales for quantized tensors: - - For Float8Tensor: raw_data (uint8), scale_inv (FP32) - - For MXFP8Tensor: rowwise_raw_data, columnwise_raw_data (uint8), - rowwise_scale_inv, columnwise_scale_inv (decoded to FP32) - - For NVFP4Tensor: rowwise_raw_data, columnwise_raw_data (uint8), - rowwise_scale_inv, columnwise_scale_inv (decoded to FP32), - rowwise_amax, columnwise_amax (FP32) + dump_quantized_internals : bool, default = False + If True, include extracted internal data from quantized tensors + (raw data, scales, etc.) in the output dictionary. + Useful for offline analysis. Output format may change between versions. tensors/tensors_struct : List[str] list of tensors to dump: - activation @@ -78,7 +147,7 @@ class DumpTensors(TEConfigAPIMapper): - tensor: activation high_precision_tensor: True quantized_tensor: True - extended_quantized_tensor_log: True + dump_quantized_internals: True freq: 100 - tensor: weight high_precision_tensor: True @@ -87,17 +156,16 @@ class DumpTensors(TEConfigAPIMapper): Output Structure ---------------- - Files are saved to: ``nvdlfw_inspect_tensor_dumps/rank_{rank}/`` + Files are saved to: ``{nvdlfw_inspect_log_dir}/tensor_dumps/rank_{rank}/`` - Basic files: - - ``{layer}_{tensor}_iter_{iter}_high_precision.pt`` - - ``{layer}_{tensor}_iter_{iter}_quantized.pt`` + Each tensor is saved as a dictionary in a single file: + ``{layer}_{tensor}_iter_{iter:06d}.pt`` - Extended files (when extended_quantized_tensor_log=True): - - ``{layer}_{tensor}_iter_{iter}_raw_data.pt`` - - ``{layer}_{tensor}_iter_{iter}_scale_inv.pt`` - - (MXFP8/NVFP4) ``{layer}_{tensor}_iter_{iter}_rowwise_scale_inv.pt`` - - (NVFP4) ``{layer}_{tensor}_iter_{iter}_rowwise_amax.pt`` + Dictionary keys: + - ``high_precision``: pre-quantization tensor (if high_precision_tensor=True) + - ``quantized``: quantized tensor object (if quantized_tensor=True) + - Additional internal components when dump_quantized_internals=True + (raw data, scales, etc. - format may change between versions) """ @api_method @@ -151,43 +219,23 @@ def inspect_tensor( ) return - tensor_logger = get_tensor_logger() + tensor_logger = _get_tensor_logger() + if tensor_logger.root_dir is None: + tensor_logger.initialize(get_logger().root_log_dir) + + # Build dictionary with all tensors to dump + dump_dict: Dict[str, torch.Tensor] = {} - # Dump high-precision tensor if dump_hp and tensor is not None: - tensor_logger.save_tensor( - tensor=tensor, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_high_precision", - ) - debug_api.log_message( - f"Feature={self.__class__.__name__}, API=inspect_tensor: " - f"Dumped high-precision {tensor_name} at iteration {iteration}", - layer_name, - ) + dump_dict["high_precision"] = tensor - # Dump quantized tensor if dump_quant and quantized_tensor is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_quantized", - ) - debug_api.log_message( - f"Feature={self.__class__.__name__}, API=inspect_tensor: " - f"Dumped quantized {tensor_name} at iteration {iteration}", - layer_name, - ) + dump_dict["quantized"] = quantized_tensor - # Extended logging for quantized tensors - if config.get("extended_quantized_tensor_log", False): - self._dump_extended_quantized_info( - tensor_logger, quantized_tensor, layer_name, tensor_name, iteration - ) + # Add internals for quantized tensors + if config.get("dump_quantized_internals", False): + internals = self._get_quantized_internals(quantized_tensor) + dump_dict.update(internals) elif dump_quant and quantized_tensor is None: debug_api.log_message( @@ -196,132 +244,111 @@ def inspect_tensor( layer_name, ) - def _dump_extended_quantized_info( - self, - tensor_logger, - quantized_tensor: QuantizedTensor, - layer_name: str, - tensor_name: str, - iteration: int, - ): - """Dump extended debug info for quantized tensors (raw data and scales).""" - - if isinstance(quantized_tensor, Float8Tensor): - # Float8Tensor: raw_data (uint8), scale_inv (FP32) + if dump_dict: tensor_logger.save_tensor( - tensor=quantized_tensor._data, + tensor=dump_dict, layer_name=layer_name, tensor_name=tensor_name, iteration=iteration, - suffix="_raw_data", ) - tensor_logger.save_tensor( - tensor=quantized_tensor._scale_inv, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_scale_inv", + debug_api.log_message( + f"Feature={self.__class__.__name__}, API=inspect_tensor: " + f"Dumped {tensor_name} at iteration {iteration} (keys: {list(dump_dict.keys())})", + layer_name, ) + def _get_quantized_internals( + self, + quantized_tensor: QuantizedTensor, + ) -> Dict[str, torch.Tensor]: + """Get internal components of quantized tensors (raw data, scales, etc.).""" + if isinstance(quantized_tensor, Float8Tensor): + tensors = _get_extended_tensors_fp8(quantized_tensor) + elif isinstance(quantized_tensor, Float8BlockwiseQTensor): + tensors = _get_extended_tensors_fp8_blockwise(quantized_tensor) elif isinstance(quantized_tensor, MXFP8Tensor): - # MXFP8Tensor: raw data and scales (decoded from E8M0) - if quantized_tensor._rowwise_data is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._rowwise_data, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_raw_data", - ) - if quantized_tensor._columnwise_data is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._columnwise_data, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_raw_data", - ) - # Decode E8M0 scales to FP32 - if quantized_tensor._rowwise_scale_inv is not None: - decoded = torch.pow( - torch.tensor(2.0, device=quantized_tensor._rowwise_scale_inv.device), - quantized_tensor._rowwise_scale_inv.to(torch.float32) - 127.0, - ) - tensor_logger.save_tensor( - tensor=decoded, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_scale_inv", - ) - if quantized_tensor._columnwise_scale_inv is not None: - decoded = torch.pow( - torch.tensor(2.0, device=quantized_tensor._columnwise_scale_inv.device), - quantized_tensor._columnwise_scale_inv.to(torch.float32) - 127.0, - ) - tensor_logger.save_tensor( - tensor=decoded, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_scale_inv", - ) - + tensors = _get_extended_tensors_mxfp8(quantized_tensor) elif isinstance(quantized_tensor, NVFP4Tensor): - # NVFP4Tensor: raw data, scales (decoded from E4M3), and amax - if quantized_tensor._rowwise_data is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._rowwise_data, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_raw_data", - ) - if quantized_tensor._columnwise_data is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._columnwise_data, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_raw_data", - ) - # Decode E4M3 scales to FP32 - if quantized_tensor._rowwise_scale_inv is not None: - decoded = quantized_tensor._rowwise_scale_inv.view(torch.float8_e4m3fn).to( - torch.float32 - ) - tensor_logger.save_tensor( - tensor=decoded, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_scale_inv", - ) - if quantized_tensor._columnwise_scale_inv is not None: - decoded = quantized_tensor._columnwise_scale_inv.view(torch.float8_e4m3fn).to( - torch.float32 - ) - tensor_logger.save_tensor( - tensor=decoded, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_scale_inv", - ) - # Amax values (already FP32) - if quantized_tensor._amax_rowwise is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._amax_rowwise, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_rowwise_amax", - ) - if quantized_tensor._amax_columnwise is not None: - tensor_logger.save_tensor( - tensor=quantized_tensor._amax_columnwise, - layer_name=layer_name, - tensor_name=tensor_name, - iteration=iteration, - suffix="_columnwise_amax", - ) + tensors = _get_extended_tensors_nvfp4(quantized_tensor) + else: + return {} + + # Filter out None values + return {k: v for k, v in tensors.items() if v is not None} + + +def _get_extended_tensors_fp8(tensor: Float8Tensor) -> Dict[str, torch.Tensor]: + """Get extended tensors for Float8Tensor: raw FP8 data, transpose, and scale.""" + torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + result = { + "data": tensor._data.view(torch_fp8_dtype), + "scale_inv": tensor._scale_inv, + } + if tensor._transpose is not None and not tensor._transpose_invalid: + result["transpose"] = tensor._transpose.view(torch_fp8_dtype) + return result + + +def _get_extended_tensors_fp8_blockwise( + tensor: Float8BlockwiseQTensor, +) -> Dict[str, Optional[torch.Tensor]]: + """Get extended tensors for Float8BlockwiseQTensor: raw FP8 data and block scales.""" + torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + result: Dict[str, Optional[torch.Tensor]] = {} + + if tensor._rowwise_data is not None: + result["rowwise_data"] = tensor._rowwise_data.view(torch_fp8_dtype) + if tensor._columnwise_data is not None: + result["columnwise_data"] = tensor._columnwise_data.view(torch_fp8_dtype) + + # Block scaling factors (FP32) + if tensor._rowwise_scale_inv is not None: + result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv + if tensor._columnwise_scale_inv is not None: + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv + + return result + + +def _get_extended_tensors_mxfp8(tensor: MXFP8Tensor) -> Dict[str, Optional[torch.Tensor]]: + """Get extended tensors for MXFP8Tensor: raw FP8 data and block scales (E8M0).""" + torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] + result: Dict[str, Optional[torch.Tensor]] = {} + + if tensor._rowwise_data is not None: + result["rowwise_data"] = tensor._rowwise_data.view(torch_fp8_dtype) + if tensor._columnwise_data is not None: + result["columnwise_data"] = tensor._columnwise_data.view(torch_fp8_dtype) + + # Block scaling factors (E8M0 format) + if tensor._rowwise_scale_inv is not None: + result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e8m0fnu) + if tensor._columnwise_scale_inv is not None: + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e8m0fnu) + + return result + + +def _get_extended_tensors_nvfp4(tensor: NVFP4Tensor) -> Dict[str, Optional[torch.Tensor]]: + """Get extended tensors for NVFP4Tensor: raw packed FP4 data, block scales, and amax.""" + result: Dict[str, Optional[torch.Tensor]] = {} + + # Raw data (packed FP4, 2 values per byte) + if tensor._rowwise_data is not None: + result["rowwise_data"] = tensor._rowwise_data + if tensor._columnwise_data is not None: + result["columnwise_data"] = tensor._columnwise_data + + # Block scaling factors (E4M3 format) + if tensor._rowwise_scale_inv is not None: + result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e4m3fn) + if tensor._columnwise_scale_inv is not None: + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e4m3fn) + + # Input absolute maximum value (used to compute tensor scale) + if tensor._amax_rowwise is not None: + result["amax_rowwise"] = tensor._amax_rowwise + if tensor._amax_columnwise is not None: + result["amax_columnwise"] = tensor._amax_columnwise + + return result From b6e076734dc791e178bb68681a3c74fa7c119a77 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 10:45:41 +0000 Subject: [PATCH 03/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/features/dump_tensors.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 136771b7ae..b453b5e273 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -48,9 +48,7 @@ def initialize(self, root_log_dir: str): if dist.is_initialized(): self.rank = dist.get_rank() - self.root_dir = os.path.join( - root_log_dir, "tensor_dumps", f"rank_{self.rank}" - ) + self.root_dir = os.path.join(root_log_dir, "tensor_dumps", f"rank_{self.rank}") os.makedirs(self.root_dir, exist_ok=True) debug_api.log_message( @@ -75,8 +73,7 @@ def save_tensor( """Save a tensor (or dict of tensors) to a file.""" if self.root_dir is None: raise RuntimeError( - "[TE DumpTensors] TensorLogger not initialized. " - "Call initialize() first." + "[TE DumpTensors] TensorLogger not initialized. Call initialize() first." ) safe_layer_name = self._sanitize_name(layer_name) @@ -324,7 +321,9 @@ def _get_extended_tensors_mxfp8(tensor: MXFP8Tensor) -> Dict[str, Optional[torch if tensor._rowwise_scale_inv is not None: result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e8m0fnu) if tensor._columnwise_scale_inv is not None: - result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e8m0fnu) + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view( + torch.float8_e8m0fnu + ) return result @@ -343,7 +342,9 @@ def _get_extended_tensors_nvfp4(tensor: NVFP4Tensor) -> Dict[str, Optional[torch if tensor._rowwise_scale_inv is not None: result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e4m3fn) if tensor._columnwise_scale_inv is not None: - result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view(torch.float8_e4m3fn) + result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view( + torch.float8_e4m3fn + ) # Input absolute maximum value (used to compute tensor scale) if tensor._amax_rowwise is not None: From dc60fe870ea3695fba6ba1e94d097994a68fe1fe Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Mar 2026 10:11:05 +0000 Subject: [PATCH 04/27] docs Signed-off-by: root --- docs/debug/3_api_features.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/debug/3_api_features.rst b/docs/debug/3_api_features.rst index a973a0b4fe..a8a644d5b5 100644 --- a/docs/debug/3_api_features.rst +++ b/docs/debug/3_api_features.rst @@ -14,4 +14,5 @@ Debug features .. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling .. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant .. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM -.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer \ No newline at end of file +.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer +.. autoapiclass:: transformer_engine.debug.features.dump_tensors.DumpTensors \ No newline at end of file From e94467faff9f17f7c1b474c96e36363b6e636e86 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Mar 2026 10:40:51 +0000 Subject: [PATCH 05/27] nvfp4 internals support Signed-off-by: root --- tests/pytorch/debug/test_log.py | 112 +++++++++++++++--- .../debug/features/dump_tensors.py | 60 +++++++++- 2 files changed, 153 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 9341369241..0b3834d913 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -663,6 +663,22 @@ def test_compute_max_blockwise_dynamic_range_direct(): """ +NVFP4_DUMP_TENSORS_CONFIG = """ +dump: + layers: + layer_name_regex_pattern: .* + enabled: True + transformer_engine: + DumpTensors: + enabled: True + tensors: [activation] + high_precision_tensor: False + quantized_tensor: True + dump_quantized_internals: True + freq: 1 +""" + + def test_dump_tensors_sanity(feature_dirs): """Sanity test for DumpTensors feature - verify files are created with correct structure.""" if not fp8_available: @@ -693,26 +709,90 @@ def test_dump_tensors_sanity(feature_dirs): ) debug_api.step() - # Check that dump file was created - dump_dir = os.path.join(log_dir, "tensor_dumps", "rank_0") - assert os.path.exists(dump_dir), f"Dump directory not created: {dump_dir}" + # Check that dump file was created + dump_dir = os.path.join(log_dir, "tensor_dumps", "rank_0") + assert os.path.exists(dump_dir), f"Dump directory not created: {dump_dir}" - dump_files = os.listdir(dump_dir) - assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}" + dump_files = os.listdir(dump_dir) + assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}" - # Load and verify structure - dump_file = os.path.join(dump_dir, dump_files[0]) - data = torch.load(dump_file, weights_only=False) + # Load and verify structure + dump_file = os.path.join(dump_dir, dump_files[0]) + data = torch.load(dump_file, weights_only=False) - assert isinstance(data, dict), "Dump should be a dictionary" - assert "high_precision" in data, "Missing high_precision tensor" - assert "quantized" in data, "Missing quantized tensor" + assert isinstance(data, dict), "Dump should be a dictionary" + assert "high_precision" in data, "Missing high_precision tensor" + assert "quantized" in data, "Missing quantized tensor" - # Check internals are present (dump_quantized_internals=True) - assert "data" in data, "Missing data (raw FP8 data)" - assert "scale_inv" in data, "Missing scale_inv" + # Check internals are present (dump_quantized_internals=True) + assert "data" in data, "Missing data (raw FP8 data)" + assert "scale_inv" in data, "Missing scale_inv" - # Verify tensor shapes match - assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" + # Verify tensor shapes match + assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" print("DumpTensors sanity test passed!") + + +def test_dump_tensors_nvfp4_unpacked_codes(feature_dirs): + """Verify DumpTensors includes unpacked FP4 values in NVFP4 internals.""" + if not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + with debug_session(NVFP4_DUMP_TENSORS_CONFIG, feature_dirs) as log_dir: + recipe_state = RecipeState.create( + recipe.NVFP4BlockScaling(), + mode="forward", + num_quantizers=3, + ) + + tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda() + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(tensor) + + debug_api.transformer_engine.inspect_tensor( + layer_name="test_layer", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=quantized_tensor, + columnwise_quantized_tensor=quantized_tensor, + ) + debug_api.step() + + dump_dir = os.path.join(log_dir, "tensor_dumps", "rank_0") + dump_files = os.listdir(dump_dir) + assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}" + + data = torch.load(os.path.join(dump_dir, dump_files[0]), weights_only=False) + assert "rowwise_data" in data, "Missing packed NVFP4 rowwise_data" + assert "rowwise_data_unpacked_values" in data, "Missing unpacked NVFP4 rowwise values" + + packed = data["rowwise_data"] + unpacked = data["rowwise_data_unpacked_values"] + assert unpacked.dtype == torch.float32, "Unpacked values must be float32" + assert unpacked.shape[-1] == packed.shape[-1] * 2, ( + "Unpacked values should double the last packed dimension" + ) + assert unpacked.min().item() >= -6.0 and unpacked.max().item() <= 6.0, ( + "Decoded FP4 values should be in representable E2M1 range [-6, 6]" + ) + + # Reconstruct dequantized values from unpacked FP4 values and block scales. + # For NVFP4 rowwise path, one E4M3 scale corresponds to a block of 16 values. + assert "rowwise_block_scale_inv" in data, "Missing rowwise NVFP4 block scales" + rowwise_scale_inv = data["rowwise_block_scale_inv"].to(torch.float32) + values = unpacked.to(torch.float32) + n_rows, n_cols = values.shape + scale_tiles = (n_cols + 15) // 16 + expanded_scales = rowwise_scale_inv[:n_rows, :scale_tiles].repeat_interleave(16, dim=1)[ + :, :n_cols + ] + reconstructed = values * expanded_scales + + expected = quantized_tensor.dequantize(dtype=torch.float32) + assert torch.allclose(reconstructed, expected, atol=1e-5, rtol=1e-3), ( + "Unpacked FP4 values multiplied by block scales should match NVFP4 dequantization" + ) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index b453b5e273..4be57e3b6c 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -48,7 +48,7 @@ def initialize(self, root_log_dir: str): if dist.is_initialized(): self.rank = dist.get_rank() - self.root_dir = os.path.join(root_log_dir, "tensor_dumps", f"rank_{self.rank}") + self.root_dir = self._expected_root_dir(root_log_dir) os.makedirs(self.root_dir, exist_ok=True) debug_api.log_message( @@ -56,6 +56,16 @@ def initialize(self, root_log_dir: str): log_level="info", ) + def _expected_root_dir(self, root_log_dir: str) -> str: + """Return the rank-specific dump directory for the provided root log path.""" + return os.path.join(root_log_dir, "tensor_dumps", f"rank_{self.rank}") + + def ensure_initialized(self, root_log_dir: str) -> None: + """Reinitialize logger if debug session log directory changed.""" + expected_root_dir = self._expected_root_dir(root_log_dir) + if self.root_dir != expected_root_dir or not os.path.isdir(expected_root_dir): + self.initialize(root_log_dir) + @staticmethod def _sanitize_name(name: str) -> str: """Sanitize layer/tensor names for use in file paths.""" @@ -163,6 +173,8 @@ class DumpTensors(TEConfigAPIMapper): - ``quantized``: quantized tensor object (if quantized_tensor=True) - Additional internal components when dump_quantized_internals=True (raw data, scales, etc. - format may change between versions) + - For NVFP4 internals, unpacked FP4 value tensors are included for + easier offline analysis. """ @api_method @@ -217,8 +229,7 @@ def inspect_tensor( return tensor_logger = _get_tensor_logger() - if tensor_logger.root_dir is None: - tensor_logger.initialize(get_logger().root_log_dir) + tensor_logger.ensure_initialized(get_logger().root_log_dir) # Build dictionary with all tensors to dump dump_dict: Dict[str, torch.Tensor] = {} @@ -328,6 +339,45 @@ def _get_extended_tensors_mxfp8(tensor: MXFP8Tensor) -> Dict[str, Optional[torch return result +def _unpack_uint4_codes(packed_data: torch.Tensor) -> torch.Tensor: + """Unpack packed uint4 values stored in uint8 into uint8 tensor with values 0..15.""" + packed_uint8 = packed_data.view(torch.uint8).contiguous().view(-1) + unpacked = torch.empty(packed_uint8.numel() * 2, dtype=torch.uint8, device=packed_data.device) + unpacked[::2] = packed_uint8 & 0x0F + unpacked[1::2] = (packed_uint8 >> 4) & 0x0F + unpacked_shape = (*packed_data.shape[:-1], packed_data.shape[-1] * 2) + return unpacked.view(unpacked_shape) + + +def _decode_uint4_e2m1_to_float(unpacked_codes: torch.Tensor) -> torch.Tensor: + """Decode uint4 FP4 E2M1 codes (0..15) into float32 values.""" + # Bit layout: [sign:1][exp:2][mantissa:1], exponent bias = 1. + # Positive representable magnitudes are: 0, 0.5, 1, 1.5, 2, 3, 4, 6. + fp4_e2m1_lut = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + device=unpacked_codes.device, + dtype=torch.float32, + ) + return fp4_e2m1_lut[unpacked_codes.long()] + + def _get_extended_tensors_nvfp4(tensor: NVFP4Tensor) -> Dict[str, Optional[torch.Tensor]]: """Get extended tensors for NVFP4Tensor: raw packed FP4 data, block scales, and amax.""" result: Dict[str, Optional[torch.Tensor]] = {} @@ -335,8 +385,12 @@ def _get_extended_tensors_nvfp4(tensor: NVFP4Tensor) -> Dict[str, Optional[torch # Raw data (packed FP4, 2 values per byte) if tensor._rowwise_data is not None: result["rowwise_data"] = tensor._rowwise_data + rowwise_codes = _unpack_uint4_codes(tensor._rowwise_data) + result["rowwise_data_unpacked_values"] = _decode_uint4_e2m1_to_float(rowwise_codes) if tensor._columnwise_data is not None: result["columnwise_data"] = tensor._columnwise_data + columnwise_codes = _unpack_uint4_codes(tensor._columnwise_data) + result["columnwise_data_unpacked_values"] = _decode_uint4_e2m1_to_float(columnwise_codes) # Block scaling factors (E4M3 format) if tensor._rowwise_scale_inv is not None: From e8c8e56a2c8834ae34d1313ed48f4e278d1a6c0b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:42:02 +0000 Subject: [PATCH 06/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 0b3834d913..38ee729022 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -773,12 +773,12 @@ def test_dump_tensors_nvfp4_unpacked_codes(feature_dirs): packed = data["rowwise_data"] unpacked = data["rowwise_data_unpacked_values"] assert unpacked.dtype == torch.float32, "Unpacked values must be float32" - assert unpacked.shape[-1] == packed.shape[-1] * 2, ( - "Unpacked values should double the last packed dimension" - ) - assert unpacked.min().item() >= -6.0 and unpacked.max().item() <= 6.0, ( - "Decoded FP4 values should be in representable E2M1 range [-6, 6]" - ) + assert ( + unpacked.shape[-1] == packed.shape[-1] * 2 + ), "Unpacked values should double the last packed dimension" + assert ( + unpacked.min().item() >= -6.0 and unpacked.max().item() <= 6.0 + ), "Decoded FP4 values should be in representable E2M1 range [-6, 6]" # Reconstruct dequantized values from unpacked FP4 values and block scales. # For NVFP4 rowwise path, one E4M3 scale corresponds to a block of 16 values. @@ -793,6 +793,6 @@ def test_dump_tensors_nvfp4_unpacked_codes(feature_dirs): reconstructed = values * expanded_scales expected = quantized_tensor.dequantize(dtype=torch.float32) - assert torch.allclose(reconstructed, expected, atol=1e-5, rtol=1e-3), ( - "Unpacked FP4 values multiplied by block scales should match NVFP4 dequantization" - ) + assert torch.allclose( + reconstructed, expected, atol=1e-5, rtol=1e-3 + ), "Unpacked FP4 values multiplied by block scales should match NVFP4 dequantization" From b002b89dc95413dacd77ffff3855376cc6927d2d Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Mar 2026 10:43:46 +0000 Subject: [PATCH 07/27] lint fixes Signed-off-by: root --- transformer_engine/debug/features/dump_tensors.py | 3 --- transformer_engine/debug/features/log_fp8_tensor_stats.py | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 4be57e3b6c..5bb6d5a554 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -53,7 +53,6 @@ def initialize(self, root_log_dir: str): debug_api.log_message( f"TensorLogger initialized. Saving tensors to: {self.root_dir}", - log_level="info", ) def _expected_root_dir(self, root_log_dir: str) -> str: @@ -173,8 +172,6 @@ class DumpTensors(TEConfigAPIMapper): - ``quantized``: quantized tensor object (if quantized_tensor=True) - Additional internal components when dump_quantized_internals=True (raw data, scales, etc. - format may change between versions) - - For NVFP4 internals, unpacked FP4 value tensors are included for - easier offline analysis. """ @api_method diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index fd18d590ec..27535be70c 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -10,10 +10,9 @@ import torch import nvdlfw_inspect.api as debug_api -import transformer_engine_torch as tex - from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method +import transformer_engine_torch as tex from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter From 2816f37dad3517d9b3e8fc443473b4ffff0e0c5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 5 Mar 2026 11:50:01 +0100 Subject: [PATCH 08/27] Update transformer_engine/debug/features/dump_tensors.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/debug/features/dump_tensors.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 5bb6d5a554..47ceaf5d4f 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -233,6 +233,12 @@ def inspect_tensor( if dump_hp and tensor is not None: dump_dict["high_precision"] = tensor + elif dump_hp and tensor is None: + debug_api.log_message( + f"Feature={self.__class__.__name__}: high_precision_tensor is True but " + f"no high-precision tensor available for {tensor_name}. Skipping.", + layer_name, + ) if dump_quant and quantized_tensor is not None: dump_dict["quantized"] = quantized_tensor From 83506af50a9d4099b938942b859fbf7de8ff28fe Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Mar 2026 10:56:55 +0000 Subject: [PATCH 09/27] fix Signed-off-by: root --- .../debug/features/dump_tensors.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 47ceaf5d4f..aa3debed6c 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -206,13 +206,23 @@ def inspect_tensor( Supports dumping both high-precision tensors and quantized tensors based on config. """ - # Assert that rowwise and columnwise are the same (or one is None) - assert rowwise_quantized_tensor is columnwise_quantized_tensor, ( - "[NVTORCH INSPECT ERROR] DumpTensors expects rowwise_quantized_tensor and " - "columnwise_quantized_tensor to be the same object or both None." - ) + # We support one-sided availability (only rowwise or only columnwise tensor). + # If both are present, require them to be the same object to avoid ambiguity. + if ( + rowwise_quantized_tensor is not None + and columnwise_quantized_tensor is not None + and rowwise_quantized_tensor is not columnwise_quantized_tensor + ): + raise AssertionError( + "[NVTORCH INSPECT ERROR] DumpTensors expects rowwise_quantized_tensor and " + "columnwise_quantized_tensor to be the same object when both are provided." + ) - quantized_tensor = rowwise_quantized_tensor + quantized_tensor = ( + rowwise_quantized_tensor + if rowwise_quantized_tensor is not None + else columnwise_quantized_tensor + ) dump_hp = config.get("high_precision_tensor", False) dump_quant = config.get("quantized_tensor", False) From a525f82ffb312a4f0c1ec3ffd8d0639e93b434fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 5 Mar 2026 11:57:47 +0100 Subject: [PATCH 10/27] Update transformer_engine/debug/features/dump_tensors.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/debug/features/dump_tensors.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index aa3debed6c..beee08497b 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -88,9 +88,10 @@ def save_tensor( safe_layer_name = self._sanitize_name(layer_name) safe_tensor_name = self._sanitize_name(tensor_name) - filename = f"{safe_layer_name}_{safe_tensor_name}_iter_{iteration:06d}.pt" - filepath = os.path.join(self.root_dir, filename) - + if os.path.exists(filepath): + debug_api.log_message( + f"[TE DumpTensors] Overwriting existing dump file: {filepath}" + ) torch.save(tensor, filepath) From df6605474b638772dc8e14362c51223a5fde7f74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 5 Mar 2026 11:57:58 +0100 Subject: [PATCH 11/27] Update transformer_engine/debug/features/dump_tensors.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/debug/features/dump_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index beee08497b..43d15f3d7d 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -197,7 +197,7 @@ def inspect_tensor( tensor_name: str, iteration: int, tp_group: torch.distributed.ProcessGroup, - tensor: torch.Tensor, + tensor: Optional[torch.Tensor], rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, quantizer: Optional[Quantizer] = None, From ab3e90ea050f74d9e966b8e5bccfd80b4dee56d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:59:39 +0000 Subject: [PATCH 12/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/features/dump_tensors.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 43d15f3d7d..901c5687c1 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -89,9 +89,7 @@ def save_tensor( safe_tensor_name = self._sanitize_name(tensor_name) if os.path.exists(filepath): - debug_api.log_message( - f"[TE DumpTensors] Overwriting existing dump file: {filepath}" - ) + debug_api.log_message(f"[TE DumpTensors] Overwriting existing dump file: {filepath}") torch.save(tensor, filepath) From 089a4d2289f3297241e3362c1ecf1802c6462925 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:13:35 +0100 Subject: [PATCH 13/27] Update tests/pytorch/debug/test_log.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- tests/pytorch/debug/test_log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 38ee729022..b502b45c64 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -685,7 +685,7 @@ def test_dump_tensors_sanity(feature_dirs): pytest.skip(reason_for_no_fp8) with debug_session(DUMP_TENSORS_CONFIG, feature_dirs) as log_dir: - from transformer_engine.pytorch.quantization import RecipeState + recipe_state = RecipeState.create( recipe_state = RecipeState.create( recipe.DelayedScaling(), From a18664f6fbab67467d338392bbdbfd2a72d21067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 5 Mar 2026 13:13:54 +0100 Subject: [PATCH 14/27] Update transformer_engine/debug/features/dump_tensors.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/debug/features/dump_tensors.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 901c5687c1..b1d3bc31e0 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -87,6 +87,10 @@ def save_tensor( safe_layer_name = self._sanitize_name(layer_name) safe_tensor_name = self._sanitize_name(tensor_name) + filepath = os.path.join( + self.root_dir, + f"{safe_layer_name}_{safe_tensor_name}_iter_{iteration:06d}.pt", + ) if os.path.exists(filepath): debug_api.log_message(f"[TE DumpTensors] Overwriting existing dump file: {filepath}") From 41d17fa7cb32500ffed467f860da76e20db6143c Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Mar 2026 12:18:26 +0000 Subject: [PATCH 15/27] fix Signed-off-by: root --- tests/pytorch/debug/test_log.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index b502b45c64..38c05497c8 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -685,8 +685,6 @@ def test_dump_tensors_sanity(feature_dirs): pytest.skip(reason_for_no_fp8) with debug_session(DUMP_TENSORS_CONFIG, feature_dirs) as log_dir: - recipe_state = RecipeState.create( - recipe_state = RecipeState.create( recipe.DelayedScaling(), mode="forward", From 1736cbe9083113c7ca45bce8d01c6ab3afc9c0ae Mon Sep 17 00:00:00 2001 From: root Date: Thu, 5 Mar 2026 14:03:05 +0000 Subject: [PATCH 16/27] fix Signed-off-by: root --- transformer_engine/debug/features/dump_tensors.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index b1d3bc31e0..e2e8aade46 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -295,6 +295,11 @@ def _get_quantized_internals( elif isinstance(quantized_tensor, NVFP4Tensor): tensors = _get_extended_tensors_nvfp4(quantized_tensor) else: + debug_api.log_message( + f"[TE DumpTensors] dump_quantized_internals=True but tensor type " + f"{type(quantized_tensor).__name__} is not supported for internals extraction. " + "Skipping internals." + ) return {} # Filter out None values From b78d36f862bcd18ffbb891a96de92c92fb3bdba9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:03:57 +0000 Subject: [PATCH 17/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/features/dump_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index e2e8aade46..a6446c77fb 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -296,7 +296,7 @@ def _get_quantized_internals( tensors = _get_extended_tensors_nvfp4(quantized_tensor) else: debug_api.log_message( - f"[TE DumpTensors] dump_quantized_internals=True but tensor type " + "[TE DumpTensors] dump_quantized_internals=True but tensor type " f"{type(quantized_tensor).__name__} is not supported for internals extraction. " "Skipping internals." ) From d98c4d06ebd32aab229f2d6e1437b698b82b3a75 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 10 Mar 2026 11:51:08 +0100 Subject: [PATCH 18/27] Remove dump_quantized_internals support from DumpTensors Drop the dump_quantized_internals config option, the _get_quantized_internals method, and all helper functions for extracting scales/raw data from Float8Tensor, Float8BlockwiseQTensor, MXFP8Tensor, and NVFP4Tensor. Remove corresponding tests: test_dump_tensors_nvfp4_unpacked_codes and NVFP4_DUMP_TENSORS_CONFIG, and scale/data assertions from test_dump_tensors_sanity. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 83 --------- .../debug/features/dump_tensors.py | 165 ------------------ 2 files changed, 248 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 38c05497c8..742348bce0 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -658,23 +658,6 @@ def test_compute_max_blockwise_dynamic_range_direct(): tensors: [activation] high_precision_tensor: True quantized_tensor: True - dump_quantized_internals: True - freq: 1 -""" - - -NVFP4_DUMP_TENSORS_CONFIG = """ -dump: - layers: - layer_name_regex_pattern: .* - enabled: True - transformer_engine: - DumpTensors: - enabled: True - tensors: [activation] - high_precision_tensor: False - quantized_tensor: True - dump_quantized_internals: True freq: 1 """ @@ -722,75 +705,9 @@ def test_dump_tensors_sanity(feature_dirs): assert "high_precision" in data, "Missing high_precision tensor" assert "quantized" in data, "Missing quantized tensor" - # Check internals are present (dump_quantized_internals=True) - assert "data" in data, "Missing data (raw FP8 data)" - assert "scale_inv" in data, "Missing scale_inv" - # Verify tensor shapes match assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" print("DumpTensors sanity test passed!") -def test_dump_tensors_nvfp4_unpacked_codes(feature_dirs): - """Verify DumpTensors includes unpacked FP4 values in NVFP4 internals.""" - if not nvfp4_available: - pytest.skip(reason_for_no_nvfp4) - - with debug_session(NVFP4_DUMP_TENSORS_CONFIG, feature_dirs) as log_dir: - recipe_state = RecipeState.create( - recipe.NVFP4BlockScaling(), - mode="forward", - num_quantizers=3, - ) - - tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda() - quantizer = recipe_state.make_quantizers()[0] - quantized_tensor = quantizer(tensor) - - debug_api.transformer_engine.inspect_tensor( - layer_name="test_layer", - tensor_name="activation", - iteration=0, - tp_group=None, - tensor=tensor, - quantizer=quantizer, - rowwise_quantized_tensor=quantized_tensor, - columnwise_quantized_tensor=quantized_tensor, - ) - debug_api.step() - - dump_dir = os.path.join(log_dir, "tensor_dumps", "rank_0") - dump_files = os.listdir(dump_dir) - assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}" - - data = torch.load(os.path.join(dump_dir, dump_files[0]), weights_only=False) - assert "rowwise_data" in data, "Missing packed NVFP4 rowwise_data" - assert "rowwise_data_unpacked_values" in data, "Missing unpacked NVFP4 rowwise values" - - packed = data["rowwise_data"] - unpacked = data["rowwise_data_unpacked_values"] - assert unpacked.dtype == torch.float32, "Unpacked values must be float32" - assert ( - unpacked.shape[-1] == packed.shape[-1] * 2 - ), "Unpacked values should double the last packed dimension" - assert ( - unpacked.min().item() >= -6.0 and unpacked.max().item() <= 6.0 - ), "Decoded FP4 values should be in representable E2M1 range [-6, 6]" - - # Reconstruct dequantized values from unpacked FP4 values and block scales. - # For NVFP4 rowwise path, one E4M3 scale corresponds to a block of 16 values. - assert "rowwise_block_scale_inv" in data, "Missing rowwise NVFP4 block scales" - rowwise_scale_inv = data["rowwise_block_scale_inv"].to(torch.float32) - values = unpacked.to(torch.float32) - n_rows, n_cols = values.shape - scale_tiles = (n_cols + 15) // 16 - expanded_scales = rowwise_scale_inv[:n_rows, :scale_tiles].repeat_interleave(16, dim=1)[ - :, :n_cols - ] - reconstructed = values * expanded_scales - - expected = quantized_tensor.dequantize(dtype=torch.float32) - assert torch.allclose( - reconstructed, expected, atol=1e-5, rtol=1e-3 - ), "Unpacked FP4 values multiplied by block scales should match NVFP4 dequantization" diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index a6446c77fb..eb9dd9ca8a 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -16,12 +16,7 @@ from transformer_engine.debug.features.api import TEConfigAPIMapper from transformer_engine.debug.features.utils import next_enabled_iter -from transformer_engine.pytorch.constants import TE_DType_To_Torch from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor class TensorLogger: @@ -119,10 +114,6 @@ class DumpTensors(TEConfigAPIMapper): If True, dump the high-precision tensor (before quantization). quantized_tensor : bool If True, dump the quantized tensor (after quantization). - dump_quantized_internals : bool, default = False - If True, include extracted internal data from quantized tensors - (raw data, scales, etc.) in the output dictionary. - Useful for offline analysis. Output format may change between versions. tensors/tensors_struct : List[str] list of tensors to dump: - activation @@ -156,7 +147,6 @@ class DumpTensors(TEConfigAPIMapper): - tensor: activation high_precision_tensor: True quantized_tensor: True - dump_quantized_internals: True freq: 100 - tensor: weight high_precision_tensor: True @@ -173,8 +163,6 @@ class DumpTensors(TEConfigAPIMapper): Dictionary keys: - ``high_precision``: pre-quantization tensor (if high_precision_tensor=True) - ``quantized``: quantized tensor object (if quantized_tensor=True) - - Additional internal components when dump_quantized_internals=True - (raw data, scales, etc. - format may change between versions) """ @api_method @@ -255,12 +243,6 @@ def inspect_tensor( if dump_quant and quantized_tensor is not None: dump_dict["quantized"] = quantized_tensor - - # Add internals for quantized tensors - if config.get("dump_quantized_internals", False): - internals = self._get_quantized_internals(quantized_tensor) - dump_dict.update(internals) - elif dump_quant and quantized_tensor is None: debug_api.log_message( f"Feature={self.__class__.__name__}: quantized_tensor is True but " @@ -281,150 +263,3 @@ def inspect_tensor( layer_name, ) - def _get_quantized_internals( - self, - quantized_tensor: QuantizedTensor, - ) -> Dict[str, torch.Tensor]: - """Get internal components of quantized tensors (raw data, scales, etc.).""" - if isinstance(quantized_tensor, Float8Tensor): - tensors = _get_extended_tensors_fp8(quantized_tensor) - elif isinstance(quantized_tensor, Float8BlockwiseQTensor): - tensors = _get_extended_tensors_fp8_blockwise(quantized_tensor) - elif isinstance(quantized_tensor, MXFP8Tensor): - tensors = _get_extended_tensors_mxfp8(quantized_tensor) - elif isinstance(quantized_tensor, NVFP4Tensor): - tensors = _get_extended_tensors_nvfp4(quantized_tensor) - else: - debug_api.log_message( - "[TE DumpTensors] dump_quantized_internals=True but tensor type " - f"{type(quantized_tensor).__name__} is not supported for internals extraction. " - "Skipping internals." - ) - return {} - - # Filter out None values - return {k: v for k, v in tensors.items() if v is not None} - - -def _get_extended_tensors_fp8(tensor: Float8Tensor) -> Dict[str, torch.Tensor]: - """Get extended tensors for Float8Tensor: raw FP8 data, transpose, and scale.""" - torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] - result = { - "data": tensor._data.view(torch_fp8_dtype), - "scale_inv": tensor._scale_inv, - } - if tensor._transpose is not None and not tensor._transpose_invalid: - result["transpose"] = tensor._transpose.view(torch_fp8_dtype) - return result - - -def _get_extended_tensors_fp8_blockwise( - tensor: Float8BlockwiseQTensor, -) -> Dict[str, Optional[torch.Tensor]]: - """Get extended tensors for Float8BlockwiseQTensor: raw FP8 data and block scales.""" - torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] - result: Dict[str, Optional[torch.Tensor]] = {} - - if tensor._rowwise_data is not None: - result["rowwise_data"] = tensor._rowwise_data.view(torch_fp8_dtype) - if tensor._columnwise_data is not None: - result["columnwise_data"] = tensor._columnwise_data.view(torch_fp8_dtype) - - # Block scaling factors (FP32) - if tensor._rowwise_scale_inv is not None: - result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv - if tensor._columnwise_scale_inv is not None: - result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv - - return result - - -def _get_extended_tensors_mxfp8(tensor: MXFP8Tensor) -> Dict[str, Optional[torch.Tensor]]: - """Get extended tensors for MXFP8Tensor: raw FP8 data and block scales (E8M0).""" - torch_fp8_dtype = TE_DType_To_Torch[tensor._fp8_dtype] - result: Dict[str, Optional[torch.Tensor]] = {} - - if tensor._rowwise_data is not None: - result["rowwise_data"] = tensor._rowwise_data.view(torch_fp8_dtype) - if tensor._columnwise_data is not None: - result["columnwise_data"] = tensor._columnwise_data.view(torch_fp8_dtype) - - # Block scaling factors (E8M0 format) - if tensor._rowwise_scale_inv is not None: - result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e8m0fnu) - if tensor._columnwise_scale_inv is not None: - result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view( - torch.float8_e8m0fnu - ) - - return result - - -def _unpack_uint4_codes(packed_data: torch.Tensor) -> torch.Tensor: - """Unpack packed uint4 values stored in uint8 into uint8 tensor with values 0..15.""" - packed_uint8 = packed_data.view(torch.uint8).contiguous().view(-1) - unpacked = torch.empty(packed_uint8.numel() * 2, dtype=torch.uint8, device=packed_data.device) - unpacked[::2] = packed_uint8 & 0x0F - unpacked[1::2] = (packed_uint8 >> 4) & 0x0F - unpacked_shape = (*packed_data.shape[:-1], packed_data.shape[-1] * 2) - return unpacked.view(unpacked_shape) - - -def _decode_uint4_e2m1_to_float(unpacked_codes: torch.Tensor) -> torch.Tensor: - """Decode uint4 FP4 E2M1 codes (0..15) into float32 values.""" - # Bit layout: [sign:1][exp:2][mantissa:1], exponent bias = 1. - # Positive representable magnitudes are: 0, 0.5, 1, 1.5, 2, 3, 4, 6. - fp4_e2m1_lut = torch.tensor( - [ - 0.0, - 0.5, - 1.0, - 1.5, - 2.0, - 3.0, - 4.0, - 6.0, - -0.0, - -0.5, - -1.0, - -1.5, - -2.0, - -3.0, - -4.0, - -6.0, - ], - device=unpacked_codes.device, - dtype=torch.float32, - ) - return fp4_e2m1_lut[unpacked_codes.long()] - - -def _get_extended_tensors_nvfp4(tensor: NVFP4Tensor) -> Dict[str, Optional[torch.Tensor]]: - """Get extended tensors for NVFP4Tensor: raw packed FP4 data, block scales, and amax.""" - result: Dict[str, Optional[torch.Tensor]] = {} - - # Raw data (packed FP4, 2 values per byte) - if tensor._rowwise_data is not None: - result["rowwise_data"] = tensor._rowwise_data - rowwise_codes = _unpack_uint4_codes(tensor._rowwise_data) - result["rowwise_data_unpacked_values"] = _decode_uint4_e2m1_to_float(rowwise_codes) - if tensor._columnwise_data is not None: - result["columnwise_data"] = tensor._columnwise_data - columnwise_codes = _unpack_uint4_codes(tensor._columnwise_data) - result["columnwise_data_unpacked_values"] = _decode_uint4_e2m1_to_float(columnwise_codes) - - # Block scaling factors (E4M3 format) - if tensor._rowwise_scale_inv is not None: - result["rowwise_block_scale_inv"] = tensor._rowwise_scale_inv.view(torch.float8_e4m3fn) - if tensor._columnwise_scale_inv is not None: - result["columnwise_block_scale_inv"] = tensor._columnwise_scale_inv.view( - torch.float8_e4m3fn - ) - - # Input absolute maximum value (used to compute tensor scale) - if tensor._amax_rowwise is not None: - result["amax_rowwise"] = tensor._amax_rowwise - if tensor._amax_columnwise is not None: - result["amax_columnwise"] = tensor._amax_columnwise - - return result From 23c70e513952fb96e2534842c1a41042aed1d8ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 10:52:05 +0000 Subject: [PATCH 19/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 2 -- transformer_engine/debug/features/dump_tensors.py | 1 - 2 files changed, 3 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 742348bce0..01505ded19 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -709,5 +709,3 @@ def test_dump_tensors_sanity(feature_dirs): assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" print("DumpTensors sanity test passed!") - - diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index eb9dd9ca8a..abd174f1ee 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -262,4 +262,3 @@ def inspect_tensor( f"Dumped {tensor_name} at iteration {iteration} (keys: {list(dump_dict.keys())})", layer_name, ) - From 8357ebeb261249810d18b64c402394449104312f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 10 Mar 2026 12:02:58 +0100 Subject: [PATCH 20/27] Address Greptile review comments - Add dot ('.') to _sanitize_name to handle common PyTorch dotted layer names like 'encoder.layer.0.attention' - Add docstring note about pickle dependency for the 'quantized' key - Add comment explaining weights_only=False in test - Remove redundant local RecipeState import in test_nvfp4_numeric Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 5 ++--- transformer_engine/debug/features/dump_tensors.py | 7 ++++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 01505ded19..4b1daa9148 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -445,9 +445,6 @@ def test_nvfp4_numeric(feature_dirs): log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse") with debug_session(log_nvfp4_config, feature_dirs) as log_dir: - from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer - from transformer_engine.pytorch.quantization import RecipeState - recipe_state = RecipeState.create( recipe.NVFP4BlockScaling(), mode="forward", @@ -699,6 +696,8 @@ def test_dump_tensors_sanity(feature_dirs): # Load and verify structure dump_file = os.path.join(dump_dir, dump_files[0]) + # weights_only=False is required because the dump may contain QuantizedTensor objects, + # which are custom Python classes incompatible with the safe weights_only=True path. data = torch.load(dump_file, weights_only=False) assert isinstance(data, dict), "Dump should be a dictionary" diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index abd174f1ee..3482b33017 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -63,7 +63,7 @@ def ensure_initialized(self, root_log_dir: str) -> None: @staticmethod def _sanitize_name(name: str) -> str: """Sanitize layer/tensor names for use in file paths.""" - for char in ["/", "\\", ":", "*", "?", '"', "<", ">", "|", " "]: + for char in ["/", "\\", ":", "*", "?", '"', "<", ">", "|", " ", "."]: name = name.replace(char, "_") return name @@ -163,6 +163,11 @@ class DumpTensors(TEConfigAPIMapper): Dictionary keys: - ``high_precision``: pre-quantization tensor (if high_precision_tensor=True) - ``quantized``: quantized tensor object (if quantized_tensor=True) + + .. note:: + The ``quantized`` value is a pickled ``QuantizedTensor`` object. Loading it + (with ``weights_only=False``) requires the same version of TransformerEngine + to be installed. For maximum portability, dequantize the tensor before saving. """ @api_method From 41c671e6fd8e4cad9d0a6219d84d0e310f23a193 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 10 Mar 2026 12:08:48 +0100 Subject: [PATCH 21/27] Remove portability suggestion from quantized key docstring Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/features/dump_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 3482b33017..5d37dd5c87 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -167,7 +167,7 @@ class DumpTensors(TEConfigAPIMapper): .. note:: The ``quantized`` value is a pickled ``QuantizedTensor`` object. Loading it (with ``weights_only=False``) requires the same version of TransformerEngine - to be installed. For maximum portability, dequantize the tensor before saving. + to be installed. """ @api_method From 0cd16e5fd37a4a037c569ecdabe3f8d9f043a7f8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 10 Mar 2026 12:14:17 +0100 Subject: [PATCH 22/27] Compute rank lazily in _expected_root_dir Avoids relying on stale self.rank when ensure_initialized is called before initialize() has set the rank. Consistent with how nvdlfw_inspect logger resolves rank. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/features/dump_tensors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 5d37dd5c87..393d2afc60 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -52,7 +52,8 @@ def initialize(self, root_log_dir: str): def _expected_root_dir(self, root_log_dir: str) -> str: """Return the rank-specific dump directory for the provided root log path.""" - return os.path.join(root_log_dir, "tensor_dumps", f"rank_{self.rank}") + rank = dist.get_rank() if dist.is_initialized() else 0 + return os.path.join(root_log_dir, "tensor_dumps", f"rank_{rank}") def ensure_initialized(self, root_log_dir: str) -> None: """Reinitialize logger if debug session log directory changed.""" From 6f21734e6a19229a04ae333c55a459bbac910544 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 10 Mar 2026 12:38:36 +0100 Subject: [PATCH 23/27] detach tensors before saving; verify dump filename in test Detach both high_precision and quantized tensors before saving to avoid serializing the autograd graph. For QuantizedTensor this is a zero-copy view (make_like), so no extra GPU allocation. Add filename format assertion to test_dump_tensors_sanity to catch regressions in _sanitize_name or the naming convention. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 3 +++ transformer_engine/debug/features/dump_tensors.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 4b1daa9148..d3db6232af 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -693,6 +693,9 @@ def test_dump_tensors_sanity(feature_dirs): dump_files = os.listdir(dump_dir) assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}" + assert dump_files[0] == "test_layer_activation_iter_000000.pt", ( + f"Unexpected dump filename: {dump_files[0]}" + ) # Load and verify structure dump_file = os.path.join(dump_dir, dump_files[0]) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 393d2afc60..e5fb654382 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -239,7 +239,7 @@ def inspect_tensor( dump_dict: Dict[str, torch.Tensor] = {} if dump_hp and tensor is not None: - dump_dict["high_precision"] = tensor + dump_dict["high_precision"] = tensor.detach() elif dump_hp and tensor is None: debug_api.log_message( f"Feature={self.__class__.__name__}: high_precision_tensor is True but " @@ -248,7 +248,7 @@ def inspect_tensor( ) if dump_quant and quantized_tensor is not None: - dump_dict["quantized"] = quantized_tensor + dump_dict["quantized"] = quantized_tensor.detach() elif dump_quant and quantized_tensor is None: debug_api.log_message( f"Feature={self.__class__.__name__}: quantized_tensor is True but " From 7d368118eeaf5ebf2c55b074a85ab4bee031d756 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 11:39:25 +0000 Subject: [PATCH 24/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index d3db6232af..1685c1bf37 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -693,9 +693,9 @@ def test_dump_tensors_sanity(feature_dirs): dump_files = os.listdir(dump_dir) assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}" - assert dump_files[0] == "test_layer_activation_iter_000000.pt", ( - f"Unexpected dump filename: {dump_files[0]}" - ) + assert ( + dump_files[0] == "test_layer_activation_iter_000000.pt" + ), f"Unexpected dump filename: {dump_files[0]}" # Load and verify structure dump_file = os.path.join(dump_dir, dump_files[0]) From c7b7f014906a58426a59fd70398845f3b2ce0773 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 10 Mar 2026 13:33:16 +0100 Subject: [PATCH 25/27] Add empty dump_dict log; assert QuantizedTensor type in test Log a message when no tensors are available to dump so the user has an explicit signal that no file was written. Assert that the quantized key round-trips as a QuantizedTensor to catch regressions in detach() or serialisation path. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 4 ++++ transformer_engine/debug/features/dump_tensors.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 1685c1bf37..051a820692 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -18,6 +18,7 @@ is_nvfp4_available, ) from transformer_engine.pytorch.quantization import RecipeState +from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.features.utils.stats_computation import ( compute_max_blockwise_dynamic_range, @@ -706,6 +707,9 @@ def test_dump_tensors_sanity(feature_dirs): assert isinstance(data, dict), "Dump should be a dictionary" assert "high_precision" in data, "Missing high_precision tensor" assert "quantized" in data, "Missing quantized tensor" + assert isinstance(data["quantized"], QuantizedTensor), ( + f"Expected QuantizedTensor, got {type(data['quantized'])}" + ) # Verify tensor shapes match assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index e5fb654382..970e85fa9a 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -268,3 +268,9 @@ def inspect_tensor( f"Dumped {tensor_name} at iteration {iteration} (keys: {list(dump_dict.keys())})", layer_name, ) + else: + debug_api.log_message( + f"Feature={self.__class__.__name__}: No tensors available to dump for " + f"{tensor_name} at iteration {iteration}. No file written.", + layer_name, + ) From 2fcd7eb8b75f249678ef4b23b3bc22dfd469315a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:34:12 +0000 Subject: [PATCH 26/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 051a820692..0f7f465a52 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -707,9 +707,9 @@ def test_dump_tensors_sanity(feature_dirs): assert isinstance(data, dict), "Dump should be a dictionary" assert "high_precision" in data, "Missing high_precision tensor" assert "quantized" in data, "Missing quantized tensor" - assert isinstance(data["quantized"], QuantizedTensor), ( - f"Expected QuantizedTensor, got {type(data['quantized'])}" - ) + assert isinstance( + data["quantized"], QuantizedTensor + ), f"Expected QuantizedTensor, got {type(data['quantized'])}" # Verify tensor shapes match assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" From 677ad512ec8698b7581144f94b030ec06454e7c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Tue, 10 Mar 2026 13:42:10 +0100 Subject: [PATCH 27/27] Update transformer_engine/debug/features/dump_tensors.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/debug/features/dump_tensors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/debug/features/dump_tensors.py b/transformer_engine/debug/features/dump_tensors.py index 970e85fa9a..1ce195672e 100644 --- a/transformer_engine/debug/features/dump_tensors.py +++ b/transformer_engine/debug/features/dump_tensors.py @@ -210,7 +210,7 @@ def inspect_tensor( and columnwise_quantized_tensor is not None and rowwise_quantized_tensor is not columnwise_quantized_tensor ): - raise AssertionError( + raise ValueError( "[NVTORCH INSPECT ERROR] DumpTensors expects rowwise_quantized_tensor and " "columnwise_quantized_tensor to be the same object when both are provided." )