From ebffb20fc2a05718fda50cfb4e20b6d723012644 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 18:32:44 +0100 Subject: [PATCH 1/6] Split svd_trunc Now we have a `svd_trunc_with_err` that returns epsilon for those who wish, and `svd_trunc` returns only the truncated USVh --- docs/src/user_interface/truncations.md | 6 +- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 27 ++++++- .../MatrixAlgebraKitMooncakeExt.jl | 36 ++++++++- src/MatrixAlgebraKit.jl | 4 +- src/implementations/svd.jl | 80 ++++++++++++++----- src/interface/svd.jl | 74 ++++++++++++++--- test/amd/svd.jl | 4 +- test/chainrules.jl | 32 ++++++-- test/cuda/svd.jl | 13 ++- test/genericlinearalgebra/svd.jl | 13 +-- test/mooncake.jl | 14 +++- test/svd.jl | 19 +++-- 12 files changed, 251 insertions(+), 71 deletions(-) diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index ee730020..c57a869c 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -113,16 +113,16 @@ combined_trunc = truncrank(10) & trunctol(; atol = 1e-6); ## Truncation Error -When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned. +When using truncated decompositions such as [`svd_trunc_with_err`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned. This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality. -For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix. +For `svd_trunc_with_err` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix. For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality. For example: ```jldoctest truncations; output=false using LinearAlgebra: norm -U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2)) +U, S, Vᴴ, ϵ = svd_trunc_with_err(A; trunc=truncrank(2)) norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values # output diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 549f4a53..42cce8ab 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -170,15 +170,15 @@ for svd_f in (:svd_compact, :svd_full) end end -function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm) +function ChainRulesCore.rrule(::typeof(svd_trunc_with_err!), A, USVᴴ, alg::TruncatedAlgorithm) Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) ϵ = truncation_error(diagview(USVᴴ[2]), ind) - return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) + return (USVᴴ′..., ϵ), _make_svd_trunc_with_err_pullback(A, USVᴴ, ind) end -function _make_svd_trunc_pullback(A, USVᴴ, ind) - function svd_trunc_pullback(ΔUSVᴴϵ) +function _make_svd_trunc_with_err_pullback(A, USVᴴ, ind) + function svd_trunc_with_err_pullback(ΔUSVᴴϵ) ΔA = zero(A) ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) @@ -187,6 +187,25 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind) MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end + function svd_trunc_with_err_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return svd_trunc_with_err_pullback +end + +function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm) + Ac = copy_input(svd_compact, A) + USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind) +end +function _make_svd_trunc_pullback(A, USVᴴ, ind) + function svd_trunc_pullback(ΔUSVᴴ) + ΔA = zero(A) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index aa16f61e..676bc1db 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -303,14 +303,14 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co return S_codual, svd_vals_adjoint end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_with_err), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A_ = Mooncake.primal(A_dA) dA_ = Mooncake.tangent(A_dA) A, dA = arrayify(A_, dA_) alg = Mooncake.primal(alg_dalg) - output = svd_trunc(A, alg) + output = svd_trunc_with_err(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but @@ -319,7 +319,35 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" + abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc_with_err does not yet support non-zero tangent for the truncation error" + U, dU = arrayify(Utrunc, dUtrunc_) + S, dS = arrayify(Strunc, dStrunc_) + Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) + svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + MatrixAlgebraKit.zero!(dU) + MatrixAlgebraKit.zero!(dS) + MatrixAlgebraKit.zero!(dVᴴ) + return NoRData(), NoRData(), NoRData() + end + return output_codual, svd_trunc_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) + # compute primal + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + alg = Mooncake.primal(alg_dalg) + output = svd_trunc(A, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function svd_trunc_adjoint(::NoRData) + Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual) + dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual) U, dU = arrayify(Utrunc, dUtrunc_) S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index fd97497b..e1052fa3 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -16,8 +16,8 @@ export project_hermitian, project_antihermitian, project_isometric export project_hermitian!, project_antihermitian!, project_isometric! export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null! -export svd_compact, svd_full, svd_vals, svd_trunc -export svd_compact!, svd_full!, svd_vals!, svd_trunc! +export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_with_err +export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_with_err! export eigh_full, eigh_vals, eigh_trunc export eigh_full!, eigh_vals!, eigh_trunc! export eig_full, eig_vals, eig_trunc diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 126e6a04..57365f6d 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -4,6 +4,7 @@ copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltyp copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A) copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A) copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A) +copy_input(::typeof(svd_trunc_with_err), A) = copy_input(svd_compact, A) copy_input(::typeof(svd_full), A::Diagonal) = copy(A) @@ -92,6 +93,9 @@ end function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm) return initialize_output(svd_compact!, A, alg.alg) end +function initialize_output(::typeof(svd_trunc_with_err!), A, alg::TruncatedAlgorithm) + return initialize_output(svd_compact!, A, alg.alg) +end function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm) TA = eltype(A) @@ -206,19 +210,16 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end -function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} - ϵ = similar(A, real(eltype(A)), compute_error) - (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) +function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) + U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + return USVᴴtrunc end -function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ} - U, S, Vᴴ, ϵ = USVᴴϵ - U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg) +function svd_trunc_with_err!(A, USVᴴ, alg::TruncatedAlgorithm) + U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - if !isempty(ϵ) - ϵ .= truncation_error!(diagview(S), ind) - end + ϵ = truncation_error!(diagview(S), ind) return USVᴴtrunc..., ϵ end @@ -287,6 +288,22 @@ function check_input( return nothing end +function check_input( + ::typeof(svd_trunc_with_err!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized + ) + m, n = size(A) + minmn = min(m, n) + U, S, Vᴴ = USVᴴ + @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix + @check_size(U, (m, m)) + @check_scalar(U, A) + @check_size(S, (minmn, minmn)) + @check_scalar(S, A, real) + @check_size(Vᴴ, (n, n)) + @check_scalar(Vᴴ, A) + return nothing +end + function initialize_output( ::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} ) @@ -298,6 +315,17 @@ function initialize_output( return (U, S, Vᴴ) end +function initialize_output( + ::typeof(svd_trunc_with_err!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} + ) + m, n = size(A) + minmn = min(m, n) + U = similar(A, (m, m)) + S = Diagonal(similar(A, real(eltype(A)), (minmn,))) + Vᴴ = similar(A, (n, n)) + return (U, S, Vᴴ) +end + function _gpu_gesvd!( A::AbstractMatrix, S::AbstractVector, U::AbstractMatrix, Vᴴ::AbstractMatrix ) @@ -372,22 +400,34 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) return USVᴴ end -function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ} - U, S, Vᴴ, ϵ = USVᴴϵ +function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) + U, S, Vᴴ = USVᴴ check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - if !isempty(ϵ) - # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum - normS = norm(diagview(Str)) - normA = norm(A) - # equivalent to sqrt(normA^2 - normS^2) - # but may be more accurate - ϵ = sqrt((normA + normS) * (normA - normS)) - end + do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool + do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) + + return Utr, Str, Vᴴtr +end + +function svd_trunc_with_err!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) + U, S, Vᴴ = USVᴴ + check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) + _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) + + # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong + (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + + # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum + normS = norm(diagview(Str)) + normA = norm(A) + # equivalent to sqrt(normA^2 - normS^2) + # but may be more accurate + ϵ = sqrt((normA + normS) * (normA - normS)) do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 2ea26204..b60a5839 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -42,10 +42,10 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and @functiondef svd_compact """ - svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ - svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ - svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ - svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ + svd_trunc_with_err(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc_with_err(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ + svd_trunc_with_err!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc_with_err!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ Compute a partial or truncated singular value decomposition (SVD) of `A`, such that `A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size @@ -81,6 +81,54 @@ for the default algorithm selection behavior. When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the truncation strategy is already embedded in the algorithm. +!!! note + The bang method `svd_trunc!` optionally accepts the output structure and + possibly destroys the input matrix `A`. Always use the return value of the function + as it may not always be possible to use the provided `USVᴴ` as output. + +See also [`svd_trunc(!)`](@ref svd_trunc), [`svd_full(!)`](@ref svd_full), +[`svd_compact(!)`](@ref svd_compact), [`svd_vals(!)`](@ref svd_vals), +and [Truncations](@ref) for more information on truncation strategies. +""" +@functiondef svd_trunc_with_err + +""" + svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ + svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ + svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ + svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ + +Compute a partial or truncated singular value decomposition (SVD) of `A`, such that +`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size +`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a +square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy. + +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + +## Keyword arguments +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. + +When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the +truncation strategy is already embedded in the algorithm. + !!! note The bang method `svd_trunc!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function @@ -125,13 +173,15 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!) end end -function select_algorithm(::typeof(svd_trunc!), A, alg; trunc = nothing, kwargs...) - if alg isa TruncatedAlgorithm - isnothing(trunc) || - throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) - return alg - else - alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) +for f in (:svd_trunc!, :svd_trunc_with_err!) + @eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) + end end end diff --git a/test/amd/svd.jl b/test/amd/svd.jl index fcd5b490..d681acc1 100644 --- a/test/amd/svd.jl +++ b/test/amd/svd.jl @@ -140,14 +140,14 @@ end # minmn = min(m, n) # r = minmn - 2 # -# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) +# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc=truncrank(r)) # @test length(S1.diag) == r # @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] # # s = 1 + sqrt(eps(real(T))) # trunc2 = trunctol(; atol=s * S₀[r + 1]) # -# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) +# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) # @test length(S2.diag) == r # @test U1 ≈ U2 # @test S1 ≈ S2 diff --git a/test/chainrules.jl b/test/chainrules.jl index 5258b839..7219b22a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -12,7 +12,7 @@ for f in ( :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, - :svd_compact, :svd_trunc, :svd_vals, + :svd_compact, :svd_trunc, :svd_trunc_with_err, :svd_vals, :left_polar, :right_polar, ) copy_f = Symbol(:copy_, f) @@ -430,10 +430,15 @@ end ΔUtrunc = ΔU[:, ind] ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); + copy_svd_trunc_with_err, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) + test_rrule( + copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + atol = atol, rtol = rtol + ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -447,10 +452,15 @@ end ΔUtrunc = ΔU[:, ind] ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); + copy_svd_trunc_with_err, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) + test_rrule( + copy_svd_trunc, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + atol = atol, rtol = rtol + ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -475,20 +485,32 @@ end trunc = truncrank(r) ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) test_rrule( - config, svd_trunc, A; + config, svd_trunc_with_err, A; fkwargs = (; trunc = trunc), output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) end trunc = trunctol(; atol = S[1, 1] / 2) ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) test_rrule( - config, svd_trunc, A; + config, svd_trunc_with_err, A; fkwargs = (; trunc = trunc), output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + test_rrule( + config, svd_trunc, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) end end diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index fc564fec..14e26991 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -140,16 +140,25 @@ end S₀ = svd_vals(hA) r = k - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1 + U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + @test length(S1.diag) == r + @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] if !(alg isa CUSOLVER_Randomized) s = 1 + sqrt(eps(real(T))) trunc2 = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) + @test length(S2.diag) == r + @test U1 ≈ U2 + @test parent(S1) ≈ parent(S2) + @test V1ᴴ ≈ V2ᴴ + + U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) @test length(S2.diag) == r @test U1 ≈ U2 @test parent(S1) ≈ parent(S2) diff --git a/test/genericlinearalgebra/svd.jl b/test/genericlinearalgebra/svd.jl index f7177e79..43feb30f 100644 --- a/test/genericlinearalgebra/svd.jl +++ b/test/genericlinearalgebra/svd.jl @@ -105,7 +105,7 @@ end minmn = min(m, n) r = minmn - 2 - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc = truncrank(r)) @test length(diagview(S1)) == r @test diagview(S1) ≈ S₀[1:r] @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @@ -115,7 +115,7 @@ end s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc) @test length(diagview(S2)) == r @test U1 ≈ U2 @test S1 ≈ S2 @@ -123,7 +123,7 @@ end @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) + U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc_with_err(A; alg, trunc) @test length(diagview(S3)) == r @test U1 ≈ U3 @test S1 ≈ S3 @@ -145,11 +145,11 @@ end (rtol, maxrank) -> (; rtol, maxrank), (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), ) - U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) + U1, S1, V1ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) @test length(diagview(S1)) == 1 @test diagview(S1) ≈ diagview(S)[1:1] - U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + U2, S2, V2ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) @test length(diagview(S2)) == 2 @test diagview(S2) ≈ diagview(S)[1:2] end @@ -164,8 +164,9 @@ end Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ alg = TruncatedAlgorithm(GLA_QRIteration(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) + @test_throws ArgumentError svd_trunc_with_err(A; alg, trunc = (; maxrank = 2)) end diff --git a/test/mooncake.jl b/test/mooncake.jl index 3e19e44d..f2b506ae 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -456,8 +456,11 @@ end dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) ϵ = zero(real(T)) dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + Mooncake.TestUtils.test_rule(rng, svd_trunc_with_err, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc_with_err!, svd_trunc_with_err, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) end @testset "trunctol" begin U, S, Vᴴ = svd_compact(A) @@ -479,8 +482,11 @@ end dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) ϵ = zero(real(T)) dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + Mooncake.TestUtils.test_rule(rng, svd_trunc_with_err, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc_with_err!, svd_trunc_with_err, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) end end end diff --git a/test/svd.jl b/test/svd.jl index d055f866..09d62acf 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -129,7 +129,7 @@ end minmn = min(m, n) r = minmn - 2 - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc = truncrank(r)) @test length(diagview(S1)) == r @test diagview(S1) ≈ S₀[1:r] @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @@ -139,7 +139,7 @@ end s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc) @test length(diagview(S2)) == r @test U1 ≈ U2 @test S1 ≈ S2 @@ -147,7 +147,7 @@ end @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) + U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc_with_err(A; alg, trunc) @test length(diagview(S3)) == r @test U1 ≈ U3 @test S1 ≈ S3 @@ -177,11 +177,11 @@ end (rtol, maxrank) -> (; rtol, maxrank), (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), ) - U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) + U1, S1, V1ᴴ, ϵ1 = svd_trunc_with_err(A; alg, trunc = trunc_fun(0.2, 1)) @test length(diagview(S1)) == 1 @test diagview(S1) ≈ diagview(S)[1:1] - U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + U2, S2, V2ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) @test length(diagview(S2)) == 2 @test diagview(S2) ≈ diagview(S)[1:2] end @@ -197,9 +197,12 @@ end Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol + U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg) + @test diagview(S2) ≈ diagview(S)[1:2] + @test_throws ArgumentError svd_trunc_with_err(A; alg, trunc = (; maxrank = 2)) @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) end @@ -233,8 +236,10 @@ end @test S2 ≈ diagview(S) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) + U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc_with_err(A; alg) @test diagview(S3) ≈ S2[1:min(m, 2)] @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol + U3, S3, Vᴴ3 = @constinferred svd_trunc(A; alg) + @test diagview(S3) ≈ S2[1:min(m, 2)] end end From 48dd5dd56af338a40a6ae5be16d5b8465a1d3c6e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 21:32:34 +0100 Subject: [PATCH 2/6] Fix coverage --- src/implementations/svd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 57365f6d..6a8c2c54 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -416,7 +416,7 @@ end function svd_trunc_with_err!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) U, S, Vᴴ = USVᴴ - check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) + check_input(svd_trunc_with_err!, A, (U, S, Vᴴ), alg.alg) _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong From d22d05aef6ad6a50f664db1b4962d4a38f36a5ce Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 23:12:01 +0100 Subject: [PATCH 3/6] S.diag to diagview(S) --- src/implementations/svd.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 6a8c2c54..bd89ed54 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -163,17 +163,17 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) if alg isa LAPACK_QRIteration isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) - YALAPACK.gesvd!(A, S.diag, U, Vᴴ) + YALAPACK.gesvd!(A, diagview(S), U, Vᴴ) elseif alg isa LAPACK_DivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) - YALAPACK.gesdd!(A, S.diag, U, Vᴴ) + YALAPACK.gesdd!(A, diagview(S), U, Vᴴ) elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg_kwargs...) + YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) - YALAPACK.gesvj!(A, S.diag, U, Vᴴ) + YALAPACK.gesvj!(A, diagview(S), U, Vᴴ) else throw(ArgumentError("Unsupported SVD algorithm")) end @@ -403,7 +403,7 @@ end function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) U, S, Vᴴ = USVᴴ check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) - _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) + _gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) @@ -417,7 +417,7 @@ end function svd_trunc_with_err!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) U, S, Vᴴ = USVᴴ check_input(svd_trunc_with_err!, A, (U, S, Vᴴ), alg.alg) - _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) + _gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) @@ -444,11 +444,11 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) if alg isa GPU_QRIteration isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" - _gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ) + _gpu_gesvd_maybe_transpose!(A, diagview(S), U, Vᴴ) elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg_kwargs...) + _gpu_Xgesvdp!(A, diagview(S), U, Vᴴ; alg_kwargs...) elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, S.diag, U, Vᴴ; alg_kwargs...) + _gpu_gesvdj!(A, diagview(S), U, Vᴴ; alg_kwargs...) else throw(ArgumentError("Unsupported SVD algorithm")) end From 1fb8b6aac8ce01f025cf709addf16155c2885ace Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 23:55:01 +0100 Subject: [PATCH 4/6] Reduce duplication --- src/implementations/svd.jl | 39 ++++---------------------------------- 1 file changed, 4 insertions(+), 35 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index bd89ed54..d1ac444c 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -3,8 +3,7 @@ copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A) copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A) copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A) -copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A) -copy_input(::typeof(svd_trunc_with_err), A) = copy_input(svd_compact, A) +copy_input(::Union{typeof(svd_trunc), typeof(svd_trunc_with_err)}, A) = copy_input(svd_compact, A) copy_input(::typeof(svd_full), A::Diagonal) = copy(A) @@ -90,10 +89,7 @@ end function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm) return similar(A, real(eltype(A)), (min(size(A)...),)) end -function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm) - return initialize_output(svd_compact!, A, alg.alg) -end -function initialize_output(::typeof(svd_trunc_with_err!), A, alg::TruncatedAlgorithm) +function initialize_output(::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A, alg::TruncatedAlgorithm) return initialize_output(svd_compact!, A, alg.alg) end @@ -273,7 +269,7 @@ end ### function check_input( - ::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized + ::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized ) m, n = size(A) minmn = min(m, n) @@ -288,35 +284,8 @@ function check_input( return nothing end -function check_input( - ::typeof(svd_trunc_with_err!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized - ) - m, n = size(A) - minmn = min(m, n) - U, S, Vᴴ = USVᴴ - @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix - @check_size(U, (m, m)) - @check_scalar(U, A) - @check_size(S, (minmn, minmn)) - @check_scalar(S, A, real) - @check_size(Vᴴ, (n, n)) - @check_scalar(Vᴴ, A) - return nothing -end - -function initialize_output( - ::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} - ) - m, n = size(A) - minmn = min(m, n) - U = similar(A, (m, m)) - S = Diagonal(similar(A, real(eltype(A)), (minmn,))) - Vᴴ = similar(A, (n, n)) - return (U, S, Vᴴ) -end - function initialize_output( - ::typeof(svd_trunc_with_err!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} + ::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} ) m, n = size(A) minmn = min(m, n) From 7d649de9b49529eeca89cb9c5cc3b069406dfa37 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 15 Dec 2025 19:40:23 +0100 Subject: [PATCH 5/6] Switch to svd_trunc and svd_trunc_no_error --- docs/src/user_interface/truncations.md | 6 ++-- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 20 ++++++------ .../MatrixAlgebraKitMooncakeExt.jl | 14 ++++----- src/MatrixAlgebraKit.jl | 4 +-- src/implementations/svd.jl | 20 ++++++------ src/interface/svd.jl | 31 ++++++++++--------- test/amd/svd.jl | 4 +-- test/chainrules.jl | 18 +++++------ test/cuda/svd.jl | 8 ++--- test/genericlinearalgebra/svd.jl | 10 +++--- test/mooncake.jl | 16 +++++----- test/svd.jl | 20 ++++++------ 12 files changed, 86 insertions(+), 85 deletions(-) diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index c57a869c..ee730020 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -113,16 +113,16 @@ combined_trunc = truncrank(10) & trunctol(; atol = 1e-6); ## Truncation Error -When using truncated decompositions such as [`svd_trunc_with_err`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned. +When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned. This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality. -For `svd_trunc_with_err` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix. +For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix. For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality. For example: ```jldoctest truncations; output=false using LinearAlgebra: norm -U, S, Vᴴ, ϵ = svd_trunc_with_err(A; trunc=truncrank(2)) +U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2)) norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values # output diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 42cce8ab..c2de1758 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -170,15 +170,15 @@ for svd_f in (:svd_compact, :svd_full) end end -function ChainRulesCore.rrule(::typeof(svd_trunc_with_err!), A, USVᴴ, alg::TruncatedAlgorithm) +function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm) Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) ϵ = truncation_error(diagview(USVᴴ[2]), ind) - return (USVᴴ′..., ϵ), _make_svd_trunc_with_err_pullback(A, USVᴴ, ind) + return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) end -function _make_svd_trunc_with_err_pullback(A, USVᴴ, ind) - function svd_trunc_with_err_pullback(ΔUSVᴴϵ) +function _make_svd_trunc_pullback(A, USVᴴ, ind) + function svd_trunc_pullback(ΔUSVᴴϵ) ΔA = zero(A) ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) @@ -187,26 +187,26 @@ function _make_svd_trunc_with_err_pullback(A, USVᴴ, ind) MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_with_err_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? + function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end - return svd_trunc_with_err_pullback + return svd_trunc_pullback end -function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm) +function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::TruncatedAlgorithm) Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind) + return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind) end -function _make_svd_trunc_pullback(A, USVᴴ, ind) +function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind) function svd_trunc_pullback(ΔUSVᴴ) ΔA = zero(A) ΔU, ΔS, ΔVᴴ = ΔUSVᴴ MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? + function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return svd_trunc_pullback diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 676bc1db..c2814ef6 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -303,14 +303,14 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co return S_codual, svd_vals_adjoint end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_with_err), Any, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, alg_dalg::CoDual) +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A_ = Mooncake.primal(A_dA) dA_ = Mooncake.tangent(A_dA) A, dA = arrayify(A_, dA_) alg = Mooncake.primal(alg_dalg) - output = svd_trunc_with_err(A, alg) + output = svd_trunc(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but @@ -319,7 +319,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, al function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc_with_err does not yet support non-zero tangent for the truncation error" + abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" U, dU = arrayify(Utrunc, dUtrunc_) S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) @@ -332,14 +332,14 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_with_err)}, A_dA::CoDual, al return output_codual, svd_trunc_adjoint end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A_ = Mooncake.primal(A_dA) dA_ = Mooncake.tangent(A_dA) A, dA = arrayify(A_, dA_) alg = Mooncake.primal(alg_dalg) - output = svd_trunc(A, alg) + output = svd_trunc_no_error(A, alg) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index e1052fa3..85cfc633 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -16,8 +16,8 @@ export project_hermitian, project_antihermitian, project_isometric export project_hermitian!, project_antihermitian!, project_isometric! export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null! -export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_with_err -export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_with_err! +export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_no_error +export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_no_error! export eigh_full, eigh_vals, eigh_trunc export eigh_full!, eigh_vals!, eigh_trunc! export eig_full, eig_vals, eig_trunc diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index d1ac444c..4263c592 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -3,7 +3,7 @@ copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A) copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A) copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A) -copy_input(::Union{typeof(svd_trunc), typeof(svd_trunc_with_err)}, A) = copy_input(svd_compact, A) +copy_input(::Union{typeof(svd_trunc), typeof(svd_trunc_no_error)}, A) = copy_input(svd_compact, A) copy_input(::typeof(svd_full), A::Diagonal) = copy(A) @@ -89,7 +89,7 @@ end function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm) return similar(A, real(eltype(A)), (min(size(A)...),)) end -function initialize_output(::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A, alg::TruncatedAlgorithm) +function initialize_output(::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A, alg::TruncatedAlgorithm) return initialize_output(svd_compact!, A, alg.alg) end @@ -206,13 +206,13 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end -function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) +function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) return USVᴴtrunc end -function svd_trunc_with_err!(A, USVᴴ, alg::TruncatedAlgorithm) +function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) ϵ = truncation_error!(diagview(S), ind) @@ -269,7 +269,7 @@ end ### function check_input( - ::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized + ::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized ) m, n = size(A) minmn = min(m, n) @@ -285,7 +285,7 @@ function check_input( end function initialize_output( - ::Union{typeof(svd_trunc!), typeof(svd_trunc_with_err!)}, A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} + ::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} ) m, n = size(A) minmn = min(m, n) @@ -369,9 +369,9 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) return USVᴴ end -function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) +function svd_trunc_no_error!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) U, S, Vᴴ = USVᴴ - check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) + check_input(svd_trunc_no_error!, A, (U, S, Vᴴ), alg.alg) _gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong @@ -383,9 +383,9 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran return Utr, Str, Vᴴtr end -function svd_trunc_with_err!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) +function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) U, S, Vᴴ = USVᴴ - check_input(svd_trunc_with_err!, A, (U, S, Vᴴ), alg.alg) + check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) _gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong diff --git a/src/interface/svd.jl b/src/interface/svd.jl index b60a5839..0a349901 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -42,10 +42,10 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and @functiondef svd_compact """ - svd_trunc_with_err(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ - svd_trunc_with_err(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ - svd_trunc_with_err!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ - svd_trunc_with_err!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ + svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ Compute a partial or truncated singular value decomposition (SVD) of `A`, such that `A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size @@ -86,22 +86,23 @@ truncation strategy is already embedded in the algorithm. possibly destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `USVᴴ` as output. -See also [`svd_trunc(!)`](@ref svd_trunc), [`svd_full(!)`](@ref svd_full), +See also [`svd_trunc_no_error(!)`](@ref svd_trunc), [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact), [`svd_vals(!)`](@ref svd_vals), and [Truncations](@ref) for more information on truncation strategies. """ -@functiondef svd_trunc_with_err +@functiondef svd_trunc """ - svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ - svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ - svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ - svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ + svd_trunc_no_error(A; [trunc], kwargs...) -> U, S, Vᴴ + svd_trunc_no_error(A, alg::AbstractAlgorithm) -> U, S, Vᴴ + svd_trunc_no_error!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ + svd_trunc_no_error!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ Compute a partial or truncated singular value decomposition (SVD) of `A`, such that `A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size `(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy. +The truncation error is *not* returned. ## Truncation The truncation strategy can be controlled via the `trunc` keyword argument. This can be @@ -130,15 +131,15 @@ When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be spec truncation strategy is already embedded in the algorithm. !!! note - The bang method `svd_trunc!` optionally accepts the output structure and + The bang method `svd_trunc_no_error!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `USVᴴ` as output. See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact), -[`svd_vals(!)`](@ref svd_vals), and [Truncations](@ref) for more information on -truncation strategies. +[`svd_vals(!)`](@ref svd_vals), [`svd_trunc(!)`](@ref svd_trunc) and +[Truncations](@ref) for more information on truncation strategies. """ -@functiondef svd_trunc +@functiondef svd_trunc_no_error """ svd_vals(A; kwargs...) -> S @@ -173,7 +174,7 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!) end end -for f in (:svd_trunc!, :svd_trunc_with_err!) +for f in (:svd_trunc!, :svd_trunc_no_error!) @eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...) if alg isa TruncatedAlgorithm isnothing(trunc) || diff --git a/test/amd/svd.jl b/test/amd/svd.jl index d681acc1..fcd5b490 100644 --- a/test/amd/svd.jl +++ b/test/amd/svd.jl @@ -140,14 +140,14 @@ end # minmn = min(m, n) # r = minmn - 2 # -# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc=truncrank(r)) +# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) # @test length(S1.diag) == r # @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] # # s = 1 + sqrt(eps(real(T))) # trunc2 = trunctol(; atol=s * S₀[r + 1]) # -# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) +# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) # @test length(S2.diag) == r # @test U1 ≈ U2 # @test S1 ≈ S2 diff --git a/test/chainrules.jl b/test/chainrules.jl index 7219b22a..4be77380 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -12,7 +12,7 @@ for f in ( :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, - :svd_compact, :svd_trunc, :svd_trunc_with_err, :svd_vals, + :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, :left_polar, :right_polar, ) copy_f = Symbol(:copy_, f) @@ -430,12 +430,12 @@ end ΔUtrunc = ΔU[:, ind] ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( - copy_svd_trunc_with_err, A, truncalg ⊢ NoTangent(); + copy_svd_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); + copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), atol = atol, rtol = rtol ) @@ -452,12 +452,12 @@ end ΔUtrunc = ΔU[:, ind] ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( - copy_svd_trunc_with_err, A, truncalg ⊢ NoTangent(); + copy_svd_trunc, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) test_rrule( - copy_svd_trunc, A, truncalg ⊢ NoTangent(); + copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), atol = atol, rtol = rtol ) @@ -485,13 +485,13 @@ end trunc = truncrank(r) ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) test_rrule( - config, svd_trunc_with_err, A; + config, svd_trunc, A; fkwargs = (; trunc = trunc), output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) test_rrule( - config, svd_trunc, A; + config, svd_trunc_no_error, A; fkwargs = (; trunc = trunc), output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false @@ -500,13 +500,13 @@ end trunc = trunctol(; atol = S[1, 1] / 2) ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) test_rrule( - config, svd_trunc_with_err, A; + config, svd_trunc, A; fkwargs = (; trunc = trunc), output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) test_rrule( - config, svd_trunc, A; + config, svd_trunc_no_error, A; fkwargs = (; trunc = trunc), output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index 14e26991..8d931b3b 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -140,11 +140,11 @@ end S₀ = svd_vals(hA) r = k - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1 - U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ = @constinferred svd_trunc_no_error(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @@ -152,13 +152,13 @@ end s = 1 + sqrt(eps(real(T))) trunc2 = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) @test length(S2.diag) == r @test U1 ≈ U2 @test parent(S1) ≈ parent(S2) @test V1ᴴ ≈ V2ᴴ - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) + U2, S2, V2ᴴ = @constinferred svd_trunc_no_error(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) @test length(S2.diag) == r @test U1 ≈ U2 @test parent(S1) ≈ parent(S2) diff --git a/test/genericlinearalgebra/svd.jl b/test/genericlinearalgebra/svd.jl index 43feb30f..50effd6b 100644 --- a/test/genericlinearalgebra/svd.jl +++ b/test/genericlinearalgebra/svd.jl @@ -105,7 +105,7 @@ end minmn = min(m, n) r = minmn - 2 - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(diagview(S1)) == r @test diagview(S1) ≈ S₀[1:r] @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @@ -115,7 +115,7 @@ end s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) @test length(diagview(S2)) == r @test U1 ≈ U2 @test S1 ≈ S2 @@ -123,7 +123,7 @@ end @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc_with_err(A; alg, trunc) + U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) @test length(diagview(S3)) == r @test U1 ≈ U3 @test S1 ≈ S3 @@ -164,9 +164,9 @@ end Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ alg = TruncatedAlgorithm(GLA_QRIteration(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) - @test_throws ArgumentError svd_trunc_with_err(A; alg, trunc = (; maxrank = 2)) + @test_throws ArgumentError svd_trunc_no_error(A; alg, trunc = (; maxrank = 2)) end diff --git a/test/mooncake.jl b/test/mooncake.jl index f2b506ae..a47bbb8d 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -456,11 +456,11 @@ end dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) ϵ = zero(real(T)) dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc_with_err, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc_with_err!, svd_trunc_with_err, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) end @testset "trunctol" begin U, S, Vᴴ = svd_compact(A) @@ -482,11 +482,11 @@ end dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) ϵ = zero(real(T)) dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc_with_err, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc_with_err!, svd_trunc_with_err, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) end end end diff --git a/test/svd.jl b/test/svd.jl index 09d62acf..a41e075c 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -129,7 +129,7 @@ end minmn = min(m, n) r = minmn - 2 - U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc_with_err(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(diagview(S1)) == r @test diagview(S1) ≈ S₀[1:r] @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @@ -139,7 +139,7 @@ end s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg, trunc) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) @test length(diagview(S2)) == r @test U1 ≈ U2 @test S1 ≈ S2 @@ -147,7 +147,7 @@ end @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc_with_err(A; alg, trunc) + U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) @test length(diagview(S3)) == r @test U1 ≈ U3 @test S1 ≈ S3 @@ -177,11 +177,11 @@ end (rtol, maxrank) -> (; rtol, maxrank), (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), ) - U1, S1, V1ᴴ, ϵ1 = svd_trunc_with_err(A; alg, trunc = trunc_fun(0.2, 1)) + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) @test length(diagview(S1)) == 1 @test diagview(S1) ≈ diagview(S)[1:1] - U2, S2, V2ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + U2, S2, V2ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 3)) @test length(diagview(S2)) == 2 @test diagview(S2) ≈ diagview(S)[1:2] end @@ -197,13 +197,13 @@ end Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc_with_err(A; alg) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg) + U2, S2, V2ᴴ = @constinferred svd_trunc_no_error(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] - @test_throws ArgumentError svd_trunc_with_err(A; alg, trunc = (; maxrank = 2)) @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) + @test_throws ArgumentError svd_trunc_no_error(A; alg, trunc = (; maxrank = 2)) end @testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) @@ -236,10 +236,10 @@ end @test S2 ≈ diagview(S) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc_with_err(A; alg) + U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) @test diagview(S3) ≈ S2[1:min(m, 2)] @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol - U3, S3, Vᴴ3 = @constinferred svd_trunc(A; alg) + U3, S3, Vᴴ3 = @constinferred svd_trunc_no_error(A; alg) @test diagview(S3) ≈ S2[1:min(m, 2)] end end From 63542f002717b0a7f75a543e492f493f834360d6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 15 Dec 2025 23:11:42 +0100 Subject: [PATCH 6/6] update docstring and test --- src/interface/svd.jl | 2 +- test/genericlinearalgebra/svd.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 0a349901..6ca88b0c 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -86,7 +86,7 @@ truncation strategy is already embedded in the algorithm. possibly destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `USVᴴ` as output. -See also [`svd_trunc_no_error(!)`](@ref svd_trunc), [`svd_full(!)`](@ref svd_full), +See also [`svd_trunc_no_error(!)`](@ref svd_trunc_no_error), [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact), [`svd_vals(!)`](@ref svd_vals), and [Truncations](@ref) for more information on truncation strategies. """ diff --git a/test/genericlinearalgebra/svd.jl b/test/genericlinearalgebra/svd.jl index 50effd6b..9cfdabf7 100644 --- a/test/genericlinearalgebra/svd.jl +++ b/test/genericlinearalgebra/svd.jl @@ -145,11 +145,11 @@ end (rtol, maxrank) -> (; rtol, maxrank), (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), ) - U1, S1, V1ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) + U1, S1, V1ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 1)) @test length(diagview(S1)) == 1 @test diagview(S1) ≈ diagview(S)[1:1] - U2, S2, V2ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + U2, S2, V2ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 3)) @test length(diagview(S2)) == 2 @test diagview(S2) ≈ diagview(S)[1:2] end