From cf31cb4e3cfe23c2439fedab14950e32dbac052d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 14:12:46 +0100 Subject: [PATCH 1/3] Pass compute_error and epsilon as arguments This PR adds a `compute_error` field to `TruncatedAlgorithm` and changes `eig_trunc!`, `eigh_trunc!`, and `svd_trunc!` (and their non-mutating counterparts) to accept an array `epsilon` as part of their arguments. The purpose of this is twofold: to make handling the truncation error easier in AD, and to avoid forcing GPU synchronization. --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 4 +-- .../MatrixAlgebraKitMooncakeExt.jl | 6 ++-- src/algorithms.jl | 2 ++ src/implementations/eig.jl | 14 +++++++--- src/implementations/eigh.jl | 14 +++++++--- src/implementations/svd.jl | 18 ++++++------ test/chainrules.jl | 20 ++++++------- test/eig.jl | 12 ++++---- test/eigh.jl | 12 ++++---- test/genericlinearalgebra/eigh.jl | 10 +++---- test/genericlinearalgebra/svd.jl | 8 +++--- test/genericschur/eig.jl | 12 ++++---- test/mooncake.jl | 28 +++++++++---------- test/svd.jl | 10 +++---- 14 files changed, 90 insertions(+), 80 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 549f4a53..9cee4d70 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -118,7 +118,7 @@ for eig in (:eig, :eigh) Ac = copy_input($eig_f, A) DV = $(eig_f!)(Ac, DV, alg.alg) DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc) - ϵ = truncation_error(diagview(DV[1]), ind) + ϵ = [truncation_error(diagview(DV[1]), ind)] return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind) end function $(_make_eig_t_pb)(A, DV, ind) @@ -174,7 +174,7 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg Ac = copy_input(svd_compact, A) USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg) USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc) - ϵ = truncation_error(diagview(USVᴴ[2]), ind) + ϵ = [truncation_error(diagview(USVᴴ[2]), ind)] return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind) end function _make_svd_trunc_pullback(A, USVᴴ, ind) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index aa16f61e..8ccca109 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -179,10 +179,9 @@ for (f, pb, adj) in ( # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} + function $adj(::NoRData) Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) $pb(dA, A, (D, V), (dD, dV)) @@ -316,10 +315,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} + function svd_trunc_adjoint(::NoRData) Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" U, dU = arrayify(Utrunc, dUtrunc_) S, dS = arrayify(Strunc, dStrunc_) Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) diff --git a/src/algorithms.jl b/src/algorithms.jl index e9e7b8e8..0826bba4 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -228,7 +228,9 @@ truncation through `trunc`. struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm alg::A trunc::T + compute_error::Bool end +TruncatedAlgorithm(alg::A, trunc::T; compute_error::Bool = true) where {A <: AbstractAlgorithm, T} = TruncatedAlgorithm{A, T}(alg, trunc, compute_error) does_truncate(::TruncatedAlgorithm) = true diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 9b14167c..ed7c04c7 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -66,7 +66,9 @@ function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::AbstractAlg return D end function initialize_output(::typeof(eig_trunc!), A, alg::TruncatedAlgorithm) - return initialize_output(eig_full!, A, alg.alg) + DV = initialize_output(eig_full!, A, alg.alg) + ϵ = similar(A, real(eltype(A)), alg.compute_error) + return (DV..., ϵ) end function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm) @@ -115,10 +117,14 @@ function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm) return D end -function eig_trunc!(A, DV, alg::TruncatedAlgorithm) - D, V = eig_full!(A, DV, alg.alg) +function eig_trunc!(A, DVϵ, alg::TruncatedAlgorithm) + D, V, ϵ = DVϵ + D, V = eig_full!(A, (D, V), alg.alg) DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc) - return DVtrunc..., truncation_error!(diagview(D), ind) + if !isempty(ϵ) + ϵ .= truncation_error!(diagview(D), ind) + end + return DVtrunc..., ϵ end # Diagonal logic diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index a45300dc..07c15bc1 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -75,7 +75,9 @@ function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::AbstractAl return D end function initialize_output(::typeof(eigh_trunc!), A, alg::TruncatedAlgorithm) - return initialize_output(eigh_full!, A, alg.alg) + DV = initialize_output(eigh_full!, A, alg.alg) + ϵ = similar(A, real(eltype(A)), alg.compute_error) + return (DV..., ϵ) end function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm) @@ -129,10 +131,14 @@ function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm) return D end -function eigh_trunc!(A, DV, alg::TruncatedAlgorithm) - D, V = eigh_full!(A, DV, alg.alg) +function eigh_trunc!(A, DVϵ, alg::TruncatedAlgorithm) + D, V, ϵ = DVϵ + D, V = eigh_full!(A, (D, V), alg.alg) DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc) - return DVtrunc..., truncation_error!(diagview(D), ind) + if !isempty(ϵ) + ϵ .= truncation_error!(diagview(D), ind) + end + return DVtrunc..., ϵ end # Diagonal logic diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 126e6a04..faa9dc18 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -90,7 +90,9 @@ function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlg return similar(A, real(eltype(A)), (min(size(A)...),)) end function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm) - return initialize_output(svd_compact!, A, alg.alg) + USVᴴ = initialize_output(svd_compact!, A, alg.alg) + ϵ = similar(A, real(eltype(A)), alg.compute_error) + return (USVᴴ..., ϵ) end function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm) @@ -206,12 +208,6 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end -function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} - ϵ = similar(A, real(eltype(A)), compute_error) - (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) -end - function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ} U, S, Vᴴ, ϵ = USVᴴϵ U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg) @@ -272,11 +268,11 @@ end ### function check_input( - ::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized + ::typeof(svd_trunc!), A::AbstractMatrix, USVᴴϵ, alg::CUSOLVER_Randomized ) m, n = size(A) minmn = min(m, n) - U, S, Vᴴ = USVᴴ + U, S, Vᴴ, ϵ = USVᴴϵ @assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix @check_size(U, (m, m)) @check_scalar(U, A) @@ -284,6 +280,7 @@ function check_input( @check_scalar(S, A, real) @check_size(Vᴴ, (n, n)) @check_scalar(Vᴴ, A) + @check_scalar(ϵ, A, real) return nothing end @@ -295,7 +292,8 @@ function initialize_output( U = similar(A, (m, m)) S = Diagonal(similar(A, real(eltype(A)), (minmn,))) Vᴴ = similar(A, (n, n)) - return (U, S, Vᴴ) + ϵ = similar(A, real(eltype(A)), alg.compute_error) + return (U, S, Vᴴ, ϵ) end function _gpu_gesvd!( diff --git a/test/chainrules.jl b/test/chainrules.jl index 5258b839..64b498f7 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -247,7 +247,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), + output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -262,7 +262,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eig_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), + output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -328,7 +328,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), + output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -343,7 +343,7 @@ end ΔVtrunc = ΔV[:, ind] test_rrule( copy_eigh_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))), + output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind) @@ -380,7 +380,7 @@ end test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), + output_tangent = (ΔD[ind, ind], ΔV[:, ind], [zero(real(T))]), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -389,7 +389,7 @@ end test_rrule( config, eigh_trunc2, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))), + output_tangent = (ΔD[ind, ind], ΔV[:, ind], [zero(real(T))]), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -431,7 +431,7 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) @@ -448,7 +448,7 @@ end ΔVᴴtrunc = ΔVᴴ[ind, :] test_rrule( copy_svd_trunc, A, truncalg ⊢ NoTangent(); - output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), + output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]), atol = atol, rtol = rtol ) dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind) @@ -477,7 +477,7 @@ end test_rrule( config, svd_trunc, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], [zero(real(T))]), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end @@ -486,7 +486,7 @@ end test_rrule( config, svd_trunc, A; fkwargs = (; trunc = trunc), - output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))), + output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], [zero(real(T))]), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end diff --git a/test/eig.jl b/test/eig.jl index 6da6d72c..3db9590c 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -48,21 +48,21 @@ end D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) @test length(diagview(D1)) == r @test A * V1 ≈ V1 * D1 - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ1) ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * abs(D₀[r + 1])) D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ2) ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ3) ≈ norm(view(D₀, (r + 1):m)) atol = atol # trunctol keeps order, truncrank might not # test for same subspace @@ -83,13 +83,13 @@ end alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(D)[3:4]) atol = atol @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + @test norm(ϵ3) ≈ norm(diagview(D)[3:4]) atol = atol end @testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) @@ -112,5 +112,5 @@ end alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(A2)[3:4]) atol = atol end diff --git a/test/eigh.jl b/test/eigh.jl index 92b0f3a0..c6c9a1b8 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -57,21 +57,21 @@ end @test isisometric(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ1) ≈ norm(view(D₀, (r + 1):m)) atol = atol trunc = trunctol(; atol = s * D₀[r + 1]) D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometric(V2) @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ2) ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ3) ≈ norm(view(D₀, (r + 1):m)) atol = atol # test for same subspace @test V1 * (V1' * V2) ≈ V2 @@ -93,12 +93,12 @@ end D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(D)[3:4]) atol = atol alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + @test norm(ϵ3) ≈ norm(diagview(D)[3:4]) atol = atol end @testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) @@ -122,5 +122,5 @@ end alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(A2)[3:4]) atol = atol end diff --git a/test/genericlinearalgebra/eigh.jl b/test/genericlinearalgebra/eigh.jl index 7e602026..668aa8aa 100644 --- a/test/genericlinearalgebra/eigh.jl +++ b/test/genericlinearalgebra/eigh.jl @@ -49,21 +49,21 @@ end @test isisometric(V1) @test A * V1 ≈ V1 * D1 @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ1) ≈ norm(view(D₀, (r + 1):m)) atol = atol trunc = trunctol(; atol = s * D₀[r + 1]) D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test isisometric(V2) @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ2) ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ3) ≈ norm(view(D₀, (r + 1):m)) atol = atol # test for same subspace @test V1 * (V1' * V2) ≈ V2 @@ -84,10 +84,10 @@ end D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(D)[3:4]) atol = atol alg = TruncatedAlgorithm(GLA_QRIteration(), truncerror(; atol = 0.2)) D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + @test norm(ϵ3) ≈ norm(diagview(D)[3:4]) atol = atol end diff --git a/test/genericlinearalgebra/svd.jl b/test/genericlinearalgebra/svd.jl index f7177e79..cf08020f 100644 --- a/test/genericlinearalgebra/svd.jl +++ b/test/genericlinearalgebra/svd.jl @@ -110,7 +110,7 @@ end @test diagview(S1) ≈ S₀[1:r] @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] # Test truncation error - @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + @test norm(ϵ1) ≈ norm(view(S₀, (r + 1):minmn)) atol = atol s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) @@ -120,7 +120,7 @@ end @test U1 ≈ U2 @test S1 ≈ S2 @test V1ᴴ ≈ V2ᴴ - @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + @test norm(ϵ2) ≈ norm(view(S₀, (r + 1):minmn)) atol = atol trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) @@ -128,7 +128,7 @@ end @test U1 ≈ U3 @test S1 ≈ S3 @test V1ᴴ ≈ V3ᴴ - @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + @test norm(ϵ3) ≈ norm(view(S₀, (r + 1):minmn)) atol = atol end end @@ -166,6 +166,6 @@ end alg = TruncatedAlgorithm(GLA_QRIteration(), trunctol(; atol = 0.2)) U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] - @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(S)[3:4]) atol = atol @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) end diff --git a/test/genericschur/eig.jl b/test/genericschur/eig.jl index ce1e8f1b..84715888 100644 --- a/test/genericschur/eig.jl +++ b/test/genericschur/eig.jl @@ -49,21 +49,21 @@ end @test length(diagview(D1)) == r @test A * V1 ≈ V1 * D1 - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ1) ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * abs(D₀[r + 1])) D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D2)) == r @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ2) ≈ norm(view(D₀, (r + 1):m)) atol = atol s = 1 - sqrt(eps(real(T))) trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) @test length(diagview(D3)) == r @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + @test norm(ϵ3) ≈ norm(view(D₀, (r + 1):m)) atol = atol # trunctol keeps order, truncrank might not # test for same subspace @@ -83,13 +83,13 @@ end alg = TruncatedAlgorithm(GS_QRIteration(), truncrank(2)) D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(D)[3:4]) atol = atol @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) alg = TruncatedAlgorithm(GS_QRIteration(), truncerror(; atol = 0.2, p = 1)) D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + @test norm(ϵ3) ≈ norm(diagview(D)[3:4]) atol = atol end @testset "eig for Diagonal{$T}" for T in eltypes @@ -112,5 +112,5 @@ end alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(A2)[3:4]) atol = atol end diff --git a/test/mooncake.jl b/test/mooncake.jl index 3e19e44d..e9e263af 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -273,9 +273,9 @@ end ΔVtrunc = ΔV[:, ind] dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, [zero(real(T))])), dDtrunc, dVtrunc, [zero(real(T))]) Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V, [zero(real(T))]), (ΔD2, ΔV, [zero(real(T))]), truncalg) end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) @@ -285,9 +285,9 @@ end ΔVtrunc = ΔV[:, ind] dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, [zero(real(T))])), dDtrunc, dVtrunc, [zero(real(T))]) Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V, [zero(real(T))]), (ΔD2, ΔV, [zero(real(T))]), truncalg) end end end @@ -317,9 +317,9 @@ function copy_eigh_trunc(A, alg; kwargs...) return eigh_trunc(A, alg; kwargs...) end -function copy_eigh_trunc!(A, DV, alg; kwargs...) +function copy_eigh_trunc!(A, DVϵ, alg; kwargs...) A = (A + A') / 2 - return eigh_trunc!(A, DV, alg; kwargs...) + return eigh_trunc!(A, DVϵ, alg; kwargs...) end MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) @@ -365,9 +365,9 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop ΔVtrunc = ΔV[:, ind] dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, [zero(real(T))])), dDtrunc, dVtrunc, [zero(real(T))]) Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V, [zero(real(T))]), (ΔD2, ΔV, [zero(real(T))]), truncalg) end truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) @@ -377,9 +377,9 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop ΔVtrunc = ΔV[:, ind] dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) - dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, [zero(real(T))])), dDtrunc, dVtrunc, [zero(real(T))]) Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) - test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V, [zero(real(T))]), (ΔD2, ΔV, [zero(real(T))]), truncalg) end end end @@ -454,10 +454,10 @@ end dStrunc = make_mooncake_tangent(ΔStrunc) dUtrunc = make_mooncake_tangent(ΔUtrunc) dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) + ϵ = [zero(real(T))] dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ, ϵ), (ΔU, ΔS2, ΔVᴴ, copy(ϵ)), truncalg) end @testset "trunctol" begin U, S, Vᴴ = svd_compact(A) @@ -477,10 +477,10 @@ end dStrunc = make_mooncake_tangent(ΔStrunc) dUtrunc = make_mooncake_tangent(ΔUtrunc) dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) + ϵ = [zero(real(T))] dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) - test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ, ϵ), (ΔU, ΔS2, ΔVᴴ, ϵ), truncalg) end end end diff --git a/test/svd.jl b/test/svd.jl index d055f866..b15ae1d3 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -134,7 +134,7 @@ end @test diagview(S1) ≈ S₀[1:r] @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] # Test truncation error - @test ϵ1 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + @test norm(ϵ1) ≈ norm(view(S₀, (r + 1):minmn)) atol = atol s = 1 + sqrt(eps(real(T))) trunc = trunctol(; atol = s * S₀[r + 1]) @@ -144,7 +144,7 @@ end @test U1 ≈ U2 @test S1 ≈ S2 @test V1ᴴ ≈ V2ᴴ - @test ϵ2 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + @test norm(ϵ2) ≈ norm(view(S₀, (r + 1):minmn)) atol = atol trunc = truncerror(; atol = s * norm(@view(S₀[(r + 1):end]))) U3, S3, V3ᴴ, ϵ3 = @constinferred svd_trunc(A; alg, trunc) @@ -152,7 +152,7 @@ end @test U1 ≈ U3 @test S1 ≈ S3 @test V1ᴴ ≈ V3ᴴ - @test ϵ3 ≈ norm(view(S₀, (r + 1):minmn)) atol = atol + @test norm(ϵ3) ≈ norm(view(S₀, (r + 1):minmn)) atol = atol end end end @@ -199,7 +199,7 @@ end alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), trunctol(; atol = 0.2)) U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg) @test diagview(S2) ≈ diagview(S)[1:2] - @test ϵ2 ≈ norm(diagview(S)[3:4]) atol = atol + @test norm(ϵ2) ≈ norm(diagview(S)[3:4]) atol = atol @test_throws ArgumentError svd_trunc(A; alg, trunc = (; maxrank = 2)) end @@ -235,6 +235,6 @@ end alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) U3, S3, Vᴴ3, ϵ3 = @constinferred svd_trunc(A; alg) @test diagview(S3) ≈ S2[1:min(m, 2)] - @test ϵ3 ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol + @test norm(ϵ3) ≈ norm(S2[(min(m, 2) + 1):m]) atol = atol end end From cc90916e83f25e3079fdda233fd78dc92597ff6d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 14:19:55 +0100 Subject: [PATCH 2/3] Fix doc --- docs/src/user_interface/truncations.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/user_interface/truncations.md b/docs/src/user_interface/truncations.md index ee730020..9116b7eb 100644 --- a/docs/src/user_interface/truncations.md +++ b/docs/src/user_interface/truncations.md @@ -123,7 +123,7 @@ For example: ```jldoctest truncations; output=false using LinearAlgebra: norm U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2)) -norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values +norm(A - U * S * Vᴴ) ≈ only(ϵ) # ϵ is the 2-norm of the discarded singular values, stored as a Vector # output true From 4905fc4bfd49ff66be6f956ebee0fcdc0f0756bd Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Fri, 12 Dec 2025 14:30:09 +0100 Subject: [PATCH 3/3] CUDA fix --- src/implementations/svd.jl | 4 ++-- test/cuda/svd.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index faa9dc18..470214be 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -372,7 +372,7 @@ end function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ} U, S, Vᴴ, ϵ = USVᴴϵ - check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) + check_input(svd_trunc!, A, USVᴴϵ, alg.alg) _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong @@ -384,7 +384,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg: normA = norm(A) # equivalent to sqrt(normA^2 - normS^2) # but may be more accurate - ϵ = sqrt((normA + normS) * (normA - normS)) + ϵ .= sqrt((normA + normS) * (normA - normS)) end do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index fc564fec..ca3f8edd 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -143,7 +143,7 @@ end U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r)) @test length(S1.diag) == r @test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1] - @test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1 + @test norm(A - U1 * S1 * V1ᴴ) ≈ norm(ϵ1) if !(alg isa CUSOLVER_Randomized) s = 1 + sqrt(eps(real(T)))