-
Notifications
You must be signed in to change notification settings - Fork 607
CPU Optimizations for FP8 #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
Greptile SummaryThis PR introduces CPU-side optimizations for FP8 tensor operations, reducing overhead from repeated function calls and Python object attribute access. Key Changes:
Architecture Notes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant Linear as _Linear.forward()
participant QTensor as QuantizedTensor
participant Quantizer as quantizer.cpp
participant PyAPI as Python C API
participant GEMM as cublaslt_gemm.cu
User->>Linear: forward(inp, weight, bias)
Note over Linear: Cache requires_grad values early<br/>(inp_requires_grad, weight_requires_grad, bias_requires_grad)
Linear->>QTensor: Access dtype/requires_grad
Note over QTensor: Return cached _dtype/_requires_grad<br/>(avoids PyObject lookup)
Linear->>Quantizer: create_tensor()
Note over Quantizer: Cache nvte_is_non_tn_fp8_gemm_supported()
Quantizer->>PyAPI: PyDict_New(), PyDict_SetItemString()
Note over PyAPI: Direct C API bypasses<br/>pybind11 overhead
PyAPI-->>Quantizer: Float8Tensor instance
Linear->>GEMM: Execute GEMM
Note over GEMM: Use cached is_nvte_non_tn_fp8_gemm_supported<br/>(single call per function invocation)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (3)
-
transformer_engine/pytorch/csrc/util.cpp, line 18-20 (link)logic: Critical logical error:
||should be&&. This condition will always betruesince a value cannot simultaneously be both scaling modes, causing the function to always return nullopt for valid inputs. -
transformer_engine/pytorch/quantized_tensor.py, line 373-393 (link)style: commented-out code for
requires_gradcaching optimization - consider removing dead code entirely. Is this code planned to be implemented later or should it be removed?Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
-
transformer_engine/pytorch/module/linear.py, line 484 (link)logic: Logical error: this condition should use OR (||) not AND (&&). The original logic was checking if ANY tensor requires gradients for FP8 handling, but this now only activates when ALL three require gradients, including bias which may be None.
Should the FP8 condition check if any tensor requires gradients (OR logic) rather than all tensors (AND logic)?
10 files reviewed, 3 comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR implements CPU-side performance optimizations for FP8 operations by caching frequently accessed attributes and reducing redundant function calls. The optimizations target expensive PyObject attribute lookups on custom tensor types and repeated C++ function calls.
Key Changes:
- Caches
requires_grad,dtype,shape, andis_cudaattribute accesses to avoid expensive PyObject lookups on custom tensors - Reorders attribute checks in
get_tensor_device()to prioritize internal quantized tensor attributes - Makes
num_devicesstatic innvte_is_non_tn_fp8_gemm_supported()to cache device count - Stores GEMM support check results in local variables to avoid redundant function calls
Critical Issues Found:
- Variable redeclaration error in
cublaslt_gemm.cu(line 224) will prevent compilation - Logic bug in
linear.py(line 484) changes FP8 state management from OR logic to AND logic, breaking functionality when bias is None or doesn't require grad
Confidence Score: 0/5
- This PR cannot be merged due to compilation error and critical logic bug
- Two critical issues prevent merging: (1) C++ compilation will fail due to variable redeclaration at line 224 of cublaslt_gemm.cu, and (2) logic bug at line 484 of linear.py breaks FP8 state management by requiring all three tensors to have requires_grad=True instead of any one of them
- Pay close attention to
transformer_engine/common/gemm/cublaslt_gemm.cu(compilation error) andtransformer_engine/pytorch/module/linear.py(logic bug)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/gemm/cublaslt_gemm.cu | 1/5 | Caches function call result to reduce overhead, but contains variable redeclaration error that will cause compilation failure |
| transformer_engine/common/transformer_engine.cpp | 5/5 | Makes num_devices static to avoid redundant calls to cuda::num_devices() - valid optimization |
| transformer_engine/pytorch/module/linear.py | 0/5 | Caches requires_grad checks for performance, but contains critical logic bug at line 484 that changes FP8 state management behavior |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Linear as Linear Module
participant Quantizer as Quantizer/QuantizedTensor
participant GEMM as GEMM Operations
participant CPP as C++ Extensions
Note over Linear,CPP: Performance Optimization Flow
User->>Linear: forward(input, weight, bias)
Note over Linear: Cache requires_grad checks
Linear->>Linear: inp_requires_grad = inp.requires_grad<br/>weight_requires_grad = weight.requires_grad<br/>bias_requires_grad = bias.requires_grad
Linear->>Quantizer: Check if quantized tensor
alt QuantizedTensor
Note over Quantizer: Use cached dtype property
Quantizer->>Quantizer: return self._dtype
Note over Quantizer: Use cached shape/is_cuda
Quantizer->>Quantizer: return self._data.shape
else Regular Tensor
Quantizer->>Linear: Standard attribute access
end
Linear->>CPP: get_tensor_device(tensor)
Note over CPP: Reordered attribute checks
CPP->>CPP: Check _rowwise_data first<br/>Check _columnwise_data<br/>Check device last
CPP-->>Linear: device_index
Linear->>GEMM: Configure GEMM parameters
Note over GEMM: Cache nvte_is_non_tn_fp8_gemm_supported
GEMM->>CPP: nvte_is_non_tn_fp8_gemm_supported()
Note over CPP: Static num_devices cached
CPP-->>GEMM: support_flag
GEMM->>GEMM: Store in local variable
GEMM->>GEMM: Execute optimized GEMM
GEMM-->>Linear: output
Note over Linear: FP8 State Management
alt FP8 enabled and requires_grad check
Linear->>Linear: Update FP8 tensors<br/>based on cached flags
end
Linear-->>User: output
Additional Comments (2)
|
|
/te-ci L1 pytorch |
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
| ret.A_scale_inv = A.scale_inv.dptr; | ||
| ret.lda = is_A_transposed ? k : m; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.
The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move this declaration before line 125 (Configure A matrix) | |
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new Python tuple object with a reference count of 1, but this reference is never decremented. The tuple is passed directly to PyObject_Call which doesn't steal the reference, so it must be explicitly released.
This same issue occurs at lines 186, 398, 417, 697, 718, 1039, 1058, 1346, and 1367.
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); | |
| Py_DECREF(kwargs); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid parameter: The Float8TensorStorage.__new__ method does not accept a device parameter. Looking at the Python class definition in float8_tensor_storage.py, the __new__ signature only expects: data, fp8_scale_inv, fp8_dtype, data_transpose, and quantizer.
This will cause a runtime error when trying to create Float8TensorStorage instances. This issue also occurs at lines 184, 396, 415, and in similar patterns for MXFP8 and other quantizers.
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| // Remove this line - device parameter is not accepted | |
| // PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
| def requires_grad(self) -> bool: | ||
| # Attribute access of custom tensors goes through an | ||
| # expensive Pyobject lookup. Since requires_grad is set during | ||
| # initialization and may be updated, we cache it in a member variable. | ||
| return self._requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: The requires_grad property getter accesses self._requires_grad at line 378, but this member variable is never initialized in the __new__ method (lines 353-364). This will cause an AttributeError when the property is first accessed.
The _requires_grad attribute must be initialized in __new__:
| def requires_grad(self) -> bool: | |
| # Attribute access of custom tensors goes through an | |
| # expensive Pyobject lookup. Since requires_grad is set during | |
| # initialization and may be updated, we cache it in a member variable. | |
| return self._requires_grad | |
| @property | |
| def requires_grad(self) -> bool: | |
| # Attribute access of custom tensors goes through an | |
| # expensive Pyobject lookup. Since requires_grad is set during | |
| # initialization and may be updated, we cache it in a member variable. | |
| # Fallback to parent if not cached yet | |
| if not hasattr(self, '_requires_grad'): | |
| self._requires_grad = torch.Tensor.requires_grad.fget(self) | |
| return self._requires_grad |
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing initialization: The _requires_grad member variable needs to be initialized here to support the cached property added later in the PR. Without this, accessing the requires_grad property will fail.
| instance._dtype = dtype | |
| return instance | |
| instance._dtype = dtype | |
| instance._requires_grad = requires_grad | |
| return instance |
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||
| bool is_extension_initialized = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.
The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:
| bool is_extension_initialized = false; | |
| bool is_extension_initialized = false; | |
| std::once_flag extension_init_flag; |
And in init_extension():
void init_extension() {
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}| static int num_devices = transformer_engine::cuda::num_devices(); | ||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.
While dynamic device changes are rare, consider adding a bounds check:
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| int device_id = transformer_engine::cuda::current_device(); | |
| NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration |
| py::object scale_inv_py = py::cast(scale_inv); | ||
| // Initialize scale-inverse tensor | ||
| if (!scale_inv) { | ||
| scale_inv = at::reciprocal(scale); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: scale_inv_py is created on line 146 by dereferencing scale_inv (which is std::optional<at::Tensor>), but scale_inv is only initialized on lines 148-150 if it's empty. When scale_inv is not provided to this function, line 146 will dereference an empty optional, causing undefined behavior or a crash.
The fix is to move line 146 after the initialization block:
| py::object scale_inv_py = py::cast(scale_inv); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Initialize scale-inverse tensor | |
| if (!scale_inv) { | |
| scale_inv = at::reciprocal(scale); | |
| } | |
| py::object scale_inv_py = py::cast(*scale_inv); |
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Memory leak: PyTuple_New(0) creates a new tuple with reference count 1, but PyObject_Call does not steal the reference. The tuple must be stored and explicitly DECREF'd after the call. This pattern is repeated 12 times throughout this file (lines 166, 186, 398, 417, 697, 718, 1039, 1058, 1346, 1367).
| PyTuple_New(0), kwargs); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs); | |
| Py_DECREF(args); |
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: _dtype is cached in __new__ but _requires_grad is not initialized. The requires_grad property (line 378) tries to return self._requires_grad, which will raise AttributeError when accessed before the setter is called. This breaks the optimization for caching requires_grad.
| instance._dtype = dtype | |
| return instance | |
| instance._dtype = dtype | |
| instance._requires_grad = requires_grad | |
| return instance |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 13 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Reuse is_nvte_non_tn_fp8_gemm_supported from line 132 |
| } | ||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Use-after-free bug: scale_inv_py is created from scale_inv before scale_inv is initialized. When scale_inv parameter is empty (lines 148-150), this dereferences an empty std::optional, causing undefined behavior. Move this line after the scale_inv initialization.
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | |
| // Move scale_inv_py creation after initialization (line 150) |
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) creates a new Python object with refcount 1, but it's never decremented. This leaks memory on every tensor creation. Add Py_DECREF for the tuple:
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| Py_DECREF(args); |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same issue as line 166).
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs); | |
| Py_DECREF(args); |
|
|
||
| ctx.owns_input = saved_inputmat is not inp | ||
| if ctx.fp8 and requires_grad(inp, weight, bias): | ||
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Logic error: The original requires_grad(inp, weight, bias) returns True if ANY tensor requires grad. This change requires ALL THREE to be True with and, breaking FP8 state management when bias is None or doesn't require grad.
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): | |
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
| PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) is not decremented (same pattern as line 166).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| ret.B_scale_inv = B.scale_inv.dptr; | ||
| ret.ldb = is_B_transposed ? n : k; | ||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is declared twice in the same function scope (first at line 132). This will cause a compilation error. Remove this redeclaration and reuse the variable from line 132.
| } | ||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); | ||
|
|
||
| py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Duplicate line. transpose_py is assigned twice (lines 145 and 146). Remove this duplicate line.
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
|
|
||
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | ||
| PyTuple_New(0), kwargs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Memory leak: PyTuple_New(0) creates a new Python object that is never decremented. Store the result and call Py_DECREF(tuple) after PyObject_Call. This pattern repeats in all 12 calls to PyObject_Call throughout this file (lines 167, 186, 331, 350, 397, 416, 706, 725, 1042, 1061, 1347, 1366).
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 5 comments
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
potential memory leak if py::cast() throws exception before Py_DECREF
If any of the py::cast(this->dtype), py::cast(shape_int64), or py::cast(device) calls throw an exception, the kwargs and args PyObjects will leak because Py_DECREF on lines 168-169 will never be reached.
| // Use direct C API call bypassing pybind11 overhead | |
| PyObject* kwargs = PyDict_New(); | |
| PyObject* args = PyTuple_New(0); | |
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | |
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | |
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | |
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | |
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | |
| PyObject* result = | |
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| Py_DECREF(kwargs); | |
| Py_DECREF(args); | |
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | |
| out_py = py::reinterpret_steal<py::object>(result); | |
| // Use RAII for Python object lifetime management | |
| py::dict kwargs_py; | |
| kwargs_py["data"] = data_py; | |
| kwargs_py["fp8_scale_inv"] = scale_inv_py; | |
| kwargs_py["fp8_dtype"] = py::cast(this->dtype); | |
| kwargs_py["data_transpose"] = transpose_py; | |
| kwargs_py["quantizer"] = this->quantizer; | |
| PyObject* args = PyTuple_New(0); | |
| PyObject* result = PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), | |
| args, kwargs_py.ptr()); | |
| Py_DECREF(args); |
This same pattern repeats throughout the file at lines 176-193, 382-400, 411-428, 657-678, 686-707, 1013-1034, 1042-1063, 1304-1325, and 1333-1354. All instances should use pybind11's py::dict for exception safety.
| at::Device device = | ||
| with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device defaults to cuda:0 when both data and transpose are None, may not be current device
In multi-GPU setups, if neither with_data nor with_transpose are true, the device defaults to torch::kCUDA (which implicitly converts to cuda:0). This may not match the current CUDA device, potentially causing cross-device operation issues.
| at::Device device = | |
| with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); | |
| at::Device device = | |
| with_data ? data->device() : (with_transpose ? transpose->device() : at::Device(torch::kCUDA, torch::cuda::current_device())); |
This same issue exists at line 383 in Float8CurrentScalingQuantizer::create_tensor.
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index | ||
| return torch.cuda.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reordering attribute checks changes behavior when devices don't match
The new code checks quantized tensor attributes (_rowwise_data, _columnwise_data, etc.) before checking the standard .device attribute. If a tensor has both and they refer to different devices, this changes which device is returned.
For example, if tensor.device == cuda:1 but tensor._rowwise_data.device == cuda:0, the old code would return 1 but the new code returns 0. This could break multi-GPU code that relies on the .device attribute being authoritative.
Consider documenting this behavior change or adding a check that all device attributes match if multiple are present.
| @property | ||
| def dtype(self) -> torch.dtype: | ||
| """ | ||
| Return the high precision data type of the tensor | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since dtype for a tensor is never | ||
| change after creation, we cache it in a member variable and return | ||
| """ | ||
| return self._dtype | ||
|
|
||
| @property | ||
| def requires_grad(self) -> bool: | ||
| """ | ||
| Return whether or not the tensor requires gradient. | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since requires_grad is set during | ||
| initialization and may be updated, we cache it in a member variable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_dtype and _requires_grad only initialized in __new__, will cause AttributeError in other creation paths
The new dtype and requires_grad properties rely on cached _dtype and _requires_grad attributes, but these are only initialized in the __new__ method (lines 363-364). If a QuantizedTensor is created through other paths such as:
- Unpickling (
__reduce_ex__,__setstate__) torch.Tensorsubclass operations that bypass__new__- FSDP state dict loading
Accessing .dtype or .requires_grad will raise AttributeError: '_dtype' not found.
| @property | |
| def dtype(self) -> torch.dtype: | |
| """ | |
| Return the high precision data type of the tensor | |
| Attribute access of custom tensors goes through an | |
| expensive Pyobject lookup. Since dtype for a tensor is never | |
| change after creation, we cache it in a member variable and return | |
| """ | |
| return self._dtype | |
| @property | |
| def requires_grad(self) -> bool: | |
| """ | |
| Return whether or not the tensor requires gradient. | |
| Attribute access of custom tensors goes through an | |
| expensive Pyobject lookup. Since requires_grad is set during | |
| initialization and may be updated, we cache it in a member variable. | |
| @property | |
| def dtype(self) -> torch.dtype: | |
| """ | |
| Return the high precision data type of the tensor | |
| """ | |
| # Lazy initialization for tensors created via alternate paths | |
| if not hasattr(self, '_dtype'): | |
| self._dtype = super().__getattribute__('dtype') | |
| return self._dtype |
Apply similar fix to requires_grad property.
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return self._data.shape if self._data is not None else self._transpose.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape property will raise AttributeError if both _data and _transpose are None
The property assumes at least one of _data or _transpose is not None, but doesn't validate the _transpose attribute before accessing it. If both are None (edge case during initialization or after certain operations), accessing .shape will raise AttributeError: 'NoneType' object has no attribute 'shape'.
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return self._data.shape if self._data is not None else self._transpose.shape | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._data is not None: | |
| return self._data.shape | |
| elif self._transpose is not None: | |
| return self._transpose.shape | |
| else: | |
| # Fallback to getting shape from alternate data sources or raise descriptive error | |
| raise RuntimeError("Float8Tensor has no data buffers initialized") |
Same issue exists in the is_cuda property at lines 919-921.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 6 comments
| ctx.tp_size = tp_size | ||
| ctx.requires_dgrad = inp.requires_grad | ||
| ctx.requires_wgrad = weight.requires_grad | ||
| ctx.requires_dgrad = inp_requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic change from requires_grad(inp, weight, bias) to (inp_requires_grad or weight_requires_grad or bias_requires_grad) is correct for the optimization, but this breaks the FP8 state management semantic.
The original requires_grad() function returns True if ANY tensor requires grad (see utils.py line 23-28). The new code uses and instead of or:
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad):However, looking more carefully at the actual change, the new code uses or, which is semantically equivalent to the old requires_grad(inp, weight, bias) call. This is actually correct!
But wait - I need to double-check the previous thread comment. The previous thread says "logic change from original requires_grad(inp, weight, bias) which returns True if ANY tensor requires grad. New code requires ALL THREE to be True". Let me verify this is indeed using or not and...
| @property | ||
| def dtype(self) -> torch.dtype: | ||
| """ | ||
| Return the high precision data type of the tensor | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since dtype for a tensor is never | ||
| change after creation, we cache it in a member variable and return | ||
| """ | ||
| return self._dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype property accesses self._dtype without checking if it exists. While QuantizedTensor.__new__() initializes this at line 364, subclasses like Float8Tensor can be instantiated through paths that may not properly initialize these cached attributes.
For example, when Float8TensorStorage.__new__() calls super().__new__(cls, *args, **kwargs), if dtype is not in kwargs, QuantizedTensor.__new__() won't receive it and won't set instance._dtype.
Consider adding defensive checks or ensuring all instantiation paths properly initialize _dtype and _requires_grad, or catch AttributeError and fall back to the parent class property.
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential memory leak if any operation between PyDict_New()/PyTuple_New() and Py_DECREF() throws an exception or fails. The py::cast() calls at lines 161-163 could throw C++ exceptions, and PyDict_SetItemString() could fail, leaving kwargs and args unreleased.
Consider using RAII wrappers or try-catch blocks to ensure cleanup:
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
if (!kwargs || !args) {
Py_XDECREF(kwargs);
Py_XDECREF(args);
throw std::runtime_error("Failed to allocate Python objects");
}
try {
PyDict_SetItemString(kwargs, "data", data_py.ptr());
// ... rest of setup
PyObject* result = PyObject_Call(...);
Py_DECREF(kwargs);
Py_DECREF(args);
NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance");
out_py = py::reinterpret_steal<py::object>(result);
} catch (...) {
Py_DECREF(kwargs);
Py_DECREF(args);
throw;
}Or better yet, use py::dict and py::tuple from pybind11 which handle cleanup automatically.
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return self._data.shape if self._data is not None else self._transpose.shape | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape and is_cuda properties assume that either self._data or self._transpose is not None. However, looking at the quantizer logic in Float8Quantizer.create_tensor() (quantizer.cpp lines 126-137), there are cases where:
with_data = rowwise_usage || is_non_tn_fp8_gemm_supportedwith_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported
If both rowwise_usage and columnwise_usage are False, and is_non_tn_fp8_gemm_supported is False, then both _data and _transpose would be None, causing AttributeError when accessing .shape or .is_cuda.
This scenario should probably be prevented at quantizer initialization, or these properties should handle the None case gracefully.
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index | ||
| return torch.cuda.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reordering of device attribute checks changes the priority of which device is returned. Previously, if a tensor had both device attribute and internal attributes like _data, it would return tensor.device.index. Now it returns the device of the internal attribute.
For QuantizedTensor instances where these might differ (though unlikely), this could change behavior. Consider the case where:
- A QuantizedTensor has
_dataon device 0 - But the tensor itself reports being on device 1 (via its
deviceproperty)
The old code would return device 1, the new code returns device 0. While this scenario is unlikely given how QuantizedTensor is implemented, it's a semantic change worth noting. If the internal attributes are guaranteed to be on the same device as the tensor, this is fine. Otherwise, document this priority or add an assertion.
|
|
||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||
| int num_devices = transformer_engine::cuda::num_devices(); | ||
| static int num_devices = transformer_engine::cuda::num_devices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making num_devices static has a subtle initialization order issue. The static initialization happens once when the function is first called, but the subsequent static vectors cache and flags depend on num_devices for their size.
If transformer_engine::cuda::num_devices() returns different values across multiple calls (which shouldn't happen in practice but isn't guaranteed by the API), the first call to this function will initialize num_devices, and subsequent calls will use that cached value. However, if the CUDA context changes or devices are added/removed (in rare scenarios), this could cause a mismatch.
Consider:
static int num_devices = transformer_engine::cuda::num_devices();This is initialized once, but cache and flags vectors might need a different size if the device count somehow changes. While unlikely, this could cause out-of-bounds access.
A safer approach might be:
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
cache.resize(num_devices, -1);
flags.resize(num_devices);
});Or simply document that the device count must not change during the application's lifetime.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 8 comments
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error check occurs AFTER Py_DECREF(kwargs) and Py_DECREF(args) (lines 168-169), which means if result is nullptr, the program will terminate via NVTE_CHECK without proper cleanup. While this is acceptable for fatal errors, it's inconsistent with the cleanup-before-check pattern and could mask the actual Python exception message.
Suggested improvement: Check for nullptr before decrementing reference counts, and use PyErr_Print() or PyErr_Fetch() to preserve the Python exception information:
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | |
| PyObject* result = | |
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | |
| if (result == nullptr) { | |
| PyErr_Print(); // Print the actual Python error | |
| Py_DECREF(kwargs); | |
| Py_DECREF(args); | |
| NVTE_CHECK(false, "Failed to create Float8TensorStorage instance"); | |
| } | |
| Py_DECREF(kwargs); | |
| Py_DECREF(args); | |
| out_py = py::reinterpret_steal<py::object>(result); |
This same pattern applies to all other PyObject_Call invocations in this file (lines 187, 401, 706, 723, 1047, 1068, 1357, 1378).
| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value | ||
| self._requires_grad = value | ||
| # Call parent class to ensure autograd engine is aware of the change | ||
| torch.Tensor.requires_grad.fset(self, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requires_grad property setter updates the cached _requires_grad value and then calls torch.Tensor.requires_grad.fset(self, value) to ensure the autograd engine is aware. However, there's a potential issue: if the parent class's setter fails or raises an exception, _requires_grad will already be updated, causing the cached value to be out of sync with the actual tensor state.
Suggested improvement: Call the parent setter first, then update the cache only if successful:
| @requires_grad.setter | |
| def requires_grad(self, value: bool) -> None: | |
| """Set requires_grad property so that autograd engine is aware of the change""" | |
| # Update the cached value | |
| self._requires_grad = value | |
| # Call parent class to ensure autograd engine is aware of the change | |
| torch.Tensor.requires_grad.fset(self, value) | |
| @requires_grad.setter | |
| def requires_grad(self, value: bool) -> None: | |
| """Set requires_grad property so that autograd engine is aware of the change""" | |
| # Call parent class first to ensure autograd engine is aware of the change | |
| torch.Tensor.requires_grad.fset(self, value) | |
| # Update the cached value only after successful update | |
| self._requires_grad = value |
The same pattern should be applied to requires_grad_() method on lines 402-409.
| @dtype.setter | ||
| def dtype(self, value: torch.dtype) -> None: | ||
| """Set dtype property""" | ||
| # Update the cached value | ||
| self._dtype = value | ||
| warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype property has a setter that allows changing the dtype after creation, with a warning. However, this is problematic because:
- The warning says "Dtype of QuantizedTensor has been changed" but doesn't indicate this could break things
- QuantizedTensor subclasses (Float8Tensor, MXFP8Tensor, etc.) have quantized data in specific formats (FP8, MXFP4) that don't change when you update
_dtype - Changing dtype without re-quantizing the underlying data creates a mismatch between the cached dtype and the actual data representation
This setter creates a dangerous API surface. If dtype must be mutable, it should either:
- Raise an error instead of just warning
- Or trigger re-quantization of the data
Current impact: Users could accidentally corrupt tensor semantics by changing dtype, leading to incorrect dequantization.
| ) | ||
| return out, all_gather_outputs | ||
|
|
||
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return self._data.shape if self._data is not None else self._transpose.shape | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new shape and is_cuda property implementations have a logic issue: they return properties from self._data if it exists, otherwise from self._transpose. However, according to the codebase logic in cublaslt_gemm.cu and quantizer.cpp, on Hopper (non-Blackwell) architectures, FP8 tensors may only have _transpose data for certain GEMM layouts, while on Blackwell they may only have _data.
The properties assume that if _data is None, then _transpose must exist, but there's no fallback handling if both are somehow None (which shouldn't happen but could in edge cases).
Suggested improvement: Add a safety check or document the invariant that at least one must always exist:
| ) | |
| return out, all_gather_outputs | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return self._data.shape if self._data is not None else self._transpose.shape | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._data is not None: | |
| return self._data.shape | |
| if self._transpose is not None: | |
| return self._transpose.shape | |
| raise RuntimeError("Float8Tensor has neither _data nor _transpose") | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._data is not None: | |
| return self._data.is_cuda | |
| if self._transpose is not None: | |
| return self._transpose.is_cuda | |
| raise RuntimeError("Float8Tensor has neither _data nor _transpose") |
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return ( | ||
| self._rowwise_data.shape | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.shape | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as Float8Tensor: the shape and is_cuda properties assume that if _rowwise_data is None, then _columnwise_data must exist. However, there's no safety check for the edge case where both could be None.
Suggested improvement: Add safety checks:
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| return ( | |
| self._rowwise_data.shape | |
| if self._rowwise_data is not None | |
| else self._columnwise_data.shape | |
| ) | |
| @property | |
| def shape(self): | |
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.shape | |
| if self._columnwise_data is not None: | |
| return self._columnwise_data.shape | |
| raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data") | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._rowwise_data is not None: | |
| return self._rowwise_data.is_cuda | |
| if self._columnwise_data is not None: | |
| return self._columnwise_data.is_cuda | |
| raise RuntimeError("MXFP8Tensor has neither _rowwise_data nor _columnwise_data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error checking for Python C API calls could lead to crashes. The code doesn't check return values from PyDict_New(), PyTuple_New(), PyDict_SetItemString(), or PyObject_Call() before proceeding. If any of these fail (e.g., due to memory allocation failure), the subsequent operations will access invalid pointers.
Recommendation:
Add error checks after each C API call:
- Check if
PyDict_New()andPyTuple_New()return NULL - Check if
PyDict_SetItemString()returns -1 - The existing
NVTE_CHECK(result != nullptr, ...)is good, but should come beforePy_DECREFcalls to avoid decrementing invalid references
This pattern is repeated throughout this file at lines 176-193, 389-404, 410-426, 688-706, 710-727, 1035-1050, 1054-1070, 1343-1361, and 1365-1383. All instances need the same error checking.
| @property | ||
| def requires_grad(self) -> bool: | ||
| """ | ||
| Return whether or not the tensor requires gradient. | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since requires_grad is set during | ||
| initialization and may be updated, we cache it in a member variable. | ||
| """ | ||
| return self._requires_grad | ||
|
|
||
| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value | ||
| self._requires_grad = value | ||
| # Call parent class to ensure autograd engine is aware of the change | ||
| torch.Tensor.requires_grad.fset(self, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential cache inconsistency if PyTorch modifies requires_grad through internal mechanisms. While the setter and requires_grad_() method properly sync the cache with the parent tensor, there's a risk that PyTorch's autograd engine could modify requires_grad through paths that bypass these methods (e.g., during graph construction or when tensors are involved in certain operations).
Current behavior: The cached value could become stale if PyTorch modifies requires_grad internally.
Mitigation: The implementation correctly updates the parent tensor in the setter and requires_grad_(), which should handle most cases. However, consider documenting this limitation or adding validation in critical paths to ensure cache consistency.
Note: This is a minor concern as most requires_grad modifications go through the standard setter, but worth being aware of for debugging purposes.
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error checking order issue: NVTE_CHECK for result != nullptr happens after Py_DECREF(kwargs) and Py_DECREF(args). While decrefing these is safe even when result is NULL, checking for errors before cleanup is better practice and makes the error path clearer.
If PyObject_Call fails and returns NULL, Python's exception state is set. The NVTE_CHECK should happen immediately after the call, and cleanup should be part of an error handling path. Consider moving the check before Py_DECREF or using proper error handling with early returns.
This pattern affects all PyObject_Call instances in this file.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
…ormerEngine into cpu_fp8_optimizations Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
| static int num_devices = transformer_engine::cuda::num_devices(); | ||
| static std::vector<int> cache(num_devices, -1); | ||
| static std::vector<std::once_flag> flags(num_devices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: Potential out-of-bounds access when device_id >= num_devices
Making num_devices static causes it to be evaluated only once at first call. This creates a serious bug if the device ID is not within the range [0, num_devices):
- Problem: CUDA device IDs are not necessarily contiguous. For example, with
CUDA_VISIBLE_DEVICES=0,2, you have 2 visible devices but device IDs are 0 and 2. - Bug scenario: If
num_devicesis initialized to 2, but later code callscurrent_device()which returns 2, thencache[device_id]at line 975 will accesscache[2]which is out of bounds (cache size is 2, valid indices are 0-1). - Impact: This will cause undefined behavior, likely a segmentation fault or memory corruption.
Fix: Either:
- Use
current_device()as the cache key and bounds-check it - Map actual device IDs to cache indices
- Remove the static keyword and re-evaluate num_devices on each call (sacrifices some of the optimization but ensures correctness)
Additional Comments (1)
The original line was: _quantization_list: list[Optional[str]] = [None]Now it's: _quantization_list: list[Optional[str]] = []This means tests will no longer run with Question: Is this intentional? If this is a CPU optimization PR, removing coverage for non-quantized tests seems unrelated to the stated goals. Please clarify if this change is intended or accidental. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 12 comments
|
|
||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||
| int num_devices = transformer_engine::cuda::num_devices(); | ||
| static int num_devices = transformer_engine::cuda::num_devices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Potential out-of-bounds access with static num_devices
Making num_devices static creates a serious bug when GPUs are hot-plugged after the first call to this function. The cache and flags vectors are sized based on the initial device count, but device_id from current_device() could exceed num_devices if GPUs are added later.
This will cause out-of-bounds access on lines 968 and 975:
std::call_once(flags[device_id], ...); // OOB if device_id >= num_devices
return cache[device_id]; // OOB if device_id >= num_devicesImpact: Undefined behavior, potential crashes, memory corruption.
Fix: Query device count each time, or add bounds checking:
int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::mutex resize_mutex;
int device_id = transformer_engine::cuda::current_device();
{
std::lock_guard<std::mutex> lock(resize_mutex);
if (device_id >= cache.size()) {
cache.resize(device_id + 1, -1);
flags.resize(device_id + 1);
}
}
std::call_once(flags[device_id], [&]() { ... });| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak when exceptions occur
The direct C API usage introduces memory leak vulnerabilities. If any py::cast() call throws an exception or if PyDict_SetItemString() fails after PyDict_New() and PyTuple_New(), the allocated kwargs and args will leak.
Problems:
- No NULL checks after
PyDict_New()andPyTuple_New() - Exception safety:
py::cast()can throw, but cleanup code won't run - Multiple
py::cast()calls create temporary objects that could throw
Example failure scenario:
PyObject* kwargs = PyDict_New(); // Allocated
PyObject* args = PyTuple_New(0); // Allocated
// ... several PyDict_SetItemString calls ...
PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); // Throws!
// Cleanup code never runs -> kwargs and args leakedFix: Use RAII for proper cleanup:
struct PyObjectDeleter {
void operator()(PyObject* obj) { Py_XDECREF(obj); }
};
using PyObjectPtr = std::unique_ptr<PyObject, PyObjectDeleter>;
PyObjectPtr kwargs(PyDict_New());
PyObjectPtr args(PyTuple_New(0));
NVTE_CHECK(kwargs && args, "Failed to allocate Python objects");
// Now safe even if exceptions occurThis issue affects all quantizer types (Float8, Float8CurrentScaling, Float8Block, MXFP8, NVFP4).
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); | ||
| PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorPythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak when exceptions occur (non-internal path)
Same memory leak vulnerability as the internal path above. The kwargs and args will leak if any exception occurs between allocation and cleanup.
Additionally, this block has more py::cast() calls (8 vs 5), increasing the risk of exceptions:
py::cast(shape_int64).ptr()py::cast(GetATenDType(dtype)).ptr()py::cast(this->dtype).ptr()py::cast(device).ptr()
Each of these can throw, and none are wrapped in exception-safe RAII wrappers.
Fix: Apply the same RAII solution as suggested in the previous comment.
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performance regression for regular tensors
The reordering of attribute checks optimizes for quantized tensors but creates a performance regression for regular torch.Tensor objects.
Before: Regular tensors check device immediately (line 1) and return
After: Regular tensors check 4 non-existent attributes via hasattr() before checking device
Impact:
- Quantized tensors: ~4 fewer attribute checks ✓
- Regular tensors: ~4 additional attribute checks ✗
Concern: If regular tensors are passed to get_tensor_device() frequently (which seems likely given the function name), this could offset the gains from other optimizations in this PR.
Recommendation: Profile both paths or add an early isinstance check:
def get_tensor_device(tensor: torch.Tensor) -> int:
# Fast path for regular tensors
if type(tensor).__name__ == 'Tensor':
return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device()
# Quantized tensor path
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
# ... etc| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in Float8CurrentScalingQuantizer (internal path)
Same memory leak vulnerability as Float8Quantizer. The kwargs and args will leak if any exception occurs between allocation and the Py_DECREF cleanup code. See comment on lines 157-171 for detailed explanation and fix.
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); | ||
| ret = py::reinterpret_steal<py::object>(result); | ||
| } else { | ||
| py::handle Float8BlockwiseQTensorClass( | ||
| reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass)); | ||
| ret = Float8BlockwiseQTensorClass( | ||
| "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, | ||
| "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, | ||
| "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, | ||
| "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), | ||
| "data_format"_a = data_format); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).ptr()); | ||
| PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
| PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass), args, kwargs); | ||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in Float8BlockQuantizer (non-internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.
| if (internal) { | ||
| py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass)); | ||
| out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, | ||
| "columnwise_data"_a = columnwise_data_py, | ||
| "rowwise_scale_inv"_a = rowwise_scale_inv_py, | ||
| "columnwise_scale_inv"_a = columnwise_scale_inv_py, | ||
| "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in MXFP8Quantizer (internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 157-171 for detailed explanation and fix.
| out_py = py::reinterpret_steal<py::object>(result); | ||
| } else { | ||
| py::handle MXFP8TensorClass(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass)); | ||
| out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), | ||
| "rowwise_data"_a = rowwise_data_py, | ||
| "columnwise_data"_a = columnwise_data_py, | ||
| "rowwise_scale_inv"_a = rowwise_scale_inv_py, | ||
| "columnwise_scale_inv"_a = columnwise_scale_inv_py, | ||
| "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); | ||
| PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(MXFP8TensorPythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in MXFP8Quantizer (non-internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.
| if (internal) { | ||
| py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass)); | ||
| out_py = NVFP4TensorClass( | ||
| "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, | ||
| "rowwise_scale_inv"_a = rowwise_scale_inv_py, | ||
| "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, | ||
| "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, | ||
| "quantizer"_a = this->quantizer); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorStoragePythonClass), args, kwargs); | ||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in NVFP4Quantizer (internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 157-171 for detailed explanation and fix.
| } else { | ||
| py::handle NVFP4TensorClass(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass)); | ||
| out_py = NVFP4TensorClass( | ||
| "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), | ||
| "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, | ||
| "rowwise_scale_inv"_a = rowwise_scale_inv_py, | ||
| "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, | ||
| "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, | ||
| "quantizer"_a = this->quantizer); | ||
| // Use direct C API call bypassing pybind11 overhead | ||
| PyObject* kwargs = PyDict_New(); | ||
| PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); | ||
| PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* args = PyTuple_New(0); | ||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(NVFP4TensorPythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(args); | ||
| Py_DECREF(kwargs); | ||
|
|
||
| NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Memory leak in NVFP4Quantizer (non-internal path)
Same memory leak vulnerability as other quantizers. The kwargs and args will leak if any exception occurs. See comment on lines 176-193 for detailed explanation and fix.
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 4 comments
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: Memory safety bug - Use-after-free with temporary py::object
The expression py::cast(this->dtype).ptr() creates a temporary py::object that is destroyed immediately after .ptr() returns, but before PyDict_SetItemString can increment its reference count. This results in passing a dangling pointer to PyDict_SetItemString, which is undefined behavior and will likely cause crashes or memory corruption.
The same issue occurs on multiple lines in this file:
- Line 161:
py::cast(this->dtype).ptr() - Line 178:
py::cast(shape_int64).ptr() - Line 179:
py::cast(GetATenDType(dtype)).ptr() - Line 182:
py::cast(this->dtype).ptr() - Line 185:
py::cast(device).ptr()
And similar patterns throughout the file for other quantizer types.
Fix: Store the py::object in a local variable before calling .ptr():
py::object fp8_dtype_py = py::cast(this->dtype);
PyDict_SetItemString(kwargs, "fp8_dtype", fp8_dtype_py.ptr());Or use py::cast(this->dtype).inc_ref().ptr() to explicitly increment the reference count before extracting the pointer.
|
|
||
| ctx.owns_input = saved_inputmat is not inp | ||
| if ctx.fp8 and requires_grad(inp, weight, bias): | ||
| if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CRITICAL: Logic error - Changed from OR to AND breaks FP8 state management
The original code used the requires_grad() helper function which returns True if ANY of the tensors require grad (OR logic). The new code requires ALL THREE tensors to require grad (AND logic).
Original behavior:
if ctx.fp8 and requires_grad(inp, weight, bias): # True if ANY require gradNew behavior:
if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): # Correct!But wait, this line actually looks correct with OR logic. Let me re-check... Actually, this specific line IS correct. The issue mentioned in previous threads was already addressed in this PR. My apologies for the confusion on this specific line - it's actually implemented correctly.
| @property | ||
| def requires_grad(self) -> bool: | ||
| """ | ||
| Return whether or not the tensor requires gradient. | ||
| Attribute access of custom tensors goes through an | ||
| expensive Pyobject lookup. Since requires_grad is set during | ||
| initialization and may be updated, we cache it in a member variable. | ||
| """ | ||
| return self._requires_grad | ||
|
|
||
| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value and call parent class method to ensure autograd engine is aware | ||
| self.requires_grad_(value) | ||
|
|
||
| def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: | ||
| """Cache requires_grad property and call parent class method""" | ||
| # pylint: disable=missing-function-docstring | ||
| # Update the cached value | ||
| self._requires_grad = requires_grad | ||
| # Call parent class method to ensure autograd engine is aware | ||
| super().requires_grad_(requires_grad) | ||
| return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential cache staleness issue with requires_grad property
The cached _requires_grad attribute could become stale if the underlying tensor's requires_grad is modified through mechanisms other than the provided setter. For example:
- PyTorch's autograd engine might directly modify the tensor's
requires_gradflag - Operations like
tensor.requires_grad_()from the parent class might bypass the cached value - Deserialization or other internal PyTorch operations might set
requires_gradwithout going through the custom setter
While the setter correctly updates both the cache and calls super().requires_grad_(value), the getter always returns the cached value, which could diverge from the actual tensor state.
Recommendation: Consider either:
- Synchronizing the cache by checking
super().requires_gradperiodically - Documenting this limitation clearly
- Or accepting this as a known trade-off for performance (which may be acceptable if requires_grad is rarely modified after initialization)
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling for Python C API calls
The code doesn't check for failures in Python C API calls:
PyDict_New()can return NULL on memory allocation failurePyTuple_New(0)can return NULL on memory allocation failurePyDict_SetItemString()can fail and return -1
If PyDict_New() or PyTuple_New() fails, the subsequent calls will crash when dereferencing NULL pointers. If PyDict_SetItemString() fails, the dictionary will be incomplete but the error is silently ignored.
Recommendation: Add error checks:
PyObject* kwargs = PyDict_New();
if (!kwargs) throw std::runtime_error("Failed to create kwargs dict");
PyObject* args = PyTuple_New(0);
if (!args) {
Py_DECREF(kwargs);
throw std::runtime_error("Failed to create args tuple");
}
// Check PyDict_SetItemString return values or use NVTE_CHECKThis pattern repeats throughout all the create_tensor methods in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 10 comments
| PyObject* kwargs = PyDict_New(); | ||
| PyObject* args = PyTuple_New(0); | ||
| PyDict_SetItemString(kwargs, "data", data_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); | ||
| PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); | ||
| PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); | ||
|
|
||
| PyObject* result = | ||
| PyObject_Call(reinterpret_cast<PyObject*>(Float8TensorStoragePythonClass), args, kwargs); | ||
|
|
||
| Py_DECREF(kwargs); | ||
| Py_DECREF(args); | ||
| NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); | ||
| out_py = py::reinterpret_steal<py::object>(result); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential reference counting issue with temporary py::object
The pattern PyDict_SetItemString(kwargs, "key", py::cast(value).ptr()) creates temporary py::object instances that may lead to subtle reference counting issues. The temporary object from py::cast() will be destroyed at the end of the statement, potentially before PyDict_SetItemString completes in edge cases.
Safer pattern:
PyObject* kwargs = PyDict_New();
PyObject* args = PyTuple_New(0);
py::object data_obj = data_py;
py::object scale_inv_obj = scale_inv_py;
py::object dtype_obj = py::cast(this->dtype);
py::object transpose_obj = transpose_py;
py::object quantizer_obj = this->quantizer;
PyDict_SetItemString(kwargs, "data", data_obj.ptr());
PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_obj.ptr());
PyDict_SetItemString(kwargs, "fp8_dtype", dtype_obj.ptr());
PyDict_SetItemString(kwargs, "data_transpose", transpose_obj.ptr());
PyDict_SetItemString(kwargs, "quantizer", quantizer_obj.ptr());This ensures all py::object instances remain alive while their pointers are being used. The same issue exists in multiple locations throughout this file (lines 176-185, 389-397, 409-417, and many more).
| instance._requires_grad = requires_grad | ||
| instance._dtype = dtype | ||
| return instance |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing cached attribute initialization in alternate construction paths
The cached attributes _requires_grad and _dtype are initialized here in __new__, but there are several code paths where QuantizedTensor objects can be created that may bypass this initialization:
- Pickling/unpickling: When tensors are pickled and unpickled,
__reduce_ex__and__setstate__may not properly initialize these cached attributes - Tensor operations: Operations like
.clone(),.detach(),torch.cat(), etc. may create new tensor instances - C++ constructors: When tensors are created from C++ code via quantizer
Impact: Accessing .dtype or .requires_grad on such tensors will raise AttributeError: '_dtype' not found.
Recommendation: Add defensive checks in the property getters:
@property
def dtype(self) -> torch.dtype:
if not hasattr(self, '_dtype'):
self._dtype = super().dtype
return self._dtype
@property
def requires_grad(self) -> bool:
if not hasattr(self, '_requires_grad'):
self._requires_grad = super().requires_grad
return self._requires_grad| @requires_grad.setter | ||
| def requires_grad(self, value: bool) -> None: | ||
| """Set requires_grad property so that autograd engine is aware of the change""" | ||
| # Update the cached value and call parent class method to ensure autograd engine is aware | ||
| self.requires_grad_(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect setter implementation causes infinite recursion
The requires_grad setter calls self.requires_grad_(value) which then calls super().requires_grad_(requires_grad). However, setting requires_grad through assignment (e.g., tensor.requires_grad = True) will invoke this setter, which doesn't update self._requires_grad before calling the parent method.
The correct flow should be:
- Update cached value
self._requires_grad - Call parent
requires_grad_()to notify autograd
But line 399 in requires_grad_() updates the cache AFTER calling super, which means the setter doesn't actually update the cache.
Fix:
@requires_grad.setter
def requires_grad(self, value: bool) -> None:
"""Set requires_grad property so that autograd engine is aware of the change"""
self._requires_grad = value
super().requires_grad_(value)|
|
||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||
| int num_devices = transformer_engine::cuda::num_devices(); | ||
| static int num_devices = transformer_engine::cuda::num_devices(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential race condition in static variable initialization
Making num_devices static improves performance by caching the device count, but this introduces a subtle race condition. While the cache and flags vectors are protected by std::call_once, num_devices itself is initialized on first function call without synchronization.
Scenario:
- Thread A enters function first time, starts calling
transformer_engine::cuda::num_devices() - Thread B enters function while A is still in
num_devices(), sees uninitialized static - Both threads may initialize
num_devices, causing undefined behavior
However, this is likely safe in practice because:
- The function is typically called during module initialization (single-threaded)
- The value doesn't change after initialization
- Most compilers guarantee thread-safe static initialization in C++11
But for correctness, consider wrapping in the existing std::call_once:
static int num_devices = []() {
return transformer_engine::cuda::num_devices();
}();Or initialize it within the std::call_once block if device_id doesn't need num_devices.
| ) | ||
| # pylint: disable=unnecessary-dunder-call | ||
| super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) | ||
| self.dtype = tensor.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cached dtype not initialized in all construction paths
This line sets self.dtype = tensor.dtype to initialize the cached _dtype attribute, but this only happens in the _set_data method when copying from another Float8Tensor AND certain conditions are met (size/stride/dtype mismatch).
Missing initialization in:
- Initial tensor construction (first time
Float8Tensoris created) - When sizes/strides match (lines 1002-1006 don't call this setter)
_make_in_reduce_exunpickling path (line 942)- Other construction paths
This will cause AttributeError when accessing .dtype property on tensors created through other paths.
Fix: Ensure _dtype is initialized in Float8Tensor.__new__ or add the defensive check mentioned in the QuantizedTensor comment.
| void init_float8_extension() { | ||
| if (Float8TensorPythonClass) return; | ||
| auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); | ||
| Float8QuantizerClass = | ||
| reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); | ||
| Float8CurrentScalingQuantizerClass = reinterpret_cast<PyTypeObject *>( | ||
| PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer")); | ||
| Float8TensorPythonClass = | ||
| reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); | ||
| auto fp8_base_module = | ||
| py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage"); | ||
| Float8TensorStoragePythonClass = reinterpret_cast<PyTypeObject *>( | ||
| PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage")); | ||
| NVTE_CHECK(Float8TensorPythonClass != nullptr, | ||
| "Internal error: could not initialize pyTorch Float8 extension."); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed null-check guards without proper synchronization guarantee
The original code had null checks (if (Float8TensorPythonClass) return;) to prevent redundant initialization. These were removed in favor of std::call_once, but the individual init_* functions no longer have guards.
Issue: If init_extension() is called multiple times from different threads, std::call_once ensures the lambda runs once. However, if init_float8_extension() is called directly (not through init_extension()), it will re-import modules and reassign global pointers without protection.
Verification needed: Confirm that these functions are NEVER called directly, only through init_extension(). If they can be called directly, the null checks should be restored:
void init_float8_extension() {
if (Float8TensorPythonClass) return; // Keep this guard
auto fp8_module = py::module_::import(...);
...
}| dtype = inp.dtype | ||
| # All checks after this have already been performed once, thus skip | ||
| if self.activation_dtype == inp.dtype: | ||
| if self.activation_dtype == dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimization: Cache inp.dtype early to avoid redundant attribute access
Good optimization! By caching inp.dtype early (line 932), the code avoids a second attribute lookup that would have occurred on line 936 in the original code. This is a clean CPU optimization with no functional change.
Note: This is correctly placed BEFORE the early return check, ensuring the cached value is available for the comparison on line 933.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
| return tensor._columnwise_data.device.index | ||
| if hasattr(tensor, "_data") and tensor._data is not None: | ||
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index | ||
| return torch.cuda.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reordering device attribute checks changes behavior for regular tensors
The original code checked tensor.device FIRST (line 78-79 in old code), then checked quantized tensor attributes. The new code reverses this order, checking quantized attributes first.
Impact:
- Performance: For regular
torch.Tensorobjects, this INCREASES overhead by attempting to check for_rowwise_data,_columnwise_data,_data, and_transposeattributes (all of which will fail viahasattr) before finally checkingdevice. - Correctness: For QuantizedTensors that happen to have a
.deviceattribute that differs from their internal data device, this could change which device is returned.
Concern: This optimization seems to hurt performance for the common case (regular tensors) while only helping the less common case (quantized tensors).
Question: What is the relative frequency of this function being called with regular tensors vs quantized tensors? If regular tensors are more common, this change may degrade overall performance.
| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| if self._data is not None: | ||
| return self._data.shape | ||
| else: | ||
| transpose_shape = self._transpose.shape | ||
| return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling when both _data and _transpose are None
The shape property assumes at least one of _data or _transpose is not None, but there's no validation:
@property
def shape(self):
if self._data is not None:
return self._data.shape
else:
transpose_shape = self._transpose.shape # AttributeError if _transpose is None!
return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))Similarly for is_cuda:
@property
def is_cuda(self):
return self._data.is_cuda if self._data is not None else self._transpose.is_cuda
# AttributeError if both are None!When can this happen?
- During tensor construction before data is initialized
- After certain operations that clear internal data
- In error/edge cases
Fix: Add defensive check:
@property
def shape(self):
if self._data is not None:
return self._data.shape
elif self._transpose is not None:
transpose_shape = self._transpose.shape
return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],))
else:
# Fallback to base class or raise informative error
return super().shape| @property | ||
| def shape(self): | ||
| """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" | ||
| return ( | ||
| self._rowwise_data.shape | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.shape | ||
| ) | ||
|
|
||
| @property | ||
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return ( | ||
| self._rowwise_data.is_cuda | ||
| if self._rowwise_data is not None | ||
| else self._columnwise_data.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error handling when both rowwise_data and columnwise_data are None
Same issue as Float8Tensor - these properties assume at least one data tensor exists:
@property
def shape(self):
return (
self._rowwise_data.shape
if self._rowwise_data is not None
else self._columnwise_data.shape # AttributeError if both are None!
)
@property
def is_cuda(self):
return (
self._rowwise_data.is_cuda
if self._rowwise_data is not None
else self._columnwise_data.is_cuda # AttributeError if both are None!
)Fix: Add validation:
@property
def shape(self):
if self._rowwise_data is not None:
return self._rowwise_data.shape
elif self._columnwise_data is not None:
return self._columnwise_data.shape
else:
return super().shape # or raise informative errorSigned-off-by: Varun Thumbe <vthumbe@nvidia.com>
|
/te-ci L1 pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
11 files reviewed, 2 comments
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Unlike the shape property above which has a defensive check (raise RuntimeError("Both data and transpose are None")), is_cuda will raise AttributeError if both _data and _transpose are None. Consider adding the same defensive check for consistency:
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| return self._data.is_cuda if self._data is not None else self._transpose.is_cuda | |
| @property | |
| def is_cuda(self): | |
| """Return whether the tensor is on a CUDA device.""" | |
| if self._data is not None: | |
| return self._data.is_cuda | |
| elif self._transpose is not None: | |
| return self._transpose.is_cuda | |
| else: | |
| raise RuntimeError("Both data and transpose are None") |
| def is_cuda(self): | ||
| """Return whether the tensor is on a CUDA device.""" | ||
| return ( | ||
| self._rowwise_data.is_cuda | ||
| if self._rowwise_data is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Same consistency issue as Float8Tensor - is_cuda will raise AttributeError if both _rowwise_data and _columnwise_data are None, while there's no defensive check like shape should have. Consider adding error handling for this edge case.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: