Skip to content

Conversation

@vthumbe1503
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 2 commits January 5, 2026 18:11
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
vthumbe1503 and others added 4 commits January 6, 2026 12:34
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review January 7, 2026 17:22
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Greptile Summary

This PR introduces CPU-side optimizations for FP8 tensor operations, reducing overhead from repeated function calls and Python object attribute access.

Key Changes:

  • Cached nvte_is_non_tn_fp8_gemm_supported() result in cublaslt_gemm.cu to avoid repeated function calls during GEMM configuration
  • Made num_devices static in nvte_is_non_tn_fp8_gemm_supported() to cache device count
  • Replaced pybind11 keyword arguments with direct Python C API calls in quantizer.cpp for lower overhead
  • Added cached dtype, requires_grad, shape, and is_cuda properties to QuantizedTensor and subclasses to avoid expensive PyObject attribute lookups
  • Cached requires_grad values at the start of _Linear.forward() to reduce repeated attribute access
  • Used std::call_once for thread-safe extension initialization in pybind.cpp

Architecture Notes:

  • The Python C API usage in quantizer.cpp is technically correct - PyDict_SetItemString increments reference counts before temporaries are destroyed
  • The requires_grad logic change in linear.py correctly preserves OR semantics (matching the original requires_grad() helper function behavior)

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - the optimizations are well-implemented and maintain existing semantics.
  • Score of 4 reflects solid CPU optimizations with correct implementation. Minor deductions for: (1) the Python C API pattern in quantizer.cpp could leak memory if exceptions occur mid-execution, though this is unlikely in practice; (2) is_cuda property in Float8Tensor lacks defensive check that shape property has.
  • quantizer.cpp - review exception safety of Python C API usage; float8_tensor.py - is_cuda property edge case

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_gemm.cu CPU optimization caching nvte_is_non_tn_fp8_gemm_supported() result to avoid repeated function calls. Clean implementation.
transformer_engine/pytorch/csrc/quantizer.cpp Replaced pybind11 keyword arguments with direct Python C API calls for performance. Multiple potential memory leak concerns if exceptions occur.
transformer_engine/pytorch/module/linear.py Cached requires_grad values at start of forward pass to avoid repeated attribute access. Correctly preserved OR semantics.
transformer_engine/pytorch/quantized_tensor.py Added cached dtype and requires_grad properties. Proper initialization in __new__ and correct sync with parent tensor.
transformer_engine/pytorch/tensor/float8_tensor.py Added cached shape and is_cuda properties. is_cuda lacks defensive check when both _data and _transpose are None.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as _Linear.forward()
    participant QTensor as QuantizedTensor
    participant Quantizer as quantizer.cpp
    participant PyAPI as Python C API
    participant GEMM as cublaslt_gemm.cu

    User->>Linear: forward(inp, weight, bias)
    Note over Linear: Cache requires_grad values early<br/>(inp_requires_grad, weight_requires_grad, bias_requires_grad)
    
    Linear->>QTensor: Access dtype/requires_grad
    Note over QTensor: Return cached _dtype/_requires_grad<br/>(avoids PyObject lookup)
    
    Linear->>Quantizer: create_tensor()
    Note over Quantizer: Cache nvte_is_non_tn_fp8_gemm_supported()
    Quantizer->>PyAPI: PyDict_New(), PyDict_SetItemString()
    Note over PyAPI: Direct C API bypasses<br/>pybind11 overhead
    PyAPI-->>Quantizer: Float8Tensor instance
    
    Linear->>GEMM: Execute GEMM
    Note over GEMM: Use cached is_nvte_non_tn_fp8_gemm_supported<br/>(single call per function invocation)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (3)

  1. transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)

    logic: Critical logical error: || should be &&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs.

  2. transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)

    style: commented-out code for requires_grad caching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  3. transformer_engine/pytorch/module/linear.py, line 484 (link)

    logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.

    Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?

10 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.

Key Changes:

  • Caches requires_grad, dtype, shape, and is_cuda attribute accesses to avoid expensive PyObject lookups on custom tensors
  • Reorders attribute checks in get_tensor_device() to prioritize internal quantized tensor attributes
  • Makes num_devices static in nvte_is_non_tn_fp8_gemm_supported() to cache device count
  • Stores GEMM support check results in local variables to avoid redundant function calls

