@@ -171,14 +171,15 @@ end
171171 ΔV = remove_eiggauge_dependence! (ΔV, D, V; degeneracy_atol = atol)
172172 ΔD = randn (rng, complex (T), m, m)
173173 ΔD2 = Diagonal (randn (rng, complex (T), m))
174+ fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range = 1.0e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
174175 @testset for alg in (
175176 LAPACK_Simple (),
176177 # LAPACK_Expert(), # expensive on CI
177178 )
178179 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
179- test_reverse (eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
180+ test_reverse (eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)), fdm = fdm )
180181 test_pullbacks_match (rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg)
181- test_reverse (eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy (ΔD2. diag))
182+ test_reverse (eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy (ΔD2. diag), fdm = fdm )
182183 test_pullbacks_match (rng, eig_vals!, eig_vals, A, D. diag, ΔD2. diag, alg)
183184 end
184185 @testset " eig_trunc reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
189190 Vtrunc = V[:, ind]
190191 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
191192 ΔVtrunc = ΔV[:, ind]
192- test_reverse (eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
193+ test_reverse (eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm )
193194 test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
194195 dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
195196 dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
201202 Vtrunc = V[:, ind]
202203 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
203204 ΔVtrunc = ΔV[:, ind]
204- test_reverse (eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
205+ test_reverse (eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm )
205206 test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
206207 dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
207208 dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
@@ -275,26 +276,26 @@ end
275276 m = 19
276277 atol = rtol = m * m * precision (T)
277278 A = make_eigh_matrix (rng, T, m)
278- Ac = copy (A)
279- A = (A + A' ) / 2
279+ # A = (A + A') / 2
280280 D, V = eigh_full (A)
281281 D2 = Diagonal (D)
282282 ΔV = randn (rng, T, m, m)
283283 ΔV = remove_eighgauge_dependence! (ΔV, D, V; degeneracy_atol = atol)
284284 ΔD = randn (rng, real (T), m, m)
285285 ΔD2 = Diagonal (randn (rng, real (T), m))
286+ fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range = 1.0e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
286287 @testset for alg in (
287288 LAPACK_QRIteration (),
288289 # LAPACK_DivideAndConquer(),
289290 # LAPACK_Bisection(),
290291 # LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI
291292 )
292293 @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
293- test_reverse (copy_eigh_full, RT, (Ac , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
294- test_reverse (copy_eigh_full!, RT, (copy (Ac ), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)))
295- test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, Ac , (D, V), (ΔD2, ΔV), alg)
296- test_reverse (copy_eigh_vals, RT, (Ac , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy (ΔD2. diag))
297- test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, Ac , D. diag, ΔD2. diag, alg)
294+ test_reverse (copy_eigh_full, RT, (A , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)), fdm = fdm )
295+ test_reverse (copy_eigh_full!, RT, (copy (A ), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy (ΔD2), copy (ΔV)), fdm = fdm )
296+ test_pullbacks_match (rng, copy_eigh_full!, copy_eigh_full, A , (D, V), (ΔD2, ΔV), alg)
297+ test_reverse (copy_eigh_vals, RT, (A , TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy (ΔD2. diag), fdm = fdm )
298+ test_pullbacks_match (rng, copy_eigh_vals!, copy_eigh_vals, A , D. diag, ΔD2. diag, alg)
298299 end
299300 @testset " eigh_trunc reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
300301 for r in 1 : 4 : m
305306 Vtrunc = V[:, ind]
306307 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
307308 ΔVtrunc = ΔV[:, ind]
308- test_reverse (copy_eigh_trunc_no_error, RT, (Ac , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
309- test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
309+ test_reverse (copy_eigh_trunc_no_error, RT, (A , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm )
310+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
310311 end
311312 Ddiag = diagview (D)
312313 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = maximum (abs, Ddiag) / 2 ))
315316 Vtrunc = V[:, ind]
316317 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
317318 ΔVtrunc = ΔV[:, ind]
318- test_reverse (copy_eigh_trunc_no_error, RT, (Ac , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
319- test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
319+ test_reverse (copy_eigh_trunc_no_error, RT, (A , TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm )
320+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A , (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
320321 end
321322 end
322323end
0 commit comments