Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions compute/gpu_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -1556,13 +1556,42 @@ func (e *GPUEngine[T]) MatMul(ctx context.Context, a, b *tensor.TensorNumeric[T]
// (supported by cuBLAS). This is critical for GQA attention where K/V
// have fewer heads than Q, replacing N individual Sgemm calls with 1.
if batchSize > 1 && isFloat32 {
if batched, ok := e.blas.(gpuapi.BLASBatched); ok {
strideA := int64(aMatSize)
strideBVal := int64(bMatSize)
if bBatchSize == 1 {
strideBVal = 0
strideA := int64(aMatSize)
strideBVal := int64(bMatSize)
if bBatchSize == 1 {
strideBVal = 0
}
strideC := int64(cMatSize)

// Tiny-matrix fast path (ADR 075 L3): when m,n,k are all small, cuBLAS
// SgemmStridedBatched routes through a GEMV + split-K reduction fan-out
// (2-3 internal kernels per logical GEMM). The custom kernel does it in
// one launch, one block per batch element with the tiles in shared
// memory. This is exactly the CrossAsset attention case (12x12 Q@K^T and
// 12x64 weights@V over batch=B*heads). Falls back to cuBLAS on any
// kernel error so correctness never depends on the fast path.
if !disableTinyGemm &&
m <= tinyGemmMaxDim && n <= tinyGemmMaxDim && k <= tinyGemmMaxDim {
if debugGPU {
e.logger.Debug("MatMul: TinyBatchedGemmF32 call",
"m", fmt.Sprintf("%d", m), "n", fmt.Sprintf("%d", n),
"k", fmt.Sprintf("%d", k), "batchSize", fmt.Sprintf("%d", batchSize))
}
strideC := int64(cMatSize)
if terr := e.kernels.TinyBatchedGemmF32(
devA, devB, devCTotal, m, n, k,
strideA, strideBVal, strideC, batchSize, e.stream); terr == nil {
if reusedC {
return finishReusedDst[T](dst[0], outShape), nil
}
return makeGPUResult[T](e, outShape, devCTotal, batchSize*cMatSize, dst...)
} else if debugGPU {
e.logger.Debug("MatMul: TinyBatchedGemmF32 fell back to cuBLAS",
"error", terr.Error())
}
// Fall through to cuBLAS on tiny-kernel error.
}

if batched, ok := e.blas.(gpuapi.BLASBatched); ok {
if debugGPU {
e.logger.Debug("MatMul: SgemmStridedBatched call",
"m", fmt.Sprintf("%d", m),
Expand Down
14 changes: 14 additions & 0 deletions compute/gpu_kernels.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ const f64Size = int(unsafe.Sizeof(float64(0)))
// Off by default to avoid any performance impact in normal operation.
var debugGPU = os.Getenv("ZERFOO_DEBUG_GPU") == "1"

// disableTinyGemm, when ZERFOO_DISABLE_TINY_GEMM=1, forces the batched MatMul
// path to use cuBLAS SgemmStridedBatched even for tiny matrices, bypassing the
// custom tiny-matrix batched-GEMM kernel (ADR 075 L3). The kernel is enabled by
// default because it is a strict win on tiny shapes (it avoids cuBLAS's
// GEMV + split-K fan-out) and falls back to cuBLAS on any launch error; this
// flag exists for A/B comparison and as a safety escape hatch.
var disableTinyGemm = os.Getenv("ZERFOO_DISABLE_TINY_GEMM") == "1"

// tinyGemmMaxDim mirrors the kernel's TINY_GEMM_MAX_DIM bound: the custom
// tiny-matrix batched GEMM is dispatched only when m, n, and k are all at most
// this value (and batch > 1). Above it, cuBLAS tiles efficiently and the custom
// kernel offers no benefit.
const tinyGemmMaxDim = 64

// traceH2D, when ZERFOO_TRACE_H2D=1, makes getDevicePtr emit one line per
// CPU-backed host->device upload, tagged with the operand shape. It exists to
// attribute the per-op H2D firehose (the [256,256]/[1024,256] weight-class
Expand Down
203 changes: 203 additions & 0 deletions compute/gpu_tiny_batched_gemm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
package compute

import (
"context"
"math"
"math/rand"
"testing"

"github.com/zerfoo/ztensor/numeric"
"github.com/zerfoo/ztensor/tensor"
)

// TestGPUEngine_TinyBatchedGemm_AttentionShapes verifies the custom tiny-matrix
// batched-GEMM kernel (ADR 075 L3), which the GPU MatMul dispatches to for small
// m,n,k, matches the CPU reference GEMM within f32 tolerance for the exact
// CrossAsset attention shapes: Q@K^T [12,64]@[64,12]->[12,12] and
// weights@V [12,12]@[12,64]->[12,64], batched over B*heads = 1024.
func TestGPUEngine_TinyBatchedGemm_AttentionShapes(t *testing.T) {
gpuEng := newTestGPUEngine(t)
cpuEng := NewCPUEngine[float32](numeric.Float32Ops{})
ctx := context.Background()

cases := []struct {
name string
batch int
m, k, n int
bBroadcast bool // B has batch dim 1 (strideB=0 broadcast)
}{
{"QKt_12x64x12_b1024", 1024, 12, 64, 12, false},
{"weightsV_12x12x64_b1024", 1024, 12, 12, 64, false},
{"QKt_broadcastB", 256, 12, 64, 12, true},
{"tile_boundary_64x64x64", 64, 64, 64, 64, false},
{"asym_3x7x5", 100, 3, 7, 5, false},
}

rng := rand.New(rand.NewSource(42))
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
aData := make([]float32, tc.batch*tc.m*tc.k)
for i := range aData {
aData[i] = float32(rng.NormFloat64())
}
bBatch := tc.batch
if tc.bBroadcast {
bBatch = 1
}
bData := make([]float32, bBatch*tc.k*tc.n)
for i := range bData {
bData[i] = float32(rng.NormFloat64())
}

a, _ := tensor.New[float32]([]int{tc.batch, tc.m, tc.k}, aData)
b, _ := tensor.New[float32]([]int{bBatch, tc.k, tc.n}, bData)

gpuRes, err := gpuEng.MatMul(ctx, a, b)
if err != nil {
t.Fatalf("GPU MatMul: %v", err)
}
// CPU reference uses the same broadcasting semantics.
aCPU, _ := tensor.New[float32]([]int{tc.batch, tc.m, tc.k}, aData)
bCPU, _ := tensor.New[float32]([]int{bBatch, tc.k, tc.n}, bData)
cpuRes, err := cpuEng.MatMul(ctx, aCPU, bCPU)
if err != nil {
t.Fatalf("CPU MatMul: %v", err)
}

gd := gpuRes.Data()
cd := cpuRes.Data()
if len(gd) != len(cd) {
t.Fatalf("length mismatch: GPU=%d CPU=%d", len(gd), len(cd))
}
var maxAbs, maxRel float64
for i := range gd {
if math.IsNaN(float64(gd[i])) {
t.Fatalf("NaN in GPU result at [%d]", i)
}
diff := math.Abs(float64(gd[i] - cd[i]))
if diff > maxAbs {
maxAbs = diff
}
denom := math.Abs(float64(cd[i]))
if denom > 1e-6 {
if rel := diff / denom; rel > maxRel {
maxRel = rel
}
}
}
// f32 GEMM with k up to 64: accumulation error ~ k*eps. Tolerate a
// small absolute and relative bound (the kernel and CPU sum in the
// same f32 precision; differences come only from summation order).
if maxAbs > 1e-3 || maxRel > 1e-4 {
t.Errorf("%s: tiny-GEMM vs CPU mismatch maxAbs=%e maxRel=%e", tc.name, maxAbs, maxRel)
}
})
}
}

// TestGPUEngine_TinyBatchedGemm_MatchesCublas verifies the tiny kernel result
// equals the cuBLAS SgemmStridedBatched result for the same inputs (the path
// it replaces), by toggling ZERFOO_DISABLE_TINY_GEMM. This guards against any
// silent divergence between the custom path and the framework's prior behavior.
func TestGPUEngine_TinyBatchedGemm_MatchesCublas(t *testing.T) {
eng := newTestGPUEngine(t)
ctx := context.Background()

batch, m, k, n := 1024, 12, 64, 12
rng := rand.New(rand.NewSource(7))
aData := make([]float32, batch*m*k)
for i := range aData {
aData[i] = float32(rng.NormFloat64())
}
bData := make([]float32, batch*k*n)
for i := range bData {
bData[i] = float32(rng.NormFloat64())
}

run := func(disable bool) []float32 {
t.Helper()
old := disableTinyGemm
disableTinyGemm = disable
defer func() { disableTinyGemm = old }()
a, _ := tensor.New[float32]([]int{batch, m, k}, aData)
b, _ := tensor.New[float32]([]int{batch, k, n}, bData)
res, err := eng.MatMul(ctx, a, b)
if err != nil {
t.Fatalf("MatMul (disable=%v): %v", disable, err)
}
out := make([]float32, len(res.Data()))
copy(out, res.Data())
return out
}

tiny := run(false)
cublas := run(true)
if len(tiny) != len(cublas) {
t.Fatalf("length mismatch tiny=%d cublas=%d", len(tiny), len(cublas))
}
var maxAbs float64
for i := range tiny {
d := math.Abs(float64(tiny[i] - cublas[i]))
if d > maxAbs {
maxAbs = d
}
}
// Both accumulate in f32; only summation order differs, so the gap is tiny.
if maxAbs > 1e-3 {
t.Errorf("tiny-GEMM vs cuBLAS divergence: maxAbs=%e", maxAbs)
}
}

// TestGPUEngine_TinyBatchedGemm_Gradcheck does a finite-difference gradient
// check through the GPU batched MatMul (which uses the tiny kernel) to confirm
// the forward result is the true matmul (so any autograd built on MatMul has
// correct gradients). dC/dA = upstream @ B^T, dC/dB = A^T @ upstream; we verify
// the forward against a numerical perturbation of a few entries.
func TestGPUEngine_TinyBatchedGemm_Gradcheck(t *testing.T) {
eng := newTestGPUEngine(t)
ctx := context.Background()

batch, m, k, n := 8, 12, 64, 12
rng := rand.New(rand.NewSource(99))
aData := make([]float32, batch*m*k)
for i := range aData {
aData[i] = float32(rng.NormFloat64()) * 0.1
}
bData := make([]float32, batch*k*n)
for i := range bData {
bData[i] = float32(rng.NormFloat64()) * 0.1
}

matmul := func(ad, bd []float32) []float32 {
a, _ := tensor.New[float32]([]int{batch, m, k}, ad)
b, _ := tensor.New[float32]([]int{batch, k, n}, bd)
res, err := eng.MatMul(ctx, a, b)
if err != nil {
t.Fatalf("MatMul: %v", err)
}
out := make([]float32, len(res.Data()))
copy(out, res.Data())
return out
}

base := matmul(aData, bData)

// Analytic dC[outIdx]/dA[aIdx]: pick batch 0, perturb A[0,i0,l0] and confirm
// the change in C[0,i0,j] equals B[0,l0,j] * eps for all j (linear in A).
const eps = float32(1e-2)
i0, l0 := 2, 5
aIdx := (0*m+i0)*k + l0
perturbed := make([]float32, len(aData))
copy(perturbed, aData)
perturbed[aIdx] += eps
after := matmul(perturbed, bData)

for j := 0; j < n; j++ {
outIdx := (0*m+i0)*n + j
got := (after[outIdx] - base[outIdx]) / eps
want := bData[(0*k+l0)*n+j] // B[0, l0, j]
if math.Abs(float64(got-want)) > 1e-2 {
t.Errorf("grad dC[0,%d,%d]/dA[0,%d,%d]: got=%f want=%f", i0, j, i0, l0, got, want)
}
}
}
2 changes: 1 addition & 1 deletion internal/cuda/kernels/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ifeq ($(CUDA_ARCH),sm_121)
NVCC_FLAGS += -DFLASH_BLOCK_SIZE=64
endif

SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_encoder_fwd.cu fused_encoder_bwd.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_adamw.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu
SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_encoder_fwd.cu fused_encoder_bwd.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_adamw.cu tiny_batched_gemm.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu
OBJS = $(SRCS:.cu=.o)
PIC_OBJS = $(SRCS:.cu=.pic.o)
LIB = libkernels.a
Expand Down
5 changes: 5 additions & 0 deletions internal/cuda/kernels/purego.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ type KernelLib struct {
// fused_adamw (on-device AdamW mixed-precision optimizer step)
launchFusedAdamWF32 uintptr

// tiny_batched_gemm (small-matrix strided-batched GEMM, ADR 075 L3)
launchTinyBatchedGemmF32 uintptr

// fused_repeat_interleave (GQA KV head expansion)
launchRepeatInterleaveF32 uintptr

Expand Down Expand Up @@ -288,6 +291,8 @@ func openKernelLib() (*KernelLib, error) {
{"fused_swiglu_f32", &k.launchFusedSwiGLUF32},
// fused_adamw
{"fused_adamw_f32", &k.launchFusedAdamWF32},
// tiny_batched_gemm (ADR 075 L3)
{"tiny_batched_gemm_f32", &k.launchTinyBatchedGemmF32},
// fused_repeat_interleave (GQA KV head expansion)
{"launch_repeat_interleave_f32", &k.launchRepeatInterleaveF32},
// fused_add_rmsnorm
Expand Down
Loading
Loading