Critical Issues Found:

  • Variable redeclaration error in cublaslt_gemm.cu (line 224) will prevent compilation
  • Logic bug in linear.py (line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad

Confidence Score: 0/5

  • This PR cannot be merged due to compilation error and critical logic bug
  • Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
  • Pay close attention to transformer_engine/common/gemm/cublaslt_gemm.cu (compilation error) and transformer_engine/pytorch/module/linear.py (logic bug)

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/gemm/cublaslt_gemm.cu 1/5 Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure
transformer_engine/common/transformer_engine.cpp 5/5 Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization
transformer_engine/pytorch/module/linear.py 0/5 Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Linear as Linear Module
    participant Quantizer as Quantizer/QuantizedTensor
    participant GEMM as GEMM Operations
    participant CPP as C++ Extensions

    Note over Linear,CPP: Performance Optimization Flow
    
    User->>Linear: forward(input, weight, bias)
    
    Note over Linear: Cache requires_grad checks
    Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
    
    Linear->>Quantizer: Check if quantized tensor
    alt QuantizedTensor
        Note over Quantizer: Use cached dtype property
        Quantizer->>Quantizer: return self._dtype
        Note over Quantizer: Use cached shape/is_cuda
        Quantizer->>Quantizer: return self._data.shape
    else Regular Tensor
        Quantizer->>Linear: Standard attribute access
    end
    
    Linear->>CPP: get_tensor_device(tensor)
    Note over CPP: Reordered attribute checks
    CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
    CPP-->>Linear: device_index
    
    Linear->>GEMM: Configure GEMM parameters
    Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
    GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
    Note over CPP: Static num_devices cached
    CPP-->>GEMM: support_flag
    GEMM->>GEMM: Store in local variable
    
    GEMM->>GEMM: Execute optimized GEMM
    GEMM-->>Linear: output
    
    Note over Linear: FP8 State Management
    alt FP8 enabled and requires_grad check
        Linear->>Linear: Update FP8 tensors<br/>based on cached flags
    end
    
    Linear-->>User: output
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 7, 2026

Additional Comments (2)

transformer_engine/common/gemm/cublaslt_gemm.cu
variable redeclared in same scope - already declared at line 132

    // int is_nvte_non_tn_fp8_gemm_supported already declared at line 132

transformer_engine/pytorch/module/linear.py
logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True, breaking FP8 state management when bias is None or doesn't require grad

            if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.

The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move this declaration before line 125 (Configure A matrix)
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();

Comment on lines 165 to 166
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.

This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.

Suggested change
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.

This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.

Suggested change
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
// Remove this line - device parameter is not accepted
// PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

Comment on lines 374 to 378
def requires_grad(self) -> bool:
# Attribute access of custom tensors goes through an
# expensive Pyobject lookup. Since requires_grad is set during
# initialization and may be updated, we cache it in a member variable.
return self._requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: The requires_grad property getter accesses self._requires_grad at line 378, but this member variable is never initialized in the __new__ method (lines 353-364). This will cause an AttributeError when the property is first accessed.

The _requires_grad attribute must be initialized in __new__:

Suggested change
def requires_grad(self) -> bool:
# Attribute access of custom tensors goes through an
# expensive Pyobject lookup. Since requires_grad is set during
# initialization and may be updated, we cache it in a member variable.
return self._requires_grad
@property
def requires_grad(self) -> bool:
# Attribute access of custom tensors goes through an
# expensive Pyobject lookup. Since requires_grad is set during
# initialization and may be updated, we cache it in a member variable.
# Fallback to parent if not cached yet
if not hasattr(self, '_requires_grad'):
self._requires_grad = torch.Tensor.requires_grad.fget(self)
return self._requires_grad

Comment on lines +363 to 364
instance._dtype = dtype
return instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing initialization: The _requires_grad member variable needs to be initialized here to support the cached property added later in the PR. Without this, accessing the requires_grad property will fail.

Suggested change
instance._dtype = dtype
return instance
instance._dtype = dtype
instance._requires_grad = requires_grad
return instance

PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
bool is_extension_initialized = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.

The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:

Suggested change
bool is_extension_initialized = false;
bool is_extension_initialized = false;
std::once_flag extension_init_flag;

And in init_extension():

void init_extension() {
  std::call_once(extension_init_flag, []() {
    init_float8_extension();
    init_mxfp8_extension();
    init_float8blockwise_extension();
    init_nvfp4_extensions();
  });
}

Comment on lines +964 to 966
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.

While dynamic device changes are rare, consider adding a bounds check:

Suggested change
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count");

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration

Comment on lines 146 to 150
py::object scale_inv_py = py::cast(scale_inv);
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: scale_inv_py is created on line 146 by dereferencing scale_inv (which is std::optional<at::Tensor>), but scale_inv is only initialized on lines 148-150 if it's empty. When scale_inv is not provided to this function, line 146 will dereference an empty optional, causing undefined behavior or a crash.

The fix is to move line 146 after the initialization block:

Suggested change
py::object scale_inv_py = py::cast(scale_inv);
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Initialize scale-inverse tensor
if (!scale_inv) {
scale_inv = at::reciprocal(scale);
}
py::object scale_inv_py = py::cast(*scale_inv);

PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).

Suggested change
PyTuple_New(0), kwargs);
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs);
Py_DECREF(args);

Comment on lines +363 to 364
instance._dtype = dtype
return instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: _dtype is cached in __new__ but _requires_grad is not initialized. The requires_grad property (line 378) tries to return self._requires_grad, which will raise AttributeError when accessed before the setter is called. This breaks the optimization for caching requires_grad.

Suggested change
instance._dtype = dtype
return instance
instance._dtype = dtype
instance._requires_grad = requires_grad
return instance

vthumbe1503 and others added 3 commits January 11, 2026 19:12
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 13 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Reuse is_nvte_non_tn_fp8_gemm_supported from line 132

}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();

py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.

Suggested change
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
// Move scale_inv_py creation after initialization (line 150)

PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:

Suggested change
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
Py_DECREF(args);

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same issue as line 166).

Suggested change
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs);
Py_DECREF(args);


ctx.owns_input = saved_inputmat is not inp
if ctx.fp8 and requires_grad(inp, weight, bias):
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Logic error: The original requires_grad(inp, weight, bias) returns True if ANY tensor requires grad. This change requires ALL THREE to be True with and, breaking FP8 state management when bias is None or doesn't require grad.

Suggested change
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr());
PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Variable is_nvte_non_tn_fp8_gemm_supported is declared twice in the same function scope (first at line 132). This will cause a compilation error. Remove this redeclaration and reuse the variable from line 132.

}
py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();

py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Duplicate line. transpose_py is assigned twice (lines 145 and 146). Remove this duplicate line.

PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());

PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
PyTuple_New(0), kwargs);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Memory leak: PyTuple_New(0) creates a new Python object that is never decremented. Store the result and call Py_DECREF(tuple) after PyObject_Call. This pattern repeats in all 12 calls to PyObject_Call throughout this file (lines 167, 186, 331, 350, 397, 416, 706, 725, 1042, 1061, 1347, 1366).

vthumbe1503 and others added 2 commits January 11, 2026 20:04
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +156 to +171
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potential memory leak if py::cast() throws exception before Py_DECREF

If any of the py::cast(this->dtype), py::cast(shape_int64), or py::cast(device) calls throw an exception, the kwargs and args PyObjects will leak because Py_DECREF on lines 168-169 will never be reached.

Suggested change
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
// Use RAII for Python object lifetime management
py::dict kwargs_py;
kwargs_py["data"] = data_py;
kwargs_py["fp8_scale_inv"] = scale_inv_py;
kwargs_py["fp8_dtype"] = py::cast(this->dtype);
kwargs_py["data_transpose"] = transpose_py;
kwargs_py["quantizer"] = this->quantizer;
PyObject* args = PyTuple_New(0);
PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass),
args, kwargs_py.ptr());
Py_DECREF(args);

This same pattern repeats throughout the file at lines 176-193, 382-400, 411-428, 657-678, 686-707, 1013-1034, 1042-1063, 1304-1325, and 1333-1354. All instances should use pybind11's py::dict for exception safety.

