Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fe437c1
code drop
pggPL Feb 3, 2026
a54a743
code drop
pggPL Feb 3, 2026
b6e0767
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 3, 2026
76d362c
Merge branch 'main' into inpsect_tensor_dump_support
pggPL Mar 5, 2026
dc60fe8
docs
pggPL Mar 5, 2026
e94467f
nvfp4 internals support
pggPL Mar 5, 2026
e8c8e56
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
b002b89
lint fixes
pggPL Mar 5, 2026
2816f37
Update transformer_engine/debug/features/dump_tensors.py
pggPL Mar 5, 2026
83506af
fix
pggPL Mar 5, 2026
a525f82
Update transformer_engine/debug/features/dump_tensors.py
pggPL Mar 5, 2026
df66054
Update transformer_engine/debug/features/dump_tensors.py
pggPL Mar 5, 2026
ab3e90e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
089a4d2
Update tests/pytorch/debug/test_log.py
pggPL Mar 5, 2026
a18664f
Update transformer_engine/debug/features/dump_tensors.py
pggPL Mar 5, 2026
41d17fa
fix
pggPL Mar 5, 2026
1736cbe
fix
pggPL Mar 5, 2026
b78d36f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2026
d98c4d0
Remove dump_quantized_internals support from DumpTensors
pggPL Mar 10, 2026
23c70e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
8357ebe
Address Greptile review comments
pggPL Mar 10, 2026
41c671e
Remove portability suggestion from quantized key docstring
pggPL Mar 10, 2026
0cd16e5
Compute rank lazily in _expected_root_dir
pggPL Mar 10, 2026
6f21734
detach tensors before saving; verify dump filename in test
pggPL Mar 10, 2026
7d36811
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
c7b7f01
Add empty dump_dict log; assert QuantizedTensor type in test
pggPL Mar 10, 2026
2fcd7eb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 10, 2026
677ad51
Update transformer_engine/debug/features/dump_tensors.py
pggPL Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/debug/3_api_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ Debug features
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
.. autoapiclass:: transformer_engine.debug.features.dump_tensors.DumpTensors
77 changes: 74 additions & 3 deletions tests/pytorch/debug/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
is_nvfp4_available,
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_computation import (
compute_max_blockwise_dynamic_range,
Expand Down Expand Up @@ -445,9 +446,6 @@ def test_nvfp4_numeric(feature_dirs):
log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse")

with debug_session(log_nvfp4_config, feature_dirs) as log_dir:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.quantization import RecipeState

recipe_state = RecipeState.create(
recipe.NVFP4BlockScaling(),
mode="forward",
Expand Down Expand Up @@ -644,3 +642,76 @@ def test_compute_max_blockwise_dynamic_range_direct():
)

print("All direct tests for compute_max_blockwise_dynamic_range passed!")


# DumpTensors tests
DUMP_TENSORS_CONFIG = """
dump:
layers:
layer_name_regex_pattern: .*
enabled: True
transformer_engine:
DumpTensors:
enabled: True
tensors: [activation]
high_precision_tensor: True
quantized_tensor: True
freq: 1
"""


def test_dump_tensors_sanity(feature_dirs):
"""Sanity test for DumpTensors feature - verify files are created with correct structure."""
if not fp8_available:
pytest.skip(reason_for_no_fp8)

with debug_session(DUMP_TENSORS_CONFIG, feature_dirs) as log_dir:
recipe_state = RecipeState.create(
recipe.DelayedScaling(),
mode="forward",
num_quantizers=3,
)

tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)

