Skip to content
9 changes: 7 additions & 2 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ using LinearAlgebra: LinearAlgebra
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
using LinearAlgebra: sylvester
using LinearAlgebra: isposdef, ishermitian, issymmetric
using LinearAlgebra: isposdef, issymmetric
using LinearAlgebra: Diagonal, diag, diagind, isdiag
using LinearAlgebra: UpperTriangular, LowerTriangular
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt

export isisometry, isunitary
export isisometry, isunitary, ishermitian, isantihermitian

export project_hermitian, project_antihermitian
export project_hermitian!, project_antihermitian!
export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
export svd_compact, svd_full, svd_vals, svd_trunc
Expand All @@ -33,6 +35,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_DivideAndConquer, LAPACK_Jacobi
export LQViaTransposedQR
export DiagonalAlgorithm
export NativeBlocked
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar,
CUSOLVER_Jacobi, CUSOLVER_Randomized, CUSOLVER_DivideAndConquer
export ROCSOLVER_HouseholderQR, ROCSOLVER_QRIteration, ROCSOLVER_Jacobi,
Expand Down Expand Up @@ -74,6 +77,7 @@ include("common/gauge.jl")

include("yalapack.jl")
include("algorithms.jl")
include("interface/projections.jl")
include("interface/decompositions.jl")
include("interface/truncation.jl")
include("interface/qr.jl")
Expand All @@ -86,6 +90,7 @@ include("interface/schur.jl")
include("interface/polar.jl")
include("interface/orthnull.jl")

include("implementations/projections.jl")
include("implementations/truncation.jl")
include("implementations/qr.jl")
include("implementations/lq.jl")
Expand Down
95 changes: 91 additions & 4 deletions src/common/matrixproperties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ function isunitary(A; isapprox_kwargs...)
return is_left_isometry(A; isapprox_kwargs...) &&
is_right_isometry(A; isapprox_kwargs...)
end
function isunitary(A::AbstractMatrix; isapprox_kwargs...)
size(A, 1) == size(A, 2) || return false
return is_left_isometry(A; isapprox_kwargs...)
end

@doc """
is_left_isometry(A; isapprox_kwargs...) -> Bool
Expand All @@ -41,8 +45,11 @@ The `isapprox_kwargs` can be used to control the tolerances of the equality.
See also [`isisometry`](@ref) and [`is_right_isometry`](@ref).
""" is_left_isometry

function is_left_isometry(A::AbstractMatrix; isapprox_kwargs...)
return isapprox(A' * A, LinearAlgebra.I; isapprox_kwargs...)
function is_left_isometry(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm)
P = A' * A
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
diagview(P) .-= 1
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
end

@doc """
Expand All @@ -54,6 +61,86 @@ The `isapprox_kwargs` can be used to control the tolerances of the equality.
See also [`isisometry`](@ref) and [`is_left_isometry`](@ref).
""" is_right_isometry

function is_right_isometry(A::AbstractMatrix; isapprox_kwargs...)
return isapprox(A * A', LinearAlgebra.I; isapprox_kwargs...)
function is_right_isometry(A::AbstractMatrix; atol::Real = 0, rtol::Real = defaulttol(A), norm = LinearAlgebra.norm)
P = A * A'
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
diagview(P) .-= 1
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
end

"""
ishermitian(A; isapprox_kwargs...)

Test whether a linear map is Hermitian, i.e. `A = A'`.
The `isapprox_kwargs` can be used to control the tolerances of the equality.
"""
function ishermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...)
if iszero(atol) && iszero(rtol)
return ishermitian_exact(A; kwargs...)
else
return 2 * norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
end
end
function ishermitian_exact(A)
return A == A'
end
function ishermitian_exact(A::StridedMatrix; kwargs...)
return strided_ishermitian_exact(A, Val(false); kwargs...)
end

"""
isantihermitian(A; isapprox_kwargs...)

Test whether a linear map is anti-Hermitian, i.e. `A = -A'`.
The `isapprox_kwargs` can be used to control the tolerances of the equality.
"""
function isantihermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...)
if iszero(atol) && iszero(rtol)
return isantihermitian_exact(A; kwargs...)
else
return 2 * norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
end
end
function isantihermitian_exact(A)
return A == -A'
end
function isantihermitian_exact(A::StridedMatrix; kwargs...)
return strided_ishermitian_exact(A, Val(true); kwargs...)
end

# blocked implementation of exact checks for strided matrices
# -----------------------------------------------------------
function strided_ishermitian_exact(A::AbstractMatrix, anti::Val; blocksize = 32)
n = size(A, 1)
for j in 1:blocksize:n
jb = min(blocksize, n - j + 1)
_ishermitian_exact_diag(view(A, j:(j + jb - 1), j:(j + jb - 1)), anti) || return false
for i in 1:blocksize:(j - 1)
ib = blocksize
_ishermitian_exact_offdiag(
view(A, i:(i + ib - 1), j:(j + jb - 1)),
view(A, j:(j + jb - 1), i:(i + ib - 1)),
anti
) || return false
end
end
return true
end
function _ishermitian_exact_diag(A, ::Val{anti}) where {anti}
n = size(A, 1)
@inbounds for j in 1:n
@simd for i in 1:j
A[i, j] == (anti ? -adjoint(A[j, i]) : adjoint(A[j, i])) || return false
end
end
return true
end
function _ishermitian_exact_offdiag(Al, Au, ::Val{anti}) where {anti}
m, n = size(Al) # == reverse(size(Al))
@inbounds for j in 1:n
@simd for i in 1:m
Al[i, j] == (anti ? -adjoint(Au[j, i]) : adjoint(Au[j, i])) || return false
end
end
return true
end
95 changes: 95 additions & 0 deletions src/implementations/projections.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Inputs
# ------
function copy_input(::typeof(project_hermitian), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end
copy_input(::typeof(project_antihermitian), A) = copy_input(project_hermitian, A)

function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
LinearAlgebra.checksquare(A)
n = Base.require_one_based_indexing(A)
B === A || @check_size(B, (n, n))
return nothing
end
function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
LinearAlgebra.checksquare(A)
n = Base.require_one_based_indexing(A)
B === A || @check_size(B, (n, n))
return nothing
end

# Outputs
# -------
function initialize_output(::typeof(project_hermitian!), A::AbstractMatrix, ::NativeBlocked)
return A
end
function initialize_output(::typeof(project_antihermitian!), A::AbstractMatrix, ::NativeBlocked)
return A
end

# Implementation
# --------------
function project_hermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
check_input(project_hermitian!, A, B, alg)
return project_hermitian_native!(A, B, Val(false); alg.kwargs...)
end
function project_antihermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
check_input(project_antihermitian!, A, B, alg)
return project_hermitian_native!(A, B, Val(true); alg.kwargs...)
end

function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)
n = size(A, 1)
for j in 1:blocksize:n
for i in 1:blocksize:(j - 1)
jb = min(blocksize, n - j + 1)
ib = blocksize
_project_hermitian_offdiag!(
view(A, i:(i + ib - 1), j:(j + jb - 1)),
view(A, j:(j + jb - 1), i:(i + ib - 1)),
view(B, i:(i + ib - 1), j:(j + jb - 1)),
view(B, j:(j + jb - 1), i:(i + ib - 1)),
anti
)
end
jb = min(blocksize, n - j + 1)
_project_hermitian_diag!(
view(A, j:(j + jb - 1), j:(j + jb - 1)),
view(B, j:(j + jb - 1), j:(j + jb - 1)),
anti
)
end
return B
end

function _project_hermitian_offdiag!(
Au::AbstractMatrix, Al::AbstractMatrix, Bu::AbstractMatrix, Bl::AbstractMatrix, ::Val{anti}
) where {anti}

m, n = size(Au) # == reverse(size(Au))
return @inbounds for j in 1:n
@simd for i in 1:m
val = anti ? (Au[i, j] - adjoint(Al[j, i])) / 2 : (Au[i, j] + adjoint(Al[j, i])) / 2
Bu[i, j] = val
aval = adjoint(val)
Bl[j, i] = anti ? -aval : aval
end
end
return nothing
end
function _project_hermitian_diag!(A::AbstractMatrix, B::AbstractMatrix, ::Val{anti}) where {anti}
n = size(A, 1)
@inbounds for j in 1:n
@simd for i in 1:(j - 1)
val = anti ? (A[i, j] - adjoint(A[j, i])) / 2 : (A[i, j] + adjoint(A[j, i])) / 2
B[i, j] = val
aval = adjoint(val)
B[j, i] = anti ? -aval : aval
end
B[j, j] = anti ? _imimag(A[j, j]) : real(A[j, j])
end
return nothing
end

_imimag(x::Real) = zero(x)
_imimag(x::Complex) = im * imag(x)
45 changes: 45 additions & 0 deletions src/interface/projections.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
@doc """
project_hermitian(A; kwargs...)
project_hermitian(A, alg)
project_hermitian!(A; kwargs...)
project_hermitian!(A, alg)

Compute the hermitian part of a (square) matrix `A`, defined as `(A + A') / 2`.
For real matrices this corresponds to the symmetric part of `A`.

See also [`project_antihermitian`](@ref).
"""
@functiondef project_hermitian

