Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Aqua = "0.6, 0.7, 0.8"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
CUDA = "5"
JET = "0.9"
JET = "0.9, 0.10"
LinearAlgebra = "1"
SafeTestsets = "0.1"
StableRNGs = "1"
Expand Down
68 changes: 40 additions & 28 deletions docs/src/user_interface/truncations.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ Truncation strategies allow you to control which eigenvalues or singular values
Truncation strategies can be used with truncated decomposition functions in two ways, as illustrated below.
For concreteness, we use the following matrix as an example:

```jldoctest truncations
```jldoctest truncations; output=false
using MatrixAlgebraKit
using MatrixAlgebraKit: diagview

A = [2 1 0; 1 3 1; 0 1 4];
D, V = eigh_full(A);

diagview(D) ≈ [3 - √3, 3, 3 + √3]

# output

true
```

Expand All @@ -31,38 +29,35 @@ true
The simplest approach is to pass a `NamedTuple` with the truncation parameters.
For example, keeping only the largest 2 eigenvalues:

```jldoctest truncations
Dtrunc, Vtrunc = eigh_trunc(A; trunc = (maxrank = 2,));
```jldoctest truncations; output=false
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2,));
size(Dtrunc, 1) <= 2

# output

true
```

Note however that there are no guarantees on the order of the output values:

```jldoctest truncations
```jldoctest truncations; output=false
diagview(Dtrunc) ≈ diagview(D)[[3, 2]]

# output

true
```

You can also use tolerance-based truncation or combine multiple criteria:

```jldoctest truncations
Dtrunc, Vtrunc = eigh_trunc(A; trunc = (atol = 2.9,));
```jldoctest truncations; output=false
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (atol = 2.9,));
all(>(2.9), diagview(Dtrunc))

# output

true
```

```jldoctest truncations
Dtrunc, Vtrunc = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9));
```jldoctest truncations; output=false
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = (maxrank = 2, atol = 2.9));
size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))

# output
Expand All @@ -72,7 +67,7 @@ true
In general, the keyword arguments that are supported can be found in the `TruncationStrategy` docstring:

```@docs; canonical = false
TruncationStrategy
TruncationStrategy()
```


Expand All @@ -81,33 +76,22 @@ TruncationStrategy
For more control, you can construct [`TruncationStrategy`](@ref) objects directly.
This is also what the previous syntax will end up calling.

```jldoctest truncations
```jldoctest truncations; output=false
Dtrunc, Vtrunc = eigh_trunc(A; trunc = truncrank(2))
size(Dtrunc, 1) <= 2

# output

true
```

```jldoctest truncations
Dtrunc, Vtrunc = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9))
```jldoctest truncations; output=false
Dtrunc, Vtrunc, ϵ = eigh_trunc(A; trunc = truncrank(2) & trunctol(; atol = 2.9))
size(Dtrunc, 1) <= 2 && all(>(2.9), diagview(Dtrunc))

# output
true
```

## Truncation with SVD vs Eigenvalue Decompositions

When using truncations with different decomposition types, keep in mind:

- **`svd_trunc`**: Singular values are always real and non-negative, sorted in descending order. Truncation by value typically keeps the largest singular values.

- **`eigh_trunc`**: Eigenvalues are real but can be negative for symmetric matrices. By default, `truncrank` sorts by absolute value, so `truncrank(k)` keeps the `k` eigenvalues with largest magnitude (positive or negative).

- **`eig_trunc`**: For general (non-symmetric) matrices, eigenvalues can be complex. Truncation by absolute value considers the complex magnitude.

## Truncation Strategies

MatrixAlgebraKit provides several built-in truncation strategies:
Expand All @@ -127,3 +111,31 @@ When strategies are combined, only the values that satisfy all conditions are ke
combined_trunc = truncrank(10) & trunctol(; atol = 1e-6);
```

## Truncation Error

When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref), an additional truncation error value is returned.
This error is defined as the 2-norm of the discarded singular values or eigenvalues, providing a measure of the approximation quality.
For `svd_trunc` and `eigh_trunc`, this corresponds to the 2-norm difference between the original and the truncated matrix.
For the case of `eig_trunc`, this interpretation does not hold because the norm of the non-unitary matrix of eigenvectors and its inverse also influence the approximation quality.


For example:
```jldoctest truncations; output=false
using LinearAlgebra: norm
U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2))
norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values

