Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions compute/bulk_upload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"math"
"testing"

"github.com/zerfoo/float16"
"github.com/zerfoo/ztensor/internal/cuda"
"github.com/zerfoo/ztensor/numeric"
"github.com/zerfoo/ztensor/tensor"
Expand Down Expand Up @@ -69,6 +70,82 @@ func TestGPUEngine_UploadWeights_BulkPath(t *testing.T) {
}
}

// TestGPUEngine_UploadWeightsT_BF16 verifies the generic (T-typed) bulk upload:
// a GPUEngine[float16.BFloat16] can make its host-backed bf16 weight tensors
// device-resident via UploadWeightsT, so they no longer stay host-backed and
// pay a per-op H2D firehose. Asserts the tensors become *GPUStorage[BFloat16]
// (2-byte device elements, sized by unsafe.Sizeof, not f32Size) and that values
// round-trip across the bulk copy. CUDA-gated; skips without a GPU.
// GPU-UNVERIFIED until run on the GB10 verify pod.
func TestGPUEngine_UploadWeightsT_BF16(t *testing.T) {
if !cuda.Available() {
t.Skip("CUDA not available")
}

gpuEng, err := NewGPUEngine[float16.BFloat16](numeric.BFloat16Ops{})
if err != nil {
t.Fatalf("NewGPUEngine[BFloat16]: %v", err)
}
defer func() { _ = gpuEng.Close() }()

// N must exceed bulkUploadF32MinTensors (64) to exercise the bulk path.
const N = 128
const elemsPer = 17

tensors := make([]*tensor.TensorNumeric[float16.BFloat16], N)
wantVals := make([][]float32, N)
for i := range N {
data := make([]float16.BFloat16, elemsPer)
wantVals[i] = make([]float32, elemsPer)
for j := range elemsPer {
// values exactly representable in bf16 so the round-trip is exact
f := float16.BFloat16FromFloat32(float32(i) + float32(j)/64.0)
data[j] = f
wantVals[i][j] = f.ToFloat32()
}
tt, errNew := tensor.New[float16.BFloat16]([]int{elemsPer}, data)
if errNew != nil {
t.Fatalf("tensor.New[BFloat16]: %v", errNew)
}
tensors[i] = tt
}

if got := len(gpuEng.bulkUploadBuffers); got != 0 {
t.Fatalf("bulkUploadBuffers before upload = %d, want 0", got)
}

if err := gpuEng.UploadWeightsT(tensors); err != nil {
t.Fatalf("UploadWeightsT: %v", err)
}

// Every tensor must now be device-resident as *GPUStorage[BFloat16].
for i, tt := range tensors {
if _, ok := tt.GetStorage().(*tensor.GPUStorage[float16.BFloat16]); !ok {
t.Fatalf("tensor[%d] storage = %T, want *GPUStorage[float16.BFloat16]", i, tt.GetStorage())
}
}

// Round-trip a sample to verify the bulk copy preserved bf16 bytes.
for _, i := range []int{0, 1, N / 2, N - 1} {
got := tensors[i].Data()
for j := range elemsPer {
if got[j].ToFloat32() != wantVals[i][j] {
t.Errorf("tensor[%d][%d] = %g, want %g", i, j, got[j].ToFloat32(), wantVals[i][j])
}
}
}

// A second UploadWeightsT must be a no-op: the tensors are already
// *GPUStorage[BFloat16] and must be skipped (no new bulk buffers).
before := len(gpuEng.bulkUploadBuffers)
if err := gpuEng.UploadWeightsT(tensors); err != nil {
t.Fatalf("UploadWeightsT (second call): %v", err)
}
if after := len(gpuEng.bulkUploadBuffers); after != before {
t.Errorf("second UploadWeightsT allocated %d new bulk buffers, want 0 (already-resident skip)", after-before)
}
}

// TestGPUEngine_UploadWeights_MultiChunk exercises the bounded-chunk upload
// path on real hardware (zerfoo/ztensor#106). It uploads a payload large enough
// to span several bulkUploadF32MaxChunkBytes (64 MiB) chunks, proving that (a) a
Expand Down
114 changes: 114 additions & 0 deletions compute/gpu_bf16.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,120 @@ func gpuFusedQKNormRoPEBF16[T tensor.Numeric](
return makeGPUResult[T](e, []int{totalHeads, headDim}, devOut, outElems)
}