@doc """
project_antihermitian(A; kwargs...)
project_antihermitian(A, alg)
project_antihermitian!(A; kwargs...)
project_antihermitian!(A, alg)

Compute the anti-hermitian part of a (square) matrix `A`, defined as `(A - A') / 2`.
For real matrices this corresponds to the antisymmetric part of `A`.

See also [`project_hermitian`](@ref).
"""
@functiondef project_antihermitian

"""
NativeBlocked(; blocksize = 32)

Algorithm type to denote a native blocked algorithm with given `blocksize` for computing
the hermitian or anti-hermitian part of a matrix.
"""
@algdef NativeBlocked
# TODO: multithreaded? numthreads keyword?

default_hermitian_algorithm(A; kwargs...) = default_hermitian_algorithm(typeof(A); kwargs...)
function default_hermitian_algorithm(::Type{A}; kwargs...) where {A <: AbstractMatrix}
return NativeBlocked(; kwargs...)
end

for f in (:project_hermitian!, :project_antihermitian!)
@eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A}
return default_hermitian_algorithm(A; kwargs...)
end
end
6 changes: 3 additions & 3 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function eigh_pullback!(
indV = axes(V, 2)[ind]
length(indV) == pV || throw(DimensionMismatch())
mul!(view(VᴴΔV, :, indV), V', ΔV)
aVᴴΔV = rmul!(VᴴΔV - VᴴΔV', 1 / 2)
aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work

mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Expand All @@ -58,7 +58,7 @@ function eigh_pullback!(
length(indD) == pD || throw(DimensionMismatch())
view(diagview(aVᴴΔV), indD) .+= real.(ΔDvec)
end
# recylce VdΔV space
# recycle VdΔV space
ΔA = mul!(ΔA, mul!(VᴴΔV, V, aVᴴΔV), V', 1, 1)
elseif !iszerotangent(ΔDmat)
ΔDvec = diagview(ΔDmat)
Expand Down Expand Up @@ -112,7 +112,7 @@ function eigh_trunc_pullback!(
if !iszerotangent(ΔV)
(n, p) == size(ΔV) || throw(DimensionMismatch())
VᴴΔV = V' * ΔV
aVᴴΔV = rmul!(VᴴΔV - VᴴΔV', 1 / 2)
aVᴴΔV = project_antihermitian!(VᴴΔV)

mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Expand Down
4 changes: 2 additions & 2 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ function lq_null_pullback!(
gauge_atol::Real = tol
)
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
NᴴΔN = Nᴴ * ΔNᴴ'
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge < tol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
Expand Down
4 changes: 2 additions & 2 deletions src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP)
# Extract and check the cotangents
ΔW, ΔP = ΔWP
if !iszerotangent(ΔP)
ΔP = (ΔP + ΔP') / 2
ΔP = project_hermitian(ΔP)
end
M = zero(P)
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
Expand Down Expand Up @@ -41,7 +41,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ)
# Extract and check the cotangents
ΔP, ΔWᴴ = ΔPWᴴ
if !iszerotangent(ΔP)
ΔP = (ΔP + ΔP') / 2
ΔP = project_hermitian(ΔP)
end
M = zero(P)
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
Expand Down
4 changes: 2 additions & 2 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ function qr_null_pullback!(
gauge_atol::Real = tol
)
if !iszerotangent(ΔN) && size(N, 2) > 0
NᴴΔN = N' * ΔN
Δgauge = norm((NᴴΔN .- NᴴΔN') ./ 2)
aNᴴΔN = project_antihermitian!(N' * ΔN)
Δgauge = norm(aNᴴΔN)
Δgauge < tol ||
@warn "`qr_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"

Expand Down
8 changes: 4 additions & 4 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ function svd_pullback!(
end

# Project onto antihermitian part; hermitian part outside of Grassmann tangent space
aUΔU = rmul!(UΔU - UΔU', 1 / 2)
aVΔV = rmul!(VΔV - VΔV', 1 / 2)
aUΔU = project_antihermitian!(UΔU)
aVΔV = project_antihermitian!(VΔV)

# check whether cotangents arise from gauge-invariance objective function
mask = abs.(Sr' .- Sr) .< degeneracy_atol
Expand Down Expand Up @@ -159,8 +159,8 @@ function svd_trunc_pullback!(
end

# Project onto antihermitian part; hermitian part outside of Grassmann tangent space
aUΔU = rmul!(UΔU - UΔU', 1 / 2)
aVΔV = rmul!(VΔV - VΔV', 1 / 2)
aUΔU = project_antihermitian!(UΔU)
aVΔV = project_antihermitian!(VΔV)

# check whether cotangents arise from gauge-invariance objective function
mask = abs.(S' .- S) .< degeneracy_atol
Expand Down
Loading