# output
true
```

### Truncation with SVD vs Eigenvalue Decompositions

When using truncations with different decomposition types, keep in mind:

- **[`svd_trunc`](@ref)**: Singular values are always real and non-negative, sorted in descending order. Truncation by value typically keeps the largest singular values. The truncation error gives the 2-norm difference between the original and the truncated matrix.

- **[`eigh_trunc`](@ref)**: Eigenvalues are real but can be negative for symmetric matrices. By default, eigenvalues are treated by absolute value, e.g. `truncrank(k)` keeps the `k` eigenvalues with largest magnitude (positive or negative). The truncation error gives the 2-norm difference between the original and the truncated matrix.

- **[`eig_trunc`](@ref)**: For general (non-symmetric) matrices, eigenvalues can be complex. By default, eigenvalues are treated by absolute value. The truncation error gives an indication of the magnitude of discarded values, but is not directly related to the 2-norm difference between the original and the truncated matrix.

28 changes: 19 additions & 9 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MatrixAlgebraKitChainRulesCoreExt

using MatrixAlgebraKit
using MatrixAlgebraKit: copy_input, initialize_output, zero!, diagview,
TruncatedAlgorithm, findtruncated, findtruncated_svd
TruncatedAlgorithm, findtruncated, findtruncated_svd, truncation_error
using ChainRulesCore
using LinearAlgebra

Expand Down Expand Up @@ -113,15 +113,20 @@ for eig in (:eig, :eigh)
Ac = copy_input($eig_f, A)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return DV′, $(_make_eig_t_pb)(A, DV, ind)
ϵ = truncation_error(diagview(DV[1]), ind)
return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind)
end
function $(_make_eig_t_pb)(A, DV, ind)
function $eig_t_pb(ΔDV)
function $eig_t_pb(ΔDVϵ)
ΔA = zero(A)
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind)
ΔD, ΔV, Δϵ = ΔDVϵ
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
throw(ArgumentError("Pullback for eig_trunc! does not yet support non-zero tangent for the truncation error"))
end
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return $eig_t_pb
Expand Down Expand Up @@ -152,15 +157,20 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
Ac = copy_input(svd_compact, A)
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
end
function _make_svd_trunc_pullback(A, USVᴴ, ind)
function svd_trunc_pullback(ΔUSVᴴ)
function svd_trunc_pullback(ΔUSVᴴϵ)
ΔA = zero(A)
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind)
ΔU, ΔS, ΔVᴴ, Δϵ = ΔUSVᴴϵ
if !MatrixAlgebraKit.iszerotangent(Δϵ) && !iszero(unthunk(Δϵ))
throw(ArgumentError("Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"))
end
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.((ΔU, ΔS, ΔVᴴ)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return svd_trunc_pullback
Expand Down
3 changes: 2 additions & 1 deletion src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ end

function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
D, V = eig_full!(A, DV, alg.alg)
return first(truncate(eig_trunc!, (D, V), alg.trunc))
DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc)
return DVtrunc..., truncation_error!(diagview(D), ind)
end

# Diagonal logic
Expand Down
3 changes: 2 additions & 1 deletion src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ end

function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
D, V = eigh_full!(A, DV, alg.alg)
return first(truncate(eigh_trunc!, (D, V), alg.trunc))
DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc)
return DVtrunc..., truncation_error!(diagview(D), ind)
end

# Diagonal logic
Expand Down
12 changes: 9 additions & 3 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,9 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
end

function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
return first(truncate(svd_trunc!, USVᴴ′, alg.trunc))
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
return USVᴴtrunc..., truncation_error!(diagview(S), ind)
end

# Diagonal logic
Expand Down Expand Up @@ -381,7 +382,12 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
# TODO: make this controllable using a `gaugefix` keyword argument
gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...)
return first(truncate(svd_trunc!, USVᴴ, alg.trunc))
# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
Strunc = diagview(USVᴴtrunc[2])
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this?
return USVᴴtrunc..., ϵ
end

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
Expand Down
10 changes: 10 additions & 0 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,13 @@ end
_ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A)
_ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B
_ind_intersect(A, B) = intersect(A, B)

# Truncation error
# ----------------
truncation_error(values::AbstractVector, ind) = truncation_error!(copy(values), ind)
# destroys input in order to maximize accuracy:
# sqrt(norm(values)^2 - norm(values[ind])^2) might suffer from floating point error
function truncation_error!(values::AbstractVector, ind)
values[ind] .= zero(eltype(values))
return norm(values)
end
11 changes: 7 additions & 4 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,19 @@ See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc).
@functiondef eig_full

"""
eig_trunc(A; [trunc], kwargs...) -> D, V
eig_trunc(A, alg::AbstractAlgorithm) -> D, V
eig_trunc!(A, [DV]; [trunc], kwargs...) -> D, V
eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V
eig_trunc(A; [trunc], kwargs...) -> D, V, ϵ
eig_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ
eig_trunc!(A, [DV]; [trunc], kwargs...) -> D, V, ϵ
eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ

Compute a partial or truncated eigenvalue decomposition of the matrix `A`,
such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains
a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues,
selected according to a truncation strategy.

The function also returns `ϵ`, the truncation error defined as the 2-norm of the
discarded eigenvalues.

## Keyword arguments
The behavior of this function is controlled by the following keyword arguments:

Expand Down
22 changes: 14 additions & 8 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@ docs_eigh_note = """
"""

"""
eigh_full(A; kwargs...) -> D, V
eigh_full(A, alg::AbstractAlgorithm) -> D, V
eigh_full!(A, [DV]; kwargs...) -> D, V
eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V
eigh_full(A; kwargs...) -> D, V, ϵ
eigh_full(A, alg::AbstractAlgorithm) -> D, V, ϵ
eigh_full!(A, [DV]; kwargs...) -> D, V, ϵ
eigh_full!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ

Compute the full eigenvalue decomposition of the symmetric or hermitian matrix `A`,
such that `A * V = V * D`, where the unitary matrix `V` contains the orthogonal eigenvectors
and the real diagonal matrix `D` contains the associated eigenvalues.

The function also returns `ϵ`, the truncation error defined as the 2-norm of the
discarded eigenvalues.

!!! note
The bang method `eigh_full!` optionally accepts the output structure and
possibly destroys the input matrix `A`. Always use the return value of the function
Expand All @@ -34,16 +37,19 @@ See also [`eigh_vals(!)`](@ref eigh_vals) and [`eigh_trunc(!)`](@ref eigh_trunc)
@functiondef eigh_full

"""
eigh_trunc(A; [trunc], kwargs...) -> D, V
eigh_trunc(A, alg::AbstractAlgorithm) -> D, V
eigh_trunc!(A, [DV]; [trunc], kwargs...) -> D, V
eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V
eigh_trunc(A; [trunc], kwargs...) -> D, V, ϵ
eigh_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ
eigh_trunc!(A, [DV]; [trunc], kwargs...) -> D, V, ϵ
eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ

Compute a partial or truncated eigenvalue decomposition of the symmetric or hermitian matrix
`A`, such that `A * V ≈ V * D`, where the isometric matrix `V` contains a subset of the
orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues,
selected according to a truncation strategy.

The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded
eigenvalues.

## Keyword arguments
The behavior of this function is controlled by the following keyword arguments:

Expand Down
13 changes: 8 additions & 5 deletions src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,19 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and
@functiondef svd_compact

"""
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ
svd_trunc(A; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
svd_trunc!(A, [USVᴴ]; [trunc], kwargs...) -> U, S, Vᴴ, ϵ
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ

Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
`A * (Vᴴ)' = U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
`A * (Vᴴ)' U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a
square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy.

The function also returns `ϵ`, the truncation error defined as the 2-norm of the
discarded singular values.

## Keyword arguments
The behavior of this function is controlled by the following keyword arguments:

Expand Down
5 changes: 5 additions & 0 deletions src/interface/truncation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,8 @@ Base.:&(::NoTruncation, ::NoTruncation) = notrunc()
# disambiguate
Base.:&(::NoTruncation, trunc::TruncationIntersection) = trunc
Base.:&(trunc::TruncationIntersection, ::NoTruncation) = trunc

@doc """
truncation_error(values, ind)
Compute the truncation error as the 2-norm of the values that are not kept by `ind`.
""" truncation_error, truncation_error!
Loading