Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ for eig in (:eig, :eigh)
eig_t! = Symbol(eig, "_trunc!")
eig_t_pb = Symbol(eig, "_trunc_pullback")
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
eig_v = Symbol(eig, "_vals")
eig_v! = Symbol(eig_v, "!")
eig_v_pb = Symbol(eig_v, "_pullback")
eig_v_pb! = Symbol(eig_v_pb, "!")

@eval begin
function ChainRulesCore.rrule(::typeof($eig_f!), A, DV, alg)
Ac = copy_input($eig_f, A)
Expand Down Expand Up @@ -131,6 +136,18 @@ for eig in (:eig, :eigh)
end
return $eig_t_pb
end
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
DV = $eig_f(A, alg)
function $eig_v_pb(ΔD)
ΔA = zero(A)
MatrixAlgebraKit.$eig_v_pb!(ΔA, A, DV, unthunk(ΔD))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function $eig_v_pb(::ZeroTangent) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return diagview(DV[1]), $eig_v_pb
end
end
end

Expand Down Expand Up @@ -176,6 +193,19 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind)
return svd_trunc_pullback
end

function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg)
USVᴴ = svd_compact(A, alg)
function svd_vals_pullback(ΔS)
ΔA = zero(A)
MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS))
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function svd_pullback(::ZeroTangent) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return diagview(USVᴴ[2]), svd_vals_pullback
end

function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg)
Ac = copy_input(left_polar, A)
WP = left_polar!(Ac, WP, alg)
Expand Down
26 changes: 14 additions & 12 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview, copy_input
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback!
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!
using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback!
using LinearAlgebra


Expand Down Expand Up @@ -122,8 +123,8 @@ for (f!, f, pb, adj) in (
end

for (f!, f, f_full, pb, adj) in (
(:eig_vals!, :eig_vals, :eig_full, :eig_pullback!, :eig_vals_adjoint),
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback!, :eigh_vals_adjoint),
(:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_adjoint),
(:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
Expand All @@ -136,7 +137,7 @@ for (f!, f, f_full, pb, adj) in (
copy!(D, diagview(DV[1]))
V = DV[2]
function $adj(::NoRData)
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
$pb(dA, A, DV, dD)
MatrixAlgebraKit.zero!(dD)
return NoRData(), NoRData(), NoRData(), NoRData()
end
Expand All @@ -153,7 +154,7 @@ for (f!, f, f_full, pb, adj) in (
output_codual = CoDual(output, Mooncake.zero_tangent(output))
function $adj(::NoRData)
D, dD = arrayify(output_codual)
$pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing))
$pb(dA, A, DV, dD)
MatrixAlgebraKit.zero!(dD)
return NoRData(), NoRData(), NoRData()
end
Expand Down Expand Up @@ -272,10 +273,10 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
# compute primal
A, dA = arrayify(A_dA)
S, dS = arrayify(S_dS)
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
copy!(S, diagview(nS))
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
copy!(S, diagview(USVᴴ[2]))
function svd_vals_adjoint(::NoRData)
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
svd_vals_pullback!(dA, A, USVᴴ, dS)
MatrixAlgebraKit.zero!(dS)
return NoRData(), NoRData(), NoRData(), NoRData()
end
Expand All @@ -286,15 +287,16 @@ end
function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
U, S, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg))
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
S_codual = CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(diagview(S))))
S = diagview(USVᴴ[2])
S_codual = CoDual(S, Mooncake.fdata(Mooncake.zero_tangent(S)))
function svd_vals_adjoint(::NoRData)
S, dS = arrayify(S_codual)
svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing))
svd_vals_pullback!(dA, A, USVᴴ, dS)
MatrixAlgebraKit.zero!(dS)
return NoRData(), NoRData(), NoRData()
end
Expand Down
6 changes: 4 additions & 2 deletions src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using LinearAlgebra: UpperTriangular, LowerTriangular
using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt

export isisometric, isunitary, ishermitian, isantihermitian
export diagview, diagonal

