diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 549f4a53..c2de1758 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -193,6 +193,25 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind) return svd_trunc_pullback end +function ChainRulesCore.rrule(::typeof(svd_trunc_no_error!), A, USVᴴ, alg::TruncatedAlgorithm) + Ac = copy_input(svd_compact, A) + USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) + USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + return USVᴴ′, _make_svd_trunc_no_error_pullback(A, USVᴴ, ind) +end +function _make_svd_trunc_no_error_pullback(A, USVᴴ, ind) + function svd_trunc_pullback(ΔUSVᴴ) + ΔA = zero(A) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return svd_trunc_pullback +end + function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg) USVᴴ = svd_compact(A, alg) function svd_vals_pullback(ΔS) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index aa16f61e..c2814ef6 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -319,7 +319,35 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" + abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error" + U, dU = arrayify(Utrunc, dUtrunc_) + S, dS = arrayify(Strunc, dStrunc_) + Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) + svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + MatrixAlgebraKit.zero!(dU) + MatrixAlgebraKit.zero!(dS) + MatrixAlgebraKit.zero!(dVᴴ) + return NoRData(), NoRData(), NoRData() + end + return output_codual, svd_trunc_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual) + # compute primal + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + alg = Mooncake.primal(alg_dalg) + output = svd_trunc_no_error(A, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function svd_trunc_adjoint(::NoRData) + Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual) + dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual) U, dU = arrayify(Utrunc, dUtrunc_) S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index fd97497b..85cfc633 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -16,8 +16,8 @@ export project_hermitian, project_antihermitian, project_isometric export project_hermitian!, project_antihermitian!, project_isometric! export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null! -export svd_compact, svd_full, svd_vals, svd_trunc -export svd_compact!, svd_full!, svd_vals!, svd_trunc! +export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_no_error +export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_no_error! export eigh_full, eigh_vals, eigh_trunc export eigh_full!, eigh_vals!, eigh_trunc! export eig_full, eig_vals, eig_trunc diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 126e6a04..4263c592 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -3,7 +3,7 @@ copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A) copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A) copy_input(::typeof(svd_vals), A) = copy_input(svd_full, A) -copy_input(::typeof(svd_trunc), A) = copy_input(svd_compact, A) +copy_input(::Union{typeof(svd_trunc), typeof(svd_trunc_no_error)}, A) = copy_input(svd_compact, A) copy_input(::typeof(svd_full), A::Diagonal) = copy(A) @@ -89,7 +89,7 @@ end function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlgorithm) return similar(A, real(eltype(A)), (min(size(A)...),)) end -function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm) +function initialize_output(::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A, alg::TruncatedAlgorithm) return initialize_output(svd_compact!, A, alg.alg) end @@ -159,17 +159,17 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) if alg isa LAPACK_QRIteration isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) - YALAPACK.gesvd!(A, S.diag, U, Vᴴ) + YALAPACK.gesvd!(A, diagview(S), U, Vᴴ) elseif alg isa LAPACK_DivideAndConquer isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) - YALAPACK.gesdd!(A, S.diag, U, Vᴴ) + YALAPACK.gesdd!(A, diagview(S), U, Vᴴ) elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg_kwargs...) + YALAPACK.gesvdx!(A, diagview(S), U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) - YALAPACK.gesvj!(A, S.diag, U, Vᴴ) + YALAPACK.gesvj!(A, diagview(S), U, Vᴴ) else throw(ArgumentError("Unsupported SVD algorithm")) end @@ -206,19 +206,16 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end -function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} - ϵ = similar(A, real(eltype(A)), compute_error) - (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) +function svd_trunc_no_error!(A, USVᴴ, alg::TruncatedAlgorithm) + U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + return USVᴴtrunc end -function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ} - U, S, Vᴴ, ϵ = USVᴴϵ - U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg) +function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) + U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - if !isempty(ϵ) - ϵ .= truncation_error!(diagview(S), ind) - end + ϵ = truncation_error!(diagview(S), ind) return USVᴴtrunc..., ϵ end @@ -272,7 +269,7 @@ end ### function check_input( - ::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized + ::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized ) m, n = size(A) minmn = min(m, n) @@ -288,7 +285,7 @@ function check_input( end function initialize_output( - ::typeof(svd_trunc!), A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} + ::Union{typeof(svd_trunc!), typeof(svd_trunc_no_error!)}, A::AbstractMatrix, alg::TruncatedAlgorithm{<:CUSOLVER_Randomized} ) m, n = size(A) minmn = min(m, n) @@ -372,22 +369,34 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) return USVᴴ end -function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ} - U, S, Vᴴ, ϵ = USVᴴϵ +function svd_trunc_no_error!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) + U, S, Vᴴ = USVᴴ + check_input(svd_trunc_no_error!, A, (U, S, Vᴴ), alg.alg) + _gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...) + + # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong + (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + + do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool + do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) + + return Utr, Str, Vᴴtr +end + +function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) + U, S, Vᴴ = USVᴴ check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) - _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) + _gpu_Xgesvdr!(A, diagview(S), U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - if !isempty(ϵ) - # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum - normS = norm(diagview(Str)) - normA = norm(A) - # equivalent to sqrt(normA^2 - normS^2) - # but may be more accurate - ϵ = sqrt((normA + normS) * (normA - normS)) - end + # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum + normS = norm(diagview(Str)) + normA = norm(A) + # equivalent to sqrt(normA^2 - normS^2) + # but may be more accurate + ϵ = sqrt((normA + normS) * (normA - normS)) do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) @@ -404,11 +413,11 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) if alg isa GPU_QRIteration isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" - _gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ) + _gpu_gesvd_maybe_transpose!(A, diagview(S), U, Vᴴ) elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg_kwargs...) + _gpu_Xgesvdp!(A, diagview(S), U, Vᴴ; alg_kwargs...) elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, S.diag, U, Vᴴ; alg_kwargs...) + _gpu_gesvdj!(A, diagview(S), U, Vᴴ; alg_kwargs...) else throw(ArgumentError("Unsupported SVD algorithm")) end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 2ea26204..6ca88b0c 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -86,12 +86,61 @@ truncation strategy is already embedded in the algorithm. possibly destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `USVᴴ` as output. -See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact), -[`svd_vals(!)`](@ref svd_vals), and [Truncations](@ref) for more information on -truncation strategies. +See also [`svd_trunc_no_error(!)`](@ref svd_trunc_no_error), [`svd_full(!)`](@ref svd_full), +[`svd_compact(!)`](@ref svd_compact), [`svd_vals(!)`](@ref svd_vals), +and [Truncations](@ref) for more information on truncation strategies. """ @functiondef svd_trunc +""" + svd_trunc_no_error(A; [trunc], kwargs...) -> U, S, Vᴴ + svd_trunc_no_error(A, alg::AbstractAlgorithm) -> U, S, Vᴴ + svd_trunc_no_error!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ + svd_trunc_no_error!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ + +Compute a partial or truncated singular value decomposition (SVD) of `A`, such that +`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size +`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a +square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy. +The truncation error is *not* returned. + +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + +## Keyword arguments +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. + +When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the +truncation strategy is already embedded in the algorithm. + +!!! note + The bang method `svd_trunc_no_error!` optionally accepts the output structure and + possibly destroys the input matrix `A`. Always use the return value of the function + as it may not always be possible to use the provided `USVᴴ` as output. + +See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact), +[`svd_vals(!)`](@ref svd_vals), [`svd_trunc(!)`](@ref svd_trunc) and +[Truncations](@ref) for more information on truncation strategies. +""" +@functiondef svd_trunc_no_error + """ svd_vals(A; kwargs...) -> S svd_vals(A, alg::AbstractAlgorithm) -> S @@ -125,13 +174,15 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!) end end -function select_algorithm(::typeof(svd_trunc!), A, alg; trunc = nothing, kwargs...) - if alg isa TruncatedAlgorithm - isnothing(trunc) || - throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) - return alg - else - alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) +for f in (:svd_trunc!, :svd_trunc_no_error!) + @eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) + end end end diff --git a/test/chainrules.jl b/test/chainrules.jl index 5258b839..4be77380 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -12,7 +12,7 @@ for f in ( :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, - :svd_compact, :svd_trunc, :svd_vals, + :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, :left_polar, :right_polar, ) copy_f = Symbol(:copy_, f) @@ -434,6 +434,11 @@ end output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) + test_rrule( + copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + atol = atol, rtol = rtol + ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -451,6 +456,11 @@ end output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) + test_rrule( + copy_svd_trunc_no_error, A, truncalg ⊢ NoTangent(); + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + atol = atol, rtol = rtol + ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc)) @test isapprox(dA1, dA2; atol = atol, rtol = rtol) @@ -480,6 +490,12 @@ end output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + test_rrule( + config, svd_trunc_no_error, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) end trunc = trunctol(; atol = S[1, 1] / 2) ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) @@ -489,6 +505,12 @@ end output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + test_rrule( + config, svd_trunc_no_error, A; + fkwargs = (; trunc = trunc), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) end end diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index fc564fec..8d931b3b 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -144,6 +144,9 @@ end @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1 + U1, S1, V1ᴴ = @constinferred svd_trunc_no_error(A; alg, trunc = truncrank(r)) + @test length(S1.diag) == r + @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] if !(alg isa CUSOLVER_Randomized) s = 1 + sqrt(eps(real(T))) @@ -154,6 +157,12 @@ end @test U1 ≈ U2 @test parent(S1) ≈ parent(S2) @test V1ᴴ ≈ V2ᴴ + + U2, S2, V2ᴴ = @constinferred svd_trunc_no_error(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) + @test length(S2.diag) == r + @test U1 ≈ U2 + @test parent(S1) ≈ parent(S2) + @test V1ᴴ ≈ V2ᴴ end end end diff --git a/test/genericlinearalgebra/svd.jl b/test/genericlinearalgebra/svd.jl index f7177e79..9cfdabf7 100644 --- a/test/genericlinearalgebra/svd.jl +++ b/test/genericlinearalgebra/svd.jl @@ -145,11 +145,11 @@ end (rtol, maxrank) -> (; rtol, maxrank), (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), ) - U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) + U1, S1, V1ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 1)) @test length(diagview(S1)) == 1 @test diagview(S1) ≈ diagview(S)[1:1] - U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + U2, S2, V2ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 3)) @test length(diagview(S2)) == 2 @test diagview(S2) ≈ diagview(S)[1:2] end @@ -168,4 +168,5 @@ end @test diagview(S2) ≈ diagview(S)[1:2] @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) + @test_throws ArgumentError svd_trunc_no_error(A; alg, trunc = (; maxrank = 2)) end diff --git a/test/mooncake.jl b/test/mooncake.jl index 3e19e44d..a47bbb8d 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -458,6 +458,9 @@ end dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) end @testset "trunctol" begin U, S, Vᴴ = svd_compact(A) @@ -481,6 +484,9 @@ end dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dUtrunc, dStrunc, dVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc_no_error!, svd_trunc_no_error, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg) end end end diff --git a/test/svd.jl b/test/svd.jl index d055f866..a41e075c 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -181,7 +181,7 @@ end @test length(diagview(S1)) == 1 @test diagview(S1) ≈ diagview(S)[1:1] - U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + U2, S2, V2ᴴ = svd_trunc_no_error(A; alg, trunc = trunc_fun(0.2, 3)) @test length(diagview(S2)) == 2 @test diagview(S2) ≈ diagview(S)[1:2] end @@ -200,7 +200,10 @@ end U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol + U2, S2, V2ᴴ = @constinferred svd_trunc_no_error(A; alg) + @test diagview(S2) ≈ diagview(S)[1:2] @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) + @test_throws ArgumentError svd_trunc_no_error(A; alg, trunc = (; maxrank = 2)) end @testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) @@ -236,5 +239,7 @@ end U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) @test diagview(S3) ≈ S2[1:min(m, 2)] @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol + U3, S3, Vᴴ3 = @constinferred svd_trunc_no_error(A; alg) + @test diagview(S3) ≈ S2[1:min(m, 2)] end end