Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export left_orth!, right_orth!, left_null!, right_null!
export Native_HouseholderQR, Native_HouseholderLQ
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi
LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer
export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration
export LQViaTransposedQR
export PolarViaSVD, PolarNewton
Expand Down
12 changes: 12 additions & 0 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ)
elseif alg isa LAPACK_SafeDivideAndConquer
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
elseif alg isa LAPACK_Bisection
throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
elseif alg isa LAPACK_Jacobi
Expand Down Expand Up @@ -172,6 +176,10 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
YALAPACK.gesdd!(A, diagview(S), U, Vᴴ)
elseif alg isa LAPACK_SafeDivideAndConquer
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ)
elseif alg isa LAPACK_Bisection
YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...)
elseif alg isa LAPACK_Jacobi
Expand Down Expand Up @@ -207,6 +215,10 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer"))
YALAPACK.gesdd!(A, S, U, Vᴴ)
elseif alg isa LAPACK_SafeDivideAndConquer
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
YALAPACK.gesdvd!(A, S, U, Vᴴ)
elseif alg isa LAPACK_Bisection
YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...)
elseif alg isa LAPACK_Jacobi
Expand Down
17 changes: 17 additions & 0 deletions src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,22 @@ singular vectors, see also [`gaugefix!`](@ref).

# Singular Value Decomposition
# ----------------------------
"""
LAPACK_SafeDivideAndConquer(; fixgauge::Bool = true)

Algorithm type to denote the LAPACK driver for computing the singular value decomposition of
a general matrix using the Divide and Conquer algorithm, with an additional fallback to
the standard QR Iteration algorithm in case the former fails to converge.
The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors,
see also [`gaugefix!`](@ref).

!!! warning
This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy.
However, as it combines the speed of the Divide and Conquer algorithm with the robustness of the
QR Iteration algorithm, it is the default SVD strategy for LAPACK-based implementations in MatrixAlgebraKit.
"""
@algdef LAPACK_SafeDivideAndConquer

