From aec6c917de9811794e85c6c533b294dde00d8a37 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 10 Oct 2025 01:05:21 +0200 Subject: [PATCH 01/10] add truncerr --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 28 ++++++++++++++++-------- src/implementations/eig.jl | 3 ++- src/implementations/eigh.jl | 3 ++- src/implementations/svd.jl | 5 +++-- src/implementations/truncation.jl | 7 ++++++ src/interface/eig.jl | 11 ++++++---- src/interface/eigh.jl | 11 ++++++---- src/interface/svd.jl | 13 ++++++----- test/amd/eigh.jl | 6 ++--- test/amd/svd.jl | 12 +++++----- test/chainrules.jl | 20 ++++++++--------- test/cuda/eig.jl | 4 ++-- test/cuda/eigh.jl | 6 ++--- test/cuda/svd.jl | 4 ++-- test/eig.jl | 20 +++++++++++------ test/eigh.jl | 20 +++++++++++------ test/svd.jl | 24 ++++++++++++-------- 17 files changed, 122 insertions(+), 75 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index d8664d80..473846f6 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt using MatrixAlgebraKit using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview, - TruncatedAlgorithm, findtruncated, findtruncated_svd + TruncatedAlgorithm, findtruncated, findtruncated_svd, compute_truncerr! using ChainRulesCore using LinearAlgebra @@ -113,15 +113,20 @@ for eig in (:eig, :eigh) Ac = copy_input($eig_f, A) DV = $(eig_f!)(Ac, DV, alg.alg) DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) - return DV′, $(_make_eig_t_pb)(A, DV, ind) + ϵ = compute_truncerr!(diagview(copy(DV[1])), ind) + return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind) end function $(_make_eig_t_pb)(A, DV, ind) - function $eig_t_pb(ΔDV) + function $eig_t_pb(ΔDVϵ) ΔA = zero(A) - MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind) + ΔD, ΔV, Δϵ = ΔDVϵ + if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) + throw(ArgumentError("Pullback for eig_trunc! does not yet support non-zero tangent for the truncation error")) + end + MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful? + function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return $eig_t_pb @@ -152,15 +157,20 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind) + ϵ = compute_truncerr!(diagview(copy(USVᴴ[2])), ind) + return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) end function _make_svd_trunc_pullback(A, USVᴴ, ind) - function svd_trunc_pullback(ΔUSVᴴ) + function svd_trunc_pullback(ΔUSVᴴϵ) ΔA = zero(A) - MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind) + ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ + if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ)) + throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error")) + end + MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind) return NoTangent(), ΔA, ZeroTangent(), NoTangent() end - function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? + function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful? return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() end return svd_trunc_pullback diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 48f1deb4..0815c9fd 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -108,7 +108,8 @@ end function eig_trunc!(A, DV, alg::TruncatedAlgorithm) D, V = eig_full!(A, DV, alg.alg) - return first(truncate(eig_trunc!, (D, V), alg.trunc)) + DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc) + return DVtrunc..., compute_truncerr!(diagview(D), ind) end # Diagonal logic diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index c7dc4dab..2fbdf7ba 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -111,7 +111,8 @@ end function eigh_trunc!(A, DV, alg::TruncatedAlgorithm) D, V = eigh_full!(A, DV, alg.alg) - return first(truncate(eigh_trunc!, (D, V), alg.trunc)) + DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc) + return DVtrunc..., compute_truncerr!(diagview(D), ind) end # Diagonal logic diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 1a735d57..baa6e88e 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -237,8 +237,9 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) end function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) - USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg) - return first(truncate(svd_trunc!, USVᴴ′, alg.trunc)) + U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + return USVᴴtrunc..., compute_truncerr!(diagview(S), ind) end # Diagonal logic diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 239b27dc..4437d918 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -116,3 +116,10 @@ end _ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A) _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B _ind_intersect(A, B) = intersect(A, B) + +# Compute truncation error as 2-norm of discarded values +# by destroying original values +function compute_truncerr!(values::AbstractVector, ind) + values[ind] .= zero(eltype(values)) + return norm(values) +end diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 5a74f79f..8a0107f3 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -31,16 +31,19 @@ See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc). @functiondef eig_full """ - eig_trunc(A; kwargs...) -> D, V - eig_trunc(A, alg::AbstractAlgorithm) -> D, V - eig_trunc!(A, [DV]; kwargs...) -> D, V - eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V + eig_trunc(A; kwargs...) -> D, V, ϵ + eig_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ + eig_trunc!(A, [DV]; kwargs...) -> D, V, ϵ + eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute a partial or truncated eigenvalue decomposition of the matrix `A`, such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues, selected according to a truncation strategy. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded eigenvalues. + !!! note The bang method `eig_trunc!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index c51b2c01..87c24d5d 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -9,15 +9,18 @@ For generic eigenvalue decompositions, see [`eig_full`](@ref). """ """ - eigh_full(A; kwargs...) -> D, V - eigh_full(A, alg::AbstractAlgorithm) -> D, V - eigh_full!(A, [DV]; kwargs...) -> D, V - eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V + eigh_full(A; kwargs...) -> D, V, ϵ + eigh_full(A, alg::AbstractAlgorithm) -> D, V, ϵ + eigh_full!(A, [DV]; kwargs...) -> D, V, ϵ + eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute the full eigenvalue decomposition of the symmetric or hermitian matrix `A`, such that `A * V = V * D`, where the unitary matrix `V` contains the orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded eigenvalues. + !!! note The bang method `eigh_full!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function diff --git a/src/interface/svd.jl b/src/interface/svd.jl index ee6cd9af..92af73e3 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -43,16 +43,19 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and # TODO: decide if we should have `svd_trunc!!` instead """ - svd_trunc(A; kwargs...) -> U, S, Vᴴ - svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ - svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ - svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ + svd_trunc(A; kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ Compute a partial or truncated singular value decomposition (SVD) of `A`, such that -`A * (Vᴴ)' = U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size +`A * (Vᴴ)' ≈ U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size `(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the +discarded singular values. + !!! note The bang method `svd_trunc!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl index 0e2cb5be..cdad6f95 100644 --- a/test/amd/eigh.jl +++ b/test/amd/eigh.jl @@ -46,14 +46,14 @@ end r = m - 2 s = 1 + sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) @test length(diagview(D1)) == r @test isisometry(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] trunc = trunctol(; atol=s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) @test A * V2 ≈ V2 * D2 @@ -75,7 +75,7 @@ end A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) end=# diff --git a/test/amd/svd.jl b/test/amd/svd.jl index 2cb15b9f..893b41b2 100644 --- a/test/amd/svd.jl +++ b/test/amd/svd.jl @@ -94,7 +94,7 @@ end # algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection(), # LAPACK_Jacobi()) # end - +# # @testset "size ($m, $n)" for n in (37, m, 63) # @testset "algorithm $alg" for alg in algs # n > m && alg isa LAPACK_Jacobi && continue # not supported @@ -102,15 +102,15 @@ end # S₀ = svd_vals(A) # minmn = min(m, n) # r = minmn - 2 - -# U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) +# +# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r)) # @test length(S1.diag) == r # @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - +# # s = 1 + sqrt(eps(real(T))) # trunc2 = trunctol(; atol=s * S₀[r + 1]) - -# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) +# +# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1])) # @test length(S2.diag) == r # @test U1 ≈ U2 # @test S1 ≈ S2 diff --git a/test/chainrules.jl b/test/chainrules.jl index 99f8007e..441a17fb 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -275,7 +275,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -290,7 +290,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -351,7 +351,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -366,7 +366,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc), + output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -399,7 +399,7 @@ end test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind]), + output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -408,7 +408,7 @@ end test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind]), + output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -446,7 +446,7 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) @@ -463,7 +463,7 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) @@ -488,7 +488,7 @@ end test_rrule( config, svd_trunc, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -497,7 +497,7 @@ end test_rrule( config, svd_trunc, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :]), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end diff --git a/test/cuda/eig.jl b/test/cuda/eig.jl index 5cbdaecf..6a44ea73 100644 --- a/test/cuda/eig.jl +++ b/test/cuda/eig.jl @@ -44,13 +44,13 @@ end rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin - D1, V1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r)) @test length(D1.diag) == r @test A * V1 ≈ V1 * D1 s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol=s * abs(D₀[r + 1])) - D2, V2 = @constinferred eig_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl index eece6be3..a12763dc 100644 --- a/test/cuda/eigh.jl +++ b/test/cuda/eigh.jl @@ -41,14 +41,14 @@ end r = m - 2 s = 1 + sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) @test length(diagview(D1)) == r @test isisometry(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) @test A * V2 ≈ V2 * D2 @@ -70,7 +70,7 @@ end A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) end=# diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index dfc95fe6..f9dc1232 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -103,7 +103,7 @@ end minmn = min(m, n) r = k - U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] @@ -111,7 +111,7 @@ end s = 1 + sqrt(eps(real(T))) trunc2 = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc = trunctol(; atol = s * S₀[r + 1])) @test length(S2.diag) == r @test U1 ≈ U2 @test parent(S1) ≈ parent(S2) diff --git a/test/eig.jl b/test/eig.jl index ad708f04..acc63b24 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -43,21 +43,24 @@ end rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin - D1, V1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) + D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) @test length(D1.diag) == r @test A * V1 ≈ V1 * D1 + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * abs(D₀[r + 1])) - D2, V2 = @constinferred eig_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3 = @constinferred eig_trunc(A; alg, trunc) + D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) # trunctol keeps order, truncrank might not # test for same subspace @@ -72,16 +75,18 @@ end rng = StableRNG(123) m = 4 V = randn(rng, T, m, m) - D = Diagonal([0.9, 0.3, 0.1, 0.01]) + D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) A = V * D * inv(V) alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) - D2, V2 = @constinferred eig_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test ϵ2 ≈ norm(diagview(D)[3:4]) @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) - D3, V3 = @constinferred eig_trunc(A; alg) + D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test ϵ3 ≈ norm(diagview(D)[3:4]) end @testset "eig for Diagonal{$T}" for T in BLASFloats @@ -101,6 +106,7 @@ end A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2 = @constinferred eig_trunc(A2; alg) + D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] + @test ϵ2 ≈ norm(diagview(A2)[3:4]) end diff --git a/test/eigh.jl b/test/eigh.jl index fa766abb..b78b1213 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -50,23 +50,26 @@ end r = m - 2 s = 1 + sqrt(eps(real(T))) - D1, V1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) + D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) @test length(diagview(D1)) == r @test isisometry(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2 = @constinferred eigh_trunc(A; alg, trunc) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometry(V2) @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3 = @constinferred eigh_trunc(A; alg, trunc) + D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) # test for same subspace @test V1 * (V1' * V2) ≈ V2 @@ -80,17 +83,19 @@ end rng = StableRNG(123) m = 4 V = qr_compact(randn(rng, T, m, m))[1] - D = Diagonal([0.9, 0.3, 0.1, 0.01]) + D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) A = V * D * V' A = (A + A') / 2 alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) + @test ϵ2 ≈ norm(diagview(D)[3:4]) alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) - D3, V3 = @constinferred eigh_trunc(A; alg) + D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test ϵ3 ≈ norm(diagview(D)[3:4]) end @testset "eigh for Diagonal{$T}" for T in BLASFloats @@ -111,6 +116,7 @@ end A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2 = @constinferred eigh_trunc(A2; alg) + D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] + @test ϵ2 ≈ norm(diagview(A2)[3:4]) end diff --git a/test/svd.jl b/test/svd.jl index 21d34337..1b38b174 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -113,25 +113,29 @@ end minmn = min(m, n) r = minmn - 2 - U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) + U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + # Test truncation error + @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) @test length(S2.diag) == r @test U1 ≈ U2 @test S1 ≈ S2 @test V1ᴴ ≈ V2ᴴ + @test ϵ2 ≈ ϵ1 trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) - U3, S3, V3ᴴ = @constinferred svd_trunc(A; alg, trunc) + U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) @test length(S3.diag) == r @test U1 ≈ U3 @test S1 ≈ S3 @test V1ᴴ ≈ V3ᴴ + @test ϵ3 ≈ ϵ1 end end end @@ -148,7 +152,7 @@ end m = 4 @testset "algorithm $alg" for alg in algs U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal([0.9, 0.3, 0.1, 0.01]) + S = Diagonal(T[0.9, 0.3, 0.1, 0.01]) Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ @@ -156,11 +160,11 @@ end (rtol, maxrank) -> (; rtol, maxrank), (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), ) - U1, S1, V1ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) + U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) @test length(S1.diag) == 1 @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) - U2, S2, V2ᴴ = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) + U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) @test length(S2.diag) == 2 @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) end @@ -171,12 +175,13 @@ end rng = StableRNG(123) m = 4 U = qr_compact(randn(rng, T, m, m))[1] - S = Diagonal([0.9, 0.3, 0.1, 0.01]) + S = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) Vᴴ = qr_compact(randn(rng, T, m, m))[1] A = U * S * Vᴴ alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol = 0.2)) - U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg) + U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] rtol = sqrt(eps(real(T))) + @test ϵ2 ≈ norm(diagview(S)[3:4]) @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) end @@ -209,7 +214,8 @@ end @test S2 ≈ diagview(S) alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - U3, S3, Vᴴ3 = @constinferred svd_trunc(A; alg) + U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) @test diagview(S3) ≈ S2[1:min(m, 2)] + @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) end end From d63707cbf1fe444e7a4bdf8e05a811c99edaf6d9 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sat, 11 Oct 2025 00:30:45 +0200 Subject: [PATCH 02/10] some final fixes --- docs/src/user_interface/truncations.md | 13 +++++++++++++ ext/MatrixAlgebraKitChainRulesCoreExt.jl | 4 ++-- src/implementations/svd.jl | 7 ++++++- test/cuda/svd.jl | 1 + 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index e6038d32..cf165a3a 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -127,3 +127,16 @@ When strategies are combined, only the values that satisfy all conditions are ke combined_trunc = truncrank(10) & trunctol(; atol = 1e-6); ``` +## Truncation Error + +When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned. +This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality. +For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix. +For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality. + + +For example: +```julia +U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(10)) +# ϵ is the 2-norm of the discarded singular values +``` diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 26190929..91878296 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -113,7 +113,7 @@ for eig in (:eig, :eigh) Ac = copy_input($eig_f, A) DV = $(eig_f!)(Ac, DV, alg.alg) DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) - ϵ = compute_truncerr!(diagview(copy(DV[1])), ind) + ϵ = compute_truncerr!(copy(diagview(DV[1])), ind) return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind) end function $(_make_eig_t_pb)(A, DV, ind) @@ -157,7 +157,7 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - ϵ = compute_truncerr!(diagview(copy(USVᴴ[2])), ind) + ϵ = compute_truncerr!(copy(diagview(USVᴴ[2])), ind) return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) end function _make_svd_trunc_pullback(A, USVᴴ, ind) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index baa6e88e..b73d5983 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -382,7 +382,12 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make this controllable using a `gaugefix` keyword argument gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) - return first(truncate(svd_trunc!, USVᴴ, alg.trunc)) + # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong + USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + Strunc = diagview(USVᴴtrunc[2]) + # normal `compute_truncerr!` does not work here since `S` is not the full singular value spectrum + ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this? + return USVᴴtrunc..., ϵ end function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index f9dc1232..1847ca0b 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -106,6 +106,7 @@ end U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] + @test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1 if !(alg isa CUSOLVER_Randomized) s = 1 + sqrt(eps(real(T))) From 5ce6c2cb6d1f2cd5b368bc8a5a4722eac4c4c5b1 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sat, 11 Oct 2025 00:37:45 +0200 Subject: [PATCH 03/10] fix cuda test --- test/cuda/svd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index 1847ca0b..36f923e3 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -1,6 +1,6 @@ using MatrixAlgebraKit using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, isposdef, opnorm +using LinearAlgebra: Diagonal, isposdef, norm, opnorm using Test using TestExtras using StableRNGs From 8bb35d9b9a79ef7439aedd4959344802a05c696e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 08:53:31 -0400 Subject: [PATCH 04/10] make `trunc` kwarg explicit in docstrings --- src/interface/eig.jl | 4 ++-- src/interface/svd.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 26ae0c36..6eb26941 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -32,9 +32,9 @@ See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc). @functiondef eig_full """ - eig_trunc(A; kwargs...) -> D, V, ϵ + eig_trunc(A; [trunc], kwargs...) -> D, V, ϵ eig_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ - eig_trunc!(A, [DV]; kwargs...) -> D, V, ϵ + eig_trunc!(A, [DV]; [trunc], kwargs...) -> D, V, ϵ eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute a partial or truncated eigenvalue decomposition of the matrix `A`, diff --git a/src/interface/svd.jl b/src/interface/svd.jl index a5a6c6f6..7cc4d20e 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -42,9 +42,9 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and @functiondef svd_compact """ - svd_trunc(A; kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ - svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ, ϵ + svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ Compute a partial or truncated singular value decomposition (SVD) of `A`, such that From 02f3394a2a932e7e26f6f7392bee4d940c5957e8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 09:34:50 -0400 Subject: [PATCH 05/10] update docs --- docs/src/user_interface/truncations.md | 61 +++++++++++++------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index cf165a3a..ee730020 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -12,17 +12,15 @@ Truncation strategies allow you to control which eigenvalues or singular values Truncation strategies can be used with truncated decomposition functions in two ways, as illustrated below. For concreteness, we use the following matrix as an example: -```jldoctest truncations +```jldoctest truncations; output=false using MatrixAlgebraKit using MatrixAlgebraKit: diagview A = [2 1 0; 1 3 1; 0 1 4]; D, V = eigh_full(A); - diagview(D) ≈ [3 - √3, 3, 3 + √3] # output - true ``` @@ -31,38 +29,35 @@ true The simplest approach is to pass a `NamedTuple` with the truncation parameters. For example, keeping only the largest 2 eigenvalues: -```jldoctest truncations -Dtrunc, Vtrunc = eigh_trunc(A; trunc = (maxrank = 2,)); +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2,)); size(Dtrunc, 1) <= 2 # output - true ``` Note however that there are no guarantees on the order of the output values: -```jldoctest truncations +```jldoctest truncations; output=false diagview(Dtrunc) ≈ diagview(D)[[3, 2]] # output - true ``` You can also use tolerance-based truncation or combine multiple criteria: -```jldoctest truncations -Dtrunc, Vtrunc = eigh_trunc(A; trunc = (atol = 2.9,)); +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (atol = 2.9,)); all(>(2.9), diagview(Dtrunc)) # output - true ``` -```jldoctest truncations -Dtrunc, Vtrunc = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9)); +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9)); size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc)) # output @@ -72,7 +67,7 @@ true In general, the keyword arguments that are supported can be found in the `TruncationStrategy` docstring: ```@docs; canonical = false -TruncationStrategy +TruncationStrategy() ``` @@ -81,33 +76,22 @@ TruncationStrategy For more control, you can construct [`TruncationStrategy`](@ref) objects directly. This is also what the previous syntax will end up calling. -```jldoctest truncations +```jldoctest truncations; output=false Dtrunc, Vtrunc = eigh_trunc(A; trunc = truncrank(2)) size(Dtrunc, 1) <= 2 # output - true ``` -```jldoctest truncations -Dtrunc, Vtrunc = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9)) +```jldoctest truncations; output=false +Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9)) size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc)) # output true ``` -## Truncation with SVD vs Eigenvalue Decompositions - -When using truncations with different decomposition types, keep in mind: - -- **`svd_trunc`**: Singular values are always real and non-negative, sorted in descending order. Truncation by value typically keeps the largest singular values. - -- **`eigh_trunc`**: Eigenvalues are real but can be negative for symmetric matrices. By default, `truncrank` sorts by absolute value, so `truncrank(k)` keeps the `k` eigenvalues with largest magnitude (positive or negative). - -- **`eig_trunc`**: For general (non-symmetric) matrices, eigenvalues can be complex. Truncation by absolute value considers the complex magnitude. - ## Truncation Strategies MatrixAlgebraKit provides several built-in truncation strategies: @@ -136,7 +120,22 @@ For the case of `eig_trunc`, this interpretation does not hold because the norm For example: -```julia -U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(10)) -# ϵ is the 2-norm of the discarded singular values +```jldoctest truncations; output=false +using LinearAlgebra: norm +U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2)) +norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values + +# output +true ``` + +### Truncation with SVD vs Eigenvalue Decompositions + +When using truncations with different decomposition types, keep in mind: + +- **[`svd_trunc`](@ref)**: Singular values are always real and non-negative, sorted in descending order. Truncation by value typically keeps the largest singular values. The truncation error gives the 2-norm difference between the original and the truncated matrix. + +- **[`eigh_trunc`](@ref)**: Eigenvalues are real but can be negative for symmetric matrices. By default, eigenvalues are treated by absolute value, e.g. `truncrank(k)` keeps the `k` eigenvalues with largest magnitude (positive or negative). The truncation error gives the 2-norm difference between the original and the truncated matrix. + +- **[`eig_trunc`](@ref)**: For general (non-symmetric) matrices, eigenvalues can be complex. By default, eigenvalues are treated by absolute value. The truncation error gives an indication of the magnitude of discarded values, but is not directly related to the 2-norm difference between the original and the truncated matrix. + From 3a0c54449be0c78c3d80c0b3ecf8c9f2d6c8f143 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 12 Oct 2025 17:57:17 -0400 Subject: [PATCH 06/10] bump JET compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e1d0a218..be0bee32 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ Aqua = "0.6, 0.7, 0.8" ChainRulesCore = "1" ChainRulesTestUtils = "1" CUDA = "5" -JET = "0.9" +JET = "0.9, 0.10" LinearAlgebra = "1" SafeTestsets = "0.1" StableRNGs = "1" From f1147f513c43e804cd333541ff60b92227344187 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 13 Oct 2025 12:01:11 -0400 Subject: [PATCH 07/10] rename `truncation_error(!)` --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 6 +++--- src/implementations/eig.jl | 2 +- src/implementations/eigh.jl | 2 +- src/implementations/svd.jl | 4 ++-- src/implementations/truncation.jl | 17 ++++++++++++++--- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 91878296..26e63570 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt using MatrixAlgebraKit using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview, - TruncatedAlgorithm, findtruncated, findtruncated_svd, compute_truncerr! + TruncatedAlgorithm, findtruncated, findtruncated_svd, truncation_error using ChainRulesCore using LinearAlgebra @@ -113,7 +113,7 @@ for eig in (:eig, :eigh) Ac = copy_input($eig_f, A) DV = $(eig_f!)(Ac, DV, alg.alg) DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) - ϵ = compute_truncerr!(copy(diagview(DV[1])), ind) + ϵ = truncation_error(diagview(DV[1]), ind) return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind) end function $(_make_eig_t_pb)(A, DV, ind) @@ -157,7 +157,7 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - ϵ = compute_truncerr!(copy(diagview(USVᴴ[2])), ind) + ϵ = truncation_error(diagview(USVᴴ[2]), ind) return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) end function _make_svd_trunc_pullback(A, USVᴴ, ind) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 0815c9fd..d491ee94 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -109,7 +109,7 @@ end function eig_trunc!(A, DV, alg::TruncatedAlgorithm) D, V = eig_full!(A, DV, alg.alg) DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc) - return DVtrunc..., compute_truncerr!(diagview(D), ind) + return DVtrunc..., truncation_error!(diagview(D), ind) end # Diagonal logic diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 2fbdf7ba..7c2a2010 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -112,7 +112,7 @@ end function eigh_trunc!(A, DV, alg::TruncatedAlgorithm) D, V = eigh_full!(A, DV, alg.alg) DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc) - return DVtrunc..., compute_truncerr!(diagview(D), ind) + return DVtrunc..., truncation_error!(diagview(D), ind) end # Diagonal logic diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index b73d5983..e8ec7e21 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -239,7 +239,7 @@ end function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - return USVᴴtrunc..., compute_truncerr!(diagview(S), ind) + return USVᴴtrunc..., truncation_error!(diagview(S), ind) end # Diagonal logic @@ -385,7 +385,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) Strunc = diagview(USVᴴtrunc[2]) - # normal `compute_truncerr!` does not work here since `S` is not the full singular value spectrum + # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this? return USVᴴtrunc..., ϵ end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 4437d918..0fa3909b 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -117,9 +117,20 @@ _ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B _ind_intersect(A, B) = intersect(A, B) -# Compute truncation error as 2-norm of discarded values -# by destroying original values -function compute_truncerr!(values::AbstractVector, ind) +# Truncation error +# ---------------- +@doc """ + truncation_error(values, ind) + truncation_error!(values, ind) + +Determine the truncation error of selecting `ind` out of the `values`. +This is defined as the 2-norm of the discarded values. +""" truncation_error, truncation_error! + +truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind) +# destroys input in order to maximize accuracy: +# sqrt(norm(values)^2 - norm(values[ind])^2) might suffer from floating point error +function truncation_error!(values::AbstractVector, ind) values[ind] .= zero(eltype(values)) return norm(values) end From 22495c3a429701b8d5e9938ca5cf85d462e4fe3c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 13 Oct 2025 12:03:49 -0400 Subject: [PATCH 08/10] docstring fixes --- src/implementations/truncation.jl | 8 -------- src/interface/eig.jl | 20 ++++++++++++++++++++ src/interface/eigh.jl | 11 +++++++---- src/interface/svd.jl | 22 +++++++++++++++++++++- src/interface/truncation.jl | 5 +++++ 5 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 0fa3909b..f6201b07 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -119,14 +119,6 @@ _ind_intersect(A, B) = intersect(A, B) # Truncation error # ---------------- -@doc """ - truncation_error(values, ind) - truncation_error!(values, ind) - -Determine the truncation error of selecting `ind` out of the `values`. -This is defined as the 2-norm of the discarded values. -""" truncation_error, truncation_error! - truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind) # destroys input in order to maximize accuracy: # sqrt(norm(values)^2 - norm(values[ind])^2) might suffer from floating point error diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 6eb26941..867d69eb 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -45,6 +45,26 @@ selected according to a truncation strategy. The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded eigenvalues. +## Keyword arguments +The behavior of this function is controlled by the following keyword arguments: + +- `trunc`: Specifies the truncation strategy. This can be: + - A `NamedTuple` with fields `atol`, `rtol`, and/or `maxrank`, which will be converted to + a [`TruncationStrategy`](@ref). For details on available truncation strategies, see + [Truncations](@ref). + - A `TruncationStrategy` object directly (e.g., `truncrank(10)`, `trunctol(atol=1e-6)`, or + combinations using `&`). + - `nothing` (default), which keeps all eigenvalues. + +- Other keyword arguments are passed to the algorithm selection procedure. If no explicit + `alg` is provided, these keywords are used to select and configure the algorithm through + [`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm + selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) + for the default algorithm selection behavior. + +When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the +truncation strategy is already embedded in the algorithm. + !!! note The bang method `eig_trunc!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index e5f47b0d..314cb934 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -37,16 +37,19 @@ See also [`eigh_vals(!)`](@ref eigh_vals) and [`eigh_trunc(!)`](@ref eigh_trunc) @functiondef eigh_full """ - eigh_trunc(A; [trunc], kwargs...) -> D, V - eigh_trunc(A, alg::AbstractAlgorithm) -> D, V - eigh_trunc!(A, [DV]; [trunc], kwargs...) -> D, V - eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V + eigh_trunc(A; [trunc], kwargs...) -> D, V, ϵ + eigh_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ + eigh_trunc!(A, [DV]; [trunc], kwargs...) -> D, V, ϵ + eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ Compute a partial or truncated eigenvalue decomposition of the symmetric or hermitian matrix `A`, such that `A * V ≈ V * D`, where the isometric matrix `V` contains a subset of the orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues, selected according to a truncation strategy. +The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded +eigenvalues. + ## Keyword arguments The behavior of this function is controlled by the following keyword arguments: diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 7cc4d20e..606a1c4e 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -54,7 +54,27 @@ square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strat The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded singular values. - + +## Keyword arguments +The behavior of this function is controlled by the following keyword arguments: + +- `trunc`: Specifies the truncation strategy. This can be: + - A `NamedTuple` with fields `atol`, `rtol`, and/or `maxrank`, which will be converted to + a [`TruncationStrategy`](@ref). For details on available truncation strategies, see + [Truncations](@ref). + - A `TruncationStrategy` object directly (e.g., `truncrank(10)`, `trunctol(atol=1e-6)`, or + combinations using `&`). + - `nothing` (default), which keeps all singular values. + +- Other keyword arguments are passed to the algorithm selection procedure. If no explicit + `alg` is provided, these keywords are used to select and configure the algorithm through + [`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm + selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) + for the default algorithm selection behavior. + +When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the +truncation strategy is already embedded in the algorithm. + !!! note The bang method `svd_trunc!` optionally accepts the output structure and possibly destroys the input matrix `A`. Always use the return value of the function diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 10ace743..db417edb 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -181,3 +181,8 @@ Base.:&(::NoTruncation, ::NoTruncation) = notrunc() # disambiguate Base.:&(::NoTruncation, trunc::TruncationIntersection) = trunc Base.:&(trunc::TruncationIntersection, ::NoTruncation) = trunc + +@doc """ + truncation_error(values, ind) +Compute the truncation error as the 2-norm of the values that are not kept by `ind`. +""" truncation_error, truncation_error! From 125a55a4d6d89822ae53e6ac4b402713fcc40149 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 13 Oct 2025 18:18:46 -0400 Subject: [PATCH 09/10] relax truncation error tolerances --- test/eig.jl | 15 +++++++++------ test/eigh.jl | 15 +++++++++------ test/svd.jl | 6 ++++-- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/test/eig.jl b/test/eig.jl index acc63b24..318242c0 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -42,25 +42,26 @@ end D₀ = sort!(eig_vals(A); by = abs, rev = true) rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) r = length(D₀) - rmin + atol = sqrt(eps(real(T))) D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) @test length(D1.diag) == r @test A * V1 ≈ V1 * D1 - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * abs(D₀[r + 1])) D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol # trunctol keeps order, truncrank might not # test for same subspace @@ -74,19 +75,20 @@ end @testset "eig_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) m = 4 + atol = sqrt(eps(real(T))) V = randn(rng, T, m, m) D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) A = V * D * inv(V) alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test ϵ2 ≈ norm(diagview(D)[3:4]) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test ϵ3 ≈ norm(diagview(D)[3:4]) + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol end @testset "eig for Diagonal{$T}" for T in BLASFloats @@ -94,6 +96,7 @@ end m = 54 Ad = randn(rng, T, m) A = Diagonal(Ad) + atol = sqrt(eps(real(T))) D, V = @constinferred eig_full(A) @test D isa Diagonal{T} && size(D) == size(A) @@ -108,5 +111,5 @@ end alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) + @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol end diff --git a/test/eigh.jl b/test/eigh.jl index 60ab2f1f..ab10c2e5 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -49,27 +49,28 @@ end D₀ = reverse(eigh_vals(A)) r = m - 2 s = 1 + sqrt(eps(real(T))) + atol = sqrt(eps(real(T))) D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) @test length(diagview(D1)) == r @test isisometric(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol trunc = trunctol(; atol = s * D₀[r + 1]) D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometric(V2) @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol # test for same subspace @test V1 * (V1' * V2) ≈ V2 @@ -82,6 +83,7 @@ end @testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) m = 4 + atol = sqrt(eps(real(T))) V = qr_compact(randn(rng, T, m, m))[1] D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) A = V * D * V' @@ -90,12 +92,12 @@ end D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test ϵ3 ≈ norm(diagview(D)[3:4]) + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol end @testset "eigh for Diagonal{$T}" for T in BLASFloats @@ -104,6 +106,7 @@ end Ad = randn(rng, T, m) Ad .+= conj.(Ad) A = Diagonal(Ad) + atol = sqrt(eps(real(T))) D, V = @constinferred eigh_full(A) @test D isa Diagonal{real(T)} && size(D) == size(A) @@ -118,5 +121,5 @@ end alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) + @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol end diff --git a/test/svd.jl b/test/svd.jl index 62817461..82cb3005 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -97,6 +97,7 @@ end @testset "svd_trunc! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 54 + atol = sqrt(eps(real(T))) if LinearAlgebra.LAPACK.version() < v"3.12.0" algs = (LAPACK_DivideAndConquer(), LAPACK_QRIteration(), LAPACK_Bisection()) else @@ -117,7 +118,7 @@ end @test length(S1.diag) == r @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] # Test truncation error - @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) + @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) @@ -187,6 +188,7 @@ end @testset "svd for Diagonal{$T}" for T in BLASFloats rng = StableRNG(123) + atol = sqrt(eps(real(T))) for m in (54, 0) Ad = randn(T, m) A = Diagonal(Ad) @@ -216,6 +218,6 @@ end alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) @test diagview(S3) ≈ S2[1:min(m, 2)] - @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) + @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol end end From 6c4fdb1b91fee0a47528b0c4f263bcb9159925dc Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Tue, 14 Oct 2025 12:03:09 +0200 Subject: [PATCH 10/10] some test streamline --- test/eig.jl | 6 +++--- test/eigh.jl | 4 ++-- test/svd.jl | 24 +++++++++++++----------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/test/eig.jl b/test/eig.jl index 318242c0..0ece1b35 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -45,7 +45,7 @@ end atol = sqrt(eps(real(T))) D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) - @test length(D1.diag) == r + @test length(diagview(D1)) == r @test A * V1 ≈ V1 * D1 @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol @@ -81,13 +81,13 @@ end A = V * D * inv(V) alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test diagview(D2) ≈ diagview(D)[1:2] @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test diagview(D3) ≈ diagview(D)[1:2] @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol end diff --git a/test/eigh.jl b/test/eigh.jl index ab10c2e5..bc04b057 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -90,13 +90,13 @@ end A = (A + A') / 2 alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test diagview(D2) ≈ diagview(D)[1:2] @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test diagview(D3) ≈ diagview(D)[1:2] @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol end diff --git a/test/svd.jl b/test/svd.jl index 82cb3005..d9016b4e 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -115,7 +115,8 @@ end r = minmn - 2 U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) - @test length(S1.diag) == r + @test length(diagview(S1)) == r + @test diagview(S1) ≈ S₀[1:r] @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] # Test truncation error @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol @@ -124,19 +125,19 @@ end trunc = trunctol(; atol = s * S₀[r + 1]) U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc) - @test length(S2.diag) == r + @test length(diagview(S2)) == r @test U1 ≈ U2 @test S1 ≈ S2 @test V1ᴴ ≈ V2ᴴ - @test ϵ2 ≈ ϵ1 + @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) - @test length(S3.diag) == r + @test length(diagview(S3)) == r @test U1 ≈ U3 @test S1 ≈ S3 @test V1ᴴ ≈ V3ᴴ - @test ϵ3 ≈ ϵ1 + @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol end end end @@ -162,18 +163,19 @@ end (rtol, maxrank) -> truncrank(maxrank) & trunctol(; rtol), ) U1, S1, V1ᴴ, ϵ1 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 1)) - @test length(S1.diag) == 1 - @test S1.diag ≈ S.diag[1:1] rtol = sqrt(eps(real(T))) + @test length(diagview(S1)) == 1 + @test diagview(S1) ≈ diagview(S)[1:1] U2, S2, V2ᴴ, ϵ2 = svd_trunc(A; alg, trunc = trunc_fun(0.2, 3)) - @test length(S2.diag) == 2 - @test S2.diag ≈ S.diag[1:2] rtol = sqrt(eps(real(T))) + @test length(diagview(S2)) == 2 + @test diagview(S2) ≈ diagview(S)[1:2] end end end @testset "svd_trunc! specify truncation algorithm T = $T" for T in BLASFloats rng = StableRNG(123) + atol = sqrt(eps(real(T))) m = 4 U = qr_compact(randn(rng, T, m, m))[1] S = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) @@ -181,8 +183,8 @@ end A = U * S * Vᴴ alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol = 0.2)) U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) - @test diagview(S2) ≈ diagview(S)[1:2] rtol = sqrt(eps(real(T))) - @test ϵ2 ≈ norm(diagview(S)[3:4]) + @test diagview(S2) ≈ diagview(S)[1:2] + @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) end