From af31c36a1d4bd2b690b2e83ed9714b236a5bfc26 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Thu, 12 Mar 2026 01:04:39 +0100 Subject: [PATCH 1/5] SafeDivideAndConquer --- src/MatrixAlgebraKit.jl | 2 +- src/implementations/svd.jl | 12 ++++ src/interface/decompositions.jl | 17 ++++++ src/interface/svd.jl | 2 +- src/yalapack.jl | 101 ++++++++++++++++++++++++++------ test/algorithms.jl | 4 +- test/polar.jl | 2 +- test/svd.jl | 1 + 8 files changed, 118 insertions(+), 23 deletions(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index ac95ae651..d0666c943 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -34,7 +34,7 @@ export left_orth!, right_orth!, left_null!, right_null! export Native_HouseholderQR, Native_HouseholderLQ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, - LAPACK_DivideAndConquer, LAPACK_Jacobi + LAPACK_DivideAndConquer, LAPACK_Jacobi, LAPACK_SafeDivideAndConquer export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 7dc19c8df..dfe1fa16d 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -131,6 +131,10 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ) + elseif alg isa LAPACK_SafeDivideAndConquer + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa LAPACK_Bisection throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) elseif alg isa LAPACK_Jacobi @@ -172,6 +176,10 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, diagview(S), U, Vᴴ) + elseif alg isa LAPACK_SafeDivideAndConquer + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ) elseif alg isa LAPACK_Bisection YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi @@ -207,6 +215,10 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, S, U, Vᴴ) + elseif alg isa LAPACK_SafeDivideAndConquer + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + YALAPACK.gesdvd!(A, S, U, Vᴴ) elseif alg isa LAPACK_Bisection YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 253cd30e7..448d51851 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -190,6 +190,22 @@ singular vectors, see also [`gaugefix!`](@ref). # Singular Value Decomposition # ---------------------------- +""" + LAPACK_SafeDivideAndConquer(; fixgauge::Bool = true) + +Algorithm type to denote the LAPACK driver for computing the singular value decomposition of +a general matrix using the Divide and Conquer algorithm, with an additional fallback to +the standard QR Iteration algorithm in case the former fails to converge. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, +see also [`gaugefix!`](@ref). + +!!! warning + This approach requires a copy of the input matrix, and is thus the most memory intensive SVD strategy. + However, as it combines the speed of the Divide and Conquer algorithm with the robustness of the + QR Iteration algorithm, it is the default SVD strategy for LAPACK-based implementations in MatrixAlgebraKit. +""" +@algdef LAPACK_SafeDivideAndConquer + """ LAPACK_Jacobi(; fixgauge::Bool = true) @@ -205,6 +221,7 @@ const LAPACK_SVDAlgorithm = Union{ LAPACK_Bisection, LAPACK_DivideAndConquer, LAPACK_Jacobi, + LAPACK_SafeDivideAndConquer, } # ========================= diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 24611e65a..b973f6c48 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -162,7 +162,7 @@ function default_svd_algorithm(T::Type; kwargs...) throw(MethodError(default_svd_algorithm, (T,))) end function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasVecOrMat} - return LAPACK_DivideAndConquer(; kwargs...) + return LAPACK_SafeDivideAndConquer(; kwargs...) end function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} return DiagonalAlgorithm(; kwargs...) diff --git a/src/yalapack.jl b/src/yalapack.jl index fded1ed33..d0e6aeb5d 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -1967,6 +1967,27 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in chkstride1(A, U, Vᴴ, S) m, n = size(A) minmn = min(m, n) + work = Vector{$elty}(undef, 1) + cmplx = eltype(A) <: Complex + if cmplx + rwork = Vector{$relty}(undef, 5 * minmn) + else + rwork = nothing + end + (S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork) + chklapackerror(info) + return S, U, Vᴴ + end + function _gesvd_body!( + A::AbstractMatrix{$elty}, + S::AbstractVector{$relty}, + U::AbstractMatrix{$elty}, + Vᴴ::AbstractMatrix{$elty}, + work::Vector{$elty}, + rwork::Union{Vector{$relty}, Nothing} + ) + m, n = size(A) + minmn = min(m, n) if length(U) == 0 jobu = 'N' else @@ -2007,16 +2028,11 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in lda = max(1, stride(A, 2)) ldu = max(1, stride(U, 2)) ldv = max(1, stride(Vᴴ, 2)) - work = Vector{$elty}(undef, 1) lwork = BlasInt(-1) - cmplx = eltype(A) <: Complex - if cmplx - rwork = Vector{$relty}(undef, 5 * minmn) - end info = Ref{BlasInt}() for i in 1:2 # first call returns lwork as work[1] #! format: off - if cmplx + if eltype(A) <: Complex ccall((@blasfunc($gesvd), libblastrampoline), Cvoid, (Ref{UInt8}, Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$relty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, @@ -2038,13 +2054,13 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in info, 1, 1) end #! format: on - chklapackerror(info[]) if i == 1 + chklapackerror(info[]) # bail out early if even the workspace query failed lwork = BlasInt(real(work[1])) resize!(work, lwork) end end - return (S, U, Vᴴ) + return (S, U, Vᴴ), info[] end #! format: off function gesdd!( @@ -2058,6 +2074,33 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in chkstride1(A, U, Vᴴ, S) m, n = size(A) minmn = min(m, n) + work = Vector{$elty}(undef, 1) + if eltype(A) <: Complex + if length(U) == 0 && length(Vᴴ) == 0 + lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn + else + lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1) + end + rwork = Vector{$relty}(undef, lrwork) + else + rwork = nothing + end + (S, U, Vᴴ), info = _gesdd_body!(A, S, U, Vᴴ, work, rwork) + chklapackerror(info) + return S, U, Vᴴ + end + #! format: off + function _gesdd_body!( + A::AbstractMatrix{$elty}, + S::AbstractVector{$relty}, + U::AbstractMatrix{$elty}, + Vᴴ::AbstractMatrix{$elty}, + work::Vector{$elty}, + rwork::Union{Vector{$relty}, Nothing} + ) + #! format: on + m, n = size(A) + minmn = min(m, n) if length(U) == 0 && length(Vᴴ) == 0 job = 'N' @@ -2086,19 +2129,12 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in lda = max(1, stride(A, 2)) ldu = max(1, stride(U, 2)) ldv = max(1, stride(Vᴴ, 2)) - work = Vector{$elty}(undef, 1) lwork = BlasInt(-1) - cmplx = eltype(A) <: Complex - if cmplx - lrwork = job == 'N' ? 7 * minmn : - minmn * max(5 * minmn + 7, 2 * max(m, n) + 2 * minmn + 1) - rwork = Vector{$relty}(undef, lrwork) - end iwork = Vector{BlasInt}(undef, 8 * minmn) info = Ref{BlasInt}() for i in 1:2 # first call returns lwork as work[1] #! format: off - if cmplx + if eltype(A) <: Complex ccall((@blasfunc($gesdd), libblastrampoline), Cvoid, (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{$relty}, Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, @@ -2120,8 +2156,8 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in info, 1) end #! format: on - chklapackerror(info[]) if i == 1 + chklapackerror(info[]) # bail out if even the workspace query failed # Work around issue with truncated Float32 representation of lwork in # sgesdd by using nextfloat. See # http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=13&t=4587&p=11036&hilit=sgesdd#p11036 @@ -2131,9 +2167,38 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in resize!(work, lwork) end end - return (S, U, Vᴴ) + return (S, U, Vᴴ), info[] end #! format: off + function gesdvd!( # SafeSVD implementation + A::AbstractMatrix{$elty}, + S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)), + U::AbstractMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)), + Vᴴ::AbstractMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2)) + ) + #! format: on + require_one_based_indexing(A, U, Vᴴ, S) + chkstride1(A, U, Vᴴ, S) + m, n = size(A) + minmn = min(m, n) + work = Vector{$elty}(undef, 1) + if eltype(A) <: Complex + if length(U) == 0 && length(Vᴴ) == 0 + lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn + else + lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1) + end + rwork = Vector{$relty}(undef, lrwork) + else + rwork = nothing + end + (S, U, Vᴴ), info = _gesdd_body!(copy(A), S, U, Vᴴ, work, rwork) + if info != 0 + (S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork) + end + return S, U, Vᴴ + end + #! format: off function gesvdx!( A::AbstractMatrix{$elty}, S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)), diff --git a/test/algorithms.jl b/test/algorithms.jl index 078ed2a08..83384dab8 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -7,7 +7,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, @testset "default_algorithm" begin A = randn(3, 3) for f in (svd_compact!, svd_compact, svd_full!, svd_full) - @test @constinferred(default_algorithm(f, A)) === LAPACK_DivideAndConquer() + @test @constinferred(default_algorithm(f, A)) === LAPACK_SafeDivideAndConquer() end for f in (eig_full!, eig_full, eig_vals!, eig_vals) @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() @@ -21,7 +21,7 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, end for f in (left_polar!, left_polar, right_polar!, right_polar) @test @constinferred(default_algorithm(f, A)) == - PolarViaSVD(LAPACK_DivideAndConquer()) + PolarViaSVD(LAPACK_SafeDivideAndConquer()) end for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null) @test @constinferred(default_algorithm(f, A)) == Householder(; driver = LAPACK()) diff --git a/test/polar.jl b/test/polar.jl index 4f1df5a96..6fa0e8d97 100644 --- a/test/polar.jl +++ b/test/polar.jl @@ -31,7 +31,7 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63) end if !is_buildkite if T ∈ BLASFloats - LAPACK_POLAR_ALGS = (PolarViaSVD.((LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_DivideAndConquer()))..., PolarNewton()) + LAPACK_POLAR_ALGS = (PolarViaSVD.((LAPACK_QRIteration(), LAPACK_Bisection(), LAPACK_SafeDivideAndConquer()))..., PolarNewton()) TestSuite.test_polar(T, (m, n), LAPACK_POLAR_ALGS) if LAPACK.version() ≥ v"3.12.0" LAPACK_JACOBI = (PolarViaSVD(LAPACK_Jacobi()),) diff --git a/test/svd.jl b/test/svd.jl index bc9e8c170..800f191b6 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -51,6 +51,7 @@ for T in (BLASFloats..., GenericFloats...), m in (0, 54), n in (0, 37, m, 63) LAPACK_SVD_ALGS = ( LAPACK_QRIteration(), LAPACK_DivideAndConquer(), + LAPACK_SafeDivideAndConquer(; fixgauge = true), ) TestSuite.test_svd(T, (m, n)) TestSuite.test_svd_algs(T, (m, n), LAPACK_SVD_ALGS) From 8307ad4bdd8e58dc82395b255ac9bfc7f97279a9 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Thu, 12 Mar 2026 01:05:50 +0100 Subject: [PATCH 2/5] more algorithm test changes --- test/algorithms.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/algorithms.jl b/test/algorithms.jl index 83384dab8..4c4883c9a 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -38,7 +38,7 @@ end A = randn(3, 3) for f in (svd_trunc!, svd_trunc) @test @constinferred(select_algorithm(f, A)) === - TruncatedAlgorithm(LAPACK_DivideAndConquer(), notrunc()) + TruncatedAlgorithm(LAPACK_SafeDivideAndConquer(), notrunc()) end for f in (eig_trunc!, eig_trunc) @test @constinferred(select_algorithm(f, A)) === @@ -55,8 +55,8 @@ end @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc = (; maxrank = 2)) end - @test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer() - @test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_DivideAndConquer() + @test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_SafeDivideAndConquer() + @test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_SafeDivideAndConquer() for alg in (:LAPACK_QRIteration, LAPACK_QRIteration, LAPACK_QRIteration()) @test @constinferred(select_algorithm(svd_compact!, A, $alg)) === LAPACK_QRIteration() end From 93c9fe32ceaee80359bdf68d0e9c894d4a25699c Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 12 Mar 2026 13:40:55 +0100 Subject: [PATCH 3/5] Apply suggestions from code review Co-authored-by: Lukas Devos --- src/yalapack.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/yalapack.jl b/src/yalapack.jl index d0e6aeb5d..7181eeb52 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2193,8 +2193,9 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in rwork = nothing end (S, U, Vᴴ), info = _gesdd_body!(copy(A), S, U, Vᴴ, work, rwork) - if info != 0 + if info > 0 (S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork) + chklapackerror(info) end return S, U, Vᴴ end From 5b23753711f07dffebd5f3f279ad88978ae5cfa6 Mon Sep 17 00:00:00 2001 From: Jutho Date: Thu, 12 Mar 2026 22:08:46 +0100 Subject: [PATCH 4/5] Update src/yalapack.jl Co-authored-by: Lukas Devos --- src/yalapack.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/yalapack.jl b/src/yalapack.jl index 7181eeb52..9e4bc50b7 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2195,8 +2195,8 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in (S, U, Vᴴ), info = _gesdd_body!(copy(A), S, U, Vᴴ, work, rwork) if info > 0 (S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork) - chklapackerror(info) - end + end + chklapackerror(info) return S, U, Vᴴ end #! format: off From 0d9989cddc84d23276612b604a2584630a341775 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Thu, 12 Mar 2026 22:21:43 +0100 Subject: [PATCH 5/5] fix formatting --- src/yalapack.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/yalapack.jl b/src/yalapack.jl index 9e4bc50b7..576fe3c5a 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2157,7 +2157,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in end #! format: on if i == 1 - chklapackerror(info[]) # bail out if even the workspace query failed + chklapackerror(info[]) # bail out if even the workspace query failed # Work around issue with truncated Float32 representation of lwork in # sgesdd by using nextfloat. See # http://icl.cs.utk.edu/lapack-forum/viewtopic.php?f=13&t=4587&p=11036&hilit=sgesdd#p11036 @@ -2192,13 +2192,14 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in else rwork = nothing end - (S, U, Vᴴ), info = _gesdd_body!(copy(A), S, U, Vᴴ, work, rwork) + Ac = copy(A) + (S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork) if info > 0 (S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork) - end - chklapackerror(info) + end + chklapackerror(info) return S, U, Vᴴ - end + end #! format: off function gesvdx!( A::AbstractMatrix{$elty},