Tensorax is a deep learning framework written from scratch in C++/CUDA with a Python frontend. Every kernel — matmul, attention, elementwise ops, reductions — is hand-written. No PyTorch, no NumPy, no cuBLAS at runtime. The only dependency is pybind11 for the C++/Python bridge.
The goal is a clean, readable implementation of a DL framework from first principles that also runs fast on real hardware. The MMA attention kernel uses inline PTX assembly to hit Ampere Tensor Cores, and the best matmul variant runs at ~3x NumPy speed — all without calling into any external math library.
pip install tensoraxThe API is intentionally PyTorch-like, so the learning curve is minimal:
from tensorax import Tensor, nn, optim, lr_scheduler, functional as F
# define a model
model = nn.Sequential(
nn.Linear(4, 8),
nn.GELU(),
nn.LayerNorm(8),
nn.Linear(8, 3),
)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# train
for epoch in range(100):
loss = F.mse_loss(model(x_train), y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()More examples in examples/ and the full API reference in docs/USAGE.md.
Tensor core. CPU and CUDA backends with automatic fallback. Broadcasting arithmetic, reshape, transpose, sum, mean, exp, log, sqrt, pow. Reverse-mode autograd through 18+ operations. 13 dtype constants.
Layers. Linear, Embedding, Sequential, Dropout. Activations: ReLU, Sigmoid, Tanh, Softmax, GELU, SiLU. Norms: LayerNorm, RMSNorm, BatchNorm.
Attention. Scaled dot-product attention, Multi-Head Attention, and Grouped Query Attention — each backed by 5 CUDA kernel variants (naive, tiled, flash, optimized flash, MMA Tensor Core). Causal and padding mask support.
Training. SGD with momentum, Adam with bias correction. MSE, cross-entropy, and cross-entropy-from-logits losses. 5 LR schedulers: StepLR, CosineAnnealingLR, ExponentialLR, LinearLR, MultiStepLR.
CUDA kernels. 6 matmul implementations (naive through 2D block tiling), 5 attention kernels, 14 element-wise ops. Shared memory tiling, coalesced access patterns, and mma.sync Tensor Core instructions where it matters.
Matmul — fp32, 3x1024x1024, 100 iterations:
PyTorch CUDA (cuBLAS) 0.08s 22.24x
Tensorax 2D Block Tiling 0.58s 2.97x <- best
Tensorax 1D Block Tiling 0.64s 2.68x
Tensorax Tiled 0.83s 2.05x
Tensorax Cache Blocking 0.98s 1.75x
Tensorax SM Coalesced 1.14s 1.50x
Tensorax Default 1.18s 1.45x
NumPy CPU (baseline) 1.71s 1.00x
Attention — fp32/fp16, B=4 H=8 S=256 Dk=512 Dv=512, 30 iterations:
PyTorch SDPA 0.04s 2340x
Tensorax MMA Tensor Core 0.33s 297x <- best
Tensorax Optim. Flash 0.52s 187x
Tensorax Flash SDPA 3.10s 31x
NumPy CPU (baseline) 7.06s 14x
Tensorax Tiled SDPA 32.91s 3x
Tensorax Naive SDPA 98.26s 1x
The MMA kernel achieves a 9.3x speedup over the flash kernel by dropping into inline PTX to use mma.sync Ampere Tensor Core instructions and SFU fast-math intrinsics. Still ~8x behind PyTorch's heavily optimized SDPA — closing that gap is ongoing work.
csrc/
cuda/kernels/ elementwise, matmul (x6), reduction, attention (x5)
cpu/ CPU fallback for all ops
tensor_ops.cpp/.h pybind11 bindings
tensorax/
tensor.py Tensor class + autograd engine
functional.py F.relu, F.gelu, F.softmax, F.sdpa, losses, ...
nn/ Linear, Embedding, norms, dropout, attention (MHA, GQA)
optim.py SGD, Adam
lr_scheduler.py StepLR, CosineAnnealingLR, ExponentialLR, LinearLR, MultiStepLR
What's here now: core tensor ops, autograd, all the layers/norms/activations listed above, two optimizers, five LR schedulers, three loss functions, five attention kernels, six matmul variants, MHA, GQA, embeddings.
What's next: Conv2D, MaxPool2D, AdamW, tensor indexing/slicing, model serialization, DataLoader, multi-GPU, mixed precision, DDP, ONNX export.
- Usage Guide — full API reference with code examples
- Architecture — system design, kernel strategy, autograd internals
- Development — building from source, testing, contributing
- Examples — runnable scripts
@software{tensorax2025,
title = {Tensorax: Pure C++/CUDA Tensor Library},
author = {Shrirang Mahajan},
year = {2025},
url = {https://github.com/NotShrirang/tensorax}
}