From 9a0959d60ed2602f6491fe4fa633aebb4ee92f0e Mon Sep 17 00:00:00 2001 From: charlie Date: Tue, 30 Sep 2025 17:18:43 -0500 Subject: [PATCH 1/3] initial --- src/include/migraphx/generic_float.hpp | 199 ++++++++++++++++++++----- 1 file changed, 165 insertions(+), 34 deletions(-) diff --git a/src/include/migraphx/generic_float.hpp b/src/include/migraphx/generic_float.hpp index b077d16631b..7cbf98e9740 100644 --- a/src/include/migraphx/generic_float.hpp +++ b/src/include/migraphx/generic_float.hpp @@ -71,6 +71,53 @@ struct unsigned_type<8> using type = std::uint64_t; }; +// CRTP base for operators +template +struct generic_float_operators +{ + +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \ + friend constexpr Derived& operator op(Derived & lhs, const Derived & rhs) \ + { \ + float self = lhs; \ + float frhs = rhs; \ + self op frhs; \ + lhs = self; \ + return lhs; \ + } + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=) + MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=) + +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \ + friend constexpr Derived operator op(const Derived& x, const Derived& y) \ + { \ + return Derived(float(x) op float(y)); \ + } + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+) + MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/) + +// NOLINTNEXTLINE +#define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \ + friend constexpr bool operator op(const Derived& x, const Derived& y) \ + { \ + return float(x) op float(y); \ + } + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>) + MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=) + + protected: + // prohibit creation of this base object + generic_float_operators() = default; +}; + struct float32_parts { unsigned int mantissa : 23; @@ -92,6 +139,7 @@ constexpr float32_parts get_parts(float f) { return migraphx::bit_cast struct __attribute__((packed, may_alias)) generic_float + : generic_float_operators> { using type = typename unsigned_type::type; @@ -296,40 +344,6 @@ struct __attribute__((packed, may_alias)) generic_float x.mantissa++; return generic_float{x.to_float() - 1.0f}; } -// NOLINTNEXTLINE -#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \ - constexpr generic_float& operator op(const generic_float & rhs) \ - { \ - float self = *this; \ - float frhs = rhs; \ - self op frhs; \ - *this = generic_float(self); \ - return *this; \ - } - MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=) - MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=) - MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=) - MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=) -// NOLINTNEXTLINE -#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \ - friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \ - { \ - return generic_float(float(x) op float(y)); \ - } - MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*) - MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-) - MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+) - MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/) -// NOLINTNEXTLINE -#define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \ - friend constexpr bool operator op(const generic_float& x, const generic_float& y) \ - { \ - return float(x) op float(y); \ - } - MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<) - MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=) - MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>) - MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=) friend constexpr bool operator==(const generic_float& x, const generic_float& y) { @@ -363,6 +377,123 @@ struct __attribute__((packed, may_alias)) generic_float } }; +template +struct __attribute__((packed, may_alias)) generic_float<0, 8, Flags> + : generic_float_operators> +{ + uint8_t exponent; + + static constexpr int exponent_bias() { return all_ones<7>(); } + + explicit constexpr generic_float(float f = 1.0) noexcept { from_float(get_parts(f)); } + + constexpr generic_float& operator=(float f) noexcept + { + from_float(get_parts(f)); + return *this; + } + + // No sign for this type + constexpr generic_float operator-() const noexcept { return snan(); } + + constexpr generic_float operator+() const noexcept { return *this; } + + constexpr float to_float() const noexcept + { + float32_parts f{}; + f.sign = 0; + f.mantissa = 0; + f.exponent = exponent; + return f.to_float(); + } + + // Extracts only exponent bits from float + constexpr void from_float(float32_parts f) noexcept { exponent = f.exponent; } + + // No subnormal numbers + constexpr bool is_normal() const noexcept { return not is_nan(); } + + constexpr bool is_inf() const noexcept { return false; } + + constexpr bool is_nan() const noexcept { return exponent == all_ones<8>(); } + + constexpr bool is_finite() const noexcept { return not is_nan(); } + + constexpr operator float() const noexcept { return this->to_float(); } + + // doesn't have infinity, returning 2**0 + static constexpr generic_float infinity() + { + generic_float x{}; + x.exponent = all_ones<8>() >> 1u; + return x; + } + + // only one NaN value + static constexpr generic_float snan() + { + generic_float x{}; + x.exponent = all_ones<8>(); + return x; + } + + // only one NaN value + static constexpr generic_float qnan() { return snan(); } + + // min value = 2**(-127) + static constexpr generic_float min() + { + generic_float x{}; + x.exponent = 0; + return x; + } + + // No subnormal numbers + static constexpr generic_float denorm_min() { return min(); } + + static constexpr generic_float lowest() { return min(); } + + // max value = 2**(127) + static constexpr generic_float max() + { + generic_float x{}; + x.exponent = all_ones<8>() - 1; + return x; + } + + // next number from 2**0 is 2**1 so epsilon is 2**0 + static constexpr generic_float epsilon() + { + generic_float x{}; + x.exponent = all_ones<8>() >> 1u; + return x; + } + + friend constexpr bool operator==(const generic_float& x, const generic_float& y) + { + + return x.exponent == y.exponent; + } + + friend constexpr bool operator!=(const generic_float& x, const generic_float& y) + { + return not(x == y); + } + + constexpr generic_float& operator++() noexcept + { + ++exponent; + return *this; + } + + const generic_float operator++(int) noexcept // NOLINT(readability-const-return-type) + { + generic_float temp = *this; + operator++(this->exponent); + return temp; + } +}; + } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx From 23ae414906e0f4501309d5c5c943099eae11e84a Mon Sep 17 00:00:00 2001 From: charlie Date: Fri, 10 Oct 2025 17:00:38 -0500 Subject: [PATCH 2/3] progress --- src/api/include/migraphx/migraphx.h | 3 +- src/include/migraphx/fp8e8m0.hpp | 39 +++++++++++++++++++++++++ src/include/migraphx/shape.hpp | 5 +++- src/include/migraphx/type_traits.hpp | 4 +++ src/netron_output.cpp | 3 +- test/fp8e8m0.cpp | 43 ++++++++++++++++++++++++++++ tools/api/migraphx.h | 3 +- 7 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 src/include/migraphx/fp8e8m0.hpp create mode 100644 test/fp8e8m0.cpp diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index d1fb5ed316b..32f9f27a709 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -49,7 +49,8 @@ m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \ m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \ m(bf16_type, bf16) \ - m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) + m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) \ + m(fp8e8m0_type, fp8e8m0) // clang-format on #ifdef __cplusplus diff --git a/src/include/migraphx/fp8e8m0.hpp b/src/include/migraphx/fp8e8m0.hpp new file mode 100644 index 00000000000..27b835dd900 --- /dev/null +++ b/src/include/migraphx/fp8e8m0.hpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#ifndef MIGRAPHX_GUARD_RTGLIB_FP8E8M0FN_HPP +#define MIGRAPHX_GUARD_RTGLIB_FP8E8M0FN_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +using fp8e8m0 = migraphx::generic_float<0, 8, 0>; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 9448dac30d1..68044b84196 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include @@ -68,7 +69,9 @@ struct MIGRAPHX_EXPORT shape m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \ m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \ m(bf16_type, bf16) \ - m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) // clang-format on + m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) \ + m(fp8e8m0_type, fp8e8m0) + // clang-format on #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, enum type_t diff --git a/src/include/migraphx/type_traits.hpp b/src/include/migraphx/type_traits.hpp index ea42f83af9b..c7e708f7961 100644 --- a/src/include/migraphx/type_traits.hpp +++ b/src/include/migraphx/type_traits.hpp @@ -30,6 +30,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -74,6 +75,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e5m2) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e5m2) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e5m2) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, fp8e8m0) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, fp8e8m0) + template using accumulator_type = std::conditional_t{}, diff --git a/src/netron_output.cpp b/src/netron_output.cpp index 8323556f169..1d07ce248ba 100644 --- a/src/netron_output.cpp +++ b/src/netron_output.cpp @@ -57,7 +57,8 @@ int get_onnx_type(shape::type_t s_type) case shape::fp8e5m2_type: return 19; case shape::fp8e5m2fnuz_type: return 20; case shape::tuple_type: return 0; - case shape::fp4x2_type: return 21; // TODO update this when the type is added + case shape::fp4x2_type: return 23; + case shape::fp8e8m0_type: return 24; } MIGRAPHX_THROW("MIGraphX type " + std::to_string(s_type) + " not supported"); } diff --git a/test/fp8e8m0.cpp b/test/fp8e8m0.cpp new file mode 100644 index 00000000000..dcc2e44b59f --- /dev/null +++ b/test/fp8e8m0.cpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include "test.hpp" + +#include +#include +#include +#include +#include +#include + +template +static bool bit_equal(const T& x, const U& y) +{ + static_assert(sizeof(T) == sizeof(U)); + using type = std::array; + return migraphx::bit_cast(x) == migraphx::bit_cast(y); +} diff --git a/tools/api/migraphx.h b/tools/api/migraphx.h index 90834fc7dc9..8e017115702 100644 --- a/tools/api/migraphx.h +++ b/tools/api/migraphx.h @@ -49,7 +49,8 @@ m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \ m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \ m(bf16_type, bf16) \ - m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) + m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) \ + m(fp8e8m0_type, fp8e8m0) // clang-format on #ifdef __cplusplus From cb257bd7d40e41c512bb7564439089d2fc6f19f9 Mon Sep 17 00:00:00 2001 From: charlie Date: Thu, 23 Oct 2025 15:59:29 -0500 Subject: [PATCH 3/3] Make E8M0 tests and fix bugs --- src/include/migraphx/generic_float.hpp | 37 ++++- src/targets/gpu/gemm_impl.cpp | 1 + src/targets/gpu/hip_gemm_impl.cpp | 1 + test/fp8e8m0.cpp | 182 +++++++++++++++++++++++++ 4 files changed, 216 insertions(+), 5 deletions(-) diff --git a/src/include/migraphx/generic_float.hpp b/src/include/migraphx/generic_float.hpp index 7cbf98e9740..54a94b40054 100644 --- a/src/include/migraphx/generic_float.hpp +++ b/src/include/migraphx/generic_float.hpp @@ -276,6 +276,8 @@ struct __attribute__((packed, may_alias)) generic_float return exponent == all_ones() and mantissa != 0; } + constexpr bool has_infinity() const noexcept { return true; } + constexpr bool is_finite() const noexcept { return exponent != all_ones(); } constexpr operator float() const noexcept { return this->to_float(); } @@ -402,23 +404,48 @@ struct __attribute__((packed, may_alias)) generic_float<0, 8, Flags> { float32_parts f{}; f.sign = 0; - f.mantissa = 0; + if(exponent == 0) + { + // 2^(-127) is a fp32 denormal number + f.mantissa = 1; + f.mantissa = f.mantissa << (float32_parts::mantissa_width() - 1); + } + else if(exponent == all_ones<8>()) + { + // setting to fp32 qNaN + f.mantissa = (1 << (float32_parts::mantissa_width() - 1)) + 1; + } + else + { + f.mantissa = 0; + } f.exponent = exponent; return f.to_float(); } - // Extracts only exponent bits from float + /** + * Extracts only exponent bits from float. + * All fp32 denorm numbers will go to fp8e8m0{2^(-127)}. + * All fp32 NaN and infinity go to fp8e8m0{NaN}. + */ constexpr void from_float(float32_parts f) noexcept { exponent = f.exponent; } - // No subnormal numbers + // No denorm numbers in fp8e8m0. constexpr bool is_normal() const noexcept { return not is_nan(); } + // No infinity numbers in fp8e8m0. constexpr bool is_inf() const noexcept { return false; } constexpr bool is_nan() const noexcept { return exponent == all_ones<8>(); } constexpr bool is_finite() const noexcept { return not is_nan(); } + constexpr bool has_infinity() const noexcept + { + return false; + ; + } + constexpr operator float() const noexcept { return this->to_float(); } // doesn't have infinity, returning 2**0 @@ -448,7 +475,7 @@ struct __attribute__((packed, may_alias)) generic_float<0, 8, Flags> return x; } - // No subnormal numbers + // No subnormal numbers in FP8E8M0 static constexpr generic_float denorm_min() { return min(); } static constexpr generic_float lowest() { return min(); } @@ -504,7 +531,7 @@ template class numeric_limits> { public: - static constexpr bool has_infinity = true; + static constexpr bool has_infinity = not(M == 0 and E == 0); static constexpr migraphx::generic_float epsilon() { return migraphx::generic_float::epsilon(); diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index 8968616a8de..3f8e7dc77a8 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -75,6 +75,7 @@ static rocblas_datatype get_type(shape::type_t type) case shape::int64_type: case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); case shape::bf16_type: return rocblas_datatype_bf16_r; + case shape::fp8e8m0_type: } MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); diff --git a/src/targets/gpu/hip_gemm_impl.cpp b/src/targets/gpu/hip_gemm_impl.cpp index 7dda04c4e9a..be82e569dc4 100644 --- a/src/targets/gpu/hip_gemm_impl.cpp +++ b/src/targets/gpu/hip_gemm_impl.cpp @@ -73,6 +73,7 @@ static hipDataType get_type_hipblas(shape::type_t type) case shape::int64_type: case shape::uint64_type: MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!"); case shape::bf16_type: return HIP_R_16BF; + case shape::fp8e8m0_type: } MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!"); diff --git a/test/fp8e8m0.cpp b/test/fp8e8m0.cpp index dcc2e44b59f..c1422938da6 100644 --- a/test/fp8e8m0.cpp +++ b/test/fp8e8m0.cpp @@ -41,3 +41,185 @@ static bool bit_equal(const T& x, const U& y) using type = std::array; return migraphx::bit_cast(x) == migraphx::bit_cast(y); } + +TEST_CASE(check_numeric_limits) +{ + CHECK(bit_equal(std::numeric_limits::min(), uint8_t{0x00})); + CHECK(bit_equal(std::numeric_limits::lowest(), uint8_t{0x00})); + CHECK(bit_equal(std::numeric_limits::max(), uint8_t{0xfe})); + CHECK(bit_equal(std::numeric_limits::epsilon(), uint8_t{0x7f})); + CHECK(bit_equal(std::numeric_limits::denorm_min(), uint8_t{0x00})); + CHECK(bit_equal(std::numeric_limits::infinity(), uint8_t{0x7f})); + CHECK(bit_equal(std::numeric_limits::quiet_NaN(), uint8_t{0xff})); + CHECK(bit_equal(std::numeric_limits::signaling_NaN(), uint8_t{0xff})); +} + +const static std::map& fp8e8m0_lut() +{ + static const std::map result = { + {5.877471754111437539843683E-39, 0x00}, + {-5.877471754111437539843683E-39, 0x00}, + {1.175494350822287507968737E-38, 0x01}, + {-1.175494350822287507968737E-38, 0x01}, + {7.888609052210118054117286E-31, 0x1b}, + {-7.888609052210118054117286E-31, 0x1b}, + {1.355252715606880542509316E-20, 0x3d}, + {-1.355252715606880542509316E-20, 0x3d}, + {0.0625, 0x7b}, + {-0.0625, 0x7b}, + {0.125, 0x7c}, + {-0.125, 0x7c}, + {0.25, 0x7d}, + {-0.25, 0x7d}, + {0.5, 0x7e}, + {-0.5, 0x7e}, + {1., 0x7f}, + {-1., 0x7f}, + {2., 0x80}, + {-2., 0x80}, + {4., 0x81}, + {-4., 0x81}, + {8., 0x82}, + {-8., 0x82}, + {16., 0x83}, + {-16., 0x83}, + {268435456, 0x9b}, + {-268435456, 0x9b}, + {4611686018427387904, 0xbd}, + {-4611686018427387904, 0xbd}, + {1.701411834604692317316873E+38, 0xfe}, + {-1.701411834604692317316873E+38, 0xfe}, + {std::numeric_limits::quiet_NaN(), 0xff}}; + return result; +} +TEST_CASE(check_fp8e8m0_values) +{ + for(auto [f, x] : fp8e8m0_lut()) + { + + auto h = migraphx::bit_cast(x); + if(std::isnan(f)) + { + CHECK(std::isnan(h)); + } + else + { + CHECK(bit_equal(x, migraphx::fp8e8m0(f))); + // absolute value because fp8e8m0 only takes exponent bits + // and has no negative numbers + CHECK(migraphx::float_equal(float(h), std::abs(f))); + } + } +} + +TEST_CASE(check_flows) +{ + // underflow, lowest possible value is 2^(-127) + CHECK(bit_equal(std::numeric_limits::min() * + std::numeric_limits::min(), + std::numeric_limits::min())); + // overflow + CHECK(bit_equal(std::numeric_limits::max() + + std::numeric_limits::max(), + std::numeric_limits::signaling_NaN())); + CHECK(bit_equal(std::numeric_limits::max() + + std::numeric_limits::min(), + std::numeric_limits::max())); +} + +TEST_CASE(test_nan) +{ + float f_qnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8e8m0 fp8e8m0_qnan(f_qnan); + EXPECT(fp8e8m0_qnan.is_nan()); + EXPECT(std::isnan(fp8e8m0_qnan)); + + float f_snan = std::numeric_limits::signaling_NaN(); + migraphx::fp8e8m0 fp8e8m0_snan(f_snan); + EXPECT(fp8e8m0_snan.is_nan()); + EXPECT(std::isnan(fp8e8m0_snan)); +} + +TEST_CASE(test_bool) +{ + // there is no zero value in E8M0 + float two = 2.0; + float other = 0.125; + migraphx::fp8e8m0 fp8e8m0_two(two); + migraphx::fp8e8m0 fp8e8m0_other(other); + EXPECT(static_cast(std::numeric_limits::min())); + EXPECT(static_cast(fp8e8m0_two)); + EXPECT(static_cast(fp8e8m0_other)); +} + +TEST_CASE(test_pos_infinity) +{ + float finf = std::numeric_limits::infinity(); + CHECK(bit_equal(migraphx::fp8e8m0(finf), std::numeric_limits::quiet_NaN())); +} + +TEST_CASE(test_neg_infinity) +{ + float finf = -1.0 * std::numeric_limits::infinity(); + CHECK(bit_equal(migraphx::fp8e8m0(finf), std::numeric_limits::quiet_NaN())); +} + +TEST_CASE(test_f32_max) +{ + float fmax = std::numeric_limits::max(); + CHECK(bit_equal(migraphx::fp8e8m0(fmax), std::numeric_limits::max())); +} + +TEST_CASE(test_f32_denorm_min) +{ + float fmin = std::numeric_limits::denorm_min(); + CHECK(bit_equal(migraphx::fp8e8m0(fmin), std::numeric_limits::min())); +} + +TEST_CASE(test_f32_lowest) +{ + float flowest = std::numeric_limits::lowest(); + CHECK(bit_equal(migraphx::fp8e8m0(flowest), std::numeric_limits::max())); +} + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8e8m0(0.0))); + EXPECT(std::isfinite(migraphx::fp8e8m0(-0.0))); + EXPECT( + not std::isfinite(migraphx::fp8e8m0(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8e8m0(1.0); + auto b = migraphx::fp8e8m0(2.0); + auto c = migraphx::fp8e8m0(4.0); + EXPECT(migraphx::float_equal((a + a), b)); + EXPECT(migraphx::float_equal((b + b), c)); + EXPECT(migraphx::float_equal((a + b), b)); + + auto e = migraphx::fp8e8m0(8.0); + auto f = migraphx::fp8e8m0(0.125); + EXPECT(e > f); + EXPECT(f < e); + EXPECT(f <= e); + EXPECT(e >= f); + EXPECT(e <= e); + EXPECT(f >= f); + EXPECT(not migraphx::float_equal(f, e)); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8e8m0(4.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("4") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); }