// gpuSumAxisBF16 runs a native bf16 axis reduction (sum along `axis`) using the
// FP32-accumulating SumAxisBF16 kernel. It is the bf16 analogue of the f32
// gpuSum body and mirrors its axis-normalization, keepDims/squeeze output-shape
// logic, and CPU fallbacks exactly. The kernel accumulates each axis stripe in
// FP32 and rounds the result to bf16; for axis-sized reductions this is more
// accurate than pairwise bf16 addition but still only carries bf16's 7-bit
// mantissa, so callers must tolerate a few bf16 steps vs an f64 reference.
//
// invDivisor scales the per-stripe FP32 sum before the bf16 round: pass 1.0 for
// a plain sum (gpuSum/gpuReduceSum) or 1/axisSize for a mean (gpuReduceMean).
// On every CPU fallback the divide is reapplied (cpu.ReduceMean) so the contract
// holds identically for both sum and mean.
func gpuSumAxisBF16[T tensor.Numeric](
e *GPUEngine[T],
ctx context.Context,
a *tensor.TensorNumeric[T],
axis int,
keepDims bool,
invDivisor float32,
dst ...*tensor.TensorNumeric[T],
) (*tensor.TensorNumeric[T], error) {
if a == nil {
return nil, fmt.Errorf("Sum: input tensor must not be nil")
}

e.setDevice()

// cpuFallback computes the same reduction on the CPU engine: a plain sum when
// invDivisor == 1, otherwise the mean (so the divide-folding contract holds
// identically on every fallback path).
cpuFallback := func() (*tensor.TensorNumeric[T], error) {
if invDivisor == 1.0 {
return e.cpu.Sum(ctx, a, axis, keepDims, dst...)
}
return e.cpu.ReduceMean(ctx, a, axis, keepDims, dst...)
}

// Negative axis falls back to CPU, matching the f32 gpuSum contract.
if axis < 0 {
return cpuFallback()
}

shape := a.Shape()
rank := len(shape)

if axis >= rank {
return nil, fmt.Errorf("Sum: axis %d out of bounds for %d dimensions", axis, rank)
}

inner := 1
for i := axis + 1; i < rank; i++ {
inner *= shape[i]
}

outer := 1
for i := 0; i < axis; i++ {
outer *= shape[i]
}

axisSize := shape[axis]
numStripes := outer * inner

var newShape []int
if keepDims {
newShape = make([]int, rank)
for i, d := range shape {
if i == axis {
newShape[i] = 1
} else {
newShape[i] = d
}
}
} else {
for i, d := range shape {
if i != axis {
newShape = append(newShape, d)
}
}
if len(newShape) == 0 {
newShape = []int{1}
}
}

devIn, cleanupIn, err := getDevicePtr(e, a)
if err != nil {
return cpuFallback()
}
defer cleanupIn()

outByteSize := numStripes * bf16Size

// Reuse dst's existing GPU memory when possible (mirrors f32 gpuSum #84).
devOut, reused := tryReuseDstPtr[T](numStripes, dst)
if !reused {
devOut, err = e.pool.Alloc(e.deviceID, outByteSize)
if err != nil {
return cpuFallback()
}
}

if err := e.kernels.SumAxisBF16(devIn, devOut, outer, inner, axisSize, invDivisor, e.stream); err != nil {
if !reused {
e.pool.Free(e.deviceID, devOut, outByteSize)
}

return nil, err
}

if reused {
return finishReusedDst[T](dst[0], newShape), nil
}
return makeGPUResult[T](e, newShape, devOut, numStripes, dst...)
}