export project_hermitian, project_antihermitian, project_isometric
export project_hermitian!, project_antihermitian!, project_isometric!
Expand Down Expand Up @@ -62,8 +63,9 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter
Expr(
:public, :left_polar_pullback!, :right_polar_pullback!,
:qr_pullback!, :qr_null_pullback!, :lq_pullback!, :lq_null_pullback!,
:eig_pullback!, :eig_trunc_pullback!, :eigh_pullback!, :eigh_trunc_pullback!,
:svd_pullback!, :svd_trunc_pullback!
:eig_pullback!, :eig_trunc_pullback!, :eig_vals_pullback!,
:eigh_pullback!, :eigh_trunc_pullback!, :eigh_vals_pullback!,
:svd_pullback!, :svd_trunc_pullback!, :svd_vals_pullback!
)
)
eval(Expr(:public, :is_left_isometric, :is_right_isometric))
Expand Down
18 changes: 18 additions & 0 deletions src/common/view.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
# diagind: provided by LinearAlgebra.jl
@doc """
diagview(D)

Return a view of the diagonal elements of a matrix `D`.

See also [`diagonal`](@ref).
""" diagview

diagview(D::Diagonal) = D.diag
diagview(D::AbstractMatrix) = view(D, diagind(D))

@doc """
diagonal(v)

Construct a diagonal matrix view for the given diagonal vector.

See also [`diagview`](@ref).
""" diagonal

diagonal(v::AbstractVector) = Diagonal(v)

# triangularind
function lowertriangularind(A::AbstractMatrix)
Base.require_one_based_indexing(A)
Expand Down
24 changes: 24 additions & 0 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,27 @@ function eig_trunc_pullback!(
end
return ΔA
end

"""
eig_vals_pullback!(
ΔA, A, DV, ΔD, [ind];
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
)

Adds the pullback from the eigenvalues of `A` to `ΔA`, given the output
`DV` of `eig_full` and the cotangent `ΔD` of `eig_vals`.

In particular, it is assumed that `A ≈ V * D * inv(V)` with thus `size(A) == size(V) == size(D)`
and `D` diagonal. For the cotangents, an arbitrary number of eigenvalues can be missing, i.e.
for a matrix `A` of size `(n, n)`, `diagview(ΔD)` can have length `pD`. In those cases,
additionally `ind` is required to specify which eigenvalues are present in `ΔV` or `ΔD`.
By default, it is assumed that all eigenvectors and eigenvalues are present.
"""
function eig_vals_pullback!(
ΔA, A, DV, ΔD, ind = Colon();
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
)

ΔDV = (diagonal(ΔD), nothing)
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
end
24 changes: 24 additions & 0 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,27 @@ function eigh_trunc_pullback!(
end
return ΔA
end

"""
eigh_vals_pullback!(
ΔA, A, DV, ΔD, [ind];
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
)

Adds the pullback from the eigenvalues of `A` to `ΔA`, given the output
`DV` of `eigh_full` and the cotangent `ΔD` of `eig_vals`.

In particular, it is assumed that `A ≈ V * D * inv(V)` with thus `size(A) == size(V) == size(D)`
and `D` diagonal. For the cotangents, an arbitrary number of eigenvalues can be missing, i.e.
for a matrix `A` of size `(n, n)`, `diagview(ΔD)` can have length `pD`. In those cases,
additionally `ind` is required to specify which eigenvalues are present in `ΔV` or `ΔD`.
By default, it is assumed that all eigenvectors and eigenvalues are present.
"""
function eigh_vals_pullback!(
ΔA, A, DV, ΔD, ind = Colon();
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
)

ΔDV = (diagonal(ΔD), nothing)
return eigh_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
end
28 changes: 27 additions & 1 deletion src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3])
)

Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or
Adds the pullback from the SVD of `A` to `ΔA` given the output `USVᴴ` of `svd_compact` or
`svd_full` and the cotangent `ΔUSVᴴ` of `svd_compact`, `svd_full` or `svd_trunc`.

