diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 04bee40e..258b46d6 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -127,18 +127,14 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid return nothing end -MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A)) -MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = - all(A.diag .== adjoint(A.diag)) -MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; kwargs...) = - @invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...) - -MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = - all(A .== -adjoint(A)) -MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = - all(A.diag .== -adjoint(A.diag)) -MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; kwargs...) = - @invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...) +# avoids calling the `StridedMatrix` specialization to avoid scalar indexing, +# use (allocating) fallback instead until we write a dedicated kernel +MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = A == A' +MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) = + norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) +MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = A == -A' +MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; atol, rtol, kwargs...) = + norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix) axes(A) == axes(B) || throw(DimensionMismatch()) diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index e4c77453..189f5825 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -151,19 +151,14 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::Stride return nothing end -MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = - all(A .== adjoint(A)) -MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = - all(A.diag .== adjoint(A.diag)) -MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; kwargs...) = - @invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...) - -MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = - all(A .== -adjoint(A)) -MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = - all(A.diag .== -adjoint(A.diag)) -MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; kwargs...) = - @invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...) +# avoids calling the `StridedMatrix` specialization to avoid scalar indexing, +# use (allocating) fallback instead until we write a dedicated kernel +MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = A == A' +MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) = + norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) +MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = A == -A' +MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; atol, rtol, kwargs...) = + norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix) axes(A) == axes(B) || throw(DimensionMismatch()) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index d3a04741..493f3a91 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -3,7 +3,7 @@ module MatrixAlgebraKit using LinearAlgebra: LinearAlgebra using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl? using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv! -using LinearAlgebra: sylvester, lu! +using LinearAlgebra: sylvester, lu!, diagm using LinearAlgebra: isposdef, issymmetric using LinearAlgebra: Diagonal, diag, diagind, isdiag using LinearAlgebra: UpperTriangular, LowerTriangular diff --git a/src/common/matrixproperties.jl b/src/common/matrixproperties.jl index 925f92e6..5a73176e 100644 --- a/src/common/matrixproperties.jl +++ b/src/common/matrixproperties.jl @@ -79,10 +79,13 @@ end ishermitian_exact(A) = A == A' ishermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(false); kwargs...) +ishermitian_exact(A::Diagonal) = diagonal_ishermitian_exact(A, Val(false)) + function ishermitian_approx(A; atol, rtol, kwargs...) return norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) end ishermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(false); kwargs...) +ishermitian_approx(A::Diagonal; kwargs...) = diagonal_ishermitian_approx(A, Val(false); kwargs...) """ isantihermitian(A; isapprox_kwargs...) @@ -97,16 +100,15 @@ function isantihermitian(A; atol::Real = 0, rtol::Real = 0, kwargs...) return isantihermitian_approx(A; atol, rtol, kwargs...) end end -function isantihermitian_exact(A) - return A == -A' -end -function isantihermitian_exact(A::StridedMatrix; kwargs...) - return strided_ishermitian_exact(A, Val(true); kwargs...) -end +isantihermitian_exact(A) = A == -A' +isantihermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(true); kwargs...) +isantihermitian_exact(A::Diagonal) = diagonal_ishermitian_exact(A, Val(true)) + function isantihermitian_approx(A; atol, rtol, kwargs...) return norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) end isantihermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(true); kwargs...) +isantihermitian_approx(A::Diagonal; kwargs...) = diagonal_ishermitian_approx(A, Val(true); kwargs...) # blocked implementation of exact checks for strided matrices # ----------------------------------------------------------- @@ -145,7 +147,6 @@ function _ishermitian_exact_offdiag(Al, Au, ::Val{anti}) where {anti} return true end - function strided_ishermitian_approx( A::AbstractMatrix, anti::Val; blocksize = 32, atol::Real = default_hermitian_tol(A), rtol::Real = 0 @@ -192,3 +193,16 @@ function _ishermitian_approx_offdiag(Al, Au, ::Val{anti}) where {anti} end return ϵ² end + +diagonal_ishermitian_exact(A, ::Val{anti}) where {anti} = all(iszero ∘ (anti ? real : imag), diagview(A)) + +function diagonal_ishermitian_approx( + A, ::Val{anti}; atol::Real = default_hermitian_tol(A), rtol::Real = 0 + ) where {anti} + m, n = size(A) + m == n || throw(DimensionMismatch()) + init = abs2(zero(eltype(A))) + ϵ² = sum(abs2 ∘ (anti ? real : imag), diagview(A); init) + ϵ²max = oftype(ϵ², rtol > 0 ? max(atol, rtol * norm(A)) : atol)^2 + return ϵ² ≤ ϵ²max +end diff --git a/src/implementations/projections.jl b/src/implementations/projections.jl index 8baf09a5..857366ed 100644 --- a/src/implementations/projections.jl +++ b/src/implementations/projections.jl @@ -9,13 +9,15 @@ copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A) function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm) LinearAlgebra.checksquare(A) - n = Base.require_one_based_indexing(A) + Base.require_one_based_indexing(A) + n = size(A, 1) B === A || @check_size(B, (n, n)) return nothing end function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm) LinearAlgebra.checksquare(A) - n = Base.require_one_based_indexing(A) + Base.require_one_based_indexing(A) + n = size(A, 1) B === A || @check_size(B, (n, n)) return nothing end @@ -61,6 +63,15 @@ function project_isometric!(A::AbstractMatrix, W, alg::AbstractAlgorithm) return W end +function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti} + if anti + diagview(A) .= _imimag.(diagview(B)) + else + diagview(A) .= real.(diagview(B)) + end + return A +end + function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32) n = size(A, 1) for j in 1:blocksize:n diff --git a/test/amd/projections.jl b/test/amd/projections.jl index b22148e2..a06b152c 100644 --- a/test/amd/projections.jl +++ b/test/amd/projections.jl @@ -12,37 +12,40 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) m = 54 noisefactor = eps(real(T))^(3 / 4) for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) - A = ROCArray(randn(rng, T, m, m)) - Ah = (A + A') / 2 - Aa = (A - A') / 2 - Ac = copy(A) + for A in (ROCArray(randn(rng, T, m, m)), Diagonal(ROCArray(randn(rng, T, m)))) + Ah = (A + A') / 2 + Aa = (A - A') / 2 + Ac = copy(A) - Bh = project_hermitian(A, alg) - @test ishermitian(Bh) - @test Bh ≈ Ah - @test A == Ac - Bh_approx = Bh + noisefactor * Aa - @test !ishermitian(Bh_approx) - @test ishermitian(Bh_approx; rtol = 10 * noisefactor) + Bh = project_hermitian(A, alg) + @test ishermitian(Bh) + @test Bh ≈ Ah + @test A == Ac + Bh_approx = Bh + noisefactor * Aa + # this is still hermitian for real Diagonal: |A - A'| == 0 + @test !ishermitian(Bh_approx) || norm(Aa) == 0 + @test ishermitian(Bh_approx; rtol = 10 * noisefactor) - Ba = project_antihermitian(A, alg) - @test isantihermitian(Ba) - @test Ba ≈ Aa - @test A == Ac - Ba_approx = Ba + noisefactor * Ah - @test !isantihermitian(Ba_approx) - @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) + Ba = project_antihermitian(A, alg) + @test isantihermitian(Ba) + @test Ba ≈ Aa + @test A == Ac + Ba_approx = Ba + noisefactor * Ah + @test !isantihermitian(Ba_approx) + # this is never anti-hermitian for real Diagonal: |A - A'| == 0 + @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0 - Bh = project_hermitian!(Ac, alg) - @test Bh === Ac - @test ishermitian(Bh) - @test Bh ≈ Ah + Bh = project_hermitian!(Ac, alg) + @test Bh === Ac + @test ishermitian(Bh) + @test Bh ≈ Ah - copy!(Ac, A) - Ba = project_antihermitian!(Ac, alg) - @test Ba === Ac - @test isantihermitian(Ba) - @test Ba ≈ Aa + copy!(Ac, A) + Ba = project_antihermitian!(Ac, alg) + @test Ba === Ac + @test isantihermitian(Ba) + @test Ba ≈ Aa + end end end @@ -68,10 +71,33 @@ end # test that W is closer to A then any other isometry for k in 1:10 - δA = ROCArray(randn(rng, T, m, n)) + δA = ROCArray(randn(rng, T, size(A)...)) + W = project_isometric(A, alg) + W2 = project_isometric(A + δA / 100, alg) + @test norm(A - W2) >= norm(A - W) + end + end + + m == n && @testset "DiagonalAlgorithm" begin + A = Diagonal(ROCArray(randn(rng, T, m))) + alg = PolarViaSVD(DiagonalAlgorithm()) + W = project_isometric(A, alg) + @test isisometric(W) + W2 = project_isometric(W, alg) + @test W2 ≈ W # stability of the projection + @test W * (W' * A) ≈ A + + Ac = similar(A) + W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) + @test W2 === W + @test isisometric(W) + + # test that W is closer to A then any other isometry + for k in 1:10 + δA = Diagonal(ROCArray(randn(rng, T, m))) W = project_isometric(A, alg) W2 = project_isometric(A + δA / 100, alg) - @test norm(A - W2) > norm(A - W) + @test norm(A - W2) >= norm(A - W) end end end diff --git a/test/cuda/projections.jl b/test/cuda/projections.jl index d9f7bd7f..677ef520 100644 --- a/test/cuda/projections.jl +++ b/test/cuda/projections.jl @@ -12,37 +12,40 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) m = 54 noisefactor = eps(real(T))^(3 / 4) for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) - A = CuArray(randn(rng, T, m, m)) - Ah = (A + A') / 2 - Aa = (A - A') / 2 - Ac = copy(A) + for A in (CuArray(randn(rng, T, m, m)), Diagonal(CuArray(randn(rng, T, m)))) + Ah = (A + A') / 2 + Aa = (A - A') / 2 + Ac = copy(A) - Bh = project_hermitian(A, alg) - @test ishermitian(Bh) - @test Bh ≈ Ah - @test A == Ac - Bh_approx = Bh + noisefactor * Aa - @test !ishermitian(Bh_approx) - @test ishermitian(Bh_approx; rtol = 10 * noisefactor) + Bh = project_hermitian(A, alg) + @test ishermitian(Bh) + @test Bh ≈ Ah + @test A == Ac + Bh_approx = Bh + noisefactor * Aa + # this is still hermitian for real Diagonal: |A - A'| == 0 + @test !ishermitian(Bh_approx) || norm(Aa) == 0 + @test ishermitian(Bh_approx; rtol = 10 * noisefactor) - Ba = project_antihermitian(A, alg) - @test isantihermitian(Ba) - @test Ba ≈ Aa - @test A == Ac - Ba_approx = Ba + noisefactor * Ah - @test !isantihermitian(Ba_approx) - @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) + Ba = project_antihermitian(A, alg) + @test isantihermitian(Ba) + @test Ba ≈ Aa + @test A == Ac + Ba_approx = Ba + noisefactor * Ah + @test !isantihermitian(Ba_approx) + # this is never anti-hermitian for real Diagonal: |A - A'| == 0 + @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0 - Bh = project_hermitian!(Ac, alg) - @test Bh === Ac - @test ishermitian(Bh) - @test Bh ≈ Ah + Bh = project_hermitian!(Ac, alg) + @test Bh === Ac + @test ishermitian(Bh) + @test Bh ≈ Ah - copy!(Ac, A) - Ba = project_antihermitian!(Ac, alg) - @test Ba === Ac - @test isantihermitian(Ba) - @test Ba ≈ Aa + copy!(Ac, A) + Ba = project_antihermitian!(Ac, alg) + @test Ba === Ac + @test isantihermitian(Ba) + @test Ba ≈ Aa + end end end @@ -68,10 +71,33 @@ end # test that W is closer to A then any other isometry for k in 1:10 - δA = CuArray(randn(rng, T, m, n)) + δA = CuArray(randn(rng, T, size(A)...)) + W = project_isometric(A, alg) + W2 = project_isometric(A + δA / 100, alg) + @test norm(A - W2) >= norm(A - W) + end + end + + m == n && @testset "DiagonalAlgorithm" begin + A = Diagonal(CuArray(randn(rng, T, m))) + alg = PolarViaSVD(DiagonalAlgorithm()) + W = project_isometric(A, alg) + @test isisometric(W) + W2 = project_isometric(W, alg) + @test W2 ≈ W # stability of the projection + @test W * (W' * A) ≈ A + + Ac = similar(A) + W2 = @constinferred project_isometric!(copy!(Ac, A), W, alg) + @test W2 === W + @test isisometric(W) + + # test that W is closer to A then any other isometry + for k in 1:10 + δA = Diagonal(CuArray(randn(rng, T, m))) W = project_isometric(A, alg) W2 = project_isometric(A + δA / 100, alg) - @test norm(A - W2) > norm(A - W) + @test norm(A - W2) >= norm(A - W) end end end diff --git a/test/projections.jl b/test/projections.jl index 56bc1a04..3923528e 100644 --- a/test/projections.jl +++ b/test/projections.jl @@ -11,37 +11,40 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) m = 54 noisefactor = eps(real(T))^(3 / 4) for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) - A = randn(rng, T, m, m) - Ah = (A + A') / 2 - Aa = (A - A') / 2 - Ac = copy(A) + for A in (randn(rng, T, m, m), Diagonal(randn(rng, T, m))) + Ah = (A + A') / 2 + Aa = (A - A') / 2 + Ac = copy(A) - Bh = project_hermitian(A, alg) - @test ishermitian(Bh) - @test Bh ≈ Ah - @test A == Ac - Bh_approx = Bh + noisefactor * Aa - @test !ishermitian(Bh_approx) - @test ishermitian(Bh_approx; rtol = 10 * noisefactor) + Bh = project_hermitian(A, alg) + @test ishermitian(Bh) + @test Bh ≈ Ah + @test A == Ac + Bh_approx = Bh + noisefactor * Aa + # this is still hermitian for real Diagonal: |A - A'| == 0 + @test !ishermitian(Bh_approx) || norm(Aa) == 0 + @test ishermitian(Bh_approx; rtol = 10 * noisefactor) - Ba = project_antihermitian(A, alg) - @test isantihermitian(Ba) - @test Ba ≈ Aa - @test A == Ac - Ba_approx = Ba + noisefactor * Ah - @test !isantihermitian(Ba_approx) - @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) + Ba = project_antihermitian(A, alg) + @test isantihermitian(Ba) + @test Ba ≈ Aa + @test A == Ac + Ba_approx = Ba + noisefactor * Ah + @test !isantihermitian(Ba_approx) + # this is never anti-hermitian for real Diagonal: |A - A'| == 0 + @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) || norm(Aa) == 0 - Bh = project_hermitian!(Ac, alg) - @test Bh === Ac - @test ishermitian(Bh) - @test Bh ≈ Ah + Bh = project_hermitian!(Ac, alg) + @test Bh === Ac + @test ishermitian(Bh) + @test Bh ≈ Ah - copy!(Ac, A) - Ba = project_antihermitian!(Ac, alg) - @test Ba === Ac - @test isantihermitian(Ba) - @test Ba ≈ Aa + copy!(Ac, A) + Ba = project_antihermitian!(Ac, alg) + @test Ba === Ac + @test isantihermitian(Ba) + @test Ba ≈ Aa + end end # test approximate error calculation