diff --git a/Project.toml b/Project.toml index e1d0a218..be0bee32 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" CUDA = "5" -JET = "0.9" +JET = "0.9, 0.10" LinearAlgebra = "1" SafeTestsets = "0.1" StableRNGs = "1" diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index e6038d32..ee730020 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -12,17 +12,15 @@ Truncation strategies allow you to control which eigenvalues or singular values Truncation strategies can be used with truncated decomposition functions in two ways, as illustrated below. For concreteness, we use the following matrix as an example: -```jldoctest truncations +```jldoctest truncations; output=false using MatrixAlgebraKit using MatrixAlgebraKit: diagview A = [2 1 0; 1 3 1; 0 1 4]; D, V = eigh_full(A); - diagview(D) ≈ [3 - √3, 3, 3 + √3] # output - true ``` @@ -31,38 +29,35 @@ true The simplest approach is to pass a `NamedTuple` with the truncation parameters. For example, keeping only the largest 2 eigenvalues: -```jldoctest truncations -Dtrunc, Vtrunc = eigh_trunc(A; trunc = (maxrank = 2,)); +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2,)); size(Dtrunc, 1) <= 2 # output - true ``` Note however that there are no guarantees on the order of the output values: -```jldoctest truncations +```jldoctest truncations; output=false diagview(Dtrunc) ≈ diagview(D)[[3, 2]] # output - true ``` You can also use tolerance-based truncation or combine multiple criteria: -```jldoctest truncations -Dtrunc, Vtrunc = eigh_trunc(A; trunc = (atol = 2.9,)); +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (atol = 2.9,)); all(>(2.9), diagview(Dtrunc)) # output - true ``` -```jldoctest truncations -Dtrunc, Vtrunc = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9)); +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9)); size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc)) # output @@ -72,7 +67,7 @@ true In general, the keyword arguments that are supported can be found in the `TruncationStrategy` docstring: ```@docs; canonical = false -TruncationStrategy +TruncationStrategy() ``` @@ -81,33 +76,22 @@ TruncationStrategy For more control, you can construct [`TruncationStrategy`](@ref) objects directly. This is also what the previous syntax will end up calling. -```jldoctest truncations +```jldoctest truncations; output=false Dtrunc, Vtrunc = eigh_trunc(A; trunc = truncrank(2)) size(Dtrunc, 1) <= 2 # output - true ``` -```jldoctest truncations -Dtrunc, Vtrunc = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9)) +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9)) size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc)) # output true ``` -## Truncation with SVD vs Eigenvalue Decompositions - -When using truncations with different decomposition types, keep in mind: - -- **`svd_trunc`**: Singular values are always real and non-negative, sorted in descending order. Truncation by value typically keeps the largest singular values. - -- **`eigh_trunc`**: Eigenvalues are real but can be negative for symmetric matrices. By default, `truncrank` sorts by absolute value, so `truncrank(k)` keeps the `k` eigenvalues with largest magnitude (positive or negative). - -- **`eig_trunc`**: For general (non-symmetric) matrices, eigenvalues can be complex. Truncation by absolute value considers the complex magnitude. - ## Truncation Strategies MatrixAlgebraKit provides several built-in truncation strategies: @@ -127,3 +111,31 @@ When strategies are combined, only the values that satisfy all conditions are ke combined_trunc = truncrank(10) & trunctol(; atol = 1e-6); ``` +## Truncation Error + +When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned. +This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality. +For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix. +For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality. + + +For example: +```jldoctest truncations; output=false +using LinearAlgebra: norm +U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2)) +norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values + +# output +true +``` + +### Truncation with SVD vs Eigenvalue Decompositions + +When using truncations with different decomposition types, keep in mind: + +- **[`svd_trunc`](@ref)**: Singular values are always real and non-negative, sorted in descending order. Truncation by value typically keeps the largest singular values. The truncation error gives the 2-norm difference between the original and the truncated matrix. + +- **[`eigh_trunc`](@ref)**: Eigenvalues are real but can be negative for symmetric matrices. By default, eigenvalues are treated by absolute value, e.g. `truncrank(k)` keeps the `k` eigenvalues with largest magnitude (positive or negative). The truncation error gives the 2-norm difference between the original and the truncated matrix. + +- **[`eig_trunc`](@ref)**: For general (non-symmetric) matrices, eigenvalues can be complex. By default, eigenvalues are treated by absolute value. The truncation error gives an indication of the magnitude of discarded values, but is not directly related to the 2-norm difference between the original and the truncated matrix. + diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 8e6b930a..26e63570 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt using MatrixAlgebraKit using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview, - TruncatedAlgorithm, findtruncated, findtruncated_svd + TruncatedAlgorithm, findtruncated, findtruncated_svd, truncation_error using ChainRulesCore using LinearAlgebra @@ -113,15 +113,20 @@ for eig in (:eig, :eigh) Ac = copy_input($eig_f, A) DV = $(eig_f!)(Ac, DV, alg.alg) DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) - return DV′, $(_make_eig_t_pb)(A, DV, ind) + ϵ = truncation_error(diagview(DV[1]), ind) + return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind) end function $(_make_eig_t_pb)(A, DV, ind) - function $eig_t_pb(ΔDV) + function $eig_t_pb(ΔDVϵ) ΔA = zero(A) - MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind) + ΔD, ΔV, Δϵ = ΔDVϵ + if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) + throw(ArgumentError("Pullback for eig_trunc! does not yet support non-zero tangent for the truncation error")) + end + MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful? + function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return $eig_t_pb @@ -152,15 +157,20 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind) + ϵ = truncation_error(diagview(USVᴴ[2]), ind) + return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) end function _make_svd_trunc_pullback(A, USVᴴ, ind) - function svd_trunc_pullback(ΔUSVᴴ) + function svd_trunc_pullback(ΔUSVᴴϵ) ΔA = zero(A) - MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind) + ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ + if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) + throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error")) + end + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? + function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return svd_trunc_pullback diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 48f1deb4..d491ee94 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -108,7 +108,8 @@ end function eig_trunc!(A, DV, alg::TruncatedAlgorithm) D, V = eig_full!(A, DV, alg.alg) - return first(truncate(eig_trunc!, (D, V), alg.trunc)) + DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc) + return DVtrunc..., truncation_error!(diagview(D), ind) end # Diagonal logic diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index c7dc4dab..7c2a2010 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -111,7 +111,8 @@ end function eigh_trunc!(A, DV, alg::TruncatedAlgorithm) D, V = eigh_full!(A, DV, alg.alg) - return first(truncate(eigh_trunc!, (D, V), alg.trunc)) + DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc) + return DVtrunc..., truncation_error!(diagview(D), ind) end # Diagonal logic diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 1a735d57..e8ec7e21 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -237,8 +237,9 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) end function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) - USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg) - return first(truncate(svd_trunc!, USVᴴ′, alg.trunc)) + U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + return USVᴴtrunc..., truncation_error!(diagview(S), ind) end # Diagonal logic @@ -381,7 +382,12 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make this controllable using a `gaugefix` keyword argument gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) - return first(truncate(svd_trunc!, USVᴴ, alg.trunc)) + # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong + USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + Strunc = diagview(USVᴴtrunc[2]) + # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum + ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this? + return USVᴴtrunc..., ϵ end function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 239b27dc..f6201b07 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -116,3 +116,13 @@ end _ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A) _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B _ind_intersect(A, B) = intersect(A, B) + +# Truncation error +# ---------------- +truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind) +# destroys input in order to maximize accuracy: +# sqrt(norm(values)^2 - norm(values[ind])^2) might suffer from floating point error +function truncation_error!(values::AbstractVector, ind) + values[ind] .= zero(eltype(values)) + return norm(values) +end diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 62dc713e..867d69eb 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -32,16 +32,19 @@ See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc). @functiondef eig_full """ - eig_trunc(A; [trunc], kwargs...) -> D, V - eig_trunc(A, alg::AbstractAlgorithm) -> D, V - eig_trunc!(A, [DV]; [trunc], kwargs...) -> D, V - eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V + eig_trunc(A; [trunc], kwargs...) -> D, V, ϵ + eig_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ + eig_trunc!(A, [DV]; [trunc], kwargs...) -> D, V, ϵ + eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute a partial or truncated eigenvalue decomposition of the matrix `A`, such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues, selected according to a truncation strategy. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded eigenvalues. + ## Keyword arguments The behavior of this function is controlled by the following keyword arguments: diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index f1fe1936..314cb934 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -12,15 +12,18 @@ docs_eigh_note = """ """ """ - eigh_full(A; kwargs...) -> D, V - eigh_full(A, alg::AbstractAlgorithm) -> D, V - eigh_full!(A, [DV]; kwargs...) -> D, V - eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V + eigh_full(A; kwargs...) -> D, V, ϵ + eigh_full(A, alg::AbstractAlgorithm) -> D, V, ϵ + eigh_full!(A, [DV]; kwargs...) -> D, V, ϵ + eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute the full eigenvalue decomposition of the symmetric or hermitian matrix `A`, such that `A * V = V * D`, where the unitary matrix `V` contains the orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded eigenvalues. + !!! note The bang method `eigh_full!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function @@ -34,16 +37,19 @@ See also [`eigh_vals(!)`](@ref eigh_vals) and [`eigh_trunc(!)`](@ref eigh_trunc) @functiondef eigh_full """ - eigh_trunc(A; [trunc], kwargs...) -> D, V - eigh_trunc(A, alg::AbstractAlgorithm) -> D, V - eigh_trunc!(A, [DV]; [trunc], kwargs...) -> D, V - eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V + eigh_trunc(A; [trunc], kwargs...) -> D, V, ϵ + eigh_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ + eigh_trunc!(A, [DV]; [trunc], kwargs...) -> D, V, ϵ + eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute a partial or truncated eigenvalue decomposition of the symmetric or hermitian matrix `A`, such that `A * V ≈ V * D`, where the isometric matrix `V` contains a subset of the orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues, selected according to a truncation strategy. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded +eigenvalues. + ## Keyword arguments The behavior of this function is controlled by the following keyword arguments: diff --git a/src/interface/svd.jl b/src/interface/svd.jl index e36b0462..606a1c4e 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -42,16 +42,19 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and @functiondef svd_compact """ - svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ - svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ - svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ - svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ + svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ Compute a partial or truncated singular value decomposition (SVD) of `A`, such that -`A * (Vᴴ)' = U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size +`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size `(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded singular values. + ## Keyword arguments The behavior of this function is controlled by the following keyword arguments: diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 10ace743..db417edb 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -181,3 +181,8 @@ Base.:&(::NoTruncation, ::NoTruncation) = notrunc() # disambiguate Base.:&(::NoTruncation, trunc::TruncationIntersection) = trunc Base.:&(trunc::TruncationIntersection, ::NoTruncation) = trunc + +@doc """ + truncation_error(values, ind) +Compute the truncation error as the 2-norm of the values that are not kept by `ind`. +""" truncation_error, truncation_error! diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl index d3cd25f4..a9bd4c0c 100644 --- a/test/amd/eigh.jl +++ b/test/amd/eigh.jl @@ -46,14 +46,14 @@ end r = m - 2 s = 1 + sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) @test length(diagview(D1)) == r @test isisometric(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] trunc = trunctol(; atol=s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometric(V2) @test A * V2 ≈ V2 * D2 @@ -75,7 +75,7 @@ end A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) end=# diff --git a/test/amd/svd.jl b/test/amd/svd.jl index 2cb15b9f..893b41b2 100644 --- a/test/amd/svd.jl +++ b/test/amd/svd.jl @@ -94,7 +94,7 @@ end # algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), # LAPACK_Jacobi()) # end - +# # @testset "size ($m, $n)" for n in (37, m, 63) # @testset "algorithm $alg" for alg in algs # n > m && alg isa LAPACK_Jacobi && continue # not supported @@ -102,15 +102,15 @@ end # S₀ = svd_vals(A) # minmn = min(m, n) # r = minmn - 2 - -# U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) +# +# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) # @test length(S1.diag) == r # @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - +# # s = 1 + sqrt(eps(real(T))) # trunc2 = trunctol(; atol=s * S₀[r + 1]) - -# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) +# +# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) # @test length(S2.diag) == r # @test U1 ≈ U2 # @test S1 ≈ S2 diff --git a/test/chainrules.jl b/test/chainrules.jl index 99f8007e..441a17fb 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -275,7 +275,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -290,7 +290,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -351,7 +351,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -366,7 +366,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -399,7 +399,7 @@ end test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind]), + output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -408,7 +408,7 @@ end test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind]), + output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -446,7 +446,7 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) @@ -463,7 +463,7 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) @@ -488,7 +488,7 @@ end test_rrule( config, svd_trunc, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -497,7 +497,7 @@ end test_rrule( config, svd_trunc, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end diff --git a/test/cuda/eig.jl b/test/cuda/eig.jl index 5cbdaecf..6a44ea73 100644 --- a/test/cuda/eig.jl +++ b/test/cuda/eig.jl @@ -44,13 +44,13 @@ end rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin - D1, V1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r)) @test length(D1.diag) == r @test A * V1 ≈ V1 * D1 s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol=s * abs(D₀[r + 1])) - D2, V2 = @constinferred eig_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl index 0a9be80d..d37722e1 100644 --- a/test/cuda/eigh.jl +++ b/test/cuda/eigh.jl @@ -41,14 +41,14 @@ end r = m - 2 s = 1 + sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) @test length(diagview(D1)) == r @test isisometric(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometric(V2) @test A * V2 ≈ V2 * D2 @@ -70,7 +70,7 @@ end A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) end=# diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index dfc95fe6..36f923e3 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -1,6 +1,6 @@ using MatrixAlgebraKit using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, isposdef, opnorm +using LinearAlgebra: Diagonal, isposdef, norm, opnorm using Test using TestExtras using StableRNGs @@ -103,15 +103,16 @@ end minmn = min(m, n) r = k - U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + @test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1 if !(alg isa CUSOLVER_Randomized) s = 1 + sqrt(eps(real(T))) trunc2 = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) @test length(S2.diag) == r @test U1 ≈ U2 @test parent(S1) ≈ parent(S2) diff --git a/test/eig.jl b/test/eig.jl index ad708f04..0ece1b35 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -42,22 +42,26 @@ end D₀ = sort!(eig_vals(A); by = abs, rev = true) rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin + atol = sqrt(eps(real(T))) - D1, V1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) - @test length(D1.diag) == r + D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) + @test length(diagview(D1)) == r @test A * V1 ≈ V1 * D1 + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * abs(D₀[r + 1])) - D2, V2 = @constinferred eig_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3 = @constinferred eig_trunc(A; alg, trunc) + D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol # trunctol keeps order, truncrank might not # test for same subspace @@ -71,17 +75,20 @@ end @testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) m = 4 + atol = sqrt(eps(real(T))) V = randn(rng, T, m, m) - D = Diagonal([0.9, 0.3, 0.1, 0.01]) + D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) A = V * D * inv(V) alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) - D2, V2 = @constinferred eig_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) - D3, V3 = @constinferred eig_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol end @testset "eig for Diagonal{$T}" for T in BLASFloats @@ -89,6 +96,7 @@ end m = 54 Ad = randn(rng, T, m) A = Diagonal(Ad) + atol = sqrt(eps(real(T))) D, V = @constinferred eig_full(A) @test D isa Diagonal{T} && size(D) == size(A) @@ -101,6 +109,7 @@ end A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2 = @constinferred eig_trunc(A2; alg) + D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] + @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol end diff --git a/test/eigh.jl b/test/eigh.jl index 74f50604..bc04b057 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -49,24 +49,28 @@ end D₀ = reverse(eigh_vals(A)) r = m - 2 s = 1 + sqrt(eps(real(T))) + atol = sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) @test length(diagview(D1)) == r @test isisometric(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometric(V2) @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3 = @constinferred eigh_trunc(A; alg, trunc) + D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol # test for same subspace @test V1 * (V1' * V2) ≈ V2 @@ -79,18 +83,21 @@ end @testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) m = 4 + atol = sqrt(eps(real(T))) V = qr_compact(randn(rng, T, m, m))[1] - D = Diagonal([0.9, 0.3, 0.1, 0.01]) + D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) - D3, V3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol end @testset "eigh for Diagonal{$T}" for T in BLASFloats @@ -99,6 +106,7 @@ end Ad = randn(rng, T, m) Ad .+= conj.(Ad) A = Diagonal(Ad) + atol = sqrt(eps(real(T))) D, V = @constinferred eigh_full(A) @test D isa Diagonal{real(T)} && size(D) == size(A) @@ -111,6 +119,7 @@ end A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A2; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] + @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol end diff --git a/test/svd.jl b/test/svd.jl index a5daad96..d9016b4e 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -97,6 +97,7 @@ end @testset "svd_trunc! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 + atol = sqrt(eps(real(T))) if LinearAlgebra.LAPACK.version() < v"3.12.0" algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) else @@ -113,25 +114,30 @@ end minmn = min(m, n) r = minmn - 2 - U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) - @test length(S1.diag) == r + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + @test length(diagview(S1)) == r + @test diagview(S1) ≈ S₀[1:r] @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + # Test truncation error + @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc) - @test length(S2.diag) == r + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) + @test length(diagview(S2)) == r @test U1 ≈ U2 @test S1 ≈ S2 @test V1ᴴ ≈ V2ᴴ + @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ = @constinferred svd_trunc(A; alg, trunc) - @test length(S3.diag) == r + U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) + @test length(diagview(S3)) == r @test U1 ≈ U3 @test S1 ≈ S3 @test V1ᴴ ≈ V3ᴴ + @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol end end end @@ -148,7 +154,7 @@ end m = 4 @testset "algorithm $alg" for alg in algs U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal([0.9, 0.3, 0.1, 0.01]) + S = Diagonal(T[0.9, 0.3, 0.1, 0.01]) Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ @@ -156,32 +162,35 @@ end (rtol, maxrank) -> (; rtol, maxrank), (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), ) - U1, S1, V1ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) - @test length(S1.diag) == 1 - @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) + @test length(diagview(S1)) == 1 + @test diagview(S1) ≈ diagview(S)[1:1] - U2, S2, V2ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) - @test length(S2.diag) == 2 - @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) + U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + @test length(diagview(S2)) == 2 + @test diagview(S2) ≈ diagview(S)[1:2] end end end @testset "svd_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) + atol = sqrt(eps(real(T))) m = 4 U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal([0.9, 0.3, 0.1, 0.01]) + S = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg) - @test diagview(S2) ≈ diagview(S)[1:2] rtol = sqrt(eps(real(T))) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) + @test diagview(S2) ≈ diagview(S)[1:2] + @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) end @testset "svd for Diagonal{$T}" for T in BLASFloats rng = StableRNG(123) + atol = sqrt(eps(real(T))) for m in (54, 0) Ad = randn(T, m) A = Diagonal(Ad) @@ -209,7 +218,8 @@ end @test S2 ≈ diagview(S) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - U3, S3, Vᴴ3 = @constinferred svd_trunc(A; alg) + U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) @test diagview(S3) ≈ S2[1:min(m, 2)] + @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol end end