diff --git a/test/mooncake.jl b/test/mooncake.jl index c3917847..0c1b5d56 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -328,9 +328,20 @@ function copy_eigh_trunc!(A, DV, alg; kwargs...) return eigh_trunc!(A, DV, alg; kwargs...) end +function copy_eigh_trunc_no_error(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg; kwargs...) +end + +function copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg; kwargs...) +end + MatrixAlgebraKit.copy_input(::typeof(copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) MatrixAlgebraKit.copy_input(::typeof(copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) +MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) @timedtestset "EIGH AD Rules with eltype $T" for T in ETs rng = StableRNG(12345) @@ -374,6 +385,9 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) end truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) @@ -386,6 +400,9 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) + dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc)), dDtrunc, dVtrunc) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg) end end end