Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ for svd_f in (:svd_compact, :svd_full)
end
return USVᴴ, svd_pullback
end
function ChainRulesCore.rrule(::typeof($svd_f), A, alg)
USVᴴ = $(svd_f)(A, alg)
function svd_pullback(ΔUSVᴴ)
ΔA = zero(A)
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ))
return NoTangent(), ΔA, NoTangent()
end
return USVᴴ, svd_pullback
end
end
end

Expand All @@ -196,55 +205,81 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
end
function _make_svd_trunc_pullback(A, USVᴴ, ind)
function svd_trunc_pullback(ΔUSVᴴϵ)
ΔA = zero(A)
ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"))
end
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
ΔA = _svd_trunc_pullback(unthunk(ΔUSVᴴϵ), A, USVᴴ, ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
return (USVᴴ′..., ϵ), svd_trunc_pullback
end
function ChainRulesCore.rrule(::typeof(svd_trunc), A, alg::TruncatedAlgorithm)
USVᴴ = svd_compact(A, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
function svd_trunc_pullback(ΔUSVᴴϵ)
ΔA = _svd_trunc_pullback(unthunk(ΔUSVᴴϵ), A, USVᴴ, ind)
return NoTangent(), ΔA, NoTangent()
end
return svd_trunc_pullback
return (USVᴴ′..., ϵ), svd_trunc_pullback
end

function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::TruncatedAlgorithm)
Ac = copy_input(svd_compact, A)
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
end
function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind)
function svd_trunc_pullback(ΔUSVᴴ)
ΔA = zero(A)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
ΔA = _svd_trunc_no_error_pullback(unthunk(ΔUSVᴴ), A, USVᴴ, ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
return USVᴴ′, svd_trunc_pullback
end
function ChainRulesCore.rrule(::typeof(svd_trunc_no_error), A, alg::TruncatedAlgorithm)
USVᴴ = svd_compact(A, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
function svd_trunc_pullback(ΔUSVᴴ)
ΔA = _svd_trunc_no_error_pullback(unthunk(ΔUSVᴴ), A, USVᴴ, ind)
return NoTangent(), ΔA, NoTangent()
end
return svd_trunc_pullback
return USVᴴ′, svd_trunc_pullback
end

function _svd_trunc_pullback(ΔUSVᴴϵ, A, USVᴴ, ind)
Δϵ = last(ΔUSVᴴϵ)
!MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) &&
throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"))
return _make_svd_trunc_no_error_pullback(Base.front(ΔUSVᴴ), A, USVᴴ, ind)
end
function _svd_trunc_no_error_pullback(ΔUSVᴴ, A, USVᴴ, ind)
ΔA = zero(A)
ΔU, ΔS, ΔVᴴ = ΔUSVᴴ
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
return ΔA
end
_svd_trunc_no_error_pullback(::NTuple{3, ZeroTangent}, A, USVᴴ, ind) = ZeroTangent()

function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg)
USVᴴ = svd_compact(A, alg)
function svd_vals_pullback(ΔS)
ΔA = zero(A)
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function svd_pullback(::ZeroTangent) # is this extra definition useful?
function svd_vals_pullback(::ZeroTangent) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return diagview(USVᴴ[2]), svd_vals_pullback
end
function ChainRulesCore.rrule(::typeof(svd_vals), A, alg)
USVᴴ = svd_compact(A, alg)
function svd_vals_pullback(ΔS)
ΔA = zero(A)
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
return NoTangent(), ΔA, NoTangent()
end
function svd_vals_pullback(::ZeroTangent) # is this extra definition useful?
return NoTangent(), ZeroTangent(), NoTangent()
end
return diagview(USVᴴ[2]), svd_vals_pullback
end

function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg)
Ac = copy_input(left_polar, A)
Expand Down
87 changes: 84 additions & 3 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
elseif alg isa LAPACK_SafeDivideAndConquer
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ)
YALAPACK.gesdvd!(A, copy(A), view(S, 1:minmn, 1), U, Vᴴ)
elseif alg isa LAPACK_Bisection
throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
elseif alg isa LAPACK_Jacobi
Expand Down Expand Up @@ -179,7 +179,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm)
elseif alg isa LAPACK_SafeDivideAndConquer
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ)
YALAPACK.gesdvd!(A, copy(A), diagview(S), U, Vᴴ)
elseif alg isa LAPACK_Bisection
YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...)
elseif alg isa LAPACK_Jacobi
Expand Down Expand Up @@ -218,7 +218,7 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
elseif alg isa LAPACK_SafeDivideAndConquer
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))
YALAPACK.gesdvd!(A, S, U, Vᴴ)
YALAPACK.gesdvd!(A, copy(A), S, U, Vᴴ)
elseif alg isa LAPACK_Bisection
YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...)
elseif alg isa LAPACK_Jacobi
Expand All @@ -232,12 +232,93 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
return S
end

