Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/fory_compiler/generators/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def message_has_unions(self, message: Message) -> bool:
PrimitiveKind.VAR_UINT64: "uint64",
PrimitiveKind.TAGGED_UINT64: "uint64",
PrimitiveKind.FLOAT16: "float16.Float16",
PrimitiveKind.BFLOAT16: "bfloat16.BFloat16",
PrimitiveKind.FLOAT32: "float32",
PrimitiveKind.FLOAT64: "float64",
PrimitiveKind.STRING: "string",
Expand Down Expand Up @@ -1090,6 +1091,8 @@ def collect_imports(self, field_type: FieldType, imports: Set[str]):
imports.add('"time"')
elif field_type.kind == PrimitiveKind.FLOAT16:
imports.add('float16 "github.com/apache/fory/go/fory/float16"')
elif field_type.kind == PrimitiveKind.BFLOAT16:
imports.add('bfloat16 "github.com/apache/fory/go/fory/bfloat16"')

elif isinstance(field_type, ListType):
self.collect_imports(field_type.element_type, imports)
Expand Down
2 changes: 2 additions & 0 deletions compiler/fory_compiler/ir/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class PrimitiveKind(PyEnum):
VAR_UINT64 = "var_uint64"
TAGGED_UINT64 = "tagged_uint64"
FLOAT16 = "float16"
BFLOAT16 = "bfloat16"
FLOAT32 = "float32"
FLOAT64 = "float64"
STRING = "string"
Expand Down Expand Up @@ -67,6 +68,7 @@ class PrimitiveKind(PyEnum):
"fixed_uint64": PrimitiveKind.UINT64,
"tagged_uint64": PrimitiveKind.TAGGED_UINT64,
"float16": PrimitiveKind.FLOAT16,
"bfloat16": PrimitiveKind.BFLOAT16,
"float32": PrimitiveKind.FLOAT32,
"float64": PrimitiveKind.FLOAT64,
"string": PrimitiveKind.STRING,
Expand Down
30 changes: 30 additions & 0 deletions compiler/fory_compiler/tests/test_generated_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,33 @@ def test_generated_code_tree_ref_options_equivalent():

cpp_output = render_files(generate_files(schemas["fdl"], CppGenerator))
assert "SharedWeak<TreeNode>" in cpp_output


def test_go_bfloat16_generation():
idl = dedent(
"""
package bfloat16_test;

message BFloat16Message {
bfloat16 val = 1;
optional bfloat16 opt_val = 2;
list<bfloat16> list_val = 3;
}
"""
)
schema = parse_fdl(idl)
files = generate_files(schema, GoGenerator)

assert len(files) == 1
content = list(files.values())[0]

# Check imports
assert 'bfloat16 "github.com/apache/fory/go/fory/bfloat16"' in content

# Check fields
assert '\tVal bfloat16.BFloat16 `fory:"id=1"`' in content
assert (
'\tOptVal optional.Optional[bfloat16.BFloat16] `fory:"id=2,nullable"`'
in content
)
assert "\tListVal []bfloat16.BFloat16" in content
75 changes: 75 additions & 0 deletions go/fory/array_primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"reflect"
"unsafe"

"github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
)

Expand Down Expand Up @@ -871,3 +872,77 @@ func (s float16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType
func (s float16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) {
s.Read(ctx, refMode, false, false, value)
}

// ============================================================================
// bfloat16ArraySerializer - optimized [N]bfloat16.BFloat16 serialization
// ============================================================================

type bfloat16ArraySerializer struct {
arrayType reflect.Type
}

func (s bfloat16ArraySerializer) WriteData(ctx *WriteContext, value reflect.Value) {
buf := ctx.Buffer()
length := value.Len()
size := length * 2
buf.WriteLength(size)
if length > 0 {
if value.CanAddr() && isLittleEndian {
ptr := value.Addr().UnsafePointer()
buf.WriteBinary(unsafe.Slice((*byte)(ptr), size))
} else {
for i := 0; i < length; i++ {
// We can't easily cast the whole array if not addressable/little-endian
// So we iterate.
val := value.Index(i).Interface().(bfloat16.BFloat16)
buf.WriteUint16(val.Bits())
}
}
}
}

func (s bfloat16ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) {
writeArrayRefAndType(ctx, refMode, writeType, value, BFLOAT16_ARRAY)
if ctx.HasError() {
return
}
s.WriteData(ctx, value)
}

func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) {
buf := ctx.Buffer()
ctxErr := ctx.Err()
size := buf.ReadLength(ctxErr)
length := size / 2
if ctx.HasError() {
return
}
if length != value.Type().Len() {
ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type()))
return
}

if length > 0 {
if isLittleEndian {
ptr := value.Addr().UnsafePointer()
raw := buf.ReadBinary(size, ctxErr)
copy(unsafe.Slice((*byte)(ptr), size), raw)
} else {
for i := 0; i < length; i++ {
value.Index(i).Set(reflect.ValueOf(bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr))))
}
}
}
}

func (s bfloat16ArraySerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) {
done := readArrayRefAndType(ctx, refMode, readType, value)
if done || ctx.HasError() {
return
}
s.ReadData(ctx, value)
}

func (s bfloat16ArraySerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) {
s.Read(ctx, refMode, false, false, value)
}
16 changes: 16 additions & 0 deletions go/fory/array_primitive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package fory
import (
"testing"

"github.com/apache/fory/go/fory/bfloat16"
"github.com/apache/fory/go/fory/float16"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -74,6 +75,21 @@ func TestPrimitiveArraySerializer(t *testing.T) {
require.NoError(t, err)
require.Equal(t, arr, result)
})

