From 288ea0a74e52fb797d36f7065a77bb5fbaf51994 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Sun, 8 Feb 2026 16:54:21 -0600 Subject: [PATCH 01/11] feat: add bfloat16 array primitive --- go/fory/array_primitive.go | 75 +++++++++++++++++++++++++++++++++ go/fory/array_primitive_test.go | 16 +++++++ 2 files changed, 91 insertions(+) diff --git a/go/fory/array_primitive.go b/go/fory/array_primitive.go index 9777733590..27813060b7 100644 --- a/go/fory/array_primitive.go +++ b/go/fory/array_primitive.go @@ -21,6 +21,7 @@ import ( "reflect" "unsafe" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" ) @@ -871,3 +872,77 @@ func (s float16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType func (s float16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { s.Read(ctx, refMode, false, false, value) } + +// ============================================================================ +// bfloat16ArraySerializer - optimized [N]bfloat16.BFloat16 serialization +// ============================================================================ + +type bfloat16ArraySerializer struct { + arrayType reflect.Type +} + +func (s bfloat16ArraySerializer) WriteData(ctx *WriteContext, value reflect.Value) { + buf := ctx.Buffer() + length := value.Len() + size := length * 2 + buf.WriteLength(size) + if length > 0 { + if value.CanAddr() && isLittleEndian { + ptr := value.Addr().UnsafePointer() + buf.WriteBinary(unsafe.Slice((*byte)(ptr), size)) + } else { + for i := 0; i < length; i++ { + // We can't easily cast the whole array if not addressable/little-endian + // So we iterate. + val := value.Index(i).Interface().(bfloat16.BFloat16) + buf.WriteUint16(val.Bits()) + } + } + } +} + +func (s bfloat16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + writeArrayRefAndType(ctx, refMode, writeType, value, BFLOAT16_ARRAY) + if ctx.HasError() { + return + } + s.WriteData(ctx, value) +} + +func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { + buf := ctx.Buffer() + ctxErr := ctx.Err() + size := buf.ReadLength(ctxErr) + length := size / 2 + if ctx.HasError() { + return + } + if length != value.Type().Len() { + ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + return + } + + if length > 0 { + if isLittleEndian { + ptr := value.Addr().UnsafePointer() + raw := buf.ReadBinary(size, ctxErr) + copy(unsafe.Slice((*byte)(ptr), size), raw) + } else { + for i := 0; i < length; i++ { + value.Index(i).Set(reflect.ValueOf(bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr)))) + } + } + } +} + +func (s bfloat16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + done := readArrayRefAndType(ctx, refMode, readType, value) + if done || ctx.HasError() { + return + } + s.ReadData(ctx, value) +} + +func (s bfloat16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} diff --git a/go/fory/array_primitive_test.go b/go/fory/array_primitive_test.go index c2e684af4c..e8c99fb2f2 100644 --- a/go/fory/array_primitive_test.go +++ b/go/fory/array_primitive_test.go @@ -20,6 +20,7 @@ package fory import ( "testing" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -74,6 +75,21 @@ func TestPrimitiveArraySerializer(t *testing.T) { require.NoError(t, err) require.Equal(t, arr, result) }) + + t.Run("bfloat16_array", func(t *testing.T) { + arr := [3]bfloat16.BFloat16{ + bfloat16.BFloat16FromFloat32(1.0), + bfloat16.BFloat16FromFloat32(2.5), + bfloat16.BFloat16FromFloat32(-3.5), + } + data, err := f.Serialize(arr) + assert.NoError(t, err) + + var result [3]bfloat16.BFloat16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Equal(t, arr, result) + }) } func TestArraySliceInteroperability(t *testing.T) { From 166e94288dc6f40d06de7c5acfa217b293eda8cb Mon Sep 17 00:00:00 2001 From: Brian Li Date: Sun, 8 Feb 2026 16:54:59 -0600 Subject: [PATCH 02/11] feat: add bfloat16 slice primitive --- go/fory/slice_primitive.go | 79 +++++++++++++++++++++++++++++++++ go/fory/slice_primitive_test.go | 43 ++++++++++++++++++ 2 files changed, 122 insertions(+) diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 200d144c63..e4daf990be 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -22,6 +22,7 @@ import ( "strconv" "unsafe" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" ) @@ -1330,3 +1331,81 @@ func ReadStringSlice(buf *ByteBuffer, err *Error) []string { } return result } + +// ============================================================================ +// bfloat16SliceSerializer - optimized []bfloat16.BFloat16 serialization +// ============================================================================ + +type bfloat16SliceSerializer struct{} + +func (s bfloat16SliceSerializer) WriteData(ctx *WriteContext, value reflect.Value) { + // Cast to []bfloat16.BFloat16 + v := value.Interface().([]bfloat16.BFloat16) + buf := ctx.Buffer() + length := len(v) + size := length * 2 + buf.WriteLength(size) + if length > 0 { + ptr := unsafe.Pointer(&v[0]) + if isLittleEndian { + buf.WriteBinary(unsafe.Slice((*byte)(ptr), size)) + } else { + for i := 0; i < length; i++ { + buf.WriteUint16(v[i].Bits()) + } + } + } +} + +func (s bfloat16SliceSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + done := writeSliceRefAndType(ctx, refMode, writeType, value, BFLOAT16_ARRAY) + if done || ctx.HasError() { + return + } + s.WriteData(ctx, value) +} + +func (s bfloat16SliceSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + done, typeId := readSliceRefAndType(ctx, refMode, readType, value) + if done || ctx.HasError() { + return + } + if readType && typeId != uint32(BFLOAT16_ARRAY) { + ctx.SetError(DeserializationErrorf("slice type mismatch: expected BFLOAT16_ARRAY (%d), got %d", BFLOAT16_ARRAY, typeId)) + return + } + s.ReadData(ctx, value) +} + +func (s bfloat16SliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} + +func (s bfloat16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { + buf := ctx.Buffer() + ctxErr := ctx.Err() + size := buf.ReadLength(ctxErr) + length := size / 2 + if ctx.HasError() { + return + } + + ptr := (*[]bfloat16.BFloat16)(value.Addr().UnsafePointer()) + if length == 0 { + *ptr = make([]bfloat16.BFloat16, 0) + return + } + + result := make([]bfloat16.BFloat16, length) + + if isLittleEndian { + raw := buf.ReadBinary(size, ctxErr) + targetPtr := unsafe.Pointer(&result[0]) + copy(unsafe.Slice((*byte)(targetPtr), size), raw) + } else { + for i := 0; i < length; i++ { + result[i] = bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr)) + } + } + *ptr = result +} diff --git a/go/fory/slice_primitive_test.go b/go/fory/slice_primitive_test.go index 61ff7b89c8..f376cd4f87 100644 --- a/go/fory/slice_primitive_test.go +++ b/go/fory/slice_primitive_test.go @@ -20,6 +20,7 @@ package fory import ( "testing" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/stretchr/testify/assert" ) @@ -65,3 +66,45 @@ func TestFloat16Slice(t *testing.T) { assert.Nil(t, result) }) } + +func TestBFloat16Slice(t *testing.T) { + f := NewFory() + + t.Run("bfloat16_slice", func(t *testing.T) { + slice := []bfloat16.BFloat16{ + bfloat16.BFloat16FromFloat32(1.0), + bfloat16.BFloat16FromFloat32(2.5), + bfloat16.BFloat16FromFloat32(-3.5), + } + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []bfloat16.BFloat16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Equal(t, slice, result) + }) + + t.Run("bfloat16_slice_empty", func(t *testing.T) { + slice := []bfloat16.BFloat16{} + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []bfloat16.BFloat16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Empty(t, result) + }) + + t.Run("bfloat16_slice_nil", func(t *testing.T) { + var slice []bfloat16.BFloat16 = nil + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []bfloat16.BFloat16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Nil(t, result) + }) +} From 405941ba380116075ba05b1a347e613729aab592 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Sun, 8 Feb 2026 16:56:28 -0600 Subject: [PATCH 03/11] feat: add bfloat16 primitive --- go/fory/primitive.go | 53 +++++++++++++++++++++++++++++++++++++++ go/fory/primitive_test.go | 34 +++++++++++++++++++++++++ 2 files changed, 87 insertions(+) diff --git a/go/fory/primitive.go b/go/fory/primitive.go index 2c316d549e..998879042a 100644 --- a/go/fory/primitive.go +++ b/go/fory/primitive.go @@ -663,3 +663,56 @@ func (s float16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool func (s float16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { s.Read(ctx, refMode, false, false, value) } + +// ============================================================================ +// bfloat16Serializer - optimized bfloat16 serialization +// ============================================================================ + +// bfloat16Serializer handles bfloat16 type +type bfloat16Serializer struct{} + +var globalBFloat16Serializer = bfloat16Serializer{} + +func (s bfloat16Serializer) WriteData(ctx *WriteContext, value reflect.Value) { + // Value is effectively uint16 (alias) + ctx.buffer.WriteUint16(uint16(value.Uint())) +} + +func (s bfloat16Serializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + if refMode != RefModeNone { + ctx.buffer.WriteInt8(NotNullValueFlag) + } + if writeType { + ctx.buffer.WriteUint8(uint8(BFLOAT16)) + } + s.WriteData(ctx, value) +} + +func (s bfloat16Serializer) ReadData(ctx *ReadContext, value reflect.Value) { + err := ctx.Err() + bits := ctx.buffer.ReadUint16(err) + if ctx.HasError() { + return + } + value.SetUint(uint64(bits)) +} + +func (s bfloat16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + err := ctx.Err() + if refMode != RefModeNone { + if ctx.buffer.ReadInt8(err) == NullFlag { + return + } + } + if readType { + _ = ctx.buffer.ReadUint8(err) + } + if ctx.HasError() { + return + } + s.ReadData(ctx, value) +} + +func (s bfloat16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} diff --git a/go/fory/primitive_test.go b/go/fory/primitive_test.go index 978a81d46b..0f478590e0 100644 --- a/go/fory/primitive_test.go +++ b/go/fory/primitive_test.go @@ -20,6 +20,7 @@ package fory import ( "testing" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/stretchr/testify/require" ) @@ -56,3 +57,36 @@ func TestFloat16PrimitiveSliceDirect(t *testing.T) { require.NoError(t, err) require.Equal(t, slice, resSlice) } + +func TestBFloat16Primitive(t *testing.T) { + f := New(WithXlang(true)) + bf16 := bfloat16.BFloat16FromFloat32(3.14) + + // Directly serialize a bfloat16 value + data, err := f.Serialize(bf16) + require.NoError(t, err) + + var res bfloat16.BFloat16 + err = f.Deserialize(data, &res) + require.NoError(t, err) + + require.Equal(t, bf16.Bits(), res.Bits()) + + // Value check (approximate because BF16 precision is low) + require.InDelta(t, 3.14, res.Float32(), 0.1) +} + +func TestBFloat16PrimitiveSliceDirect(t *testing.T) { + // Tests serializing a slice as a root object + f := New(WithXlang(true)) + bf16 := bfloat16.BFloat16FromFloat32(3.14) + + slice := []bfloat16.BFloat16{bf16, bfloat16.BFloat16(0)} + data, err := f.Serialize(slice) + require.NoError(t, err) + + var resSlice []bfloat16.BFloat16 + err = f.Deserialize(data, &resSlice) + require.NoError(t, err) + require.Equal(t, slice, resSlice) +} From ddf67f0356347cbc04ce87ee007450b0ce888dac Mon Sep 17 00:00:00 2001 From: Brian Li Date: Sun, 8 Feb 2026 16:57:51 -0600 Subject: [PATCH 04/11] feat: bfloat16 implementation --- go/fory/bfloat16/bfloat16.go | 75 ++++++++++++++++++++ go/fory/bfloat16/bfloat16_test.go | 112 ++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 go/fory/bfloat16/bfloat16.go create mode 100644 go/fory/bfloat16/bfloat16_test.go diff --git a/go/fory/bfloat16/bfloat16.go b/go/fory/bfloat16/bfloat16.go new file mode 100644 index 0000000000..035b66031b --- /dev/null +++ b/go/fory/bfloat16/bfloat16.go @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bfloat16 + +import ( + "fmt" + "math" +) + +// BFloat16 represents a brain floating point number (bfloat16). +// It is stored as a uint16. +type BFloat16 uint16 + +// BFloat16Array is a slice of BFloat16 values. +type BFloat16Array []BFloat16 + +// BFloat16FromBits returns the BFloat16 corresponding to the given bit pattern. +func BFloat16FromBits(b uint16) BFloat16 { + return BFloat16(b) +} + +// Bits returns the raw bit pattern of the floating point number. +func (f BFloat16) Bits() uint16 { + return uint16(f) +} + +// BFloat16FromFloat32 converts a float32 to a BFloat16. +// Rounds to nearest, ties to even. +func BFloat16FromFloat32(f float32) BFloat16 { + u := math.Float32bits(f) + + // NaN check + if (u&0x7F800000) == 0x7F800000 && (u&0x007FFFFF) != 0 { + return BFloat16(0x7FC0) // Canonical NaN + } + + // Fast path for rounding + // We want to add a rounding bias and then truncate. + // For ties-to-even: + // If LSB of result (bit 16) is 0: Rounding bias is 0x7FFF + // If LSB of result (bit 16) is 1: Rounding bias is 0x8000 + // lsb is (u >> 16) & 1. + // bias = 0x7FFF + lsb + + lsb := (u >> 16) & 1 + roundingBias := uint32(0x7FFF) + lsb + u += roundingBias + return BFloat16(u >> 16) +} + +// Float32 returns the float32 representation of the BFloat16. +func (f BFloat16) Float32() float32 { + // Just shift left by 16 bits + return math.Float32frombits(uint32(f) << 16) +} + +// String returns the string representation of f. +func (f BFloat16) String() string { + return fmt.Sprintf("%g", f.Float32()) +} diff --git a/go/fory/bfloat16/bfloat16_test.go b/go/fory/bfloat16/bfloat16_test.go new file mode 100644 index 0000000000..cb8b4b8fce --- /dev/null +++ b/go/fory/bfloat16/bfloat16_test.go @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bfloat16_test + +import ( + "math" + "testing" + + "github.com/apache/fory/go/fory/bfloat16" + "github.com/stretchr/testify/assert" +) + +func TestBFloat16_Conversion(t *testing.T) { + tests := []struct { + name string + f32 float32 + want uint16 // bits + check bool // if true, check exact bits + }{ + {"Zero", 0.0, 0x0000, true}, + {"NegZero", float32(math.Copysign(0, -1)), 0x8000, true}, + {"One", 1.0, 0x3F80, true}, + {"MinusOne", -1.0, 0xBF80, true}, + {"Inf", float32(math.Inf(1)), 0x7F80, true}, + {"NegInf", float32(math.Inf(-1)), 0xFF80, true}, + // 1.5 -> 0x3FC0. (0x3FC00000 is 1.5) + {"OnePointFive", 1.5, 0x3FC0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bf16 := bfloat16.BFloat16FromFloat32(tt.f32) + if tt.check { + assert.Equal(t, tt.want, bf16.Bits(), "Bits match") + } + + // Round trip check + roundTrip := bf16.Float32() + if math.IsInf(float64(tt.f32), 0) { + assert.True(t, math.IsInf(float64(roundTrip), 0)) + assert.Equal(t, math.Signbit(float64(tt.f32)), math.Signbit(float64(roundTrip))) + } else if math.IsNaN(float64(tt.f32)) { + assert.True(t, math.IsNaN(float64(roundTrip))) + } else { + if tt.check { + assert.Equal(t, tt.f32, roundTrip, "Round trip value match") + } + } + }) + } +} + +func TestBFloat16_Rounding(t *testing.T) { + // 1 + 2^-8. 2^-8 is half of ULP for 1.0 in BFloat16? + // BFloat16 mantissa has 7 bits. 1.0 is 1.0000000 * 2^0. + // ULP is 2^-7. + // Half ULP is 2^-8. + // 1.0 + 2^-8 should round to even (1.0) because 1.0 mantissa is even (0). + // 1.0 + 2^-8 + epsilon => round up. + + // float32 bits for 1.0 is 0x3F800000. + // 1 + 2^-8 + // 2^-8 = 1/256. + // We set bit index (23-8) = 15. + // 0x3F800000 | (1<<15) = 0x3F808000. + val1 := math.Float32frombits(0x3F808000) // 1.0 + 2^-8 + bf1 := bfloat16.BFloat16FromFloat32(val1) + assert.Equal(t, uint16(0x3F80), bf1.Bits(), "Round to even (down)") + + // 1.0 + 3 * 2^-8. (1.0 + 1.5 ULP) -> Round up. + // 3 * 2^-8 = 1.5 * 2^-7. + // 11 at bits 15,14. + // 0x3F800000 | (1<<15) | (1<<14) = 0x3F80C000. + val2 := math.Float32frombits(0x3F80C000) + bf2 := bfloat16.BFloat16FromFloat32(val2) + assert.Equal(t, uint16(0x3F81), bf2.Bits(), "Round up") + + // 1.0 + 2^-7 (Next representable). + // 0x3F800000 | (1<<16) = 0x3F810000. + val3 := math.Float32frombits(0x3F810000) + bf3 := bfloat16.BFloat16FromFloat32(val3) + assert.Equal(t, uint16(0x3F81), bf3.Bits(), "Exact") + + // 1.0 + 2^-7 + 2^-8 -> 1.5 ULP relative to 1.0? No. + // 1.0 mant 0. + // 1.0 + 2^-7: mant 1 (at bit 16). + // Add 2^-8: bit 15 set. + // 0x3F818000. + // LSB (bit 16) is 1 (odd). + // Guard (bit 15) is 1. + // Sticky is 0. + // Round to even -> Round up. (1 -> 2) + // Result 0x3F82. + val4 := math.Float32frombits(0x3F818000) + bf4 := bfloat16.BFloat16FromFloat32(val4) + assert.Equal(t, uint16(0x3F82), bf4.Bits(), "Round to even (up)") +} From 7cd9b951c4112a197ea6421d5ecf7aa4e806d9ba Mon Sep 17 00:00:00 2001 From: Brian Li Date: Sun, 8 Feb 2026 17:31:51 -0600 Subject: [PATCH 05/11] feat: add bloat16 resolvers --- go/fory/type_resolver.go | 16 ++++++++++++++++ go/fory/type_test.go | 7 ++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index bd49a03e15..f9e10ffc36 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -29,6 +29,7 @@ import ( "time" "unsafe" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/apache/fory/go/fory/meta" ) @@ -77,6 +78,7 @@ var ( float32SliceType = reflect.TypeOf((*[]float32)(nil)).Elem() float64SliceType = reflect.TypeOf((*[]float64)(nil)).Elem() float16SliceType = reflect.TypeOf((*[]float16.Float16)(nil)).Elem() + bfloat16SliceType = reflect.TypeOf((*[]bfloat16.BFloat16)(nil)).Elem() interfaceSliceType = reflect.TypeOf((*[]any)(nil)).Elem() interfaceMapType = reflect.TypeOf((*map[any]any)(nil)).Elem() stringStringMapType = reflect.TypeOf((*map[string]string)(nil)).Elem() @@ -103,6 +105,7 @@ var ( float32Type = reflect.TypeOf((*float32)(nil)).Elem() float64Type = reflect.TypeOf((*float64)(nil)).Elem() float16Type = reflect.TypeOf((*float16.Float16)(nil)).Elem() + bfloat16Type = reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem() dateType = reflect.TypeOf((*Date)(nil)).Elem() timestampType = reflect.TypeOf((*time.Time)(nil)).Elem() genericSetType = reflect.TypeOf((*Set[any])(nil)).Elem() @@ -256,6 +259,7 @@ func newTypeResolver(fory *Fory) *TypeResolver { float32Type, float64Type, float16Type, + bfloat16Type, stringType, dateType, timestampType, @@ -413,6 +417,7 @@ func (r *TypeResolver) initialize() { {float32SliceType, FLOAT32_ARRAY, float32SliceSerializer{}}, {float64SliceType, FLOAT64_ARRAY, float64SliceSerializer{}}, {float16SliceType, FLOAT16_ARRAY, float16SliceSerializer{}}, + {bfloat16SliceType, BFLOAT16_ARRAY, bfloat16SliceSerializer{}}, // Register common map types for fast path with optimized serializers {stringStringMapType, MAP, stringStringMapSerializer{}}, {stringInt64MapType, MAP, stringInt64MapSerializer{}}, @@ -437,6 +442,7 @@ func (r *TypeResolver) initialize() { {float32Type, FLOAT32, float32Serializer{}}, {float64Type, FLOAT64, float64Serializer{}}, {float16Type, FLOAT16, float16Serializer{}}, + {bfloat16Type, BFLOAT16, bfloat16Serializer{}}, {dateType, DATE, dateSerializer{}}, {timestampType, TIMESTAMP, timeSerializer{}}, {genericSetType, SET, setSerializer{}}, @@ -1704,6 +1710,12 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s case reflect.Uint8: // []byte uses byteSliceSerializer return byteSliceSerializer{}, nil + case reflect.Uint16: + // Check for fory.BFloat16 (aliased to uint16) + if elem == bfloat16Type { + return bfloat16SliceSerializer{}, nil + } + return uint16SliceSerializer{}, nil case reflect.String: return stringSliceSerializer{}, nil } @@ -1751,6 +1763,10 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s if elem.Name() == "Float16" && (elem.PkgPath() == "github.com/apache/fory/go/fory/float16" || strings.HasSuffix(elem.PkgPath(), "/float16")) { return float16ArraySerializer{arrayType: type_}, nil } + // Check for fory.BFloat16 (aliased to uint16) + if elem == bfloat16Type { + return bfloat16ArraySerializer{arrayType: type_}, nil + } return uint16ArraySerializer{arrayType: type_}, nil case reflect.Uint32: return uint32ArraySerializer{arrayType: type_}, nil diff --git a/go/fory/type_test.go b/go/fory/type_test.go index bbb890be57..001813be6a 100644 --- a/go/fory/type_test.go +++ b/go/fory/type_test.go @@ -18,9 +18,11 @@ package fory import ( - "github.com/stretchr/testify/require" "reflect" "testing" + + "github.com/apache/fory/go/fory/bfloat16" + "github.com/stretchr/testify/require" ) func TestTypeResolver(t *testing.T) { @@ -39,6 +41,9 @@ func TestTypeResolver(t *testing.T) { {reflect.TypeOf((*int)(nil)), "*int"}, {reflect.TypeOf((*[10]int)(nil)), "*[10]int"}, {reflect.TypeOf((*[10]int)(nil)).Elem(), "[10]int"}, + {reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem(), "bfloat16.BFloat16"}, + {reflect.TypeOf((*[]bfloat16.BFloat16)(nil)).Elem(), "[]bfloat16.BFloat16"}, + {reflect.TypeOf((*[10]bfloat16.BFloat16)(nil)).Elem(), "[10]bfloat16.BFloat16"}, {reflect.TypeOf((*[]map[string][]map[string]*any)(nil)).Elem(), "[]map[string][]map[string]*interface {}"}, {reflect.TypeOf((*A)(nil)), "*@example.A"}, From de635dc6beb58c1a8af5be61eb3cc1b256778473 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Sun, 8 Feb 2026 17:45:42 -0600 Subject: [PATCH 06/11] feat: add bfloat16 type to compiler --- compiler/fory_compiler/generators/go.py | 3 ++ compiler/fory_compiler/ir/types.py | 2 ++ .../tests/test_generated_code.py | 30 +++++++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index d7784a0dfa..0fd9d4b0c1 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -190,6 +190,7 @@ def message_has_unions(self, message: Message) -> bool: PrimitiveKind.VAR_UINT64: "uint64", PrimitiveKind.TAGGED_UINT64: "uint64", PrimitiveKind.FLOAT16: "float16.Float16", + PrimitiveKind.BFLOAT16: "bfloat16.BFloat16", PrimitiveKind.FLOAT32: "float32", PrimitiveKind.FLOAT64: "float64", PrimitiveKind.STRING: "string", @@ -1090,6 +1091,8 @@ def collect_imports(self, field_type: FieldType, imports: Set[str]): imports.add('"time"') elif field_type.kind == PrimitiveKind.FLOAT16: imports.add('float16 "github.com/apache/fory/go/fory/float16"') + elif field_type.kind == PrimitiveKind.BFLOAT16: + imports.add('bfloat16 "github.com/apache/fory/go/fory/bfloat16"') elif isinstance(field_type, ListType): self.collect_imports(field_type.element_type, imports) diff --git a/compiler/fory_compiler/ir/types.py b/compiler/fory_compiler/ir/types.py index 3dfc3d8edf..facc95ef67 100644 --- a/compiler/fory_compiler/ir/types.py +++ b/compiler/fory_compiler/ir/types.py @@ -39,6 +39,7 @@ class PrimitiveKind(PyEnum): VAR_UINT64 = "var_uint64" TAGGED_UINT64 = "tagged_uint64" FLOAT16 = "float16" + BFLOAT16 = "bfloat16" FLOAT32 = "float32" FLOAT64 = "float64" STRING = "string" @@ -67,6 +68,7 @@ class PrimitiveKind(PyEnum): "fixed_uint64": PrimitiveKind.UINT64, "tagged_uint64": PrimitiveKind.TAGGED_UINT64, "float16": PrimitiveKind.FLOAT16, + "bfloat16": PrimitiveKind.BFLOAT16, "float32": PrimitiveKind.FLOAT32, "float64": PrimitiveKind.FLOAT64, "string": PrimitiveKind.STRING, diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index 7fda28f486..13dc99eab9 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -497,3 +497,33 @@ def test_generated_code_tree_ref_options_equivalent(): cpp_output = render_files(generate_files(schemas["fdl"], CppGenerator)) assert "SharedWeak" in cpp_output + + +def test_go_bfloat16_generation(): + idl = dedent( + """ + package bfloat16_test; + + message BFloat16Message { + bfloat16 val = 1; + optional bfloat16 opt_val = 2; + list list_val = 3; + } + """ + ) + schema = parse_fdl(idl) + files = generate_files(schema, GoGenerator) + + assert len(files) == 1 + content = list(files.values())[0] + + # Check imports + assert 'bfloat16 "github.com/apache/fory/go/fory/bfloat16"' in content + + # Check fields + assert '\tVal bfloat16.BFloat16 `fory:"id=1"`' in content + assert ( + '\tOptVal optional.Optional[bfloat16.BFloat16] `fory:"id=2,nullable"`' + in content + ) + assert "\tListVal []bfloat16.BFloat16" in content From 9d3bb9f1dd5de4646d3b44d4b26e7c8e5a7088c5 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Mon, 9 Feb 2026 00:05:39 -0600 Subject: [PATCH 07/11] refactor: refine the comment for bfloat_test --- go/fory/bfloat16/bfloat16_test.go | 38 +++++++++---------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/go/fory/bfloat16/bfloat16_test.go b/go/fory/bfloat16/bfloat16_test.go index cb8b4b8fce..44745ae916 100644 --- a/go/fory/bfloat16/bfloat16_test.go +++ b/go/fory/bfloat16/bfloat16_test.go @@ -66,46 +66,30 @@ func TestBFloat16_Conversion(t *testing.T) { } func TestBFloat16_Rounding(t *testing.T) { - // 1 + 2^-8. 2^-8 is half of ULP for 1.0 in BFloat16? - // BFloat16 mantissa has 7 bits. 1.0 is 1.0000000 * 2^0. - // ULP is 2^-7. - // Half ULP is 2^-8. - // 1.0 + 2^-8 should round to even (1.0) because 1.0 mantissa is even (0). - // 1.0 + 2^-8 + epsilon => round up. + // BFloat16 has 7 bits of mantissa. For 1.0, ULP is 2^-7, and half ULP is 2^-8. + // Values are rounded to nearest even. 1.0 + 2^-8 should round to 1.0 (even mantissa). - // float32 bits for 1.0 is 0x3F800000. - // 1 + 2^-8 - // 2^-8 = 1/256. - // We set bit index (23-8) = 15. - // 0x3F800000 | (1<<15) = 0x3F808000. + // The float32 representation of 1.0 is 0x3F800000. + // Adding 2^-8 (1/256) means setting bit 15 (23-8). + // So, 1.0 + 2^-8 in float32 is 0x3F808000. val1 := math.Float32frombits(0x3F808000) // 1.0 + 2^-8 bf1 := bfloat16.BFloat16FromFloat32(val1) assert.Equal(t, uint16(0x3F80), bf1.Bits(), "Round to even (down)") - // 1.0 + 3 * 2^-8. (1.0 + 1.5 ULP) -> Round up. - // 3 * 2^-8 = 1.5 * 2^-7. - // 11 at bits 15,14. - // 0x3F800000 | (1<<15) | (1<<14) = 0x3F80C000. + // For 1.0 + 3 * 2^-8 (1.5 ULP), bits 15 and 14 are set, + // making the float32 representation 0x3F80C000. This rounds up. val2 := math.Float32frombits(0x3F80C000) bf2 := bfloat16.BFloat16FromFloat32(val2) assert.Equal(t, uint16(0x3F81), bf2.Bits(), "Round up") - // 1.0 + 2^-7 (Next representable). - // 0x3F800000 | (1<<16) = 0x3F810000. + // 1.0 + 2^-7 is the next representable number after 1.0. In float32, this is 0x3F810000. val3 := math.Float32frombits(0x3F810000) bf3 := bfloat16.BFloat16FromFloat32(val3) assert.Equal(t, uint16(0x3F81), bf3.Bits(), "Exact") - // 1.0 + 2^-7 + 2^-8 -> 1.5 ULP relative to 1.0? No. - // 1.0 mant 0. - // 1.0 + 2^-7: mant 1 (at bit 16). - // Add 2^-8: bit 15 set. - // 0x3F818000. - // LSB (bit 16) is 1 (odd). - // Guard (bit 15) is 1. - // Sticky is 0. - // Round to even -> Round up. (1 -> 2) - // Result 0x3F82. + // For 1.0 + 2^-7 + 2^-8 (0x3F818000), the LSB (bit 16) of 0x3F81 is 1 (odd), + // and the guard bit (bit 15) is 1. Rounding to nearest even means rounding up. + // Result: 0x3F82. val4 := math.Float32frombits(0x3F818000) bf4 := bfloat16.BFloat16FromFloat32(val4) assert.Equal(t, uint16(0x3F82), bf4.Bits(), "Round to even (up)") From 1c6910c72bf19523beb622cc80dc3f11c9416ec7 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Tue, 10 Feb 2026 11:22:56 -0600 Subject: [PATCH 08/11] feat: make the order of bfloat and float consistent in type_resolver --- go/fory/type_resolver.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index b120b21357..19cd839978 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -1714,14 +1714,14 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s // []byte uses byteSliceSerializer return byteSliceSerializer{}, nil case reflect.Uint16: - // Check for fory.BFloat16 (aliased to uint16) - if elem == bfloat16Type { - return bfloat16SliceSerializer{}, nil - } // Check for fory.Float16 (aliased to uint16) if elem == float16Type { return float16SliceSerializer{}, nil } + // Check for fory.BFloat16 (aliased to uint16) + if elem == bfloat16Type { + return bfloat16SliceSerializer{}, nil + } return uint16SliceSerializer{}, nil case reflect.Uint32: return uint32SliceSerializer{}, nil From c619f2d3f6cfaf5322c477544711297a2c4b7b33 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Tue, 10 Feb 2026 12:38:45 -0600 Subject: [PATCH 09/11] feat: remove redundent BFloat16Array type --- go/fory/bfloat16/bfloat16.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/go/fory/bfloat16/bfloat16.go b/go/fory/bfloat16/bfloat16.go index 035b66031b..6b31ffdb90 100644 --- a/go/fory/bfloat16/bfloat16.go +++ b/go/fory/bfloat16/bfloat16.go @@ -26,9 +26,6 @@ import ( // It is stored as a uint16. type BFloat16 uint16 -// BFloat16Array is a slice of BFloat16 values. -type BFloat16Array []BFloat16 - // BFloat16FromBits returns the BFloat16 corresponding to the given bit pattern. func BFloat16FromBits(b uint16) BFloat16 { return BFloat16(b) From 0f2bd06054ad2a305ce274c511dc32323cb4bc64 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Wed, 11 Feb 2026 14:42:09 -0600 Subject: [PATCH 10/11] feat(go): add serializer tests for bfloat --- go/fory/type_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/go/fory/type_test.go b/go/fory/type_test.go index ddba169ad6..e8978e5930 100644 --- a/go/fory/type_test.go +++ b/go/fory/type_test.go @@ -88,6 +88,7 @@ func TestCreateSerializerSliceTypes(t *testing.T) { {reflect.TypeOf([]uint{}), reflect.TypeOf(uintSliceSerializer{})}, {reflect.TypeOf([]uint16{}), reflect.TypeOf(uint16SliceSerializer{})}, {reflect.TypeOf([]float16.Float16{}), reflect.TypeOf(float16SliceSerializer{})}, + {reflect.TypeOf([]bfloat16.BFloat16{}), reflect.TypeOf(bfloat16SliceSerializer{})}, {reflect.TypeOf([]uint32{}), reflect.TypeOf(uint32SliceSerializer{})}, {reflect.TypeOf([]uint64{}), reflect.TypeOf(uint64SliceSerializer{})}, {reflect.TypeOf([]string{}), reflect.TypeOf(stringSliceSerializer{})}, @@ -139,6 +140,7 @@ func TestCreateSerializerArrayTypes(t *testing.T) { {reflect.TypeOf([4]byte{}), reflect.TypeOf(uint8ArraySerializer{})}, {reflect.TypeOf([4]uint16{}), reflect.TypeOf(uint16ArraySerializer{})}, {reflect.TypeOf([4]float16.Float16{}), reflect.TypeOf(float16ArraySerializer{})}, + {reflect.TypeOf([4]bfloat16.BFloat16{}), reflect.TypeOf(bfloat16ArraySerializer{})}, {reflect.TypeOf([4]uint32{}), reflect.TypeOf(uint32ArraySerializer{})}, {reflect.TypeOf([4]uint64{}), reflect.TypeOf(uint64ArraySerializer{})}, } From eeeb6074f3eb205cf8e8e0d11a9b8fbfdb50e2a2 Mon Sep 17 00:00:00 2001 From: Brian Li Date: Wed, 11 Feb 2026 15:19:08 -0600 Subject: [PATCH 11/11] feat(go): add BFLOAT16 type check in skip --- go/fory/skip.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/go/fory/skip.go b/go/fory/skip.go index b0b36c5af7..34005ad74d 100644 --- a/go/fory/skip.go +++ b/go/fory/skip.go @@ -576,6 +576,8 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo _ = ctx.buffer.ReadVarint64(err) // Floating point types + case BFLOAT16, FLOAT16: + _ = ctx.buffer.ReadUint16(err) case FLOAT32: _ = ctx.buffer.ReadFloat32(err) case FLOAT64: @@ -610,7 +612,7 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo return } _ = ctx.buffer.ReadBinary(length, err) - case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY: + case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY, BFLOAT16_ARRAY: length := ctx.buffer.ReadLength(err) if ctx.HasError() { return