Skip to content
Merged
7 changes: 4 additions & 3 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ module MatrixAlgebraKit
using LinearAlgebra: LinearAlgebra
using LinearAlgebra: norm # TODO: eleminate if we use VectorInterface.jl?
using LinearAlgebra: mul!, rmul!, lmul!, adjoint!, rdiv!, ldiv!
using LinearAlgebra: sylvester
using LinearAlgebra: sylvester, lu!
using LinearAlgebra: isposdef, issymmetric
using LinearAlgebra: Diagonal, diag, diagind, isdiag
using LinearAlgebra: UpperTriangular, LowerTriangular
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt

export isisometry, isunitary, ishermitian, isantihermitian

export project_hermitian, project_antihermitian
export project_hermitian!, project_antihermitian!
export project_hermitian, project_antihermitian, project_isometric
export project_hermitian!, project_antihermitian!, project_isometric!
Comment on lines +14 to +15
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm slightly unhappy with project_isometric vs isisometry. Is it at all reasonable to have project_isometry, or does that not really make sense? Happy to leave as is too, just noticing this here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good catch. But I would prefer isisometric, as all other checks are also using the adjective form, e.g. ishermitian, ...

I know this constitutes yet another breaking change 😄 .

I guess unitary is the odd duck, as that is both an adjective and a noun.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case a breaking change is not warranted, I guess we can take comfort in the fact that the list of is... functions in Base has both nouns and adjectives in there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For what it's worth, I slightly prefer isisometric and project_isometric (and I agree that it is nice if they match).

