diff --git a/compiler/fory_compiler/generators/java.py b/compiler/fory_compiler/generators/java.py index a6d85b3e80..29f52c3011 100644 --- a/compiler/fory_compiler/generators/java.py +++ b/compiler/fory_compiler/generators/java.py @@ -226,7 +226,7 @@ def generate_bytes_methods(self, class_name: str) -> List[str]: PrimitiveKind.UINT64: "long", PrimitiveKind.VAR_UINT64: "long", PrimitiveKind.TAGGED_UINT64: "long", - PrimitiveKind.FLOAT16: "float", + PrimitiveKind.FLOAT16: "Float16", PrimitiveKind.FLOAT32: "float", PrimitiveKind.FLOAT64: "double", PrimitiveKind.STRING: "String", @@ -255,7 +255,7 @@ def generate_bytes_methods(self, class_name: str) -> List[str]: PrimitiveKind.UINT64: "Long", PrimitiveKind.VAR_UINT64: "Long", PrimitiveKind.TAGGED_UINT64: "Long", - PrimitiveKind.FLOAT16: "Float", + PrimitiveKind.FLOAT16: "Float16", PrimitiveKind.FLOAT32: "Float", PrimitiveKind.FLOAT64: "Double", PrimitiveKind.ANY: "Object", @@ -278,7 +278,7 @@ def generate_bytes_methods(self, class_name: str) -> List[str]: PrimitiveKind.UINT64: "long[]", PrimitiveKind.VAR_UINT64: "long[]", PrimitiveKind.TAGGED_UINT64: "long[]", - PrimitiveKind.FLOAT16: "float[]", + PrimitiveKind.FLOAT16: "Float16[]", PrimitiveKind.FLOAT32: "float[]", PrimitiveKind.FLOAT64: "double[]", } @@ -1165,6 +1165,8 @@ def collect_type_imports( imports.add("java.time.LocalDate") elif field_type.kind == PrimitiveKind.TIMESTAMP: imports.add("java.time.Instant") + elif field_type.kind == PrimitiveKind.FLOAT16: + imports.add("org.apache.fory.type.Float16") elif isinstance(field_type, ListType): # Primitive arrays don't need List import @@ -1379,7 +1381,11 @@ def generate_equals_method(self, message: Message) -> List[str]: ) elif isinstance(field.field_type, PrimitiveType): kind = field.field_type.kind - if kind in (PrimitiveKind.FLOAT32,): + if kind in (PrimitiveKind.FLOAT16,): + comparisons.append( + f"{field_name}.equalsValue(that.{field_name})" + ) + elif kind in (PrimitiveKind.FLOAT32,): comparisons.append( f"Float.compare({field_name}, that.{field_name}) == 0" ) diff --git a/java/fory-core/src/main/java/org/apache/fory/collection/Float16List.java b/java/fory-core/src/main/java/org/apache/fory/collection/Float16List.java new file mode 100644 index 0000000000..bf9c75a24f --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/collection/Float16List.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.collection; + +import java.util.AbstractList; +import java.util.Arrays; +import java.util.Objects; +import java.util.RandomAccess; +import org.apache.fory.type.Float16; + +public final class Float16List extends AbstractList implements RandomAccess { + private static final int DEFAULT_CAPACITY = 10; + + private short[] array; + private int size; + + public Float16List() { + this(DEFAULT_CAPACITY); + } + + public Float16List(int initialCapacity) { + if (initialCapacity < 0) { + throw new IllegalArgumentException("Illegal capacity: " + initialCapacity); + } + this.array = new short[initialCapacity]; + this.size = 0; + } + + public Float16List(short[] array) { + this.array = array; + this.size = array.length; + } + + @Override + public Float16 get(int index) { + checkIndex(index); + return Float16.fromBits(array[index]); + } + + @Override + public int size() { + return size; + } + + @Override + public Float16 set(int index, Float16 element) { + checkIndex(index); + Objects.requireNonNull(element, "element"); + short prev = array[index]; + array[index] = element.toBits(); + return Float16.fromBits(prev); + } + + public void set(int index, short bits) { + checkIndex(index); + array[index] = bits; + } + + public void set(int index, float value) { + checkIndex(index); + array[index] = Float16.valueOf(value).toBits(); + } + + @Override + public void add(int index, Float16 element) { + checkPositionIndex(index); + ensureCapacity(size + 1); + System.arraycopy(array, index, array, index + 1, size - index); + array[index] = element.toBits(); + size++; + modCount++; + } + + @Override + public boolean add(Float16 element) { + Objects.requireNonNull(element, "element"); + ensureCapacity(size + 1); + array[size++] = element.toBits(); + modCount++; + return true; + } + + public boolean add(short bits) { + ensureCapacity(size + 1); + array[size++] = bits; + modCount++; + return true; + } + + public boolean add(float value) { + ensureCapacity(size + 1); + array[size++] = Float16.valueOf(value).toBits(); + modCount++; + return true; + } + + public float getFloat(int index) { + checkIndex(index); + return Float16.fromBits(array[index]).toFloat(); + } + + public short getShort(int index) { + checkIndex(index); + return array[index]; + } + + public boolean hasArray() { + return array != null; + } + + public short[] getArray() { + return array; + } + + public short[] copyArray() { + return Arrays.copyOf(array, size); + } + + private void ensureCapacity(int minCapacity) { + if (array.length >= minCapacity) { + return; + } + int newCapacity = array.length + (array.length >> 1) + 1; + if (newCapacity < minCapacity) { + newCapacity = minCapacity; + } + array = Arrays.copyOf(array, newCapacity); + } + + private void checkIndex(int index) { + if (index < 0 || index >= size) { + throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size); + } + } + + private void checkPositionIndex(int index) { + if (index < 0 || index > size) { + throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size); + } + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index db9ba2b503..83aa5399c8 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -33,6 +33,7 @@ import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.collection.CollectionFlags; import org.apache.fory.serializer.collection.ForyArrayAsListSerializer; +import org.apache.fory.type.Float16; import org.apache.fory.type.GenericType; import org.apache.fory.type.TypeUtils; import org.apache.fory.util.Preconditions; @@ -881,6 +882,50 @@ private void readFloat64BySwapEndian(MemoryBuffer buffer, double[] values, int n } } + public static final class Float16ArraySerializer extends PrimitiveArraySerializer { + + public Float16ArraySerializer(Fory fory) { + super(fory, Float16[].class); + } + + @Override + public void write(MemoryBuffer buffer, Float16[] value) { + int length = value.length; + int size = length * 2; + buffer.writeVarUint32Small7(size); + + int writerIndex = buffer.writerIndex(); + buffer.ensure(writerIndex + size); + + for (int i = 0; i < length; i++) { + buffer._unsafePutInt16(writerIndex + i * 2, value[i].toBits()); + } + buffer._unsafeWriterIndex(writerIndex + size); + } + + @Override + public Float16[] copy(Float16[] originArray) { + return Arrays.copyOf(originArray, originArray.length); + } + + @Override + public Float16[] read(MemoryBuffer buffer) { + int size = buffer.readVarUint32Small7(); + int numElements = size / 2; + Float16[] values = new Float16[numElements]; + + int readerIndex = buffer.readerIndex(); + buffer.checkReadableBytes(size); + + for (int i = 0; i < numElements; i++) { + values[i] = Float16.fromBits(buffer._unsafeGetInt16(readerIndex + i * 2)); + } + buffer._increaseReaderIndexUnsafe(size); + + return values; + } + } + public static final class StringArraySerializer extends Serializer { private final StringSerializer stringSerializer; private final ForyArrayAsListSerializer collectionSerializer; @@ -1008,6 +1053,7 @@ public static void registerDefaultSerializers(Fory fory) { resolver.registerInternalSerializer(double[].class, new DoubleArraySerializer(fory)); resolver.registerInternalSerializer( Double[].class, new ObjectArraySerializer<>(fory, Double[].class)); + resolver.registerInternalSerializer(Float16[].class, new Float16ArraySerializer(fory)); resolver.registerInternalSerializer(boolean[].class, new BooleanArraySerializer(fory)); resolver.registerInternalSerializer( Boolean[].class, new ObjectArraySerializer<>(fory, Boolean[].class)); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/Float16Serializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/Float16Serializer.java new file mode 100644 index 0000000000..cf34806ad6 --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/Float16Serializer.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import org.apache.fory.Fory; +import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.type.Float16; + +public final class Float16Serializer extends ImmutableSerializer { + + public Float16Serializer(Fory fory) { + super(fory, Float16.class); + } + + @Override + public void write(MemoryBuffer buffer, Float16 value) { + buffer.writeInt16(value.toBits()); + } + + @Override + public Float16 read(MemoryBuffer buffer) { + return Float16.fromBits(buffer.readInt16()); + } + + @Override + public void xwrite(MemoryBuffer buffer, Float16 value) { + buffer.writeInt16(value.toBits()); + } + + @Override + public Float16 xread(MemoryBuffer buffer) { + return Float16.fromBits(buffer.readInt16()); + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveSerializers.java index 3c44fc5e1a..9cbfba043c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveSerializers.java @@ -29,6 +29,7 @@ import org.apache.fory.memory.Platform; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.Serializers.CrossLanguageCompatibleSerializer; +import org.apache.fory.type.Float16; import org.apache.fory.util.Preconditions; /** Serializers for java primitive types. */ @@ -342,6 +343,22 @@ public Double read(MemoryBuffer buffer) { } } + public static final class Float16Serializer extends CrossLanguageCompatibleSerializer { + public Float16Serializer(Fory fory, Class cls) { + super(fory, (Class) cls, false, true); + } + + @Override + public void write(MemoryBuffer buffer, Float16 value) { + buffer.writeInt16(value.toBits()); + } + + @Override + public Float16 read(MemoryBuffer buffer) { + return Float16.fromBits(buffer.readInt16()); + } + } + public static void registerDefaultSerializers(Fory fory) { // primitive types will be boxed. TypeResolver resolver = fory.getTypeResolver(); @@ -361,5 +378,6 @@ public static void registerDefaultSerializers(Fory fory) { resolver.registerInternalSerializer(Long.class, new LongSerializer(fory, Long.class)); resolver.registerInternalSerializer(Float.class, new FloatSerializer(fory, Float.class)); resolver.registerInternalSerializer(Double.class, new DoubleSerializer(fory, Double.class)); + resolver.registerInternalSerializer(Float16.class, new Float16Serializer(fory, Float16.class)); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java index 8ade12701d..13b1bdab6e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java @@ -50,6 +50,7 @@ import org.apache.fory.meta.TypeDef; import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.resolver.TypeResolver; +import org.apache.fory.type.Float16; import org.apache.fory.util.ExceptionUtils; import org.apache.fory.util.GraalvmSupport; import org.apache.fory.util.GraalvmSupport.GraalvmSerializerHolder; @@ -595,6 +596,7 @@ public static void registerDefaultSerializers(Fory fory) { resolver.registerInternalSerializer(URI.class, new URISerializer(fory)); resolver.registerInternalSerializer(Pattern.class, new RegexSerializer(fory)); resolver.registerInternalSerializer(UUID.class, new UUIDSerializer(fory)); + resolver.registerInternalSerializer(Float16.class, new Float16Serializer(fory)); resolver.registerInternalSerializer(Object.class, new EmptyObjectSerializer(fory)); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java index d86556bb1f..0ae18c5a9a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java @@ -21,6 +21,7 @@ import org.apache.fory.Fory; import org.apache.fory.collection.BoolList; +import org.apache.fory.collection.Float16List; import org.apache.fory.collection.Float32List; import org.apache.fory.collection.Float64List; import org.apache.fory.collection.Int16List; @@ -539,6 +540,43 @@ public Float64List read(MemoryBuffer buffer) { } } + public static final class Float16ListSerializer + extends Serializers.CrossLanguageCompatibleSerializer { + public Float16ListSerializer(Fory fory) { + super(fory, Float16List.class, false, true); + } + + @Override + public void write(MemoryBuffer buffer, Float16List value) { + int size = value.size(); + int byteSize = size * 2; + buffer.writeVarUint32Small7(byteSize); + short[] array = value.getArray(); + if (Platform.IS_LITTLE_ENDIAN) { + buffer.writePrimitiveArray(array, Platform.SHORT_ARRAY_OFFSET, byteSize); + } else { + for (int i = 0; i < size; i++) { + buffer.writeInt16(array[i]); + } + } + } + + @Override + public Float16List read(MemoryBuffer buffer) { + int byteSize = buffer.readVarUint32Small7(); + int size = byteSize / 2; + short[] array = new short[size]; + if (Platform.IS_LITTLE_ENDIAN) { + buffer.readToUnsafe(array, Platform.SHORT_ARRAY_OFFSET, byteSize); + } else { + for (int i = 0; i < size; i++) { + array[i] = buffer.readInt16(); + } + } + return new Float16List(array); + } + } + public static void registerDefaultSerializers(Fory fory) { // Note: Classes are already registered in ClassResolver.initialize() // We only need to register serializers here @@ -554,5 +592,6 @@ public static void registerDefaultSerializers(Fory fory) { resolver.registerInternalSerializer(Uint64List.class, new Uint64ListSerializer(fory)); resolver.registerInternalSerializer(Float32List.class, new Float32ListSerializer(fory)); resolver.registerInternalSerializer(Float64List.class, new Float64ListSerializer(fory)); + resolver.registerInternalSerializer(Float16List.class, new Float16ListSerializer(fory)); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/type/DispatchId.java b/java/fory-core/src/main/java/org/apache/fory/type/DispatchId.java index d154e3352d..9719d8e755 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/DispatchId.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/DispatchId.java @@ -49,15 +49,16 @@ public class DispatchId { public static final int UINT32 = 14; public static final int VAR_UINT32 = 15; public static final int UINT64 = 16; - public static final int VAR_UINT64 = 17; - public static final int TAGGED_UINT64 = 18; - public static final int EXT_UINT8 = 19; - public static final int EXT_UINT16 = 20; - public static final int EXT_UINT32 = 21; - public static final int EXT_VAR_UINT32 = 22; - public static final int EXT_UINT64 = 23; - public static final int EXT_VAR_UINT64 = 24; - public static final int STRING = 25; + public static final int FLOAT16 = 17; + public static final int VAR_UINT64 = 18; + public static final int TAGGED_UINT64 = 19; + public static final int EXT_UINT8 = 20; + public static final int EXT_UINT16 = 21; + public static final int EXT_UINT32 = 22; + public static final int EXT_VAR_UINT32 = 23; + public static final int EXT_UINT64 = 24; + public static final int EXT_VAR_UINT64 = 25; + public static final int STRING = 26; public static int getDispatchId(Fory fory, Descriptor d) { int typeId = Types.getDescriptorTypeId(fory, d); @@ -101,6 +102,8 @@ private static int xlangTypeIdToDispatchId(int typeId) { return VAR_UINT64; case Types.TAGGED_UINT64: return TAGGED_UINT64; + case Types.FLOAT16: + return FLOAT16; case Types.FLOAT32: return FLOAT32; case Types.FLOAT64: diff --git a/java/fory-core/src/main/java/org/apache/fory/type/Float16.java b/java/fory-core/src/main/java/org/apache/fory/type/Float16.java new file mode 100644 index 0000000000..d7efd9fed6 --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/type/Float16.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.type; + +import java.io.Serializable; + +public final class Float16 extends Number implements Comparable, Serializable { + private static final long serialVersionUID = 1L; + + private static final int SIGN_MASK = 0x8000; + private static final int EXP_MASK = 0x7C00; + private static final int MANT_MASK = 0x03FF; + + private static final short BITS_NAN = (short) 0x7E00; + private static final short BITS_POS_INF = (short) 0x7C00; // +Inf + private static final short BITS_NEG_INF = (short) 0xFC00; // -Inf + private static final short BITS_NEG_ZERO = (short) 0x8000; // -0 + private static final short BITS_MAX = (short) 0x7BFF; // 65504 + private static final short BITS_ONE = (short) 0x3C00; // 1.0 + private static final short BITS_MIN_NORMAL = (short) 0x0400; + private static final short BITS_MIN_VALUE = (short) 0x0001; + + public static final Float16 NaN = new Float16(BITS_NAN); + + public static final Float16 POSITIVE_INFINITY = new Float16(BITS_POS_INF); + + public static final Float16 NEGATIVE_INFINITY = new Float16(BITS_NEG_INF); + + public static final Float16 ZERO = new Float16((short) 0); + + public static final Float16 NEGATIVE_ZERO = new Float16(BITS_NEG_ZERO); + + public static final Float16 ONE = new Float16(BITS_ONE); + + public static final Float16 MAX_VALUE = new Float16(BITS_MAX); + + public static final Float16 MIN_NORMAL = new Float16(BITS_MIN_NORMAL); + + public static final Float16 MIN_VALUE = new Float16(BITS_MIN_VALUE); + + public static final int SIZE_BITS = 16; + + public static final int SIZE_BYTES = 2; + + private final short bits; + + private Float16(short bits) { + this.bits = bits; + } + + public static Float16 fromBits(short bits) { + return new Float16(bits); + } + + public static Float16 valueOf(float value) { + return new Float16(floatToFloat16Bits(value)); + } + + public short toBits() { + return bits; + } + + public float toFloat() { + return floatValue(); + } + + private static short floatToFloat16Bits(float f32) { + int bits32 = Float.floatToRawIntBits(f32); + int sign = (bits32 >>> 31) & 0x1; + int exp = (bits32 >>> 23) & 0xFF; + int mant = bits32 & 0x7FFFFF; + + int outSign = sign << 15; + int outExp; + int outMant; + + if (exp == 0xFF) { + outExp = 0x1F; + if (mant != 0) { + outMant = 0x200 | ((mant >>> 13) & 0x1FF); + if (outMant == 0x200) { + outMant = 0x201; + } + } else { + outMant = 0; + } + } else if (exp == 0) { + outExp = 0; + outMant = 0; + } else { + int newExp = exp - 127 + 15; + + if (newExp >= 31) { + outExp = 0x1F; + outMant = 0; + } else if (newExp <= 0) { + int fullMant = mant | 0x800000; + int shift = 1 - newExp; + int netShift = 13 + shift; + + if (netShift >= 24) { + outExp = 0; + outMant = 0; + } else { + outExp = 0; + int roundBit = (fullMant >>> (netShift - 1)) & 1; + int sticky = fullMant & ((1 << (netShift - 1)) - 1); + outMant = fullMant >>> netShift; + + if (roundBit == 1 && (sticky != 0 || (outMant & 1) == 1)) { + outMant++; + } + } + } else { + outExp = newExp; + outMant = mant >>> 13; + + int roundBit = (mant >>> 12) & 1; + int sticky = mant & 0xFFF; + + if (roundBit == 1 && (sticky != 0 || (outMant & 1) == 1)) { + outMant++; + if (outMant > 0x3FF) { + outMant = 0; + outExp++; + if (outExp >= 31) { + outExp = 0x1F; + } + } + } + } + } + + return (short) (outSign | (outExp << 10) | outMant); + } + + private static float float16BitsToFloat(short bits16) { + int bits = bits16 & 0xFFFF; + int sign = (bits >>> 15) & 0x1; + int exp = (bits >>> 10) & 0x1F; + int mant = bits & 0x3FF; + + int outBits = sign << 31; + + if (exp == 0x1F) { + outBits |= 0xFF << 23; + if (mant != 0) { + outBits |= mant << 13; + } + } else if (exp == 0) { + if (mant != 0) { + int shift = Integer.numberOfLeadingZeros(mant) - 22; + mant = (mant << shift) & 0x3FF; + int newExp = 1 - 15 - shift + 127; + outBits |= newExp << 23; + outBits |= mant << 13; + } + } else { + outBits |= (exp - 15 + 127) << 23; + outBits |= mant << 13; + } + + return Float.intBitsToFloat(outBits); + } + + public boolean isNaN() { + return (bits & EXP_MASK) == EXP_MASK && (bits & MANT_MASK) != 0; + } + + public boolean isInfinite() { + return (bits & EXP_MASK) == EXP_MASK && (bits & MANT_MASK) == 0; + } + + public boolean isFinite() { + return (bits & EXP_MASK) != EXP_MASK; + } + + public boolean isZero() { + return (bits & (EXP_MASK | MANT_MASK)) == 0; + } + + public boolean isNormal() { + int exp = bits & EXP_MASK; + return exp != 0 && exp != EXP_MASK; + } + + public boolean isSubnormal() { + return (bits & EXP_MASK) == 0 && (bits & MANT_MASK) != 0; + } + + public boolean signbit() { + return (bits & SIGN_MASK) != 0; + } + + public Float16 add(Float16 other) { + return valueOf(floatValue() + other.floatValue()); + } + + public Float16 subtract(Float16 other) { + return valueOf(floatValue() - other.floatValue()); + } + + public Float16 multiply(Float16 other) { + return valueOf(floatValue() * other.floatValue()); + } + + public Float16 divide(Float16 other) { + return valueOf(floatValue() / other.floatValue()); + } + + public Float16 negate() { + return fromBits((short) (bits ^ SIGN_MASK)); + } + + public Float16 abs() { + return fromBits((short) (bits & ~SIGN_MASK)); + } + + @Override + public float floatValue() { + return float16BitsToFloat(bits); + } + + @Override + public double doubleValue() { + return floatValue(); + } + + @Override + public int intValue() { + return (int) floatValue(); + } + + @Override + public long longValue() { + return (long) floatValue(); + } + + @Override + public byte byteValue() { + return (byte) floatValue(); + } + + @Override + public short shortValue() { + return (short) floatValue(); + } + + public boolean isNumericEqual(Float16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + if (isZero() && other.isZero()) { + return true; + } + return bits == other.bits; + } + + public boolean equalsValue(Float16 other) { + return isNumericEqual(other); + } + + public boolean lessThan(Float16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + return floatValue() < other.floatValue(); + } + + public boolean lessThanOrEqual(Float16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + return floatValue() <= other.floatValue(); + } + + public boolean greaterThan(Float16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + return floatValue() > other.floatValue(); + } + + public boolean greaterThanOrEqual(Float16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + return floatValue() >= other.floatValue(); + } + + public static int compare(Float16 a, Float16 b) { + return Float.compare(a.floatValue(), b.floatValue()); + } + + public static Float16 parse(String s) { + return valueOf(Float.parseFloat(s)); + } + + @Override + public int compareTo(Float16 other) { + return compare(this, other); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Float16)) { + return false; + } + Float16 other = (Float16) obj; + return bits == other.bits; + } + + @Override + public int hashCode() { + return Short.hashCode(bits); + } + + @Override + public String toString() { + return Float.toString(floatValue()); + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/type/Types.java b/java/fory-core/src/main/java/org/apache/fory/type/Types.java index 2e7968a3e2..ed26f9dff2 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/Types.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/Types.java @@ -449,7 +449,7 @@ public static Class getClassForTypeId(int typeId) { return Long.class; case FLOAT8: case FLOAT16: - case BFLOAT16: + return Float16.class; case FLOAT32: return Float.class; case FLOAT64: diff --git a/java/fory-core/src/main/resources/META-INF/native-image/org.apache.fory/fory-core/native-image.properties b/java/fory-core/src/main/resources/META-INF/native-image/org.apache.fory/fory-core/native-image.properties index 8c7dd27afb..d485e25bdc 100644 --- a/java/fory-core/src/main/resources/META-INF/native-image/org.apache.fory/fory-core/native-image.properties +++ b/java/fory-core/src/main/resources/META-INF/native-image/org.apache.fory/fory-core/native-image.properties @@ -19,6 +19,8 @@ # The unsafe offset get on build time may be different from runtime Args=--initialize-at-build-time=org.apache.fory.memory.MemoryBuffer,\ org.apache.fory.serializer.struct.Fingerprint,\ + org.apache.fory.serializer.Float16Serializer,\ + org.apache.fory.serializer.PrimitiveSerializers$Float16Serializer,\ com.google.common.base.Equivalence$Equals,\ com.google.common.base.Equivalence$Identity,\ com.google.common.base.Equivalence,\ @@ -207,6 +209,7 @@ Args=--initialize-at-build-time=org.apache.fory.memory.MemoryBuffer,\ org.apache.fory.codegen.Expression$Variable,\ org.apache.fory.codegen.JaninoUtils,\ org.apache.fory.collection.ClassValueCache,\ + org.apache.fory.collection.Float16List,\ org.apache.fory.collection.ForyObjectMap,\ org.apache.fory.collection.IdentityMap,\ org.apache.fory.collection.IdentityObjectIntMap,\ @@ -311,6 +314,7 @@ Args=--initialize-at-build-time=org.apache.fory.memory.MemoryBuffer,\ org.apache.fory.serializer.ArraySerializers$ObjectArraySerializer,\ org.apache.fory.serializer.ArraySerializers$ShortArraySerializer,\ org.apache.fory.serializer.ArraySerializers$StringArraySerializer,\ + org.apache.fory.serializer.ArraySerializers$Float16ArraySerializer,\ org.apache.fory.serializer.ArraySerializers,\ org.apache.fory.serializer.UnsignedSerializers,\ org.apache.fory.serializer.UnsignedSerializers$Uint8Serializer,\ @@ -329,6 +333,7 @@ Args=--initialize-at-build-time=org.apache.fory.memory.MemoryBuffer,\ org.apache.fory.serializer.collection.PrimitiveListSerializers$Uint64ListSerializer,\ org.apache.fory.serializer.collection.PrimitiveListSerializers$Float32ListSerializer,\ org.apache.fory.serializer.collection.PrimitiveListSerializers$Float64ListSerializer,\ + org.apache.fory.serializer.collection.PrimitiveListSerializers$Float16ListSerializer,\ org.apache.fory.serializer.BufferSerializers$ByteBufferSerializer,\ org.apache.fory.serializer.MetaSharedSerializer,\ org.apache.fory.serializer.MetaSharedLayerSerializer,\ @@ -514,6 +519,7 @@ Args=--initialize-at-build-time=org.apache.fory.memory.MemoryBuffer,\ org.apache.fory.type.Descriptor$1,\ org.apache.fory.type.Descriptor,\ org.apache.fory.type.DescriptorGrouper,\ + org.apache.fory.type.Float16,\ org.apache.fory.type.GenericType,\ org.apache.fory.type.Generics,\ org.apache.fory.type.Type,\ diff --git a/java/fory-core/src/main/resources/META-INF/native-image/org.apache.fory/fory-core/reflection-config.json b/java/fory-core/src/main/resources/META-INF/native-image/org.apache.fory/fory-core/reflection-config.json index 505934bf94..97d1028ae4 100644 --- a/java/fory-core/src/main/resources/META-INF/native-image/org.apache.fory/fory-core/reflection-config.json +++ b/java/fory-core/src/main/resources/META-INF/native-image/org.apache.fory/fory-core/reflection-config.json @@ -5,5 +5,8 @@ "allPublicConstructors": true, "allDeclaredMethods": true, "allPublicMethods": true + }, } ] + + diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/Float16SerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/Float16SerializerTest.java new file mode 100644 index 0000000000..b9ab353e46 --- /dev/null +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/Float16SerializerTest.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import static org.testng.Assert.*; + +import lombok.AllArgsConstructor; +import lombok.Data; +import org.apache.fory.Fory; +import org.apache.fory.ForyTestBase; +import org.apache.fory.config.Language; +import org.apache.fory.type.Float16; +import org.testng.annotations.Test; + +public class Float16SerializerTest extends ForyTestBase { + + @Test + public void testFloat16Serialization() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + + byte[] bytes = fory.serialize(Float16.NaN); + Float16 result = (Float16) fory.deserialize(bytes); + assertTrue(result.isNaN()); + + bytes = fory.serialize(Float16.POSITIVE_INFINITY); + result = (Float16) fory.deserialize(bytes); + assertTrue(result.isInfinite() && !result.signbit()); + + bytes = fory.serialize(Float16.NEGATIVE_INFINITY); + result = (Float16) fory.deserialize(bytes); + assertTrue(result.isInfinite() && result.signbit()); + + bytes = fory.serialize(Float16.ZERO); + result = (Float16) fory.deserialize(bytes); + assertEquals(Float16.ZERO.toBits(), result.toBits()); + + bytes = fory.serialize(Float16.ONE); + result = (Float16) fory.deserialize(bytes); + assertEquals(Float16.ONE.toBits(), result.toBits()); + + bytes = fory.serialize(Float16.valueOf(1.5f)); + result = (Float16) fory.deserialize(bytes); + assertEquals(1.5f, result.floatValue(), 0.01f); + } + + @Test + public void testFloat16ArraySerialization() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + + Float16[] array = + new Float16[] { + Float16.ZERO, + Float16.ONE, + Float16.valueOf(2.5f), + Float16.valueOf(-1.5f), + Float16.NaN, + Float16.POSITIVE_INFINITY, + Float16.NEGATIVE_INFINITY, + Float16.MAX_VALUE, + Float16.MIN_VALUE + }; + + byte[] bytes = fory.serialize(array); + Float16[] result = (Float16[]) fory.deserialize(bytes); + + assertEquals(array.length, result.length); + for (int i = 0; i < array.length; i++) { + if (array[i].isNaN()) { + assertTrue(result[i].isNaN(), "Index " + i + " should be NaN"); + } else { + assertEquals( + array[i].toBits(), + result[i].toBits(), + "Index " + + i + + " bits should match: expected 0x" + + Integer.toHexString(array[i].toBits() & 0xFFFF) + + " but got 0x" + + Integer.toHexString(result[i].toBits() & 0xFFFF)); + } + } + } + + @Test + public void testFloat16EmptyArray() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + + Float16[] empty = new Float16[0]; + byte[] bytes = fory.serialize(empty); + Float16[] result = (Float16[]) fory.deserialize(bytes); + assertEquals(0, result.length); + } + + @Test + public void testFloat16LargeArray() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + + Float16[] large = new Float16[1000]; + for (int i = 0; i < large.length; i++) { + large[i] = Float16.valueOf((float) i); + } + + byte[] bytes = fory.serialize(large); + Float16[] result = (Float16[]) fory.deserialize(bytes); + + assertEquals(large.length, result.length); + for (int i = 0; i < large.length; i++) { + assertEquals( + large[i].floatValue(), result[i].floatValue(), 0.1f, "Index " + i + " should match"); + } + } + + @Data + @AllArgsConstructor + public static class StructWithFloat16 { + Float16 f16Field; + Float16 f16Field2; + } + + @Test + public void testStructWithFloat16Field() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + fory.register(StructWithFloat16.class); + + StructWithFloat16 obj = new StructWithFloat16(Float16.valueOf(1.5f), Float16.valueOf(-2.5f)); + + byte[] bytes = fory.serialize(obj); + StructWithFloat16 result = (StructWithFloat16) fory.deserialize(bytes); + + assertEquals(obj.f16Field.toBits(), result.f16Field.toBits()); + assertEquals(obj.f16Field2.toBits(), result.f16Field2.toBits()); + } + + @Data + @AllArgsConstructor + public static class StructWithFloat16Array { + Float16[] f16Array; + } + + @Test + public void testStructWithFloat16Array() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + fory.register(StructWithFloat16Array.class); + + Float16[] array = new Float16[] {Float16.ONE, Float16.valueOf(2.0f), Float16.valueOf(3.0f)}; + StructWithFloat16Array obj = new StructWithFloat16Array(array); + + byte[] bytes = fory.serialize(obj); + StructWithFloat16Array result = (StructWithFloat16Array) fory.deserialize(bytes); + + assertEquals(obj.f16Array.length, result.f16Array.length); + for (int i = 0; i < obj.f16Array.length; i++) { + assertEquals(obj.f16Array[i].toBits(), result.f16Array[i].toBits()); + } + } + + @Test + public void testFloat16WithNullableField() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + + @Data + @AllArgsConstructor + class StructWithNullableFloat16 { + Float16 nullableField; + } + + fory.register(StructWithNullableFloat16.class); + + StructWithNullableFloat16 obj1 = new StructWithNullableFloat16(Float16.valueOf(1.5f)); + byte[] bytes = fory.serialize(obj1); + StructWithNullableFloat16 result = (StructWithNullableFloat16) fory.deserialize(bytes); + assertEquals(obj1.nullableField.toBits(), result.nullableField.toBits()); + + StructWithNullableFloat16 obj2 = new StructWithNullableFloat16(null); + bytes = fory.serialize(obj2); + result = (StructWithNullableFloat16) fory.deserialize(bytes); + assertNull(result.nullableField); + } + + @Test + public void testFloat16SpecialValuesInStruct() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + + @Data + @AllArgsConstructor + class StructWithSpecialValues { + Float16 nan; + Float16 posInf; + Float16 negInf; + Float16 zero; + Float16 negZero; + } + + fory.register(StructWithSpecialValues.class); + + StructWithSpecialValues obj = + new StructWithSpecialValues( + Float16.NaN, + Float16.POSITIVE_INFINITY, + Float16.NEGATIVE_INFINITY, + Float16.ZERO, + Float16.NEGATIVE_ZERO); + + byte[] bytes = fory.serialize(obj); + StructWithSpecialValues result = (StructWithSpecialValues) fory.deserialize(bytes); + + assertTrue(result.nan.isNaN()); + assertTrue(result.posInf.isInfinite() && !result.posInf.signbit()); + assertTrue(result.negInf.isInfinite() && result.negInf.signbit()); + assertTrue(result.zero.isZero() && !result.zero.signbit()); + assertTrue(result.negZero.isZero() && result.negZero.signbit()); + } + + @Test + public void testFloat16BitPatternPreservation() { + Fory fory = Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + + short[] testBits = { + (short) 0x0000, + (short) 0x8000, + (short) 0x3c00, + (short) 0xbc00, + (short) 0x7bff, + (short) 0x0001, + (short) 0x0400, + (short) 0x7c00, + (short) 0xfc00, + (short) 0x7e00 + }; + + for (short bits : testBits) { + Float16 original = Float16.fromBits(bits); + byte[] bytes = fory.serialize(original); + Float16 result = (Float16) fory.deserialize(bytes); + + if (original.isNaN()) { + assertTrue(result.isNaN(), "NaN should remain NaN"); + } else { + assertEquals( + original.toBits(), + result.toBits(), + "Bit pattern should be preserved for 0x" + Integer.toHexString(bits & 0xFFFF)); + } + } + } +} diff --git a/java/fory-core/src/test/java/org/apache/fory/type/Float16Test.java b/java/fory-core/src/test/java/org/apache/fory/type/Float16Test.java new file mode 100644 index 0000000000..a3364bfe6d --- /dev/null +++ b/java/fory-core/src/test/java/org/apache/fory/type/Float16Test.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.type; + +import static org.testng.Assert.*; + +import org.testng.annotations.Test; + +public class Float16Test { + + @Test + public void testSpecialValues() { + assertTrue(Float16.NaN.isNaN()); + assertFalse(Float16.NaN.equalsValue(Float16.NaN)); + assertFalse(Float16.NaN.isFinite()); + assertFalse(Float16.NaN.isInfinite()); + + assertTrue(Float16.POSITIVE_INFINITY.isInfinite()); + assertTrue(Float16.NEGATIVE_INFINITY.isInfinite()); + assertFalse(Float16.POSITIVE_INFINITY.signbit()); + assertTrue(Float16.NEGATIVE_INFINITY.signbit()); + assertFalse(Float16.POSITIVE_INFINITY.isFinite()); + + assertTrue(Float16.ZERO.isZero()); + assertTrue(Float16.NEGATIVE_ZERO.isZero()); + assertTrue(Float16.ZERO.equalsValue(Float16.NEGATIVE_ZERO)); + assertFalse(Float16.ZERO.signbit()); + assertTrue(Float16.NEGATIVE_ZERO.signbit()); + + assertFalse(Float16.ONE.isZero()); + assertTrue(Float16.ONE.isNormal()); + assertEquals(1.0f, Float16.ONE.floatValue(), 0.0f); + } + + @Test + public void testConversionBasic() { + assertEquals((short) 0x0000, Float16.valueOf(0.0f).toBits()); + assertEquals((short) 0x8000, Float16.valueOf(-0.0f).toBits()); + + assertEquals((short) 0x3c00, Float16.valueOf(1.0f).toBits()); + assertEquals((short) 0xbc00, Float16.valueOf(-1.0f).toBits()); + assertEquals((short) 0x4000, Float16.valueOf(2.0f).toBits()); + + assertEquals(1.0f, Float16.valueOf(1.0f).floatValue(), 0.0f); + assertEquals(-1.0f, Float16.valueOf(-1.0f).floatValue(), 0.0f); + assertEquals(2.0f, Float16.valueOf(2.0f).floatValue(), 0.0f); + } + + @Test + public void testBoundaryValues() { + Float16 max = Float16.valueOf(65504.0f); + assertEquals(0x7bff, max.toBits()); + assertEquals(65504.0f, max.floatValue(), 0.0f); + assertTrue(max.isNormal()); + assertTrue(max.isFinite()); + + Float16 minNormal = Float16.valueOf((float) Math.pow(2, -14)); + assertEquals(0x0400, minNormal.toBits()); + assertTrue(minNormal.isNormal()); + assertFalse(minNormal.isSubnormal()); + + Float16 minSubnormal = Float16.valueOf((float) Math.pow(2, -24)); + assertEquals(0x0001, minSubnormal.toBits()); + assertTrue(minSubnormal.isSubnormal()); + assertFalse(minSubnormal.isNormal()); + } + + @Test + public void testOverflowUnderflow() { + Float16 overflow = Float16.valueOf(70000.0f); + assertTrue(overflow.isInfinite()); + assertFalse(overflow.signbit()); + + Float16 negOverflow = Float16.valueOf(-70000.0f); + assertTrue(negOverflow.isInfinite()); + assertTrue(negOverflow.signbit()); + + Float16 underflow = Float16.valueOf((float) Math.pow(2, -30)); + assertTrue(underflow.isZero()); + } + + @Test + public void testRoundingToNearestEven() { + Float16 rounded = Float16.valueOf(1.0009765625f); + assertTrue(Math.abs(rounded.floatValue() - 1.0009765625f) < 0.01f); + } + + @Test + public void testArithmetic() { + Float16 one = Float16.ONE; + Float16 two = Float16.valueOf(2.0f); + Float16 three = Float16.valueOf(3.0f); + + assertEquals(3.0f, one.add(two).floatValue(), 0.01f); + assertEquals(5.0f, two.add(three).floatValue(), 0.01f); + + assertEquals(-1.0f, one.subtract(two).floatValue(), 0.01f); + assertEquals(1.0f, three.subtract(two).floatValue(), 0.01f); + + assertEquals(2.0f, one.multiply(two).floatValue(), 0.01f); + assertEquals(6.0f, two.multiply(three).floatValue(), 0.01f); + + assertEquals(0.5f, one.divide(two).floatValue(), 0.01f); + assertEquals(1.5f, three.divide(two).floatValue(), 0.01f); + + assertEquals(-1.0f, one.negate().floatValue(), 0.0f); + assertEquals(1.0f, one.negate().negate().floatValue(), 0.0f); + + assertEquals(1.0f, one.abs().floatValue(), 0.0f); + assertEquals(1.0f, one.negate().abs().floatValue(), 0.0f); + } + + @Test + public void testComparison() { + Float16 one = Float16.ONE; + Float16 two = Float16.valueOf(2.0f); + Float16 nan = Float16.NaN; + + assertTrue(one.compareTo(two) < 0); + assertTrue(two.compareTo(one) > 0); + assertEquals(0, one.compareTo(Float16.ONE)); + + assertTrue(one.equalsValue(Float16.ONE)); + assertFalse(nan.equalsValue(nan)); + assertTrue(Float16.ZERO.equalsValue(Float16.NEGATIVE_ZERO)); + + assertTrue(one.equals(Float16.ONE)); + assertFalse(Float16.ZERO.equals(Float16.NEGATIVE_ZERO)); + + assertTrue(one.lessThan(two)); + assertFalse(two.lessThan(one)); + assertFalse(nan.lessThan(one)); + assertFalse(one.lessThan(nan)); + + assertTrue(one.lessThanOrEqual(two)); + assertTrue(one.lessThanOrEqual(Float16.ONE)); + assertFalse(nan.lessThanOrEqual(one)); + + assertTrue(two.greaterThan(one)); + assertFalse(one.greaterThan(two)); + assertFalse(nan.greaterThan(one)); + + assertTrue(two.greaterThanOrEqual(one)); + assertTrue(one.greaterThanOrEqual(Float16.ONE)); + assertFalse(nan.greaterThanOrEqual(one)); + + assertTrue(Float16.compare(one, two) < 0); + assertTrue(Float16.compare(two, one) > 0); + assertEquals(0, Float16.compare(one, Float16.ONE)); + } + + @Test + public void testParse() { + assertEquals(1.0f, Float16.parse("1.0").floatValue(), 0.0f); + assertEquals(2.5f, Float16.parse("2.5").floatValue(), 0.01f); + assertEquals(-3.14f, Float16.parse("-3.14").floatValue(), 0.01f); + assertTrue(Float16.parse("NaN").isNaN()); + assertTrue(Float16.parse("Infinity").isInfinite()); + assertTrue(Float16.parse("-Infinity").isInfinite()); + } + + @Test + public void testToFloat() { + Float16 f16 = Float16.valueOf(3.14f); + assertEquals(f16.floatValue(), f16.toFloat(), 0.0f); + } + + @Test + public void testClassification() { + assertTrue(Float16.ONE.isNormal()); + assertFalse(Float16.ZERO.isNormal()); + assertFalse(Float16.NaN.isNormal()); + assertFalse(Float16.POSITIVE_INFINITY.isNormal()); + + Float16 subnormal = Float16.valueOf((float) Math.pow(2, -24)); + assertTrue(subnormal.isSubnormal()); + assertFalse(Float16.ONE.isSubnormal()); + + assertTrue(Float16.ONE.isFinite()); + assertTrue(Float16.ZERO.isFinite()); + assertFalse(Float16.NaN.isFinite()); + assertFalse(Float16.POSITIVE_INFINITY.isFinite()); + + assertFalse(Float16.ONE.signbit()); + assertTrue(Float16.valueOf(-1.0f).signbit()); + assertFalse(Float16.ZERO.signbit()); + assertTrue(Float16.NEGATIVE_ZERO.signbit()); + } + + @Test + public void testToString() { + assertEquals("1.0", Float16.ONE.toString()); + assertEquals("2.0", Float16.valueOf(2.0f).toString()); + assertTrue(Float16.NaN.toString().contains("NaN")); + assertTrue(Float16.POSITIVE_INFINITY.toString().contains("Infinity")); + } + + @Test + public void testHashCode() { + assertEquals(Float16.ONE.hashCode(), Float16.valueOf(1.0f).hashCode()); + + assertNotEquals(Float16.ONE.hashCode(), Float16.valueOf(2.0f).hashCode()); + } + + @Test + public void testNumberConversions() { + Float16 f16 = Float16.valueOf(3.14f); + + assertEquals(3.14f, f16.floatValue(), 0.01f); + + assertEquals(3.14, f16.doubleValue(), 0.01); + + assertEquals(3, f16.intValue()); + + assertEquals(3L, f16.longValue()); + + assertEquals((byte) 3, f16.byteValue()); + + assertEquals((short) 3, f16.shortValue()); + } + + @Test + public void testAllBitPatterns() { + int nanCount = 0; + int normalCount = 0; + int subnormalCount = 0; + int zeroCount = 0; + int infCount = 0; + + for (int bits = 0; bits <= 0xFFFF; bits++) { + Float16 h = Float16.fromBits((short) bits); + float f = h.floatValue(); + Float16 h2 = Float16.valueOf(f); + + if (h.isNaN()) { + nanCount++; + assertTrue(h2.isNaN(), "NaN should remain NaN after round-trip"); + } else { + if (!h.equals(h2)) { + assertEquals( + h.floatValue(), + h2.floatValue(), + 0.0001f, + String.format("Round-trip failed for bits 0x%04x", bits)); + } + + if (h.isNormal()) { + normalCount++; + } else if (h.isSubnormal()) { + subnormalCount++; + } else if (h.isZero()) { + zeroCount++; + } else if (h.isInfinite()) { + infCount++; + } + } + } + + assertTrue(nanCount > 0, "Should have NaN values"); + assertTrue(normalCount > 0, "Should have normal values"); + assertTrue(subnormalCount > 0, "Should have subnormal values"); + assertEquals(2, zeroCount, "Should have exactly 2 zeros (+0 and -0)"); + assertEquals(2, infCount, "Should have exactly 2 infinities (+Inf and -Inf)"); + + int total = nanCount + normalCount + subnormalCount + zeroCount + infCount; + assertEquals(65536, total, "Total should be 65536"); + } + + @Test + public void testConstants() { + assertEquals((short) 0x7e00, Float16.NaN.toBits()); + assertEquals((short) 0x7c00, Float16.POSITIVE_INFINITY.toBits()); + assertEquals((short) 0xfc00, Float16.NEGATIVE_INFINITY.toBits()); + assertEquals((short) 0x0000, Float16.ZERO.toBits()); + assertEquals((short) 0x8000, Float16.NEGATIVE_ZERO.toBits()); + assertEquals((short) 0x3c00, Float16.ONE.toBits()); + assertEquals((short) 0x7bff, Float16.MAX_VALUE.toBits()); + assertEquals((short) 0x0400, Float16.MIN_NORMAL.toBits()); + assertEquals((short) 0x0001, Float16.MIN_VALUE.toBits()); + } + + @Test + public void testSizeConstants() { + assertEquals(16, Float16.SIZE_BITS); + assertEquals(2, Float16.SIZE_BYTES); + } +}