diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index 40799890..831dffc9 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -23,3 +23,15 @@ maxdim = 10 atol = 1e-6 combined_trunc = truncrank(maxdim) & trunctol(; atol) ``` + +## Truncation Error + +When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), +an additional truncation error value is returned. This error is defined as the 2-norm of the discarded +singular values or eigenvalues, providing a measure of the approximation quality. + +For example: +```julia +U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(10)) +# ϵ is the 2-norm of the discarded singular values +``` diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index d8664d80..f98a933a 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -112,16 +112,20 @@ for eig in (:eig, :eigh) function ChainRulesCore.rrule(::typeof($eig_t!), A, DV, alg::TruncatedAlgorithm) Ac = copy_input($eig_f, A) DV = $(eig_f!)(Ac, DV, alg.alg) - DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) - return DV′, $(_make_eig_t_pb)(A, DV, ind) + DV′, ind, truncerr = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) + return (DV′..., truncerr), $(_make_eig_t_pb)(A, DV, ind) end function $(_make_eig_t_pb)(A, DV, ind) - function $eig_t_pb(ΔDV) + function $eig_t_pb(ΔDV_and_err) + # Extract the tangents for D and V, ignore the truncerr tangent + ΔD = ΔDV_and_err[1] + ΔV = ΔDV_and_err[2] + # Ignore ΔDV_and_err[3] which is the truncerr tangent ΔA = zero(A) - MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind) + MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, (unthunk(ΔD), unthunk(ΔV)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful? + function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return $eig_t_pb @@ -151,16 +155,21 @@ end function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm) Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) - USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind) + USVᴴ′, ind, truncerr = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) + return (USVᴴ′..., truncerr), _make_svd_trunc_pullback(A, USVᴴ, ind) end function _make_svd_trunc_pullback(A, USVᴴ, ind) - function svd_trunc_pullback(ΔUSVᴴ) + function svd_trunc_pullback(ΔUSVᴴ_and_err) + # Extract the tangents for U, S, and Vᴴ, ignore the truncerr tangent + ΔU = ΔUSVᴴ_and_err[1] + ΔS = ΔUSVᴴ_and_err[2] + ΔVᴴ = ΔUSVᴴ_and_err[3] + # Ignore ΔUSVᴴ_and_err[4] which is the truncerr tangent ΔA = zero(A) - MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind) + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, (unthunk(ΔU), unthunk(ΔS), unthunk(ΔVᴴ)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? + function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return svd_trunc_pullback diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 48f1deb4..9d298a10 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -108,7 +108,8 @@ end function eig_trunc!(A, DV, alg::TruncatedAlgorithm) D, V = eig_full!(A, DV, alg.alg) - return first(truncate(eig_trunc!, (D, V), alg.trunc)) + result, _, truncerr = truncate(eig_trunc!, (D, V), alg.trunc) + return result..., truncerr end # Diagonal logic diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index c7dc4dab..1c15f53d 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -111,7 +111,8 @@ end function eigh_trunc!(A, DV, alg::TruncatedAlgorithm) D, V = eigh_full!(A, DV, alg.alg) - return first(truncate(eigh_trunc!, (D, V), alg.trunc)) + result, _, truncerr = truncate(eigh_trunc!, (D, V), alg.trunc) + return result..., truncerr end # Diagonal logic diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 76e32ce1..0d2e0baf 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -128,7 +128,7 @@ function left_orth_svd!(A, VC, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) check_input(left_orth!, A, VC, alg′) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - U, S, Vᴴ = svd_trunc!(A, alg_trunc) + U, S, Vᴴ, _ = svd_trunc!(A, alg_trunc) V, C = VC return copy!(V, U), mul!(C, S, Vᴴ) end @@ -138,7 +138,7 @@ function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) V, C = VC S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc) + U, S, Vᴴ, _ = svd_trunc!(A, (V, S, C), alg_trunc) return U, lmul!(S, Vᴴ) end @@ -189,7 +189,7 @@ function right_orth_svd!(A, CVᴴ, alg, trunc) alg′ = select_algorithm(svd_compact!, A, alg) check_input(right_orth!, A, CVᴴ, alg′) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - U, S, Vᴴ′ = svd_trunc!(A, alg_trunc) + U, S, Vᴴ′, _ = svd_trunc!(A, alg_trunc) C, Vᴴ = CVᴴ return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) end @@ -199,7 +199,7 @@ function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc) alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) C, Vᴴ = CVᴴ S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc) + U, S, Vᴴ, _ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc) return rmul!(U, S), Vᴴ end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 1a735d57..9076d290 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -238,7 +238,8 @@ end function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg) - return first(truncate(svd_trunc!, USVᴴ′, alg.trunc)) + result, _, truncerr = truncate(svd_trunc!, USVᴴ′, alg.trunc) + return result..., truncerr end # Diagonal logic @@ -381,7 +382,8 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make this controllable using a `gaugefix` keyword argument gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) - return first(truncate(svd_trunc!, USVᴴ, alg.trunc)) + result, _, truncerr = truncate(svd_trunc!, USVᴴ, alg.trunc) + return result..., truncerr end function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 239b27dc..6beddafa 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -1,29 +1,57 @@ +# Compute truncation error as 2-norm of discarded values +function _compute_truncerr(values::AbstractVector, ind) + # Find indices that are NOT in ind (i.e., discarded values) + if ind isa Colon + # No truncation, all values kept + return zero(real(eltype(values))) + elseif ind isa AbstractVector{Bool} + # Boolean indexing: discarded values are where ind is false + discarded_vals = view(values, .!ind) + else + # Integer indexing: need to find complement + all_inds = Set(eachindex(values)) + kept_inds = Set(ind) + discarded_inds = collect(setdiff(all_inds, kept_inds)) + discarded_vals = view(values, discarded_inds) + end + # Compute 2-norm of discarded values + return sqrt(sum(abs2, discarded_vals)) +end + # truncate # -------- # Generic implementation: `findtruncated` followed by indexing function truncate(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy) ind = findtruncated_svd(diagview(S), strategy) - return (U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]), ind + Svals = diagview(S) + truncerr = _compute_truncerr(Svals, ind) + return (U[:, ind], Diagonal(Svals[ind]), Vᴴ[ind, :]), ind, truncerr end function truncate(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) - return (Diagonal(diagview(D)[ind]), V[:, ind]), ind + Dvals = diagview(D) + truncerr = _compute_truncerr(Dvals, ind) + return (Diagonal(Dvals[ind]), V[:, ind]), ind, truncerr end function truncate(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy) ind = findtruncated(diagview(D), strategy) - return (Diagonal(diagview(D)[ind]), V[:, ind]), ind + Dvals = diagview(D) + truncerr = _compute_truncerr(Dvals, ind) + return (Diagonal(Dvals[ind]), V[:, ind]), ind, truncerr end function truncate(::typeof(left_null!), (U, S), strategy::TruncationStrategy) # TODO: avoid allocation? extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2)))) ind = findtruncated(extended_S, strategy) - return U[:, ind], ind + truncerr = _compute_truncerr(extended_S, ind) + return U[:, ind], ind, truncerr end function truncate(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy) # TODO: avoid allocation? extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1)))) ind = findtruncated(extended_S, strategy) - return Vᴴ[ind, :], ind + truncerr = _compute_truncerr(extended_S, ind) + return Vᴴ[ind, :], ind, truncerr end # findtruncated diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 5a74f79f..8a0107f3 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -31,16 +31,19 @@ See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc). @functiondef eig_full """ - eig_trunc(A; kwargs...) -> D, V - eig_trunc(A, alg::AbstractAlgorithm) -> D, V - eig_trunc!(A, [DV]; kwargs...) -> D, V - eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V + eig_trunc(A; kwargs...) -> D, V, ϵ + eig_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ + eig_trunc!(A, [DV]; kwargs...) -> D, V, ϵ + eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute a partial or truncated eigenvalue decomposition of the matrix `A`, such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues, selected according to a truncation strategy. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded eigenvalues. + !!! note The bang method `eig_trunc!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index c51b2c01..d221b88e 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -31,15 +31,18 @@ See also [`eigh_vals(!)`](@ref eigh_vals) and [`eigh_trunc(!)`](@ref eigh_trunc) @functiondef eigh_full """ - eigh_trunc(A; kwargs...) -> D, V - eigh_trunc(A, alg::AbstractAlgorithm) -> D, V - eigh_trunc!(A, [DV]; kwargs...) -> D, V - eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V + eigh_trunc(A; kwargs...) -> D, V, ϵ + eigh_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ + eigh_trunc!(A, [DV]; kwargs...) -> D, V, ϵ + eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute a partial or truncated eigenvalue decomposition of the symmetric or hermitian matrix `A`, such that `A * V ≈ V * D`, where the isometric matrix `V` contains a subset of the orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues, -selected according to a truncation strategy. +selected according to a truncation strategy. + +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded eigenvalues. !!! note The bang method `eigh_trunc!` optionally accepts the output structure and diff --git a/src/interface/svd.jl b/src/interface/svd.jl index ee6cd9af..edd6a3b9 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -43,16 +43,19 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and # TODO: decide if we should have `svd_trunc!!` instead """ - svd_trunc(A; kwargs...) -> U, S, Vᴴ - svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ - svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ - svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ + svd_trunc(A; kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ Compute a partial or truncated singular value decomposition (SVD) of `A`, such that -`A * (Vᴴ)' = U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size +`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size `(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded singular values. + !!! note The bang method `svd_trunc!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl index 0e2cb5be..cdad6f95 100644 --- a/test/amd/eigh.jl +++ b/test/amd/eigh.jl @@ -46,14 +46,14 @@ end r = m - 2 s = 1 + sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) @test length(diagview(D1)) == r @test isisometry(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] trunc = trunctol(; atol=s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) @test A * V2 ≈ V2 * D2 @@ -75,7 +75,7 @@ end A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) end=# diff --git a/test/amd/svd.jl b/test/amd/svd.jl index 2cb15b9f..d49e4eed 100644 --- a/test/amd/svd.jl +++ b/test/amd/svd.jl @@ -103,14 +103,14 @@ end # minmn = min(m, n) # r = minmn - 2 -# U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) +# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) # @test length(S1.diag) == r # @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] # s = 1 + sqrt(eps(real(T))) # trunc2 = trunctol(; atol=s * S₀[r + 1]) -# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) +# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) # @test length(S2.diag) == r # @test U1 ≈ U2 # @test S1 ≈ S2 diff --git a/test/chainrules.jl b/test/chainrules.jl index 99f8007e..ca102bd3 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -275,7 +275,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, ZeroTangent()), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -290,7 +290,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, ZeroTangent()), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -351,7 +351,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, ZeroTangent()), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -366,7 +366,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, ZeroTangent()), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -446,7 +446,7 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, ZeroTangent()), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) @@ -463,7 +463,7 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, ZeroTangent()), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) diff --git a/test/cuda/eig.jl b/test/cuda/eig.jl index 5cbdaecf..6a44ea73 100644 --- a/test/cuda/eig.jl +++ b/test/cuda/eig.jl @@ -44,13 +44,13 @@ end rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin - D1, V1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r)) @test length(D1.diag) == r @test A * V1 ≈ V1 * D1 s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol=s * abs(D₀[r + 1])) - D2, V2 = @constinferred eig_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl index eece6be3..a12763dc 100644 --- a/test/cuda/eigh.jl +++ b/test/cuda/eigh.jl @@ -41,14 +41,14 @@ end r = m - 2 s = 1 + sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) @test length(diagview(D1)) == r @test isisometry(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) @test A * V2 ≈ V2 * D2 @@ -70,7 +70,7 @@ end A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) end=# diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index dfc95fe6..f9dc1232 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -103,7 +103,7 @@ end minmn = min(m, n) r = k - U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @@ -111,7 +111,7 @@ end s = 1 + sqrt(eps(real(T))) trunc2 = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) @test length(S2.diag) == r @test U1 ≈ U2 @test parent(S1) ≈ parent(S2) diff --git a/test/eig.jl b/test/eig.jl index ad708f04..c6f60af2 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -43,19 +43,21 @@ end rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin - D1, V1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) + D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) @test length(D1.diag) == r @test A * V1 ≈ V1 * D1 + # Test truncation error + @test ϵ1 ≈ norm(@view(D₀[(r + 1):end])) s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * abs(D₀[r + 1])) - D2, V2 = @constinferred eig_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3 = @constinferred eig_trunc(A; alg, trunc) + D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 @@ -75,12 +77,12 @@ end D = Diagonal([0.9, 0.3, 0.1, 0.01]) A = V * D * inv(V) alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) - D2, V2 = @constinferred eig_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) - D3, V3 = @constinferred eig_trunc(A; alg) + D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) end @@ -101,6 +103,6 @@ end A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2 = @constinferred eig_trunc(A2; alg) + D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] end diff --git a/test/eigh.jl b/test/eigh.jl index fa766abb..19824e3c 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -50,21 +50,23 @@ end r = m - 2 s = 1 + sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) @test length(diagview(D1)) == r @test isisometry(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + # Test truncation error + @test ϵ1 ≈ norm(@view(D₀[(r + 1):end])) trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) @test A * V2 ≈ V2 * D2 s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3 = @constinferred eigh_trunc(A; alg, trunc) + D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 @@ -84,12 +86,12 @@ end A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) - D3, V3 = @constinferred eigh_trunc(A; alg) + D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) end @@ -111,6 +113,6 @@ end A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A2; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] end diff --git a/test/svd.jl b/test/svd.jl index 21d34337..e62185bb 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -113,25 +113,29 @@ end minmn = min(m, n) r = minmn - 2 - U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + # Test truncation error + @test ϵ1 ≈ norm(@view(S₀[(r + 1):end])) s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) @test length(S2.diag) == r @test U1 ≈ U2 @test S1 ≈ S2 @test V1ᴴ ≈ V2ᴴ + @test ϵ2 ≈ ϵ1 trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ = @constinferred svd_trunc(A; alg, trunc) + U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) @test length(S3.diag) == r @test U1 ≈ U3 @test S1 ≈ S3 @test V1ᴴ ≈ V3ᴴ + @test ϵ3 ≈ ϵ1 end end end