@@ -7,9 +7,7 @@ using Enzyme, EnzymeTestUtils
77using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
88using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
99
10- is_ci = get (ENV , " CI" , " false" ) == " true"
11-
12- ETs = is_ci ? (Float64, Float32) : (Float64, Float32, ComplexF32, ComplexF64) # Enzyme/#2631
10+ ETs = (Float32, ComplexF64)
1311include (" ad_utils.jl" )
1412function test_pullbacks_match (rng, f!, f, A, args, Δargs, alg = nothing ; ȳ = copy .(Δargs), return_act = Duplicated)
1513 ΔA = randn (rng, eltype (A), size (A)... )
188186 Vtrunc = V[:, ind]
189187 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
190188 ΔVtrunc = ΔV[:, ind]
191- # broken due to Enzyme
192- # test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
193- # broken due to Enzyme
194- # test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
189+ test_reverse (eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
190+ test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
195191 dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
196192 dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
197193 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
202198 Vtrunc = V[:, ind]
203199 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
204200 ΔVtrunc = ΔV[:, ind]
205- # broken due to Enzyme
206- # test_reverse(eig_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
207- # broken due to Enzyme
208- # test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
201+ test_reverse (eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
202+ test_pullbacks_match (rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
209203 dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
210204 dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
211205 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
@@ -253,24 +247,24 @@ function copy_eigh_vals!(A, D, alg; kwargs...)
253247 return eigh_vals! (A, D, alg; kwargs... )
254248end
255249
256- function copy_eigh_trunc (A; kwargs... )
250+ function copy_eigh_trunc_no_error (A; kwargs... )
257251 A = (A + A' ) / 2
258- return eigh_trunc (A; kwargs... )
252+ return eigh_trunc_no_error (A; kwargs... )
259253end
260254
261- function copy_eigh_trunc ! (A, DV; kwargs... )
255+ function copy_eigh_trunc_no_error ! (A, DV; kwargs... )
262256 A = (A + A' ) / 2
263- return eigh_trunc ! (A, DV; kwargs... )
257+ return eigh_trunc_no_error ! (A, DV; kwargs... )
264258end
265259
266- function copy_eigh_trunc (A, alg; kwargs... )
260+ function copy_eigh_trunc_no_error (A, alg; kwargs... )
267261 A = (A + A' ) / 2
268- return eigh_trunc (A ; kwargs... )
262+ return eigh_trunc_no_error (A, alg ; kwargs... )
269263end
270264
271- function copy_eigh_trunc ! (A, DV, alg; kwargs... )
265+ function copy_eigh_trunc_no_error ! (A, DV, alg; kwargs... )
272266 A = (A + A' ) / 2
273- return eigh_trunc ! (A, DV; kwargs... )
267+ return eigh_trunc_no_error ! (A, DV, alg ; kwargs... )
274268end
275269
276270@timedtestset " EIGH AD Rules with eltype $T " for T in ETs
307301 Vtrunc = V[:, ind]
308302 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
309303 ΔVtrunc = ΔV[:, ind]
310- # broken due to Enzyme
311- # test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
312- # test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
304+ test_reverse (copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
305+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
313306 end
314307 Ddiag = diagview (D)
315308 truncalg = TruncatedAlgorithm (alg, trunctol (; atol = maximum (abs, Ddiag) / 2 ))
318311 Vtrunc = V[:, ind]
319312 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
320313 ΔVtrunc = ΔV[:, ind]
321- # broken due to Enzyme
322- # test_reverse(copy_eigh_trunc, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc, zero(real(T))))
323- # test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg, ȳ=(ΔDtrunc, ΔVtrunc, zero(real(T))), return_act=RT)
314+ test_reverse (copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
315+ test_pullbacks_match (rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
324316 end
325317 end
326318end
@@ -373,21 +365,21 @@ end
373365 @testset " svd_trunc reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
374366 for r in 1 : 4 : minmn
375367 U, S, Vᴴ = svd_compact (A)
376- ΔU = randn (rng, T, m, minmn)
377- ΔS = randn (rng, real (T), minmn, minmn)
368+ ΔU = randn (rng, T, m, minmn)
369+ ΔS = randn (rng, real (T), minmn, minmn)
378370 ΔS2 = Diagonal (randn (rng, real (T), minmn))
379371 ΔVᴴ = randn (rng, T, minmn, n)
380372 ΔU, ΔVᴴ = remove_svdgauge_dependence! (ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol)
381373 truncalg = TruncatedAlgorithm (alg, truncrank (r))
382374 ind = MatrixAlgebraKit. findtruncated (diagview (S), truncalg. trunc)
383- Strunc = Diagonal (diagview (S)[ind])
384- Utrunc = U[:, ind]
385- Vᴴtrunc = Vᴴ[ind, :]
386- ΔStrunc = Diagonal (diagview (ΔS2)[ind])
387- ΔUtrunc = ΔU[:, ind]
375+ Strunc = Diagonal (diagview (S)[ind])
376+ Utrunc = U[:, ind]
377+ Vᴴtrunc = Vᴴ[ind, :]
378+ ΔStrunc = Diagonal (diagview (ΔS2)[ind])
379+ ΔUtrunc = ΔU[:, ind]
388380 ΔVᴴtrunc = ΔVᴴ[ind, :]
389- test_reverse (svd_trunc , RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
390- test_pullbacks_match (rng, svd_trunc !, svd_trunc , A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act= RT)
381+ test_reverse (svd_trunc_no_error , RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
382+ test_pullbacks_match (rng, svd_trunc_no_error !, svd_trunc_no_error , A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act = RT)
391383 end
392384 U, S, Vᴴ = svd_compact (A)
393385 ΔU = randn (rng, T, m, minmn)
403395 ΔStrunc = Diagonal (diagview (ΔS2)[ind])
404396 ΔUtrunc = ΔU[:, ind]
405397 ΔVᴴtrunc = ΔVᴴ[ind, :]
406- test_reverse (svd_trunc , RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
407- test_pullbacks_match (rng, svd_trunc !, svd_trunc , A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act= RT)
398+ test_reverse (svd_trunc_no_error , RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), fdm = fdm)
399+ test_pullbacks_match (rng, svd_trunc_no_error !, svd_trunc_no_error , A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg, ȳ = (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), return_act = RT)
408400 end
409401 end
410402 end
0 commit comments