Comment on lines +151 to +152
at::Device device =
with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

device defaults to cuda:0 when both data and transpose are None, may not be current device

In multi-GPU setups, if neither with_data nor with_transpose are true, the device defaults to torch::kCUDA (which implicitly converts to cuda:0). This may not match the current CUDA device, potentially causing cross-device operation issues.

Suggested change
at::Device device =
with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA);
at::Device device =
with_data ? data->device() : (with_transpose ? transpose->device() : at::Device(torch::kCUDA, torch::cuda::current_device()));

This same issue exists at line 383 in Float8CurrentScalingQuantizer::create_tensor.

Comment on lines 79 to 89
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordering attribute checks changes behavior when devices don't match

The new code checks quantized tensor attributes (_rowwise_data, _columnwise_data, etc.) before checking the standard .device attribute. If a tensor has both and they refer to different devices, this changes which device is returned.

For example, if tensor.device == cuda:1 but tensor._rowwise_data.device == cuda:0, the old code would return 1 but the new code returns 0. This could break multi-GPU code that relies on the .device attribute being authoritative.

Consider documenting this behavior change or adding a check that all device attributes match if multiple are present.

Comment on lines 367 to 383
@property
def dtype(self) -> torch.dtype:
"""
Return the high precision data type of the tensor
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since dtype for a tensor is never
change after creation, we cache it in a member variable and return
"""
return self._dtype

@property
def requires_grad(self) -> bool:
"""
Return whether or not the tensor requires gradient.
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since requires_grad is set during
initialization and may be updated, we cache it in a member variable.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_dtype and _requires_grad only initialized in __new__, will cause AttributeError in other creation paths

The new dtype and requires_grad properties rely on cached _dtype and _requires_grad attributes, but these are only initialized in the __new__ method (lines 363-364). If a QuantizedTensor is created through other paths such as:

  • Unpickling (__reduce_ex__, __setstate__)
  • torch.Tensor subclass operations that bypass __new__
  • FSDP state dict loading

Accessing .dtype or .requires_grad will raise AttributeError: '_dtype' not found.

Suggested change
@property
def dtype(self) -> torch.dtype:
"""
Return the high precision data type of the tensor
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since dtype for a tensor is never
change after creation, we cache it in a member variable and return
"""
return self._dtype
@property
def requires_grad(self) -> bool:
"""
Return whether or not the tensor requires gradient.
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since requires_grad is set during
initialization and may be updated, we cache it in a member variable.
@property
def dtype(self) -> torch.dtype:
"""
Return the high precision data type of the tensor
"""
# Lazy initialization for tensors created via alternate paths
if not hasattr(self, '_dtype'):
self._dtype = super().__getattribute__('dtype')
return self._dtype

Apply similar fix to requires_grad property.

Comment on lines 913 to 916
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape property will raise AttributeError if both _data and _transpose are None

The property assumes at least one of _data or _transpose is not None, but doesn't validate the _transpose attribute before accessing it. If both are None (edge case during initialization or after certain operations), accessing .shape will raise AttributeError: 'NoneType' object has no attribute 'shape'.

Suggested change
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
if self._data is not None:
return self._data.shape
elif self._transpose is not None:
return self._transpose.shape
else:
# Fallback to getting shape from alternate data sources or raise descriptive error
raise RuntimeError("Float8Tensor has no data buffers initialized")

Same issue exists in the is_cuda property at lines 919-921.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

ctx.tp_size = tp_size
ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.requires_dgrad = inp_requires_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic change from requires_grad(inp, weight, bias) to (inp_requires_grad or weight_requires_grad or bias_requires_grad) is correct for the optimization, but this breaks the FP8 state management semantic.

The original requires_grad() function returns True if ANY tensor requires grad (see utils.py line 23-28). The new code uses and instead of or:

if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):

However, looking more carefully at the actual change, the new code uses or, which is semantically equivalent to the old requires_grad(inp, weight, bias) call. This is actually correct!

But wait - I need to double-check the previous thread comment. The previous thread says "logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True". Let me verify this is indeed using or not and...

Comment on lines +367 to +375
@property
def dtype(self) -> torch.dtype:
"""
Return the high precision data type of the tensor
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since dtype for a tensor is never
change after creation, we cache it in a member variable and return
"""
return self._dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype property accesses self._dtype without checking if it exists. While QuantizedTensor.__new__() initializes this at line 364, subclasses like Float8Tensor can be instantiated through paths that may not properly initialize these cached attributes.

For example, when Float8TensorStorage.__new__() calls super().__new__(cls, *args, **kwargs), if dtype is not in kwargs, QuantizedTensor.__new__() won't receive it and won't set instance._dtype.

Consider adding defensive checks or ensuring all instantiation paths properly initialize _dtype and _requires_grad, or catch AttributeError and fall back to the parent class property.

Comment on lines +157 to +171
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential memory leak if any operation between PyDict_New()/PyTuple_New() and Py_DECREF() throws an exception or fails. The py::cast() calls at lines 161-163 could throw C++ exceptions, and PyDict_SetItemString() could fail, leaving kwargs and args unreleased.

Consider using RAII wrappers or try-catch blocks to ensure cleanup:

PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
if (!kwargs || !args) {
    Py_XDECREF(kwargs);
    Py_XDECREF(args);
    throw std::runtime_error("Failed to allocate Python objects");
}

try {
    PyDict_SetItemString(kwargs, "data", data_py.ptr());
    // ... rest of setup
    PyObject* result = PyObject_Call(...);
    Py_DECREF(kwargs);
    Py_DECREF(args);
    NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
    out_py = py::reinterpret_steal<py::object>(result);
} catch (...) {
    Py_DECREF(kwargs);
    Py_DECREF(args);
    throw;
}

Or better yet, use py::dict and py::tuple from pybind11 which handle cleanup automatically.

Comment on lines 913 to 921
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape

@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
return self._data.is_cuda if self._data is not None else self._transpose.is_cuda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape and is_cuda properties assume that either self._data or self._transpose is not None. However, looking at the quantizer logic in Float8Quantizer.create_tensor() (quantizer.cpp lines 126-137), there are cases where:

  • with_data = rowwise_usage || is_non_tn_fp8_gemm_supported
  • with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported

