diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 7dc44b72..69ee4171 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -4,13 +4,15 @@ using LinearAlgebra: LinearAlgebra using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl? using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv! using LinearAlgebra: sylvester -using LinearAlgebra: isposdef, ishermitian, issymmetric +using LinearAlgebra: isposdef, issymmetric using LinearAlgebra: Diagonal, diag, diagind, isdiag using LinearAlgebra: UpperTriangular, LowerTriangular using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt -export isisometry, isunitary +export isisometry, isunitary, ishermitian, isantihermitian +export project_hermitian, project_antihermitian +export project_hermitian!, project_antihermitian! export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null! export svd_compact, svd_full, svd_vals, svd_trunc @@ -33,6 +35,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_DivideAndConquer, LAPACK_Jacobi export LQViaTransposedQR export DiagonalAlgorithm +export NativeBlocked export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi, @@ -74,6 +77,7 @@ include("common/gauge.jl") include("yalapack.jl") include("algorithms.jl") +include("interface/projections.jl") include("interface/decompositions.jl") include("interface/truncation.jl") include("interface/qr.jl") @@ -86,6 +90,7 @@ include("interface/schur.jl") include("interface/polar.jl") include("interface/orthnull.jl") +include("implementations/projections.jl") include("implementations/truncation.jl") include("implementations/qr.jl") include("implementations/lq.jl") diff --git a/src/common/matrixproperties.jl b/src/common/matrixproperties.jl index d9541cab..61ff73e9 100644 --- a/src/common/matrixproperties.jl +++ b/src/common/matrixproperties.jl @@ -31,6 +31,10 @@ function isunitary(A; isapprox_kwargs...) return is_left_isometry(A; isapprox_kwargs...) && is_right_isometry(A; isapprox_kwargs...) end +function isunitary(A::AbstractMatrix; isapprox_kwargs...) + size(A, 1) == size(A, 2) || return false + return is_left_isometry(A; isapprox_kwargs...) +end @doc """ is_left_isometry(A; isapprox_kwargs...) -> Bool @@ -41,8 +45,11 @@ The `isapprox_kwargs` can be used to control the tolerances of the equality. See also [`isisometry`](@ref) and [`is_right_isometry`](@ref). """ is_left_isometry -function is_left_isometry(A::AbstractMatrix; isapprox_kwargs...) - return isapprox(A' * A, LinearAlgebra.I; isapprox_kwargs...) +function is_left_isometry(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm) + P = A' * A + nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))` + diagview(P) .-= 1 + return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)` end @doc """ @@ -54,6 +61,86 @@ The `isapprox_kwargs` can be used to control the tolerances of the equality. See also [`isisometry`](@ref) and [`is_left_isometry`](@ref). """ is_right_isometry -function is_right_isometry(A::AbstractMatrix; isapprox_kwargs...) - return isapprox(A * A', LinearAlgebra.I; isapprox_kwargs...) +function is_right_isometry(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm) + P = A * A' + nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))` + diagview(P) .-= 1 + return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)` +end + +""" + ishermitian(A; isapprox_kwargs...) + +Test whether a linear map is Hermitian, i.e. `A = A'`. +The `isapprox_kwargs` can be used to control the tolerances of the equality. +""" +function ishermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...) + if iszero(atol) && iszero(rtol) + return ishermitian_exact(A; kwargs...) + else + return 2 * norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) + end +end +function ishermitian_exact(A) + return A == A' +end +function ishermitian_exact(A::StridedMatrix; kwargs...) + return strided_ishermitian_exact(A, Val(false); kwargs...) +end + +""" + isantihermitian(A; isapprox_kwargs...) + +Test whether a linear map is anti-Hermitian, i.e. `A = -A'`. +The `isapprox_kwargs` can be used to control the tolerances of the equality. +""" +function isantihermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...) + if iszero(atol) && iszero(rtol) + return isantihermitian_exact(A; kwargs...) + else + return 2 * norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A)) + end +end +function isantihermitian_exact(A) + return A == -A' +end +function isantihermitian_exact(A::StridedMatrix; kwargs...) + return strided_ishermitian_exact(A, Val(true); kwargs...) +end + +# blocked implementation of exact checks for strided matrices +# ----------------------------------------------------------- +function strided_ishermitian_exact(A::AbstractMatrix, anti::Val; blocksize = 32) + n = size(A, 1) + for j in 1:blocksize:n + jb = min(blocksize, n - j + 1) + _ishermitian_exact_diag(view(A, j:(j + jb - 1), j:(j + jb - 1)), anti) || return false + for i in 1:blocksize:(j - 1) + ib = blocksize + _ishermitian_exact_offdiag( + view(A, i:(i + ib - 1), j:(j + jb - 1)), + view(A, j:(j + jb - 1), i:(i + ib - 1)), + anti + ) || return false + end + end + return true +end +function _ishermitian_exact_diag(A, ::Val{anti}) where {anti} + n = size(A, 1) + @inbounds for j in 1:n + @simd for i in 1:j + A[i, j] == (anti ? -adjoint(A[j, i]) : adjoint(A[j, i])) || return false + end + end + return true +end +function _ishermitian_exact_offdiag(Al, Au, ::Val{anti}) where {anti} + m, n = size(Al) # == reverse(size(Al)) + @inbounds for j in 1:n + @simd for i in 1:m + Al[i, j] == (anti ? -adjoint(Au[j, i]) : adjoint(Au[j, i])) || return false + end + end + return true end diff --git a/src/implementations/projections.jl b/src/implementations/projections.jl new file mode 100644 index 00000000..c905180b --- /dev/null +++ b/src/implementations/projections.jl @@ -0,0 +1,95 @@ +# Inputs +# ------ +function copy_input(::typeof(project_hermitian), A::AbstractMatrix) + return copy!(similar(A, float(eltype(A))), A) +end +copy_input(::typeof(project_antihermitian), A) = copy_input(project_hermitian, A) + +function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm) + LinearAlgebra.checksquare(A) + n = Base.require_one_based_indexing(A) + B === A || @check_size(B, (n, n)) + return nothing +end +function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm) + LinearAlgebra.checksquare(A) + n = Base.require_one_based_indexing(A) + B === A || @check_size(B, (n, n)) + return nothing +end + +# Outputs +# ------- +function initialize_output(::typeof(project_hermitian!), A::AbstractMatrix, ::NativeBlocked) + return A +end +function initialize_output(::typeof(project_antihermitian!), A::AbstractMatrix, ::NativeBlocked) + return A +end + +# Implementation +# -------------- +function project_hermitian!(A::AbstractMatrix, B, alg::NativeBlocked) + check_input(project_hermitian!, A, B, alg) + return project_hermitian_native!(A, B, Val(false); alg.kwargs...) +end +function project_antihermitian!(A::AbstractMatrix, B, alg::NativeBlocked) + check_input(project_antihermitian!, A, B, alg) + return project_hermitian_native!(A, B, Val(true); alg.kwargs...) +end + +function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32) + n = size(A, 1) + for j in 1:blocksize:n + for i in 1:blocksize:(j - 1) + jb = min(blocksize, n - j + 1) + ib = blocksize + _project_hermitian_offdiag!( + view(A, i:(i + ib - 1), j:(j + jb - 1)), + view(A, j:(j + jb - 1), i:(i + ib - 1)), + view(B, i:(i + ib - 1), j:(j + jb - 1)), + view(B, j:(j + jb - 1), i:(i + ib - 1)), + anti + ) + end + jb = min(blocksize, n - j + 1) + _project_hermitian_diag!( + view(A, j:(j + jb - 1), j:(j + jb - 1)), + view(B, j:(j + jb - 1), j:(j + jb - 1)), + anti + ) + end + return B +end + +function _project_hermitian_offdiag!( + Au::AbstractMatrix, Al::AbstractMatrix, Bu::AbstractMatrix, Bl::AbstractMatrix, ::Val{anti} + ) where {anti} + + m, n = size(Au) # == reverse(size(Au)) + return @inbounds for j in 1:n + @simd for i in 1:m + val = anti ? (Au[i, j] - adjoint(Al[j, i])) / 2 : (Au[i, j] + adjoint(Al[j, i])) / 2 + Bu[i, j] = val + aval = adjoint(val) + Bl[j, i] = anti ? -aval : aval + end + end + return nothing +end +function _project_hermitian_diag!(A::AbstractMatrix, B::AbstractMatrix, ::Val{anti}) where {anti} + n = size(A, 1) + @inbounds for j in 1:n + @simd for i in 1:(j - 1) + val = anti ? (A[i, j] - adjoint(A[j, i])) / 2 : (A[i, j] + adjoint(A[j, i])) / 2 + B[i, j] = val + aval = adjoint(val) + B[j, i] = anti ? -aval : aval + end + B[j, j] = anti ? _imimag(A[j, j]) : real(A[j, j]) + end + return nothing +end + +_imimag(x::Real) = zero(x) +_imimag(x::Complex) = im * imag(x) diff --git a/src/interface/projections.jl b/src/interface/projections.jl new file mode 100644 index 00000000..dbf62b92 --- /dev/null +++ b/src/interface/projections.jl @@ -0,0 +1,45 @@ +@doc """ + project_hermitian(A; kwargs...) + project_hermitian(A, alg) + project_hermitian!(A; kwargs...) + project_hermitian!(A, alg) + +Compute the hermitian part of a (square) matrix `A`, defined as `(A + A') / 2`. +For real matrices this corresponds to the symmetric part of `A`. + +See also [`project_antihermitian`](@ref). +""" +@functiondef project_hermitian + +@doc """ + project_antihermitian(A; kwargs...) + project_antihermitian(A, alg) + project_antihermitian!(A; kwargs...) + project_antihermitian!(A, alg) + +Compute the anti-hermitian part of a (square) matrix `A`, defined as `(A - A') / 2`. +For real matrices this corresponds to the antisymmetric part of `A`. + +See also [`project_hermitian`](@ref). +""" +@functiondef project_antihermitian + +""" +NativeBlocked(; blocksize = 32) + +Algorithm type to denote a native blocked algorithm with given `blocksize` for computing +the hermitian or anti-hermitian part of a matrix. +""" +@algdef NativeBlocked +# TODO: multithreaded? numthreads keyword? + +default_hermitian_algorithm(A; kwargs...) = default_hermitian_algorithm(typeof(A); kwargs...) +function default_hermitian_algorithm(::Type{A}; kwargs...) where {A <: AbstractMatrix} + return NativeBlocked(; kwargs...) +end + +for f in (:project_hermitian!, :project_antihermitian!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_hermitian_algorithm(A; kwargs...) + end +end diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 243d476b..b15f912e 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -42,7 +42,7 @@ function eigh_pullback!( indV = axes(V, 2)[ind] length(indV) == pV || throw(DimensionMismatch()) mul!(view(VᴴΔV, :, indV), V', ΔV) - aVᴴΔV = rmul!(VᴴΔV - VᴴΔV', 1 / 2) + aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work mask = abs.(D' .- D) .< degeneracy_atol Δgauge = norm(view(aVᴴΔV, mask)) @@ -58,7 +58,7 @@ function eigh_pullback!( length(indD) == pD || throw(DimensionMismatch()) view(diagview(aVᴴΔV), indD) .+= real.(ΔDvec) end - # recylce VdΔV space + # recycle VdΔV space ΔA = mul!(ΔA, mul!(VᴴΔV, V, aVᴴΔV), V', 1, 1) elseif !iszerotangent(ΔDmat) ΔDvec = diagview(ΔDmat) @@ -112,7 +112,7 @@ function eigh_trunc_pullback!( if !iszerotangent(ΔV) (n, p) == size(ΔV) || throw(DimensionMismatch()) VᴴΔV = V' * ΔV - aVᴴΔV = rmul!(VᴴΔV - VᴴΔV', 1 / 2) + aVᴴΔV = project_antihermitian!(VᴴΔV) mask = abs.(D' .- D) .< degeneracy_atol Δgauge = norm(view(aVᴴΔV, mask)) diff --git a/src/pullbacks/lq.jl b/src/pullbacks/lq.jl index 5b4db9f3..05e23c38 100644 --- a/src/pullbacks/lq.jl +++ b/src/pullbacks/lq.jl @@ -118,8 +118,8 @@ function lq_null_pullback!( gauge_atol::Real = tol ) if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0 - NᴴΔN = Nᴴ * ΔNᴴ' - Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2) + aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ') + Δgauge = norm(aNᴴΔN) Δgauge < tol || @warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here? diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 522b3201..fabc2c2e 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -11,7 +11,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP) # Extract and check the cotangents ΔW, ΔP = ΔWP if !iszerotangent(ΔP) - ΔP = (ΔP + ΔP') / 2 + ΔP = project_hermitian(ΔP) end M = zero(P) !iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1) @@ -41,7 +41,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ) # Extract and check the cotangents ΔP, ΔWᴴ = ΔPWᴴ if !iszerotangent(ΔP) - ΔP = (ΔP + ΔP') / 2 + ΔP = project_hermitian(ΔP) end M = zero(P) !iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1) diff --git a/src/pullbacks/qr.jl b/src/pullbacks/qr.jl index e1e8ac6c..bc49d6af 100644 --- a/src/pullbacks/qr.jl +++ b/src/pullbacks/qr.jl @@ -117,8 +117,8 @@ function qr_null_pullback!( gauge_atol::Real = tol ) if !iszerotangent(ΔN) && size(N, 2) > 0 - NᴴΔN = N' * ΔN - Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2) + aNᴴΔN = project_antihermitian!(N' * ΔN) + Δgauge = norm(aNᴴΔN) Δgauge < tol || @warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)" diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index c5f6e98b..effceaa3 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -69,8 +69,8 @@ function svd_pullback!( end # Project onto antihermitian part; hermitian part outside of Grassmann tangent space - aUΔU = rmul!(UΔU - UΔU', 1 / 2) - aVΔV = rmul!(VΔV - VΔV', 1 / 2) + aUΔU = project_antihermitian!(UΔU) + aVΔV = project_antihermitian!(VΔV) # check whether cotangents arise from gauge-invariance objective function mask = abs.(Sr' .- Sr) .< degeneracy_atol @@ -159,8 +159,8 @@ function svd_trunc_pullback!( end # Project onto antihermitian part; hermitian part outside of Grassmann tangent space - aUΔU = rmul!(UΔU - UΔU', 1 / 2) - aVΔV = rmul!(VΔV - VΔV', 1 / 2) + aUΔU = project_antihermitian!(UΔU) + aVΔV = project_antihermitian!(VΔV) # check whether cotangents arise from gauge-invariance objective function mask = abs.(S' .- S) .< degeneracy_atol diff --git a/test/projections.jl b/test/projections.jl new file mode 100644 index 00000000..2eb2e28e --- /dev/null +++ b/test/projections.jl @@ -0,0 +1,46 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using LinearAlgebra: Diagonal + +const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "project_(anti)hermitian! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + noisefactor = eps(real(T))^(3 / 4) + for alg in (NativeBlocked(blocksize = 16), NativeBlocked(blocksize = 32), NativeBlocked(blocksize = 64)) + A = randn(rng, T, m, m) + Ah = (A + A') / 2 + Aa = (A - A') / 2 + Ac = copy(A) + + Bh = project_hermitian(A, alg) + @test ishermitian(Bh) + @test Bh ≈ Ah + @test A == Ac + Bh_approx = Bh + noisefactor * Aa + @test !ishermitian(Bh_approx) + @test ishermitian(Bh_approx; rtol = 10 * noisefactor) + + Ba = project_antihermitian(A, alg) + @test isantihermitian(Ba) + @test Ba ≈ Aa + @test A == Ac + Ba_approx = Ba + noisefactor * Ah + @test !isantihermitian(Ba_approx) + @test isantihermitian(Ba_approx; rtol = 10 * noisefactor) + + Bh = project_hermitian!(Ac, alg) + @test Bh === Ac + @test ishermitian(Bh) + @test Bh ≈ Ah + + copy!(Ac, A) + Ba = project_antihermitian!(Ac, alg) + @test Ba === Ac + @test isantihermitian(Ba) + @test Ba ≈ Aa + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 207215ec..10f82214 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,9 @@ if !is_buildkite @safetestset "Algorithms" begin include("algorithms.jl") end + @safetestset "Projections" begin + include("projections.jl") + end @safetestset "Truncate" begin include("truncate.jl") end