Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
93ee022
add all the optimizations
vthumbe1503 Jan 5, 2026
06338bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2026
50de9cd
requires_grad optimization
vthumbe1503 Jan 6, 2026
5fee841
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 6, 2026
4c79ac7
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 6, 2026
62b88e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2026
99494d7
test if commenting out requires_grad works
vthumbe1503 Jan 7, 2026
b157f85
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
2a7b627
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 7, 2026
b61a6a8
fix minor bug
vthumbe1503 Jan 7, 2026
938651e
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
88dfdbd
fix ci
vthumbe1503 Jan 11, 2026
1526eea
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 11, 2026
5809dcc
missed a bug
vthumbe1503 Jan 11, 2026
b3bd748
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 11, 2026
30fecf2
Update transformer_engine/pytorch/csrc/quantizer.cpp
vthumbe1503 Jan 11, 2026
1b0d497
fix some bugs pointed to by copilot
vthumbe1503 Jan 11, 2026
138b7bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2026
eec1e86
linting error
vthumbe1503 Jan 11, 2026
8169d9c
fix the error
vthumbe1503 Jan 12, 2026
6fefaf2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2026
a5feaf9
fix the bug
vthumbe1503 Jan 13, 2026
285dbff
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 13, 2026
afb2f23
get rid of the change
vthumbe1503 Jan 13, 2026
3919cb8
fix the transpose shape bug
vthumbe1503 Jan 13, 2026
fd36424
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 13, 2026
4668133
minor linter fix
vthumbe1503 Jan 13, 2026
5a00652
fix lint
vthumbe1503 Jan 13, 2026
739bbad
fix linting error
vthumbe1503 Jan 16, 2026
e8042c1
address copilot review comment regarding error check when both data a…
vthumbe1503 Jan 16, 2026
1d323d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
da7fbf5
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 16, 2026
e2c7435
fix linting errors
vthumbe1503 Jan 16, 2026
f4e2492
fix merge conflict
vthumbe1503 Jan 16, 2026
beada36
missed a merge conflict
vthumbe1503 Jan 16, 2026
06a72a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
5d21db2
final optimizations
vthumbe1503 Jan 16, 2026
1dfd6fe
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 16, 2026
8c8dd20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
c1acd62
fix ci error
vthumbe1503 Jan 18, 2026
7f35b0b
fix merge conflixt
vthumbe1503 Jan 18, 2026
ca177ae
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 18, 2026
1538fd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2026
710b581
address review comment from greptile
vthumbe1503 Jan 18, 2026
8a57a75
fix merge conflixt
vthumbe1503 Jan 18, 2026
7e4f093
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// Set conditions for MXFP8 and NVFP4 gemm execution.
const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);
int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling
if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) {
is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
}

// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) {
Expand All @@ -129,7 +133,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +144,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +224,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +235,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
21 changes: 20 additions & 1 deletion transformer_engine/common/util/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include <cuda.h>

#include <mutex>
#include <string>
#include <unordered_map>

#include "../common.h"
#include "../util/string.h"
Expand All @@ -29,13 +31,30 @@ void *get_symbol(const char *symbol, int cuda_version = 12010);
* without GPUs. Indirect function calls into a lazily-initialized
* library ensures we are accessing the correct version.
*
* Symbol pointers are cached to avoid repeated lookups.
*
* \param[in] symbol Function name
* \param[in] args Function arguments
*/
template <typename... ArgTs>
inline CUresult call(const char *symbol, ArgTs... args) {
using FuncT = CUresult(ArgTs...);
FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));

static std::unordered_map<std::string, void *> symbol_cache;
static std::mutex cache_mutex;
FuncT *func;

{
std::lock_guard<std::mutex> lock(cache_mutex);
auto it = symbol_cache.find(symbol);
if (it == symbol_cache.end()) {
void *ptr = get_symbol(symbol);
symbol_cache[symbol] = ptr;
func = reinterpret_cast<FuncT *>(ptr);
} else {
func = reinterpret_cast<FuncT *>(it->second);
}
}
return (*func)(args...);
}

Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
Expand All @@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()


Expand Down
15 changes: 7 additions & 8 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
std::once_flag extension_init_flag;

void init_float8_extension() {
if (Float8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Expand All @@ -54,7 +54,6 @@ void init_float8_extension() {
}

void init_mxfp8_extension() {
if (MXFP8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor");
MXFP8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer"));
Expand All @@ -69,7 +68,6 @@ void init_mxfp8_extension() {
}

void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorStoragePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
Expand All @@ -90,7 +88,6 @@ void init_float8blockwise_extension() {
}

void init_nvfp4_extensions() {
if (NVFP4TensorPythonClass) return;
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
Expand All @@ -105,10 +102,12 @@ void init_nvfp4_extensions() {
}

void init_extension() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}

} // namespace transformer_engine::pytorch
Expand Down
Loading