# avoid double allocation
function svd_full(A::AbstractMatrix, alg::LAPACK_SafeDivideAndConquer)
Ac = copy_input(svd_full, A)
USVᴴ = initialize_output(svd_full!, Ac, alg)
check_input(svd_full!, Ac, USVᴴ, alg)

U, S, Vᴴ = USVᴴ
zero!(S)

minmn = min(size(A)...)
minmn == 0 && return one!(U), S, one!(Vᴴ)

do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))


YALAPACK.gesdvd!(A, Ac, view(S, 1:minmn, 1), U, Vᴴ)

for i in 2:minmn
S[i, i] = S[i, 1]
S[i, 1] = zero(eltype(S))
end

do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ)

return USVᴴ
end
function svd_compact(A::AbstractMatrix, alg::LAPACK_SafeDivideAndConquer)
Ac = copy_input(svd_compact, A)
USVᴴ = initialize_output(svd_compact!, Ac, alg)
check_input(svd_compact!, Ac, USVᴴ, alg)

U, S, Vᴴ = USVᴴ
zero!(S)

minmn = min(size(A)...)
minmn == 0 && return one!(U), S, one!(Vᴴ)

do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool
alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))

YALAPACK.gesdvd!(A, Ac, diagview(S), U, Vᴴ)

do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ)

return USVᴴ
end
function svd_vals(A::AbstractMatrix, alg::LAPACK_SVDAlgorithm)
Ac = copy_input(svd_vals, A)
S = initialize_output(svd_vals!, Ac, alg)
check_input(svd_vals!, Ac, S, alg)

minmn = min(size(A)...)
minmn == 0 && return zero!(S)

U, Vᴴ = similar(Ac, (0, 0)), similar(Ac, (0, 0))

alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)})
isempty(alg_kwargs) ||
throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer"))

YALAPACK.gesdvd!(A, Ac, S, U, Vᴴ)

return S
end

function svd_trunc_no_error(A, alg::TruncatedAlgorithm)
USVᴴ = svd_compact(A, alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc)
return USVᴴtrunc
end
function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm)
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
return USVᴴtrunc
end

function svd_trunc(A, alg::TruncatedAlgorithm)
USVᴴ = svd_compact(A, alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = truncation_error!(diagview(USVᴴ[2]), ind)
return USVᴴtrunc..., ϵ
end
function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
Expand Down
65 changes: 34 additions & 31 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2170,37 +2170,6 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
return (S, U, Vᴴ), info[]
end
#! format: off
function gesdvd!( # SafeSVD implementation
A::AbstractMatrix{$elty},
S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)),
U::AbstractMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
Vᴴ::AbstractMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2))
)
#! format: on
require_one_based_indexing(A, U, Vᴴ, S)
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
minmn = min(m, n)
work = Vector{$elty}(undef, 1)
if eltype(A) <: Complex
if length(U) == 0 && length(Vᴴ) == 0
lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn
else
lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1)
end
rwork = Vector{$relty}(undef, lrwork)
else
rwork = nothing
end
Ac = copy(A)
(S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork)
if info > 0
(S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork)
end
chklapackerror(info)
return S, U, Vᴴ
end
#! format: off
function gesvdx!(
A::AbstractMatrix{$elty},
S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)),
Expand Down Expand Up @@ -2430,4 +2399,38 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
end
end

# SafeSVD implementation:
# attempts gesdd and falls back to gesvd, requiring two independent copies of A
# Here A is never modified, and Ac is the matrix that will be passed to LAPACK
function gesdvd!(
A::AbstractMatrix, Ac::AbstractMatrix{T}, # only modifies Ac!
S::AbstractVector{Tr} = similar(A, real(T), min(size(A)...)),
U::AbstractMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)),
Vᴴ::AbstractMatrix{T} = similar(A, T, min(size(A)...), size(A, 2))
) where {T <: BlasFloat, Tr <: BlasReal}
@assert Tr == real(Tr)
require_one_based_indexing(A, U, Vᴴ, S)
chkstride1(A, U, Vᴴ, S)
m, n = size(A)
minmn = min(m, n)
work = Vector{T}(undef, 1)
if eltype(A) <: Complex
if length(U) == length(Vᴴ) == 0
lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn
else
lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1)
end
rwork = Vector{Tr}(undef, lrwork)
else
rwork = nothing
end
(S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork)
if info > 0
copy!(Ac, A)
(S, U, Vᴴ), info = _gesvd_body!(Ac, S, U, Vᴴ, work, rwork)
end
chklapackerror(info)
return S, U, Vᴴ
end

end
Loading