diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 258b46d6..ff150f24 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -112,15 +112,17 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false}) end # COV_EXCL_STOP +const SupportedROCMatrix{T} = Union{AnyROCMatrix{T}, SubArray{T, 2, <:AnyROCMatrix{T}}} + function MatrixAlgebraKit._project_hermitian_offdiag!( - Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti} + Au::SupportedROCMatrix, Al::SupportedROCMatrix, Bu::SupportedROCMatrix, Bl::SupportedROCMatrix, ::Val{anti} ) where {anti} thread_dim = 512 block_dim = cld(size(Au, 2), thread_dim) @roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti)) return nothing end -function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti} +function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedROCMatrix, B::SupportedROCMatrix, ::Val{anti}) where {anti} thread_dim = 512 block_dim = cld(size(A, 1), thread_dim) @roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti)) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 189f5825..4d34dd9e 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -136,15 +136,17 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false}) end # COV_EXCL_STOP +const SupportedCuMatrix{T} = Union{AnyCuMatrix{T}, SubArray{T, 2, <:AnyCuMatrix{T}}} + function MatrixAlgebraKit._project_hermitian_offdiag!( - Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti} + Au::SupportedCuMatrix, Al::SupportedCuMatrix, Bu::SupportedCuMatrix, Bl::SupportedCuMatrix, ::Val{anti} ) where {anti} thread_dim = 512 block_dim = cld(size(Au, 2), thread_dim) @cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti)) return nothing end -function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti} +function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedCuMatrix, B::SupportedCuMatrix, ::Val{anti}) where {anti} thread_dim = 512 block_dim = cld(size(A, 1), thread_dim) @cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti))