feat(compute): tiny-matrix batched GEMM kernel (ADR 075 L3)#150
Merged
Conversation
…apes (ADR 075 L3) cuBLAS SgemmStridedBatched routes tiny matrices through a GEMV + split-K reduction fan-out (gemvNSP_kernel + splitKreduce_kernel): 2-3 internal kernels per logical GEMM, none tiling efficiently for m,n,k <= ~64. A fresh GB10 nsys breakdown of CrossAsset batched-eager training attributes ~17% of GPU kernel time to this fan-out on the 12x12 Q@K^T and 12x64 weights@V attention matmuls (batch = B*heads = 1024). Add a general-purpose custom kernel: one strided-batched GEMM in ONE launch, one CUDA thread-block per batch element, A[m,k] and B[k,n] tiles staged in shared memory, f32 accumulation matching cuBLAS Sgemm. The GPU MatMul dispatches to it when m,n,k are all <= 64 and batch > 1, falling back to cuBLAS on any kernel error (correctness never depends on the fast path). Gated by ZERFOO_DISABLE_TINY_GEMM for A/B comparison. CPU path unchanged. General mechanism: any small batched matmul benefits, not just attention. Tests (compute/gpu_tiny_batched_gemm_test.go, GPU-gated): - tiny-GEMM vs CPU reference parity on the exact attention shapes + broadcast-B + tile-boundary 64^3 + asymmetric shapes, zero NaN - tiny-GEMM vs cuBLAS SgemmStridedBatched equivalence (toggle the gate) - finite-difference gradcheck through the GPU MatMul Wiring follows the fused_adamw kernel pattern: .cu/.h + cgo + purego bindings, purego symbol registration, KernelRunner interface method across all backends.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
A custom small-matrix strided-batched GEMM CUDA kernel for f32, dispatched from the GPU engine's
MatMulwhenm,n,kare all<= 64andbatch > 1. Replaces cuBLASSgemmStridedBatchedon tiny shapes where cuBLAS falls back to a GEMV + split-K reduction fan-out (gemvNSP_kernel+splitKreduce_kernel).Why (measured)
A fresh GB10 nsys kernel-time breakdown of Wolf CrossAsset batched-eager training (zerfoo v1.51.0 / ztensor v1.12.0, the 1.62 s/epoch binary) attributes ~17% of GPU kernel time to this fan-out on the 12x12 Q@K^T and 12x64 weights@V attention matmuls (batch = B*heads = 1024). The fingerprint (
gemvNSP437 ms +gemvx640 ms +splitKreduce89 ms over the profiled run) is exactly the cuBLAS tiny-matrix routing. The custom kernel does each logical GEMM in one launch (one thread-block per batch element, A/B tiles in shared memory, f32 accumulation matching cuBLAS Sgemm).General-purpose
Any small batched matmul benefits, not just attention. The kernel is shape-gated, not Wolf-specific.
Safety
ZERFOO_DISABLE_TINY_GEMM=1for A/B comparison.Tests (
compute/gpu_tiny_batched_gemm_test.go, GPU-gated, skip without CUDA)SgemmStridedBatchedequivalence (toggle the gate).Full non-GPU suite green locally (
go test ./..., 30 pkgs ok). GPU tests run on the DGX (sm_121); thelibkernels.sorebuild with the new.cuis part of the Wolf L3 verification.Wiring
Follows the merged
fused_adamwkernel pattern:.cu/.h+ cgo + purego bindings, purego symbol registration,KernelRunnerinterface method across all backends (CUDA real; ROCm/Metal/SYCL/FPGA/OpenCL stubs).Plan: Wolf
docs/plan-speed-to-parity.mdT14L3.1 (ADR 075).