diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index c2814ef6..f6feda8b 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -163,9 +163,9 @@ for (f!, f, f_full, pb, adj) in ( end end -for (f, pb, adj) in ( - (:eig_trunc, :eig_trunc_pullback!, :eig_trunc_adjoint), - (:eigh_trunc, :eigh_trunc_pullback!, :eigh_trunc_adjoint), +for (f, f_ne, pb, adj) in ( + (:eig_trunc, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint), + (:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint), ) @eval begin @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} @@ -192,6 +192,29 @@ for (f, pb, adj) in ( end return output_codual, $adj end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + output = $f_ne(A, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function $adj(::NoRData) + Dtrunc, Vtrunc = Mooncake.primal(output_codual) + dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual) + D, dD = arrayify(Dtrunc, dDtrunc_) + V, dV = arrayify(Vtrunc, dVtrunc_) + $pb(dA, A, (D, V), (dD, dV)) + MatrixAlgebraKit.zero!(dD) + MatrixAlgebraKit.zero!(dV) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end end end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 85cfc633..2178d41a 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -18,10 +18,10 @@ export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null! export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_no_error export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_no_error! -export eigh_full, eigh_vals, eigh_trunc -export eigh_full!, eigh_vals!, eigh_trunc! -export eig_full, eig_vals, eig_trunc -export eig_full!, eig_vals!, eig_trunc! +export eigh_full, eigh_vals, eigh_trunc, eigh_trunc_no_error +export eigh_full!, eigh_vals!, eigh_trunc!, eigh_trunc_no_error! +export eig_full, eig_vals, eig_trunc, eig_trunc_no_error +export eig_full!, eig_vals!, eig_trunc!, eig_trunc_no_error! export gen_eig_full, gen_eig_vals export gen_eig_full!, gen_eig_vals! export schur_full, schur_vals diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 9b14167c..5a0dd679 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -4,7 +4,7 @@ function copy_input(::typeof(eig_full), A::AbstractMatrix) return copy!(similar(A, float(eltype(A))), A) end copy_input(::typeof(eig_vals), A) = copy_input(eig_full, A) -copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A) +copy_input(::Union{typeof(eig_trunc), typeof(eig_trunc_no_error)}, A) = copy_input(eig_full, A) copy_input(::typeof(eig_full), A::Diagonal) = copy(A) @@ -65,7 +65,7 @@ function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::AbstractAlg D = similar(A, Tc, n) return D end -function initialize_output(::typeof(eig_trunc!), A, alg::TruncatedAlgorithm) +function initialize_output(::Union{typeof(eig_trunc!), typeof(eig_trunc_no_error!)}, A, alg::TruncatedAlgorithm) return initialize_output(eig_full!, A, alg.alg) end @@ -121,6 +121,12 @@ function eig_trunc!(A, DV, alg::TruncatedAlgorithm) return DVtrunc..., truncation_error!(diagview(D), ind) end +function eig_trunc_no_error!(A, DV, alg::TruncatedAlgorithm) + D, V = eig_full!(A, DV, alg.alg) + DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc) + return DVtrunc +end + # Diagonal logic # -------------- function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal, Diagonal}, alg::DiagonalAlgorithm) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index a45300dc..40f2c557 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -4,7 +4,7 @@ function copy_input(::typeof(eigh_full), A::AbstractMatrix) return copy!(similar(A, float(eltype(A))), A) end copy_input(::typeof(eigh_vals), A) = copy_input(eigh_full, A) -copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A) +copy_input(::Union{typeof(eigh_trunc), typeof(eigh_trunc_no_error)}, A) = copy_input(eigh_full, A) copy_input(::typeof(eigh_full), A::Diagonal) = copy(A) @@ -74,7 +74,7 @@ function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::AbstractAl D = similar(A, real(eltype(A)), n) return D end -function initialize_output(::typeof(eigh_trunc!), A, alg::TruncatedAlgorithm) +function initialize_output(::Union{typeof(eigh_trunc!), typeof(eigh_trunc_no_error!)}, A, alg::TruncatedAlgorithm) return initialize_output(eigh_full!, A, alg.alg) end @@ -135,6 +135,12 @@ function eigh_trunc!(A, DV, alg::TruncatedAlgorithm) return DVtrunc..., truncation_error!(diagview(D), ind) end +function eigh_trunc_no_error!(A, DV, alg::TruncatedAlgorithm) + D, V = eigh_full!(A, DV, alg.alg) + DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc) + return DVtrunc +end + # Diagonal logic # -------------- function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 28b5c69c..bb111c01 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -27,7 +27,8 @@ and the diagonal matrix `D` contains the associated eigenvalues. !!! note $(docs_eig_note) -See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc). +See also [`eig_vals(!)`](@ref eig_vals), [`eig_trunc_no_error`](@ref eig_trunc_no_error) +and [`eig_trunc(!)`](@ref eig_trunc). """ @functiondef eig_full @@ -79,11 +80,63 @@ truncation strategy is already embedded in the algorithm. !!! note $docs_eig_note -See also [`eig_full(!)`](@ref eig_full), [`eig_vals(!)`](@ref eig_vals), and -[Truncations](@ref) for more information on truncation strategies. +See also [`eig_full(!)`](@ref eig_full), [`eig_vals(!)`](@ref eig_vals), +[`eig_trunc_no_error!`](@ref eig_trunc_no_error) and [Truncations](@ref) +for more information on truncation strategies. """ @functiondef eig_trunc +""" + eig_trunc_no_error(A; [trunc], kwargs...) -> D, V + eig_trunc_no_error(A, alg::AbstractAlgorithm) -> D, V + eig_trunc_no_error!(A, [DV]; [trunc], kwargs...) -> D, V + eig_trunc_no_error!(A, [DV], alg::AbstractAlgorithm) -> D, V + +Compute a partial or truncated eigenvalue decomposition of the matrix `A`, +such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains +a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues, +selected according to a truncation strategy. The truncation error is *not* returned. + +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + +## Keyword Arguments +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. + +When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the +truncation strategy is already embedded in the algorithm. + +!!! note + The bang method `eig_trunc!` optionally accepts the output structure and + possibly destroys the input matrix `A`. Always use the return value of the function + as it may not always be possible to use the provided `DV` as output. + +!!! note +$docs_eig_note + +See also [`eig_full(!)`](@ref eig_full), [`eig_vals(!)`](@ref eig_vals), +[`eig_trunc(!)`](@ref eig_trunc) and [Truncations](@ref) for more +information on truncation strategies. +""" +@functiondef eig_trunc_no_error + """ eig_vals(A; kwargs...) -> D eig_vals(A, alg::AbstractAlgorithm) -> D @@ -121,13 +174,15 @@ for f in (:eig_full!, :eig_vals!) end end -function select_algorithm(::typeof(eig_trunc!), A, alg; trunc = nothing, kwargs...) - if alg isa TruncatedAlgorithm - isnothing(trunc) || - throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) - return alg - else - alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) +for f in (:eig_trunc!, :eig_trunc_no_error!) + @eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + end end end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 97f8f95c..42c7f9f3 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -84,11 +84,63 @@ truncation strategy is already embedded in the algorithm. !!! note $(docs_eigh_note) -See also [`eigh_full(!)`](@ref eigh_full), [`eigh_vals(!)`](@ref eigh_vals), and -[Truncations](@ref) for more information on truncation strategies. +See also [`eigh_full(!)`](@ref eigh_full), [`eigh_vals(!)`](@ref eigh_vals), +[`eigh_trunc_no_error(!)`](@ref eigh_trunc_no_error) and [Truncations](@ref) +for more information on truncation strategies. """ @functiondef eigh_trunc +""" + eigh_trunc_no_error(A; [trunc], kwargs...) -> D, V + eigh_trunc_no_error(A, alg::AbstractAlgorithm) -> D, V + eigh_trunc_no_error!(A, [DV]; [trunc], kwargs...) -> D, V + eigh_trunc_no_error!(A, [DV], alg::AbstractAlgorithm) -> D, V + +Compute a partial or truncated eigenvalue decomposition of the symmetric or hermitian matrix +`A`, such that `A * V ≈ V * D`, where the isometric matrix `V` contains a subset of the +orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues, +selected according to a truncation strategy. The function does *not* returns the truncation error. + +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + +## Keyword arguments +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. + +When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the +truncation strategy is already embedded in the algorithm. + +!!! note + The bang method `eigh_trunc!` optionally accepts the output structure and + possibly destroys the input matrix `A`. Always use the return value of the function + as it may not always be possible to use the provided `DV` as output. + +!!! note +$(docs_eigh_note) + +See also [`eigh_full(!)`](@ref eigh_full), [`eigh_vals(!)`](@ref eigh_vals), +[`eigh_trunc(!)`](@ref eig_trunc), and [Truncations](@ref) for more information +on truncation strategies. +""" +@functiondef eigh_trunc_no_error + """ eigh_vals(A; kwargs...) -> D eigh_vals(A, alg::AbstractAlgorithm) -> D @@ -128,13 +180,15 @@ for f in (:eigh_full!, :eigh_vals!) end end -function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc = nothing, kwargs...) - if alg isa TruncatedAlgorithm - isnothing(trunc) || - throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) - return alg - else - alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) +for f in (:eigh_trunc!, :eigh_trunc_no_error!) + @eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + end end end diff --git a/test/eig.jl b/test/eig.jl index 6da6d72c..d709bd20 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -64,6 +64,11 @@ end @test A * V3 ≈ V3 * D3 @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + s = 1 - sqrt(eps(real(T))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D4, V4 = @constinferred eig_trunc_no_error(A; alg, trunc) + @test length(diagview(D4)) == r + @test A * V4 ≈ V4 * D4 # trunctol keeps order, truncrank might not # test for same subspace @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2 @@ -90,6 +95,10 @@ end D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + + alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1)) + D4, V4 = @constinferred eig_trunc_no_error(A; alg) + @test diagview(D4) ≈ diagview(D)[1:2] end @testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) @@ -113,4 +122,9 @@ end D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol + + A3 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) + alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) + D3, V3 = @constinferred eig_trunc_no_error(A3; alg) + @test diagview(D3) ≈ diagview(A3)[1:2] end diff --git a/test/eigh.jl b/test/eigh.jl index 92b0f3a0..3b711c5b 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -73,6 +73,12 @@ end @test A * V3 ≈ V3 * D3 @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + s = 1 - sqrt(eps(real(T))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D4, V4 = @constinferred eigh_trunc_no_error(A; alg, trunc) + @test length(diagview(D4)) == r + @test A * V4 ≈ V4 * D4 + # test for same subspace @test V1 * (V1' * V2) ≈ V2 @test V2 * (V2' * V1) ≈ V1 @@ -99,6 +105,10 @@ end D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) @test diagview(D3) ≈ diagview(D)[1:2] @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + + alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) + D4, V4 = @constinferred eigh_trunc_no_error(A; alg) + @test diagview(D4) ≈ diagview(D)[1:2] end @testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) @@ -123,4 +133,9 @@ end D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) @test diagview(D2) ≈ diagview(A2)[1:2] @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol + + A3 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) + alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(3)) + D3, V3 = @constinferred eigh_trunc_no_error(A3; alg) + @test diagview(D3) ≈ diagview(A3)[1:3] end diff --git a/test/mooncake.jl b/test/mooncake.jl index a47bbb8d..c3917847 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -276,6 +276,9 @@ end dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) @@ -288,6 +291,9 @@ end dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) end end end