A minimal LoRA fine-tuning library with hand-written Triton kernels for the forward and backward passes.
This is a personal learning project to get into kernel programming for ML. Kernels are written from scratch using Triton.
fastaf is designed to integrate easily into existing PyTorch and HuggingFace scripts:
from fastaf import LoRAConfig, init
init(model) # freeze base weights, inject adapter registry
model.add_new_adapter(
'sql',
LoRAConfig(r=16, alpha=32, targets=['q_proj', 'v_proj']),
)
model.set_active_adapter('sql') # activate; deactivate with None
model.remove_adapter('sql') # remove weights and registry entryFrom here, use the model exactly as you normally would: training with SFTTrainer, inference with model.generate(). When running on CUDA, the Triton kernels are invoked automatically for both forward and backward passes. On CPU, the library falls back to plain PyTorch.
from fastaf import save_adapter_to_local_dir, load_adapter, push_adapter_to_hub
save_adapter_to_local_dir(model, path='./checkpoints', name='sql')
load_adapter(model, path_or_repo_id='./checkpoints')
push_adapter_to_hub(model, name='sql', repo_id='username/my-sql-adapter')That covers the full fastaf public API.
In standard PyTorch, a single LoRA layer's forward pass
out = X @ W.T + (X @ A.T) @ B.T * scaledispatches six separate CUDA kernels:
X @ W.T— base matrix multiply (cuBLAS)X @ A.T— LoRA down-projection- A split-K reduction cleanup (cuBLAS internal, fired for rectangular A shapes)
intermediate @ B.T— LoRA up-projection- Element-wise multiply by
scale = alpha / r - Element-wise add onto the base output
Each kernel launch carries overhead: the GPU receives the launch command, sets up execution state, and loads operands from VRAM. Between launches the GPU sits idle — these are the "pipeline bubbles" visible as gaps in a profiler trace.
The fastaf forward kernel collapses steps 2–6 into a single kernel launch. The key insight is that the intermediate result of the down-projection (X @ A.T, shape [seq, rank]) never needs to be written to VRAM. Each thread block owns a tile of the final output and keeps the intermediate result in registers across both matrix multiplications. Scaling and base addition happen as a cheap epilogue before the single store.
Note: run
fastaf-tunebefore benchmarking. Numbers below are after tuning on RTX 5060 Ti.
uv run python -m benchmarks.lora.bench_fwd
...
======================================================================
bench_fwd.py SUMMARY
======================================================================
Overall Win Rate: 75.0% (81/108 configs faster than unfused)
Average Speedup by Hidden Dimension:
d_model=512 : 1.23x
d_model=1024 : 1.34x
d_model=2048 : 1.29x
Peak Performance:
3.45x speedup at batch=16, seq=1024, rank=16
======================================================================- Workload Scaling (The Overhead Threshold): The fused kernel shines in throughput-heavy scenarios. As batch size and sequence length increase (e.g., batch=16, seq=1024), the time saved by eliminating redundant memory reads/writes massively outweighs the initial kernel launch overhead, unlocking peak speedups of up to 3.45x.
- Small Workload Regression: For extremely small computations (e.g., batch=1, seq=256), standard PyTorch unfused cuBLAS is often faster. In these micro-workloads, the overhead of launching a custom Triton kernel simply takes longer than executing multiple tiny, highly optimized cuBLAS operations.
The three backward kernels (dA, dB, dX) follow the same register-resident fusion strategy, but still suck a little bit:
uv run python -m benchmarks.lora.bench_bwd
...
======================================================================
bench_bwd.py SUMMARY (RTX 5060 Ti)
======================================================================
Overall Win Rate: 0.0% (0/108 configs faster than unfused)
Average Speedup by Hidden Dimension:
d_model=512 : 0.60x
d_model=1024 : 0.68x
d_model=2048 : 0.73x
Peak Performance:
0.92x speedup at batch=4, seq=1024, rank=8
======================================================================The general problem is that during the intermediate register-resident computations in the backward kernels, we suffer from serial loops over the large sequence dimension that starved the GPU of occupancy and caused the "fusion tax" to outweigh the benefits (which works for the fused forward kernel only because there we only loop over the small rank dimension).
We already went from peak ~0.5x to a ~0.92x speedup by switching dA to a sequence-parallel architecture with atomic reductions, but dB and dX still require this same refactor to eliminate their remaining serial bottlenecks.
Triton has a feature, called autotune, to test and select the best block sizes, number of warps and number of prefetching stages for a kernel from a given set of configurations.
fastaf includes this feature and lets you tune kernel configurations via:
fastaf-tuneIt will benchmark candidate configurations against a representative set of input shapes, selects the winner for each shape, and writes the results to ~/.cache/fastaf/block_sizes.json. You only need to do this once. All subsequent runs read from the cache.
How the tuner prunes the search space: rather than benchmarking every possible block size permutation (which would take many hours), fastaf-tune first filters candidates using device properties from torch.cuda.get_device_properties() — discarding configs that exceed the per-block thread count limit, estimated register budget, shared memory capacity, or a minimum occupancy threshold. The surviving configs are then passed to triton.autotune, which benchmarks them directly on the GPU and selects the fastest. On an RTX 5060 Ti this takes roughly 1.5 hours. (Still way too long, I need to work on that.)
If a shape is not in the cache the library uses nearest-neighbour interpolation. If there are no cached entries at all it falls back to a hardcoded default.
fastaf is fully compatible with torch.compile and introduces zero graph breaks in the fastaf layers themselves (verified by the test suite via torch._dynamo.explain).
This required registering each Triton kernel as a @torch.library.custom_op with a @register_fake implementation. Without this, torch.compile's tracer would encounter the kernel calls as opaque Python functions and insert a graph break at every LoRA layer. The @register_fake functions provide the compiler with output shape and dtype information so it can trace through the full model graph without interruption.
Note: graph breaks from HuggingFace Transformers internals (dynamic control flow, shape-dependent branches) are outside fastaf's control and may still appear when compiling a full model.
Requirements:
- Python ≥ 3.14
- CUDA toolkit 13.0
- NVIDIA GPU with compute capability ≥ 9.0 (Hopper / Ada Lovelace or newer). Primary development target is Blackwell (RTX 50xx, sm_120).
Older CUDA versions:
pyproject.tomlpulls PyTorch from thecu130index. Adjust[tool.uv.sources]and[[tool.uv.index]]to point at the appropriate wheel index for your CUDA version (e.g.cu121).
Install:
git clone https://github.com/your-username/fastaf
cd fastaf
make installmake install runs uv sync --all-groups, which creates a virtual environment and installs all dependencies including dev tools. If you don't have uv:
curl -LsSf https://astral.sh/uv/install.sh | shFormat and lint:
make format # runs ruff check --fix and ruff formatUnder active development with known limitations:
bwd_xandbwd_bkernels are slower than unfused PyTorch — needs a better tiling strategy for the dX memory traffic problem- Tuning time (~1.5 hours) is too high — maybe interpolate between few well tuned shapes?
- Better CLI for cache management (target individual kernels, specify shapes)
- Tensor parallelism is not supported — kernels assume the full weight matrix is on one device. Pipeline parallelism via HuggingFace
accelerateworks.
Contributions very welcome, especially bug reports, correctness issues, and simplifications — if something is more complex than it needs to be, flag it.
The codebase is structured to make adding new kernel families straightforward. Each kernel has a consistent interface: a @triton.jit kernel, a launcher, a @torch.library.custom_op wrapper, a generate_kernel_configs_* function, and a tune_bs_* generator that plugs into fastaf-tune.