From 932a23a3c8072a675e72be69abc1e3f48a4a81c9 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 18 Aug 2025 05:49:17 -0700 Subject: [PATCH 1/4] Extend getrs with trans_code argument --- dpnp/backend/extensions/lapack/getrs.cpp | 20 +++++++++++++++----- dpnp/backend/extensions/lapack/getrs.hpp | 1 + dpnp/backend/extensions/lapack/lapack_py.cpp | 3 ++- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index b7ac5311cb34..81aea21af18f 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -166,6 +166,7 @@ std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, + const int trans_code, const std::vector &depends) { const int a_array_nd = a_array.get_ndim(); @@ -264,11 +265,20 @@ std::pair const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); - // Use transpose::T if the LU-factorized array is passed as C-contiguous. - // For F-contiguous we use transpose::N. - oneapi::mkl::transpose trans = is_a_array_c_contig - ? oneapi::mkl::transpose::T - : oneapi::mkl::transpose::N; + oneapi::mkl::transpose trans; + switch (trans_code) { + case 0: + trans = oneapi::mkl::transpose::N; + break; + case 1: + trans = oneapi::mkl::transpose::T; + break; + case 2: + trans = oneapi::mkl::transpose::C; + break; + default: + throw py::value_error("`trans_code` must be 0 (N), 1 (T), or 2 (C)"); + } char *a_array_data = a_array.get_data(); char *b_array_data = b_array.get_data(); diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 8fa4889c99af..30db88c62fe4 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -37,6 +37,7 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, + const int trans_code, const std::vector &depends = {}); extern void init_getrs_dispatch_vector(void); diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 4d5adfe09e4a..9dc22419e572 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -160,7 +160,8 @@ PYBIND11_MODULE(_lapack_impl, m) "the solves of linear equations with an LU-factored " "square coefficient matrix, with multiple right-hand sides", py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), - py::arg("b_array"), py::arg("depends") = py::list()); + py::arg("b_array"), py::arg("trans_code"), + py::arg("depends") = py::list()); m.def("_orgqr_batch", &lapack_ext::orgqr_batch, "Call `_orgqr_batch` from OneMKL LAPACK library to return " From 5b5a2a5f12f8bdf8c05b8e73e04448cf504c6d3b Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 18 Aug 2025 05:55:46 -0700 Subject: [PATCH 2/4] Pass trans_code to getrs in dpnp_solve() --- dpnp/linalg/dpnp_utils_linalg.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 55d140c5c88f..1a7d452935ab 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2632,6 +2632,12 @@ def dpnp_solve(a, b): _manager = dpu.SequentialOrderManager[exec_q] dev_evs = _manager.submitted_events + # TODO: remove after PR #2558 is merged + # Temporarily set trans_code=1 (transpose) because the LU-factorized + # array is C-contiguous. + # For F-contiguous arrays use 0 (non-transpose) + trans_code = 1 + # use DPCTL tensor function to fill the сopy of the input array # from the input array ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( @@ -2688,6 +2694,7 @@ def dpnp_solve(a, b): a_h.get_array(), ipiv_h.get_array(), b_h.get_array(), + trans_code, depends=[b_copy_ev, getrf_ev], ) _manager.add_event_pair(ht_ev, getrs_ev) From ce878e6079c57d553ba15163a5e2d7ce6949fe84 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 07:22:10 -0700 Subject: [PATCH 3/4] Remove TODO --- dpnp/linalg/dpnp_utils_linalg.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 6e9037c3b356..fdf46174bfce 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2875,12 +2875,6 @@ def dpnp_solve(a, b): _manager = dpu.SequentialOrderManager[exec_q] dep_evs = _manager.submitted_events - # TODO: remove after PR #2558 is merged - # Temporarily set trans_code=1 (transpose) because the LU-factorized - # array is C-contiguous. - # For F-contiguous arrays use 0 (non-transpose) - trans_code = 1 - # use DPCTL tensor function to fill the сopy of the input array # from the input array ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( From bfb82b69785bc7c47faaa3fd61877031679aa87f Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 12 Sep 2025 03:23:32 -0700 Subject: [PATCH 4/4] Expose Transpose enum to Python via pybind11 --- dpnp/backend/extensions/lapack/getrs.cpp | 17 +---------------- dpnp/backend/extensions/lapack/getrs.hpp | 2 +- dpnp/backend/extensions/lapack/lapack_py.cpp | 9 ++++++++- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/dpnp/backend/extensions/lapack/getrs.cpp b/dpnp/backend/extensions/lapack/getrs.cpp index 81aea21af18f..8185f3a06a79 100644 --- a/dpnp/backend/extensions/lapack/getrs.cpp +++ b/dpnp/backend/extensions/lapack/getrs.cpp @@ -166,7 +166,7 @@ std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, - const int trans_code, + oneapi::mkl::transpose trans, const std::vector &depends) { const int a_array_nd = a_array.get_ndim(); @@ -265,21 +265,6 @@ std::pair const std::int64_t lda = std::max(1UL, n); const std::int64_t ldb = std::max(1UL, n); - oneapi::mkl::transpose trans; - switch (trans_code) { - case 0: - trans = oneapi::mkl::transpose::N; - break; - case 1: - trans = oneapi::mkl::transpose::T; - break; - case 2: - trans = oneapi::mkl::transpose::C; - break; - default: - throw py::value_error("`trans_code` must be 0 (N), 1 (T), or 2 (C)"); - } - char *a_array_data = a_array.get_data(); char *b_array_data = b_array.get_data(); char *ipiv_array_data = ipiv_array.get_data(); diff --git a/dpnp/backend/extensions/lapack/getrs.hpp b/dpnp/backend/extensions/lapack/getrs.hpp index 30db88c62fe4..d8952f3f0b3f 100644 --- a/dpnp/backend/extensions/lapack/getrs.hpp +++ b/dpnp/backend/extensions/lapack/getrs.hpp @@ -37,7 +37,7 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, const dpctl::tensor::usm_ndarray &b_array, - const int trans_code, + oneapi::mkl::transpose trans, const std::vector &depends = {}); extern void init_getrs_dispatch_vector(void); diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index 2020db4b5a0e..46471cc2f366 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -76,6 +76,13 @@ void init_dispatch_tables(void) PYBIND11_MODULE(_lapack_impl, m) { + // Expose oneMKL transpose enum to Python + py::enum_(m, "Transpose") + .value("N", oneapi::mkl::transpose::N) + .value("T", oneapi::mkl::transpose::T) + .value("C", oneapi::mkl::transpose::C) + .export_values(); // Optional, allows access like `Transpose.N` + // Register a custom LinAlgError exception in the dpnp.linalg submodule py::module_ linalg_module = py::module_::import("dpnp.linalg"); py::register_exception( @@ -160,7 +167,7 @@ PYBIND11_MODULE(_lapack_impl, m) "the solves of linear equations with an LU-factored " "square coefficient matrix, with multiple right-hand sides", py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), - py::arg("b_array"), py::arg("trans_code"), + py::arg("b_array"), py::arg("trans") = oneapi::mkl::transpose::N, py::arg("depends") = py::list()); m.def("_orgqr_batch", &lapack_ext::orgqr_batch,