In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with
Expand Down Expand Up @@ -201,3 +201,29 @@ function svd_trunc_pullback!(
ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1)
return ΔA
end

"""
svd_vals_pullback!(
ΔA, A, USVᴴ, ΔS, [ind];
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
)


Adds the pullback from the singular values of `A` to `ΔA`, given the output
`USVᴴ` of `svd_compact`, and the cotangent `ΔS` of `svd_vals`.

In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with
magnitude less than `rank_atol` are missing from `S`. For the cotangents, an arbitrary
number of singular vectors or singular values can be missing, i.e. for a matrix `A` with
size `(m, n)`, `diagview(ΔS)` can have length `pS`. In those cases, additionally `ind` is required to
specify which singular vectors and values are present in `ΔS`.
"""
function svd_vals_pullback!(
ΔA, A, USVᴴ, ΔS, ind = Colon();
rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]),
degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2])
)
ΔUSVᴴ = (nothing, diagonal(ΔS), nothing)
return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol)
end
42 changes: 31 additions & 11 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,23 @@ include("ad_utils.jl")
for f in
(
:qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null,
:eig_full, :eig_trunc, :eigh_full, :eigh_trunc, :svd_compact, :svd_trunc,
:eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals,
:svd_compact, :svd_trunc, :svd_vals,
:left_polar, :right_polar,
)
copy_f = Symbol(:copy_, f)
f! = Symbol(f, '!')
_hermitian = startswith(string(f), "eigh")
@eval begin
function $copy_f(input, alg)
if $f === eigh_full || $f === eigh_trunc
if $_hermitian
input = (input + input') / 2
end
return $f(input, alg)
end
function ChainRulesCore.rrule(::typeof($copy_f), input, alg)
output = MatrixAlgebraKit.initialize_output($f!, input, alg)
if $f === eigh_full || $f === eigh_trunc
if $_hermitian
input = (input + input') / 2
else
input = copy(input)
Expand Down Expand Up @@ -228,12 +230,13 @@ end
ΔD2 = Diagonal(randn(rng, complex(T), m))
for alg in (LAPACK_Simple(), LAPACK_Expert())
test_rrule(
copy_eig_full, A, alg ⊢ NoTangent();
output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol
copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol
)
test_rrule(
copy_eig_full, A, alg ⊢ NoTangent();
output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol
copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol
)
test_rrule(
copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol
)
for r in 1:4:m
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
Expand Down Expand Up @@ -284,6 +287,10 @@ end
config, last ∘ eig_full, A;
output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
)
test_rrule(
config, eig_vals, A;
output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
)
end

@timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32)
Expand All @@ -304,12 +311,13 @@ end
)
# copy_eigh_full includes a projector onto the Hermitian part of the matrix
test_rrule(
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV),
atol = atol, rtol = rtol
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol
)
test_rrule(
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV),
atol = atol, rtol = rtol
copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol
)
test_rrule(
copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol
)
for r in 1:4:m
truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs))
Expand Down Expand Up @@ -361,6 +369,10 @@ end
config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A;
output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
)
test_rrule(
config, eigh_vals ∘ Matrix ∘ Hermitian, A;
output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
)
eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...)
for r in 1:4:m
trunc = truncrank(r; by = real)
Expand Down Expand Up @@ -404,6 +416,10 @@ end
copy_svd_compact, A, alg ⊢ NoTangent();
output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol
)
test_rrule(
copy_svd_vals, A, alg ⊢ NoTangent();
output_tangent = diagview(ΔS), atol, rtol
)
for r in 1:4:minmn
truncalg = TruncatedAlgorithm(alg, truncrank(r))
ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc)
Expand Down Expand Up @@ -451,6 +467,10 @@ end
output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol,
rrule_f = rrule_via_ad, check_inferred = false
)
test_rrule(
config, svd_vals, A;
output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false
)
for r in 1:4:minmn
trunc = truncrank(r)
ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc)
Expand Down