diff --git a/Project.toml b/Project.toml index d7044a33..a4e1c820 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" GenericLinearAlgebra = "14197337-ba66-59df-a3e3-ca00e7dcff7a" GenericSchur = "c145ed77-6b09-5dd9-b285-bf645a82121e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore" @@ -19,6 +20,7 @@ MatrixAlgebraKitAMDGPUExt = "AMDGPU" MatrixAlgebraKitCUDAExt = "CUDA" MatrixAlgebraKitGenericLinearAlgebraExt = "GenericLinearAlgebra" MatrixAlgebraKitGenericSchurExt = "GenericSchur" +MatrixAlgebraKitMooncakeExt = "Mooncake" [compat] AMDGPU = "2" @@ -30,6 +32,7 @@ GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" JET = "0.9, 0.10" LinearAlgebra = "1" +Mooncake = "0.4.174" SafeTestsets = "0.1" StableRNGs = "1" Test = "1" @@ -43,6 +46,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -50,4 +54,4 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"] diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl new file mode 100644 index 00000000..d845dd97 --- /dev/null +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -0,0 +1,333 @@ +module MatrixAlgebraKitMooncakeExt + +using Mooncake +using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive +using MatrixAlgebraKit +using MatrixAlgebraKit: inv_safe, diagview, copy_input +using MatrixAlgebraKit: qr_pullback!, lq_pullback! +using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! +using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! +using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback! +using LinearAlgebra + + +Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.NoTangent + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any} +function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual) + Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA)) + dAc = Mooncake.zero_tangent(Ac) + function copy_input_pb(::NoRData) + Mooncake.increment!!(Mooncake.tangent(A_dA), dAc) + return NoRData(), NoRData(), NoRData() + end + return CoDual(Ac, dAc), copy_input_pb +end + +# two-argument in-place factorizations like LQ, QR, EIG +for (f!, f, pb, adj) in ( + (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), + (:lq_full!, :lq_full, :lq_pullback!, :lq_adjoint), + (:qr_compact!, :qr_compact, :qr_pullback!, :qr_adjoint), + (:lq_compact!, :lq_compact, :lq_pullback!, :lq_adjoint), + (:eig_full!, :eig_full, :eig_pullback!, :eig_adjoint), + (:eigh_full!, :eigh_full, :eigh_pullback!, :eigh_adjoint), + (:left_polar!, :left_polar, :left_polar_pullback!, :left_polar_adjoint), + (:right_polar!, :right_polar, :right_polar_pullback!, :right_polar_adjoint), + ) + + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + args = Mooncake.primal(args_dargs) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + Ac = copy(A) + arg1c = copy(arg1) + arg2c = copy(arg2) + $f!(A, args, Mooncake.primal(alg_dalg)) + function $adj(::NoRData) + copy!(A, Ac) + $pb(dA, A, (arg1, arg2), (darg1, darg2)) + copy!(arg1, arg1c) + copy!(arg2, arg2c) + MatrixAlgebraKit.zero!(darg1) + MatrixAlgebraKit.zero!(darg2) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return args_dargs, $adj + end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function $adj(::NoRData) + arg1, arg2 = Mooncake.primal(output_codual) + darg1_, darg2_ = Mooncake.tangent(output_codual) + arg1, darg1 = arrayify(arg1, darg1_) + arg2, darg2 = arrayify(arg2, darg2_) + $pb(dA, A, (arg1, arg2), (darg1, darg2)) + MatrixAlgebraKit.zero!(darg1) + MatrixAlgebraKit.zero!(darg2) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end + end +end + +for (f!, f, pb, adj) in ( + (:qr_null!, :qr_null, :qr_null_pullback!, :qr_null_adjoint), + (:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + Ac = copy(A) + arg, darg = arrayify(arg_darg) + argc = copy(arg) + $f!(A, arg, Mooncake.primal(alg_dalg)) + function $adj(::NoRData) + copy!(A, Ac) + $pb(dA, A, arg, darg) + copy!(arg, argc) + MatrixAlgebraKit.zero!(darg) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return arg_darg, $adj + end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + output_codual = CoDual(output, Mooncake.zero_tangent(output)) + function $adj(::NoRData) + arg, darg = arrayify(output_codual) + $pb(dA, A, arg, darg) + MatrixAlgebraKit.zero!(darg) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end + end +end + +for (f!, f, f_full, pb, adj) in ( + (:eig_vals!, :eig_vals, :eig_full, :eig_pullback!, :eig_vals_adjoint), + (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback!, :eigh_vals_adjoint), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + D, dD = arrayify(D_dD) + # update primal + DV = $f_full(A, Mooncake.primal(alg_dalg)) + copy!(D, diagview(DV[1])) + V = DV[2] + function $adj(::NoRData) + $pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing)) + MatrixAlgebraKit.zero!(dD) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return D_dD, $adj + end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + # update primal + DV = $f_full(A, Mooncake.primal(alg_dalg)) + V = DV[2] + output = diagview(DV[1]) + output_codual = CoDual(output, Mooncake.zero_tangent(output)) + function $adj(::NoRData) + D, dD = arrayify(output_codual) + $pb(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing)) + MatrixAlgebraKit.zero!(dD) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end + end +end + +for (f, pb, adj) in ( + (:eig_trunc, :eig_trunc_pullback!, :eig_trunc_adjoint), + (:eigh_trunc, :eigh_trunc_pullback!, :eigh_trunc_adjoint), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + output = $f(A, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real} + Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual) + dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual) + abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error" + D, dD = arrayify(Dtrunc, dDtrunc_) + V, dV = arrayify(Vtrunc, dVtrunc_) + $pb(dA, A, (D, V), (dD, dV)) + MatrixAlgebraKit.zero!(dD) + MatrixAlgebraKit.zero!(dV) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end + end +end + +for (f!, f) in ( + (:svd_full!, :svd_full), + (:svd_compact!, :svd_compact), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) + A, dA = arrayify(A_dA) + Ac = copy(A) + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + U, dU = arrayify(USVᴴ[1], dUSVᴴ[1]) + S, dS = arrayify(USVᴴ[2], dUSVᴴ[2]) + Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3]) + output = $f!(A, Mooncake.primal(alg_dalg)) + function svd_adjoint(::NoRData) + copy!(A, Ac) + if $(f! == svd_compact!) + svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + else # full + minmn = min(size(A)...) + vU = view(U, :, 1:minmn) + vS = Diagonal(diagview(S)[1:minmn]) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = Diagonal(diagview(dS)[1:minmn]) + vdVᴴ = view(dVᴴ, 1:minmn, :) + svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) + end + MatrixAlgebraKit.zero!(dU) + MatrixAlgebraKit.zero!(dS) + MatrixAlgebraKit.zero!(dVᴴ) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return CoDual(output, dUSVᴴ), svd_adjoint + end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) + A, dA = arrayify(A_dA) + USVᴴ = $f(A, Mooncake.primal(alg_dalg)) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + USVᴴ_codual = CoDual(USVᴴ, Mooncake.fdata(Mooncake.zero_tangent(USVᴴ))) + function svd_adjoint(::NoRData) + U, S, Vᴴ = Mooncake.primal(USVᴴ_codual) + dU_, dS_, dVᴴ_ = Mooncake.tangent(USVᴴ_codual) + U, dU = arrayify(U, dU_) + S, dS = arrayify(S, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_) + if $(f == svd_compact) + svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + else # full + minmn = min(size(A)...) + vU = view(U, :, 1:minmn) + vS = Diagonal(view(diagview(S), 1:minmn)) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = Diagonal(view(diagview(dS), 1:minmn)) + vdVᴴ = view(dVᴴ, 1:minmn, :) + svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) + end + MatrixAlgebraKit.zero!(dU) + MatrixAlgebraKit.zero!(dS) + MatrixAlgebraKit.zero!(dVᴴ) + return NoRData(), NoRData(), NoRData() + end + return USVᴴ_codual, svd_adjoint + end + end +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + S, dS = arrayify(S_dS) + U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + copy!(S, diagview(nS)) + function svd_vals_adjoint(::NoRData) + svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing)) + MatrixAlgebraKit.zero!(dS) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return S_dS, svd_vals_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + U, S, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + S_codual = CoDual(diagview(S), Mooncake.fdata(Mooncake.zero_tangent(diagview(S)))) + function svd_vals_adjoint(::NoRData) + S, dS = arrayify(S_codual) + svd_pullback!(dA, A, (U, Diagonal(S), Vᴴ), (nothing, Diagonal(dS), nothing)) + MatrixAlgebraKit.zero!(dS) + return NoRData(), NoRData(), NoRData() + end + return S_codual, svd_vals_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) + # compute primal + A_ = Mooncake.primal(A_dA) + dA_ = Mooncake.tangent(A_dA) + A, dA = arrayify(A_, dA_) + alg = Mooncake.primal(alg_dalg) + output = svd_trunc(A, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} + Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) + dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) + abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" + U, dU = arrayify(Utrunc, dUtrunc_) + S, dS = arrayify(Strunc, dStrunc_) + Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) + svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + MatrixAlgebraKit.zero!(dU) + MatrixAlgebraKit.zero!(dS) + MatrixAlgebraKit.zero!(dVᴴ) + return NoRData(), NoRData(), NoRData() + end + return output_codual, svd_trunc_adjoint +end + +end diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index fabc2c2e..1c6de509 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -4,7 +4,7 @@ Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `WP` and cotangent `ΔWP` of `left_polar(A)`. """ -function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP) +function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) # Extract the Polar components W, P = WP @@ -34,7 +34,7 @@ end Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `PWᴴ` and cotangent `ΔPWᴴ` of `right_polar(A)`. """ -function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ) +function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...) # Extract the Polar components P, Wᴴ = PWᴴ diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index c0353a3a..a85b7165 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -26,14 +26,13 @@ function svd_pullback!( degeneracy_atol::Real = default_pullback_rank_atol(USVᴴ[2]), gauge_atol::Real = default_pullback_gauge_atol(ΔUSVᴴ[1], ΔUSVᴴ[3]) ) - # Extract the SVD components U, Smat, Vᴴ = USVᴴ m, n = size(U, 1), size(Vᴴ, 2) - (m, n) == size(ΔA) || throw(DimensionMismatch()) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) minmn = min(m, n) S = diagview(Smat) - length(S) == minmn || throw(DimensionMismatch()) + length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) r = searchsortedlast(S, rank_atol; rev = true) # rank Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) @@ -44,22 +43,22 @@ function svd_pullback!( UΔU = fill!(similar(U, (r, r)), 0) VΔV = fill!(similar(Vᴴ, (r, r)), 0) if !iszerotangent(ΔU) - m == size(ΔU, 1) || throw(DimensionMismatch()) + m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)")) pU = size(ΔU, 2) - pU > r && throw(DimensionMismatch()) + pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)")) indU = axes(U, 2)[ind] - length(indU) == pU || throw(DimensionMismatch()) + length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))")) UΔUp = view(UΔU, :, indU) mul!(UΔUp, Ur', ΔU) # ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1) end if !iszerotangent(ΔVᴴ) - n == size(ΔVᴴ, 2) || throw(DimensionMismatch()) + n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)")) pV = size(ΔVᴴ, 1) - pV > r && throw(DimensionMismatch()) + pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)")) indV = axes(Vᴴ, 1)[ind] - length(indV) == pV || throw(DimensionMismatch()) + length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))")) VΔVp = view(VΔV, :, indV) mul!(VΔVp, Vᴴr, ΔVᴴ') # ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ @@ -82,7 +81,7 @@ function svd_pullback!( ΔS = diagview(ΔSmat) pS = length(ΔS) indS = axes(S, 1)[ind] - length(indS) == pS || throw(DimensionMismatch()) + length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))")) view(diagview(UdΔAV), indS) .+= real.(ΔS) end ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA diff --git a/test/ad_utils.jl b/test/ad_utils.jl new file mode 100644 index 00000000..4c03e50c --- /dev/null +++ b/test/ad_utils.jl @@ -0,0 +1,31 @@ +function remove_svdgauge_dependence!( + ΔU, ΔVᴴ, U, S, Vᴴ; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) + ) + gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true) + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 + mul!(ΔU, U, gaugepart, -1, 1) + return ΔU, ΔVᴴ +end +function remove_eiggauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) + ) + gaugepart = V' * ΔV + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V / (V' * V), gaugepart, -1, 1) + return ΔV +end +function remove_eighgauge_dependence!( + ΔV, D, V; + degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) + ) + gaugepart = V' * ΔV + gaugepart = project_antihermitian!(gaugepart) + gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 + mul!(ΔV, V, gaugepart, -1, 1) + return ΔV +end + +precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) diff --git a/test/chainrules.jl b/test/chainrules.jl index ba3f0681..76eb84c8 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -6,38 +6,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! -function remove_svdgauge_dependence!( - ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S) - ) - gaugepart = U' * ΔU + Vᴴ * ΔVᴴ' - gaugepart = (gaugepart - gaugepart') / 2 - gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0 - mul!(ΔU, U, gaugepart, -1, 1) - return ΔU, ΔVᴴ -end -function remove_eiggauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D) - ) - gaugepart = V' * ΔV - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V / (V' * V), gaugepart, -1, 1) - return ΔV -end -function remove_eighgauge_dependence!( - ΔV, D, V; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D) - ) - gaugepart = V' * ΔV - gaugepart = (gaugepart - gaugepart') / 2 - gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0 - mul!(ΔV, V, gaugepart, -1, 1) - return ΔV -end - -precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32)) -precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64)) +include("ad_utils.jl") for f in ( diff --git a/test/mooncake.jl b/test/mooncake.jl new file mode 100644 index 00000000..554b35d9 --- /dev/null +++ b/test/mooncake.jl @@ -0,0 +1,560 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using Mooncake, Mooncake.TestUtils +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc +using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul! + +include("ad_utils.jl") + +make_mooncake_tangent(ΔAelem::T) where {T <: Complex} = Mooncake.build_tangent(T, real(ΔAelem), imag(ΔAelem)) +make_mooncake_tangent(ΔA::Matrix{<:Real}) = ΔA +make_mooncake_tangent(ΔA::Vector{<:Real}) = ΔA +make_mooncake_tangent(ΔA::Matrix{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔA::Vector{T}) where {T <: Complex} = map(make_mooncake_tangent, ΔA) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Real} = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) +make_mooncake_tangent(ΔD::Diagonal{T}) where {T <: Complex} = Mooncake.build_tangent(typeof(ΔD), map(make_mooncake_tangent, diagview(ΔD))) + +make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...) + +make_mooncake_fdata(x) = make_mooncake_tangent(x) +make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) + +ETs = (Float64, Float32, ComplexF64, ComplexF32) + +# no `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_copy = make_mooncake_tangent(copy(ΔA)) + A_copy = copy(A) + dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy)) + copy_pb!!(rdata) + return dA_copy +end + +# `alg` argument +function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_copy = make_mooncake_tangent(copy(ΔA)) + A_copy = copy(A) + dargs_copy = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f_c, Mooncake.NoFData()), Mooncake.CoDual(A_copy, dA_copy), Mooncake.CoDual(args, dargs_copy), Mooncake.CoDual(alg, Mooncake.NoFData())) + copy_pb!!(rdata) + return dA_copy +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_inplace = make_mooncake_tangent(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace)) + end + inplace_pb!!(rdata) + return dA_inplace +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = make_mooncake_tangent(copy(ΔA)) + A_inplace = copy(A) + dargs_inplace = Δargs isa Tuple ? make_mooncake_fdata.(deepcopy(Δargs)) : make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), Mooncake.CoDual(A_inplace, dA_inplace), Mooncake.CoDual(args, dargs_inplace), Mooncake.CoDual(alg, Mooncake.NoFData())) + end + inplace_pb!!(rdata) + return dA_inplace +end + +""" + test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + +Compare the result of running the *in-place, mutating* function `f!`'s reverse rule +with the result of running its *non-mutating* partner function `f`'s reverse rule. +We must compare directly because many of the mutating functions modify `A` as a +scratch workspace, making testing `f!` against finite differences infeasible. + +The arguments to this function are: + - `f!` the mutating, in-place version of the function (accepts `args` for the function result) + - `f` the non-mutating version of the function (does not accept `args` for the function result) + - `A` the input matrix to factorize + - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) + - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input + - `alg` optional algorithm keyword argument + - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) +""" +function test_pullbacks_match(rng, f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + f_c = isnothing(alg) ? (A, args) -> f!(MatrixAlgebraKit.copy_input(f, A), args) : (A, args, alg) -> f!(MatrixAlgebraKit.copy_input(f, A), args, alg) + sig = isnothing(alg) ? Tuple{typeof(f_c), typeof(A), typeof(args)} : Tuple{typeof(f_c), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + rrule = Mooncake.build_rrule(rvs_interp, sig) + ΔA = randn(rng, eltype(A), size(A)) + + dA_copy = _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_inplace = _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata) + + dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] + dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] + @test dA_inplace_ ≈ dA_copy_ + return +end + +@timedtestset "QR AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in ( + LAPACK_HouseholderQR(), + LAPACK_HouseholderQR(; positive = true), + ) + @testset "qr_compact" begin + QR = qr_compact(A, alg) + Q = randn(rng, T, m, minmn) + R = randn(rng, T, minmn, n) + Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) + end + @testset "qr_null" begin + Q, R = qr_compact(A, alg) + ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) + N = qr_null(A, alg) + dN = make_mooncake_tangent(copy(ΔN)) + Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dN, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) + end + @testset "qr_full" begin + Q, R = qr_full(A, alg) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) + mul!(ΔQ2, Q1, Q1' * ΔQ2) + ΔR = randn(rng, T, m, n) + dQ = make_mooncake_tangent(copy(ΔQ)) + dR = make_mooncake_tangent(copy(ΔR)) + dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) + Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) + end + @testset "qr_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + Q, R = qr_compact(Ard, alg) + QR = (Q, R) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + ΔQ2 .= 0 + ΔR = randn(rng, T, minmn, n) + view(ΔR, (r + 1):minmn, :) .= 0 + dQ = make_mooncake_tangent(copy(ΔQ)) + dR = make_mooncake_tangent(copy(ΔR)) + dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) + end + end + end +end + +@timedtestset "LQ AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in ( + LAPACK_HouseholderLQ(), + LAPACK_HouseholderLQ(; positive = true), + ) + @testset "lq_compact" begin + L, Q = lq_compact(A, alg) + Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) + end + @testset "lq_null" begin + L, Q = lq_compact(A, alg) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + Nᴴ = randn(rng, T, max(0, n - minmn), n) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) + end + @testset "lq_full" begin + L, Q = lq_full(A, alg) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + mul!(ΔQ2, ΔQ2 * Q1', Q1) + ΔL = randn(rng, T, m, n) + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) + end + @testset "lq_compact - rank-deficient A" begin + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + L, Q = lq_compact(Ard, alg) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔQ2 .= 0 + view(ΔL, :, (r + 1):minmn) .= 0 + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) + end + end + end +end + +@timedtestset "EIG AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + DV = eig_full(A) + D, V = DV + Ddiag = diagview(D) + ΔV = randn(rng, complex(T), m, m) + ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol) + ΔD = randn(rng, complex(T), m, m) + ΔD2 = Diagonal(randn(rng, complex(T), m)) + + dD = make_mooncake_tangent(ΔD2) + dV = make_mooncake_tangent(ΔV) + dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) + # compute the dA corresponding to the above dD, dV + @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) + @testset "eig_full" begin + Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) + end + @testset "eig_vals" begin + Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + end +end + +function copy_eigh_full(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_full(A, alg; kwargs...) +end + +function copy_eigh_full!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV, alg; kwargs...) +end + +function copy_eigh_vals(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_vals(A, alg; kwargs...) +end + +function copy_eigh_vals!(A, D, alg; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D, alg; kwargs...) +end + +function copy_eigh_trunc(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc(A, alg; kwargs...) +end + +function copy_eigh_trunc!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc!(A, DV, alg; kwargs...) +end + +MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) +MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) +MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) + +@timedtestset "EIGH AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + atol = rtol = m * m * precision(T) + A = randn(rng, T, m, m) + A = A + A' + D, V = eigh_full(A) + ΔV = randn(rng, T, m, m) + ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) + ΔD = randn(rng, real(T), m, m) + ΔD2 = Diagonal(randn(rng, real(T), m)) + dD = make_mooncake_tangent(ΔD2) + dV = make_mooncake_tangent(ΔV) + dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) + Ddiag = diagview(D) + @testset for alg in ( + LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + LAPACK_Bisection(), + LAPACK_MultipleRelativelyRobustRepresentations(), + ) + @testset "eigh_full" begin + Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) + end + @testset "eigh_vals" begin + Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) + end + @testset "eigh_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) + ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) + Dtrunc = Diagonal(diagview(D)[ind]) + Vtrunc = V[:, ind] + ΔDtrunc = Diagonal(diagview(ΔD2)[ind]) + ΔVtrunc = ΔV[:, ind] + dDtrunc = make_mooncake_tangent(ΔDtrunc) + dVtrunc = make_mooncake_tangent(ΔVtrunc) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + end +end + +@timedtestset "SVD AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + minmn = min(m, n) + @testset for alg in ( + LAPACK_QRIteration(), + LAPACK_DivideAndConquer(), + ) + @testset "svd_compact" begin + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + dS = make_mooncake_tangent(ΔS2) + dU = make_mooncake_tangent(ΔU) + dVᴴ = make_mooncake_tangent(ΔVᴴ) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dU, dS, dVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), alg) + end + @testset "svd_full" begin + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + U, S, Vᴴ = svd_compact(A) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔUfull = zeros(T, m, m) + ΔSfull = zeros(real(T), m, n) + ΔVᴴfull = zeros(T, n, n) + U, S, Vᴴ = svd_full(A) + view(ΔUfull, :, 1:minmn) .= ΔU + view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ + diagview(ΔSfull)[1:minmn] .= diagview(ΔS2) + dS = make_mooncake_tangent(ΔSfull) + dU = make_mooncake_tangent(ΔUfull) + dVᴴ = make_mooncake_tangent(ΔVᴴfull) + dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull, ΔSfull, ΔVᴴfull)), dU, dS, dVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) + end + @testset "svd_vals" begin + Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + S = svd_vals(A, alg) + test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, randn(rng, real(T), minmn), alg) + end + @testset "svd_trunc" begin + @testset for r in 1:4:minmn + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + truncalg = TruncatedAlgorithm(alg, truncrank(r)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + dStrunc = make_mooncake_tangent(ΔStrunc) + dUtrunc = make_mooncake_tangent(ΔUtrunc) + dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) + ϵ = zero(real(T)) + dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + @testset "trunctol" begin + U, S, Vᴴ = svd_compact(A) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] + ΔVᴴtrunc = ΔVᴴ[ind, :] + dStrunc = make_mooncake_tangent(ΔStrunc) + dUtrunc = make_mooncake_tangent(ΔUtrunc) + dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) + ϵ = zero(real(T)) + dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + end + end + end + end +end + +@timedtestset "Polar AD Rules with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) + if m >= n + WP = left_polar(A, alg) + Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, left_polar!, left_polar, A, WP, (randn(rng, T, m, n), randn(rng, T, n, n)), alg) + elseif m <= n + PWᴴ = right_polar(A, alg) + Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, (randn(rng, T, m, m), randn(rng, T, m, n)), alg) + end + end + end +end + +@timedtestset "Orth and null with eltype $T" for T in ETs + rng = StableRNG(12345) + m = 19 + @testset "size ($m, $n)" for n in (17, m, 23) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) + VC = left_orth(A) + CVᴴ = right_orth(A) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, left_orth!, left_orth, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) + + left_orth_qr(X) = left_orth(X; alg = :qr) + left_orth_polar(X) = left_orth(X; alg = :polar) + MatrixAlgebraKit.copy_input(left_orth_qr, A) = MatrixAlgebraKit.copy_input(left_orth, A) + MatrixAlgebraKit.copy_input(left_orth_polar, A) = MatrixAlgebraKit.copy_input(left_orth, A) + + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) + if m >= n + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) + end + + left_null_qr(X) = left_null(X; alg = :qr) + MatrixAlgebraKit.copy_input(left_null_qr, A) = MatrixAlgebraKit.copy_input(left_null, A) + N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) + test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) + + right_orth_lq(X) = right_orth(X; alg = :lq) + right_orth_polar(X) = right_orth(X; alg = :polar) + MatrixAlgebraKit.copy_input(right_orth_lq, A) = MatrixAlgebraKit.copy_input(right_orth, A) + MatrixAlgebraKit.copy_input(right_orth_polar, A) = MatrixAlgebraKit.copy_input(right_orth, A) + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) + + if m <= n + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) + end + + right_null_lq(X) = right_null(X; alg = :lq) + MatrixAlgebraKit.copy_input(right_null_lq, A) = MatrixAlgebraKit.copy_input(right_null, A) + Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) + test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ec255538..4b69a3dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,9 @@ if !is_buildkite @safetestset "Image and Null Space" begin include("orthnull.jl") end + @safetestset "Mooncake" begin + include("mooncake.jl") + end @safetestset "ChainRules" begin include("chainrules.jl") end