Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions graph/comprehensive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"testing"

"github.com/zerfoo/float16"
"github.com/zerfoo/ztensor/compute"
"github.com/zerfoo/ztensor/numeric"
"github.com/zerfoo/ztensor/tensor"
Expand Down Expand Up @@ -379,6 +380,37 @@ func TestParameter_AddGradient_Int(t *testing.T) {
testAddGradientType(t, "int", []int{1, 2}, []int{3, 4})
}

func TestParameter_AddGradient_BFloat16(t *testing.T) {
// Unblocks bf16 autograd training: layer backwards accumulate grads here.
vals := []float16.BFloat16{float16.BFloat16FromFloat32(1.0), float16.BFloat16FromFloat32(2.0)}
grads := []float16.BFloat16{float16.BFloat16FromFloat32(0.5), float16.BFloat16FromFloat32(0.25)}
testAddGradientType(t, "bf16", vals, grads)
}

func TestParameter_AddGradient_BFloat16_Value(t *testing.T) {
value, _ := tensor.New[float16.BFloat16]([]int{2}, []float16.BFloat16{float16.BFloat16FromFloat32(0), float16.BFloat16FromFloat32(0)})
param, _ := NewParameter("bf16v", value, tensor.New[float16.BFloat16])
g1, _ := tensor.New[float16.BFloat16]([]int{2}, []float16.BFloat16{float16.BFloat16FromFloat32(0.5), float16.BFloat16FromFloat32(1.0)})
g2, _ := tensor.New[float16.BFloat16]([]int{2}, []float16.BFloat16{float16.BFloat16FromFloat32(0.25), float16.BFloat16FromFloat32(0.5)})
if err := param.AddGradient(g1); err != nil {
t.Fatalf("AddGradient g1: %v", err)
}
if err := param.AddGradient(g2); err != nil {
t.Fatalf("AddGradient g2: %v", err)
}
got := param.Gradient.Data()
// 0.5+0.25=0.75 and 1.0+0.5=1.5 are exactly representable in bf16.
if got[0].ToFloat32() != 0.75 || got[1].ToFloat32() != 1.5 {
t.Errorf("bf16 grad accumulation = [%g %g], want [0.75 1.5]", got[0].ToFloat32(), got[1].ToFloat32())
}
}

func TestParameter_AddGradient_Float16(t *testing.T) {
vals := []float16.Float16{float16.FromFloat32(1.0), float16.FromFloat32(2.0)}
grads := []float16.Float16{float16.FromFloat32(0.5), float16.FromFloat32(0.25)}
testAddGradientType(t, "fp16", vals, grads)
}

func TestParameter_AddGradient_Int8(t *testing.T) {
testAddGradientType(t, "int8", []int8{1, 2}, []int8{3, 4})
}
Expand Down
23 changes: 23 additions & 0 deletions graph/parameter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package graph
import (
"errors"

"github.com/zerfoo/float16"
"github.com/zerfoo/ztensor/tensor"
)

Expand Down Expand Up @@ -90,6 +91,20 @@ func (p *Parameter[T]) AddGradient(grad *tensor.TensorNumeric[T]) error {
for i := range gdst {
gdst[i] = any(any(gdst[i]).(uint64) + any(gsrc[i]).(uint64)).(T)
}
case float16.BFloat16:
// bf16 accumulates through f32 (the type has no native add); the bf16
// round on store matches how every bf16 op publishes its result. This
// unblocks bf16 autograd training (the layer backwards accumulate grads
// here). bf16 shares f32's exponent range, so no overflow vs f32.
for i := range gdst {
sum := any(gdst[i]).(float16.BFloat16).ToFloat32() + any(gsrc[i]).(float16.BFloat16).ToFloat32()
gdst[i] = any(float16.BFloat16FromFloat32(sum)).(T)
}
case float16.Float16:
for i := range gdst {
sum := any(gdst[i]).(float16.Float16).ToFloat32() + any(gsrc[i]).(float16.Float16).ToFloat32()
gdst[i] = any(float16.FromFloat32(sum)).(T)
}
default:
return errors.New("AddGradient unsupported for this numeric type; use engine ops instead")
}
Expand Down Expand Up @@ -141,6 +156,14 @@ func (p *Parameter[T]) ClearGradient() {
for i := range gdst {
gdst[i] = any(uint64(0)).(T)
}
case float16.BFloat16:
for i := range gdst {
gdst[i] = any(float16.BFloat16FromFloat32(0)).(T)
}
case float16.Float16:
for i := range gdst {
gdst[i] = any(float16.FromFloat32(0)).(T)
}
default:
// Unsupported numeric types: set via copy from a zeroed slice of same length if possible
for i := range gdst {
Expand Down
Loading