From f0a0045c09c510822f7f863591dbc2a351208fb8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Dec 2025 09:27:37 -0500 Subject: [PATCH 1/6] add and export `diagonal` and `diagview` --- src/MatrixAlgebraKit.jl | 1 + src/common/view.jl | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 493f3a91..30341c49 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -10,6 +10,7 @@ using LinearAlgebra: UpperTriangular, LowerTriangular using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt export isisometric, isunitary, ishermitian, isantihermitian +export diagview, diagonal export project_hermitian, project_antihermitian, project_isometric export project_hermitian!, project_antihermitian!, project_isometric! diff --git a/src/common/view.jl b/src/common/view.jl index c8ae1aa5..e03bfb88 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -1,7 +1,25 @@ # diagind: provided by LinearAlgebra.jl +@doc """ + diagview(D) + +Return a view of the diagonal elements of a matrix `D`. + +See also [`diagonal`](@ref). +""" diagview + diagview(D::Diagonal) = D.diag diagview(D::AbstractMatrix) = view(D, diagind(D)) +@doc """ + diagonal(v) + +Construct a diagonal matrix view for the given diagonal vector. + +See also [`diagview`](@ref). +""" diagonal + +diagonal(v::AbstractVector) = Diagonal(v) + # triangularind function lowertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) From 4cf443b4ec6ae0e2e8fb2e46610f6180c524fbc4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Dec 2025 09:36:56 -0500 Subject: [PATCH 2/6] add `pullback`, `rrule` and test for `svd_vals` --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 13 +++++++++++ src/pullbacks/svd.jl | 28 +++++++++++++++++++++++- test/chainrules.jl | 7 +++++- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 26e63570..480472dc 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -176,6 +176,19 @@ function _make_svd_trunc_pullback(A, USVᴴ, ind) return svd_trunc_pullback end +function ChainRulesCore.rrule(::typeof(svd_vals!), A, S, alg) + USVᴴ = svd_compact(A, alg) + function svd_vals_pullback(ΔS) + ΔA = zero(A) + MatrixAlgebraKit.svd_vals_pullback!(ΔA, A, USVᴴ, unthunk(ΔS)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function svd_pullback(::ZeroTangent) # is this extra definition useful? + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return diagview(USVᴴ[2]), svd_vals_pullback +end + function ChainRulesCore.rrule(::typeof(left_polar!), A, WP, alg) Ac = copy_input(left_polar, A) WP = left_polar!(Ac, WP, alg) diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 1667be65..a8f8b70c 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -6,7 +6,7 @@ gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) -Adds the pullback from the SVD of `A` to `ΔA` given the output USVᴴ of `svd_compact` or +Adds the pullback from the SVD of `A` to `ΔA` given the output `USVᴴ` of `svd_compact` or `svd_full` and the cotangent `ΔUSVᴴ` of `svd_compact`, `svd_full` or `svd_trunc`. In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with @@ -201,3 +201,29 @@ function svd_trunc_pullback!( ΔA = mul!(ΔA, U, Y' * Ṽᴴ, 1, 1) return ΔA end + +""" + svd_vals_pullback!( + ΔA, A, USVᴴ, ΔS, [ind]; + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]) + ) + + +Adds the pullback from the singular values of `A` to `ΔA`, given the output +`USVᴴ` of `svd_compact`, and the cotangent `ΔS` of `svd_vals`. + +In particular, it is assumed that `A ≈ U * S * Vᴴ`, or thus, that no singular values with +magnitude less than `rank_atol` are missing from `S`. For the cotangents, an arbitrary +number of singular vectors or singular values can be missing, i.e. for a matrix `A` with +size `(m, n)`, `diagview(ΔS)` can have length `pS`. In those cases, additionally `ind` is required to +specify which singular vectors and values are present in `ΔS`. +""" +function svd_vals_pullback!( + ΔA, A, USVᴴ, ΔS, ind = Colon(); + rank_atol::Real = default_pullback_rank_atol(USVᴴ[2]), + degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]) + ) + ΔUSVᴴ = (nothing, diagonal(ΔS), nothing) + return svd_pullback!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol, degeneracy_atol) +end diff --git a/test/chainrules.jl b/test/chainrules.jl index 76eb84c8..e158d129 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -11,7 +11,8 @@ include("ad_utils.jl") for f in ( :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, - :eig_full, :eig_trunc, :eigh_full, :eigh_trunc, :svd_compact, :svd_trunc, + :eig_full, :eig_trunc, :eigh_full, :eigh_trunc, + :svd_compact, :svd_trunc, :svd_vals, :left_polar, :right_polar, ) copy_f = Symbol(:copy_, f) @@ -404,6 +405,10 @@ end copy_svd_compact, A, alg ⊢ NoTangent(); output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol ) + test_rrule( + copy_svd_vals, A, alg ⊢ NoTangent(); + output_tangent = diagview(ΔS), atol, rtol + ) for r in 1:4:minmn truncalg = TruncatedAlgorithm(alg, truncrank(r)) ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) From b38a916b734cc2945f2f80877b70bb6abda7e884 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Dec 2025 09:53:24 -0500 Subject: [PATCH 3/6] add `pullback1, `rrule` and test for `eig(h)_vals` --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 17 ++++++++++++++++ src/pullbacks/eig.jl | 24 +++++++++++++++++++++++ src/pullbacks/eigh.jl | 24 +++++++++++++++++++++++ test/chainrules.jl | 25 +++++++++++++----------- 4 files changed, 79 insertions(+), 11 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 480472dc..549f4a53 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -95,6 +95,11 @@ for eig in (:eig, :eigh) eig_t! = Symbol(eig, "_trunc!") eig_t_pb = Symbol(eig, "_trunc_pullback") _make_eig_t_pb = Symbol("_make_", eig_t_pb) + eig_v = Symbol(eig, "_vals") + eig_v! = Symbol(eig_v, "!") + eig_v_pb = Symbol(eig_v, "_pullback") + eig_v_pb! = Symbol(eig_v_pb, "!") + @eval begin function ChainRulesCore.rrule(::typeof($eig_f!), A, DV, alg) Ac = copy_input($eig_f, A) @@ -131,6 +136,18 @@ for eig in (:eig, :eigh) end return $eig_t_pb end + function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg) + DV = $eig_f(A, alg) + function $eig_v_pb(ΔD) + ΔA = zero(A) + MatrixAlgebraKit.$eig_v_pb!(ΔA, A, DV, unthunk(ΔD)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function $eig_v_pb(::ZeroTangent) # is this extra definition useful? + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return diagview(DV[1]), $eig_v_pb + end end end diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index ff0de512..4a203f64 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -151,3 +151,27 @@ function eig_trunc_pullback!( end return ΔA end + +""" + eig_vals_pullback!( + ΔA, A, DV, ΔD, [ind]; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + ) + +Adds the pullback from the eigenvalues of `A` to `ΔA`, given the output +`DV` of `eig_full` and the cotangent `ΔD` of `eig_vals`. + +In particular, it is assumed that `A ≈ V * D * inv(V)` with thus `size(A) == size(V) == size(D)` +and `D` diagonal. For the cotangents, an arbitrary number of eigenvalues can be missing, i.e. +for a matrix `A` of size `(n, n)`, `diagview(ΔD)` can have length `pD`. In those cases, +additionally `ind` is required to specify which eigenvalues are present in `ΔV` or `ΔD`. +By default, it is assumed that all eigenvectors and eigenvalues are present. +""" +function eig_vals_pullback!( + ΔA, A, DV, ΔD, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + ) + + ΔDV = (diagonal(ΔD), nothing) + return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol) +end diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index d303faa7..195539cf 100644 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -141,3 +141,27 @@ function eigh_trunc_pullback!( end return ΔA end + +""" + eigh_vals_pullback!( + ΔA, A, DV, ΔD, [ind]; + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + ) + +Adds the pullback from the eigenvalues of `A` to `ΔA`, given the output +`DV` of `eigh_full` and the cotangent `ΔD` of `eig_vals`. + +In particular, it is assumed that `A ≈ V * D * inv(V)` with thus `size(A) == size(V) == size(D)` +and `D` diagonal. For the cotangents, an arbitrary number of eigenvalues can be missing, i.e. +for a matrix `A` of size `(n, n)`, `diagview(ΔD)` can have length `pD`. In those cases, +additionally `ind` is required to specify which eigenvalues are present in `ΔV` or `ΔD`. +By default, it is assumed that all eigenvectors and eigenvalues are present. +""" +function eigh_vals_pullback!( + ΔA, A, DV, ΔD, ind = Colon(); + degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), + ) + + ΔDV = (diagonal(ΔD), nothing) + return eigh_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol) +end diff --git a/test/chainrules.jl b/test/chainrules.jl index e158d129..63d77829 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -11,22 +11,23 @@ include("ad_utils.jl") for f in ( :qr_compact, :qr_full, :qr_null, :lq_compact, :lq_full, :lq_null, - :eig_full, :eig_trunc, :eigh_full, :eigh_trunc, + :eig_full, :eig_trunc, :eig_vals, :eigh_full, :eigh_trunc, :eigh_vals, :svd_compact, :svd_trunc, :svd_vals, :left_polar, :right_polar, ) copy_f = Symbol(:copy_, f) f! = Symbol(f, '!') + _hermitian = startswith(string(f), "eigh") @eval begin function $copy_f(input, alg) - if $f === eigh_full || $f === eigh_trunc + if $_hermitian input = (input + input') / 2 end return $f(input, alg) end function ChainRulesCore.rrule(::typeof($copy_f), input, alg) output = MatrixAlgebraKit.initialize_output($f!, input, alg) - if $f === eigh_full || $f === eigh_trunc + if $_hermitian input = (input + input') / 2 else input = copy(input) @@ -229,12 +230,13 @@ end ΔD2 = Diagonal(randn(rng, complex(T), m)) for alg in (LAPACK_Simple(), LAPACK_Expert()) test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔD, ΔV), atol = atol, rtol = rtol + copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol ) test_rrule( - copy_eig_full, A, alg ⊢ NoTangent(); - output_tangent = (ΔD2, ΔV), atol = atol, rtol = rtol + copy_eig_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol + ) + test_rrule( + copy_eig_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol ) for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) @@ -305,12 +307,13 @@ end ) # copy_eigh_full includes a projector onto the Hermitian part of the matrix test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), - atol = atol, rtol = rtol + copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD, ΔV), atol, rtol ) test_rrule( - copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), - atol = atol, rtol = rtol + copy_eigh_full, A, alg ⊢ NoTangent(); output_tangent = (ΔD2, ΔV), atol, rtol + ) + test_rrule( + copy_eigh_vals, A, alg ⊢ NoTangent(); output_tangent = diagview(ΔD), atol, rtol ) for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) From b0eec92bdb55bd77f8af290bddaa5d39cfbd230d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Dec 2025 09:56:44 -0500 Subject: [PATCH 4/6] add zygote tests --- test/chainrules.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/chainrules.jl b/test/chainrules.jl index 63d77829..5258b839 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -287,6 +287,10 @@ end config, last ∘ eig_full, A; output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + test_rrule( + config, eig_vals, A; + output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) end @timedtestset "EIGH AD Rules with eltype $T" for T in (Float64, ComplexF64, Float32) @@ -365,6 +369,10 @@ end config, last ∘ eigh_full ∘ Matrix ∘ Hermitian, A; output_tangent = ΔV, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + test_rrule( + config, eigh_vals ∘ Matrix ∘ Hermitian, A; + output_tangent = diagview(ΔD), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) eigh_trunc2(A; kwargs...) = eigh_trunc(Matrix(Hermitian(A)); kwargs...) for r in 1:4:m trunc = truncrank(r; by = real) @@ -459,6 +467,10 @@ end output_tangent = (ΔU, ΔS2, ΔVᴴ), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) + test_rrule( + config, svd_vals, A; + output_tangent = diagview(ΔS), atol, rtol, rrule_f = rrule_via_ad, check_inferred = false + ) for r in 1:4:minmn trunc = truncrank(r) ind = MatrixAlgebraKit.findtruncated(diagview(S), trunc) From 29113c1dc70b69f81770e02f68e8e7a45c9a9f13 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Dec 2025 10:05:22 -0500 Subject: [PATCH 5/6] also update mooncake rules --- .../MatrixAlgebraKitMooncakeExt.jl | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index d845dd97..aa16f61e 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -6,9 +6,10 @@ using MatrixAlgebraKit using MatrixAlgebraKit: inv_safe, diagview, copy_input using MatrixAlgebraKit: qr_pullback!, lq_pullback! using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! -using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback! +using MatrixAlgebraKit: eig_trunc_pullback!, eigh_trunc_pullback!, eigh_vals_pullback! using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! -using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback! +using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_vals_pullback! using LinearAlgebra @@ -122,8 +123,8 @@ for (f!, f, pb, adj) in ( end for (f!, f, f_full, pb, adj) in ( - (:eig_vals!, :eig_vals, :eig_full, :eig_pullback!, :eig_vals_adjoint), - (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback!, :eigh_vals_adjoint), + (:eig_vals!, :eig_vals, :eig_full, :eig_vals_pullback!, :eig_vals_adjoint), + (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_vals_pullback!, :eigh_vals_adjoint), ) @eval begin @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} @@ -136,7 +137,7 @@ for (f!, f, f_full, pb, adj) in ( copy!(D, diagview(DV[1])) V = DV[2] function $adj(::NoRData) - $pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing)) + $pb(dA, A, DV, dD) MatrixAlgebraKit.zero!(dD) return NoRData(), NoRData(), NoRData(), NoRData() end @@ -153,7 +154,7 @@ for (f!, f, f_full, pb, adj) in ( output_codual = CoDual(output, Mooncake.zero_tangent(output)) function $adj(::NoRData) D, dD = arrayify(output_codual) - $pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing)) + $pb(dA, A, DV, dD) MatrixAlgebraKit.zero!(dD) return NoRData(), NoRData(), NoRData() end @@ -272,10 +273,10 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua # compute primal A, dA = arrayify(A_dA) S, dS = arrayify(S_dS) - U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) - copy!(S, diagview(nS)) + USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + copy!(S, diagview(USVᴴ[2])) function svd_vals_adjoint(::NoRData) - svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing)) + svd_vals_pullback!(dA, A, USVᴴ, dS) MatrixAlgebraKit.zero!(dS) return NoRData(), NoRData(), NoRData(), NoRData() end @@ -286,15 +287,16 @@ end function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) - U, S, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + USVᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal # of ComplexF32) into the correct **forwards** data type (since we are now in the forward # pass). For many types this is done automatically when the forward step returns, but # not for nested structs with various fields (like Diagonal{Complex}) - S_codual = CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(diagview(S)))) + S = diagview(USVᴴ[2]) + S_codual = CoDual(S, Mooncake.fdata(Mooncake.zero_tangent(S))) function svd_vals_adjoint(::NoRData) S, dS = arrayify(S_codual) - svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing)) + svd_vals_pullback!(dA, A, USVᴴ, dS) MatrixAlgebraKit.zero!(dS) return NoRData(), NoRData(), NoRData() end From d9ac8fd1d92e4b0d180a9ed5dad787fe2c5499b6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 2 Dec 2025 10:05:27 -0500 Subject: [PATCH 6/6] add exports --- src/MatrixAlgebraKit.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 30341c49..fd97497b 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -63,8 +63,9 @@ export notrunc, truncrank, trunctol, truncerror, truncfilter Expr( :public, :left_polar_pullback!, :right_polar_pullback!, :qr_pullback!, :qr_null_pullback!, :lq_pullback!, :lq_null_pullback!, - :eig_pullback!, :eig_trunc_pullback!, :eigh_pullback!, :eigh_trunc_pullback!, - :svd_pullback!, :svd_trunc_pullback! + :eig_pullback!, :eig_trunc_pullback!, :eig_vals_pullback!, + :eigh_pullback!, :eigh_trunc_pullback!, :eigh_vals_pullback!, + :svd_pullback!, :svd_trunc_pullback!, :svd_vals_pullback! ) ) eval(Expr(:public, :is_left_isometric, :is_right_isometric))