debug_api.transformer_engine.inspect_tensor(
layer_name="test_layer",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()

# Check that dump file was created
dump_dir = os.path.join(log_dir, "tensor_dumps", "rank_0")
assert os.path.exists(dump_dir), f"Dump directory not created: {dump_dir}"

dump_files = os.listdir(dump_dir)
assert len(dump_files) == 1, f"Expected 1 dump file, got {len(dump_files)}"
assert (
dump_files[0] == "test_layer_activation_iter_000000.pt"
), f"Unexpected dump filename: {dump_files[0]}"

# Load and verify structure
dump_file = os.path.join(dump_dir, dump_files[0])
# weights_only=False is required because the dump may contain QuantizedTensor objects,
# which are custom Python classes incompatible with the safe weights_only=True path.
data = torch.load(dump_file, weights_only=False)

assert isinstance(data, dict), "Dump should be a dictionary"
assert "high_precision" in data, "Missing high_precision tensor"
assert "quantized" in data, "Missing quantized tensor"
assert isinstance(
data["quantized"], QuantizedTensor
), f"Expected QuantizedTensor, got {type(data['quantized'])}"

# Verify tensor shapes match
assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch"

print("DumpTensors sanity test passed!")
276 changes: 276 additions & 0 deletions transformer_engine/debug/features/dump_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""DumpTensors Feature support for nvidia-dlframework-inspect."""

import os
from typing import Dict, Optional

import torch
import torch.distributed as dist

import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.logging import get_logger
from nvdlfw_inspect.registry import Registry, api_method

from transformer_engine.debug.features.api import TEConfigAPIMapper
from transformer_engine.debug.features.utils import next_enabled_iter
from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer


class TensorLogger:
"""Logger for saving tensors to files. Each rank saves to its own directory."""

_instance = None
_initialized = False

def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance

def __init__(self):
if TensorLogger._initialized:
return
self.root_dir = None
self.rank = 0
TensorLogger._initialized = True

def initialize(self, root_log_dir: str):
"""Initialize the TensorLogger with the root directory for tensor dumps."""
self.rank = 0
if dist.is_initialized():
self.rank = dist.get_rank()

self.root_dir = self._expected_root_dir(root_log_dir)
os.makedirs(self.root_dir, exist_ok=True)

debug_api.log_message(
f"TensorLogger initialized. Saving tensors to: {self.root_dir}",
)

def _expected_root_dir(self, root_log_dir: str) -> str:
"""Return the rank-specific dump directory for the provided root log path."""
rank = dist.get_rank() if dist.is_initialized() else 0
return os.path.join(root_log_dir, "tensor_dumps", f"rank_{rank}")

def ensure_initialized(self, root_log_dir: str) -> None:
"""Reinitialize logger if debug session log directory changed."""
expected_root_dir = self._expected_root_dir(root_log_dir)
if self.root_dir != expected_root_dir or not os.path.isdir(expected_root_dir):
self.initialize(root_log_dir)

@staticmethod
def _sanitize_name(name: str) -> str:
"""Sanitize layer/tensor names for use in file paths."""
for char in ["/", "\\", ":", "*", "?", '"', "<", ">", "|", " ", "."]:
name = name.replace(char, "_")
return name

def save_tensor(
self,
tensor,
layer_name: str,
tensor_name: str,
iteration: int,
):
"""Save a tensor (or dict of tensors) to a file."""
if self.root_dir is None:
raise RuntimeError(
"[TE DumpTensors] TensorLogger not initialized. Call initialize() first."
)

safe_layer_name = self._sanitize_name(layer_name)
safe_tensor_name = self._sanitize_name(tensor_name)
filepath = os.path.join(
self.root_dir,
f"{safe_layer_name}_{safe_tensor_name}_iter_{iteration:06d}.pt",
)

if os.path.exists(filepath):
debug_api.log_message(f"[TE DumpTensors] Overwriting existing dump file: {filepath}")
torch.save(tensor, filepath)


def _get_tensor_logger() -> TensorLogger:
"""Get the singleton TensorLogger instance."""
return TensorLogger()


@Registry.register_feature(namespace="transformer_engine")
class DumpTensors(TEConfigAPIMapper):
"""
Dump tensors to files for debugging purposes.

This feature saves tensors to disk using torch.save(). It supports dumping
both high-precision tensors (before quantization) and quantized tensors.

Each tensor is saved to a separate file with the iteration number, layer name,
and tensor name in the filename. Files are organized per-rank in distributed settings.

Parameters
----------
high_precision_tensor : bool
If True, dump the high-precision tensor (before quantization).
quantized_tensor : bool
If True, dump the quantized tensor (after quantization).
tensors/tensors_struct : List[str]
list of tensors to dump:
- activation
- gradient
- weight
- output
- wgrad
- dgrad
freq : Optional[int], default = 1
frequency of dumping tensors, tensors will be dumped every `freq` steps
start_step : Optional[int], default = 0
start step of dumping tensors
end_step : Optional[int], default = -1
end step of dumping tensors (-1 means no end)
start_end_list : Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order.
If not None, will ignore start_step and end_step

Example
-------
.. code-block:: yaml

dump_tensors_example:
enabled: True
layers:
layer_name_regex_pattern: .*(fc1|self_attention).*
transformer_engine:
DumpTensors:
enabled: True
tensors_struct:
- tensor: activation
high_precision_tensor: True
quantized_tensor: True
freq: 100
- tensor: weight
high_precision_tensor: True
quantized_tensor: False
freq: 500

Output Structure
----------------
Files are saved to: ``{nvdlfw_inspect_log_dir}/tensor_dumps/rank_{rank}/``

Each tensor is saved as a dictionary in a single file:
``{layer}_{tensor}_iter_{iter:06d}.pt``

Dictionary keys:
- ``high_precision``: pre-quantization tensor (if high_precision_tensor=True)
- ``quantized``: quantized tensor object (if quantized_tensor=True)

.. note::
The ``quantized`` value is a pickled ``QuantizedTensor`` object. Loading it
(with ``weights_only=False``) requires the same version of TransformerEngine
to be installed.
"""

@api_method
def inspect_tensor_enabled(
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor() in the forward."""
run_current, next_iter = next_enabled_iter(
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
return run_current, next_iter

@api_method
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
tensor: Optional[torch.Tensor],
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
): # pylint: disable=unused-argument
"""
API call used to dump tensors to files.

Supports dumping both high-precision tensors and quantized tensors based on config.
"""
# We support one-sided availability (only rowwise or only columnwise tensor).
# If both are present, require them to be the same object to avoid ambiguity.
if (
rowwise_quantized_tensor is not None
and columnwise_quantized_tensor is not None
and rowwise_quantized_tensor is not columnwise_quantized_tensor
):
raise ValueError(
"[NVTORCH INSPECT ERROR] DumpTensors expects rowwise_quantized_tensor and "
"columnwise_quantized_tensor to be the same object when both are provided."
)

quantized_tensor = (
rowwise_quantized_tensor
if rowwise_quantized_tensor is not None
else columnwise_quantized_tensor
)

dump_hp = config.get("high_precision_tensor", False)
dump_quant = config.get("quantized_tensor", False)

if not dump_hp and not dump_quant:
debug_api.log_message(
f"Feature={self.__class__.__name__}: Neither high_precision_tensor nor "
"quantized_tensor is enabled. Nothing to dump.",
layer_name,
)
return

tensor_logger = _get_tensor_logger()
tensor_logger.ensure_initialized(get_logger().root_log_dir)

# Build dictionary with all tensors to dump
dump_dict: Dict[str, torch.Tensor] = {}

if dump_hp and tensor is not None:
dump_dict["high_precision"] = tensor.detach()
elif dump_hp and tensor is None:
debug_api.log_message(
f"Feature={self.__class__.__name__}: high_precision_tensor is True but "
f"no high-precision tensor available for {tensor_name}. Skipping.",
layer_name,
)

if dump_quant and quantized_tensor is not None:
dump_dict["quantized"] = quantized_tensor.detach()
elif dump_quant and quantized_tensor is None:
debug_api.log_message(
f"Feature={self.__class__.__name__}: quantized_tensor is True but "
f"no quantized tensor available for {tensor_name}. Skipping.",
layer_name,
)

if dump_dict:
tensor_logger.save_tensor(
tensor=dump_dict,
layer_name=layer_name,
tensor_name=tensor_name,
iteration=iteration,
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor: "
f"Dumped {tensor_name} at iteration {iteration} (keys: {list(dump_dict.keys())})",
layer_name,
)
else:
debug_api.log_message(
f"Feature={self.__class__.__name__}: No tensors available to dump for "
f"{tensor_name} at iteration {iteration}. No file written.",
layer_name,
)
3 changes: 1 addition & 2 deletions transformer_engine/debug/features/log_fp8_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@

import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine_torch as tex

from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
import transformer_engine_torch as tex

from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
Expand Down
Loading