If both rowwise_usage and columnwise_usage are False, and is_non_tn_fp8_gemm_supported is False, then both _data and _transpose would be None, causing AttributeError when accessing .shape or .is_cuda.

This scenario should probably be prevented at quantizer initialization, or these properties should handle the None case gracefully.

Comment on lines 79 to 89
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reordering of device attribute checks changes the priority of which device is returned. Previously, if a tensor had both device attribute and internal attributes like _data, it would return tensor.device.index. Now it returns the device of the internal attribute.

For QuantizedTensor instances where these might differ (though unlikely), this could change behavior. Consider the case where:

  • A QuantizedTensor has _data on device 0
  • But the tensor itself reports being on device 1 (via its device property)

The old code would return device 1, the new code returns device 0. While this scenario is unlikely given how QuantizedTensor is implemented, it's a semantic change worth noting. If the internal attributes are guaranteed to be on the same device as the tensor, this is fine. Otherwise, document this priority or add an assertion.


int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making num_devices static has a subtle initialization order issue. The static initialization happens once when the function is first called, but the subsequent static vectors cache and flags depend on num_devices for their size.

If transformer_engine::cuda::num_devices() returns different values across multiple calls (which shouldn't happen in practice but isn't guaranteed by the API), the first call to this function will initialize num_devices, and subsequent calls will use that cached value. However, if the CUDA context changes or devices are added/removed (in rare scenarios), this could cause a mismatch.

Consider:

static int num_devices = transformer_engine::cuda::num_devices();

This is initialized once, but cache and flags vectors might need a different size if the device count somehow changes. While unlikely, this could cause out-of-bounds access.

A safer approach might be:

static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
  cache.resize(num_devices, -1);
  flags.resize(num_devices);
});

Or simply document that the device count must not change during the application's lifetime.

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

vthumbe1503 and others added 2 commits January 12, 2026 08:27
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile


Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error check occurs AFTER Py_DECREF(kwargs) and Py_DECREF(args) (lines 168-169), which means if result is nullptr, the program will terminate via NVTE_CHECK without proper cleanup. While this is acceptable for fatal errors, it's inconsistent with the cleanup-before-check pattern and could mask the actual Python exception message.

Suggested improvement: Check for nullptr before decrementing reference counts, and use PyErr_Print() or PyErr_Fetch() to preserve the Python exception information:

Suggested change
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);
if (result == nullptr) {
PyErr_Print(); // Print the actual Python error
Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(false, "Failed to create Float8TensorStorage instance");
}
Py_DECREF(kwargs);
Py_DECREF(args);
out_py = py::reinterpret_steal<py::object>(result);

This same pattern applies to all other PyObject_Call invocations in this file (lines 187, 401, 706, 723, 1047, 1068, 1357, 1378).

Comment on lines 387 to 393
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Update the cached value
self._requires_grad = value
# Call parent class to ensure autograd engine is aware of the change
torch.Tensor.requires_grad.fset(self, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The requires_grad property setter updates the cached _requires_grad value and then calls torch.Tensor.requires_grad.fset(self, value) to ensure the autograd engine is aware. However, there's a potential issue: if the parent class's setter fails or raises an exception, _requires_grad will already be updated, causing the cached value to be out of sync with the actual tensor state.

Suggested improvement: Call the parent setter first, then update the cache only if successful:

Suggested change
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Update the cached value
self._requires_grad = value
# Call parent class to ensure autograd engine is aware of the change
torch.Tensor.requires_grad.fset(self, value)
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Call parent class first to ensure autograd engine is aware of the change
torch.Tensor.requires_grad.fset(self, value)
# Update the cached value only after successful update
self._requires_grad = value

The same pattern should be applied to requires_grad_() method on lines 402-409.

Comment on lines 395 to 400
@dtype.setter
def dtype(self, value: torch.dtype) -> None:
"""Set dtype property"""
# Update the cached value
self._dtype = value
warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype property has a setter that allows changing the dtype after creation, with a warning. However, this is problematic because:

  1. The warning says "Dtype of QuantizedTensor has been changed" but doesn't indicate this could break things
  2. QuantizedTensor subclasses (Float8Tensor, MXFP8Tensor, etc.) have quantized data in specific formats (FP8, MXFP4) that don't change when you update _dtype
  3. Changing dtype without re-quantizing the underlying data creates a mismatch between the cached dtype and the actual data representation

This setter creates a dangerous API surface. If dtype must be mutable, it should either:

  • Raise an error instead of just warning
  • Or trigger re-quantization of the data

Current impact: Users could accidentally corrupt tensor semantics by changing dtype, leading to incorrect dequantization.

Comment on lines 910 to 920
)
return out, all_gather_outputs

@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape

@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new shape and is_cuda property implementations have a logic issue: they return properties from self._data if it exists, otherwise from self._transpose. However, according to the codebase logic in cublaslt_gemm.cu and quantizer.cpp, on Hopper (non-Blackwell) architectures, FP8 tensors may only have _transpose data for certain GEMM layouts, while on Blackwell they may only have _data.

The properties assume that if _data is None, then _transpose must exist, but there's no fallback handling if both are somehow None (which shouldn't happen but could in edge cases).

Suggested improvement: Add a safety check or document the invariant that at least one must always exist:

Suggested change
)
return out, all_gather_outputs
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return self._data.shape if self._data is not None else self._transpose.shape
@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
if self._data is not None:
return self._data.shape
if self._transpose is not None:
return self._transpose.shape
raise RuntimeError("Float8Tensor has neither _data nor _transpose")
@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
if self._data is not None:
return self._data.is_cuda
if self._transpose is not None:
return self._transpose.is_cuda
raise RuntimeError("Float8Tensor has neither _data nor _transpose")

