From e5cfa3e2e9af258506a1c50db60d04f8d042ca10 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 13 Mar 2026 09:06:50 -0400 Subject: [PATCH 1/5] rework gesdvd! to accept two copies of input but only modify one --- src/implementations/svd.jl | 6 +++--- src/yalapack.jl | 8 +++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index dfe1fa16..eef21988 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -134,7 +134,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, view(S, 1:minmn, 1), U, Vᴴ) + YALAPACK.gesdvd!(A, copy_input(svd_full, A), view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa LAPACK_Bisection throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) elseif alg isa LAPACK_Jacobi @@ -179,7 +179,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, diagview(S), U, Vᴴ) + YALAPACK.gesdvd!(A, copy_input(svd_compact, A), diagview(S), U, Vᴴ) elseif alg isa LAPACK_Bisection YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi @@ -218,7 +218,7 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, S, U, Vᴴ) + YALAPACK.gesdvd!(A, copy_input(svd_vals, A), S, U, Vᴴ) elseif alg isa LAPACK_Bisection YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi diff --git a/src/yalapack.jl b/src/yalapack.jl index 576fe3c5..4cd36be6 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2169,14 +2169,12 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in end return (S, U, Vᴴ), info[] end - #! format: off function gesdvd!( # SafeSVD implementation - A::AbstractMatrix{$elty}, + A::AbstractMatrix{$elty}, Ac::AbstractMatrix{$elty} = copy(A), # only modifies Ac! S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)), U::AbstractMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)), Vᴴ::AbstractMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2)) ) - #! format: on require_one_based_indexing(A, U, Vᴴ, S) chkstride1(A, U, Vᴴ, S) m, n = size(A) @@ -2192,10 +2190,10 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in else rwork = nothing end - Ac = copy(A) (S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork) if info > 0 - (S, U, Vᴴ), info = _gesvd_body!(A, S, U, Vᴴ, work, rwork) + copy!(Ac, A) + (S, U, Vᴴ), info = _gesvd_body!(Ac, S, U, Vᴴ, work, rwork) end chklapackerror(info) return S, U, Vᴴ From cab77e19978ad7ab1f735beadbf6f4a8356ee5b5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 13 Mar 2026 09:07:15 -0400 Subject: [PATCH 2/5] avoid double-copy of svd functions --- src/implementations/svd.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index eef21988..2d492d8c 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -232,6 +232,16 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end +# avoid double allocation +for f in (:svd_compact, :svd_full, :svd_vals) + f! = Symbol(f, :!) + @eval $f(A, alg::LAPACK_SafeDivideAndConquer) = $f!(A, alg) +end +for f in (:svd_trunc, :svd_trunc_no_error) + f! = Symbol(f, :!) + @eval $f(A, alg::TruncatedAlgorithm{<:LAPACK_SafeDivideAndConquer}) = $f!(A, alg) +end + function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) From 2388634d35ec2ed5318e2b9262bac2ee72ba7ea4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 13 Mar 2026 09:52:41 -0400 Subject: [PATCH 3/5] fix eltype issue --- src/yalapack.jl | 63 ++++++++++++++++++++++++++----------------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/src/yalapack.jl b/src/yalapack.jl index 4cd36be6..00d0e202 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2169,35 +2169,6 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in end return (S, U, Vᴴ), info[] end - function gesdvd!( # SafeSVD implementation - A::AbstractMatrix{$elty}, Ac::AbstractMatrix{$elty} = copy(A), # only modifies Ac! - S::AbstractVector{$relty} = similar(A, $relty, min(size(A)...)), - U::AbstractMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)), - Vᴴ::AbstractMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2)) - ) - require_one_based_indexing(A, U, Vᴴ, S) - chkstride1(A, U, Vᴴ, S) - m, n = size(A) - minmn = min(m, n) - work = Vector{$elty}(undef, 1) - if eltype(A) <: Complex - if length(U) == 0 && length(Vᴴ) == 0 - lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn - else - lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1) - end - rwork = Vector{$relty}(undef, lrwork) - else - rwork = nothing - end - (S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork) - if info > 0 - copy!(Ac, A) - (S, U, Vᴴ), info = _gesvd_body!(Ac, S, U, Vᴴ, work, rwork) - end - chklapackerror(info) - return S, U, Vᴴ - end #! format: off function gesvdx!( A::AbstractMatrix{$elty}, @@ -2428,4 +2399,38 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in end end +# SafeSVD implementation: +# attempts gesdd and falls back to gesvd, requiring two independent copies of A +# Here A is never modified, and Ac is the matrix that will be passed to LAPACK +function gesdvd!( + A::AbstractMatrix, Ac::AbstractMatrix{T}, # only modifies Ac! + S::AbstractVector{Tr} = similar(A, real(T), min(size(A)...)), + U::AbstractMatrix{T} = similar(A, T, size(A, 1), min(size(A)...)), + Vᴴ::AbstractMatrix{T} = similar(A, T, min(size(A)...), size(A, 2)) + ) where {T <: BlasFloat, Tr <: BlasReal} + @assert Tr == real(Tr) + require_one_based_indexing(A, U, Vᴴ, S) + chkstride1(A, U, Vᴴ, S) + m, n = size(A) + minmn = min(m, n) + work = Vector{T}(undef, 1) + if eltype(A) <: Complex + if length(U) == length(Vᴴ) == 0 + lrwork = (LAPACK.version() <= v"3.6") ? 7 * minmn : 5 * minmn + else + lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1) + end + rwork = Vector{T}(undef, lrwork) + else + rwork = nothing + end + (S, U, Vᴴ), info = _gesdd_body!(Ac, S, U, Vᴴ, work, rwork) + if info > 0 + copy!(Ac, A) + (S, U, Vᴴ), info = _gesvd_body!(Ac, S, U, Vᴴ, work, rwork) + end + chklapackerror(info) + return S, U, Vᴴ +end + end From 3f7d45d9f5e334195afc54751193e1ef985f9b43 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 13 Mar 2026 10:42:58 -0400 Subject: [PATCH 4/5] another attempt --- src/implementations/svd.jl | 89 ++++++++++++++++++++++++++++++++++---- src/yalapack.jl | 2 +- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 2d492d8c..f00823ff 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -134,7 +134,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, copy_input(svd_full, A), view(S, 1:minmn, 1), U, Vᴴ) + YALAPACK.gesdvd!(A, copy(A), view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa LAPACK_Bisection throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) elseif alg isa LAPACK_Jacobi @@ -179,7 +179,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, copy_input(svd_compact, A), diagview(S), U, Vᴴ) + YALAPACK.gesdvd!(A, copy(A), diagview(S), U, Vᴴ) elseif alg isa LAPACK_Bisection YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi @@ -218,7 +218,7 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) elseif alg isa LAPACK_SafeDivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) - YALAPACK.gesdvd!(A, copy_input(svd_vals, A), S, U, Vᴴ) + YALAPACK.gesdvd!(A, copy(A), S, U, Vᴴ) elseif alg isa LAPACK_Bisection YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi @@ -233,21 +233,92 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) end # avoid double allocation -for f in (:svd_compact, :svd_full, :svd_vals) - f! = Symbol(f, :!) - @eval $f(A, alg::LAPACK_SafeDivideAndConquer) = $f!(A, alg) +function svd_full(A::AbstractMatrix, alg::LAPACK_SafeDivideAndConquer) + Ac = copy_input(svd_full, A) + USVᴴ = initialize_output(svd_full!, Ac, alg) + check_input(svd_full!, Ac, USVᴴ, alg) + + U, S, Vᴴ = USVᴴ + zero!(S) + + minmn = min(size(A)...) + minmn == 0 && return one!(U), S, one!(Vᴴ) + + do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + + + YALAPACK.gesdvd!(A, Ac, view(S, 1:minmn, 1), U, Vᴴ) + + for i in 2:minmn + S[i, i] = S[i, 1] + S[i, 1] = zero(eltype(S)) + end + + do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) + + return USVᴴ +end +function svd_compact(A::AbstractMatrix, alg::LAPACK_SafeDivideAndConquer) + Ac = copy_input(svd_compact, A) + USVᴴ = initialize_output(svd_compact!, Ac, alg) + check_input(svd_compact!, Ac, USVᴴ, alg) + + U, S, Vᴴ = USVᴴ + zero!(S) + + minmn = min(size(A)...) + minmn == 0 && return one!(U), S, one!(Vᴴ) + + do_gauge_fix = get(alg.kwargs, :fixgauge, default_fixgauge())::Bool + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + + YALAPACK.gesdvd!(A, Ac, diagview(S), U, Vᴴ) + + do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) + + return USVᴴ end -for f in (:svd_trunc, :svd_trunc_no_error) - f! = Symbol(f, :!) - @eval $f(A, alg::TruncatedAlgorithm{<:LAPACK_SafeDivideAndConquer}) = $f!(A, alg) +function svd_vals(A::AbstractMatrix, alg::LAPACK_SVDAlgorithm) + Ac = copy_input(svd_vals, A) + S = initialize_output(svd_vals!, Ac, alg) + check_input(svd_vals!, Ac, S, alg) + + minmn = min(size(A)...) + minmn == 0 && return zero!(S) + + U, Vᴴ = similar(Ac, (0, 0)), similar(Ac, (0, 0)) + + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:fixgauge,)}) + isempty(alg_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_SafeDivideAndConquer")) + + YALAPACK.gesdvd!(A, Ac, S, U, Vᴴ) + + return S end +function svd_trunc_no_error(A, alg::TruncatedAlgorithm) + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc) + return USVᴴtrunc +end function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) return USVᴴtrunc end +function svd_trunc(A, alg::TruncatedAlgorithm) + USVᴴ = svd_compact(A, alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = truncation_error!(diagview(USVᴴ[2]), ind) + return USVᴴtrunc..., ϵ +end function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) diff --git a/src/yalapack.jl b/src/yalapack.jl index 00d0e202..d890100d 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -2420,7 +2420,7 @@ function gesdvd!( else lrwork = minmn * max(5 * minmn + 5, 2 * max(m, n) + 2 * minmn + 1) end - rwork = Vector{T}(undef, lrwork) + rwork = Vector{Tr}(undef, lrwork) else rwork = nothing end From fa27db187cf543df06d21c3a61c60d81ee39f000 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 13 Mar 2026 13:04:34 -0400 Subject: [PATCH 5/5] add chainrules support --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 79 +++++++++++++++++------- 1 file changed, 57 insertions(+), 22 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index acfe0d83..c25a79ad 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -188,6 +188,15 @@ for svd_f in (:svd_compact, :svd_full) end return USVᴴ, svd_pullback end + function ChainRulesCore.rrule(::typeof($svd_f), A, alg) + USVᴴ = $(svd_f)(A, alg) + function svd_pullback(ΔUSVᴴ) + ΔA = zero(A) + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ)) + return NoTangent(), ΔA, NoTangent() + end + return USVᴴ, svd_pullback + end end end @@ -196,43 +205,57 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) ϵ = truncation_error(diagview(USVᴴ[2]), ind) - return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) -end -function _make_svd_trunc_pullback(A, USVᴴ, ind) function svd_trunc_pullback(ΔUSVᴴϵ) - ΔA = zero(A) - ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ - if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) - throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error")) - end - MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) + ΔA = _svd_trunc_pullback(unthunk(ΔUSVᴴϵ), A, USVᴴ, ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + return (USVᴴ′..., ϵ), svd_trunc_pullback +end +function ChainRulesCore.rrule(::typeof(svd_trunc), A, alg::TruncatedAlgorithm) + USVᴴ = svd_compact(A, alg.alg) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = truncation_error(diagview(USVᴴ[2]), ind) + function svd_trunc_pullback(ΔUSVᴴϵ) + ΔA = _svd_trunc_pullback(unthunk(ΔUSVᴴϵ), A, USVᴴ, ind) + return NoTangent(), ΔA, NoTangent() end - return svd_trunc_pullback + return (USVᴴ′..., ϵ), svd_trunc_pullback end function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::TruncatedAlgorithm) Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind) -end -function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind) function svd_trunc_pullback(ΔUSVᴴ) - ΔA = zero(A) - ΔU, ΔS, ΔVᴴ = ΔUSVᴴ - MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) + ΔA = _svd_trunc_no_error_pullback(unthunk(ΔUSVᴴ), A, USVᴴ, ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + return USVᴴ′, svd_trunc_pullback +end +function ChainRulesCore.rrule(::typeof(svd_trunc_no_error), A, alg::TruncatedAlgorithm) + USVᴴ = svd_compact(A, alg.alg) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + function svd_trunc_pullback(ΔUSVᴴ) + ΔA = _svd_trunc_no_error_pullback(unthunk(ΔUSVᴴ), A, USVᴴ, ind) + return NoTangent(), ΔA, NoTangent() end - return svd_trunc_pullback + return USVᴴ′, svd_trunc_pullback end +function _svd_trunc_pullback(ΔUSVᴴϵ, A, USVᴴ, ind) + Δϵ = last(ΔUSVᴴϵ) + !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) && + throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error")) + return _make_svd_trunc_no_error_pullback(Base.front(ΔUSVᴴ), A, USVᴴ, ind) +end +function _svd_trunc_no_error_pullback(ΔUSVᴴ, A, USVᴴ, ind) + ΔA = zero(A) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) + return ΔA +end +_svd_trunc_no_error_pullback(::NTuple{3, ZeroTangent}, A, USVᴴ, ind) = ZeroTangent() + function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg) USVᴴ = svd_compact(A, alg) function svd_vals_pullback(ΔS) @@ -240,11 +263,23 @@ function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg) MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS)) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_pullback(::ZeroTangent) # is this extra definition useful? + function svd_vals_pullback(::ZeroTangent) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return diagview(USVᴴ[2]), svd_vals_pullback end +function ChainRulesCore.rrule(::typeof(svd_vals), A, alg) + USVᴴ = svd_compact(A, alg) + function svd_vals_pullback(ΔS) + ΔA = zero(A) + MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS)) + return NoTangent(), ΔA, NoTangent() + end + function svd_vals_pullback(::ZeroTangent) # is this extra definition useful? + return NoTangent(), ZeroTangent(), NoTangent() + end + return diagview(USVᴴ[2]), svd_vals_pullback +end function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg) Ac = copy_input(left_polar, A)