diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index d7784a0dfa..0fd9d4b0c1 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -190,6 +190,7 @@ def message_has_unions(self, message: Message) -> bool: PrimitiveKind.VAR_UINT64: "uint64", PrimitiveKind.TAGGED_UINT64: "uint64", PrimitiveKind.FLOAT16: "float16.Float16", + PrimitiveKind.BFLOAT16: "bfloat16.BFloat16", PrimitiveKind.FLOAT32: "float32", PrimitiveKind.FLOAT64: "float64", PrimitiveKind.STRING: "string", @@ -1090,6 +1091,8 @@ def collect_imports(self, field_type: FieldType, imports: Set[str]): imports.add('"time"') elif field_type.kind == PrimitiveKind.FLOAT16: imports.add('float16 "github.com/apache/fory/go/fory/float16"') + elif field_type.kind == PrimitiveKind.BFLOAT16: + imports.add('bfloat16 "github.com/apache/fory/go/fory/bfloat16"') elif isinstance(field_type, ListType): self.collect_imports(field_type.element_type, imports) diff --git a/compiler/fory_compiler/ir/types.py b/compiler/fory_compiler/ir/types.py index 3dfc3d8edf..facc95ef67 100644 --- a/compiler/fory_compiler/ir/types.py +++ b/compiler/fory_compiler/ir/types.py @@ -39,6 +39,7 @@ class PrimitiveKind(PyEnum): VAR_UINT64 = "var_uint64" TAGGED_UINT64 = "tagged_uint64" FLOAT16 = "float16" + BFLOAT16 = "bfloat16" FLOAT32 = "float32" FLOAT64 = "float64" STRING = "string" @@ -67,6 +68,7 @@ class PrimitiveKind(PyEnum): "fixed_uint64": PrimitiveKind.UINT64, "tagged_uint64": PrimitiveKind.TAGGED_UINT64, "float16": PrimitiveKind.FLOAT16, + "bfloat16": PrimitiveKind.BFLOAT16, "float32": PrimitiveKind.FLOAT32, "float64": PrimitiveKind.FLOAT64, "string": PrimitiveKind.STRING, diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index 7fda28f486..13dc99eab9 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -497,3 +497,33 @@ def test_generated_code_tree_ref_options_equivalent(): cpp_output = render_files(generate_files(schemas["fdl"], CppGenerator)) assert "SharedWeak" in cpp_output + + +def test_go_bfloat16_generation(): + idl = dedent( + """ + package bfloat16_test; + + message BFloat16Message { + bfloat16 val = 1; + optional bfloat16 opt_val = 2; + list list_val = 3; + } + """ + ) + schema = parse_fdl(idl) + files = generate_files(schema, GoGenerator) + + assert len(files) == 1 + content = list(files.values())[0] + + # Check imports + assert 'bfloat16 "github.com/apache/fory/go/fory/bfloat16"' in content + + # Check fields + assert '\tVal bfloat16.BFloat16 `fory:"id=1"`' in content + assert ( + '\tOptVal optional.Optional[bfloat16.BFloat16] `fory:"id=2,nullable"`' + in content + ) + assert "\tListVal []bfloat16.BFloat16" in content diff --git a/go/fory/array_primitive.go b/go/fory/array_primitive.go index 9777733590..27813060b7 100644 --- a/go/fory/array_primitive.go +++ b/go/fory/array_primitive.go @@ -21,6 +21,7 @@ import ( "reflect" "unsafe" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" ) @@ -871,3 +872,77 @@ func (s float16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType func (s float16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { s.Read(ctx, refMode, false, false, value) } + +// ============================================================================ +// bfloat16ArraySerializer - optimized [N]bfloat16.BFloat16 serialization +// ============================================================================ + +type bfloat16ArraySerializer struct { + arrayType reflect.Type +} + +func (s bfloat16ArraySerializer) WriteData(ctx *WriteContext, value reflect.Value) { + buf := ctx.Buffer() + length := value.Len() + size := length * 2 + buf.WriteLength(size) + if length > 0 { + if value.CanAddr() && isLittleEndian { + ptr := value.Addr().UnsafePointer() + buf.WriteBinary(unsafe.Slice((*byte)(ptr), size)) + } else { + for i := 0; i < length; i++ { + // We can't easily cast the whole array if not addressable/little-endian + // So we iterate. + val := value.Index(i).Interface().(bfloat16.BFloat16) + buf.WriteUint16(val.Bits()) + } + } + } +} + +func (s bfloat16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + writeArrayRefAndType(ctx, refMode, writeType, value, BFLOAT16_ARRAY) + if ctx.HasError() { + return + } + s.WriteData(ctx, value) +} + +func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { + buf := ctx.Buffer() + ctxErr := ctx.Err() + size := buf.ReadLength(ctxErr) + length := size / 2 + if ctx.HasError() { + return + } + if length != value.Type().Len() { + ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + return + } + + if length > 0 { + if isLittleEndian { + ptr := value.Addr().UnsafePointer() + raw := buf.ReadBinary(size, ctxErr) + copy(unsafe.Slice((*byte)(ptr), size), raw) + } else { + for i := 0; i < length; i++ { + value.Index(i).Set(reflect.ValueOf(bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr)))) + } + } + } +} + +func (s bfloat16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + done := readArrayRefAndType(ctx, refMode, readType, value) + if done || ctx.HasError() { + return + } + s.ReadData(ctx, value) +} + +func (s bfloat16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} diff --git a/go/fory/array_primitive_test.go b/go/fory/array_primitive_test.go index c2e684af4c..e8c99fb2f2 100644 --- a/go/fory/array_primitive_test.go +++ b/go/fory/array_primitive_test.go @@ -20,6 +20,7 @@ package fory import ( "testing" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -74,6 +75,21 @@ func TestPrimitiveArraySerializer(t *testing.T) { require.NoError(t, err) require.Equal(t, arr, result) }) + + t.Run("bfloat16_array", func(t *testing.T) { + arr := [3]bfloat16.BFloat16{ + bfloat16.BFloat16FromFloat32(1.0), + bfloat16.BFloat16FromFloat32(2.5), + bfloat16.BFloat16FromFloat32(-3.5), + } + data, err := f.Serialize(arr) + assert.NoError(t, err) + + var result [3]bfloat16.BFloat16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Equal(t, arr, result) + }) } func TestArraySliceInteroperability(t *testing.T) { diff --git a/go/fory/bfloat16/bfloat16.go b/go/fory/bfloat16/bfloat16.go new file mode 100644 index 0000000000..6b31ffdb90 --- /dev/null +++ b/go/fory/bfloat16/bfloat16.go @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bfloat16 + +import ( + "fmt" + "math" +) + +// BFloat16 represents a brain floating point number (bfloat16). +// It is stored as a uint16. +type BFloat16 uint16 + +// BFloat16FromBits returns the BFloat16 corresponding to the given bit pattern. +func BFloat16FromBits(b uint16) BFloat16 { + return BFloat16(b) +} + +// Bits returns the raw bit pattern of the floating point number. +func (f BFloat16) Bits() uint16 { + return uint16(f) +} + +// BFloat16FromFloat32 converts a float32 to a BFloat16. +// Rounds to nearest, ties to even. +func BFloat16FromFloat32(f float32) BFloat16 { + u := math.Float32bits(f) + + // NaN check + if (u&0x7F800000) == 0x7F800000 && (u&0x007FFFFF) != 0 { + return BFloat16(0x7FC0) // Canonical NaN + } + + // Fast path for rounding + // We want to add a rounding bias and then truncate. + // For ties-to-even: + // If LSB of result (bit 16) is 0: Rounding bias is 0x7FFF + // If LSB of result (bit 16) is 1: Rounding bias is 0x8000 + // lsb is (u >> 16) & 1. + // bias = 0x7FFF + lsb + + lsb := (u >> 16) & 1 + roundingBias := uint32(0x7FFF) + lsb + u += roundingBias + return BFloat16(u >> 16) +} + +// Float32 returns the float32 representation of the BFloat16. +func (f BFloat16) Float32() float32 { + // Just shift left by 16 bits + return math.Float32frombits(uint32(f) << 16) +} + +// String returns the string representation of f. +func (f BFloat16) String() string { + return fmt.Sprintf("%g", f.Float32()) +} diff --git a/go/fory/bfloat16/bfloat16_test.go b/go/fory/bfloat16/bfloat16_test.go new file mode 100644 index 0000000000..44745ae916 --- /dev/null +++ b/go/fory/bfloat16/bfloat16_test.go @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package bfloat16_test + +import ( + "math" + "testing" + + "github.com/apache/fory/go/fory/bfloat16" + "github.com/stretchr/testify/assert" +) + +func TestBFloat16_Conversion(t *testing.T) { + tests := []struct { + name string + f32 float32 + want uint16 // bits + check bool // if true, check exact bits + }{ + {"Zero", 0.0, 0x0000, true}, + {"NegZero", float32(math.Copysign(0, -1)), 0x8000, true}, + {"One", 1.0, 0x3F80, true}, + {"MinusOne", -1.0, 0xBF80, true}, + {"Inf", float32(math.Inf(1)), 0x7F80, true}, + {"NegInf", float32(math.Inf(-1)), 0xFF80, true}, + // 1.5 -> 0x3FC0. (0x3FC00000 is 1.5) + {"OnePointFive", 1.5, 0x3FC0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bf16 := bfloat16.BFloat16FromFloat32(tt.f32) + if tt.check { + assert.Equal(t, tt.want, bf16.Bits(), "Bits match") + } + + // Round trip check + roundTrip := bf16.Float32() + if math.IsInf(float64(tt.f32), 0) { + assert.True(t, math.IsInf(float64(roundTrip), 0)) + assert.Equal(t, math.Signbit(float64(tt.f32)), math.Signbit(float64(roundTrip))) + } else if math.IsNaN(float64(tt.f32)) { + assert.True(t, math.IsNaN(float64(roundTrip))) + } else { + if tt.check { + assert.Equal(t, tt.f32, roundTrip, "Round trip value match") + } + } + }) + } +} + +func TestBFloat16_Rounding(t *testing.T) { + // BFloat16 has 7 bits of mantissa. For 1.0, ULP is 2^-7, and half ULP is 2^-8. + // Values are rounded to nearest even. 1.0 + 2^-8 should round to 1.0 (even mantissa). + + // The float32 representation of 1.0 is 0x3F800000. + // Adding 2^-8 (1/256) means setting bit 15 (23-8). + // So, 1.0 + 2^-8 in float32 is 0x3F808000. + val1 := math.Float32frombits(0x3F808000) // 1.0 + 2^-8 + bf1 := bfloat16.BFloat16FromFloat32(val1) + assert.Equal(t, uint16(0x3F80), bf1.Bits(), "Round to even (down)") + + // For 1.0 + 3 * 2^-8 (1.5 ULP), bits 15 and 14 are set, + // making the float32 representation 0x3F80C000. This rounds up. + val2 := math.Float32frombits(0x3F80C000) + bf2 := bfloat16.BFloat16FromFloat32(val2) + assert.Equal(t, uint16(0x3F81), bf2.Bits(), "Round up") + + // 1.0 + 2^-7 is the next representable number after 1.0. In float32, this is 0x3F810000. + val3 := math.Float32frombits(0x3F810000) + bf3 := bfloat16.BFloat16FromFloat32(val3) + assert.Equal(t, uint16(0x3F81), bf3.Bits(), "Exact") + + // For 1.0 + 2^-7 + 2^-8 (0x3F818000), the LSB (bit 16) of 0x3F81 is 1 (odd), + // and the guard bit (bit 15) is 1. Rounding to nearest even means rounding up. + // Result: 0x3F82. + val4 := math.Float32frombits(0x3F818000) + bf4 := bfloat16.BFloat16FromFloat32(val4) + assert.Equal(t, uint16(0x3F82), bf4.Bits(), "Round to even (up)") +} diff --git a/go/fory/primitive.go b/go/fory/primitive.go index 2c316d549e..998879042a 100644 --- a/go/fory/primitive.go +++ b/go/fory/primitive.go @@ -663,3 +663,56 @@ func (s float16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool func (s float16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { s.Read(ctx, refMode, false, false, value) } + +// ============================================================================ +// bfloat16Serializer - optimized bfloat16 serialization +// ============================================================================ + +// bfloat16Serializer handles bfloat16 type +type bfloat16Serializer struct{} + +var globalBFloat16Serializer = bfloat16Serializer{} + +func (s bfloat16Serializer) WriteData(ctx *WriteContext, value reflect.Value) { + // Value is effectively uint16 (alias) + ctx.buffer.WriteUint16(uint16(value.Uint())) +} + +func (s bfloat16Serializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + if refMode != RefModeNone { + ctx.buffer.WriteInt8(NotNullValueFlag) + } + if writeType { + ctx.buffer.WriteUint8(uint8(BFLOAT16)) + } + s.WriteData(ctx, value) +} + +func (s bfloat16Serializer) ReadData(ctx *ReadContext, value reflect.Value) { + err := ctx.Err() + bits := ctx.buffer.ReadUint16(err) + if ctx.HasError() { + return + } + value.SetUint(uint64(bits)) +} + +func (s bfloat16Serializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + err := ctx.Err() + if refMode != RefModeNone { + if ctx.buffer.ReadInt8(err) == NullFlag { + return + } + } + if readType { + _ = ctx.buffer.ReadUint8(err) + } + if ctx.HasError() { + return + } + s.ReadData(ctx, value) +} + +func (s bfloat16Serializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} diff --git a/go/fory/primitive_test.go b/go/fory/primitive_test.go index 978a81d46b..0f478590e0 100644 --- a/go/fory/primitive_test.go +++ b/go/fory/primitive_test.go @@ -20,6 +20,7 @@ package fory import ( "testing" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/stretchr/testify/require" ) @@ -56,3 +57,36 @@ func TestFloat16PrimitiveSliceDirect(t *testing.T) { require.NoError(t, err) require.Equal(t, slice, resSlice) } + +func TestBFloat16Primitive(t *testing.T) { + f := New(WithXlang(true)) + bf16 := bfloat16.BFloat16FromFloat32(3.14) + + // Directly serialize a bfloat16 value + data, err := f.Serialize(bf16) + require.NoError(t, err) + + var res bfloat16.BFloat16 + err = f.Deserialize(data, &res) + require.NoError(t, err) + + require.Equal(t, bf16.Bits(), res.Bits()) + + // Value check (approximate because BF16 precision is low) + require.InDelta(t, 3.14, res.Float32(), 0.1) +} + +func TestBFloat16PrimitiveSliceDirect(t *testing.T) { + // Tests serializing a slice as a root object + f := New(WithXlang(true)) + bf16 := bfloat16.BFloat16FromFloat32(3.14) + + slice := []bfloat16.BFloat16{bf16, bfloat16.BFloat16(0)} + data, err := f.Serialize(slice) + require.NoError(t, err) + + var resSlice []bfloat16.BFloat16 + err = f.Deserialize(data, &resSlice) + require.NoError(t, err) + require.Equal(t, slice, resSlice) +} diff --git a/go/fory/skip.go b/go/fory/skip.go index b0b36c5af7..34005ad74d 100644 --- a/go/fory/skip.go +++ b/go/fory/skip.go @@ -576,6 +576,8 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo _ = ctx.buffer.ReadVarint64(err) // Floating point types + case BFLOAT16, FLOAT16: + _ = ctx.buffer.ReadUint16(err) case FLOAT32: _ = ctx.buffer.ReadFloat32(err) case FLOAT64: @@ -610,7 +612,7 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo return } _ = ctx.buffer.ReadBinary(length, err) - case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY: + case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY, BFLOAT16_ARRAY: length := ctx.buffer.ReadLength(err) if ctx.HasError() { return diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 200d144c63..e4daf990be 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -22,6 +22,7 @@ import ( "strconv" "unsafe" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" ) @@ -1330,3 +1331,81 @@ func ReadStringSlice(buf *ByteBuffer, err *Error) []string { } return result } + +// ============================================================================ +// bfloat16SliceSerializer - optimized []bfloat16.BFloat16 serialization +// ============================================================================ + +type bfloat16SliceSerializer struct{} + +func (s bfloat16SliceSerializer) WriteData(ctx *WriteContext, value reflect.Value) { + // Cast to []bfloat16.BFloat16 + v := value.Interface().([]bfloat16.BFloat16) + buf := ctx.Buffer() + length := len(v) + size := length * 2 + buf.WriteLength(size) + if length > 0 { + ptr := unsafe.Pointer(&v[0]) + if isLittleEndian { + buf.WriteBinary(unsafe.Slice((*byte)(ptr), size)) + } else { + for i := 0; i < length; i++ { + buf.WriteUint16(v[i].Bits()) + } + } + } +} + +func (s bfloat16SliceSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + done := writeSliceRefAndType(ctx, refMode, writeType, value, BFLOAT16_ARRAY) + if done || ctx.HasError() { + return + } + s.WriteData(ctx, value) +} + +func (s bfloat16SliceSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + done, typeId := readSliceRefAndType(ctx, refMode, readType, value) + if done || ctx.HasError() { + return + } + if readType && typeId != uint32(BFLOAT16_ARRAY) { + ctx.SetError(DeserializationErrorf("slice type mismatch: expected BFLOAT16_ARRAY (%d), got %d", BFLOAT16_ARRAY, typeId)) + return + } + s.ReadData(ctx, value) +} + +func (s bfloat16SliceSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} + +func (s bfloat16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { + buf := ctx.Buffer() + ctxErr := ctx.Err() + size := buf.ReadLength(ctxErr) + length := size / 2 + if ctx.HasError() { + return + } + + ptr := (*[]bfloat16.BFloat16)(value.Addr().UnsafePointer()) + if length == 0 { + *ptr = make([]bfloat16.BFloat16, 0) + return + } + + result := make([]bfloat16.BFloat16, length) + + if isLittleEndian { + raw := buf.ReadBinary(size, ctxErr) + targetPtr := unsafe.Pointer(&result[0]) + copy(unsafe.Slice((*byte)(targetPtr), size), raw) + } else { + for i := 0; i < length; i++ { + result[i] = bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr)) + } + } + *ptr = result +} diff --git a/go/fory/slice_primitive_test.go b/go/fory/slice_primitive_test.go index 82959d0daf..e03a4cb00c 100644 --- a/go/fory/slice_primitive_test.go +++ b/go/fory/slice_primitive_test.go @@ -21,6 +21,7 @@ import ( "math" "testing" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/stretchr/testify/assert" ) @@ -67,6 +68,48 @@ func TestFloat16Slice(t *testing.T) { }) } +func TestBFloat16Slice(t *testing.T) { + f := NewFory() + + t.Run("bfloat16_slice", func(t *testing.T) { + slice := []bfloat16.BFloat16{ + bfloat16.BFloat16FromFloat32(1.0), + bfloat16.BFloat16FromFloat32(2.5), + bfloat16.BFloat16FromFloat32(-3.5), + } + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []bfloat16.BFloat16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Equal(t, slice, result) + }) + + t.Run("bfloat16_slice_empty", func(t *testing.T) { + slice := []bfloat16.BFloat16{} + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []bfloat16.BFloat16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Empty(t, result) + }) + + t.Run("bfloat16_slice_nil", func(t *testing.T) { + var slice []bfloat16.BFloat16 = nil + data, err := f.Serialize(slice) + assert.NoError(t, err) + + var result []bfloat16.BFloat16 + err = f.Deserialize(data, &result) + assert.NoError(t, err) + assert.Nil(t, result) + }) +} + func TestIntSlice(t *testing.T) { f := NewFory() diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index b10a134d80..19cd839978 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -29,6 +29,7 @@ import ( "time" "unsafe" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/apache/fory/go/fory/meta" ) @@ -77,6 +78,7 @@ var ( float32SliceType = reflect.TypeOf((*[]float32)(nil)).Elem() float64SliceType = reflect.TypeOf((*[]float64)(nil)).Elem() float16SliceType = reflect.TypeOf((*[]float16.Float16)(nil)).Elem() + bfloat16SliceType = reflect.TypeOf((*[]bfloat16.BFloat16)(nil)).Elem() interfaceSliceType = reflect.TypeOf((*[]any)(nil)).Elem() interfaceMapType = reflect.TypeOf((*map[any]any)(nil)).Elem() stringStringMapType = reflect.TypeOf((*map[string]string)(nil)).Elem() @@ -103,6 +105,7 @@ var ( float32Type = reflect.TypeOf((*float32)(nil)).Elem() float64Type = reflect.TypeOf((*float64)(nil)).Elem() float16Type = reflect.TypeOf((*float16.Float16)(nil)).Elem() + bfloat16Type = reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem() dateType = reflect.TypeOf((*Date)(nil)).Elem() timestampType = reflect.TypeOf((*time.Time)(nil)).Elem() genericSetType = reflect.TypeOf((*Set[any])(nil)).Elem() @@ -259,6 +262,7 @@ func newTypeResolver(fory *Fory) *TypeResolver { float32Type, float64Type, float16Type, + bfloat16Type, stringType, dateType, timestampType, @@ -416,6 +420,7 @@ func (r *TypeResolver) initialize() { {float32SliceType, FLOAT32_ARRAY, float32SliceSerializer{}}, {float64SliceType, FLOAT64_ARRAY, float64SliceSerializer{}}, {float16SliceType, FLOAT16_ARRAY, float16SliceSerializer{}}, + {bfloat16SliceType, BFLOAT16_ARRAY, bfloat16SliceSerializer{}}, // Register common map types for fast path with optimized serializers {stringStringMapType, MAP, stringStringMapSerializer{}}, {stringInt64MapType, MAP, stringInt64MapSerializer{}}, @@ -440,6 +445,7 @@ func (r *TypeResolver) initialize() { {float32Type, FLOAT32, float32Serializer{}}, {float64Type, FLOAT64, float64Serializer{}}, {float16Type, FLOAT16, float16Serializer{}}, + {bfloat16Type, BFLOAT16, bfloat16Serializer{}}, {dateType, DATE, dateSerializer{}}, {timestampType, TIMESTAMP, timeSerializer{}}, {genericSetType, SET, setSerializer{}}, @@ -1712,6 +1718,10 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s if elem == float16Type { return float16SliceSerializer{}, nil } + // Check for fory.BFloat16 (aliased to uint16) + if elem == bfloat16Type { + return bfloat16SliceSerializer{}, nil + } return uint16SliceSerializer{}, nil case reflect.Uint32: return uint32SliceSerializer{}, nil @@ -1763,6 +1773,10 @@ func (r *TypeResolver) createSerializer(type_ reflect.Type, mapInStruct bool) (s if elem == float16Type { return float16ArraySerializer{arrayType: type_}, nil } + // Check for fory.BFloat16 (aliased to uint16) + if elem == bfloat16Type { + return bfloat16ArraySerializer{arrayType: type_}, nil + } return uint16ArraySerializer{arrayType: type_}, nil case reflect.Uint32: return uint32ArraySerializer{arrayType: type_}, nil diff --git a/go/fory/type_test.go b/go/fory/type_test.go index 5b97669309..e8978e5930 100644 --- a/go/fory/type_test.go +++ b/go/fory/type_test.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/stretchr/testify/require" ) @@ -41,6 +42,9 @@ func TestTypeResolver(t *testing.T) { {reflect.TypeOf((*int)(nil)), "*int"}, {reflect.TypeOf((*[10]int)(nil)), "*[10]int"}, {reflect.TypeOf((*[10]int)(nil)).Elem(), "[10]int"}, + {reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem(), "bfloat16.BFloat16"}, + {reflect.TypeOf((*[]bfloat16.BFloat16)(nil)).Elem(), "[]bfloat16.BFloat16"}, + {reflect.TypeOf((*[10]bfloat16.BFloat16)(nil)).Elem(), "[10]bfloat16.BFloat16"}, {reflect.TypeOf((*[]map[string][]map[string]*any)(nil)).Elem(), "[]map[string][]map[string]*interface {}"}, {reflect.TypeOf((*A)(nil)), "*@example.A"}, @@ -84,6 +88,7 @@ func TestCreateSerializerSliceTypes(t *testing.T) { {reflect.TypeOf([]uint{}), reflect.TypeOf(uintSliceSerializer{})}, {reflect.TypeOf([]uint16{}), reflect.TypeOf(uint16SliceSerializer{})}, {reflect.TypeOf([]float16.Float16{}), reflect.TypeOf(float16SliceSerializer{})}, + {reflect.TypeOf([]bfloat16.BFloat16{}), reflect.TypeOf(bfloat16SliceSerializer{})}, {reflect.TypeOf([]uint32{}), reflect.TypeOf(uint32SliceSerializer{})}, {reflect.TypeOf([]uint64{}), reflect.TypeOf(uint64SliceSerializer{})}, {reflect.TypeOf([]string{}), reflect.TypeOf(stringSliceSerializer{})}, @@ -135,6 +140,7 @@ func TestCreateSerializerArrayTypes(t *testing.T) { {reflect.TypeOf([4]byte{}), reflect.TypeOf(uint8ArraySerializer{})}, {reflect.TypeOf([4]uint16{}), reflect.TypeOf(uint16ArraySerializer{})}, {reflect.TypeOf([4]float16.Float16{}), reflect.TypeOf(float16ArraySerializer{})}, + {reflect.TypeOf([4]bfloat16.BFloat16{}), reflect.TypeOf(bfloat16ArraySerializer{})}, {reflect.TypeOf([4]uint32{}), reflect.TypeOf(uint32ArraySerializer{})}, {reflect.TypeOf([4]uint64{}), reflect.TypeOf(uint64ArraySerializer{})}, }