// gpuUnaryOpBF16 runs a native bf16 unary kernel (c = op(a)).
func gpuUnaryOpBF16[T tensor.Numeric](
e *GPUEngine[T],
Expand Down
104 changes: 104 additions & 0 deletions compute/gpu_bf16_parity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ func TestGPUBF16_UnaryParity(t *testing.T) {
{"Sqrt", pos, func(in *tensor.TensorNumeric[float16.BFloat16]) (*tensor.TensorNumeric[float16.BFloat16], error) {
return eng.Sqrt(ctx, in)
}, func(v float32) float32 { return float32(math.Sqrt(float64(v))) }, 2.0},
{"Rsqrt", pos, func(in *tensor.TensorNumeric[float16.BFloat16]) (*tensor.TensorNumeric[float16.BFloat16], error) {
return eng.Rsqrt(ctx, in)
}, func(v float32) float32 { return float32(1.0 / math.Sqrt(float64(v))) }, 2.0},
{"Exp", x, func(in *tensor.TensorNumeric[float16.BFloat16]) (*tensor.TensorNumeric[float16.BFloat16], error) {
return eng.Exp(ctx, in)
}, func(v float32) float32 { return float32(math.Exp(float64(v))) }, 2.0},
Expand Down Expand Up @@ -220,6 +223,107 @@ func TestGPUBF16_SoftmaxParity(t *testing.T) {
}
}

// TestGPUBF16_ReductionParity validates the native bf16 axis reductions (Sum,
// ReduceSum, ReduceMean) against an f64 reference. The kernel accumulates each
// axis stripe in FP32 and rounds the result to bf16 once at the end, while the
// f64 reference accumulates in double precision and rounds once -- so the only
// expected divergence is bf16's 7-bit mantissa, amplified by accumulation over
// axisSize terms. We use a generous, axisSize-scaled tolerance: summing N bf16
// values in FP32 then rounding can differ from the f64-then-bf16 reference by a
// few bf16 steps as N grows, so this is an order-of-magnitude correctness gate,
// not a bit-exact one. GPU-UNVERIFIED until run on the GB10 verify pod.
func TestGPUBF16_ReductionParity(t *testing.T) {
eng := newTestGPUBF16Engine(t)
ctx := context.Background()
rng := rand.New(rand.NewSource(1234))

// 3D so the reduced axis (axis=1) has a nontrivial inner stride: this
// exercises the (outer, inner, axisSize) stripe addressing in the kernel.
const outer, axisSize, inner = 4, 24, 5
vals := make([]float32, outer*axisSize*inner)
for i := range vals {
// keep values modest so the FP32 partial sums stay well within bf16 range
vals[i] = float16.BFloat16FromFloat32(rng.Float32()*2 - 1).ToFloat32()
}
shape := []int{outer, axisSize, inner}

// f64 reference reduced along axis=1: ref[o][in] = sum_k vals[o][k][in].
refSum := make([]float64, outer*inner)
for o := 0; o < outer; o++ {
for in := 0; in < inner; in++ {
var s float64
for k := 0; k < axisSize; k++ {
s += float64(vals[o*axisSize*inner+k*inner+in])
}
refSum[o*inner+in] = s
}
}

// accumulating axisSize bf16 values in FP32 then rounding can drift a few
// bf16 steps from the f64 reference; scale tolerance with sqrt(axisSize).
tolUlps := 2.0 + math.Sqrt(float64(axisSize))

t.Run("Sum", func(t *testing.T) {
in := bf16Tensor(t, shape, vals)
got, err := eng.Sum(ctx, in, 1, false)
if err != nil {
t.Fatalf("Sum: %v", err)
}
gd := bf16ToF32(got.Data())
if len(gd) != outer*inner {
t.Fatalf("Sum produced %d elements, want %d (shape=%v)", len(gd), outer*inner, got.Shape())
}
for i := range gd {
want := float16.BFloat16FromFloat32(float32(refSum[i])).ToFloat32()
assertBF16Close(t, "Sum", i, gd[i], want, tolUlps)
}
})

t.Run("ReduceSum", func(t *testing.T) {
in := bf16Tensor(t, shape, vals)
got, err := eng.ReduceSum(ctx, in, 1, false)
if err != nil {
t.Fatalf("ReduceSum: %v", err)
}
gd := bf16ToF32(got.Data())
for i := range gd {
want := float16.BFloat16FromFloat32(float32(refSum[i])).ToFloat32()
assertBF16Close(t, "ReduceSum", i, gd[i], want, tolUlps)
}
})

t.Run("ReduceMean", func(t *testing.T) {
in := bf16Tensor(t, shape, vals)
got, err := eng.ReduceMean(ctx, in, 1, false)
if err != nil {
t.Fatalf("ReduceMean: %v", err)
}
gd := bf16ToF32(got.Data())
for i := range gd {
want := float16.BFloat16FromFloat32(float32(refSum[i] / float64(axisSize))).ToFloat32()
assertBF16Close(t, "ReduceMean", i, gd[i], want, tolUlps)
}
})

t.Run("SumKeepDims", func(t *testing.T) {
in := bf16Tensor(t, shape, vals)
got, err := eng.Sum(ctx, in, 1, true)
if err != nil {
t.Fatalf("Sum keepDims: %v", err)
}
want := []int{outer, 1, inner}
gs := got.Shape()
if len(gs) != len(want) {
t.Fatalf("Sum keepDims shape = %v, want %v", gs, want)
}
for i := range want {
if gs[i] != want[i] {
t.Fatalf("Sum keepDims shape = %v, want %v", gs, want)
}
}
})
}

// TestGPUBF16_AdamWParity validates the full gradient-consuming update path:
// the on-device bf16 AdamW step (param/grad bf16, m f32, v f64) must match an
// f64 reference AdamW step with the published parameter rounded to bf16. This
Expand Down
Loading
Loading