From d6779ee22259129cf80ef82799d8dbf22f1e89a4 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 27 Mar 2025 12:53:11 -0700 Subject: [PATCH 1/3] add support for boolean dtypes for dpt.ceil, dpt.floor, and dpt.trunc --- .../include/kernels/elementwise_functions/ceil.hpp | 3 ++- .../include/kernels/elementwise_functions/floor.hpp | 3 ++- .../include/kernels/elementwise_functions/trunc.hpp | 3 ++- dpctl/tests/elementwise/test_floor_ceil_trunc.py | 10 +++++----- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp index 1328df3f4b..c587ba6767 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp @@ -99,7 +99,8 @@ using CeilStridedFunctor = elementwise_common:: template struct CeilOutputType { using value_type = - typename std::disjunction, + typename std::disjunction, + td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp index aaa81b77b9..5bc6b888ef 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp @@ -99,7 +99,8 @@ using FloorStridedFunctor = elementwise_common:: template struct FloorOutputType { using value_type = - typename std::disjunction, + typename std::disjunction, + td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp index 008c5f59b1..4c776e3560 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp @@ -96,7 +96,8 @@ using TruncStridedFunctor = elementwise_common:: template struct TruncOutputType { using value_type = - typename std::disjunction, + typename std::disjunction, + td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, td_ns::TypeMapResultEntry, diff --git a/dpctl/tests/elementwise/test_floor_ceil_trunc.py b/dpctl/tests/elementwise/test_floor_ceil_trunc.py index a6bf956a78..20bb739b2c 100644 --- a/dpctl/tests/elementwise/test_floor_ceil_trunc.py +++ b/dpctl/tests/elementwise/test_floor_ceil_trunc.py @@ -24,13 +24,13 @@ import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported -from .utils import _map_to_device_dtype, _real_value_dtypes +from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_value_dtypes _all_funcs = [(np.floor, dpt.floor), (np.ceil, dpt.ceil), (np.trunc, dpt.trunc)] @pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc]) -@pytest.mark.parametrize("dtype", _real_value_dtypes) +@pytest.mark.parametrize("dtype", _no_complex_dtypes) def test_floor_ceil_trunc_out_type(dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) @@ -69,7 +69,7 @@ def test_floor_ceil_trunc_usm_type(np_call, dpt_call, usm_type): @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) -@pytest.mark.parametrize("dtype", _real_value_dtypes) +@pytest.mark.parametrize("dtype", _no_complex_dtypes) def test_floor_ceil_trunc_order(np_call, dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) @@ -102,7 +102,7 @@ def test_floor_ceil_trunc_error_dtype(dpt_call, dtype): @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) -@pytest.mark.parametrize("dtype", _real_value_dtypes) +@pytest.mark.parametrize("dtype", _no_complex_dtypes) def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) @@ -123,7 +123,7 @@ def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype): @pytest.mark.parametrize("np_call, dpt_call", _all_funcs) -@pytest.mark.parametrize("dtype", _real_value_dtypes) +@pytest.mark.parametrize("dtype", _no_complex_dtypes) def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype): q = get_queue_or_skip() skip_if_dtype_not_supported(dtype, q) From 2fd8eeb632b587579a14da983b96f44d722ada27 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Thu, 27 Mar 2025 12:58:17 -0700 Subject: [PATCH 2/3] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d2fdd39a8..193ecae948 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +* Support for Boolean data-type is added to `dpctl.tensor.ceil`, `dpctl.tensor.floor`, and `dpctl.tensor.trunc` [gh-2033](https://github.com/IntelPython/dpctl/pull/2033) + ### Fixed ## [0.19.0] - Feb. 26, 2025 From 9a79047824de8df4e147b8171bf0c5b28e24d56a Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 31 Mar 2025 08:24:40 -0700 Subject: [PATCH 3/3] update docstring --- dpctl/tensor/_elementwise_funcs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dpctl/tensor/_elementwise_funcs.py b/dpctl/tensor/_elementwise_funcs.py index 001a53ab35..4731ec8631 100644 --- a/dpctl/tensor/_elementwise_funcs.py +++ b/dpctl/tensor/_elementwise_funcs.py @@ -528,7 +528,7 @@ Args: x (usm_ndarray): - Input array, expected to have a real-valued data type. + Input array, expected to have a boolean or real-valued data type. out (Union[usm_ndarray, None], optional): Output array to populate. Array must have the correct shape and the expected data type. @@ -767,7 +767,7 @@ Args: x (usm_ndarray): - Input array, expected to have a real-valued data type. + Input array, expected to have a boolean or real-valued data type. out (Union[usm_ndarray, None], optional): Output array to populate. Array must have the correct shape and the expected data type. @@ -2017,7 +2017,7 @@ Args: x (usm_ndarray): - Input array, expected to have a real-valued data type. + Input array, expected to have a boolean or real-valued data type. out (Union[usm_ndarray, None], optional): Output array to populate. Array must have the correct shape and the expected data type.