diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index fc3bd7ca..e1d24de9 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -52,4 +52,106 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy:: return MatrixAlgebraKit.findtruncated(values, strategy) end +# COV_EXCL_START +function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true}) + m, n = size(Au) + j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x + j > n && return + for i in 1:m + @inbounds begin + val = (Au[i, j] - adjoint(Al[j, i])) / 2 + Bu[i, j] = val + Bl[j, i] = -adjoint(val) + end + end + return +end + +function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false}) + m, n = size(Au) + j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x + j > n && return + for i in 1:m + @inbounds begin + val = (Au[i, j] + adjoint(Al[j, i])) / 2 + Bu[i, j] = val + Bl[j, i] = adjoint(val) + end + end + return +end + +function _project_hermitian_diag_kernel(A, B, ::Val{true}) + n = size(A, 1) + j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x + j > n && return + @inbounds begin + for i in 1:(j - 1) + val = (A[i, j] - adjoint(A[j, i])) / 2 + B[i, j] = val + B[j, i] = -adjoint(val) + end + B[j, j] = MatrixAlgebraKit._imimag(A[j, j]) + end + return +end + +function _project_hermitian_diag_kernel(A, B, ::Val{false}) + n = size(A, 1) + j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x + j > n && return + @inbounds begin + for i in 1:(j - 1) + val = (A[i, j] + adjoint(A[j, i])) / 2 + B[i, j] = val + B[j, i] = adjoint(val) + end + B[j, j] = real(A[j, j]) + end + return +end +# COV_EXCL_STOP + +function MatrixAlgebraKit._project_hermitian_offdiag!( + Au::StridedROCMatrix, Al::StridedROCMatrix, Bu::StridedROCMatrix, Bl::StridedROCMatrix, ::Val{anti} + ) where {anti} + thread_dim = 512 + block_dim = cld(size(Au, 2), thread_dim) + @roc groupsize = thread_dim gridsize = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti)) + return nothing +end +function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::StridedROCMatrix, ::Val{anti}) where {anti} + thread_dim = 512 + block_dim = cld(size(A, 1), thread_dim) + @roc groupsize = thread_dim gridsize = block_dim _project_hermitian_diag_kernel(A, B, Val(anti)) + return nothing +end + +MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A)) +MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== adjoint(A.diag)) + +MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all(A .== -adjoint(A)) +MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag)) + +function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix) + axes(A) == axes(B) || throw(DimensionMismatch()) + # COV_EXCL_START + function _avgdiff_kernel(A, B) + j = workitemIdx().x + (workgroupIdx().x - 1) * workgroupDim().x + j > length(A) && return + @inbounds begin + a = A[j] + b = B[j] + A[j] = (a + b) / 2 + B[j] = b - a + end + return + end + # COV_EXCL_STOP + thread_dim = 512 + block_dim = cld(length(A), thread_dim) + @roc groupsize = thread_dim gridsize = block_dim _avgdiff_kernel(A, B) + return A, B +end + end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index c2200ddb..4a2c77e3 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -9,6 +9,7 @@ using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_ import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd! using CUDA +using CUDA: i32 using LinearAlgebra using LinearAlgebra: BlasFloat @@ -58,4 +59,106 @@ function MatrixAlgebraKit.findtruncated_svd(values::StridedCuVector, strategy::T return MatrixAlgebraKit.findtruncated(values, strategy) end +# COV_EXCL_START +function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{true}) + m, n = size(Au) + j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + j > n && return + for i in 1:m + @inbounds begin + val = (Au[i, j] - adjoint(Al[j, i])) / 2 + Bu[i, j] = val + Bl[j, i] = -adjoint(val) + end + end + return +end + +function _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, ::Val{false}) + m, n = size(Au) + j = threadIdx().x + (blockIdx().x - 1) * blockDim().x + j > n && return + for i in 1:m + @inbounds begin + val = (Au[i, j] + adjoint(Al[j, i])) / 2 + Bu[i, j] = val + Bl[j, i] = adjoint(val) + end + end + return +end + +function _project_hermitian_diag_kernel(A, B, ::Val{true}) + n = size(A, 1) + j = threadIdx().x + (blockIdx().x - 1) * blockDim().x + j > n && return + @inbounds begin + for i in 1i32:(j - 1i32) + val = (A[i, j] - adjoint(A[j, i])) / 2 + B[i, j] = val + B[j, i] = -adjoint(val) + end + B[j, j] = MatrixAlgebraKit._imimag(A[j, j]) + end + return +end + +function _project_hermitian_diag_kernel(A, B, ::Val{false}) + n = size(A, 1) + j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + j > n && return + @inbounds begin + for i in 1i32:(j - 1i32) + val = (A[i, j] + adjoint(A[j, i])) / 2 + B[i, j] = val + B[j, i] = adjoint(val) + end + B[j, j] = real(A[j, j]) + end + return +end +# COV_EXCL_STOP + +function MatrixAlgebraKit._project_hermitian_offdiag!( + Au::StridedCuMatrix, Al::StridedCuMatrix, Bu::StridedCuMatrix, Bl::StridedCuMatrix, ::Val{anti} + ) where {anti} + thread_dim = 512 + block_dim = cld(size(Au, 2), thread_dim) + @cuda threads = thread_dim blocks = block_dim _project_hermitian_offdiag_kernel(Au, Al, Bu, Bl, Val(anti)) + return nothing +end +function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::StridedCuMatrix, ::Val{anti}) where {anti} + thread_dim = 512 + block_dim = cld(size(A, 1), thread_dim) + @cuda threads = thread_dim blocks = block_dim _project_hermitian_diag_kernel(A, B, Val(anti)) + return nothing +end + +MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all(A .== adjoint(A)) +MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== adjoint(A.diag)) + +MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all(A .== -adjoint(A)) +MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag)) + +function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix) + axes(A) == axes(B) || throw(DimensionMismatch()) + # COV_EXCL_START + function _avgdiff_kernel(A, B) + j = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x + j > length(A) && return + @inbounds begin + a = A[j] + b = B[j] + A[j] = (a + b) / 2 + B[j] = b - a + end + return + end + # COV_EXCL_STOP + thread_dim = 512 + block_dim = cld(length(A), thread_dim) + @cuda threads = thread_dim blocks = block_dim _avgdiff_kernel(A, B) + return A, B +end + end diff --git a/test/amd/projections.jl b/test/amd/projections.jl new file mode 100644 index 00000000..b22148e2 --- /dev/null +++ b/test/amd/projections.jl @@ -0,0 +1,78 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, norm +using AMDGPU + +const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + noisefactor = eps(real(T))^(3 / 4) + for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) + A = ROCArray(randn(rng, T, m, m)) + Ah = (A + A') / 2 + Aa = (A - A') / 2 + Ac = copy(A) + + Bh = project_hermitian(A, alg) + @test ishermitian(Bh) + @test Bh ≈ Ah + @test A == Ac + Bh_approx = Bh + noisefactor * Aa + @test !ishermitian(Bh_approx) + @test ishermitian(Bh_approx; rtol = 10 * noisefactor) + + Ba = project_antihermitian(A, alg) + @test isantihermitian(Ba) + @test Ba ≈ Aa + @test A == Ac + Ba_approx = Ba + noisefactor * Ah + @test !isantihermitian(Ba_approx) + @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) + + Bh = project_hermitian!(Ac, alg) + @test Bh === Ac + @test ishermitian(Bh) + @test Bh ≈ Ah + + copy!(Ac, A) + Ba = project_antihermitian!(Ac, alg) + @test Ba === Ac + @test isantihermitian(Ba) + @test Ba ≈ Aa + end +end + +@testset "project_isometric! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m) + k = min(m, n) + svdalgs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi()) + algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO + @testset "algorithm $alg" for alg in algs + A = ROCArray(randn(rng, T, m, n)) + W = project_isometric(A, alg) + @test isisometric(W) + W2 = project_isometric(W, alg) + @test W2 ≈ W # stability of the projection + @test W * (W' * A) ≈ A + + Ac = similar(A) + W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) + @test W2 === W + @test isisometric(W) + + # test that W is closer to A then any other isometry + for k in 1:10 + δA = ROCArray(randn(rng, T, m, n)) + W = project_isometric(A, alg) + W2 = project_isometric(A + δA / 100, alg) + @test norm(A - W2) > norm(A - W) + end + end + end +end diff --git a/test/cuda/projections.jl b/test/cuda/projections.jl new file mode 100644 index 00000000..d9f7bd7f --- /dev/null +++ b/test/cuda/projections.jl @@ -0,0 +1,78 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: LinearAlgebra, Diagonal, norm +using CUDA + +const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + noisefactor = eps(real(T))^(3 / 4) + for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) + A = CuArray(randn(rng, T, m, m)) + Ah = (A + A') / 2 + Aa = (A - A') / 2 + Ac = copy(A) + + Bh = project_hermitian(A, alg) + @test ishermitian(Bh) + @test Bh ≈ Ah + @test A == Ac + Bh_approx = Bh + noisefactor * Aa + @test !ishermitian(Bh_approx) + @test ishermitian(Bh_approx; rtol = 10 * noisefactor) + + Ba = project_antihermitian(A, alg) + @test isantihermitian(Ba) + @test Ba ≈ Aa + @test A == Ac + Ba_approx = Ba + noisefactor * Ah + @test !isantihermitian(Ba_approx) + @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) + + Bh = project_hermitian!(Ac, alg) + @test Bh === Ac + @test ishermitian(Bh) + @test Bh ≈ Ah + + copy!(Ac, A) + Ba = project_antihermitian!(Ac, alg) + @test Ba === Ac + @test isantihermitian(Ba) + @test Ba ≈ Aa + end +end + +@testset "project_isometric! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + @testset "size ($m, $n)" for n in (37, m) + k = min(m, n) + svdalgs = (CUSOLVER_SVDPolar(), CUSOLVER_QRIteration(), CUSOLVER_Jacobi()) + algs = (PolarViaSVD.(svdalgs)...,) # PolarNewton()) # TODO + @testset "algorithm $alg" for alg in algs + A = CuArray(randn(rng, T, m, n)) + W = project_isometric(A, alg) + @test isisometric(W) + W2 = project_isometric(W, alg) + @test W2 ≈ W # stability of the projection + @test W * (W' * A) ≈ A + + Ac = similar(A) + W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) + @test W2 === W + @test isisometric(W) + + # test that W is closer to A then any other isometry + for k in 1:10 + δA = CuArray(randn(rng, T, m, n)) + W = project_isometric(A, alg) + W2 = project_isometric(A + δA / 100, alg) + @test norm(A - W2) > norm(A - W) + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 10f82214..16d5fb5b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,9 @@ if CUDA.functional() @safetestset "CUDA LQ" begin include("cuda/lq.jl") end + @safetestset "CUDA Projections" begin + include("cuda/projections.jl") + end @safetestset "CUDA SVD" begin include("cuda/svd.jl") end @@ -82,6 +85,9 @@ if AMDGPU.functional() @safetestset "AMDGPU LQ" begin include("amd/lq.jl") end + @safetestset "AMDGPU Projections" begin + include("amd/projections.jl") + end @safetestset "AMDGPU SVD" begin include("amd/svd.jl") end