Comment on lines +810 to +817
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return (
self._rowwise_data.shape
if self._rowwise_data is not None
else self._columnwise_data.shape
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as Float8Tensor: the shape and is_cuda properties assume that if _rowwise_data is None, then _columnwise_data must exist. However, there's no safety check for the edge case where both could be None.

Suggested improvement: Add safety checks:

Suggested change
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return (
self._rowwise_data.shape
if self._rowwise_data is not None
else self._columnwise_data.shape
)
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
if self._rowwise_data is not None:
return self._rowwise_data.shape
if self._columnwise_data is not None:
return self._columnwise_data.shape
raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data")
@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
if self._rowwise_data is not None:
return self._rowwise_data.is_cuda
if self._columnwise_data is not None:
return self._columnwise_data.is_cuda
raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +157 to +170
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error checking for Python C API calls could lead to crashes. The code doesn't check return values from PyDict_New(), PyTuple_New(), PyDict_SetItemString(), or PyObject_Call() before proceeding. If any of these fail (e.g., due to memory allocation failure), the subsequent operations will access invalid pointers.

Recommendation:
Add error checks after each C API call:

  • Check if PyDict_New() and PyTuple_New() return NULL
  • Check if PyDict_SetItemString() returns -1
  • The existing NVTE_CHECK(result != nullptr, ...) is good, but should come before Py_DECREF calls to avoid decrementing invalid references

This pattern is repeated throughout this file at lines 176-193, 389-404, 410-426, 688-706, 710-727, 1035-1050, 1054-1070, 1343-1361, and 1365-1383. All instances need the same error checking.

Comment on lines 377 to 393
@property
def requires_grad(self) -> bool:
"""
Return whether or not the tensor requires gradient.
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since requires_grad is set during
initialization and may be updated, we cache it in a member variable.
"""
return self._requires_grad

@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Update the cached value
self._requires_grad = value
# Call parent class to ensure autograd engine is aware of the change
torch.Tensor.requires_grad.fset(self, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential cache inconsistency if PyTorch modifies requires_grad through internal mechanisms. While the setter and requires_grad_() method properly sync the cache with the parent tensor, there's a risk that PyTorch's autograd engine could modify requires_grad through paths that bypass these methods (e.g., during graph construction or when tensors are involved in certain operations).

Current behavior: The cached value could become stale if PyTorch modifies requires_grad internally.

Mitigation: The implementation correctly updates the parent tensor in the setter and requires_grad_(), which should handle most cases. However, consider documenting this limitation or adding validation in critical paths to ensure cache consistency.

Note: This is a minor concern as most requires_grad modifications go through the standard setter, but worth being aware of for debugging purposes.

Comment on lines +189 to +193
Py_DECREF(kwargs);
Py_DECREF(args);

NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error checking order issue: NVTE_CHECK for result != nullptr happens after Py_DECREF(kwargs) and Py_DECREF(args). While decrefing these is safe even when result is NULL, checking for errors before cleanup is better practice and makes the error path clearer.

If PyObject_Call fails and returns NULL, Python's exception state is set. The NVTE_CHECK should happen immediately after the call, and cleanup should be part of an error handling path. Consider moving the check before Py_DECREF or using proper error handling with early returns.

This pattern affects all PyObject_Call instances in this file.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +964 to 966
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: Potential out-of-bounds access when device_id >= num_devices

Making num_devices static causes it to be evaluated only once at first call. This creates a serious bug if the device ID is not within the range [0, num_devices):

  1. Problem: CUDA device IDs are not necessarily contiguous. For example, with CUDA_VISIBLE_DEVICES=0,2, you have 2 visible devices but device IDs are 0 and 2.
  2. Bug scenario: If num_devices is initialized to 2, but later code calls current_device() which returns 2, then cache[device_id] at line 975 will access cache[2] which is out of bounds (cache size is 2, valid indices are 0-1).
  3. Impact: This will cause undefined behavior, likely a segmentation fault or memory corruption.

Fix: Either:

  • Use current_device() as the cache key and bounds-check it
  • Map actual device IDs to cache indices
  • Remove the static keyword and re-evaluate num_devices on each call (sacrifices some of the optimization but ensures correctness)

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Additional Comments (1)

tests/pytorch/test_fusible_ops.py
Behavior change: Removing None from _quantization_list changes test coverage

The original line was:

_quantization_list: list[Optional[str]] = [None]

Now it's:

_quantization_list: list[Optional[str]] = []

This means tests will no longer run with quantization=None (i.e., without quantization). This removes test coverage for the non-quantized code path.

Question: Is this intentional? If this is a CPU optimization PR, removing coverage for non-quantized tests seems unrelated to the stated goals. Please clarify if this change is intended or accidental.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, 12 comments

Edit Code Review Agent Settings | Greptile


int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Potential out-of-bounds access with static num_devices

Making num_devices static creates a serious bug when GPUs are hot-plugged after the first call to this function. The cache and flags vectors are sized based on the initial device count, but device_id from current_device() could exceed num_devices if GPUs are added later.

This will cause out-of-bounds access on lines 968 and 975:

std::call_once(flags[device_id], ...);  // OOB if device_id >= num_devices
return cache[device_id];                 // OOB if device_id >= num_devices

Impact: Undefined behavior, potential crashes, memory corruption.

Fix: Query device count each time, or add bounds checking:

int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::mutex resize_mutex;

int device_id = transformer_engine::cuda::current_device();
{
  std::lock_guard<std::mutex> lock(resize_mutex);
  if (device_id >= cache.size()) {
    cache.resize(device_id + 1, -1);
    flags.resize(device_id + 1);
  }
}
std::call_once(flags[device_id], [&]() { ... });

Comment on lines +157 to +171
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak when exceptions occur

The direct C API usage introduces memory leak vulnerabilities. If any py::cast() call throws an exception or if PyDict_SetItemString() fails after PyDict_New() and PyTuple_New(), the allocated kwargs and args will leak.

Problems:

  1. No NULL checks after PyDict_New() and PyTuple_New()
  2. Exception safety: py::cast() can throw, but cleanup code won't run
  3. Multiple py::cast() calls create temporary objects that could throw

Example failure scenario:

PyObject* kwargs = PyDict_New();  // Allocated
PyObject* args = PyTuple_New(0);   // Allocated
// ... several PyDict_SetItemString calls ...
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());  // Throws!
// Cleanup code never runs -> kwargs and args leaked

Fix: Use RAII for proper cleanup:

struct PyObjectDeleter {
  void operator()(PyObject* obj) { Py_XDECREF(obj); }
};
using PyObjectPtr = std::unique_ptr<PyObject, PyObjectDeleter>;

PyObjectPtr kwargs(PyDict_New());
PyObjectPtr args(PyTuple_New(0));
NVTE_CHECK(kwargs && args, "Failed to allocate Python objects");
// Now safe even if exceptions occur

This issue affects all quantizer types (Float8, Float8CurrentScaling, Float8Block, MXFP8, NVFP4).

Comment on lines +176 to +193
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr());
PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr());
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "device", py::cast(device).ptr());
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);

NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak when exceptions occur (non-internal path)

Same memory leak vulnerability as the internal path above. The kwargs and args will leak if any exception occurs between allocation and cleanup.

Additionally, this block has more py::cast() calls (8 vs 5), increasing the risk of exceptions:

  • py::cast(shape_int64).ptr()
  • py::cast(GetATenDType(dtype)).ptr()
  • py::cast(this->dtype).ptr()
  • py::cast(device).ptr()

Each of these can throw, and none are wrapped in exception-safe RAII wrappers.

Fix: Apply the same RAII solution as suggested in the previous comment.

Comment on lines 79 to 88
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance regression for regular tensors

The reordering of attribute checks optimizes for quantized tensors but creates a performance regression for regular torch.Tensor objects.

Before: Regular tensors check device immediately (line 1) and return
After: Regular tensors check 4 non-existent attributes via hasattr() before checking device

Impact:

  • Quantized tensors: ~4 fewer attribute checks ✓
  • Regular tensors: ~4 additional attribute checks ✗

Concern: If regular tensors are passed to get_tensor_device() frequently (which seems likely given the function name), this could offset the gains from other optimizations in this PR.

Recommendation: Profile both paths or add an early isinstance check:

def get_tensor_device(tensor: torch.Tensor) -> int:
    # Fast path for regular tensors
    if type(tensor).__name__ == 'Tensor':
        return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device()
    
    # Quantized tensor path
    if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
        return tensor._rowwise_data.device.index
    # ... etc

Comment on lines +390 to +406
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* args = PyTuple_New(0);
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(args);
Py_DECREF(kwargs);

NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak in Float8CurrentScalingQuantizer (internal path)

Same memory leak vulnerability as Float8Quantizer. The kwargs and args will leak if any exception occurs between allocation and the Py_DECREF cleanup code. See comment on lines 157-171 for detailed explanation and fix.

Comment on lines +706 to +728

NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance");
ret = py::reinterpret_steal<py::object>(result);
} else {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
ret = Float8BlockwiseQTensorClass(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2),
"data_format"_a = data_format);
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).ptr());
PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr());
PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr());
PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr());
PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr());
PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());
PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr());
PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr());
PyObject* args = PyTuple_New(0);
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);

NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak in Float8BlockQuantizer (non-internal path)

Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.

Comment on lines 1034 to +1050
if (internal) {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass));
out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr());
PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr());
PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), args, kwargs);

Py_DECREF(args);
Py_DECREF(kwargs);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak in MXFP8Quantizer (internal path)

Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 157-171 for detailed explanation and fix.

Comment on lines +1052 to +1072
out_py = py::reinterpret_steal<py::object>(result);
} else {
py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass));
out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = rowwise_data_py,
"columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer);
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr());
PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr());
PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr());
PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr());
PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* args = PyTuple_New(0);
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass), args, kwargs);

Py_DECREF(args);
Py_DECREF(kwargs);

NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak in MXFP8Quantizer (non-internal path)

Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.

Comment on lines 1343 to +1361
if (internal) {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass));
out_py = NVFP4TensorClass(
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr());
PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr());
PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr());
PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr());
PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* args = PyTuple_New(0);
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);

NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak in NVFP4Quantizer (internal path)

Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 157-171 for detailed explanation and fix.

Comment on lines 1363 to +1385
} else {
py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass));
out_py = NVFP4TensorClass(
"shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype),
"rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py,
"rowwise_scale_inv"_a = rowwise_scale_inv_py,
"columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py,
"amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer);
// Use direct C API call bypassing pybind11 overhead
PyObject* kwargs = PyDict_New();
PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr());
PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr());
PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr());
PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr());
PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr());
PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr());
PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* args = PyTuple_New(0);
PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass), args, kwargs);

Py_DECREF(args);
Py_DECREF(kwargs);

NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Memory leak in NVFP4Quantizer (non-internal path)

Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: Memory safety bug - Use-after-free with temporary py::object

The expression py::cast(this->dtype).ptr() creates a temporary py::object that is destroyed immediately after .ptr() returns, but before PyDict_SetItemString can increment its reference count. This results in passing a dangling pointer to PyDict_SetItemString, which is undefined behavior and will likely cause crashes or memory corruption.

The same issue occurs on multiple lines in this file:

  • Line 161: py::cast(this->dtype).ptr()
  • Line 178: py::cast(shape_int64).ptr()
  • Line 179: py::cast(GetATenDType(dtype)).ptr()
  • Line 182: py::cast(this->dtype).ptr()
  • Line 185: py::cast(device).ptr()

And similar patterns throughout the file for other quantizer types.

Fix: Store the py::object in a local variable before calling .ptr():

py::object fp8_dtype_py = py::cast(this->dtype);
PyDict_SetItemString(kwargs, "fp8_dtype", fp8_dtype_py.ptr());

Or use py::cast(this->dtype).inc_ref().ptr() to explicitly increment the reference count before extracting the pointer.


ctx.owns_input = saved_inputmat is not inp
if ctx.fp8 and requires_grad(inp, weight, bias):
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: Logic error - Changed from OR to AND breaks FP8 state management

The original code used the requires_grad() helper function which returns True if ANY of the tensors require grad (OR logic). The new code requires ALL THREE tensors to require grad (AND logic).

Original behavior:

if ctx.fp8 and requires_grad(inp, weight, bias):  # True if ANY require grad

New behavior:

if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):  # Correct!

But wait, this line actually looks correct with OR logic. Let me re-check... Actually, this specific line IS correct. The issue mentioned in previous threads was already addressed in this PR. My apologies for the confusion on this specific line - it's actually implemented correctly.