export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
export svd_compact, svd_full, svd_vals, svd_trunc
Expand All @@ -34,6 +34,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi
export LQViaTransposedQR
export PolarViaSVD, PolarNewton
export DiagonalAlgorithm
export NativeBlocked
export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar,
Expand Down
153 changes: 141 additions & 12 deletions src/implementations/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function check_input(::typeof(left_polar!), A::AbstractMatrix, WP, ::AbstractAlg
@assert W isa AbstractMatrix && P isa AbstractMatrix
@check_size(W, (m, n))
@check_scalar(W, A)
@check_size(P, (n, n))
isempty(P) || @check_size(P, (n, n))
@check_scalar(P, A)
return nothing
end
Expand All @@ -21,7 +21,7 @@ function check_input(::typeof(right_polar!), A::AbstractMatrix, PWᴴ, ::Abstrac
n >= m ||
throw(ArgumentError("input matrix needs at least as many columns as rows"))
@assert P isa AbstractMatrix && Wᴴ isa AbstractMatrix
@check_size(P, (m, m))
isempty(P) || @check_size(P, (m, m))
@check_scalar(P, A)
@check_size(Wᴴ, (m, n))
@check_scalar(Wᴴ, A)
Expand All @@ -43,25 +43,154 @@ function initialize_output(::typeof(right_polar!), A::AbstractMatrix, ::Abstract
return (P, Wᴴ)
end

# Implementation
# --------------
# Implementation via SVD
# -----------------------
function left_polar!(A::AbstractMatrix, WP, alg::PolarViaSVD)
check_input(left_polar!, A, WP, alg)
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
U, S, Vᴴ = svd_compact!(A, alg.svd_alg)
W, P = WP
W = mul!(W, U, Vᴴ)
S .= sqrt.(S)
SsqrtVᴴ = lmul!(S, Vᴴ)
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
if !isempty(P)
S .= sqrt.(S)
SsqrtVᴴ = lmul!(S, Vᴴ)
P = mul!(P, SsqrtVᴴ', SsqrtVᴴ)
end
return (W, P)
end
function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarViaSVD)
check_input(right_polar!, A, PWᴴ, alg)
U, S, Vᴴ = svd_compact!(A, alg.svdalg)
U, S, Vᴴ = svd_compact!(A, alg.svd_alg)
P, Wᴴ = PWᴴ
Wᴴ = mul!(Wᴴ, U, Vᴴ)
S .= sqrt.(S)
USsqrt = rmul!(U, S)
P = mul!(P, USsqrt, USsqrt')
if !isempty(P)
S .= sqrt.(S)
USsqrt = rmul!(U, S)
P = mul!(P, USsqrt, USsqrt')
end
return (P, Wᴴ)
end

# Implementation via Newton
# --------------------------
function left_polar!(A::AbstractMatrix, WP, alg::PolarNewton)
check_input(left_polar!, A, WP, alg)
W, P = WP
if isempty(P)
W = _left_polarnewton!(A, W, P; alg.kwargs...)
return W, P
else
W = _left_polarnewton!(copy(A), W, P; alg.kwargs...)
# we still need `A` to compute `P`
P = project_hermitian!(mul!(P, W', A))
return W, P
end
end

function right_polar!(A::AbstractMatrix, PWᴴ, alg::PolarNewton)
check_input(right_polar!, A, PWᴴ, alg)
P, Wᴴ = PWᴴ
if isempty(P)
Wᴴ = _right_polarnewton!(A, Wᴴ, P; alg.kwargs...)
return P, Wᴴ
else
Wᴴ = _right_polarnewton!(copy(A), Wᴴ, P; alg.kwargs...)
# we still need `A` to compute `P`
P = project_hermitian!(mul!(P, A, Wᴴ'))
return P, Wᴴ
end
end

# these methods only compute W and destroy A in the process
function _left_polarnewton!(A::AbstractMatrix, W, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
m, n = size(A) # we must have m >= n
Rᴴinv = isempty(P) ? similar(P, (n, n)) : P # use P as workspace when available
if m > n # initial QR
Q, R = qr_compact!(A)
Rc = view(A, 1:n, 1:n)
copy!(Rc, R)
Rᴴinv = ldiv!(UpperTriangular(Rc)', one!(Rᴴinv))
else # m == n
R = A
Rc = view(W, 1:n, 1:n)
copy!(Rc, R)
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
end
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
rmul!(R, γ)
rmul!(Rᴴinv, 1 / γ)
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
copy!(Rc, R)
i = 1
conv = norm(Rᴴinv, Inf)
while i < maxiter && conv > tol
Rᴴinv = ldiv!(lu!(Rc)', one!(Rᴴinv))
γ = sqrt(norm(Rᴴinv) / norm(R)) # scaling factor
rmul!(R, γ)
rmul!(Rᴴinv, 1 / γ)
R, Rᴴinv = _avgdiff!(R, Rᴴinv)
copy!(Rc, R)
conv = norm(Rᴴinv, Inf)
i += 1
end
if conv > tol
@warn "`left_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv)"
end
if m > n
return mul!(W, Q, Rc)
end
return W
end

function _right_polarnewton!(A::AbstractMatrix, Wᴴ, P = similar(A, (0, 0)); tol = defaulttol(A), maxiter = 10)
m, n = size(A) # we must have m <= n
Lᴴinv = isempty(P) ? similar(P, (m, m)) : P # use P as workspace when available
if m < n # initial QR
L, Q = lq_compact!(A)
Lc = view(A, 1:m, 1:m)
copy!(Lc, L)
Lᴴinv = ldiv!(LowerTriangular(Lc)', one!(Lᴴinv))
else # m == n
L = A
Lc = view(Wᴴ, 1:m, 1:m)
copy!(Lc, L)
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
end
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
rmul!(L, γ)
rmul!(Lᴴinv, 1 / γ)
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
copy!(Lc, L)
i = 1
conv = norm(Lᴴinv, Inf)
while i < maxiter && conv > tol
Lᴴinv = ldiv!(lu!(Lc)', one!(Lᴴinv))
γ = sqrt(norm(Lᴴinv) / norm(L)) # scaling factor
rmul!(L, γ)
rmul!(Lᴴinv, 1 / γ)
L, Lᴴinv = _avgdiff!(L, Lᴴinv)
copy!(Lc, L)
conv = norm(Lᴴinv, Inf)
i += 1
end
if conv > tol
@warn "`right_polar!` via Newton iteration did not converge within $maxiter iterations (final residual: $conv)"
end
if m < n
return mul!(Wᴴ, Lc, Q)
end
return Wᴴ
end

# in place computation of the average and difference of two arrays
function _avgdiff!(A::AbstractArray, B::AbstractArray)
axes(A) == axes(B) || throw(DimensionMismatch())
@simd for I in eachindex(A, B)
@inbounds begin
a = A[I]
b = B[I]
A[I] = (a + b) / 2
B[I] = b - a
end
end
return A, B
end
35 changes: 29 additions & 6 deletions src/implementations/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ function copy_input(::typeof(project_hermitian), A::AbstractMatrix)
end
copy_input(::typeof(project_antihermitian), A) = copy_input(project_hermitian, A)

copy_input(::typeof(project_isometric), A) = copy_input(left_polar, A)

function check_input(::typeof(project_hermitian!), A::AbstractMatrix, B::AbstractMatrix, ::AbstractAlgorithm)
LinearAlgebra.checksquare(A)
n = Base.require_one_based_indexing(A)
Expand All @@ -18,6 +20,16 @@ function check_input(::typeof(project_antihermitian!), A::AbstractMatrix, B::Abs
return nothing
end

function check_input(::typeof(project_isometric!), A::AbstractMatrix, W::AbstractMatrix, ::AbstractAlgorithm)
m, n = size(A)
m >= n ||
throw(ArgumentError("input matrix needs at least as many rows as columns"))
@assert W isa AbstractMatrix
@check_size(W, (m, n))
@check_scalar(W, A)
return nothing
end

# Outputs
# -------
function initialize_output(::typeof(project_hermitian!), A::AbstractMatrix, ::NativeBlocked)
Expand All @@ -27,15 +39,26 @@ function initialize_output(::typeof(project_antihermitian!), A::AbstractMatrix,
return A
end

function initialize_output(::typeof(project_isometric!), A::AbstractMatrix, ::AbstractAlgorithm)
return similar(A)
end

# Implementation
# --------------
function project_hermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
check_input(project_hermitian!, A, B, alg)
return project_hermitian_native!(A, B, Val(false); alg.kwargs...)
function project_hermitian!(A::AbstractMatrix, Aₕ, alg::NativeBlocked)
check_input(project_hermitian!, A, Aₕ, alg)
return project_hermitian_native!(A, Aₕ, Val(false); alg.kwargs...)
end
function project_antihermitian!(A::AbstractMatrix, B, alg::NativeBlocked)
check_input(project_antihermitian!, A, B, alg)
return project_hermitian_native!(A, B, Val(true); alg.kwargs...)
function project_antihermitian!(A::AbstractMatrix, Aₐ, alg::NativeBlocked)
check_input(project_antihermitian!, A, Aₐ, alg)
return project_hermitian_native!(A, Aₐ, Val(true); alg.kwargs...)
end

function project_isometric!(A::AbstractMatrix, W, alg::AbstractAlgorithm)
check_input(project_isometric!, A, W, alg)
noP = similar(W, (0, 0))
W, _ = left_polar!(A, (W, noP), alg)
return W
end

function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)
Expand Down
22 changes: 22 additions & 0 deletions src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,28 @@ const LAPACK_SVDAlgorithm = Union{
LAPACK_Jacobi,
}

# =========================
# Polar decompositions
# =========================
"""
PolarViaSVD(svd_alg)

Algorithm for computing the polar decomposition of a matrix `A` via the singular value
decomposition (SVD) of `A`. The `svd_alg` argument specifies the SVD algorithm to use.
"""
struct PolarViaSVD{SVDAlg} <: AbstractAlgorithm
svd_alg::SVDAlg
end

"""
PolarNewton(; maxiter = 10, tol = defaulttol(A))

Algorithm for computing the polar decomposition of a matrix `A` via
scaled Newton iteration, with a maximum of `maxiter` iterations and
until convergence up to tolerance `tol`.
"""
@algdef PolarNewton

# =========================
# DIAGONAL ALGORITHMS
# =========================
Expand Down
10 changes: 0 additions & 10 deletions src/interface/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,6 @@ See also [`left_polar(!)`](@ref left_polar).
"""
@functiondef right_polar

"""
PolarViaSVD(svdalg)

Algorithm for computing the polar decomposition of a matrix `A` via the singular value
decomposition (SVD) of `A`. The `svdalg` argument specifies the SVD algorithm to use.
"""
struct PolarViaSVD{SVDAlg} <: AbstractAlgorithm
svdalg::SVDAlg
end

# Algorithm selection
# -------------------
default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...)
Expand Down
37 changes: 31 additions & 6 deletions src/interface/projections.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
@doc """
project_hermitian(A; kwargs...)
project_hermitian(A, alg)
project_hermitian!(A; kwargs...)
project_hermitian!(A, alg)
project_hermitian!(A, [Aₕ]; kwargs...)
project_hermitian!(A, [Aₕ], alg)

Compute the hermitian part of a (square) matrix `A`, defined as `(A + A') / 2`.
For real matrices this corresponds to the symmetric part of `A`.
For real matrices this corresponds to the symmetric part of `A`. In the bang method,
the output storage can be provided via the optional argument `Aₕ`; by default it is
equal to `A` and so the input matrix `A` is replaced by its hermitian projection.

See also [`project_antihermitian`](@ref).
"""
Expand All @@ -14,16 +16,36 @@ See also [`project_antihermitian`](@ref).
@doc """
project_antihermitian(A; kwargs...)
project_antihermitian(A, alg)
project_antihermitian!(A; kwargs...)
project_antihermitian!(A, alg)
project_antihermitian!(A, [Aₐ]; kwargs...)
project_antihermitian!(A, [Aₐ], alg)

Compute the anti-hermitian part of a (square) matrix `A`, defined as `(A - A') / 2`.
For real matrices this corresponds to the antisymmetric part of `A`.
For real matrices this corresponds to the antisymmetric part of `A`. In the bang method,
the output storage can be provided via the optional argument `Aₐ``; by default it is
equal to `A` and so the input matrix `A` is replaced by its antihermitian projection.

See also [`project_hermitian`](@ref).
"""
@functiondef project_antihermitian

@doc """
project_isometric(A; kwargs...)
project_isometric(A, alg)
project_isometric!(A, [W]; kwargs...)
project_isometric!(A, [W], alg)

Compute the projection of `A` onto the manifold of isometric matrices, i.e. matrices
satisfying `A' * A ≈ I`. This projection is computed via the polar decomposition, i.e.
`W` corresponds to the first return value of `left_polar!`, but avoids computing the
positive definite factor explicitly.

!!! note
The bang method `project_isometric!` optionally accepts the output structure and
possibly destroys the input matrix `A`. Always use the return value of the function
as it may not always be possible to use the provided `W` as output.
"""
@functiondef project_isometric

"""
NativeBlocked(; blocksize = 32)

Expand All @@ -43,3 +65,6 @@ for f in (:project_hermitian!, :project_antihermitian!)
return default_hermitian_algorithm(A; kwargs...)
end
end

default_algorithm(::typeof(project_isometric!), ::Type{A}; kwargs...) where {A} =
default_polar_algorithm(A; kwargs...)
Loading