From 0a447add03b5194f83505f31bb31ec9e7fa4f175 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 25 Apr 2025 15:22:49 -0700 Subject: [PATCH 1/2] Update implementations of `real` and `imag` `imag` uses static constant value of 0 for real inputs and both use sycl complex extension --- .../include/kernels/elementwise_functions/imag.hpp | 13 ++++++++----- .../include/kernels/elementwise_functions/real.hpp | 5 ++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index 89adabff41..2b40d1d0aa 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -56,11 +57,11 @@ using dpctl::tensor::type_utils::is_complex; template struct ImagFunctor { - // is function constant for given argT - using is_constant = typename std::false_type; + using is_constant = + typename std::is_same, std::false_type>; // constant value, if constant - // constexpr resT constant_value = resT{}; + static constexpr resT constant_value = resT{0}; // is function defined for sycl::vec using supports_vec = typename std::false_type; // do both argTy and resTy support sugroup store/load operation @@ -70,11 +71,13 @@ template struct ImagFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - return std::imag(in); + using realT = typename argT::value_type; + using sycl_complexT = typename exprm_ns::complex; + return exprm_ns::imag(sycl_complexT(in)); } else { static_assert(std::is_same_v); - return resT{0}; + return constant_value; } } }; diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index bb22352907..2b661357b7 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -31,6 +31,7 @@ #include #include +#include "sycl_complex.hpp" #include "vec_size_util.hpp" #include "kernels/dpctl_tensor_types.hpp" @@ -70,7 +71,9 @@ template struct RealFunctor resT operator()(const argT &in) const { if constexpr (is_complex::value) { - return std::real(in); + using realT = typename argT::value_type; + using sycl_complexT = typename exprm_ns::complex; + return exprm_ns::real(sycl_complexT(in)); } else { static_assert(std::is_same_v); From 7564d1c4734c5068aeeebee2adcedebddce6928a Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Sun, 27 Apr 2025 09:28:04 -0700 Subject: [PATCH 2/2] Use `is_complex_v` in real and imag --- .../libtensor/include/kernels/elementwise_functions/imag.hpp | 3 ++- .../libtensor/include/kernels/elementwise_functions/real.hpp | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp index 2b40d1d0aa..0fa432546e 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/imag.hpp @@ -54,6 +54,7 @@ using dpctl::tensor::ssize_t; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::is_complex_v; template struct ImagFunctor { @@ -70,7 +71,7 @@ template struct ImagFunctor resT operator()(const argT &in) const { - if constexpr (is_complex::value) { + if constexpr (is_complex_v) { using realT = typename argT::value_type; using sycl_complexT = typename exprm_ns::complex; return exprm_ns::imag(sycl_complexT(in)); diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp index 2b661357b7..04ed3a6e49 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/real.hpp @@ -54,6 +54,7 @@ using dpctl::tensor::ssize_t; namespace td_ns = dpctl::tensor::type_dispatch; using dpctl::tensor::type_utils::is_complex; +using dpctl::tensor::type_utils::is_complex_v; template struct RealFunctor { @@ -70,7 +71,7 @@ template struct RealFunctor resT operator()(const argT &in) const { - if constexpr (is_complex::value) { + if constexpr (is_complex_v) { using realT = typename argT::value_type; using sycl_complexT = typename exprm_ns::complex; return exprm_ns::real(sycl_complexT(in));