Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ for (f!, f, f_full, pb, adj) in (
end
end

for (f, pb, adj) in (
(:eig_trunc, :eig_trunc_pullback!, :eig_trunc_adjoint),
(:eigh_trunc, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
for (f, f_ne, pb, adj) in (
(:eig_trunc, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint),
(:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
Expand All @@ -192,6 +192,29 @@ for (f, pb, adj) in (
end
return output_codual, $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
alg = Mooncake.primal(alg_dalg)
output = $f_ne(A, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function $adj(::NoRData)
Dtrunc, Vtrunc = Mooncake.primal(output_codual)
dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual)
D, dD = arrayify(Dtrunc, dDtrunc_)
V, dV = arrayify(Vtrunc, dVtrunc_)
$pb(dA, A, (D, V), (dD, dV))
MatrixAlgebraKit.zero!(dD)
MatrixAlgebraKit.zero!(dV)
return NoRData(), NoRData(), NoRData()
end
return output_codual, $adj
end
end
end

Expand Down
8 changes: 4 additions & 4 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ export qr_compact, qr_full, qr_null, lq_compact, lq_full, lq_null
export qr_compact!, qr_full!, qr_null!, lq_compact!, lq_full!, lq_null!
export svd_compact, svd_full, svd_vals, svd_trunc, svd_trunc_no_error
export svd_compact!, svd_full!, svd_vals!, svd_trunc!, svd_trunc_no_error!
export eigh_full, eigh_vals, eigh_trunc
export eigh_full!, eigh_vals!, eigh_trunc!
export eig_full, eig_vals, eig_trunc
export eig_full!, eig_vals!, eig_trunc!
export eigh_full, eigh_vals, eigh_trunc, eigh_trunc_no_error
export eigh_full!, eigh_vals!, eigh_trunc!, eigh_trunc_no_error!
export eig_full, eig_vals, eig_trunc, eig_trunc_no_error
export eig_full!, eig_vals!, eig_trunc!, eig_trunc_no_error!
export gen_eig_full, gen_eig_vals
export gen_eig_full!, gen_eig_vals!
export schur_full, schur_vals
Expand Down
10 changes: 8 additions & 2 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function copy_input(::typeof(eig_full), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end
copy_input(::typeof(eig_vals), A) = copy_input(eig_full, A)
copy_input(::typeof(eig_trunc), A) = copy_input(eig_full, A)
copy_input(::Union{typeof(eig_trunc), typeof(eig_trunc_no_error)}, A) = copy_input(eig_full, A)

copy_input(::typeof(eig_full), A::Diagonal) = copy(A)

Expand Down Expand Up @@ -65,7 +65,7 @@ function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::AbstractAlg
D = similar(A, Tc, n)
return D
end
function initialize_output(::typeof(eig_trunc!), A, alg::TruncatedAlgorithm)
function initialize_output(::Union{typeof(eig_trunc!), typeof(eig_trunc_no_error!)}, A, alg::TruncatedAlgorithm)
return initialize_output(eig_full!, A, alg.alg)
end

Expand Down Expand Up @@ -121,6 +121,12 @@ function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
return DVtrunc..., truncation_error!(diagview(D), ind)
end

function eig_trunc_no_error!(A, DV, alg::TruncatedAlgorithm)
D, V = eig_full!(A, DV, alg.alg)
DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc)
return DVtrunc
end

# Diagonal logic
# --------------
function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal, Diagonal}, alg::DiagonalAlgorithm)
Expand Down
10 changes: 8 additions & 2 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function copy_input(::typeof(eigh_full), A::AbstractMatrix)
return copy!(similar(A, float(eltype(A))), A)
end
copy_input(::typeof(eigh_vals), A) = copy_input(eigh_full, A)
copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)
copy_input(::Union{typeof(eigh_trunc), typeof(eigh_trunc_no_error)}, A) = copy_input(eigh_full, A)

copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)

Expand Down Expand Up @@ -74,7 +74,7 @@ function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::AbstractAl
D = similar(A, real(eltype(A)), n)
return D
end
function initialize_output(::typeof(eigh_trunc!), A, alg::TruncatedAlgorithm)
function initialize_output(::Union{typeof(eigh_trunc!), typeof(eigh_trunc_no_error!)}, A, alg::TruncatedAlgorithm)
return initialize_output(eigh_full!, A, alg.alg)
end

Expand Down Expand Up @@ -135,6 +135,12 @@ function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
return DVtrunc..., truncation_error!(diagview(D), ind)
end

function eigh_trunc_no_error!(A, DV, alg::TruncatedAlgorithm)
D, V = eigh_full!(A, DV, alg.alg)
DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc)
return DVtrunc
end

# Diagonal logic
# --------------
function eigh_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
Expand Down
77 changes: 66 additions & 11 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ and the diagonal matrix `D` contains the associated eigenvalues.
!!! note
$(docs_eig_note)

See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc).
See also [`eig_vals(!)`](@ref eig_vals), [`eig_trunc_no_error`](@ref eig_trunc_no_error)
and [`eig_trunc(!)`](@ref eig_trunc).
"""
@functiondef eig_full

Expand Down Expand Up @@ -79,11 +80,63 @@ truncation strategy is already embedded in the algorithm.
!!! note
$docs_eig_note

See also [`eig_full(!)`](@ref eig_full), [`eig_vals(!)`](@ref eig_vals), and
[Truncations](@ref) for more information on truncation strategies.
See also [`eig_full(!)`](@ref eig_full), [`eig_vals(!)`](@ref eig_vals),
[`eig_trunc_no_error!`](@ref eig_trunc_no_error) and [Truncations](@ref)
for more information on truncation strategies.
"""
@functiondef eig_trunc

"""
eig_trunc_no_error(A; [trunc], kwargs...) -> D, V
eig_trunc_no_error(A, alg::AbstractAlgorithm) -> D, V
eig_trunc_no_error!(A, [DV]; [trunc], kwargs...) -> D, V
eig_trunc_no_error!(A, [DV], alg::AbstractAlgorithm) -> D, V

Compute a partial or truncated eigenvalue decomposition of the matrix `A`,
such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains
a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues,
selected according to a truncation strategy. The truncation error is *not* returned.

## Truncation
The truncation strategy can be controlled via the `trunc` keyword argument. This can be
either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or
nothing, all values will be kept.

### `trunc::NamedTuple`
The supported truncation keyword arguments are:

$docs_truncation_kwargs

### `trunc::TruncationStrategy`
For more control, a truncation strategy can be supplied directly.
By default, MatrixAlgebraKit supplies the following:

$docs_truncation_strategies

## Keyword Arguments
Other keyword arguments are passed to the algorithm selection procedure. If no explicit
`alg` is provided, these keywords are used to select and configure the algorithm through
[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm
selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref)
for the default algorithm selection behavior.

When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the
truncation strategy is already embedded in the algorithm.

!!! note
The bang method `eig_trunc!` optionally accepts the output structure and
possibly destroys the input matrix `A`. Always use the return value of the function
as it may not always be possible to use the provided `DV` as output.

!!! note
$docs_eig_note

See also [`eig_full(!)`](@ref eig_full), [`eig_vals(!)`](@ref eig_vals),
[`eig_trunc(!)`](@ref eig_trunc) and [Truncations](@ref) for more
information on truncation strategies.
"""
@functiondef eig_trunc_no_error

"""
eig_vals(A; kwargs...) -> D
eig_vals(A, alg::AbstractAlgorithm) -> D
Expand Down Expand Up @@ -121,13 +174,15 @@ for f in (:eig_full!, :eig_vals!)
end
end

function select_algorithm(::typeof(eig_trunc!), A, alg; trunc = nothing, kwargs...)
if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
for f in (:eig_trunc!, :eig_trunc_no_error!)
@eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...)
if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_eig = select_algorithm(eig_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end
end
end
74 changes: 64 additions & 10 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,63 @@ truncation strategy is already embedded in the algorithm.
!!! note
$(docs_eigh_note)

See also [`eigh_full(!)`](@ref eigh_full), [`eigh_vals(!)`](@ref eigh_vals), and
[Truncations](@ref) for more information on truncation strategies.
See also [`eigh_full(!)`](@ref eigh_full), [`eigh_vals(!)`](@ref eigh_vals),
[`eigh_trunc_no_error(!)`](@ref eigh_trunc_no_error) and [Truncations](@ref)
for more information on truncation strategies.
"""
@functiondef eigh_trunc

"""
eigh_trunc_no_error(A; [trunc], kwargs...) -> D, V
eigh_trunc_no_error(A, alg::AbstractAlgorithm) -> D, V
eigh_trunc_no_error!(A, [DV]; [trunc], kwargs...) -> D, V
eigh_trunc_no_error!(A, [DV], alg::AbstractAlgorithm) -> D, V

Compute a partial or truncated eigenvalue decomposition of the symmetric or hermitian matrix
`A`, such that `A * V ≈ V * D`, where the isometric matrix `V` contains a subset of the
orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues,
selected according to a truncation strategy. The function does *not* returns the truncation error.

## Truncation
The truncation strategy can be controlled via the `trunc` keyword argument. This can be
either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or
nothing, all values will be kept.

### `trunc::NamedTuple`
The supported truncation keyword arguments are:

$docs_truncation_kwargs

### `trunc::TruncationStrategy`
For more control, a truncation strategy can be supplied directly.
By default, MatrixAlgebraKit supplies the following:

$docs_truncation_strategies

## Keyword arguments
Other keyword arguments are passed to the algorithm selection procedure. If no explicit
`alg` is provided, these keywords are used to select and configure the algorithm through
[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm
selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref)
for the default algorithm selection behavior.

When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the
truncation strategy is already embedded in the algorithm.

!!! note
The bang method `eigh_trunc!` optionally accepts the output structure and
possibly destroys the input matrix `A`. Always use the return value of the function
as it may not always be possible to use the provided `DV` as output.

!!! note
$(docs_eigh_note)

See also [`eigh_full(!)`](@ref eigh_full), [`eigh_vals(!)`](@ref eigh_vals),
[`eigh_trunc(!)`](@ref eig_trunc), and [Truncations](@ref) for more information
on truncation strategies.
"""
@functiondef eigh_trunc_no_error