Comment on lines +384 to +407
@property
def requires_grad(self) -> bool:
"""
Return whether or not the tensor requires gradient.
Attribute access of custom tensors goes through an
expensive Pyobject lookup. Since requires_grad is set during
initialization and may be updated, we cache it in a member variable.
"""
return self._requires_grad

@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Update the cached value and call parent class method to ensure autograd engine is aware
self.requires_grad_(value)

def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor:
"""Cache requires_grad property and call parent class method"""
# pylint: disable=missing-function-docstring
# Update the cached value
self._requires_grad = requires_grad
# Call parent class method to ensure autograd engine is aware
super().requires_grad_(requires_grad)
return self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential cache staleness issue with requires_grad property

The cached _requires_grad attribute could become stale if the underlying tensor's requires_grad is modified through mechanisms other than the provided setter. For example:

  1. PyTorch's autograd engine might directly modify the tensor's requires_grad flag
  2. Operations like tensor.requires_grad_() from the parent class might bypass the cached value
  3. Deserialization or other internal PyTorch operations might set requires_grad without going through the custom setter

While the setter correctly updates both the cache and calls super().requires_grad_(value), the getter always returns the cached value, which could diverge from the actual tensor state.

Recommendation: Consider either:

  • Synchronizing the cache by checking super().requires_grad periodically
  • Documenting this limitation clearly
  • Or accepting this as a known trade-off for performance (which may be acceptable if requires_grad is rarely modified after initialization)

Comment on lines +157 to +171
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error handling for Python C API calls

The code doesn't check for failures in Python C API calls:

  1. PyDict_New() can return NULL on memory allocation failure
  2. PyTuple_New(0) can return NULL on memory allocation failure
  3. PyDict_SetItemString() can fail and return -1

If PyDict_New() or PyTuple_New() fails, the subsequent calls will crash when dereferencing NULL pointers. If PyDict_SetItemString() fails, the dictionary will be incomplete but the error is silently ignored.

Recommendation: Add error checks:

PyObject* kwargs = PyDict_New();
if (!kwargs) throw std::runtime_error("Failed to create kwargs dict");
PyObject* args = PyTuple_New(0);
if (!args) {
    Py_DECREF(kwargs);
    throw std::runtime_error("Failed to create args tuple");
}
// Check PyDict_SetItemString return values or use NVTE_CHECK

This pattern repeats throughout all the create_tensor methods in this file.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 10 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +157 to +171
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
PyDict_SetItemString(kwargs, "data", data_py.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr());
PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr());

PyObject* result =
PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs);

Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential reference counting issue with temporary py::object

The pattern PyDict_SetItemString(kwargs, "key", py::cast(value).ptr()) creates temporary py::object instances that may lead to subtle reference counting issues. The temporary object from py::cast() will be destroyed at the end of the statement, potentially before PyDict_SetItemString completes in edge cases.

Safer pattern:

PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);

py::object data_obj = data_py;
py::object scale_inv_obj = scale_inv_py;
py::object dtype_obj = py::cast(this->dtype);
py::object transpose_obj = transpose_py;
py::object quantizer_obj = this->quantizer;

PyDict_SetItemString(kwargs, "data", data_obj.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_obj.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", dtype_obj.ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_obj.ptr());
PyDict_SetItemString(kwargs, "quantizer", quantizer_obj.ptr());

This ensures all py::object instances remain alive while their pointers are being used. The same issue exists in multiple locations throughout this file (lines 176-185, 389-397, 409-417, and many more).

Comment on lines +358 to 360
instance._requires_grad = requires_grad
instance._dtype = dtype
return instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing cached attribute initialization in alternate construction paths

The cached attributes _requires_grad and _dtype are initialized here in __new__, but there are several code paths where QuantizedTensor objects can be created that may bypass this initialization:

  1. Pickling/unpickling: When tensors are pickled and unpickled, __reduce_ex__ and __setstate__ may not properly initialize these cached attributes
  2. Tensor operations: Operations like .clone(), .detach(), torch.cat(), etc. may create new tensor instances
  3. C++ constructors: When tensors are created from C++ code via quantizer

Impact: Accessing .dtype or .requires_grad on such tensors will raise AttributeError: '_dtype' not found.

Recommendation: Add defensive checks in the property getters:

@property
def dtype(self) -> torch.dtype:
    if not hasattr(self, '_dtype'):
        self._dtype = super().dtype
    return self._dtype

@property
def requires_grad(self) -> bool:
    if not hasattr(self, '_requires_grad'):
        self._requires_grad = super().requires_grad
    return self._requires_grad

Comment on lines +389 to +393
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
# Update the cached value and call parent class method to ensure autograd engine is aware
self.requires_grad_(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect setter implementation causes infinite recursion

The requires_grad setter calls self.requires_grad_(value) which then calls super().requires_grad_(requires_grad). However, setting requires_grad through assignment (e.g., tensor.requires_grad = True) will invoke this setter, which doesn't update self._requires_grad before calling the parent method.

The correct flow should be:

  1. Update cached value self._requires_grad
  2. Call parent requires_grad_() to notify autograd

But line 399 in requires_grad_() updates the cache AFTER calling super, which means the setter doesn't actually update the cache.

Fix:

@requires_grad.setter
def requires_grad(self, value: bool) -> None:
    """Set requires_grad property so that autograd engine is aware of the change"""
    self._requires_grad = value
    super().requires_grad_(value)


int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential race condition in static variable initialization

Making num_devices static improves performance by caching the device count, but this introduces a subtle race condition. While the cache and flags vectors are protected by std::call_once, num_devices itself is initialized on first function call without synchronization.

Scenario:

  1. Thread A enters function first time, starts calling transformer_engine::cuda::num_devices()
  2. Thread B enters function while A is still in num_devices(), sees uninitialized static
  3. Both threads may initialize num_devices, causing undefined behavior

However, this is likely safe in practice because:

  • The function is typically called during module initialization (single-threaded)
  • The value doesn't change after initialization
  • Most compilers guarantee thread-safe static initialization in C++11

But for correctness, consider wrapping in the existing std::call_once:

static int num_devices = []() {
  return transformer_engine::cuda::num_devices();
}();

Or initialize it within the std::call_once block if device_id doesn't need num_devices.

)
# pylint: disable=unnecessary-dunder-call
super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor)
self.dtype = tensor.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cached dtype not initialized in all construction paths

This line sets self.dtype = tensor.dtype to initialize the cached _dtype attribute, but this only happens in the _set_data method when copying from another Float8Tensor AND certain conditions are met (size/stride/dtype mismatch).

Missing initialization in:

  1. Initial tensor construction (first time Float8Tensor is created)
  2. When sizes/strides match (lines 1002-1006 don't call this setter)
  3. _make_in_reduce_ex unpickling path (line 942)
  4. Other construction paths

This will cause AttributeError when accessing .dtype property on tensors created through other paths.

Fix: Ensure _dtype is initialized in Float8Tensor.__new__ or add the defensive check mentioned in the QuantizedTensor comment.

Comment on lines 40 to 54
void init_float8_extension() {
if (Float8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Float8CurrentScalingQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer"));
Float8TensorPythonClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor"));
auto fp8_base_module =
py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage");
Float8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage"));
NVTE_CHECK(Float8TensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch Float8 extension.");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed null-check guards without proper synchronization guarantee

The original code had null checks (if (Float8TensorPythonClass) return;) to prevent redundant initialization. These were removed in favor of std::call_once, but the individual init_* functions no longer have guards.

Issue: If init_extension() is called multiple times from different threads, std::call_once ensures the lambda runs once. However, if init_float8_extension() is called directly (not through init_extension()), it will re-import modules and reassign global pointers without protection.

Verification needed: Confirm that these functions are NEVER called directly, only through init_extension(). If they can be called directly, the null checks should be restored:

void init_float8_extension() {
  if (Float8TensorPythonClass) return;  // Keep this guard
  auto fp8_module = py::module_::import(...);
  ...
}

Comment on lines +932 to +934
dtype = inp.dtype
# All checks after this have already been performed once, thus skip
if self.activation_dtype == inp.dtype:
if self.activation_dtype == dtype:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimization: Cache inp.dtype early to avoid redundant attribute access

Good optimization! By caching inp.dtype early (line 932), the code avoids a second attribute lookup that would have occurred on line 936 in the original code. This is a clean CPU optimization with no functional change.

Note: This is correctly placed BEFORE the early return check, ensuring the cached value is available for the comparison on line 933.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines 79 to 89
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reordering device attribute checks changes behavior for regular tensors

The original code checked tensor.device FIRST (line 78-79 in old code), then checked quantized tensor attributes. The new code reverses this order, checking quantized attributes first.

Impact:

  • Performance: For regular torch.Tensor objects, this INCREASES overhead by attempting to check for _rowwise_data, _columnwise_data, _data, and _transpose attributes (all of which will fail via hasattr) before finally checking device.
  • Correctness: For QuantizedTensors that happen to have a .device attribute that differs from their internal data device, this could change which device is returned.

Concern: This optimization seems to hurt performance for the common case (regular tensors) while only helping the less common case (quantized tensors).

Question: What is the relative frequency of this function being called with regular tensors vs quantized tensors? If regular tensors are more common, this change may degrade overall performance.

Comment on lines 913 to 924
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
if self._data is not None:
return self._data.shape
else:
transpose_shape = self._transpose.shape
return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))

@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error handling when both _data and _transpose are None

The shape property assumes at least one of _data or _transpose is not None, but there's no validation:

@property
def shape(self):
    if self._data is not None:
        return self._data.shape
    else:
        transpose_shape = self._transpose.shape  # AttributeError if _transpose is None!
        return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))

Similarly for is_cuda:

@property
def is_cuda(self):
    return self._data.is_cuda if self._data is not None else self._transpose.is_cuda
    # AttributeError if both are None!

When can this happen?

  • During tensor construction before data is initialized
  • After certain operations that clear internal data
  • In error/edge cases

Fix: Add defensive check:

@property
def shape(self):
    if self._data is not None:
        return self._data.shape
    elif self._transpose is not None:
        transpose_shape = self._transpose.shape
        return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))
    else:
        # Fallback to base class or raise informative error
        return super().shape

Comment on lines +809 to +824
@property
def shape(self):
"""Return the shape of the tensor. Define this to avoid expensive PyObject lookups."""
return (
self._rowwise_data.shape
if self._rowwise_data is not None
else self._columnwise_data.shape
)

@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
return (
self._rowwise_data.is_cuda
if self._rowwise_data is not None
else self._columnwise_data.is_cuda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing error handling when both rowwise_data and columnwise_data are None

Same issue as Float8Tensor - these properties assume at least one data tensor exists:

@property
def shape(self):
    return (
        self._rowwise_data.shape
        if self._rowwise_data is not None
        else self._columnwise_data.shape  # AttributeError if both are None!
    )

@property
def is_cuda(self):
    return (
        self._rowwise_data.is_cuda
        if self._rowwise_data is not None
        else self._columnwise_data.is_cuda  # AttributeError if both are None!
    )

Fix: Add validation:

@property
def shape(self):
    if self._rowwise_data is not None:
        return self._rowwise_data.shape
    elif self._columnwise_data is not None:
        return self._columnwise_data.shape
    else:
        return super().shape  # or raise informative error

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503
Copy link
Collaborator Author

/te-ci L1 pytorch

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +925 to +927
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
return self._data.is_cuda if self._data is not None else self._transpose.is_cuda
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Unlike the shape property above which has a defensive check (raise RuntimeError("Both data and transpose are None")), is_cuda will raise AttributeError if both _data and _transpose are None. Consider adding the same defensive check for consistency:

Suggested change
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
return self._data.is_cuda if self._data is not None else self._transpose.is_cuda
@property
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
if self._data is not None:
return self._data.is_cuda
elif self._transpose is not None:
return self._transpose.is_cuda
else:
raise RuntimeError("Both data and transpose are None")

Comment on lines +819 to +823
def is_cuda(self):
"""Return whether the tensor is on a CUDA device."""
return (
self._rowwise_data.is_cuda
if self._rowwise_data is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Same consistency issue as Float8Tensor - is_cuda will raise AttributeError if both _rowwise_data and _columnwise_data are None, while there's no defensive check like shape should have. Consider adding error handling for this edge case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant