Skip to content

feat(compute): tiny-matrix batched GEMM kernel (ADR 075 L3)#150

Merged
dndungu merged 1 commit into
mainfrom
fix/tiny-matrix-attention-gemm-l3
Jun 16, 2026
Merged

feat(compute): tiny-matrix batched GEMM kernel (ADR 075 L3)#150
dndungu merged 1 commit into
mainfrom
fix/tiny-matrix-attention-gemm-l3

Conversation

@dndungu

@dndungu dndungu commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

What

A custom small-matrix strided-batched GEMM CUDA kernel for f32, dispatched from the GPU engine's MatMul when m,n,k are all <= 64 and batch > 1. Replaces cuBLAS SgemmStridedBatched on tiny shapes where cuBLAS falls back to a GEMV + split-K reduction fan-out (gemvNSP_kernel + splitKreduce_kernel).

Why (measured)

A fresh GB10 nsys kernel-time breakdown of Wolf CrossAsset batched-eager training (zerfoo v1.51.0 / ztensor v1.12.0, the 1.62 s/epoch binary) attributes ~17% of GPU kernel time to this fan-out on the 12x12 Q@K^T and 12x64 weights@V attention matmuls (batch = B*heads = 1024). The fingerprint (gemvNSP 437 ms + gemvx 640 ms + splitKreduce 89 ms over the profiled run) is exactly the cuBLAS tiny-matrix routing. The custom kernel does each logical GEMM in one launch (one thread-block per batch element, A/B tiles in shared memory, f32 accumulation matching cuBLAS Sgemm).

General-purpose

Any small batched matmul benefits, not just attention. The kernel is shape-gated, not Wolf-specific.

Safety

  • Falls back to cuBLAS on any kernel launch error — correctness never depends on the fast path.
  • Gated by ZERFOO_DISABLE_TINY_GEMM=1 for A/B comparison.
  • CPU path byte-unchanged.

Tests (compute/gpu_tiny_batched_gemm_test.go, GPU-gated, skip without CUDA)

  • tiny-GEMM vs CPU reference parity on the exact attention shapes + broadcast-B + 64^3 tile boundary + asymmetric shapes; zero NaN.
  • tiny-GEMM vs cuBLAS SgemmStridedBatched equivalence (toggle the gate).
  • finite-difference gradcheck through the GPU MatMul.

Full non-GPU suite green locally (go test ./..., 30 pkgs ok). GPU tests run on the DGX (sm_121); the libkernels.so rebuild with the new .cu is part of the Wolf L3 verification.

Wiring

Follows the merged fused_adamw kernel pattern: .cu/.h + cgo + purego bindings, purego symbol registration, KernelRunner interface method across all backends (CUDA real; ROCm/Metal/SYCL/FPGA/OpenCL stubs).

Plan: Wolf docs/plan-speed-to-parity.md T14L3.1 (ADR 075).

…apes (ADR 075 L3)

cuBLAS SgemmStridedBatched routes tiny matrices through a GEMV + split-K
reduction fan-out (gemvNSP_kernel + splitKreduce_kernel): 2-3 internal
kernels per logical GEMM, none tiling efficiently for m,n,k <= ~64. A fresh
GB10 nsys breakdown of CrossAsset batched-eager training attributes ~17% of
GPU kernel time to this fan-out on the 12x12 Q@K^T and 12x64 weights@V
attention matmuls (batch = B*heads = 1024).

Add a general-purpose custom kernel: one strided-batched GEMM in ONE launch,
one CUDA thread-block per batch element, A[m,k] and B[k,n] tiles staged in
shared memory, f32 accumulation matching cuBLAS Sgemm. The GPU MatMul
dispatches to it when m,n,k are all <= 64 and batch > 1, falling back to
cuBLAS on any kernel error (correctness never depends on the fast path).
Gated by ZERFOO_DISABLE_TINY_GEMM for A/B comparison. CPU path unchanged.

General mechanism: any small batched matmul benefits, not just attention.

Tests (compute/gpu_tiny_batched_gemm_test.go, GPU-gated):
- tiny-GEMM vs CPU reference parity on the exact attention shapes + broadcast-B
  + tile-boundary 64^3 + asymmetric shapes, zero NaN
- tiny-GEMM vs cuBLAS SgemmStridedBatched equivalence (toggle the gate)
- finite-difference gradcheck through the GPU MatMul

Wiring follows the fused_adamw kernel pattern: .cu/.h + cgo + purego bindings,
purego symbol registration, KernelRunner interface method across all backends.
@dndungu dndungu merged commit 6554b8c into main Jun 16, 2026
1 check passed
@dndungu dndungu deleted the fix/tiny-matrix-attention-gemm-l3 branch June 16, 2026 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant