From d87d681b59a3f25371da003c47f7e3d97232a2c1 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Sun, 9 Mar 2025 19:08:45 -0700 Subject: [PATCH] simplify dot --- dpnp/backend/extensions/blas/blas_py.cpp | 9 +++------ dpnp/backend/extensions/blas/dot_common.hpp | 11 ++++------- dpnp/backend/extensions/window/common.hpp | 3 +-- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/dpnp/backend/extensions/blas/blas_py.cpp b/dpnp/backend/extensions/blas/blas_py.cpp index 06f614818f7..0321ff6fc6b 100644 --- a/dpnp/backend/extensions/blas/blas_py.cpp +++ b/dpnp/backend/extensions/blas/blas_py.cpp @@ -62,8 +62,7 @@ PYBIND11_MODULE(_blas_impl, m) using event_vecT = std::vector; { - dot_ns::init_dot_dispatch_vector( + dot_ns::init_dot_dispatch_vector( dot_dispatch_vector); auto dot_pyapi = [&](sycl::queue &exec_q, const arrayT &src1, @@ -81,8 +80,7 @@ PYBIND11_MODULE(_blas_impl, m) } { - dot_ns::init_dot_dispatch_vector( + dot_ns::init_dot_dispatch_vector( dotc_dispatch_vector); auto dotc_pyapi = [&](sycl::queue &exec_q, const arrayT &src1, @@ -101,8 +99,7 @@ PYBIND11_MODULE(_blas_impl, m) } { - dot_ns::init_dot_dispatch_vector( + dot_ns::init_dot_dispatch_vector( dotu_dispatch_vector); auto dotu_pyapi = [&](sycl::queue &exec_q, const arrayT &src1, diff --git a/dpnp/backend/extensions/blas/dot_common.hpp b/dpnp/backend/extensions/blas/dot_common.hpp index 9e35e50dab8..fb9a1f078c5 100644 --- a/dpnp/backend/extensions/blas/dot_common.hpp +++ b/dpnp/backend/extensions/blas/dot_common.hpp @@ -50,14 +50,13 @@ typedef sycl::event (*dot_impl_fn_ptr_t)(sycl::queue &, namespace dpctl_td_ns = dpctl::tensor::type_dispatch; namespace py = pybind11; -template std::pair dot_func(sycl::queue &exec_q, const dpctl::tensor::usm_ndarray &vectorX, const dpctl::tensor::usm_ndarray &vectorY, const dpctl::tensor::usm_ndarray &result, const std::vector &depends, - const dispatchT &dot_dispatch_vector) + const dot_impl_fn_ptr_t *dot_dispatch_vector) { const int vectorX_nd = vectorX.get_ndim(); const int vectorY_nd = vectorY.get_ndim(); @@ -166,12 +165,10 @@ std::pair return std::make_pair(args_ev, dot_ev); } -template - typename factoryT> -void init_dot_dispatch_vector(dispatchT dot_dispatch_vector[]) +template