From 93ee022791f324a8a553adb3a1cbacdd356d53e1 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 5 Jan 2026 18:11:01 +0000 Subject: [PATCH 01/31] add all the optimizations Signed-off-by: Varun Thumbe --- .../common/gemm/cublaslt_gemm.cu | 12 +- .../common/transformer_engine.cpp | 2 +- .../pytorch/cpp_extensions/gemm.py | 4 +- .../pytorch/csrc/extensions/pybind.cpp | 7 +- transformer_engine/pytorch/csrc/quantizer.cpp | 289 +++++++++++++----- transformer_engine/pytorch/csrc/util.cpp | 2 +- transformer_engine/pytorch/module/base.py | 5 +- transformer_engine/pytorch/module/linear.py | 19 +- .../pytorch/quantized_tensor.py | 9 +- .../pytorch/tensor/float8_tensor.py | 7 + 10 files changed, 255 insertions(+), 101 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 4b7d8179b0..52e819363c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), @@ -1107,4 +1109,4 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor } cublas_path(); } -} +} \ No newline at end of file diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 370d9723cf..82c50c4ebd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - int num_devices = transformer_engine::cuda::num_devices(); + static int num_devices = transformer_engine::cuda::num_devices(); static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 2a97e2ac71..7fe37b5f54 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -76,8 +76,6 @@ def get_tensor_device(tensor: torch.Tensor) -> int: The order of attributes checked is important to also minimize overhead. """ - if hasattr(tensor, "device"): - return tensor.device.index if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: return tensor._rowwise_data.device.index if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: @@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int: return tensor._data.device.index if hasattr(tensor, "_transpose") and tensor._transpose is not None: return tensor._transpose.device.index + if hasattr(tensor, "device"): + return tensor.device.index return torch.cuda.current_device() diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e73eca7861..4b14e8c019 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; +bool is_extension_initialized = false; void init_float8_extension() { - if (Float8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); Float8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); @@ -54,7 +54,6 @@ void init_float8_extension() { } void init_mxfp8_extension() { - if (MXFP8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); MXFP8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); @@ -69,7 +68,6 @@ void init_mxfp8_extension() { } void init_float8blockwise_extension() { - if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); auto fp8_base_module = py::module_::import( @@ -90,7 +88,6 @@ void init_float8blockwise_extension() { } void init_nvfp4_extensions() { - if (NVFP4TensorPythonClass) return; auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); NVFP4QuantizerClass = reinterpret_cast( PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); @@ -105,10 +102,12 @@ void init_nvfp4_extensions() { } void init_extension() { + if (is_extension_initialized) return; init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); init_nvfp4_extensions(); + is_extension_initialized = true; } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a73efc008a..6c3d835b20 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -121,9 +121,9 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -134,7 +134,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -143,26 +143,55 @@ std::pair Float8Quantizer::create_tensor( transpose.reset(); } py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); - + py::object scale_inv_py = py::cast(scale_inv); // Initialize scale-inverse tensor if (!scale_inv) { scale_inv = at::reciprocal(scale); } - + at::Device device = with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); // Construct Python FP8 tensor py::object out_py; if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + + PyObject* result = PyObject_Call( + reinterpret_cast(Float8TensorStoragePythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + PyObject* result = PyObject_Call( + reinterpret_cast(Float8TensorPythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -185,10 +214,10 @@ std::pair Float8Quantizer::create_tensor( std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -328,7 +357,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -337,13 +367,12 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts); } - // Initialize scale-inverse tensor at::Tensor scale_inv_tensor; { @@ -351,23 +380,52 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); scale_inv_tensor = at::empty(scale_inv_shape, opts); } - + at::Device device = with_data ? data_tensor.device() : (with_transpose ? transpose_tensor.device() : torch::kCUDA); // Construct Python FP8 tensor py::object out_py; + py::object scale_inv_py = py::cast(scale_inv_tensor); py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + + PyObject* result = PyObject_Call( + reinterpret_cast(Float8TensorStoragePythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + PyObject* result = PyObject_Call( + reinterpret_cast(Float8TensorPythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -406,10 +464,10 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor @@ -629,22 +687,49 @@ std::pair Float8BlockQuantizer::create_tensor( py::object ret; if (internal) { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); - ret = Float8BlockwiseQTensorClass( - "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, - "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); + + PyObject* result = PyObject_Call( + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); + ret = py::reinterpret_steal(result); } else { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorPythonClass)); - ret = Float8BlockwiseQTensorClass( - "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, - "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, - "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), - "data_format"_a = data_format); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); + + PyObject* result = PyObject_Call( + reinterpret_cast(Float8BlockwiseQTensorPythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); + ret = py::reinterpret_steal(result); } return {std::move(tensor), std::move(ret)}; @@ -950,20 +1035,45 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Construct Python MXFP8 tensor py::object out_py; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); - out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, - "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = PyObject_Call( + reinterpret_cast(MXFP8TensorStoragePythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, - "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = PyObject_Call( + reinterpret_cast(MXFP8TensorPythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ MXFP8 tensor @@ -1234,22 +1344,49 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Construct Python NVFP4 tensor py::object out_py; if (internal) { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); - out_py = NVFP4TensorClass( - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); + PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = PyObject_Call( + reinterpret_cast(NVFP4TensorStoragePythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); - out_py = NVFP4TensorClass( - "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); + PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = PyObject_Call( + reinterpret_cast(NVFP4TensorPythonClass), + PyTuple_New(0), + kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ tensor diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 96fd2ccb3a..35e97b80cb 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -15,7 +15,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING || input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ad5cd04341..368b61b382 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -929,12 +929,11 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: if torch.is_autocast_enabled(): self.activation_dtype = torch_get_autocast_gpu_dtype() return - + dtype = inp.dtype # All checks after this have already been performed once, thus skip - if self.activation_dtype == inp.dtype: + if self.activation_dtype == dtype: return - dtype = inp.dtype if not self.allow_different_data_and_param_types: for name, param in self.named_parameters(): if param is not None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f3220d5860..a49652a2c2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -93,7 +93,6 @@ def forward( non_tensor_args: Tuple, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - ( is_first_microbatch, fp8, @@ -130,6 +129,10 @@ def forward( debug, ) = non_tensor_args + inp_requires_grad = inp.requires_grad + weight_requires_grad = weight.requires_grad + bias_requires_grad = bias.requires_grad if bias is not None else False + # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" if ub_name is not None: @@ -141,7 +144,7 @@ def forward( # Configure tensor-parallel communication tp_world_size = get_distributed_world_size(tp_group) - backward_needs_input = is_grad_enabled and weight.requires_grad + backward_needs_input = is_grad_enabled and weight_requires_grad with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) @@ -254,7 +257,7 @@ def forward( # Configure quantizer # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): - columnwise_usage = is_grad_enabled and inp.requires_grad + columnwise_usage = is_grad_enabled and inp_requires_grad if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -379,7 +382,7 @@ def forward( ctx.weight_quantizer = weight_quantizer ctx.backward_input_needs_gather = ( - weight.requires_grad and parallel_mode == "column" and sequence_parallel + weight_requires_grad and parallel_mode == "column" and sequence_parallel ) # Discard unneeded data in input tensor @@ -447,7 +450,7 @@ def forward( ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - if fuse_wgrad_accumulation and weight.requires_grad: + if fuse_wgrad_accumulation and weight_requires_grad: # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates # the main_grad buffer lazily before backprop @@ -473,12 +476,12 @@ def forward( ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size - ctx.requires_dgrad = inp.requires_grad - ctx.requires_wgrad = weight.requires_grad + ctx.requires_dgrad = inp_requires_grad + ctx.requires_wgrad = weight_requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): + if ctx.fp8 and inp_requires_grad and weight_requires_grad and bias_requires_grad: _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 3414581f7c..3ee8031149 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -360,9 +360,16 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) - + instance._dtype = dtype return instance + @property + def dtype(self) -> torch.dtype: + # Attribute access of custom tensors goes through an + # expensive Pyobject lookup. Since dtype for a tensor is never + # change after creation, we cache it in a member variable. + return self._dtype + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 43cbdcf9e6..2be8c8b0ae 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -910,6 +910,13 @@ def fsdp_post_all_gather( ) return out, all_gather_outputs + @property + def shape(self): + return self._data.shape if self._data is not None else self._transpose.shape + @property + def is_cuda(self): + return self._data.is_cuda if self._data is not None else self._transpose.is_cuda + @classmethod def _make_in_reduce_ex( cls, From 06338bc6c72eddb632c577e3e6216ebeedcef27f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Jan 2026 18:17:45 +0000 Subject: [PATCH 02/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 95 ++++++++----------- .../pytorch/tensor/float8_tensor.py | 1 + 3 files changed, 41 insertions(+), 57 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 52e819363c..85d89bff8e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1109,4 +1109,4 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor } cublas_path(); } -} \ No newline at end of file +} diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6c3d835b20..b5612f6632 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -148,7 +148,8 @@ std::pair Float8Quantizer::create_tensor( if (!scale_inv) { scale_inv = at::reciprocal(scale); } - at::Device device = with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); + at::Device device = + with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); // Construct Python FP8 tensor py::object out_py; if (internal) { @@ -161,10 +162,8 @@ std::pair Float8Quantizer::create_tensor( PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(Float8TensorStoragePythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), + PyTuple_New(0), kwargs); Py_DECREF(kwargs); @@ -183,10 +182,8 @@ std::pair Float8Quantizer::create_tensor( PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(Float8TensorPythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorPythonClass), PyTuple_New(0), kwargs); Py_DECREF(kwargs); @@ -380,31 +377,30 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); scale_inv_tensor = at::empty(scale_inv_shape, opts); } - at::Device device = with_data ? data_tensor.device() : (with_transpose ? transpose_tensor.device() : torch::kCUDA); + at::Device device = with_data ? data_tensor.device() + : (with_transpose ? transpose_tensor.device() : torch::kCUDA); // Construct Python FP8 tensor py::object out_py; py::object scale_inv_py = py::cast(scale_inv_tensor); py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - // Use direct C API call bypassing pybind11 overhead - PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "data", data_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); - PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); - PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); - - PyObject* result = PyObject_Call( - reinterpret_cast(Float8TensorStoragePythonClass), - PyTuple_New(0), - kwargs); - - Py_DECREF(kwargs); - - NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); - out_py = py::reinterpret_steal(result); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), + PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { const std::vector shape_int64(shape.begin(), shape.end()); // Use direct C API call bypassing pybind11 overhead @@ -417,10 +413,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(Float8TensorPythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorPythonClass), PyTuple_New(0), kwargs); Py_DECREF(kwargs); @@ -698,10 +692,9 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), + PyTuple_New(0), kwargs); Py_DECREF(kwargs); @@ -721,10 +714,8 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(Float8BlockwiseQTensorPythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), + PyTuple_New(0), kwargs); Py_DECREF(kwargs); @@ -1044,10 +1035,8 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(MXFP8TensorStoragePythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), + PyTuple_New(0), kwargs); Py_DECREF(kwargs); @@ -1065,10 +1054,8 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(MXFP8TensorPythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), PyTuple_New(0), kwargs); Py_DECREF(kwargs); @@ -1355,10 +1342,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(NVFP4TensorStoragePythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), + PyTuple_New(0), kwargs); Py_DECREF(kwargs); @@ -1378,10 +1363,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyObject* result = PyObject_Call( - reinterpret_cast(NVFP4TensorPythonClass), - PyTuple_New(0), - kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), PyTuple_New(0), kwargs); Py_DECREF(kwargs); diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 2be8c8b0ae..4e04708898 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -913,6 +913,7 @@ def fsdp_post_all_gather( @property def shape(self): return self._data.shape if self._data is not None else self._transpose.shape + @property def is_cuda(self): return self._data.is_cuda if self._data is not None else self._transpose.is_cuda From 50de9cdcc7770cd151e7236f042a579219bdc533 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 6 Jan 2026 12:34:01 +0000 Subject: [PATCH 03/31] requires_grad optimization Signed-off-by: Varun Thumbe --- .../pytorch/quantized_tensor.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 3ee8031149..2dda791911 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -360,6 +360,7 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) + instance._requires_grad = requires_grad instance._dtype = dtype return instance @@ -370,6 +371,29 @@ def dtype(self) -> torch.dtype: # change after creation, we cache it in a member variable. return self._dtype + @property + def requires_grad(self) -> bool: + # Attribute access of custom tensors goes through an + # expensive Pyobject lookup. Since requires_grad is set during + # initialization and may be updated, we cache it in a member variable. + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + # Update the cached value + self._requires_grad = value + # Call parent class to ensure autograd engine is aware of the change + torch.Tensor.requires_grad.fset(self, value) + + def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + # Update the cached value + self._requires_grad = requires_grad + # Call parent class method to ensure autograd engine is aware + super().requires_grad_(requires_grad) + return self + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( From 62b88e18da36e2a9b4aac8b90bc83cf80fde519a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jan 2026 12:36:29 +0000 Subject: [PATCH 04/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/quantized_tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 2dda791911..69d07b35f0 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -393,7 +393,6 @@ def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: super().requires_grad_(requires_grad) return self - def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( From 99494d7a5a7aafe60bd3e6a8fe73c1fbf80995df Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 7 Jan 2026 17:19:42 +0000 Subject: [PATCH 05/31] test if commenting out requires_grad works Signed-off-by: Varun Thumbe --- .../pytorch/quantized_tensor.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 2dda791911..6c9314663d 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -360,7 +360,6 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) - instance._requires_grad = requires_grad instance._dtype = dtype return instance @@ -371,27 +370,27 @@ def dtype(self) -> torch.dtype: # change after creation, we cache it in a member variable. return self._dtype - @property - def requires_grad(self) -> bool: - # Attribute access of custom tensors goes through an - # expensive Pyobject lookup. Since requires_grad is set during - # initialization and may be updated, we cache it in a member variable. - return self._requires_grad - - @requires_grad.setter - def requires_grad(self, value: bool) -> None: - # Update the cached value - self._requires_grad = value - # Call parent class to ensure autograd engine is aware of the change - torch.Tensor.requires_grad.fset(self, value) - - def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - # Update the cached value - self._requires_grad = requires_grad - # Call parent class method to ensure autograd engine is aware - super().requires_grad_(requires_grad) - return self + # @property + # def requires_grad(self) -> bool: + # # Attribute access of custom tensors goes through an + # # expensive Pyobject lookup. Since requires_grad is set during + # # initialization and may be updated, we cache it in a member variable. + # return self._requires_grad + + # @requires_grad.setter + # def requires_grad(self, value: bool) -> None: + # # Update the cached value + # self._requires_grad = value + # # Call parent class to ensure autograd engine is aware of the change + # torch.Tensor.requires_grad.fset(self, value) + + # def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: + # # pylint: disable=missing-function-docstring + # # Update the cached value + # self._requires_grad = requires_grad + # # Call parent class method to ensure autograd engine is aware + # super().requires_grad_(requires_grad) + # return self def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: From b61a6a81b5e3f47685768b1a827e1a8ca3f254d1 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 7 Jan 2026 17:58:32 +0000 Subject: [PATCH 06/31] fix minor bug Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/csrc/util.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 35e97b80cb..96fd2ccb3a 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -15,7 +15,7 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING || + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } From 88dfdbdcce40c3f660430405b4884316e26c9895 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 11 Jan 2026 19:03:34 +0000 Subject: [PATCH 07/31] fix ci Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/module/linear.py | 2 +- .../pytorch/quantized_tensor.py | 42 +++++++++---------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a49652a2c2..69fac7682d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -481,7 +481,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and inp_requires_grad and weight_requires_grad and bias_requires_grad: + if ctx.fp8 and (inp_requires_grad or weight_requires_grad or bias_requires_grad): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 62fa9b1114..45f00c1d67 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -370,27 +370,27 @@ def dtype(self) -> torch.dtype: # change after creation, we cache it in a member variable. return self._dtype - # @property - # def requires_grad(self) -> bool: - # # Attribute access of custom tensors goes through an - # # expensive Pyobject lookup. Since requires_grad is set during - # # initialization and may be updated, we cache it in a member variable. - # return self._requires_grad - - # @requires_grad.setter - # def requires_grad(self, value: bool) -> None: - # # Update the cached value - # self._requires_grad = value - # # Call parent class to ensure autograd engine is aware of the change - # torch.Tensor.requires_grad.fset(self, value) - - # def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: - # # pylint: disable=missing-function-docstring - # # Update the cached value - # self._requires_grad = requires_grad - # # Call parent class method to ensure autograd engine is aware - # super().requires_grad_(requires_grad) - # return self + @property + def requires_grad(self) -> bool: + # Attribute access of custom tensors goes through an + # expensive Pyobject lookup. Since requires_grad is set during + # initialization and may be updated, we cache it in a member variable. + return self._requires_grad + + @requires_grad.setter + def requires_grad(self, value: bool) -> None: + # Update the cached value + self._requires_grad = value + # Call parent class to ensure autograd engine is aware of the change + torch.Tensor.requires_grad.fset(self, value) + + def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + # Update the cached value + self._requires_grad = requires_grad + # Call parent class method to ensure autograd engine is aware + super().requires_grad_(requires_grad) + return self def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" From 5809dccb8e78fc6f6461a0618888bcaf027fd698 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 11 Jan 2026 19:12:36 +0000 Subject: [PATCH 08/31] missed a bug Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/quantized_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 45f00c1d67..69d07b35f0 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -360,6 +360,7 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) + instance._requires_grad = requires_grad instance._dtype = dtype return instance From 30fecf2e879a62e1433dfc7aa39cde7ffd4902a5 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 12 Jan 2026 00:45:19 +0530 Subject: [PATCH 09/31] Update transformer_engine/pytorch/csrc/quantizer.cpp Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 --- transformer_engine/pytorch/csrc/quantizer.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b5612f6632..512d943a5c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -143,11 +143,12 @@ std::pair Float8Quantizer::create_tensor( transpose.reset(); } py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); - py::object scale_inv_py = py::cast(scale_inv); + py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); // Initialize scale-inverse tensor if (!scale_inv) { scale_inv = at::reciprocal(scale); } + py::object scale_inv_py = py::cast(*scale_inv); at::Device device = with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); // Construct Python FP8 tensor From 1b0d49774e54b6a616e655799b927520ebb0dda0 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 11 Jan 2026 20:04:09 +0000 Subject: [PATCH 10/31] fix some bugs pointed to by copilot Signed-off-by: Varun Thumbe --- .../common/gemm/cublaslt_gemm.cu | 7 ++- .../pytorch/csrc/extensions/pybind.cpp | 14 +++--- transformer_engine/pytorch/csrc/quantizer.cpp | 47 ++++++++++++------- .../pytorch/quantized_tensor.py | 20 +++++--- .../pytorch/tensor/float8_tensor.py | 2 + 5 files changed, 58 insertions(+), 32 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 85d89bff8e..689cca74ac 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -120,6 +120,11 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // Set conditions for MXFP8 and NVFP4 gemm execution. const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); + int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling + if(is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) + { + is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + } // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { @@ -129,7 +134,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { @@ -221,7 +225,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 4b14e8c019..a022c57915 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,7 +35,7 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; -bool is_extension_initialized = false; +std::once_flag extension_init_flag; void init_float8_extension() { auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); @@ -102,12 +102,12 @@ void init_nvfp4_extensions() { } void init_extension() { - if (is_extension_initialized) return; - init_float8_extension(); - init_mxfp8_extension(); - init_float8blockwise_extension(); - init_nvfp4_extensions(); - is_extension_initialized = true; + std::call_once(extension_init_flag, []() { + init_float8_extension(); + init_mxfp8_extension(); + init_float8blockwise_extension(); + init_nvfp4_extensions(); + }); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 512d943a5c..33465eb78b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -143,7 +143,6 @@ std::pair Float8Quantizer::create_tensor( transpose.reset(); } py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); - py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); // Initialize scale-inverse tensor if (!scale_inv) { scale_inv = at::reciprocal(scale); @@ -156,18 +155,18 @@ std::pair Float8Quantizer::create_tensor( if (internal) { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); + PyObject* args = PyTuple_New(0); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), - PyTuple_New(0), kwargs); + args, kwargs); Py_DECREF(kwargs); - + Py_DECREF(args); NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); out_py = py::reinterpret_steal(result); } else { @@ -175,6 +174,7 @@ std::pair Float8Quantizer::create_tensor( // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); + PyObject* args = PyTuple_New(0); PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); PyDict_SetItemString(kwargs, "data", data_py.ptr()); @@ -184,9 +184,10 @@ std::pair Float8Quantizer::create_tensor( PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); PyObject* result = - PyObject_Call(reinterpret_cast(Float8TensorPythonClass), PyTuple_New(0), kwargs); + PyObject_Call(reinterpret_cast(Float8TensorPythonClass), args, kwargs); Py_DECREF(kwargs); + Py_DECREF(args); NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); out_py = py::reinterpret_steal(result); @@ -393,11 +394,12 @@ std::pair Float8CurrentScalingQuantizer::create_tenso PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), - PyTuple_New(0), kwargs); + args, kwargs); + Py_DECREF(args); Py_DECREF(kwargs); NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); @@ -414,9 +416,11 @@ std::pair Float8CurrentScalingQuantizer::create_tenso PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + PyObject* args = PyTuple_New(0); PyObject* result = - PyObject_Call(reinterpret_cast(Float8TensorPythonClass), PyTuple_New(0), kwargs); + PyObject_Call(reinterpret_cast(Float8TensorPythonClass), args, kwargs); + Py_DECREF(args); Py_DECREF(kwargs); NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); @@ -693,10 +697,12 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); + PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), - PyTuple_New(0), kwargs); + args, kwargs); + Py_DECREF(args); Py_DECREF(kwargs); NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); @@ -714,10 +720,10 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); - + PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), - PyTuple_New(0), kwargs); - + args, kwargs); + Py_DECREF(args); Py_DECREF(kwargs); NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); @@ -1029,6 +1035,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve if (internal) { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); + PyObject* args = PyTuple_New(0); PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); @@ -1037,8 +1044,9 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), - PyTuple_New(0), kwargs); + args, kwargs); + Py_DECREF(args); Py_DECREF(kwargs); NVTE_CHECK(result != nullptr, "Failed to create MXFP8TensorStorage instance"); @@ -1055,9 +1063,11 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyObject* args = PyTuple_New(0); PyObject* result = - PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), PyTuple_New(0), kwargs); + PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), args, kwargs); + Py_DECREF(args); Py_DECREF(kwargs); NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); @@ -1343,9 +1353,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), - PyTuple_New(0), kwargs); - + args, kwargs); + Py_DECREF(args); Py_DECREF(kwargs); NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); @@ -1364,9 +1375,11 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyObject* args = PyTuple_New(0); PyObject* result = - PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), PyTuple_New(0), kwargs); + PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args, kwargs); + Py_DECREF(args); Py_DECREF(kwargs); NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 69d07b35f0..a99e99faca 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -366,26 +366,34 @@ def __new__( @property def dtype(self) -> torch.dtype: - # Attribute access of custom tensors goes through an - # expensive Pyobject lookup. Since dtype for a tensor is never - # change after creation, we cache it in a member variable. + """ + Return the high precision data type of the tensor + Attribute access of custom tensors goes through an + expensive Pyobject lookup. Since dtype for a tensor is never + change after creation, we cache it in a member variable and return + """ return self._dtype @property def requires_grad(self) -> bool: - # Attribute access of custom tensors goes through an - # expensive Pyobject lookup. Since requires_grad is set during - # initialization and may be updated, we cache it in a member variable. + """ + Return whether or not the tensor requires gradient. + Attribute access of custom tensors goes through an + expensive Pyobject lookup. Since requires_grad is set during + initialization and may be updated, we cache it in a member variable. + """ return self._requires_grad @requires_grad.setter def requires_grad(self, value: bool) -> None: + """ Set requires_grad property so that autograd engine is aware of the change """ # Update the cached value self._requires_grad = value # Call parent class to ensure autograd engine is aware of the change torch.Tensor.requires_grad.fset(self, value) def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: + """ Cache requires_grad property and call parent class method """ # pylint: disable=missing-function-docstring # Update the cached value self._requires_grad = requires_grad diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 4e04708898..066c55d582 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -912,10 +912,12 @@ def fsdp_post_all_gather( @property def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" return self._data.shape if self._data is not None else self._transpose.shape @property def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" return self._data.is_cuda if self._data is not None else self._transpose.is_cuda @classmethod From 138b7bfb9fadc19ec6507a021d7cfcf413104990 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 11 Jan 2026 20:04:53 +0000 Subject: [PATCH 11/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/gemm/cublaslt_gemm.cu | 5 ++-- transformer_engine/pytorch/csrc/quantizer.cpp | 25 +++++++++---------- .../pytorch/quantized_tensor.py | 4 +-- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 689cca74ac..0699b1876d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -120,9 +120,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // Set conditions for MXFP8 and NVFP4 gemm execution. const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); - int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling - if(is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) - { + int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling + if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) { is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 33465eb78b..f8251332f1 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -162,8 +162,8 @@ std::pair Float8Quantizer::create_tensor( PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), - args, kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), args, kwargs); Py_DECREF(kwargs); Py_DECREF(args); @@ -396,8 +396,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); - PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), - args, kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), args, kwargs); Py_DECREF(args); Py_DECREF(kwargs); @@ -698,9 +698,8 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); PyObject* args = PyTuple_New(0); - PyObject* result = - PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), - args, kwargs); + PyObject* result = PyObject_Call( + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), args, kwargs); Py_DECREF(args); Py_DECREF(kwargs); @@ -721,8 +720,8 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); PyObject* args = PyTuple_New(0); - PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), - args, kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), args, kwargs); Py_DECREF(args); Py_DECREF(kwargs); @@ -1043,8 +1042,8 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), - args, kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), args, kwargs); Py_DECREF(args); Py_DECREF(kwargs); @@ -1354,8 +1353,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); - PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), - args, kwargs); + PyObject* result = + PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), args, kwargs); Py_DECREF(args); Py_DECREF(kwargs); diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index a99e99faca..b57752bc79 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -386,14 +386,14 @@ def requires_grad(self) -> bool: @requires_grad.setter def requires_grad(self, value: bool) -> None: - """ Set requires_grad property so that autograd engine is aware of the change """ + """Set requires_grad property so that autograd engine is aware of the change""" # Update the cached value self._requires_grad = value # Call parent class to ensure autograd engine is aware of the change torch.Tensor.requires_grad.fset(self, value) def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: - """ Cache requires_grad property and call parent class method """ + """Cache requires_grad property and call parent class method""" # pylint: disable=missing-function-docstring # Update the cached value self._requires_grad = requires_grad From eec1e865c44d1858769ca518e11f62795fa515f1 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 11 Jan 2026 20:19:05 +0000 Subject: [PATCH 12/31] linting error Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/module/linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 69fac7682d..30f121655c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -31,7 +31,6 @@ clear_tensor_data, divide, init_method_constant, - requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, assert_dim_for_all_gather, From 8169d9c2ef3a797689c7b7166268629d33d58bed Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Mon, 12 Jan 2026 08:27:08 +0000 Subject: [PATCH 13/31] fix the error Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/quantized_tensor.py | 7 +++++++ transformer_engine/pytorch/tensor/float8_tensor.py | 2 ++ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 13 +++++++++++++ transformer_engine/pytorch/tensor/nvfp4_tensor.py | 3 +++ 4 files changed, 25 insertions(+) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index b57752bc79..29a5cc7e31 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -392,6 +392,13 @@ def requires_grad(self, value: bool) -> None: # Call parent class to ensure autograd engine is aware of the change torch.Tensor.requires_grad.fset(self, value) + @dtype.setter + def dtype(self, value: torch.dtype) -> None: + """Set dtype property""" + # Update the cached value + self._dtype = value + warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") + def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: """Cache requires_grad property and call parent class method""" # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 066c55d582..8d6ebc0400 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -993,6 +993,8 @@ def _set_data(self, tensor: torch.Tensor) -> None: # pylint: disable=unnecessary-dunder-call super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) + self._dtype = tensor.dtype + self._requires_grad = tensor.requires_grad # Float8Tensor attributes self._data = tensor._data self._quantizer = tensor._quantizer.copy() diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 88081f51bf..68527aee16 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -785,6 +785,10 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) + # Cache the attributes + self._dtype = tensor.dtype + self._requires_grad = tensor.requires_grad + self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer.copy() @@ -803,6 +807,15 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Cast to FP8 when setting MXFP8Tensor.data data = property(_get_data, _set_data) + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + return self._rowwise_data.shape if self._rowwise_data is not None else self._columnwise_data.shape + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + return self._rowwise_data.is_cuda if self._rowwise_data is not None else self._columnwise_data.is_cuda class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 8b707af3b2..afb212960d 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -689,6 +689,9 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + # Cache the attributes + self._dtype = tensor.dtype + self._requires_grad = tensor.requires_grad self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer From 6fefaf28c1f732d2df346ca4a1b2bf395b92e797 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Jan 2026 08:27:53 +0000 Subject: [PATCH 14/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 68527aee16..59bd29c95f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -810,12 +810,21 @@ def _set_data(self, tensor: torch.Tensor) -> None: @property def shape(self): """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" - return self._rowwise_data.shape if self._rowwise_data is not None else self._columnwise_data.shape + return ( + self._rowwise_data.shape + if self._rowwise_data is not None + else self._columnwise_data.shape + ) @property def is_cuda(self): """Return whether the tensor is on a CUDA device.""" - return self._rowwise_data.is_cuda if self._rowwise_data is not None else self._columnwise_data.is_cuda + return ( + self._rowwise_data.is_cuda + if self._rowwise_data is not None + else self._columnwise_data.is_cuda + ) + class _ViewFunc(torch.autograd.Function): """View function From a5feaf944ba84e914eb8fc9fc4b4d14fc90a5a84 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 13 Jan 2026 10:56:28 +0000 Subject: [PATCH 15/31] fix the bug Signed-off-by: Varun Thumbe --- tests/pytorch/test_fusible_ops.py | 2 +- .../pytorch/quantized_tensor.py | 20 +++++++++---------- .../pytorch/tensor/float8_tensor.py | 3 +-- .../pytorch/tensor/mxfp8_tensor.py | 5 ++--- .../pytorch/tensor/nvfp4_tensor.py | 6 +++--- 5 files changed, 16 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index ce15dd1421..91d891b284 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -52,7 +52,7 @@ _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] # Supported quantization recipes -_quantization_list: list[Optional[str]] = [None] +_quantization_list: list[Optional[str]] = [] if fp8_available: _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 29a5cc7e31..ee2bd7b746 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -374,6 +374,13 @@ def dtype(self) -> torch.dtype: """ return self._dtype + @dtype.setter + def dtype(self, value: torch.dtype) -> None: + """Set dtype property""" + # Update the cached value + self._dtype = value + warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") + @property def requires_grad(self) -> bool: """ @@ -387,17 +394,8 @@ def requires_grad(self) -> bool: @requires_grad.setter def requires_grad(self, value: bool) -> None: """Set requires_grad property so that autograd engine is aware of the change""" - # Update the cached value - self._requires_grad = value - # Call parent class to ensure autograd engine is aware of the change - torch.Tensor.requires_grad.fset(self, value) - - @dtype.setter - def dtype(self, value: torch.dtype) -> None: - """Set dtype property""" - # Update the cached value - self._dtype = value - warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") + # Update the cached value and call parent class method to ensure autograd engine is aware + self.requires_grad_(value) def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: """Cache requires_grad property and call parent class method""" diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 8d6ebc0400..7ca3f59199 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -992,9 +992,8 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) + self.dtype = tensor.dtype - self._dtype = tensor.dtype - self._requires_grad = tensor.requires_grad # Float8Tensor attributes self._data = tensor._data self._quantizer = tensor._quantizer.copy() diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 68527aee16..b954342ac4 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -785,9 +785,8 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) - # Cache the attributes - self._dtype = tensor.dtype - self._requires_grad = tensor.requires_grad + # Cache the attributes + self.dtype = tensor.dtype self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index afb212960d..cc8c348aa2 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -689,9 +689,9 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) - # Cache the attributes - self._dtype = tensor.dtype - self._requires_grad = tensor.requires_grad + # Cache the attributes + self.dtype = tensor.dtype + self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer From afb2f230c2897cadbccf06a2b0cd3e615a423d52 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 13 Jan 2026 11:01:59 +0000 Subject: [PATCH 16/31] get rid of the change Signed-off-by: Varun Thumbe --- tests/pytorch/test_fusible_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 91d891b284..ce15dd1421 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -52,7 +52,7 @@ _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] # Supported quantization recipes -_quantization_list: list[Optional[str]] = [] +_quantization_list: list[Optional[str]] = [None] if fp8_available: _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: From 3919cb844f64c52b40fb3fbfeb592ea786370bb8 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 13 Jan 2026 14:51:13 +0000 Subject: [PATCH 17/31] fix the transpose shape bug Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/tensor/float8_tensor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 7ca3f59199..e4662570c8 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -913,7 +913,11 @@ def fsdp_post_all_gather( @property def shape(self): """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" - return self._data.shape if self._data is not None else self._transpose.shape + if self._data is not None: + return self._data.shape + else: + transpose_shape = self._transpose.shape + return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) @property def is_cuda(self): From 46681339ae4182d80e61eef5176c774efdaf9cbe Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 13 Jan 2026 18:28:36 +0000 Subject: [PATCH 18/31] minor linter fix Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index e4662570c8..5403a2f80c 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -917,7 +917,7 @@ def shape(self): return self._data.shape else: transpose_shape = self._transpose.shape - return torch.Size(tuple(transpose_shape[1:]) + (transpose_shape[0],)) + return tuple(transpose_shape[1:]) + (transpose_shape[0],) @property def is_cuda(self): From 5a0065295f58a0dfeabd48058d85578ed16fe0ee Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Tue, 13 Jan 2026 18:35:04 +0000 Subject: [PATCH 19/31] fix lint Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/tensor/float8_tensor.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5403a2f80c..c10aa82ce2 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -915,9 +915,11 @@ def shape(self): """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" if self._data is not None: return self._data.shape - else: + elif self._transpose is not None: transpose_shape = self._transpose.shape return tuple(transpose_shape[1:]) + (transpose_shape[0],) + else: + raise RuntimeError("Both data and transpose are None") @property def is_cuda(self): From 739bbad08adbe314d48ca45e7dbc63bf55d41b87 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 16 Jan 2026 15:10:50 +0000 Subject: [PATCH 20/31] fix linting error Signed-off-by: Varun Thumbe --- transformer_engine/common/transformer_engine.cpp | 2 +- transformer_engine/pytorch/tensor/float8_tensor.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 82c50c4ebd..370d9723cf 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - static int num_devices = transformer_engine::cuda::num_devices(); + int num_devices = transformer_engine::cuda::num_devices(); static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index c10aa82ce2..2ab72de92d 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -918,8 +918,7 @@ def shape(self): elif self._transpose is not None: transpose_shape = self._transpose.shape return tuple(transpose_shape[1:]) + (transpose_shape[0],) - else: - raise RuntimeError("Both data and transpose are None") + raise RuntimeError("Both data and transpose are None") @property def is_cuda(self): From e8042c1446fd8c0f3a3840e198bbadb219282976 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 16 Jan 2026 16:39:39 +0000 Subject: [PATCH 21/31] address copilot review comment regarding error check when both data and transpose are None Signed-off-by: Varun Thumbe --- .../pytorch/quantized_tensor.py | 6 ++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 17 +++++++++++++++ .../pytorch/tensor/float8_tensor.py | 6 +++++- .../pytorch/tensor/mxfp8_tensor.py | 21 +++++++++---------- 4 files changed, 38 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 9bd54decac..d15b07ee57 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -367,6 +367,9 @@ def dtype(self) -> torch.dtype: expensive Pyobject lookup. Since dtype for a tensor is never change after creation, we cache it in a member variable and return """ + # Lazy initialization for tensors created via alternate paths + if not hasattr(self, '_dtype'): + self._dtype = super(QuantizedTensor, self).__getattribute__('dtype') return self._dtype @dtype.setter @@ -384,6 +387,9 @@ def requires_grad(self) -> bool: expensive Pyobject lookup. Since requires_grad is set during initialization and may be updated, we cache it in a member variable. """ + # Fallback to parent if not cached yet + if not hasattr(self, "_requires_grad"): + self._requires_grad = super(QuantizedTensor, self).__getattribute__('requires_grad') return self._requires_grad @requires_grad.setter diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 03c16ebbed..2523856483 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -602,6 +602,23 @@ def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): # Cast to FP8 when setting Float8BlockwiseQTensor.data data = property(_get_data, _set_data) + @property + def shape(self): + """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" + if self._rowwise_data is not None: + return self._rowwise_data.shape + elif self._columnwise_data is not None: + return self._columnwise_data.shape + raise RuntimeError("Float8BlockwiseQTensor has no data!") + + @property + def is_cuda(self): + """Return whether the tensor is on a CUDA device.""" + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + elif self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("Float8BlockwiseQTensor has no data!") class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 2ab72de92d..da44aa1527 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -923,7 +923,11 @@ def shape(self): @property def is_cuda(self): """Return whether the tensor is on a CUDA device.""" - return self._data.is_cuda if self._data is not None else self._transpose.is_cuda + if self._data is not None: + return self._data.is_cuda + elif self._transpose is not None: + return self._transpose.is_cuda + raise RuntimeError("Both data and transpose are None") @classmethod def _make_in_reduce_ex( diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 56a8356cb0..63d1d4e009 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -809,21 +809,20 @@ def _set_data(self, tensor: torch.Tensor) -> None: @property def shape(self): """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" - return ( - self._rowwise_data.shape - if self._rowwise_data is not None - else self._columnwise_data.shape - ) + if self._rowwise_data is not None: + return self._rowwise_data.shape + elif self._columnwise_data is not None: + return self._columnwise_data.shape + raise RuntimeError("MXFP8Tensor has no data!") @property def is_cuda(self): """Return whether the tensor is on a CUDA device.""" - return ( - self._rowwise_data.is_cuda - if self._rowwise_data is not None - else self._columnwise_data.is_cuda - ) - + if self._rowwise_data is not None: + return self._rowwise_data.is_cuda + elif self._columnwise_data is not None: + return self._columnwise_data.is_cuda + raise RuntimeError("MXFP8Tensor has no data!") class _ViewFunc(torch.autograd.Function): """View function From 1d323d770e8268064dc206213d4f4cc7f8e21f19 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jan 2026 16:40:29 +0000 Subject: [PATCH 22/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/quantized_tensor.py | 6 +++--- .../pytorch/tensor/float8_blockwise_tensor.py | 1 + transformer_engine/pytorch/tensor/mxfp8_tensor.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index d15b07ee57..d5fbdc6498 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -368,8 +368,8 @@ def dtype(self) -> torch.dtype: change after creation, we cache it in a member variable and return """ # Lazy initialization for tensors created via alternate paths - if not hasattr(self, '_dtype'): - self._dtype = super(QuantizedTensor, self).__getattribute__('dtype') + if not hasattr(self, "_dtype"): + self._dtype = super(QuantizedTensor, self).__getattribute__("dtype") return self._dtype @dtype.setter @@ -389,7 +389,7 @@ def requires_grad(self) -> bool: """ # Fallback to parent if not cached yet if not hasattr(self, "_requires_grad"): - self._requires_grad = super(QuantizedTensor, self).__getattribute__('requires_grad') + self._requires_grad = super(QuantizedTensor, self).__getattribute__("requires_grad") return self._requires_grad @requires_grad.setter diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 2523856483..9cc069b74a 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -620,6 +620,7 @@ def is_cuda(self): return self._columnwise_data.is_cuda raise RuntimeError("Float8BlockwiseQTensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 63d1d4e009..5c5d4300ec 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -824,6 +824,7 @@ def is_cuda(self): return self._columnwise_data.is_cuda raise RuntimeError("MXFP8Tensor has no data!") + class _ViewFunc(torch.autograd.Function): """View function From e2c7435e8dacd26c238c7891fffbee107ab8c74a Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 16 Jan 2026 17:45:59 +0000 Subject: [PATCH 23/31] fix linting errors Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/quantized_tensor.py | 8 ++++---- .../pytorch/tensor/float8_blockwise_tensor.py | 4 ++-- transformer_engine/pytorch/tensor/float8_tensor.py | 4 ++-- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index d15b07ee57..d11177600f 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -355,8 +355,8 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) - instance._requires_grad = requires_grad - instance._dtype = dtype + # instance._requires_grad = requires_grad + # instance._dtype = dtype return instance @property @@ -369,7 +369,7 @@ def dtype(self) -> torch.dtype: """ # Lazy initialization for tensors created via alternate paths if not hasattr(self, '_dtype'): - self._dtype = super(QuantizedTensor, self).__getattribute__('dtype') + self._dtype = torch._C.TensorBase.dtype.__get__(self, type(self)) return self._dtype @dtype.setter @@ -389,7 +389,7 @@ def requires_grad(self) -> bool: """ # Fallback to parent if not cached yet if not hasattr(self, "_requires_grad"): - self._requires_grad = super(QuantizedTensor, self).__getattribute__('requires_grad') + self._requires_grad = torch._C.TensorBase.requires_grad.__get__(self, type(self)) return self._requires_grad @requires_grad.setter diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 2523856483..cf71b6d7d1 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -607,7 +607,7 @@ def shape(self): """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" if self._rowwise_data is not None: return self._rowwise_data.shape - elif self._columnwise_data is not None: + if self._columnwise_data is not None: return self._columnwise_data.shape raise RuntimeError("Float8BlockwiseQTensor has no data!") @@ -616,7 +616,7 @@ def is_cuda(self): """Return whether the tensor is on a CUDA device.""" if self._rowwise_data is not None: return self._rowwise_data.is_cuda - elif self._columnwise_data is not None: + if self._columnwise_data is not None: return self._columnwise_data.is_cuda raise RuntimeError("Float8BlockwiseQTensor has no data!") diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index da44aa1527..6bc3e42a0a 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -915,7 +915,7 @@ def shape(self): """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" if self._data is not None: return self._data.shape - elif self._transpose is not None: + if self._transpose is not None: transpose_shape = self._transpose.shape return tuple(transpose_shape[1:]) + (transpose_shape[0],) raise RuntimeError("Both data and transpose are None") @@ -925,7 +925,7 @@ def is_cuda(self): """Return whether the tensor is on a CUDA device.""" if self._data is not None: return self._data.is_cuda - elif self._transpose is not None: + if self._transpose is not None: return self._transpose.is_cuda raise RuntimeError("Both data and transpose are None") diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 63d1d4e009..14da896ac2 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -811,7 +811,7 @@ def shape(self): """Return the shape of the tensor. Define this to avoid expensive PyObject lookups.""" if self._rowwise_data is not None: return self._rowwise_data.shape - elif self._columnwise_data is not None: + if self._columnwise_data is not None: return self._columnwise_data.shape raise RuntimeError("MXFP8Tensor has no data!") @@ -820,7 +820,7 @@ def is_cuda(self): """Return whether the tensor is on a CUDA device.""" if self._rowwise_data is not None: return self._rowwise_data.is_cuda - elif self._columnwise_data is not None: + if self._columnwise_data is not None: return self._columnwise_data.is_cuda raise RuntimeError("MXFP8Tensor has no data!") From beada368be0c43a64934e20d321eb34733d2fa52 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 16 Jan 2026 17:50:18 +0000 Subject: [PATCH 24/31] missed a merge conflict Signed-off-by: Varun Thumbe --- transformer_engine/pytorch/quantized_tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 7880631018..d11177600f 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -389,11 +389,7 @@ def requires_grad(self) -> bool: """ # Fallback to parent if not cached yet if not hasattr(self, "_requires_grad"): -<<<<<<< HEAD self._requires_grad = torch._C.TensorBase.requires_grad.__get__(self, type(self)) -======= - self._requires_grad = super(QuantizedTensor, self).__getattribute__("requires_grad") ->>>>>>> da7fbf53566d334cb3468e270cf0b197ff5c1c5b return self._requires_grad @requires_grad.setter From 06a72a27906f8b5aa45c05bb54363f998b42126d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jan 2026 17:51:04 +0000 Subject: [PATCH 25/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/quantized_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index d11177600f..e6c2f92dff 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -368,7 +368,7 @@ def dtype(self) -> torch.dtype: change after creation, we cache it in a member variable and return """ # Lazy initialization for tensors created via alternate paths - if not hasattr(self, '_dtype'): + if not hasattr(self, "_dtype"): self._dtype = torch._C.TensorBase.dtype.__get__(self, type(self)) return self._dtype From 5d21db2e1d266426fe570302c18f9a43b2c66e4f Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Fri, 16 Jan 2026 21:43:10 +0000 Subject: [PATCH 26/31] final optimizations Signed-off-by: Varun Thumbe --- transformer_engine/common/util/cuda_driver.h | 21 +++- transformer_engine/pytorch/csrc/quantizer.cpp | 119 +++++++++++------- 2 files changed, 91 insertions(+), 49 deletions(-) diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 2715d8e4e4..a038b9e1c2 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -10,6 +10,7 @@ #include #include +#include #include "../common.h" #include "../util/string.h" @@ -29,13 +30,31 @@ void *get_symbol(const char *symbol, int cuda_version = 12010); * without GPUs. Indirect function calls into a lazily-initialized * library ensures we are accessing the correct version. * + * Symbol pointers are cached to avoid repeated lookups. + * * \param[in] symbol Function name * \param[in] args Function arguments */ template inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); - FuncT *func = reinterpret_cast(get_symbol(symbol)); + + // Cache for symbol pointers + static std::unordered_map symbol_cache; + + // Check if symbol is already cached + auto it = symbol_cache.find(symbol); + FuncT *func; + + if (it != symbol_cache.end()) { + func = reinterpret_cast(it->second); + } else { + // Symbol not in cache, look it up and cache the result + void *ptr = get_symbol(symbol); + symbol_cache[symbol] = ptr; + func = reinterpret_cast(ptr); + } + return (*func)(args...); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index f8251332f1..c7f6a927da 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -149,7 +149,8 @@ std::pair Float8Quantizer::create_tensor( } py::object scale_inv_py = py::cast(*scale_inv); at::Device device = - with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); + with_data ? data->device() : (with_transpose ? transpose->device() : + at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; if (internal) { @@ -158,13 +159,15 @@ std::pair Float8Quantizer::create_tensor( PyObject* args = PyTuple_New(0); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), args, kwargs); - + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(kwargs); Py_DECREF(args); NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); @@ -175,17 +178,19 @@ std::pair Float8Quantizer::create_tensor( // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); PyObject* args = PyTuple_New(0); - PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).inc_ref().ptr()); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), args, kwargs); - + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(kwargs); Py_DECREF(args); @@ -380,7 +385,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso scale_inv_tensor = at::empty(scale_inv_shape, opts); } at::Device device = with_data ? data_tensor.device() - : (with_transpose ? transpose_tensor.device() : torch::kCUDA); + : (with_transpose ? transpose_tensor.device() : + at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; py::object scale_inv_py = py::cast(scale_inv_tensor); @@ -391,35 +397,38 @@ std::pair Float8CurrentScalingQuantizer::create_tenso PyObject* kwargs = PyDict_New(); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), args, kwargs); - + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(args); Py_DECREF(kwargs); - NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); out_py = py::reinterpret_steal(result); } else { const std::vector shape_int64(shape.begin(), shape.end()); // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).inc_ref().ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), args, kwargs); - + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(args); Py_DECREF(kwargs); @@ -688,19 +697,21 @@ std::pair Float8BlockQuantizer::create_tensor( if (internal) { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); - PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); - PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); - PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); - PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); - PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.inc_ref().ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).inc_ref().ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call( reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), args, kwargs); - + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(args); Py_DECREF(kwargs); @@ -709,22 +720,24 @@ std::pair Float8BlockQuantizer::create_tensor( } else { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); - PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); - PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); - PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); - PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); - PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); - PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).inc_ref().ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(args); Py_DECREF(kwargs); - NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); ret = py::reinterpret_steal(result); } @@ -1039,12 +1052,14 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), args, kwargs); - + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(args); Py_DECREF(kwargs); @@ -1053,19 +1068,21 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } else { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), args, kwargs); - + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(args); Py_DECREF(kwargs); @@ -1349,12 +1366,16 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); - PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); + PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), args, kwargs); + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(args); Py_DECREF(kwargs); @@ -1363,21 +1384,23 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve } else { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); - PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args, kwargs); - + if (result == nullptr) { + PyErr_Print(); + } Py_DECREF(args); Py_DECREF(kwargs); From 8c8dd20fe34e9b2eef99a85e34b1d4e59aac55c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jan 2026 21:44:35 +0000 Subject: [PATCH 27/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/cuda_driver.h | 8 +++---- transformer_engine/pytorch/csrc/quantizer.cpp | 22 +++++++++++-------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index a038b9e1c2..8de3d9ba5b 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -38,14 +38,14 @@ void *get_symbol(const char *symbol, int cuda_version = 12010); template inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); - + // Cache for symbol pointers static std::unordered_map symbol_cache; - + // Check if symbol is already cached auto it = symbol_cache.find(symbol); FuncT *func; - + if (it != symbol_cache.end()) { func = reinterpret_cast(it->second); } else { @@ -54,7 +54,7 @@ inline CUresult call(const char *symbol, ArgTs... args) { symbol_cache[symbol] = ptr; func = reinterpret_cast(ptr); } - + return (*func)(args...); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c7f6a927da..0e4a0ca355 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -149,8 +149,9 @@ std::pair Float8Quantizer::create_tensor( } py::object scale_inv_py = py::cast(*scale_inv); at::Device device = - with_data ? data->device() : (with_transpose ? transpose->device() : - at::Device(torch::kCUDA, c10::cuda::current_device())); + with_data ? data->device() + : (with_transpose ? transpose->device() + : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; if (internal) { @@ -384,9 +385,10 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); scale_inv_tensor = at::empty(scale_inv_shape, opts); } - at::Device device = with_data ? data_tensor.device() - : (with_transpose ? transpose_tensor.device() : - at::Device(torch::kCUDA, c10::cuda::current_device())); + at::Device device = + with_data ? data_tensor.device() + : (with_transpose ? transpose_tensor.device() + : at::Device(torch::kCUDA, c10::cuda::current_device())); // Construct Python FP8 tensor py::object out_py; py::object scale_inv_py = py::cast(scale_inv_tensor); @@ -700,7 +702,8 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).inc_ref().ptr()); PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).inc_ref().ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", + py::cast(scale_inv_colwise).inc_ref().ptr()); PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.inc_ref().ptr()); PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).inc_ref().ptr()); @@ -725,9 +728,10 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).inc_ref().ptr()); PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).inc_ref().ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", + py::cast(scale_inv_colwise).inc_ref().ptr()); PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).inc_ref().ptr()); PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).inc_ref().ptr()); PyObject* args = PyTuple_New(0); @@ -1370,7 +1374,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); - + PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), args, kwargs); if (result == nullptr) { From c1acd62c707d0be08315aa961ef11a255a18b5e7 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 18 Jan 2026 14:43:22 +0000 Subject: [PATCH 28/31] fix ci error Signed-off-by: Varun Thumbe --- transformer_engine/common/util/cuda_driver.h | 25 +++---- transformer_engine/pytorch/csrc/quantizer.cpp | 70 +++++++++---------- .../pytorch/quantized_tensor.py | 26 +++++-- .../pytorch/tensor/float8_tensor.py | 1 - .../pytorch/tensor/mxfp8_tensor.py | 2 - .../pytorch/tensor/nvfp4_tensor.py | 2 - 6 files changed, 68 insertions(+), 58 deletions(-) diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index a038b9e1c2..c75d60e84a 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -11,7 +11,7 @@ #include #include - +#include #include "../common.h" #include "../util/string.h" @@ -39,23 +39,24 @@ template inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); - // Cache for symbol pointers static std::unordered_map symbol_cache; + static std::unordered_map init_flags; + static std::mutex init_mutex; - // Check if symbol is already cached - auto it = symbol_cache.find(symbol); - FuncT *func; + // Get or create the once_flag for this symbol. + std::once_flag *flag_ptr; + { + std::lock_guard lock(init_mutex); + flag_ptr = &init_flags[symbol]; // Safe: mutex protects map insertion + } - if (it != symbol_cache.end()) { - func = reinterpret_cast(it->second); - } else { - // Symbol not in cache, look it up and cache the result + // Use call_once with the flag (lock-free on subsequent calls) + std::call_once(*flag_ptr, [&]() { void *ptr = get_symbol(symbol); symbol_cache[symbol] = ptr; - func = reinterpret_cast(ptr); - } + }); - return (*func)(args...); + return (*reinterpret_cast(symbol_cache[symbol]))(args...); } /*! \brief Ensure that the calling thread has a CUDA context diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c7f6a927da..91359e0589 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -159,7 +159,7 @@ std::pair Float8Quantizer::create_tensor( PyObject* args = PyTuple_New(0); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); @@ -178,14 +178,14 @@ std::pair Float8Quantizer::create_tensor( // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); PyObject* args = PyTuple_New(0); - PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "device", py::cast(device).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), args, kwargs); if (result == nullptr) { @@ -397,7 +397,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso PyObject* kwargs = PyDict_New(); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); @@ -415,14 +415,14 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const std::vector shape_int64(shape.begin(), shape.end()); // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); PyDict_SetItemString(kwargs, "data", data_py.ptr()); PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "device", py::cast(device).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorPythonClass), args, kwargs); @@ -697,14 +697,14 @@ std::pair Float8BlockQuantizer::create_tensor( if (internal) { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "quantizer", this->quantizer.inc_ref().ptr()); - PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call( @@ -720,16 +720,16 @@ std::pair Float8BlockQuantizer::create_tensor( } else { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); - PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), args, kwargs); @@ -1052,7 +1052,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* result = @@ -1068,13 +1068,13 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } else { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); - PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); @@ -1366,7 +1366,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); - PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); @@ -1384,15 +1384,15 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve } else { // Use direct C API call bypassing pybind11 overhead PyObject* kwargs = PyDict_New(); - PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).inc_ref().ptr()); - PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); - PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).inc_ref().ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyObject* args = PyTuple_New(0); diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index e6c2f92dff..4b41ad89de 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -355,8 +355,8 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) - # instance._requires_grad = requires_grad - # instance._dtype = dtype + instance._requires_grad = requires_grad + instance._dtype = dtype return instance @property @@ -369,15 +369,13 @@ def dtype(self) -> torch.dtype: """ # Lazy initialization for tensors created via alternate paths if not hasattr(self, "_dtype"): - self._dtype = torch._C.TensorBase.dtype.__get__(self, type(self)) + self._dtype = torch._C.TensorBase.dtype.__get__(self, type(self)) # pylint: disable=unnecessary-dunder-call return self._dtype @dtype.setter def dtype(self, value: torch.dtype) -> None: """Set dtype property""" - # Update the cached value self._dtype = value - warnings.warn("Dtype of QuantizedTensor has been changed. Ensure this is intended.") @property def requires_grad(self) -> bool: @@ -389,7 +387,7 @@ def requires_grad(self) -> bool: """ # Fallback to parent if not cached yet if not hasattr(self, "_requires_grad"): - self._requires_grad = torch._C.TensorBase.requires_grad.__get__(self, type(self)) + self._requires_grad = torch._C.TensorBase.requires_grad.__get__(self, type(self)) # pylint: disable=unnecessary-dunder-call return self._requires_grad @requires_grad.setter @@ -407,6 +405,22 @@ def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: super().requires_grad_(requires_grad) return self + def _get_data(self) -> torch.Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + Updates the underlying tensor data and syncs the dtype cache. + """ + # Update the parent class's data descriptor + super(QuantizedTensor, type(self)).data.__set__(self, tensor) + # Update the dtype cache + self._dtype = tensor.dtype + + # Create the data property with getter and setter + data = property(_get_data, _set_data) + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 6bc3e42a0a..0589568c9c 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1001,7 +1001,6 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) - self.dtype = tensor.dtype # Float8Tensor attributes self._data = tensor._data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 98fb59f387..344784a085 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -785,8 +785,6 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) - # Cache the attributes - self.dtype = tensor.dtype self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index cc8c348aa2..83c1721068 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -689,8 +689,6 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) - # Cache the attributes - self.dtype = tensor.dtype self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data From 1538fd9d26aea1b4e543685fc55e2fca843c874a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jan 2026 14:49:00 +0000 Subject: [PATCH 29/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/cuda_driver.h | 13 +++++++------ transformer_engine/pytorch/csrc/quantizer.cpp | 2 +- transformer_engine/pytorch/quantized_tensor.py | 8 ++++++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index c75d60e84a..2bb187a33e 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -9,9 +9,10 @@ #include +#include #include #include -#include + #include "../common.h" #include "../util/string.h" @@ -38,24 +39,24 @@ void *get_symbol(const char *symbol, int cuda_version = 12010); template inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); - + static std::unordered_map symbol_cache; static std::unordered_map init_flags; static std::mutex init_mutex; - + // Get or create the once_flag for this symbol. std::once_flag *flag_ptr; { std::lock_guard lock(init_mutex); flag_ptr = &init_flags[symbol]; // Safe: mutex protects map insertion } - + // Use call_once with the flag (lock-free on subsequent calls) std::call_once(*flag_ptr, [&]() { void *ptr = get_symbol(symbol); symbol_cache[symbol] = ptr; - }); - + }); + return (*reinterpret_cast(symbol_cache[symbol]))(args...); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 5b70a5121f..4e2d6b1640 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -718,7 +718,7 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); - PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 240e406a64..c2e244c6a0 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -384,7 +384,9 @@ def dtype(self) -> torch.dtype: """ # Lazy initialization for tensors created via alternate paths if not hasattr(self, "_dtype"): - self._dtype = torch._C.TensorBase.dtype.__get__(self, type(self)) # pylint: disable=unnecessary-dunder-call + self._dtype = torch._C.TensorBase.dtype.__get__( + self, type(self) + ) # pylint: disable=unnecessary-dunder-call return self._dtype @dtype.setter @@ -402,7 +404,9 @@ def requires_grad(self) -> bool: """ # Fallback to parent if not cached yet if not hasattr(self, "_requires_grad"): - self._requires_grad = torch._C.TensorBase.requires_grad.__get__(self, type(self)) # pylint: disable=unnecessary-dunder-call + self._requires_grad = torch._C.TensorBase.requires_grad.__get__( + self, type(self) + ) # pylint: disable=unnecessary-dunder-call return self._requires_grad @requires_grad.setter From 710b581635e290a4e06f79c86de2fedb2a85ee89 Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Sun, 18 Jan 2026 15:38:29 +0000 Subject: [PATCH 30/31] address review comment from greptile Signed-off-by: Varun Thumbe --- transformer_engine/common/util/cuda_driver.h | 26 +++++++++---------- transformer_engine/pytorch/csrc/quantizer.cpp | 2 -- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index c75d60e84a..471e9449c7 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -40,23 +40,21 @@ inline CUresult call(const char *symbol, ArgTs... args) { using FuncT = CUresult(ArgTs...); static std::unordered_map symbol_cache; - static std::unordered_map init_flags; - static std::mutex init_mutex; + static std::mutex cache_mutex; - // Get or create the once_flag for this symbol. - std::once_flag *flag_ptr; + FuncT* func; { - std::lock_guard lock(init_mutex); - flag_ptr = &init_flags[symbol]; // Safe: mutex protects map insertion + std::lock_guard lock(cache_mutex); + auto it = symbol_cache.find(symbol); + if (it == symbol_cache.end()) { + void *ptr = get_symbol(symbol); + symbol_cache[symbol] = ptr; + func = reinterpret_cast(ptr); + } else { + func = reinterpret_cast(it->second); + } } - - // Use call_once with the flag (lock-free on subsequent calls) - std::call_once(*flag_ptr, [&]() { - void *ptr = get_symbol(symbol); - symbol_cache[symbol] = ptr; - }); - - return (*reinterpret_cast(symbol_cache[symbol]))(args...); + return (*func)(args...); } /*! \brief Ensure that the calling thread has a CUDA context diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 5b70a5121f..f2d69626e6 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -695,7 +695,6 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); - PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); PyObject* args = PyTuple_New(0); PyObject* result = PyObject_Call( @@ -720,7 +719,6 @@ std::pair Float8BlockQuantizer::create_tensor( PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); - PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); PyObject* args = PyTuple_New(0); PyObject* result = From 7e4f093260ba88e0294d936ba4a1c89df0ca0e23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jan 2026 15:40:51 +0000 Subject: [PATCH 31/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/cuda_driver.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index fe46afb163..16242347f1 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -42,7 +42,7 @@ inline CUresult call(const char *symbol, ArgTs... args) { static std::unordered_map symbol_cache; static std::mutex cache_mutex; - FuncT* func; + FuncT *func; { std::lock_guard lock(cache_mutex);