diff --git a/CHANGELOG.md b/CHANGELOG.md index b53ae6ef8a26..e2bf08ff9af8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Improved performance of `dpnp.isclose` function by implementing a dedicated kernel for scalar `rtol` and `atol` arguments [#2540](https://github.com/IntelPython/dpnp/pull/2540) * Extended `dpnp.pad` to support `pad_width` keyword as a dictionary [#2535](https://github.com/IntelPython/dpnp/pull/2535) * Redesigned `dpnp.erf` function through pybind11 extension of OneMKL call or dedicated kernel in `ufunc` namespace [#2551](https://github.com/IntelPython/dpnp/pull/2551) +* Improved performance of batched implementation of `dpnp.linalg.det` and `dpnp.linalg.slogdet` [#2572](https://github.com/IntelPython/dpnp/pull/2572) ### Deprecated diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index fdf46174bfce..bb2920e3a999 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -297,26 +297,27 @@ def _batched_lu_factor(a, res_type): batch_size = a.shape[0] a_usm_arr = dpnp.get_usm_ndarray(a) + # `a` must be copied because getrf/getrf_batch destroys the input matrix + a_h = dpnp.empty_like(a, order="C", dtype=res_type) + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_h.get_array(), + sycl_queue=a_sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, copy_ev) + + ipiv_h = dpnp.empty( + (batch_size, n), + dtype=dpnp.int64, + order="C", + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + if use_batch: - # `a` must be copied because getrf_batch destroys the input matrix - a_h = dpnp.empty_like(a, order="C", dtype=res_type) - ipiv_h = dpnp.empty( - (batch_size, n), - dtype=dpnp.int64, - order="C", - usm_type=a_usm_type, - sycl_queue=a_sycl_queue, - ) dev_info_h = [0] * batch_size - ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr, - dst=a_h.get_array(), - sycl_queue=a_sycl_queue, - depends=_manager.submitted_events, - ) - _manager.add_event_pair(ht_ev, copy_ev) - ipiv_stride = n a_stride = a_h.strides[0] @@ -336,63 +337,25 @@ def _batched_lu_factor(a, res_type): ) _manager.add_event_pair(ht_ev, getrf_ev) - dev_info_array = dpnp.array( - dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue - ) - - # Reshape the results back to their original shape - a_h = a_h.reshape(orig_shape) - ipiv_h = ipiv_h.reshape(orig_shape[:-1]) - dev_info_array = dev_info_array.reshape(orig_shape[:-2]) - - return (a_h, ipiv_h, dev_info_array) - - # Initialize lists for storing arrays and events for each batch - a_vecs = [None] * batch_size - ipiv_vecs = [None] * batch_size - dev_info_vecs = [None] * batch_size - - dep_evs = _manager.submitted_events - - # Process each batch - for i in range(batch_size): - # Copy each 2D slice to a new array because getrf will destroy - # the input matrix - a_vecs[i] = dpnp.empty_like(a[i], order="C", dtype=res_type) - - ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( - src=a_usm_arr[i], - dst=a_vecs[i].get_array(), - sycl_queue=a_sycl_queue, - depends=dep_evs, - ) - _manager.add_event_pair(ht_ev, copy_ev) - - ipiv_vecs[i] = dpnp.empty( - (n,), - dtype=dpnp.int64, - order="C", - usm_type=a_usm_type, - sycl_queue=a_sycl_queue, - ) - dev_info_vecs[i] = [0] + else: + dev_info_h = [[0] for _ in range(batch_size)] - # Call the LAPACK extension function _getrf - # to perform LU decomposition on each batch in 'a_vecs[i]' - ht_ev, getrf_ev = li._getrf( - a_sycl_queue, - a_vecs[i].get_array(), - ipiv_vecs[i].get_array(), - dev_info_vecs[i], - depends=[copy_ev], - ) - _manager.add_event_pair(ht_ev, getrf_ev) + # Sequential LU factorization using getrf per slice + for i in range(batch_size): + ht_ev, getrf_ev = li._getrf( + a_sycl_queue, + a_h[i].get_array(), + ipiv_h[i].get_array(), + dev_info_h[i], + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, getrf_ev) # Reshape the results back to their original shape - out_a = dpnp.array(a_vecs, order="C").reshape(orig_shape) - out_ipiv = dpnp.array(ipiv_vecs).reshape(orig_shape[:-1]) + out_a = a_h.reshape(orig_shape) + out_ipiv = ipiv_h.reshape(orig_shape[:-1]) out_dev_info = dpnp.array( - dev_info_vecs, usm_type=a_usm_type, sycl_queue=a_sycl_queue + dev_info_h, usm_type=a_usm_type, sycl_queue=a_sycl_queue ).reshape(orig_shape[:-2]) return (out_a, out_ipiv, out_dev_info)