Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/user_interface/truncations.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ For example:
```jldoctest truncations; output=false
using LinearAlgebra: norm
U, S, Vᴴ, ϵ = svd_trunc(A; trunc=truncrank(2))
norm(A - U * S * Vᴴ) ≈ ϵ # ϵ is the 2-norm of the discarded singular values
norm(A - U * S * Vᴴ) ≈ only(ϵ) # ϵ is the 2-norm of the discarded singular values, stored as a Vector

# output
true
Expand Down
4 changes: 2 additions & 2 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ for eig in (:eig, :eigh)
Ac = copy_input($eig_f, A)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
ϵ = truncation_error(diagview(DV[1]), ind)
ϵ = [truncation_error(diagview(DV[1]), ind)]
return (DV′..., ϵ), $(_make_eig_t_pb)(A, DV, ind)
end
function $(_make_eig_t_pb)(A, DV, ind)
Expand Down Expand Up @@ -174,7 +174,7 @@ function ChainRulesCore.rrule(::typeof(svd_trunc!), A, USVᴴ, alg::TruncatedAlg
Ac = copy_input(svd_compact, A)
USVᴴ = svd_compact!(Ac, USVᴴ, alg.alg)
USVᴴ′, ind = MatrixAlgebraKit.truncate(svd_trunc!, USVᴴ, alg.trunc)
ϵ = truncation_error(diagview(USVᴴ[2]), ind)
ϵ = [truncation_error(diagview(USVᴴ[2]), ind)]
return (USVᴴ′..., ϵ), _make_svd_trunc_pullback(A, USVᴴ, ind)
end
function _make_svd_trunc_pullback(A, USVᴴ, ind)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,9 @@ for (f, pb, adj) in (
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
function $adj(::NoRData)
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
D, dD = arrayify(Dtrunc, dDtrunc_)
V, dV = arrayify(Vtrunc, dVtrunc_)
$pb(dA, A, (D, V), (dD, dV))
Expand Down Expand Up @@ -316,10 +315,9 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
function svd_trunc_adjoint(::NoRData)
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error"
U, dU = arrayify(Utrunc, dUtrunc_)
S, dS = arrayify(Strunc, dStrunc_)
Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_)
Expand Down
2 changes: 2 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ truncation through `trunc`.
struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm
alg::A
trunc::T
compute_error::Bool
end
TruncatedAlgorithm(alg::A, trunc::T; compute_error::Bool = true) where {A <: AbstractAlgorithm, T} = TruncatedAlgorithm{A, T}(alg, trunc, compute_error)

does_truncate(::TruncatedAlgorithm) = true

Expand Down
14 changes: 10 additions & 4 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ function initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::AbstractAlg
return D
end
function initialize_output(::typeof(eig_trunc!), A, alg::TruncatedAlgorithm)
return initialize_output(eig_full!, A, alg.alg)
DV = initialize_output(eig_full!, A, alg.alg)
ϵ = similar(A, real(eltype(A)), alg.compute_error)
return (DV..., ϵ)
end

function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm)
Expand Down Expand Up @@ -115,10 +117,14 @@ function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm)
return D
end

function eig_trunc!(A, DV, alg::TruncatedAlgorithm)
D, V = eig_full!(A, DV, alg.alg)
function eig_trunc!(A, DVϵ, alg::TruncatedAlgorithm)
D, V, ϵ = DVϵ
D, V = eig_full!(A, (D, V), alg.alg)
DVtrunc, ind = truncate(eig_trunc!, (D, V), alg.trunc)
return DVtrunc..., truncation_error!(diagview(D), ind)
if !isempty(ϵ)
ϵ .= truncation_error!(diagview(D), ind)
end
return DVtrunc..., ϵ
end

# Diagonal logic
Expand Down
14 changes: 10 additions & 4 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ function initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::AbstractAl
return D
end
function initialize_output(::typeof(eigh_trunc!), A, alg::TruncatedAlgorithm)
return initialize_output(eigh_full!, A, alg.alg)
DV = initialize_output(eigh_full!, A, alg.alg)
ϵ = similar(A, real(eltype(A)), alg.compute_error)
return (DV..., ϵ)
end

function initialize_output(::typeof(eigh_full!), A::Diagonal, ::DiagonalAlgorithm)
Expand Down Expand Up @@ -129,10 +131,14 @@ function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm)
return D
end

function eigh_trunc!(A, DV, alg::TruncatedAlgorithm)
D, V = eigh_full!(A, DV, alg.alg)
function eigh_trunc!(A, DVϵ, alg::TruncatedAlgorithm)
D, V, ϵ = DVϵ
D, V = eigh_full!(A, (D, V), alg.alg)
DVtrunc, ind = truncate(eigh_trunc!, (D, V), alg.trunc)
return DVtrunc..., truncation_error!(diagview(D), ind)
if !isempty(ϵ)
ϵ .= truncation_error!(diagview(D), ind)
end
return DVtrunc..., ϵ
end

# Diagonal logic
Expand Down
22 changes: 10 additions & 12 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ function initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::AbstractAlg
return similar(A, real(eltype(A)), (min(size(A)...),))
end
function initialize_output(::typeof(svd_trunc!), A, alg::TruncatedAlgorithm)
return initialize_output(svd_compact!, A, alg.alg)
USVᴴ = initialize_output(svd_compact!, A, alg.alg)
ϵ = similar(A, real(eltype(A)), alg.compute_error)
return (USVᴴ..., ϵ)
end

function initialize_output(::typeof(svd_full!), A::Diagonal, ::DiagonalAlgorithm)
Expand Down Expand Up @@ -206,12 +208,6 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
return S
end

function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
ϵ = similar(A, real(eltype(A)), compute_error)
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ)))
end

function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
U, S, Vᴴ, ϵ = USVᴴϵ
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
Expand Down Expand Up @@ -272,18 +268,19 @@ end
###

function check_input(
::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized
::typeof(svd_trunc!), A::AbstractMatrix, USVᴴϵ, alg::CUSOLVER_Randomized
)
m, n = size(A)
minmn = min(m, n)
U, S, Vᴴ = USVᴴ
U, S, Vᴴ, ϵ = USVᴴϵ
@assert U isa AbstractMatrix && S isa Diagonal && Vᴴ isa AbstractMatrix
@check_size(U, (m, m))
@check_scalar(U, A)
@check_size(S, (minmn, minmn))
@check_scalar(S, A, real)
@check_size(Vᴴ, (n, n))
@check_scalar(Vᴴ, A)
@check_scalar(ϵ, A, real)
return nothing
end

Expand All @@ -295,7 +292,8 @@ function initialize_output(
U = similar(A, (m, m))
S = Diagonal(similar(A, real(eltype(A)), (minmn,)))
Vᴴ = similar(A, (n, n))
return (U, S, Vᴴ)
ϵ = similar(A, real(eltype(A)), alg.compute_error)
return (U, S, Vᴴ, ϵ)
end

function _gpu_gesvd!(
Expand Down Expand Up @@ -374,7 +372,7 @@ end

function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ}
U, S, Vᴴ, ϵ = USVᴴϵ
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
check_input(svd_trunc!, A, USVᴴϵ, alg.alg)
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)

# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
Expand All @@ -386,7 +384,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg:
normA = norm(A)
# equivalent to sqrt(normA^2 - normS^2)
# but may be more accurate
ϵ = sqrt((normA + normS) * (normA - normS))
ϵ .= sqrt((normA + normS) * (normA - normS))
end

do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
Expand Down
20 changes: 10 additions & 10 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ end
ΔVtrunc = ΔV[:, ind]
test_rrule(
copy_eig_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))),
output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
Expand All @@ -262,7 +262,7 @@ end
ΔVtrunc = ΔV[:, ind]
test_rrule(
copy_eig_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))),
output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
Expand Down Expand Up @@ -328,7 +328,7 @@ end
ΔVtrunc = ΔV[:, ind]
test_rrule(
copy_eigh_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))),
output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
Expand All @@ -343,7 +343,7 @@ end
ΔVtrunc = ΔV[:, ind]
test_rrule(
copy_eigh_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))),
output_tangent = (ΔDtrunc, ΔVtrunc, [zero(real(T))]),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
Expand Down Expand Up @@ -380,7 +380,7 @@ end
test_rrule(
config, eigh_trunc2, A;
fkwargs = (; trunc = trunc),
output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))),
output_tangent = (ΔD[ind, ind], ΔV[:, ind], [zero(real(T))]),
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
)
end
Expand All @@ -389,7 +389,7 @@ end
test_rrule(
config, eigh_trunc2, A;
fkwargs = (; trunc = trunc),
output_tangent = (ΔD[ind, ind], ΔV[:, ind], zero(real(T))),
output_tangent = (ΔD[ind, ind], ΔV[:, ind], [zero(real(T))]),
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
)
end
Expand Down Expand Up @@ -431,7 +431,7 @@ end
ΔVᴴtrunc = ΔVᴴ[ind, :]
test_rrule(
copy_svd_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))),
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
Expand All @@ -448,7 +448,7 @@ end
ΔVᴴtrunc = ΔVᴴ[ind, :]
test_rrule(
copy_svd_trunc, A, truncalg ⊢ NoTangent();
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))),
output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero(real(T))]),
atol = atol, rtol = rtol
)
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
Expand Down Expand Up @@ -477,7 +477,7 @@ end
test_rrule(
config, svd_trunc, A;
fkwargs = (; trunc = trunc),
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))),
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], [zero(real(T))]),
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
)
end
Expand All @@ -486,7 +486,7 @@ end
test_rrule(
config, svd_trunc, A;
fkwargs = (; trunc = trunc),
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], zero(real(T))),
output_tangent = (ΔU[:, ind], ΔS[ind, ind], ΔVᴴ[ind, :], [zero(real(T))]),
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
)
end
Expand Down
2 changes: 1 addition & 1 deletion test/cuda/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ end
U1, S1, V1ᴴ, ϵ1 = @constinferred svd_trunc(A; alg, trunc = truncrank(r))
@test length(S1.diag) == r
@test opnorm(A - U1 * S1 * V1ᴴ) ≈ S₀[r + 1]
@test norm(A - U1 * S1 * V1ᴴ) ≈ ϵ1
@test norm(A - U1 * S1 * V1ᴴ) ≈ norm(ϵ1)

if !(alg isa CUSOLVER_Randomized)
s = 1 + sqrt(eps(real(T)))
Expand Down
12 changes: 6 additions & 6 deletions test/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,21 @@ end
D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r))
@test length(diagview(D1)) == r
@test A * V1 ≈ V1 * D1
@test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol
@test norm(ϵ1) ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 + sqrt(eps(real(T)))
trunc = trunctol(; atol = s * abs(D₀[r + 1]))
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc)
@test length(diagview(D2)) == r
@test A * V2 ≈ V2 * D2
@test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol
@test norm(ϵ2) ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(T)))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc)
@test length(diagview(D3)) == r
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol
@test norm(ϵ3) ≈ norm(view(D₀, (r + 1):m)) atol = atol

# trunctol keeps order, truncrank might not
# test for same subspace
Expand All @@ -83,13 +83,13 @@ end
alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2))
D2, V2, ϵ2 = @constinferred eig_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2]
@test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol
@test norm(ϵ2) ≈ norm(diagview(D)[3:4]) atol = atol
@test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2))

alg = TruncatedAlgorithm(LAPACK_Simple(), truncerror(; atol = 0.2, p = 1))
D3, V3, ϵ3 = @constinferred eig_trunc(A; alg)
@test diagview(D3) ≈ diagview(D)[1:2]
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol
@test norm(ϵ3) ≈ norm(diagview(D)[3:4]) atol = atol
end

@testset "eig for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
Expand All @@ -112,5 +112,5 @@ end
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg)
@test diagview(D2) ≈ diagview(A2)[1:2]
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol
@test norm(ϵ2) ≈ norm(diagview(A2)[3:4]) atol = atol
end
12 changes: 6 additions & 6 deletions test/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ end
@test isisometric(V1)
@test A * V1 ≈ V1 * D1
@test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1]
@test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol
@test norm(ϵ1) ≈ norm(view(D₀, (r + 1):m)) atol = atol

trunc = trunctol(; atol = s * D₀[r + 1])
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc)
@test length(diagview(D2)) == r
@test isisometric(V2)
@test A * V2 ≈ V2 * D2
@test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol
@test norm(ϵ2) ≈ norm(view(D₀, (r + 1):m)) atol = atol

s = 1 - sqrt(eps(real(T)))
trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1)
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc)
@test length(diagview(D3)) == r
@test A * V3 ≈ V3 * D3
@test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol
@test norm(ϵ3) ≈ norm(view(D₀, (r + 1):m)) atol = atol

# test for same subspace
@test V1 * (V1' * V2) ≈ V2
Expand All @@ -93,12 +93,12 @@ end
D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg)
@test diagview(D2) ≈ diagview(D)[1:2]
@test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2))
@test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol
@test norm(ϵ2) ≈ norm(diagview(D)[3:4]) atol = atol

alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2))
D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg)
@test diagview(D3) ≈ diagview(D)[1:2]
@test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol
@test norm(ϵ3) ≈ norm(diagview(D)[3:4]) atol = atol
end

@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...)
Expand All @@ -122,5 +122,5 @@ end
alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2))
D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg)
@test diagview(D2) ≈ diagview(A2)[1:2]
@test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol
@test norm(ϵ2) ≈ norm(diagview(A2)[3:4]) atol = atol
end
Loading