Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/src/user_interface/truncations.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,15 @@ maxdim = 10
atol = 1e-6
combined_trunc = truncrank(maxdim) & trunctol(; atol)
```

## Truncation Error

When using truncated decompositions such as [`svd_trunc`](@ref), [`eig_trunc`](@ref), or [`eigh_trunc`](@ref),
an additional truncation error value is returned. This error is defined as the 2-norm of the discarded
singular values or eigenvalues, providing a measure of the approximation quality.

For example:
```julia
U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(10))
# ϵ is the 2-norm of the discarded singular values
```
29 changes: 19 additions & 10 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,20 @@ for eig in (:eig, :eigh)
function ChainRulesCore.rrule(::typeof($eig_t!), A, DV, alg::TruncatedAlgorithm)
Ac = copy_input($eig_f, A)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return DV′, $(_make_eig_t_pb)(A, DV, ind)
DV′, ind, truncerr = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return (DV′..., truncerr), $(_make_eig_t_pb)(A, DV, ind)
end
function $(_make_eig_t_pb)(A, DV, ind)
function $eig_t_pb(ΔDV)
function $eig_t_pb(ΔDV_and_err)
# Extract the tangents for D and V, ignore the truncerr tangent
ΔD = ΔDV_and_err[1]
ΔV = ΔDV_and_err[2]
# Ignore ΔDV_and_err[3] which is the truncerr tangent
ΔA = zero(A)
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.(ΔDV), ind)
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, (unthunk(ΔD), unthunk(ΔV)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
function $eig_t_pb(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return $eig_t_pb
Expand Down Expand Up @@ -151,16 +155,21 @@ end
function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlgorithm)
Ac = copy_input(svd_compact, A)
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
return USVᴴ′, _make_svd_trunc_pullback(A, USVᴴ, ind)
USVᴴ′, ind, truncerr = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
return (USVᴴ′..., truncerr), _make_svd_trunc_pullback(A, USVᴴ, ind)
end
function _make_svd_trunc_pullback(A, USVᴴ, ind)
function svd_trunc_pullback(ΔUSVᴴ)
function svd_trunc_pullback(ΔUSVᴴ_and_err)
# Extract the tangents for U, S, and Vᴴ, ignore the truncerr tangent
ΔU = ΔUSVᴴ_and_err[1]
ΔS = ΔUSVᴴ_and_err[2]
ΔVᴴ = ΔUSVᴴ_and_err[3]
# Ignore ΔUSVᴴ_and_err[4] which is the truncerr tangent
ΔA = zero(A)
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, unthunk.(ΔUSVᴴ), ind)
MatrixAlgebraKit.svd_pullback!(ΔA, A, USVᴴ, (unthunk(ΔU), unthunk(ΔS), unthunk(ΔVᴴ)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
function svd_trunc_pullback(::Tuple{ZeroTangent, ZeroTangent, ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return svd_trunc_pullback
Expand Down
3 changes: 2 additions & 1 deletion src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ end

function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
D, V = eig_full!(A, DV, alg.alg)
return first(truncate(eig_trunc!, (D, V), alg.trunc))
result, _, truncerr = truncate(eig_trunc!, (D, V), alg.trunc)
return result..., truncerr
end

# Diagonal logic
Expand Down
3 changes: 2 additions & 1 deletion src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ end

function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
D, V = eigh_full!(A, DV, alg.alg)
return first(truncate(eigh_trunc!, (D, V), alg.trunc))
result, _, truncerr = truncate(eigh_trunc!, (D, V), alg.trunc)
return result..., truncerr
end

# Diagonal logic
Expand Down
8 changes: 4 additions & 4 deletions src/implementations/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function left_orth_svd!(A, VC, alg, trunc)
alg′ = select_algorithm(svd_compact!, A, alg)
check_input(left_orth!, A, VC, alg′)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
U, S, Vᴴ = svd_trunc!(A, alg_trunc)
U, S, Vᴴ, _ = svd_trunc!(A, alg_trunc)
V, C = VC
return copy!(V, U), mul!(C, S, Vᴴ)
end
Expand All @@ -138,7 +138,7 @@ function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
V, C = VC
S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc)
U, S, Vᴴ, _ = svd_trunc!(A, (V, S, C), alg_trunc)
return U, lmul!(S, Vᴴ)
end

Expand Down Expand Up @@ -189,7 +189,7 @@ function right_orth_svd!(A, CVᴴ, alg, trunc)
alg′ = select_algorithm(svd_compact!, A, alg)
check_input(right_orth!, A, CVᴴ, alg′)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
U, S, Vᴴ′ = svd_trunc!(A, alg_trunc)
U, S, Vᴴ′, _ = svd_trunc!(A, alg_trunc)
C, Vᴴ = CVᴴ
return mul!(C, U, S), copy!(Vᴴ, Vᴴ′)
end
Expand All @@ -199,7 +199,7 @@ function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc)
alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc)
C, Vᴴ = CVᴴ
S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg))
U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc)
U, S, Vᴴ, _ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc)
return rmul!(U, S), Vᴴ
end

Expand Down
6 changes: 4 additions & 2 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ end

function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
USVᴴ′ = svd_compact!(A, USVᴴ, alg.alg)
return first(truncate(svd_trunc!, USVᴴ′, alg.trunc))
result, _, truncerr = truncate(svd_trunc!, USVᴴ′, alg.trunc)
return result..., truncerr
end

# Diagonal logic
Expand Down Expand Up @@ -381,7 +382,8 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)
# TODO: make this controllable using a `gaugefix` keyword argument
gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...)
return first(truncate(svd_trunc!, USVᴴ, alg.trunc))
result, _, truncerr = truncate(svd_trunc!, USVᴴ, alg.trunc)
return result..., truncerr
end

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
Expand Down
38 changes: 33 additions & 5 deletions src/implementations/truncation.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,57 @@
# Compute truncation error as 2-norm of discarded values
function _compute_truncerr(values::AbstractVector, ind)
# Find indices that are NOT in ind (i.e., discarded values)
if ind isa Colon
# No truncation, all values kept
return zero(real(eltype(values)))
elseif ind isa AbstractVector{Bool}
# Boolean indexing: discarded values are where ind is false
discarded_vals = view(values, .!ind)
else
# Integer indexing: need to find complement
all_inds = Set(eachindex(values))
kept_inds = Set(ind)
discarded_inds = collect(setdiff(all_inds, kept_inds))
discarded_vals = view(values, discarded_inds)
end
# Compute 2-norm of discarded values
return sqrt(sum(abs2, discarded_vals))
end

# truncate
# --------
# Generic implementation: `findtruncated` followed by indexing
function truncate(::typeof(svd_trunc!), (U, S, Vᴴ), strategy::TruncationStrategy)
ind = findtruncated_svd(diagview(S), strategy)
return (U[:, ind], Diagonal(diagview(S)[ind]), Vᴴ[ind, :]), ind
Svals = diagview(S)
truncerr = _compute_truncerr(Svals, ind)
return (U[:, ind], Diagonal(Svals[ind]), Vᴴ[ind, :]), ind, truncerr
end
function truncate(::typeof(eig_trunc!), (D, V), strategy::TruncationStrategy)
ind = findtruncated(diagview(D), strategy)
return (Diagonal(diagview(D)[ind]), V[:, ind]), ind
Dvals = diagview(D)
truncerr = _compute_truncerr(Dvals, ind)
return (Diagonal(Dvals[ind]), V[:, ind]), ind, truncerr
end
function truncate(::typeof(eigh_trunc!), (D, V), strategy::TruncationStrategy)
ind = findtruncated(diagview(D), strategy)
return (Diagonal(diagview(D)[ind]), V[:, ind]), ind
Dvals = diagview(D)
truncerr = _compute_truncerr(Dvals, ind)
return (Diagonal(Dvals[ind]), V[:, ind]), ind, truncerr
end
function truncate(::typeof(left_null!), (U, S), strategy::TruncationStrategy)
# TODO: avoid allocation?
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2))))
ind = findtruncated(extended_S, strategy)
return U[:, ind], ind
truncerr = _compute_truncerr(extended_S, ind)
return U[:, ind], ind, truncerr
end
function truncate(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy)
# TODO: avoid allocation?
extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1))))
ind = findtruncated(extended_S, strategy)
return Vᴴ[ind, :], ind
truncerr = _compute_truncerr(extended_S, ind)
return Vᴴ[ind, :], ind, truncerr
end

# findtruncated
Expand Down
11 changes: 7 additions & 4 deletions src/interface/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,19 @@ See also [`eig_vals(!)`](@ref eig_vals) and [`eig_trunc(!)`](@ref eig_trunc).
@functiondef eig_full

"""
eig_trunc(A; kwargs...) -> D, V
eig_trunc(A, alg::AbstractAlgorithm) -> D, V
eig_trunc!(A, [DV]; kwargs...) -> D, V
eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V
eig_trunc(A; kwargs...) -> D, V, ϵ
eig_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ
eig_trunc!(A, [DV]; kwargs...) -> D, V, ϵ
eig_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ

Compute a partial or truncated eigenvalue decomposition of the matrix `A`,
such that `A * V ≈ V * D`, where the (possibly rectangular) matrix `V` contains
a subset of eigenvectors and the diagonal matrix `D` contains the associated eigenvalues,
selected according to a truncation strategy.

The function also returns `ϵ`, the truncation error defined as the 2-norm of the
discarded eigenvalues.

!!! note
The bang method `eig_trunc!` optionally accepts the output structure and
possibly destroys the input matrix `A`. Always use the return value of the function
Expand Down
13 changes: 8 additions & 5 deletions src/interface/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,18 @@ See also [`eigh_vals(!)`](@ref eigh_vals) and [`eigh_trunc(!)`](@ref eigh_trunc)
@functiondef eigh_full

"""
eigh_trunc(A; kwargs...) -> D, V
eigh_trunc(A, alg::AbstractAlgorithm) -> D, V
eigh_trunc!(A, [DV]; kwargs...) -> D, V
eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V
eigh_trunc(A; kwargs...) -> D, V, ϵ
eigh_trunc(A, alg::AbstractAlgorithm) -> D, V, ϵ
eigh_trunc!(A, [DV]; kwargs...) -> D, V, ϵ
eigh_trunc!(A, [DV], alg::AbstractAlgorithm) -> D, V, ϵ

Compute a partial or truncated eigenvalue decomposition of the symmetric or hermitian matrix
`A`, such that `A * V ≈ V * D`, where the isometric matrix `V` contains a subset of the
orthogonal eigenvectors and the real diagonal matrix `D` contains the associated eigenvalues,
selected according to a truncation strategy.
selected according to a truncation strategy.

The function also returns `ϵ`, the truncation error defined as the 2-norm of the
discarded eigenvalues.

!!! note
The bang method `eigh_trunc!` optionally accepts the output structure and
Expand Down
13 changes: 8 additions & 5 deletions src/interface/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,19 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_vals(!)`](@ref svd_vals) and

# TODO: decide if we should have `svd_trunc!!` instead
"""
svd_trunc(A; kwargs...) -> U, S, Vᴴ
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ
svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ
svd_trunc(A; kwargs...) -> U, S, Vᴴ, ϵ
svd_trunc(A, alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ
svd_trunc!(A, [USVᴴ]; kwargs...) -> U, S, Vᴴ, ϵ
svd_trunc!(A, [USVᴴ], alg::AbstractAlgorithm) -> U, S, Vᴴ, ϵ

Compute a partial or truncated singular value decomposition (SVD) of `A`, such that
`A * (Vᴴ)' = U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
`A * (Vᴴ)' U * S`. Here, `U` is an isometric matrix (orthonormal columns) of size
`(m, k)`, whereas `Vᴴ` is a matrix of size `(k, n)` with orthonormal rows and `S` is a
square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strategy.

The function also returns `ϵ`, the truncation error defined as the 2-norm of the
discarded singular values.

!!! note
The bang method `svd_trunc!` optionally accepts the output structure and
possibly destroys the input matrix `A`. Always use the return value of the function
Expand Down
6 changes: 3 additions & 3 deletions test/amd/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ end
r = m - 2
s = 1 + sqrt(eps(real(T)))

D1, V1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r))
D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r))
@test length(diagview(D1)) == r
@test isisometry(V1)
@test A * V1 ≈ V1 * D1
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]

trunc = trunctol(; atol=s * D₀[r + 1])
D2, V2 = @constinferred eigh_trunc(A; alg, trunc)
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
@test length(diagview(D2)) == r
@test isisometry(V2)
@test A * V2 ≈ V2 * D2
Expand All @@ -75,7 +75,7 @@ end
A = V * D * V'
A = (A + A') / 2
alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2))
D2, V2 = @constinferred eigh_trunc(A; alg)
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))
@test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2))
end=#
4 changes: 2 additions & 2 deletions test/amd/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ end
# minmn = min(m, n)
# r = minmn - 2

# U1, S1, V1ᴴ = @constinferred svd_trunc(A; alg, trunc=truncrank(r))
# U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc=truncrank(r))
# @test length(S1.diag) == r
# @test LinearAlgebra.opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]

# s = 1 + sqrt(eps(real(T)))
# trunc2 = trunctol(; atol=s * S₀[r + 1])

# U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
# U2, S2, V2ᴴ, ϵ2 = @constinferred svd_trunc(A; alg, trunc=trunctol(; atol=s * S₀[r + 1]))
# @test length(S2.diag) == r
# @test U1 ≈ U2
# @test S1 ≈ S2
Expand Down
12 changes: 6 additions & 6 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ end
ΔVtrunc = ΔV[:, ind]
test_rrule(
copy_eig_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔDtrunc, ΔVtrunc),
output_tangent = (ΔDtrunc, ΔVtrunc, ZeroTangent()),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
Expand All @@ -290,7 +290,7 @@ end
ΔVtrunc = ΔV[:, ind]
test_rrule(
copy_eig_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔDtrunc, ΔVtrunc),
output_tangent = (ΔDtrunc, ΔVtrunc, ZeroTangent()),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
Expand Down Expand Up @@ -351,7 +351,7 @@ end
ΔVtrunc = ΔV[:, ind]
test_rrule(
copy_eigh_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔDtrunc, ΔVtrunc),
output_tangent = (ΔDtrunc, ΔVtrunc, ZeroTangent()),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
Expand All @@ -366,7 +366,7 @@ end
ΔVtrunc = ΔV[:, ind]
test_rrule(
copy_eigh_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔDtrunc, ΔVtrunc),
output_tangent = (ΔDtrunc, ΔVtrunc, ZeroTangent()),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
Expand Down Expand Up @@ -446,7 +446,7 @@ end
ΔVᴴtrunc = ΔVᴴ[ind, :]
test_rrule(
copy_svd_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc),
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, ZeroTangent()),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
Expand All @@ -463,7 +463,7 @@ end
ΔVᴴtrunc = ΔVᴴ[ind, :]
test_rrule(
copy_svd_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc),
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, ZeroTangent()),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
Expand Down
4 changes: 2 additions & 2 deletions test/cuda/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ end
rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2))
r = length(D₀) - rmin

D1, V1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r))
D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc=truncrank(r))
@test length(D1.diag) == r
@test A * V1 ≈ V1 * D1

s = 1 + sqrt(eps(real(T)))
trunc = trunctol(; atol=s * abs(D₀[r + 1]))
D2, V2 = @constinferred eig_trunc(A; alg, trunc)
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc)
@test length(diagview(D2)) == r
@test A * V2 ≈ V2 * D2

Expand Down
Loading