"""
eigh_vals(A; kwargs...) -> D
eigh_vals(A, alg::AbstractAlgorithm) -> D
Expand Down Expand Up @@ -128,13 +180,15 @@ for f in (:eigh_full!, :eigh_vals!)
end
end

function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc = nothing, kwargs...)
if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
for f in (:eigh_trunc!, :eigh_trunc_no_error!)
@eval function select_algorithm(::typeof($f), A, alg; trunc = nothing, kwargs...)
if alg isa TruncatedAlgorithm
isnothing(trunc) ||
throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`"))
return alg
else
alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...)
return TruncatedAlgorithm(alg_eig, select_truncation(trunc))
end
end
end
14 changes: 14 additions & 0 deletions test/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ end
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(T)))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D4, V4 = @constinferred eig_trunc_no_error(A; alg, trunc)
@test length(diagview(D4)) == r
Comment on lines +67 to +70
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused by this logic. You want an error that is slightly smaller than the norm of D₀[r:end], which means D₀[r] still needs to be kept. I would find it more intuitive to say that you want a truncation error that is allowed to be slightly bigger than D₀[r+1:end]. But maybe that is just me 🙂 .

Copy link
Member

@Jutho Jutho Dec 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, my reasoning would be: I want D₀[r+1:end] to be truncated, so I admit an error atol = norm(D₀[r+1:end]). But then I multiply this with a factor slightly bigger than 1 to account for finite precision errors.

That seems one less mental step than saying, I want an error that is slightly smaller than norm(D₀[r:end]), so actually, I do still want to keep D₀[r] and only start to truncate from r+1 onwards.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see what you mean. But I suppose this is the "already existing" test so maybe we should open an issue about this to deal with it separately?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, now that I get the logic, it is also fine with me as is. Not sure why I didn't stumble on this before.

@test A * V4 ≈ V4 * D4
# trunctol keeps order, truncrank might not
# test for same subspace
@test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2
Expand All @@ -90,6 +95,10 @@ end
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg)
@test diagview(D3) ≈ diagview(D)[1:2]
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol

alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
D4, V4 = @constinferred eig_trunc_no_error(A; alg)
@test diagview(D4) ≈ diagview(D)[1:2]
end

@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
Expand All @@ -113,4 +122,9 @@ end
D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg)
@test diagview(D2) ≈ diagview(A2)[1:2]
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol

A3 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
D3, V3 = @constinferred eig_trunc_no_error(A3; alg)
@test diagview(D3) ≈ diagview(A3)[1:2]
end
15 changes: 15 additions & 0 deletions test/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ end
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(T)))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D4, V4 = @constinferred eigh_trunc_no_error(A; alg, trunc)
@test length(diagview(D4)) == r
@test A * V4 ≈ V4 * D4

# test for same subspace
@test V1 * (V1' * V2) ≈ V2
@test V2 * (V2' * V1) ≈ V1
Expand All @@ -99,6 +105,10 @@ end
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg)
@test diagview(D3) ≈ diagview(D)[1:2]
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol

alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2))
D4, V4 = @constinferred eigh_trunc_no_error(A; alg)
@test diagview(D4) ≈ diagview(D)[1:2]
end

@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
Expand All @@ -123,4 +133,9 @@ end
D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
@test diagview(D2) ≈ diagview(A2)[1:2]
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol

A3 = Diagonal(T[0.9, 0.3, 0.1, 0.01])
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(3))
D3, V3 = @constinferred eigh_trunc_no_error(A3; alg)
@test diagview(D3) ≈ diagview(A3)[1:3]
end
6 changes: 6 additions & 0 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ end
dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T)))
Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc)
Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg)
end
truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real))
ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc)
Expand All @@ -288,6 +291,9 @@ end
dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T)))
Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T))))
dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc)
Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false)
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg)
end
end
end
Expand Down