diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index acfe0d83..c25a79ad 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -188,6 +188,15 @@ for svd_f in (:svd_compact, :svd_full) end return USVᴴ, svd_pullback end + function ChainRulesCore.rrule(::typeof($svd_f), A, alg) + USVᴴ = $(svd_f)(A, alg) + function svd_pullback(ΔUSVᴴ) + ΔA = zero(A) + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ)) + return NoTangent(), ΔA, NoTangent() + end + return USVᴴ, svd_pullback + end end end @@ -196,43 +205,57 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) ϵ = truncation_error(diagview(USVᴴ[2]), ind) - return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) -end -function _make_svd_trunc_pullback(A, USVᴴ, ind) function svd_trunc_pullback(ΔUSVᴴϵ) - ΔA = zero(A) - ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ - if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) - throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error")) - end - MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) + ΔA = _svd_trunc_pullback(unthunk(ΔUSVᴴϵ), A, USVᴴ, ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + return (USVᴴ′..., ϵ), svd_trunc_pullback +end +function ChainRulesCore.rrule(::typeof(svd_trunc), A, alg::TruncatedAlgorithm) + USVᴴ = svd_compact(A, alg.alg) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = truncation_error(diagview(USVᴴ[2]), ind) + function svd_trunc_pullback(ΔUSVᴴϵ) + ΔA = _svd_trunc_pullback(unthunk(ΔUSVᴴϵ), A, USVᴴ, ind) + return NoTangent(), ΔA, NoTangent() end - return svd_trunc_pullback + return (USVᴴ′..., ϵ), svd_trunc_pullback end function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::TruncatedAlgorithm) Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind) -end -function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind) function svd_trunc_pullback(ΔUSVᴴ) - ΔA = zero(A) - ΔU, ΔS, ΔVᴴ = ΔUSVᴴ - MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) + ΔA = _svd_trunc_no_error_pullback(unthunk(ΔUSVᴴ), A, USVᴴ, ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + return USVᴴ′, svd_trunc_pullback +end +function ChainRulesCore.rrule(::typeof(svd_trunc_no_error), A, alg::TruncatedAlgorithm) + USVᴴ = svd_compact(A, alg.alg) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + function svd_trunc_pullback(ΔUSVᴴ) + ΔA = _svd_trunc_no_error_pullback(unthunk(ΔUSVᴴ), A, USVᴴ, ind) + return NoTangent(), ΔA, NoTangent() end - return svd_trunc_pullback + return USVᴴ′, svd_trunc_pullback end +function _svd_trunc_pullback(ΔUSVᴴϵ, A, USVᴴ, ind) + Δϵ = last(ΔUSVᴴϵ) + !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) && + throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error")) + return _make_svd_trunc_no_error_pullback(Base.front(ΔUSVᴴ), A, USVᴴ, ind) +end +function _svd_trunc_no_error_pullback(ΔUSVᴴ, A, USVᴴ, ind) + ΔA = zero(A) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) + return ΔA +end +_svd_trunc_no_error_pullback(::NTuple{3, ZeroTangent}, A, USVᴴ, ind) = ZeroTangent() + function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg) USVᴴ = svd_compact(A, alg) function svd_vals_pullback(ΔS) @@ -240,11 +263,23 @@ function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg) MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS)) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_pullback(::ZeroTangent) # is this extra definition useful? + function svd_vals_pullback(::ZeroTangent) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return diagview(USVᴴ[2]), svd_vals_pullback end +function ChainRulesCore.rrule(::typeof(svd_vals), A, alg) + USVᴴ = svd_compact(A, alg) + function svd_vals_pullback(ΔS) + ΔA = zero(A) + MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS)) + return NoTangent(), ΔA, NoTangent() + end + function svd_vals_pullback(::ZeroTangent) # is this extra definition useful? + return NoTangent(), ZeroTangent(), NoTangent() + end + return diagview(USVᴴ[2]), svd_vals_pullback +end function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg) Ac = copy_input(left_polar, A) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index dfe1fa16..f00823ff 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -134,7 +134,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ) + YALAPACK.gesdvd!(A, copy(A), view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa LAPACK_Bisection throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) elseif alg isa LAPACK_Jacobi @@ -179,7 +179,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ) + YALAPACK.gesdvd!(A, copy(A), diagview(S), U, Vᴴ) elseif alg isa LAPACK_Bisection YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi @@ -218,7 +218,7 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, S, U, Vᴴ) + YALAPACK.gesdvd!(A, copy(A), S, U, Vᴴ) elseif alg isa LAPACK_Bisection YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi @@ -232,12 +232,93 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end +# avoid double allocation +function svd_full(A::AbstractMatrix, alg::LAPACK_SafeDivideAndConquer) + Ac = copy_input(svd_full, A) + USVᴴ = initialize_output(svd_full!, Ac, alg) + check_input(svd_full!, Ac, USVᴴ, alg) + + U, S, Vᴴ = USVᴴ + zero!(S) + + minmn = min(size(A)...) + minmn == 0 && return one!(U), S, one!(Vᴴ) + + do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + + + YALAPACK.gesdvd!(A, Ac, view(S, 1:minmn, 1), U, Vᴴ) + + for i in 2:minmn + S[i, i] = S[i, 1] + S[i, 1] = zero(eltype(S)) + end + + do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) + + return USVᴴ +end +function svd_compact(A::AbstractMatrix, alg::LAPACK_SafeDivideAndConquer) + Ac = copy_input(svd_compact, A) + USVᴴ = initialize_output(svd_compact!, Ac, alg) + check_input(svd_compact!, Ac, USVᴴ, alg) + + U, S, Vᴴ = USVᴴ + zero!(S) + + minmn = min(size(A)...) + minmn == 0 && return one!(U), S, one!(Vᴴ) + + do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + + YALAPACK.gesdvd!(A, Ac, diagview(S), U, Vᴴ) + + do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) + + return USVᴴ +end +function svd_vals(A::AbstractMatrix, alg::LAPACK_SVDAlgorithm) + Ac = copy_input(svd_vals, A) + S = initialize_output(svd_vals!, Ac, alg) + check_input(svd_vals!, Ac, S, alg) + + minmn = min(size(A)...) + minmn == 0 && return zero!(S) + + U, Vᴴ = similar(Ac, (0, 0)), similar(Ac, (0, 0)) + + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + + YALAPACK.gesdvd!(A, Ac, S, U, Vᴴ) + + return S +end + +function svd_trunc_no_error(A, alg::TruncatedAlgorithm) + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc) + return USVᴴtrunc +end function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) return USVᴴtrunc end +function svd_trunc(A, alg::TruncatedAlgorithm) + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = truncation_error!(diagview(USVᴴ[2]), ind) + return USVᴴtrunc..., ϵ +end function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) diff --git a/src/yalapack.jl b/src/yalapack.jl index 576fe3c5..d890100d 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2170,37 +2170,6 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in return (S, U, Vᴴ), info[] end #! format: off - function gesdvd!( # SafeSVD implementation - A::AbstractMatrix{$elty}, - S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)), - U::AbstractMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)), - Vᴴ::AbstractMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2)) - ) - #! format: on - require_one_based_indexing(A, U, Vᴴ, S) - chkstride1(A, U, Vᴴ, S) - m, n = size(A) - minmn = min(m, n) - work = Vector{$elty}(undef, 1) - if eltype(A) <: Complex - if length(U) == 0 && length(Vᴴ) == 0 - lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn - else - lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1) - end - rwork = Vector{$relty}(undef, lrwork) - else - rwork = nothing - end - Ac = copy(A) - (S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork) - if info > 0 - (S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork) - end - chklapackerror(info) - return S, U, Vᴴ - end - #! format: off function gesvdx!( A::AbstractMatrix{$elty}, S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)), @@ -2430,4 +2399,38 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in end end +# SafeSVD implementation: +# attempts gesdd and falls back to gesvd, requiring two independent copies of A +# Here A is never modified, and Ac is the matrix that will be passed to LAPACK +function gesdvd!( + A::AbstractMatrix, Ac::AbstractMatrix{T}, # only modifies Ac! + S::AbstractVector{Tr} = similar(A, real(T), min(size(A)...)), + U::AbstractMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)), + Vᴴ::AbstractMatrix{T} = similar(A, T, min(size(A)...), size(A, 2)) + ) where {T <: BlasFloat, Tr <: BlasReal} + @assert Tr == real(Tr) + require_one_based_indexing(A, U, Vᴴ, S) + chkstride1(A, U, Vᴴ, S) + m, n = size(A) + minmn = min(m, n) + work = Vector{T}(undef, 1) + if eltype(A) <: Complex + if length(U) == length(Vᴴ) == 0 + lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn + else + lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1) + end + rwork = Vector{Tr}(undef, lrwork) + else + rwork = nothing + end + (S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork) + if info > 0 + copy!(Ac, A) + (S, U, Vᴴ), info = _gesvd_body!(Ac, S, U, Vᴴ, work, rwork) + end + chklapackerror(info) + return S, U, Vᴴ +end + end