diff --git a/compute/bulk_upload_test.go b/compute/bulk_upload_test.go index d0eb0d0..7b6f542 100644 --- a/compute/bulk_upload_test.go +++ b/compute/bulk_upload_test.go @@ -4,6 +4,7 @@ import ( "math" "testing" + "github.com/zerfoo/float16" "github.com/zerfoo/ztensor/internal/cuda" "github.com/zerfoo/ztensor/numeric" "github.com/zerfoo/ztensor/tensor" @@ -69,6 +70,82 @@ func TestGPUEngine_UploadWeights_BulkPath(t *testing.T) { } } +// TestGPUEngine_UploadWeightsT_BF16 verifies the generic (T-typed) bulk upload: +// a GPUEngine[float16.BFloat16] can make its host-backed bf16 weight tensors +// device-resident via UploadWeightsT, so they no longer stay host-backed and +// pay a per-op H2D firehose. Asserts the tensors become *GPUStorage[BFloat16] +// (2-byte device elements, sized by unsafe.Sizeof, not f32Size) and that values +// round-trip across the bulk copy. CUDA-gated; skips without a GPU. +// GPU-UNVERIFIED until run on the GB10 verify pod. +func TestGPUEngine_UploadWeightsT_BF16(t *testing.T) { + if !cuda.Available() { + t.Skip("CUDA not available") + } + + gpuEng, err := NewGPUEngine[float16.BFloat16](numeric.BFloat16Ops{}) + if err != nil { + t.Fatalf("NewGPUEngine[BFloat16]: %v", err) + } + defer func() { _ = gpuEng.Close() }() + + // N must exceed bulkUploadF32MinTensors (64) to exercise the bulk path. + const N = 128 + const elemsPer = 17 + + tensors := make([]*tensor.TensorNumeric[float16.BFloat16], N) + wantVals := make([][]float32, N) + for i := range N { + data := make([]float16.BFloat16, elemsPer) + wantVals[i] = make([]float32, elemsPer) + for j := range elemsPer { + // values exactly representable in bf16 so the round-trip is exact + f := float16.BFloat16FromFloat32(float32(i) + float32(j)/64.0) + data[j] = f + wantVals[i][j] = f.ToFloat32() + } + tt, errNew := tensor.New[float16.BFloat16]([]int{elemsPer}, data) + if errNew != nil { + t.Fatalf("tensor.New[BFloat16]: %v", errNew) + } + tensors[i] = tt + } + + if got := len(gpuEng.bulkUploadBuffers); got != 0 { + t.Fatalf("bulkUploadBuffers before upload = %d, want 0", got) + } + + if err := gpuEng.UploadWeightsT(tensors); err != nil { + t.Fatalf("UploadWeightsT: %v", err) + } + + // Every tensor must now be device-resident as *GPUStorage[BFloat16]. + for i, tt := range tensors { + if _, ok := tt.GetStorage().(*tensor.GPUStorage[float16.BFloat16]); !ok { + t.Fatalf("tensor[%d] storage = %T, want *GPUStorage[float16.BFloat16]", i, tt.GetStorage()) + } + } + + // Round-trip a sample to verify the bulk copy preserved bf16 bytes. + for _, i := range []int{0, 1, N / 2, N - 1} { + got := tensors[i].Data() + for j := range elemsPer { + if got[j].ToFloat32() != wantVals[i][j] { + t.Errorf("tensor[%d][%d] = %g, want %g", i, j, got[j].ToFloat32(), wantVals[i][j]) + } + } + } + + // A second UploadWeightsT must be a no-op: the tensors are already + // *GPUStorage[BFloat16] and must be skipped (no new bulk buffers). + before := len(gpuEng.bulkUploadBuffers) + if err := gpuEng.UploadWeightsT(tensors); err != nil { + t.Fatalf("UploadWeightsT (second call): %v", err) + } + if after := len(gpuEng.bulkUploadBuffers); after != before { + t.Errorf("second UploadWeightsT allocated %d new bulk buffers, want 0 (already-resident skip)", after-before) + } +} + // TestGPUEngine_UploadWeights_MultiChunk exercises the bounded-chunk upload // path on real hardware (zerfoo/ztensor#106). It uploads a payload large enough // to span several bulkUploadF32MaxChunkBytes (64 MiB) chunks, proving that (a) a diff --git a/compute/gpu_bf16.go b/compute/gpu_bf16.go index 053dbdb..fe0e121 100644 --- a/compute/gpu_bf16.go +++ b/compute/gpu_bf16.go @@ -309,6 +309,120 @@ func gpuFusedQKNormRoPEBF16[T tensor.Numeric]( return makeGPUResult[T](e, []int{totalHeads, headDim}, devOut, outElems) } +// gpuSumAxisBF16 runs a native bf16 axis reduction (sum along `axis`) using the +// FP32-accumulating SumAxisBF16 kernel. It is the bf16 analogue of the f32 +// gpuSum body and mirrors its axis-normalization, keepDims/squeeze output-shape +// logic, and CPU fallbacks exactly. The kernel accumulates each axis stripe in +// FP32 and rounds the result to bf16; for axis-sized reductions this is more +// accurate than pairwise bf16 addition but still only carries bf16's 7-bit +// mantissa, so callers must tolerate a few bf16 steps vs an f64 reference. +// +// invDivisor scales the per-stripe FP32 sum before the bf16 round: pass 1.0 for +// a plain sum (gpuSum/gpuReduceSum) or 1/axisSize for a mean (gpuReduceMean). +// On every CPU fallback the divide is reapplied (cpu.ReduceMean) so the contract +// holds identically for both sum and mean. +func gpuSumAxisBF16[T tensor.Numeric]( + e *GPUEngine[T], + ctx context.Context, + a *tensor.TensorNumeric[T], + axis int, + keepDims bool, + invDivisor float32, + dst ...*tensor.TensorNumeric[T], +) (*tensor.TensorNumeric[T], error) { + if a == nil { + return nil, fmt.Errorf("Sum: input tensor must not be nil") + } + + e.setDevice() + + // cpuFallback computes the same reduction on the CPU engine: a plain sum when + // invDivisor == 1, otherwise the mean (so the divide-folding contract holds + // identically on every fallback path). + cpuFallback := func() (*tensor.TensorNumeric[T], error) { + if invDivisor == 1.0 { + return e.cpu.Sum(ctx, a, axis, keepDims, dst...) + } + return e.cpu.ReduceMean(ctx, a, axis, keepDims, dst...) + } + + // Negative axis falls back to CPU, matching the f32 gpuSum contract. + if axis < 0 { + return cpuFallback() + } + + shape := a.Shape() + rank := len(shape) + + if axis >= rank { + return nil, fmt.Errorf("Sum: axis %d out of bounds for %d dimensions", axis, rank) + } + + inner := 1 + for i := axis + 1; i < rank; i++ { + inner *= shape[i] + } + + outer := 1 + for i := 0; i < axis; i++ { + outer *= shape[i] + } + + axisSize := shape[axis] + numStripes := outer * inner + + var newShape []int + if keepDims { + newShape = make([]int, rank) + for i, d := range shape { + if i == axis { + newShape[i] = 1 + } else { + newShape[i] = d + } + } + } else { + for i, d := range shape { + if i != axis { + newShape = append(newShape, d) + } + } + if len(newShape) == 0 { + newShape = []int{1} + } + } + + devIn, cleanupIn, err := getDevicePtr(e, a) + if err != nil { + return cpuFallback() + } + defer cleanupIn() + + outByteSize := numStripes * bf16Size + + // Reuse dst's existing GPU memory when possible (mirrors f32 gpuSum #84). + devOut, reused := tryReuseDstPtr[T](numStripes, dst) + if !reused { + devOut, err = e.pool.Alloc(e.deviceID, outByteSize) + if err != nil { + return cpuFallback() + } + } + + if err := e.kernels.SumAxisBF16(devIn, devOut, outer, inner, axisSize, invDivisor, e.stream); err != nil { + if !reused { + e.pool.Free(e.deviceID, devOut, outByteSize) + } + + return nil, err + } + + if reused { + return finishReusedDst[T](dst[0], newShape), nil + } + return makeGPUResult[T](e, newShape, devOut, numStripes, dst...) +} + // gpuUnaryOpBF16 runs a native bf16 unary kernel (c = op(a)). func gpuUnaryOpBF16[T tensor.Numeric]( e *GPUEngine[T], diff --git a/compute/gpu_bf16_parity_test.go b/compute/gpu_bf16_parity_test.go index 73e2a22..f8ede1e 100644 --- a/compute/gpu_bf16_parity_test.go +++ b/compute/gpu_bf16_parity_test.go @@ -150,6 +150,9 @@ func TestGPUBF16_UnaryParity(t *testing.T) { {"Sqrt", pos, func(in *tensor.TensorNumeric[float16.BFloat16]) (*tensor.TensorNumeric[float16.BFloat16], error) { return eng.Sqrt(ctx, in) }, func(v float32) float32 { return float32(math.Sqrt(float64(v))) }, 2.0}, + {"Rsqrt", pos, func(in *tensor.TensorNumeric[float16.BFloat16]) (*tensor.TensorNumeric[float16.BFloat16], error) { + return eng.Rsqrt(ctx, in) + }, func(v float32) float32 { return float32(1.0 / math.Sqrt(float64(v))) }, 2.0}, {"Exp", x, func(in *tensor.TensorNumeric[float16.BFloat16]) (*tensor.TensorNumeric[float16.BFloat16], error) { return eng.Exp(ctx, in) }, func(v float32) float32 { return float32(math.Exp(float64(v))) }, 2.0}, @@ -220,6 +223,107 @@ func TestGPUBF16_SoftmaxParity(t *testing.T) { } } +// TestGPUBF16_ReductionParity validates the native bf16 axis reductions (Sum, +// ReduceSum, ReduceMean) against an f64 reference. The kernel accumulates each +// axis stripe in FP32 and rounds the result to bf16 once at the end, while the +// f64 reference accumulates in double precision and rounds once -- so the only +// expected divergence is bf16's 7-bit mantissa, amplified by accumulation over +// axisSize terms. We use a generous, axisSize-scaled tolerance: summing N bf16 +// values in FP32 then rounding can differ from the f64-then-bf16 reference by a +// few bf16 steps as N grows, so this is an order-of-magnitude correctness gate, +// not a bit-exact one. GPU-UNVERIFIED until run on the GB10 verify pod. +func TestGPUBF16_ReductionParity(t *testing.T) { + eng := newTestGPUBF16Engine(t) + ctx := context.Background() + rng := rand.New(rand.NewSource(1234)) + + // 3D so the reduced axis (axis=1) has a nontrivial inner stride: this + // exercises the (outer, inner, axisSize) stripe addressing in the kernel. + const outer, axisSize, inner = 4, 24, 5 + vals := make([]float32, outer*axisSize*inner) + for i := range vals { + // keep values modest so the FP32 partial sums stay well within bf16 range + vals[i] = float16.BFloat16FromFloat32(rng.Float32()*2 - 1).ToFloat32() + } + shape := []int{outer, axisSize, inner} + + // f64 reference reduced along axis=1: ref[o][in] = sum_k vals[o][k][in]. + refSum := make([]float64, outer*inner) + for o := 0; o < outer; o++ { + for in := 0; in < inner; in++ { + var s float64 + for k := 0; k < axisSize; k++ { + s += float64(vals[o*axisSize*inner+k*inner+in]) + } + refSum[o*inner+in] = s + } + } + + // accumulating axisSize bf16 values in FP32 then rounding can drift a few + // bf16 steps from the f64 reference; scale tolerance with sqrt(axisSize). + tolUlps := 2.0 + math.Sqrt(float64(axisSize)) + + t.Run("Sum", func(t *testing.T) { + in := bf16Tensor(t, shape, vals) + got, err := eng.Sum(ctx, in, 1, false) + if err != nil { + t.Fatalf("Sum: %v", err) + } + gd := bf16ToF32(got.Data()) + if len(gd) != outer*inner { + t.Fatalf("Sum produced %d elements, want %d (shape=%v)", len(gd), outer*inner, got.Shape()) + } + for i := range gd { + want := float16.BFloat16FromFloat32(float32(refSum[i])).ToFloat32() + assertBF16Close(t, "Sum", i, gd[i], want, tolUlps) + } + }) + + t.Run("ReduceSum", func(t *testing.T) { + in := bf16Tensor(t, shape, vals) + got, err := eng.ReduceSum(ctx, in, 1, false) + if err != nil { + t.Fatalf("ReduceSum: %v", err) + } + gd := bf16ToF32(got.Data()) + for i := range gd { + want := float16.BFloat16FromFloat32(float32(refSum[i])).ToFloat32() + assertBF16Close(t, "ReduceSum", i, gd[i], want, tolUlps) + } + }) + + t.Run("ReduceMean", func(t *testing.T) { + in := bf16Tensor(t, shape, vals) + got, err := eng.ReduceMean(ctx, in, 1, false) + if err != nil { + t.Fatalf("ReduceMean: %v", err) + } + gd := bf16ToF32(got.Data()) + for i := range gd { + want := float16.BFloat16FromFloat32(float32(refSum[i] / float64(axisSize))).ToFloat32() + assertBF16Close(t, "ReduceMean", i, gd[i], want, tolUlps) + } + }) + + t.Run("SumKeepDims", func(t *testing.T) { + in := bf16Tensor(t, shape, vals) + got, err := eng.Sum(ctx, in, 1, true) + if err != nil { + t.Fatalf("Sum keepDims: %v", err) + } + want := []int{outer, 1, inner} + gs := got.Shape() + if len(gs) != len(want) { + t.Fatalf("Sum keepDims shape = %v, want %v", gs, want) + } + for i := range want { + if gs[i] != want[i] { + t.Fatalf("Sum keepDims shape = %v, want %v", gs, want) + } + } + }) +} + // TestGPUBF16_AdamWParity validates the full gradient-consuming update path: // the on-device bf16 AdamW step (param/grad bf16, m f32, v f64) must match an // f64 reference AdamW step with the published parameter rounded to bf16. This diff --git a/compute/gpu_engine.go b/compute/gpu_engine.go index 83be9b6..23451a9 100644 --- a/compute/gpu_engine.go +++ b/compute/gpu_engine.go @@ -659,6 +659,122 @@ func (e *GPUEngine[T]) bulkUploadF32(tensors []*tensor.TensorNumeric[float32]) ( return len(eligible), nil } +// bulkUploadT is the generic (T-typed) analogue of bulkUploadF32: it makes a +// GPUEngine[T]'s own host-backed weight tensors device-resident in bounded +// chunks, so a non-f32 engine (e.g. GPUEngine[float16.BFloat16]) no longer +// leaves its weights host-backed and pays a per-op H2D firehose. It mirrors +// bulkUploadF32 exactly -- capture-skip, MinTensors threshold, chunk bounds via +// bulkUploadChunkRanges, managedMem direct-copy vs staged Memcpy, non-owning +// views via NewGPUStorageViewFromPtr -- but sizes everything by the actual +// element size of T (unsafe.Sizeof) instead of f32Size, and skips tensors that +// are already *tensor.GPUStorage[T] (already on device). The f32 fast path +// keeps its dedicated bulkUploadF32 (and its quantized-storage skip set); this +// path is the universal fallback for any other element type. +// +// Returns the number of tensors that were bulk-uploaded. +func (e *GPUEngine[T]) bulkUploadT(tensors []*tensor.TensorNumeric[T]) (int, error) { + if cap, ok := e.pool.(gpuapi.CaptureAwareAllocator); ok && cap.IsCapturing() { + return 0, nil + } + + elemSize := int(unsafe.Sizeof(*new(T))) + + type entry struct { + t *tensor.TensorNumeric[T] + nelem int + } + eligible := make([]entry, 0, len(tensors)) + for _, t := range tensors { + if t == nil { + continue + } + // Skip tensors already device-resident as GPUStorage[T]. + if _, ok := any(t.GetStorage()).(*tensor.GPUStorage[T]); ok { + continue + } + n := len(t.Data()) + if n == 0 { + continue + } + eligible = append(eligible, entry{t: t, nelem: n}) + } + if len(eligible) < bulkUploadF32MinTensors { + return 0, nil + } + if err := e.ensureNotCapturing(); err != nil { + return 0, err + } + + nelems := make([]int, len(eligible)) + for i, en := range eligible { + nelems[i] = en.nelem + } + for _, r := range bulkUploadChunkRanges(nelems, elemSize, + bulkUploadF32MaxChunkBytes, bulkUploadF32MaxChunkTensors) { + chunk := eligible[r[0]:r[1]] + chunkBytes := 0 + for _, en := range chunk { + chunkBytes += en.nelem * elemSize + } + + var devPtr unsafe.Pointer + var err error + if e.managedMem { + devPtr, err = mallocManagedFn(chunkBytes) + } else { + devPtr, err = e.runtime.Malloc(chunkBytes) + } + if err != nil { + return 0, fmt.Errorf("bulk alloc T chunk (%d tensors, %d bytes): %w", + len(chunk), chunkBytes, err) + } + + if e.managedMem { + dst := unsafe.Slice((*byte)(devPtr), chunkBytes) + off := 0 + for _, en := range chunk { + src := unsafe.Slice((*byte)(unsafe.Pointer(&en.t.Data()[0])), en.nelem*elemSize) + copy(dst[off:off+en.nelem*elemSize], src) + off += en.nelem * elemSize + } + } else { + host := make([]byte, chunkBytes) + off := 0 + for _, en := range chunk { + src := unsafe.Slice((*byte)(unsafe.Pointer(&en.t.Data()[0])), en.nelem*elemSize) + copy(host[off:off+en.nelem*elemSize], src) + off += en.nelem * elemSize + } + if err := e.runtime.Memcpy(devPtr, unsafe.Pointer(&host[0]), chunkBytes, gpuapi.MemcpyHostToDevice); err != nil { + _ = e.runtime.Free(devPtr) + return 0, fmt.Errorf("bulk H2D T chunk (%d bytes): %w", chunkBytes, err) + } + } + + e.bulkUploadBuffers = append(e.bulkUploadBuffers, devPtr) + off := 0 + for _, en := range chunk { + sub := unsafe.Add(devPtr, off) + view := tensor.NewGPUStorageViewFromPtr[T](sub, en.nelem, e.deviceID) + en.t.SetStorage(view) + off += en.nelem * elemSize + } + } + return len(eligible), nil +} + +// UploadWeightsT is the T-typed analogue of UploadWeights: it makes this +// engine's own bf16 (or any-T) weight tensors device-resident via bulkUploadT. +// Unlike UploadWeights it does NOT handle the float32 quantized-storage paths +// (Q4/Q4_K/...); those remain the f32 inference engine's responsibility. Use +// this from a GPUEngine[T] (e.g. T == float16.BFloat16) to lift host-backed +// weights onto the device once at load time and avoid the per-op H2D firehose. +func (e *GPUEngine[T]) UploadWeightsT(tensors []*tensor.TensorNumeric[T]) error { + e.setDevice() + _, err := e.bulkUploadT(tensors) + return err +} + func (e *GPUEngine[T]) UploadWeights(tensors []*tensor.TensorNumeric[float32]) error { e.setDevice() uploaded, err := e.bulkUploadF32(tensors) diff --git a/compute/gpu_kernels.go b/compute/gpu_kernels.go index be113f4..60252c2 100644 --- a/compute/gpu_kernels.go +++ b/compute/gpu_kernels.go @@ -1009,6 +1009,9 @@ func (e *GPUEngine[T]) gpuSqrt(ctx context.Context, a *tensor.TensorNumeric[T], } func (e *GPUEngine[T]) gpuRsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if isBFloat16[T]() { + return gpuUnaryOpBF16(e, a, e.kernels.RsqrtBF16, dst...) + } if !isFloat32[T]() { return e.cpu.Rsqrt(ctx, a, dst...) } @@ -1105,6 +1108,9 @@ func (e *GPUEngine[T]) gpuFill(ctx context.Context, t *tensor.TensorNumeric[T], } func (e *GPUEngine[T]) gpuSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if isBFloat16[T]() { + return gpuSumAxisBF16(e, ctx, a, axis, keepDims, 1.0, dst...) + } if !isFloat32[T]() { return e.cpu.Sum(ctx, a, axis, keepDims, dst...) } @@ -1203,6 +1209,26 @@ func (e *GPUEngine[T]) gpuReduceSum(ctx context.Context, a *tensor.TensorNumeric } func (e *GPUEngine[T]) gpuReduceMean(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if isBFloat16[T]() { + if a == nil { + return nil, fmt.Errorf("ReduceMean: input tensor must not be nil") + } + // Negative axis falls back to CPU ReduceMean (matches the f32 contract and + // avoids gpuSumAxisBF16's sum-only negative-axis fallback). + if axis < 0 { + return e.cpu.ReduceMean(ctx, a, axis, keepDims, dst...) + } + // Native bf16 mean: fold 1/axisSize into the FP32-accumulating sum kernel + // so the result stays on-device and rounds to bf16 exactly once. axis >= + // rank surfaces gpuSumAxisBF16's "axis out of bounds" error (matches f32). + shape := a.Shape() + rank := len(shape) + invDivisor := float32(1.0) + if axis < rank && shape[axis] != 0 { + invDivisor = 1.0 / float32(shape[axis]) + } + return gpuSumAxisBF16(e, ctx, a, axis, keepDims, invDivisor, dst...) + } if !isFloat32[T]() { return e.cpu.ReduceMean(ctx, a, axis, keepDims, dst...) } diff --git a/internal/cuda/kernels/elementwise_bf16.cu b/internal/cuda/kernels/elementwise_bf16.cu index 76ac19a..5ee3f4d 100644 --- a/internal/cuda/kernels/elementwise_bf16.cu +++ b/internal/cuda/kernels/elementwise_bf16.cu @@ -86,6 +86,13 @@ __global__ void kernel_sqrt_bf16(const __nv_bfloat16* a, __nv_bfloat16* c, int n } } +__global__ void kernel_rsqrt_bf16(const __nv_bfloat16* a, __nv_bfloat16* c, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { + c[idx] = __float2bfloat16(rsqrtf(__bfloat162float(a[idx]))); + } +} + __global__ void kernel_exp_bf16(const __nv_bfloat16* a, __nv_bfloat16* c, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { @@ -179,6 +186,53 @@ __global__ void kernel_scaled_softmax_bf16(const __nv_bfloat16* input, __nv_bflo } } +// ---------- Sum along an axis bf16 ---------- +// output[outer][inner] = sum(input[outer][k][inner], k=0..axisSize-1) * invDivisor. +// Each block handles one (outer,inner) stripe; the per-axis sum is accumulated +// in FP32 (matching SumAxis f32 semantics and the no-fast-math convention) and +// the final FP32 result is rounded to bf16. invDivisor scales the sum before +// rounding: pass 1.0f for a plain sum or 1/axisSize for a mean (folding the +// divide into the FP32 accumulation keeps the mean on-device and rounds exactly +// once). A sum of N bf16 values accumulated in FP32 is more accurate than +// pairwise bf16 addition, so the result equals round_to_bf16(sum_fp32(input) * +// invDivisor) -- the parity oracle the gate checks (with a generous tolerance, +// since the f64 reference rounds only once at the end). + +__global__ void kernel_sum_axis_bf16(const __nv_bfloat16* input, __nv_bfloat16* output, + int outer, int inner, int axisSize, + float invDivisor) { + int stripe = blockIdx.x; + int o = stripe / inner; + int in_ = stripe % inner; + int base = o * axisSize * inner + in_; + int step = inner; + + extern __shared__ float sdata[]; + + float local_sum = 0.0f; + for (int k = threadIdx.x; k < axisSize; k += blockDim.x) { + local_sum += __bfloat162float(input[base + k * step]); + } + sdata[threadIdx.x] = local_sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s >= 32; s >>= 1) { + if (threadIdx.x < s) { + sdata[threadIdx.x] += sdata[threadIdx.x + s]; + } + __syncthreads(); + } + if (threadIdx.x < 32) { + float val = sdata[threadIdx.x]; + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(0xffffffff, val, offset); + } + if (threadIdx.x == 0) { + output[stripe] = __float2bfloat16(val * invDivisor); + } + } +} + // ---------- F32 <-> BF16 conversion kernels ---------- __global__ void kernel_f32_to_bf16(const float* src, __nv_bfloat16* dst, int n) { @@ -266,6 +320,14 @@ cudaError_t launch_sqrt_bf16(const void* a, void* c, int n, cudaStream_t stream) return cudaGetLastError(); } +cudaError_t launch_rsqrt_bf16(const void* a, void* c, int n, cudaStream_t stream) { + int grid, block; + grid_config_bf16_unary(n, &grid, &block); + kernel_rsqrt_bf16<<>>( + (const __nv_bfloat16*)a, (__nv_bfloat16*)c, n); + return cudaGetLastError(); +} + cudaError_t launch_exp_bf16(const void* a, void* c, int n, cudaStream_t stream) { int grid, block; grid_config_bf16_unary(n, &grid, &block); @@ -296,6 +358,23 @@ cudaError_t launch_scaled_softmax_bf16(const void* input, void* output, return cudaGetLastError(); } +cudaError_t launch_sum_axis_bf16(const void* input, void* output, + int outer, int inner, int axisSize, + unsigned int inv_divisor_bits, + cudaStream_t stream) { + float invDivisor = bits_to_float_bf16(inv_divisor_bits); + // Mirror launch_scaled_softmax_bf16's launch config but floor the block at + // one warp: the reduction ends with a full-warp __shfl_down_sync, so every + // launched lane (incl. those past axisSize, which contribute 0) must exist. + int block = 32; + while (block < axisSize && block < 256) block <<= 1; + int numStripes = outer * inner; + size_t smem = block * sizeof(float); + kernel_sum_axis_bf16<<>>( + (const __nv_bfloat16*)input, (__nv_bfloat16*)output, outer, inner, axisSize, invDivisor); + return cudaGetLastError(); +} + cudaError_t launch_f32_to_bf16(const void* src, void* dst, int n, cudaStream_t stream) { int block = 256; int grid = (n + block - 1) / block; diff --git a/internal/cuda/kernels/elementwise_bf16_purego.go b/internal/cuda/kernels/elementwise_bf16_purego.go index 5557350..60b20d4 100644 --- a/internal/cuda/kernels/elementwise_bf16_purego.go +++ b/internal/cuda/kernels/elementwise_bf16_purego.go @@ -69,6 +69,16 @@ func SqrtBF16(a, c unsafe.Pointer, n int, s unsafe.Pointer) error { return checkKernel(ret, "sqrt_bf16") } +// RsqrtBF16 launches the bf16 elementwise rsqrt kernel (FP32 transcendental): c = 1/sqrt(a). +func RsqrtBF16(a, c unsafe.Pointer, n int, s unsafe.Pointer) error { + k := klib() + if k == nil { + return fmt.Errorf("rsqrt_bf16 kernel: kernels not available") + } + ret := cuda.Ccall(k.launchRsqrtBF16, uintptr(a), uintptr(c), uintptr(n), uintptr(s)) + return checkKernel(ret, "rsqrt_bf16") +} + // ExpBF16 launches the bf16 elementwise exp kernel (FP32 transcendental): c = exp(a). func ExpBF16(a, c unsafe.Pointer, n int, s unsafe.Pointer) error { k := klib() @@ -109,6 +119,24 @@ func BF16ToF32(src, dst unsafe.Pointer, n int, s unsafe.Pointer) error { return checkKernel(ret, "bf16_to_f32") } +// SumAxisBF16 reduces bf16 data along one axis with FP32 accumulation: +// output[outer][inner] = sum(input[outer][k][inner], k=0..axisSize-1) * invDivisor, +// rounded to bf16. invDivisor is 1.0 for a plain sum and 1/axisSize for a mean +// (folding the divide into the FP32 accumulation so the mean stays on-device and +// rounds exactly once). +func SumAxisBF16(input, output unsafe.Pointer, outer, inner, axisSize int, invDivisor float32, s unsafe.Pointer) error { + k := klib() + if k == nil { + return fmt.Errorf("sum_axis_bf16 kernel: kernels not available") + } + ret := cuda.Ccall(k.launchSumAxisBF16, + uintptr(input), uintptr(output), + uintptr(outer), uintptr(inner), uintptr(axisSize), + floatBits(invDivisor), + uintptr(s)) + return checkKernel(ret, "sum_axis_bf16") +} + // ScaledSoftmaxBF16 applies fused scaled softmax on bf16 data with FP32 accumulation. func ScaledSoftmaxBF16( input, output unsafe.Pointer, diff --git a/internal/cuda/kernels/purego.go b/internal/cuda/kernels/purego.go index 0623b2b..cbba14c 100644 --- a/internal/cuda/kernels/purego.go +++ b/internal/cuda/kernels/purego.go @@ -131,7 +131,10 @@ type KernelLib struct { launchAddBF16, launchSubBF16, launchMulBF16, launchDivBF16 uintptr // bf16 unary (FP32 transcendental) - launchTanhBF16, launchSqrtBF16, launchExpBF16, launchLogBF16 uintptr + launchTanhBF16, launchSqrtBF16, launchRsqrtBF16, launchExpBF16, launchLogBF16 uintptr + + // bf16 reductions (FP32 accumulation) + launchSumAxisBF16 uintptr // bf16 scaled_softmax launchScaledSoftmaxBF16 uintptr @@ -350,8 +353,11 @@ func openKernelLib() (*KernelLib, error) { // bf16 unary (FP32 transcendental) {"launch_tanh_bf16", &k.launchTanhBF16}, {"launch_sqrt_bf16", &k.launchSqrtBF16}, + {"launch_rsqrt_bf16", &k.launchRsqrtBF16}, {"launch_exp_bf16", &k.launchExpBF16}, {"launch_log_bf16", &k.launchLogBF16}, + // bf16 reductions + {"launch_sum_axis_bf16", &k.launchSumAxisBF16}, // bf16 scaled_softmax {"launch_scaled_softmax_bf16", &k.launchScaledSoftmaxBF16}, // bf16 conversion @@ -464,8 +470,10 @@ func openKernelLib() (*KernelLib, error) { "launch_div_bf16": true, "launch_tanh_bf16": true, "launch_sqrt_bf16": true, + "launch_rsqrt_bf16": true, "launch_exp_bf16": true, "launch_log_bf16": true, + "launch_sum_axis_bf16": true, "launch_scaled_softmax_bf16": true, "launch_f32_to_bf16": true, "launch_bf16_to_f32": true, diff --git a/internal/gpuapi/cuda_kernels.go b/internal/gpuapi/cuda_kernels.go index 1eb74e2..6a0dbcf 100644 --- a/internal/gpuapi/cuda_kernels.go +++ b/internal/gpuapi/cuda_kernels.go @@ -304,6 +304,14 @@ func (k *CUDAKernels) SqrtBF16(a, c unsafe.Pointer, n int, s Stream) error { return kernels.SqrtBF16(a, c, n, streamPtr(s)) } +func (k *CUDAKernels) RsqrtBF16(a, c unsafe.Pointer, n int, s Stream) error { + return kernels.RsqrtBF16(a, c, n, streamPtr(s)) +} + +func (k *CUDAKernels) SumAxisBF16(input, output unsafe.Pointer, outer, inner, axisSize int, invDivisor float32, s Stream) error { + return kernels.SumAxisBF16(input, output, outer, inner, axisSize, invDivisor, streamPtr(s)) +} + func (k *CUDAKernels) ExpBF16(a, c unsafe.Pointer, n int, s Stream) error { return kernels.ExpBF16(a, c, n, streamPtr(s)) } diff --git a/internal/gpuapi/fpga_kernels.go b/internal/gpuapi/fpga_kernels.go index b20115f..42f8cf8 100644 --- a/internal/gpuapi/fpga_kernels.go +++ b/internal/gpuapi/fpga_kernels.go @@ -344,6 +344,14 @@ func (k *FPGAKernels) SqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("SqrtBF16: not implemented for FPGA") } +func (k *FPGAKernels) RsqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { + return fmt.Errorf("RsqrtBF16: not implemented for FPGA") +} + +func (k *FPGAKernels) SumAxisBF16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error { + return fmt.Errorf("SumAxisBF16: not implemented for FPGA") +} + func (k *FPGAKernels) ExpBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("ExpBF16: not implemented for FPGA") } diff --git a/internal/gpuapi/gpuapi_test.go b/internal/gpuapi/gpuapi_test.go index d2a11e1..48d2bf6 100644 --- a/internal/gpuapi/gpuapi_test.go +++ b/internal/gpuapi/gpuapi_test.go @@ -286,6 +286,12 @@ func (stubKernelRunner) TanhBF16(_, _ unsafe.Pointer, _ int, _ gpuapi.Stream) er func (stubKernelRunner) SqrtBF16(_, _ unsafe.Pointer, _ int, _ gpuapi.Stream) error { return nil } +func (stubKernelRunner) RsqrtBF16(_, _ unsafe.Pointer, _ int, _ gpuapi.Stream) error { + return nil +} +func (stubKernelRunner) SumAxisBF16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ gpuapi.Stream) error { + return nil +} func (stubKernelRunner) ExpBF16(_, _ unsafe.Pointer, _ int, _ gpuapi.Stream) error { return nil } diff --git a/internal/gpuapi/kernels.go b/internal/gpuapi/kernels.go index 0c54b64..a76789a 100644 --- a/internal/gpuapi/kernels.go +++ b/internal/gpuapi/kernels.go @@ -198,9 +198,16 @@ type KernelRunner interface { // bf16 unary operations (FP32 transcendental, bf16 in/out). TanhBF16(a, c unsafe.Pointer, n int, stream Stream) error SqrtBF16(a, c unsafe.Pointer, n int, stream Stream) error + RsqrtBF16(a, c unsafe.Pointer, n int, stream Stream) error ExpBF16(a, c unsafe.Pointer, n int, stream Stream) error LogBF16(a, c unsafe.Pointer, n int, stream Stream) error + // SumAxisBF16 reduces bf16 data along one axis with FP32 accumulation: + // output[outer][inner] = sum(input[outer][k][inner], k=0..axisSize-1) * invDivisor. + // invDivisor is 1.0 for a plain sum and 1/axisSize for a mean (the divide is + // folded into the FP32 accumulation, keeping the mean on-device). + SumAxisBF16(input, output unsafe.Pointer, outer, inner, axisSize int, invDivisor float32, stream Stream) error + // ScaledSoftmaxBF16 computes softmax(input * scale) on bf16 data with FP32 accumulation. ScaledSoftmaxBF16(input, output unsafe.Pointer, outer, inner, axisSize int, scale float32, stream Stream) error diff --git a/internal/gpuapi/metal_kernels.go b/internal/gpuapi/metal_kernels.go index ef951c4..1bab6c9 100644 --- a/internal/gpuapi/metal_kernels.go +++ b/internal/gpuapi/metal_kernels.go @@ -634,6 +634,14 @@ func (k *MetalKernels) SqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("SqrtBF16: not implemented for Metal") } +func (k *MetalKernels) RsqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { + return fmt.Errorf("RsqrtBF16: not implemented for Metal") +} + +func (k *MetalKernels) SumAxisBF16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error { + return fmt.Errorf("SumAxisBF16: not implemented for Metal") +} + func (k *MetalKernels) ExpBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("ExpBF16: not implemented for Metal") } diff --git a/internal/gpuapi/opencl_kernels.go b/internal/gpuapi/opencl_kernels.go index a88cd9a..5724239 100644 --- a/internal/gpuapi/opencl_kernels.go +++ b/internal/gpuapi/opencl_kernels.go @@ -361,6 +361,14 @@ func (k *OpenCLKernels) SqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("SqrtBF16: not implemented for OpenCL") } +func (k *OpenCLKernels) RsqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { + return fmt.Errorf("RsqrtBF16: not implemented for OpenCL") +} + +func (k *OpenCLKernels) SumAxisBF16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error { + return fmt.Errorf("SumAxisBF16: not implemented for OpenCL") +} + func (k *OpenCLKernels) ExpBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("ExpBF16: not implemented for OpenCL") } diff --git a/internal/gpuapi/rocm_kernels.go b/internal/gpuapi/rocm_kernels.go index 8b50d1d..75486d0 100644 --- a/internal/gpuapi/rocm_kernels.go +++ b/internal/gpuapi/rocm_kernels.go @@ -351,6 +351,14 @@ func (k *ROCmKernels) SqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("SqrtBF16: not implemented for ROCm") } +func (k *ROCmKernels) RsqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { + return fmt.Errorf("RsqrtBF16: not implemented for ROCm") +} + +func (k *ROCmKernels) SumAxisBF16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error { + return fmt.Errorf("SumAxisBF16: not implemented for ROCm") +} + func (k *ROCmKernels) ExpBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("ExpBF16: not implemented for ROCm") } diff --git a/internal/gpuapi/sycl_kernels.go b/internal/gpuapi/sycl_kernels.go index 8a0054d..e26abd6 100644 --- a/internal/gpuapi/sycl_kernels.go +++ b/internal/gpuapi/sycl_kernels.go @@ -352,6 +352,14 @@ func (k *SYCLKernels) SqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("SqrtBF16: not implemented for SYCL") } +func (k *SYCLKernels) RsqrtBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { + return fmt.Errorf("RsqrtBF16: not implemented for SYCL") +} + +func (k *SYCLKernels) SumAxisBF16(_, _ unsafe.Pointer, _, _, _ int, _ float32, _ Stream) error { + return fmt.Errorf("SumAxisBF16: not implemented for SYCL") +} + func (k *SYCLKernels) ExpBF16(_, _ unsafe.Pointer, _ int, _ Stream) error { return fmt.Errorf("ExpBF16: not implemented for SYCL") }