From 5378c2eb2c4d542a1cbb4a82ca007470c46aeecc Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 2 Dec 2025 13:03:27 -0500 Subject: [PATCH 1/2] Support Subarray{<:Adjoint{<:GPUMatrix}} --- .../MatrixAlgebraKitAMDGPUExt.jl | 18 ++++++++++-------- .../MatrixAlgebraKitCUDAExt.jl | 16 +++++++++------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 258b46d6..3576e456 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -112,31 +112,33 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false}) end # COV_EXCL_STOP +const SupportedROCMatrix{T} = Union{AnyROCMatrix{T}, SubArray{T, 2, <:AnyROCMatrix{T}}} + function MatrixAlgebraKit._project_hermitian_offdiag!( - Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti} + Au::SupportedROCMatrix, Al::SupportedROCMatrix, Bu::SupportedROCMatrix, Bl::SupportedROCMatrix, ::Val{anti} ) where {anti} thread_dim = 512 block_dim = cld(size(Au, 2), thread_dim) @roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti)) return nothing end -function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti} +function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedROCMatrix, B::SupportedROCMatrix, ::Val{anti}) where {anti} thread_dim = 512 block_dim = cld(size(A, 1), thread_dim) @roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti)) return nothing end -# avoids calling the `StridedMatrix` specialization to avoid scalar indexing, +# avoids calling the `SupportedMatrix` specialization to avoid scalar indexing, # use (allocating) fallback instead until we write a dedicated kernel -MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = A == A' -MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) = +MatrixAlgebraKit.ishermitian_exact(A::SupportedROCMatrix) = A == A' +MatrixAlgebraKit.ishermitian_approx(A::SupportedROCMatrix; atol, rtol, kwargs...) = norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) -MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = A == -A' -MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) = +MatrixAlgebraKit.isantihermitian_exact(A::SupportedROCMatrix) = A == -A' +MatrixAlgebraKit.isantihermitian_approx(A::SupportedROCMatrix; atol, rtol, kwargs...) = norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) -function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix) +function MatrixAlgebraKit._avgdiff!(A::SupportedROCMatrix, B::SupportedROCMatrix) axes(A) == axes(B) || throw(DimensionMismatch()) # COV_EXCL_START function _avgdiff_kernel(A, B) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 189f5825..827ecd59 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -136,15 +136,17 @@ function _project_hermitian_diag_kernel(A, B, ::Val{false}) end # COV_EXCL_STOP +const SupportedCuMatrix{T} = Union{AnyCuMatrix{T}, SubArray{T, 2, <:AnyCuMatrix{T}}} + function MatrixAlgebraKit._project_hermitian_offdiag!( - Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti} + Au::SupportedCuMatrix, Al::SupportedCuMatrix, Bu::SupportedCuMatrix, Bl::SupportedCuMatrix, ::Val{anti} ) where {anti} thread_dim = 512 block_dim = cld(size(Au, 2), thread_dim) @cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti)) return nothing end -function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti} +function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedCuMatrix, B::SupportedCuMatrix, ::Val{anti}) where {anti} thread_dim = 512 block_dim = cld(size(A, 1), thread_dim) @cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti)) @@ -153,14 +155,14 @@ end # avoids calling the `StridedMatrix` specialization to avoid scalar indexing, # use (allocating) fallback instead until we write a dedicated kernel -MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = A == A' -MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) = +MatrixAlgebraKit.ishermitian_exact(A::SupportedCuMatrix) = A == A' +MatrixAlgebraKit.ishermitian_approx(A::SupportedCuMatrix; atol, rtol, kwargs...) = norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) -MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = A == -A' -MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) = +MatrixAlgebraKit.isantihermitian_exact(A::SupportedCuMatrix) = A == -A' +MatrixAlgebraKit.isantihermitian_approx(A::SupportedCuMatrix; atol, rtol, kwargs...) = norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) -function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix) +function MatrixAlgebraKit._avgdiff!(A::SupportedCuMatrix, B::SupportedCuMatrix) axes(A) == axes(B) || throw(DimensionMismatch()) # COV_EXCL_START function _avgdiff_kernel(A, B) From d3b130045c94812d9a1117d8fafedf9a1a971f76 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 2 Dec 2025 13:24:12 -0500 Subject: [PATCH 2/2] Restore Strided --- .../MatrixAlgebraKitAMDGPUExt.jl | 12 ++++++------ .../MatrixAlgebraKitCUDAExt.jl | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 3576e456..ff150f24 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -129,16 +129,16 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::SupportedROCMatrix, B::Sup return nothing end -# avoids calling the `SupportedMatrix` specialization to avoid scalar indexing, +# avoids calling the `StridedMatrix` specialization to avoid scalar indexing, # use (allocating) fallback instead until we write a dedicated kernel -MatrixAlgebraKit.ishermitian_exact(A::SupportedROCMatrix) = A == A' -MatrixAlgebraKit.ishermitian_approx(A::SupportedROCMatrix; atol, rtol, kwargs...) = +MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = A == A' +MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) = norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) -MatrixAlgebraKit.isantihermitian_exact(A::SupportedROCMatrix) = A == -A' -MatrixAlgebraKit.isantihermitian_approx(A::SupportedROCMatrix; atol, rtol, kwargs...) = +MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = A == -A' +MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) = norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) -function MatrixAlgebraKit._avgdiff!(A::SupportedROCMatrix, B::SupportedROCMatrix) +function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix) axes(A) == axes(B) || throw(DimensionMismatch()) # COV_EXCL_START function _avgdiff_kernel(A, B) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 827ecd59..4d34dd9e 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -155,14 +155,14 @@ end # avoids calling the `StridedMatrix` specialization to avoid scalar indexing, # use (allocating) fallback instead until we write a dedicated kernel -MatrixAlgebraKit.ishermitian_exact(A::SupportedCuMatrix) = A == A' -MatrixAlgebraKit.ishermitian_approx(A::SupportedCuMatrix; atol, rtol, kwargs...) = +MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = A == A' +MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) = norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) -MatrixAlgebraKit.isantihermitian_exact(A::SupportedCuMatrix) = A == -A' -MatrixAlgebraKit.isantihermitian_approx(A::SupportedCuMatrix; atol, rtol, kwargs...) = +MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = A == -A' +MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) = norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) -function MatrixAlgebraKit._avgdiff!(A::SupportedCuMatrix, B::SupportedCuMatrix) +function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix) axes(A) == axes(B) || throw(DimensionMismatch()) # COV_EXCL_START function _avgdiff_kernel(A, B)