"""
LAPACK_Jacobi(; fixgauge::Bool = true)

Expand All @@ -205,6 +221,7 @@ const LAPACK_SVDAlgorithm = Union{
LAPACK_Bisection,
LAPACK_DivideAndConquer,
LAPACK_Jacobi,
LAPACK_SafeDivideAndConquer,
}

# =========================
Expand Down
2 changes: 1 addition & 1 deletion src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ function default_svd_algorithm(T::Type; kwargs...)
throw(MethodError(default_svd_algorithm, (T,)))
end
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat}
return LAPACK_DivideAndConquer(; kwargs...)
return LAPACK_SafeDivideAndConquer(; kwargs...)
end
function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal}
return DiagonalAlgorithm(; kwargs...)
Expand Down
103 changes: 85 additions & 18 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1967,6 +1967,27 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
minmn = min(m, n)
work = Vector{$elty}(undef, 1)
cmplx = eltype(A) <: Complex
if cmplx
rwork = Vector{$relty}(undef, 5 * minmn)
else
rwork = nothing
end
(S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork)
chklapackerror(info)
return S, U, Vᴴ
end
function _gesvd_body!(
A::AbstractMatrix{$elty},
S::AbstractVector{$relty},
U::AbstractMatrix{$elty},
Vᴴ::AbstractMatrix{$elty},
work::Vector{$elty},
rwork::Union{Vector{$relty}, Nothing}
)
m, n = size(A)
minmn = min(m, n)
if length(U) == 0
jobu = 'N'
else
Expand Down Expand Up @@ -2007,16 +2028,11 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
lda = max(1, stride(A, 2))
ldu = max(1, stride(U, 2))
ldv = max(1, stride(Vᴴ, 2))
work = Vector{$elty}(undef, 1)
lwork = BlasInt(-1)
cmplx = eltype(A) <: Complex
if cmplx
rwork = Vector{$relty}(undef, 5 * minmn)
end
info = Ref{BlasInt}()
for i in 1:2 # first call returns lwork as work[1]
#! format: off
if cmplx
if eltype(A) <: Complex
ccall((@blasfunc($gesvd), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{$relty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Expand All @@ -2038,13 +2054,13 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
info, 1, 1)
end
#! format: on
chklapackerror(info[])
if i == 1
chklapackerror(info[]) # bail out early if even the workspace query failed
lwork = BlasInt(real(work[1]))
resize!(work, lwork)
end
end
return (S, U, Vᴴ)
return (S, U, Vᴴ), info[]
end
#! format: off
function gesdd!(
Expand All @@ -2058,6 +2074,33 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
minmn = min(m, n)
work = Vector{$elty}(undef, 1)
if eltype(A) <: Complex
if length(U) == 0 && length(Vᴴ) == 0
lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn
else
lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1)
end
rwork = Vector{$relty}(undef, lrwork)
else
rwork = nothing
end
(S, U, Vᴴ), info = _gesdd_body!(A, S, U, Vᴴ, work, rwork)
chklapackerror(info)
return S, U, Vᴴ
end
#! format: off
function _gesdd_body!(
A::AbstractMatrix{$elty},
S::AbstractVector{$relty},
U::AbstractMatrix{$elty},
Vᴴ::AbstractMatrix{$elty},
work::Vector{$elty},
rwork::Union{Vector{$relty}, Nothing}
)
#! format: on
m, n = size(A)
minmn = min(m, n)

if length(U) == 0 && length(Vᴴ) == 0
job = 'N'
Expand Down Expand Up @@ -2086,19 +2129,12 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
lda = max(1, stride(A, 2))
ldu = max(1, stride(U, 2))
ldv = max(1, stride(Vᴴ, 2))
work = Vector{$elty}(undef, 1)
lwork = BlasInt(-1)
cmplx = eltype(A) <: Complex
if cmplx
lrwork = job == 'N' ? 7 * minmn :
minmn * max(5 * minmn + 7, 2 * max(m, n) + 2 * minmn + 1)
rwork = Vector{$relty}(undef, lrwork)
end
iwork = Vector{BlasInt}(undef, 8 * minmn)
info = Ref{BlasInt}()
for i in 1:2 # first call returns lwork as work[1]
#! format: off
if cmplx
if eltype(A) <: Complex
ccall((@blasfunc($gesdd), libblastrampoline), Cvoid,
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{$relty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Expand All @@ -2120,8 +2156,8 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
info, 1)
end
#! format: on
chklapackerror(info[])
if i == 1
chklapackerror(info[]) # bail out if even the workspace query failed
# Work around issue with truncated Float32 representation of lwork in
# sgesdd by using nextfloat. See
# http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=13&t=4587&p=11036&hilit=sgesdd#p11036
Expand All @@ -2131,7 +2167,38 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
resize!(work, lwork)
end
end
return (S, U, Vᴴ)
return (S, U, Vᴴ), info[]
end
#! format: off
function gesdvd!( # SafeSVD implementation
A::AbstractMatrix{$elty},
S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)),
U::AbstractMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
Vᴴ::AbstractMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2))
)
#! format: on
require_one_based_indexing(A, U, Vᴴ, S)
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
minmn = min(m, n)
work = Vector{$elty}(undef, 1)
if eltype(A) <: Complex
if length(U) == 0 && length(Vᴴ) == 0
lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn
else
lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1)
end
rwork = Vector{$relty}(undef, lrwork)
else
rwork = nothing
end
Ac = copy(A)
(S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork)
if info > 0
(S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork)
end
chklapackerror(info)
return S, U, Vᴴ
end
#! format: off
function gesvdx!(
Expand Down
10 changes: 5 additions & 5 deletions test/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm,
@testset "default_algorithm" begin
A = randn(3, 3)
for f in (svd_compact!, svd_compact, svd_full!, svd_full)
@test @constinferred(default_algorithm(f, A)) === LAPACK_DivideAndConquer()
@test @constinferred(default_algorithm(f, A)) === LAPACK_SafeDivideAndConquer()
end
for f in (eig_full!, eig_full, eig_vals!, eig_vals)
@test @constinferred(default_algorithm(f, A)) === LAPACK_Expert()
Expand All @@ -21,7 +21,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm,
end
for f in (left_polar!, left_polar, right_polar!, right_polar)
@test @constinferred(default_algorithm(f, A)) ==
PolarViaSVD(LAPACK_DivideAndConquer())
PolarViaSVD(LAPACK_SafeDivideAndConquer())
end
for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null)
@test @constinferred(default_algorithm(f, A)) == Householder(; driver = LAPACK())
Expand All @@ -38,7 +38,7 @@ end
A = randn(3, 3)
for f in (svd_trunc!, svd_trunc)
@test @constinferred(select_algorithm(f, A)) ===
TruncatedAlgorithm(LAPACK_DivideAndConquer(), notrunc())
TruncatedAlgorithm(LAPACK_SafeDivideAndConquer(), notrunc())
end
for f in (eig_trunc!, eig_trunc)
@test @constinferred(select_algorithm(f, A)) ===
Expand All @@ -55,8 +55,8 @@ end
@test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc = (; maxrank = 2))
end

@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer()
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_DivideAndConquer()
@test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_SafeDivideAndConquer()
@test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_SafeDivideAndConquer()
for alg in (:LAPACK_QRIteration, LAPACK_QRIteration, LAPACK_QRIteration())
@test @constinferred(select_algorithm(svd_compact!, A, $alg)) === LAPACK_QRIteration()
end
Expand Down
2 changes: 1 addition & 1 deletion test/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
end
if !is_buildkite
if T ∈ BLASFloats
LAPACK_POLAR_ALGS = (PolarViaSVD.((LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_DivideAndConquer()))..., PolarNewton())
LAPACK_POLAR_ALGS = (PolarViaSVD.((LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_SafeDivideAndConquer()))..., PolarNewton())
TestSuite.test_polar(T, (m, n), LAPACK_POLAR_ALGS)
if LAPACK.version() ≥ v"3.12.0"
LAPACK_JACOBI = (PolarViaSVD(LAPACK_Jacobi()),)
Expand Down
1 change: 1 addition & 0 deletions test/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63)
LAPACK_SVD_ALGS = (
LAPACK_QRIteration(),
LAPACK_DivideAndConquer(),
LAPACK_SafeDivideAndConquer(; fixgauge = true),
)
TestSuite.test_svd(T, (m, n))
TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS)
Expand Down
Loading