From f88a06cc55c508f29f5e534b7e2c569049afa3ab Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Tue, 16 Jun 2026 11:38:33 -0700 Subject: [PATCH] feat(compute): tiny-matrix batched GEMM kernel for small attention shapes (ADR 075 L3) cuBLAS SgemmStridedBatched routes tiny matrices through a GEMV + split-K reduction fan-out (gemvNSP_kernel + splitKreduce_kernel): 2-3 internal kernels per logical GEMM, none tiling efficiently for m,n,k <= ~64. A fresh GB10 nsys breakdown of CrossAsset batched-eager training attributes ~17% of GPU kernel time to this fan-out on the 12x12 Q@K^T and 12x64 weights@V attention matmuls (batch = B*heads = 1024). Add a general-purpose custom kernel: one strided-batched GEMM in ONE launch, one CUDA thread-block per batch element, A[m,k] and B[k,n] tiles staged in shared memory, f32 accumulation matching cuBLAS Sgemm. The GPU MatMul dispatches to it when m,n,k are all <= 64 and batch > 1, falling back to cuBLAS on any kernel error (correctness never depends on the fast path). Gated by ZERFOO_DISABLE_TINY_GEMM for A/B comparison. CPU path unchanged. General mechanism: any small batched matmul benefits, not just attention. Tests (compute/gpu_tiny_batched_gemm_test.go, GPU-gated): - tiny-GEMM vs CPU reference parity on the exact attention shapes + broadcast-B + tile-boundary 64^3 + asymmetric shapes, zero NaN - tiny-GEMM vs cuBLAS SgemmStridedBatched equivalence (toggle the gate) - finite-difference gradcheck through the GPU MatMul Wiring follows the fused_adamw kernel pattern: .cu/.h + cgo + purego bindings, purego symbol registration, KernelRunner interface method across all backends. --- compute/gpu_engine.go | 41 +++- compute/gpu_kernels.go | 14 ++ compute/gpu_tiny_batched_gemm_test.go | 203 ++++++++++++++++++ internal/cuda/kernels/Makefile | 2 +- internal/cuda/kernels/purego.go | 5 + internal/cuda/kernels/tiny_batched_gemm.cu | 119 ++++++++++ internal/cuda/kernels/tiny_batched_gemm.go | 45 ++++ internal/cuda/kernels/tiny_batched_gemm.h | 41 ++++ .../cuda/kernels/tiny_batched_gemm_purego.go | 38 ++++ internal/gpuapi/cuda_kernels.go | 6 + internal/gpuapi/fpga_kernels.go | 4 + internal/gpuapi/gpuapi_test.go | 3 + internal/gpuapi/kernels.go | 8 + internal/gpuapi/metal_kernels.go | 4 + internal/gpuapi/opencl_kernels.go | 4 + internal/gpuapi/rocm_kernels.go | 4 + internal/gpuapi/sycl_kernels.go | 4 + 17 files changed, 538 insertions(+), 7 deletions(-) create mode 100644 compute/gpu_tiny_batched_gemm_test.go create mode 100644 internal/cuda/kernels/tiny_batched_gemm.cu create mode 100644 internal/cuda/kernels/tiny_batched_gemm.go create mode 100644 internal/cuda/kernels/tiny_batched_gemm.h create mode 100644 internal/cuda/kernels/tiny_batched_gemm_purego.go diff --git a/compute/gpu_engine.go b/compute/gpu_engine.go index 9f78bb8..83be9b6 100644 --- a/compute/gpu_engine.go +++ b/compute/gpu_engine.go @@ -1556,13 +1556,42 @@ func (e *GPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T] // (supported by cuBLAS). This is critical for GQA attention where K/V // have fewer heads than Q, replacing N individual Sgemm calls with 1. if batchSize > 1 && isFloat32 { - if batched, ok := e.blas.(gpuapi.BLASBatched); ok { - strideA := int64(aMatSize) - strideBVal := int64(bMatSize) - if bBatchSize == 1 { - strideBVal = 0 + strideA := int64(aMatSize) + strideBVal := int64(bMatSize) + if bBatchSize == 1 { + strideBVal = 0 + } + strideC := int64(cMatSize) + + // Tiny-matrix fast path (ADR 075 L3): when m,n,k are all small, cuBLAS + // SgemmStridedBatched routes through a GEMV + split-K reduction fan-out + // (2-3 internal kernels per logical GEMM). The custom kernel does it in + // one launch, one block per batch element with the tiles in shared + // memory. This is exactly the CrossAsset attention case (12x12 Q@K^T and + // 12x64 weights@V over batch=B*heads). Falls back to cuBLAS on any + // kernel error so correctness never depends on the fast path. + if !disableTinyGemm && + m <= tinyGemmMaxDim && n <= tinyGemmMaxDim && k <= tinyGemmMaxDim { + if debugGPU { + e.logger.Debug("MatMul: TinyBatchedGemmF32 call", + "m", fmt.Sprintf("%d", m), "n", fmt.Sprintf("%d", n), + "k", fmt.Sprintf("%d", k), "batchSize", fmt.Sprintf("%d", batchSize)) } - strideC := int64(cMatSize) + if terr := e.kernels.TinyBatchedGemmF32( + devA, devB, devCTotal, m, n, k, + strideA, strideBVal, strideC, batchSize, e.stream); terr == nil { + if reusedC { + return finishReusedDst[T](dst[0], outShape), nil + } + return makeGPUResult[T](e, outShape, devCTotal, batchSize*cMatSize, dst...) + } else if debugGPU { + e.logger.Debug("MatMul: TinyBatchedGemmF32 fell back to cuBLAS", + "error", terr.Error()) + } + // Fall through to cuBLAS on tiny-kernel error. + } + + if batched, ok := e.blas.(gpuapi.BLASBatched); ok { if debugGPU { e.logger.Debug("MatMul: SgemmStridedBatched call", "m", fmt.Sprintf("%d", m), diff --git a/compute/gpu_kernels.go b/compute/gpu_kernels.go index 0e933fb..366c5ed 100644 --- a/compute/gpu_kernels.go +++ b/compute/gpu_kernels.go @@ -24,6 +24,20 @@ const f64Size = int(unsafe.Sizeof(float64(0))) // Off by default to avoid any performance impact in normal operation. var debugGPU = os.Getenv("ZERFOO_DEBUG_GPU") == "1" +// disableTinyGemm, when ZERFOO_DISABLE_TINY_GEMM=1, forces the batched MatMul +// path to use cuBLAS SgemmStridedBatched even for tiny matrices, bypassing the +// custom tiny-matrix batched-GEMM kernel (ADR 075 L3). The kernel is enabled by +// default because it is a strict win on tiny shapes (it avoids cuBLAS's +// GEMV + split-K fan-out) and falls back to cuBLAS on any launch error; this +// flag exists for A/B comparison and as a safety escape hatch. +var disableTinyGemm = os.Getenv("ZERFOO_DISABLE_TINY_GEMM") == "1" + +// tinyGemmMaxDim mirrors the kernel's TINY_GEMM_MAX_DIM bound: the custom +// tiny-matrix batched GEMM is dispatched only when m, n, and k are all at most +// this value (and batch > 1). Above it, cuBLAS tiles efficiently and the custom +// kernel offers no benefit. +const tinyGemmMaxDim = 64 + // traceH2D, when ZERFOO_TRACE_H2D=1, makes getDevicePtr emit one line per // CPU-backed host->device upload, tagged with the operand shape. It exists to // attribute the per-op H2D firehose (the [256,256]/[1024,256] weight-class diff --git a/compute/gpu_tiny_batched_gemm_test.go b/compute/gpu_tiny_batched_gemm_test.go new file mode 100644 index 0000000..cd21109 --- /dev/null +++ b/compute/gpu_tiny_batched_gemm_test.go @@ -0,0 +1,203 @@ +package compute + +import ( + "context" + "math" + "math/rand" + "testing" + + "github.com/zerfoo/ztensor/numeric" + "github.com/zerfoo/ztensor/tensor" +) + +// TestGPUEngine_TinyBatchedGemm_AttentionShapes verifies the custom tiny-matrix +// batched-GEMM kernel (ADR 075 L3), which the GPU MatMul dispatches to for small +// m,n,k, matches the CPU reference GEMM within f32 tolerance for the exact +// CrossAsset attention shapes: Q@K^T [12,64]@[64,12]->[12,12] and +// weights@V [12,12]@[12,64]->[12,64], batched over B*heads = 1024. +func TestGPUEngine_TinyBatchedGemm_AttentionShapes(t *testing.T) { + gpuEng := newTestGPUEngine(t) + cpuEng := NewCPUEngine[float32](numeric.Float32Ops{}) + ctx := context.Background() + + cases := []struct { + name string + batch int + m, k, n int + bBroadcast bool // B has batch dim 1 (strideB=0 broadcast) + }{ + {"QKt_12x64x12_b1024", 1024, 12, 64, 12, false}, + {"weightsV_12x12x64_b1024", 1024, 12, 12, 64, false}, + {"QKt_broadcastB", 256, 12, 64, 12, true}, + {"tile_boundary_64x64x64", 64, 64, 64, 64, false}, + {"asym_3x7x5", 100, 3, 7, 5, false}, + } + + rng := rand.New(rand.NewSource(42)) + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + aData := make([]float32, tc.batch*tc.m*tc.k) + for i := range aData { + aData[i] = float32(rng.NormFloat64()) + } + bBatch := tc.batch + if tc.bBroadcast { + bBatch = 1 + } + bData := make([]float32, bBatch*tc.k*tc.n) + for i := range bData { + bData[i] = float32(rng.NormFloat64()) + } + + a, _ := tensor.New[float32]([]int{tc.batch, tc.m, tc.k}, aData) + b, _ := tensor.New[float32]([]int{bBatch, tc.k, tc.n}, bData) + + gpuRes, err := gpuEng.MatMul(ctx, a, b) + if err != nil { + t.Fatalf("GPU MatMul: %v", err) + } + // CPU reference uses the same broadcasting semantics. + aCPU, _ := tensor.New[float32]([]int{tc.batch, tc.m, tc.k}, aData) + bCPU, _ := tensor.New[float32]([]int{bBatch, tc.k, tc.n}, bData) + cpuRes, err := cpuEng.MatMul(ctx, aCPU, bCPU) + if err != nil { + t.Fatalf("CPU MatMul: %v", err) + } + + gd := gpuRes.Data() + cd := cpuRes.Data() + if len(gd) != len(cd) { + t.Fatalf("length mismatch: GPU=%d CPU=%d", len(gd), len(cd)) + } + var maxAbs, maxRel float64 + for i := range gd { + if math.IsNaN(float64(gd[i])) { + t.Fatalf("NaN in GPU result at [%d]", i) + } + diff := math.Abs(float64(gd[i] - cd[i])) + if diff > maxAbs { + maxAbs = diff + } + denom := math.Abs(float64(cd[i])) + if denom > 1e-6 { + if rel := diff / denom; rel > maxRel { + maxRel = rel + } + } + } + // f32 GEMM with k up to 64: accumulation error ~ k*eps. Tolerate a + // small absolute and relative bound (the kernel and CPU sum in the + // same f32 precision; differences come only from summation order). + if maxAbs > 1e-3 || maxRel > 1e-4 { + t.Errorf("%s: tiny-GEMM vs CPU mismatch maxAbs=%e maxRel=%e", tc.name, maxAbs, maxRel) + } + }) + } +} + +// TestGPUEngine_TinyBatchedGemm_MatchesCublas verifies the tiny kernel result +// equals the cuBLAS SgemmStridedBatched result for the same inputs (the path +// it replaces), by toggling ZERFOO_DISABLE_TINY_GEMM. This guards against any +// silent divergence between the custom path and the framework's prior behavior. +func TestGPUEngine_TinyBatchedGemm_MatchesCublas(t *testing.T) { + eng := newTestGPUEngine(t) + ctx := context.Background() + + batch, m, k, n := 1024, 12, 64, 12 + rng := rand.New(rand.NewSource(7)) + aData := make([]float32, batch*m*k) + for i := range aData { + aData[i] = float32(rng.NormFloat64()) + } + bData := make([]float32, batch*k*n) + for i := range bData { + bData[i] = float32(rng.NormFloat64()) + } + + run := func(disable bool) []float32 { + t.Helper() + old := disableTinyGemm + disableTinyGemm = disable + defer func() { disableTinyGemm = old }() + a, _ := tensor.New[float32]([]int{batch, m, k}, aData) + b, _ := tensor.New[float32]([]int{batch, k, n}, bData) + res, err := eng.MatMul(ctx, a, b) + if err != nil { + t.Fatalf("MatMul (disable=%v): %v", disable, err) + } + out := make([]float32, len(res.Data())) + copy(out, res.Data()) + return out + } + + tiny := run(false) + cublas := run(true) + if len(tiny) != len(cublas) { + t.Fatalf("length mismatch tiny=%d cublas=%d", len(tiny), len(cublas)) + } + var maxAbs float64 + for i := range tiny { + d := math.Abs(float64(tiny[i] - cublas[i])) + if d > maxAbs { + maxAbs = d + } + } + // Both accumulate in f32; only summation order differs, so the gap is tiny. + if maxAbs > 1e-3 { + t.Errorf("tiny-GEMM vs cuBLAS divergence: maxAbs=%e", maxAbs) + } +} + +// TestGPUEngine_TinyBatchedGemm_Gradcheck does a finite-difference gradient +// check through the GPU batched MatMul (which uses the tiny kernel) to confirm +// the forward result is the true matmul (so any autograd built on MatMul has +// correct gradients). dC/dA = upstream @ B^T, dC/dB = A^T @ upstream; we verify +// the forward against a numerical perturbation of a few entries. +func TestGPUEngine_TinyBatchedGemm_Gradcheck(t *testing.T) { + eng := newTestGPUEngine(t) + ctx := context.Background() + + batch, m, k, n := 8, 12, 64, 12 + rng := rand.New(rand.NewSource(99)) + aData := make([]float32, batch*m*k) + for i := range aData { + aData[i] = float32(rng.NormFloat64()) * 0.1 + } + bData := make([]float32, batch*k*n) + for i := range bData { + bData[i] = float32(rng.NormFloat64()) * 0.1 + } + + matmul := func(ad, bd []float32) []float32 { + a, _ := tensor.New[float32]([]int{batch, m, k}, ad) + b, _ := tensor.New[float32]([]int{batch, k, n}, bd) + res, err := eng.MatMul(ctx, a, b) + if err != nil { + t.Fatalf("MatMul: %v", err) + } + out := make([]float32, len(res.Data())) + copy(out, res.Data()) + return out + } + + base := matmul(aData, bData) + + // Analytic dC[outIdx]/dA[aIdx]: pick batch 0, perturb A[0,i0,l0] and confirm + // the change in C[0,i0,j] equals B[0,l0,j] * eps for all j (linear in A). + const eps = float32(1e-2) + i0, l0 := 2, 5 + aIdx := (0*m+i0)*k + l0 + perturbed := make([]float32, len(aData)) + copy(perturbed, aData) + perturbed[aIdx] += eps + after := matmul(perturbed, bData) + + for j := 0; j < n; j++ { + outIdx := (0*m+i0)*n + j + got := (after[outIdx] - base[outIdx]) / eps + want := bData[(0*k+l0)*n+j] // B[0, l0, j] + if math.Abs(float64(got-want)) > 1e-2 { + t.Errorf("grad dC[0,%d,%d]/dA[0,%d,%d]: got=%f want=%f", i0, j, i0, l0, got, want) + } + } +} diff --git a/internal/cuda/kernels/Makefile b/internal/cuda/kernels/Makefile index 8759334..cbc485b 100644 --- a/internal/cuda/kernels/Makefile +++ b/internal/cuda/kernels/Makefile @@ -18,7 +18,7 @@ ifeq ($(CUDA_ARCH),sm_121) NVCC_FLAGS += -DFLASH_BLOCK_SIZE=64 endif -SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_encoder_fwd.cu fused_encoder_bwd.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_adamw.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu +SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_encoder_fwd.cu fused_encoder_bwd.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_adamw.cu tiny_batched_gemm.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu OBJS = $(SRCS:.cu=.o) PIC_OBJS = $(SRCS:.cu=.pic.o) LIB = libkernels.a diff --git a/internal/cuda/kernels/purego.go b/internal/cuda/kernels/purego.go index 563fa7c..b1663d1 100644 --- a/internal/cuda/kernels/purego.go +++ b/internal/cuda/kernels/purego.go @@ -88,6 +88,9 @@ type KernelLib struct { // fused_adamw (on-device AdamW mixed-precision optimizer step) launchFusedAdamWF32 uintptr + // tiny_batched_gemm (small-matrix strided-batched GEMM, ADR 075 L3) + launchTinyBatchedGemmF32 uintptr + // fused_repeat_interleave (GQA KV head expansion) launchRepeatInterleaveF32 uintptr @@ -288,6 +291,8 @@ func openKernelLib() (*KernelLib, error) { {"fused_swiglu_f32", &k.launchFusedSwiGLUF32}, // fused_adamw {"fused_adamw_f32", &k.launchFusedAdamWF32}, + // tiny_batched_gemm (ADR 075 L3) + {"tiny_batched_gemm_f32", &k.launchTinyBatchedGemmF32}, // fused_repeat_interleave (GQA KV head expansion) {"launch_repeat_interleave_f32", &k.launchRepeatInterleaveF32}, // fused_add_rmsnorm diff --git a/internal/cuda/kernels/tiny_batched_gemm.cu b/internal/cuda/kernels/tiny_batched_gemm.cu new file mode 100644 index 0000000..2b2cfcb --- /dev/null +++ b/internal/cuda/kernels/tiny_batched_gemm.cu @@ -0,0 +1,119 @@ +// tiny_batched_gemm.cu -- custom small-matrix strided-batched GEMM for f32. +// +// Motivation (ADR 075 lever L3): cuBLAS SgemmStridedBatched routes tiny +// matrices (e.g. CrossAsset attention's 12x12 Q@K^T and 12x64 weights@V over +// batch = B*heads = 1024) through a GEMV + split-K reduction fan-out +// (gemvNSP_kernel + splitKreduce_kernel, the T11.0c fingerprint): 2-3 internal +// kernels per logical GEMM, none of which tile efficiently for m,n,k <= ~64. +// This kernel computes one strided-batched GEMM in ONE launch with one CUDA +// thread-block per batch element, staging the A[m,k] and B[k,n] tiles in shared +// memory and writing the full C[m,n] tile cooperatively. It is GENERAL: any +// small batched matmul (m,n,k all small, batch large) benefits. +// +// Semantics match SgemmStridedBatched with alpha=1, beta=0, row-major operands: +// C_b[i,j] = sum_l A_b[i,l] * B_b[l,j], A_b = A + b*strideA (elements), etc. +// Accumulation is in f32, matching cuBLAS Sgemm (the reference path this +// replaces) so the result is bit-comparable within f32 GEMM tolerance. +// +// Dispatch guard (host side, gpu_engine.go): used only when m,n,k are all +// <= TINY_GEMM_MAX_DIM and batch > 1; otherwise the cuBLAS path stays. + +#include + +// Max supported dimension per side. The shared-memory tiles are sized for this +// bound: A (m*k) + B (k*n) f32 floats. At 64 that is 64*64*2*4 = 32 KiB, within +// the 48 KiB default dynamic-smem budget on sm_121 (GB10). The host launcher +// also enforces this; the kernel guards defensively. +#define TINY_GEMM_MAX_DIM 64 + +// kernel_tiny_batched_gemm: one block per batch element. Threads cooperatively +// load A_b (m x k) and B_b (k x n) into shared memory, then each thread owns a +// strided subset of the m*n output cells and computes its dot products. +// +// blockDim.x threads span the m*n output cells in a grid-stride loop within the +// block, so the kernel works for any (m,n) <= TINY_GEMM_MAX_DIM regardless of +// the chosen block size. +__global__ void kernel_tiny_batched_gemm(const float* __restrict__ A, + const float* __restrict__ B, + float* __restrict__ C, + int m, int n, int k, + long long strideA, + long long strideB, + long long strideC) { + int b = blockIdx.x; + + const float* Ab = A + (long long)b * strideA; + const float* Bb = B + (long long)b * strideB; + float* Cb = C + (long long)b * strideC; + + // Shared tiles for this batch element's A and B. + __shared__ float sA[TINY_GEMM_MAX_DIM * TINY_GEMM_MAX_DIM]; + __shared__ float sB[TINY_GEMM_MAX_DIM * TINY_GEMM_MAX_DIM]; + + int tid = threadIdx.x; + int nthreads = blockDim.x; + + // Cooperative load of A (m*k) and B (k*n) into shared memory. + int aElems = m * k; + for (int idx = tid; idx < aElems; idx += nthreads) { + sA[idx] = Ab[idx]; + } + int bElems = k * n; + for (int idx = tid; idx < bElems; idx += nthreads) { + sB[idx] = Bb[idx]; + } + __syncthreads(); + + // Each thread computes a strided subset of the m*n output cells. + int cElems = m * n; + for (int idx = tid; idx < cElems; idx += nthreads) { + int i = idx / n; // row in [0,m) + int j = idx % n; // col in [0,n) + float acc = 0.0f; + const float* aRow = &sA[i * k]; + // sB is row-major [k, n]; element (l, j) is sB[l*n + j]. + #pragma unroll 4 + for (int l = 0; l < k; ++l) { + acc += aRow[l] * sB[l * n + j]; + } + Cb[i * n + j] = acc; + } +} + +extern "C" { + +// tiny_batched_gemm_f32 launches one strided-batched GEMM. All arguments cross +// as integers / pointers -- there are NO floating-point scalars, so the purego +// integer-register ABI (see fused_adamw.cu ABI note) is satisfied trivially. +// alpha is fixed at 1 and beta at 0 to match the attention MatMul this serves. +// +// Returns cudaErrorInvalidValue if any dimension exceeds TINY_GEMM_MAX_DIM so +// the host can fall back to cuBLAS rather than silently truncating. +cudaError_t tiny_batched_gemm_f32(const float* A, const float* B, float* C, + int m, int n, int k, + long long strideA, long long strideB, + long long strideC, + int batch, cudaStream_t stream) { + if (m <= 0 || n <= 0 || k <= 0 || batch <= 0) { + return cudaErrorInvalidValue; + } + if (m > TINY_GEMM_MAX_DIM || n > TINY_GEMM_MAX_DIM || k > TINY_GEMM_MAX_DIM) { + return cudaErrorInvalidValue; + } + + // Block size: cap at 256, but no fewer than the output tile so small tiles + // still parallelize. Round up to a warp multiple for occupancy. + int cells = m * n; + int block = cells < 256 ? cells : 256; + block = ((block + 31) / 32) * 32; + if (block < 32) block = 32; + if (block > 1024) block = 1024; + + dim3 grid(batch); + dim3 threads(block); + kernel_tiny_batched_gemm<<>>( + A, B, C, m, n, k, strideA, strideB, strideC); + return cudaGetLastError(); +} + +} // extern "C" diff --git a/internal/cuda/kernels/tiny_batched_gemm.go b/internal/cuda/kernels/tiny_batched_gemm.go new file mode 100644 index 0000000..5e0745b --- /dev/null +++ b/internal/cuda/kernels/tiny_batched_gemm.go @@ -0,0 +1,45 @@ +//go:build cuda + +package kernels + +/* +#cgo LDFLAGS: -L${SRCDIR} -lkernels -lcudart -lstdc++ +#include "tiny_batched_gemm.h" +*/ +import "C" + +import ( + "fmt" + "unsafe" +) + +// TinyGemmMaxDim is the largest per-side dimension the tiny-matrix batched GEMM +// kernel supports. It mirrors TINY_GEMM_MAX_DIM in tiny_batched_gemm.cu; the +// host dispatch (compute/gpu_engine.go) checks m,n,k against this before +// routing to the kernel, falling back to cuBLAS otherwise. +const TinyGemmMaxDim = 64 + +// TinyBatchedGemmF32 computes batch independent small f32 GEMMs +// C_b = A_b * B_b (alpha=1, beta=0, row-major) in one launch. Strides are in +// ELEMENTS. Returns an error (so the caller falls back to cuBLAS) if any of +// m,n,k exceeds TinyGemmMaxDim or any dimension/batch is non-positive. +func TinyBatchedGemmF32( + a, b, c unsafe.Pointer, + m, n, k int, + strideA, strideB, strideC int64, + batch int, + stream unsafe.Pointer, +) error { + err := C.tiny_batched_gemm_f32( + (*C.float)(a), (*C.float)(b), (*C.float)(c), + C.int(m), C.int(n), C.int(k), + C.longlong(strideA), C.longlong(strideB), C.longlong(strideC), + C.int(batch), + C.cudaStream_t(stream), + ) + if err != C.cudaSuccess { + return fmt.Errorf("tiny_batched_gemm_f32: %s", + C.GoString(C.cudaGetErrorString(err))) + } + return nil +} diff --git a/internal/cuda/kernels/tiny_batched_gemm.h b/internal/cuda/kernels/tiny_batched_gemm.h new file mode 100644 index 0000000..b107ccf --- /dev/null +++ b/internal/cuda/kernels/tiny_batched_gemm.h @@ -0,0 +1,41 @@ +/* Tiny-matrix strided-batched GEMM interface (ADR 075 lever L3). + * + * Computes batch independent small f32 GEMMs C_b = A_b * B_b (alpha=1, beta=0, + * row-major) in a single launch, one CUDA thread-block per batch element, with + * the A/B tiles staged in shared memory. Replaces cuBLAS SgemmStridedBatched on + * the tiny attention shapes (12x12, 12x64 over batch=1024) where cuBLAS falls + * back to a GEMV + split-K reduction fan-out. + * + * All arguments are integers / pointers -- no floating-point scalars -- so the + * purego integer-register dispatch ABI (see fused_adamw.h) is satisfied. + */ +#ifndef TINY_BATCHED_GEMM_H +#define TINY_BATCHED_GEMM_H + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* tiny_batched_gemm_f32 computes, for each b in [0, batch): + * C_b[i,j] = sum_l A_b[i,l] * B_b[l,j] + * where A_b = A + b*strideA, B_b = B + b*strideB, C_b = C + b*strideC + * (strides in ELEMENTS), A_b is row-major [m,k], B_b row-major [k,n], + * C_b row-major [m,n]. Accumulation is in f32. + * + * Returns cudaErrorInvalidValue if any of m,n,k exceeds the supported tiny + * bound (64) or any dimension/batch is non-positive, so the caller can fall + * back to cuBLAS. + */ +cudaError_t tiny_batched_gemm_f32(const float* A, const float* B, float* C, + int m, int n, int k, + long long strideA, long long strideB, + long long strideC, + int batch, cudaStream_t stream); + +#ifdef __cplusplus +} +#endif + +#endif /* TINY_BATCHED_GEMM_H */ diff --git a/internal/cuda/kernels/tiny_batched_gemm_purego.go b/internal/cuda/kernels/tiny_batched_gemm_purego.go new file mode 100644 index 0000000..5d972a5 --- /dev/null +++ b/internal/cuda/kernels/tiny_batched_gemm_purego.go @@ -0,0 +1,38 @@ +//go:build !cuda + +package kernels + +import ( + "fmt" + "unsafe" + + "github.com/zerfoo/ztensor/internal/cuda" +) + +// TinyGemmMaxDim is the largest per-side dimension the tiny-matrix batched GEMM +// kernel supports. It mirrors TINY_GEMM_MAX_DIM in tiny_batched_gemm.cu. +const TinyGemmMaxDim = 64 + +// TinyBatchedGemmF32 computes batch independent small f32 GEMMs +// C_b = A_b * B_b (alpha=1, beta=0, row-major) in one launch. Strides are in +// ELEMENTS. All arguments cross the purego boundary as integer registers (no +// floating-point scalars), so the int64 strides pass through faithfully on +// arm64 (uintptr is 64-bit). +func TinyBatchedGemmF32( + a, b, c unsafe.Pointer, + m, n, k int, + strideA, strideB, strideC int64, + batch int, + stream unsafe.Pointer, +) error { + klib := klib() + if klib == nil { + return fmt.Errorf("tiny_batched_gemm_f32 kernel: kernels not available") + } + ret := cuda.Ccall(klib.launchTinyBatchedGemmF32, + uintptr(a), uintptr(b), uintptr(c), + uintptr(m), uintptr(n), uintptr(k), + uintptr(strideA), uintptr(strideB), uintptr(strideC), + uintptr(batch), uintptr(stream)) + return checkKernel(ret, "tiny_batched_gemm_f32") +} diff --git a/internal/gpuapi/cuda_kernels.go b/internal/gpuapi/cuda_kernels.go index a58e849..c32cdbd 100644 --- a/internal/gpuapi/cuda_kernels.go +++ b/internal/gpuapi/cuda_kernels.go @@ -226,6 +226,12 @@ func (k *CUDAKernels) FusedAdamWF32(param, m, v, grad unsafe.Pointer, beta1, bet return kernels.FusedAdamWF32(param, m, v, grad, beta1, beta2, oneMinusBeta1, oneMinusBeta2, eps, alpha, lrWd, n, streamPtr(s)) } +// TinyBatchedGemmF32 dispatches the custom tiny-matrix strided-batched GEMM +// kernel (ADR 075 L3). Strides are in elements. +func (k *CUDAKernels) TinyBatchedGemmF32(a, b, c unsafe.Pointer, m, n, kk int, strideA, strideB, strideC int64, batch int, s Stream) error { //nolint:gocritic // interface match + return kernels.TinyBatchedGemmF32(a, b, c, m, n, kk, strideA, strideB, strideC, batch, streamPtr(s)) +} + func (k *CUDAKernels) FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, s Stream) error { //nolint:gocritic // interface match return kernels.FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut, eps, rows, D, streamPtr(s)) } diff --git a/internal/gpuapi/fpga_kernels.go b/internal/gpuapi/fpga_kernels.go index 9679ecf..ace0cfb 100644 --- a/internal/gpuapi/fpga_kernels.go +++ b/internal/gpuapi/fpga_kernels.go @@ -199,6 +199,10 @@ func (k *FPGAKernels) FusedAdamWF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _, _, return fmt.Errorf("FusedAdamWF32: not implemented for FPGA") } +func (k *FPGAKernels) TinyBatchedGemmF32(_, _, _ unsafe.Pointer, _, _, _ int, _, _, _ int64, _ int, _ Stream) error { //nolint:gocritic // interface match + return fmt.Errorf("TinyBatchedGemmF32: not implemented for FPGA") +} + func (k *FPGAKernels) FusedAddRMSNormF32(_, _, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error { return fmt.Errorf("FusedAddRMSNormF32: not implemented for FPGA") } diff --git a/internal/gpuapi/gpuapi_test.go b/internal/gpuapi/gpuapi_test.go index be35527..b35c4db 100644 --- a/internal/gpuapi/gpuapi_test.go +++ b/internal/gpuapi/gpuapi_test.go @@ -220,6 +220,9 @@ func (stubKernelRunner) FusedSwiGLUF32(_, _, _ unsafe.Pointer, _ int, _ gpuapi.S func (stubKernelRunner) FusedAdamWF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _, _, _ float64, _ int, _ gpuapi.Stream) error { //nolint:gocritic // interface match return nil } +func (stubKernelRunner) TinyBatchedGemmF32(_, _, _ unsafe.Pointer, _, _, _ int, _, _, _ int64, _ int, _ gpuapi.Stream) error { //nolint:gocritic // interface match + return nil +} func (stubKernelRunner) FusedAddRMSNormF32(_, _, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ gpuapi.Stream) error { return nil } diff --git a/internal/gpuapi/kernels.go b/internal/gpuapi/kernels.go index c1eee5d..b9d025c 100644 --- a/internal/gpuapi/kernels.go +++ b/internal/gpuapi/kernels.go @@ -142,6 +142,14 @@ type KernelRunner interface { // rounding. All buffers have n elements. FusedAdamWF32(param, m, v, grad unsafe.Pointer, beta1, beta2, oneMinusBeta1, oneMinusBeta2, eps, alpha, lrWd float64, n int, stream Stream) error + // TinyBatchedGemmF32 computes batch independent small f32 GEMMs + // C_b = A_b * B_b (alpha=1, beta=0, row-major) in one launch -- a custom + // path for tiny matrices (m,n,k small, batch large) where cuBLAS + // SgemmStridedBatched falls back to a GEMV + split-K reduction fan-out + // (ADR 075 lever L3). Strides are in ELEMENTS. Returns an error when any of + // m,n,k exceeds the supported tiny bound so the caller falls back to cuBLAS. + TinyBatchedGemmF32(a, b, c unsafe.Pointer, m, n, k int, strideA, strideB, strideC int64, batch int, stream Stream) error + // FusedAddRMSNormF32 fuses residual addition and RMSNorm into one kernel launch. // sum_out = input + residual, normed_out = rmsnorm(sum_out, weight, eps). // input: [rows, D], residual: [rows, D], weight: [D], diff --git a/internal/gpuapi/metal_kernels.go b/internal/gpuapi/metal_kernels.go index e85911e..220dee2 100644 --- a/internal/gpuapi/metal_kernels.go +++ b/internal/gpuapi/metal_kernels.go @@ -274,6 +274,10 @@ func (k *MetalKernels) FusedAdamWF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _, _ return fmt.Errorf("FusedAdamWF32: not implemented for Metal") } +func (k *MetalKernels) TinyBatchedGemmF32(_, _, _ unsafe.Pointer, _, _, _ int, _, _, _ int64, _ int, _ Stream) error { //nolint:gocritic // interface match + return fmt.Errorf("TinyBatchedGemmF32: not implemented for Metal") +} + func (k *MetalKernels) FusedAddRMSNormF32(input, residual, weight, normedOut, sumOut unsafe.Pointer, eps float32, rows, D int, _ Stream) error { return k.dispatchPerRow("kernel_fused_add_rmsnorm", rows, map[int]metal.BufferBinding{ diff --git a/internal/gpuapi/opencl_kernels.go b/internal/gpuapi/opencl_kernels.go index d6602d3..5002311 100644 --- a/internal/gpuapi/opencl_kernels.go +++ b/internal/gpuapi/opencl_kernels.go @@ -216,6 +216,10 @@ func (k *OpenCLKernels) FusedAdamWF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _, return fmt.Errorf("FusedAdamWF32: not implemented for OpenCL") } +func (k *OpenCLKernels) TinyBatchedGemmF32(_, _, _ unsafe.Pointer, _, _, _ int, _, _, _ int64, _ int, _ Stream) error { //nolint:gocritic // interface match + return fmt.Errorf("TinyBatchedGemmF32: not implemented for OpenCL") +} + func (k *OpenCLKernels) FusedAddRMSNormF32(_, _, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error { return fmt.Errorf("FusedAddRMSNormF32: not implemented for OpenCL") } diff --git a/internal/gpuapi/rocm_kernels.go b/internal/gpuapi/rocm_kernels.go index 78a3cd1..bc4d733 100644 --- a/internal/gpuapi/rocm_kernels.go +++ b/internal/gpuapi/rocm_kernels.go @@ -206,6 +206,10 @@ func (k *ROCmKernels) FusedAdamWF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _, _, return fmt.Errorf("FusedAdamWF32: not implemented for ROCm") } +func (k *ROCmKernels) TinyBatchedGemmF32(_, _, _ unsafe.Pointer, _, _, _ int, _, _, _ int64, _ int, _ Stream) error { //nolint:gocritic // interface match + return fmt.Errorf("TinyBatchedGemmF32: not implemented for ROCm") +} + func (k *ROCmKernels) FusedAddRMSNormF32(_, _, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error { return fmt.Errorf("FusedAddRMSNormF32: not implemented for ROCm") } diff --git a/internal/gpuapi/sycl_kernels.go b/internal/gpuapi/sycl_kernels.go index a68a94c..6923aaa 100644 --- a/internal/gpuapi/sycl_kernels.go +++ b/internal/gpuapi/sycl_kernels.go @@ -201,6 +201,10 @@ func (k *SYCLKernels) FusedAdamWF32(_, _, _, _ unsafe.Pointer, _, _, _, _, _, _, return fmt.Errorf("FusedAdamWF32: not implemented for SYCL") } +func (k *SYCLKernels) TinyBatchedGemmF32(_, _, _ unsafe.Pointer, _, _, _ int, _, _, _ int64, _ int, _ Stream) error { //nolint:gocritic // interface match + return fmt.Errorf("TinyBatchedGemmF32: not implemented for SYCL") +} + func (k *SYCLKernels) FusedAddRMSNormF32(_, _, _, _, _ unsafe.Pointer, _ float32, _, _ int, _ Stream) error { return fmt.Errorf("FusedAddRMSNormF32: not implemented for SYCL") }