|
9 | 9 | #include <cmath> |
10 | 10 | #include <cstring> |
11 | 11 | #include <limits> |
| 12 | +#include "GPUCommonDef.h" |
12 | 13 |
|
13 | 14 | namespace o2 |
14 | 15 | { |
@@ -43,7 +44,7 @@ static_assert( |
43 | 44 | /// Shared implementation between public and internal classes. CRTP pattern. |
44 | 45 | /// </summary> |
45 | 46 | template <class Derived> |
46 | | -struct Float16Impl { |
| 47 | +GPUd() struct Float16Impl { |
47 | 48 | protected: |
48 | 49 | /// <summary> |
49 | 50 | /// Converts from float to uint16_t float16 representation |
@@ -267,7 +268,7 @@ union float32_bits { |
267 | 268 | }; // namespace detail |
268 | 269 |
|
269 | 270 | template <class Derived> |
270 | | -inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept |
| 271 | +GPUd() inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept |
271 | 272 | { |
272 | 273 | detail::float32_bits f{}; |
273 | 274 | f.f = v; |
@@ -316,7 +317,7 @@ inline constexpr uint16_t Float16Impl<Derived>::ToUint16Impl(float v) noexcept |
316 | 317 | } |
317 | 318 |
|
318 | 319 | template <class Derived> |
319 | | -inline float Float16Impl<Derived>::ToFloatImpl() const noexcept |
| 320 | +GPUd() inline float Float16Impl<Derived>::ToFloatImpl() const noexcept |
320 | 321 | { |
321 | 322 | constexpr detail::float32_bits magic = {113 << 23}; |
322 | 323 | constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift |
@@ -349,7 +350,7 @@ inline float Float16Impl<Derived>::ToFloatImpl() const noexcept |
349 | 350 |
|
350 | 351 | /// Shared implementation between public and internal classes. CRTP pattern. |
351 | 352 | template <class Derived> |
352 | | -struct BFloat16Impl { |
| 353 | +GPUd() struct BFloat16Impl { |
353 | 354 | protected: |
354 | 355 | /// <summary> |
355 | 356 | /// Converts from float to uint16_t float16 representation |
@@ -520,7 +521,7 @@ struct BFloat16Impl { |
520 | 521 | }; |
521 | 522 |
|
522 | 523 | template <class Derived> |
523 | | -inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept |
| 524 | +GPUd() inline uint16_t BFloat16Impl<Derived>::ToUint16Impl(float v) noexcept |
524 | 525 | { |
525 | 526 | uint16_t result; |
526 | 527 | if (std::isnan(v)) { |
@@ -595,7 +596,7 @@ inline float BFloat16Impl<Derived>::ToFloatImpl() const noexcept |
595 | 596 | * |
596 | 597 | * \endcode |
597 | 598 | */ |
598 | | -struct Float16_t : OrtDataType::Float16Impl<Float16_t> { |
| 599 | +GPUd() struct Float16_t : OrtDataType::Float16Impl<Float16_t> { |
599 | 600 | private: |
600 | 601 | /// <summary> |
601 | 602 | /// Constructor from a 16-bit representation of a float16 value |
@@ -737,7 +738,7 @@ static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); |
737 | 738 | * |
738 | 739 | * \endcode |
739 | 740 | */ |
740 | | -struct BFloat16_t : OrtDataType::BFloat16Impl<BFloat16_t> { |
| 741 | +GPUd() struct BFloat16_t : OrtDataType::BFloat16Impl<BFloat16_t> { |
741 | 742 | private: |
742 | 743 | /// <summary> |
743 | 744 | /// Constructor from a uint16_t representation of bfloat16 |
|
0 commit comments