t.Run("bfloat16_array", func(t *testing.T) {
arr := [3]bfloat16.BFloat16{
bfloat16.BFloat16FromFloat32(1.0),
bfloat16.BFloat16FromFloat32(2.5),
bfloat16.BFloat16FromFloat32(-3.5),
}
data, err := f.Serialize(arr)
assert.NoError(t, err)

var result [3]bfloat16.BFloat16
err = f.Deserialize(data, &result)
assert.NoError(t, err)
assert.Equal(t, arr, result)
})
}

func TestArraySliceInteroperability(t *testing.T) {
Expand Down
72 changes: 72 additions & 0 deletions go/fory/bfloat16/bfloat16.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package bfloat16

import (
"fmt"
"math"
)

// BFloat16 represents a brain floating point number (bfloat16).
// It is stored as a uint16.
type BFloat16 uint16

// BFloat16FromBits returns the BFloat16 corresponding to the given bit pattern.
func BFloat16FromBits(b uint16) BFloat16 {
return BFloat16(b)
}

// Bits returns the raw bit pattern of the floating point number.
func (f BFloat16) Bits() uint16 {
return uint16(f)
}

// BFloat16FromFloat32 converts a float32 to a BFloat16.
// Rounds to nearest, ties to even.
func BFloat16FromFloat32(f float32) BFloat16 {
u := math.Float32bits(f)

// NaN check
if (u&0x7F800000) == 0x7F800000 && (u&0x007FFFFF) != 0 {
return BFloat16(0x7FC0) // Canonical NaN
}

// Fast path for rounding
// We want to add a rounding bias and then truncate.
// For ties-to-even:
// If LSB of result (bit 16) is 0: Rounding bias is 0x7FFF
// If LSB of result (bit 16) is 1: Rounding bias is 0x8000
// lsb is (u >> 16) & 1.
// bias = 0x7FFF + lsb

lsb := (u >> 16) & 1
roundingBias := uint32(0x7FFF) + lsb
u += roundingBias
return BFloat16(u >> 16)
}

// Float32 returns the float32 representation of the BFloat16.
func (f BFloat16) Float32() float32 {
// Just shift left by 16 bits
return math.Float32frombits(uint32(f) << 16)
}

// String returns the string representation of f.
func (f BFloat16) String() string {
return fmt.Sprintf("%g", f.Float32())
}
96 changes: 96 additions & 0 deletions go/fory/bfloat16/bfloat16_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package bfloat16_test

import (
"math"
"testing"

"github.com/apache/fory/go/fory/bfloat16"
"github.com/stretchr/testify/assert"
)

func TestBFloat16_Conversion(t *testing.T) {
tests := []struct {
name string
f32 float32
want uint16 // bits
check bool // if true, check exact bits
}{
{"Zero", 0.0, 0x0000, true},
{"NegZero", float32(math.Copysign(0, -1)), 0x8000, true},
{"One", 1.0, 0x3F80, true},
{"MinusOne", -1.0, 0xBF80, true},
{"Inf", float32(math.Inf(1)), 0x7F80, true},
{"NegInf", float32(math.Inf(-1)), 0xFF80, true},
// 1.5 -> 0x3FC0. (0x3FC00000 is 1.5)
{"OnePointFive", 1.5, 0x3FC0, true},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
bf16 := bfloat16.BFloat16FromFloat32(tt.f32)
if tt.check {
assert.Equal(t, tt.want, bf16.Bits(), "Bits match")
}

// Round trip check
roundTrip := bf16.Float32()
if math.IsInf(float64(tt.f32), 0) {
assert.True(t, math.IsInf(float64(roundTrip), 0))
assert.Equal(t, math.Signbit(float64(tt.f32)), math.Signbit(float64(roundTrip)))
} else if math.IsNaN(float64(tt.f32)) {
assert.True(t, math.IsNaN(float64(roundTrip)))
} else {
if tt.check {
assert.Equal(t, tt.f32, roundTrip, "Round trip value match")
}
}
})
}
}

func TestBFloat16_Rounding(t *testing.T) {
// BFloat16 has 7 bits of mantissa. For 1.0, ULP is 2^-7, and half ULP is 2^-8.
// Values are rounded to nearest even. 1.0 + 2^-8 should round to 1.0 (even mantissa).

// The float32 representation of 1.0 is 0x3F800000.
// Adding 2^-8 (1/256) means setting bit 15 (23-8).
// So, 1.0 + 2^-8 in float32 is 0x3F808000.
val1 := math.Float32frombits(0x3F808000) // 1.0 + 2^-8
bf1 := bfloat16.BFloat16FromFloat32(val1)
assert.Equal(t, uint16(0x3F80), bf1.Bits(), "Round to even (down)")

// For 1.0 + 3 * 2^-8 (1.5 ULP), bits 15 and 14 are set,
// making the float32 representation 0x3F80C000. This rounds up.
val2 := math.Float32frombits(0x3F80C000)
bf2 := bfloat16.BFloat16FromFloat32(val2)
assert.Equal(t, uint16(0x3F81), bf2.Bits(), "Round up")

// 1.0 + 2^-7 is the next representable number after 1.0. In float32, this is 0x3F810000.
val3 := math.Float32frombits(0x3F810000)
bf3 := bfloat16.BFloat16FromFloat32(val3)
assert.Equal(t, uint16(0x3F81), bf3.Bits(), "Exact")

// For 1.0 + 2^-7 + 2^-8 (0x3F818000), the LSB (bit 16) of 0x3F81 is 1 (odd),
// and the guard bit (bit 15) is 1. Rounding to nearest even means rounding up.
// Result: 0x3F82.
val4 := math.Float32frombits(0x3F818000)
bf4 := bfloat16.BFloat16FromFloat32(val4)
assert.Equal(t, uint16(0x3F82), bf4.Bits(), "Round to even (up)")
}
Loading
Loading