Skip to content

Commit 52a26e7

Browse files
committed
Working eigh_trunc
1 parent 29f6e8e commit 52a26e7

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

src/common/safemethods.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,4 @@ sign_safe(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))
1818
Compute the inverse of a number `a`, but return zero if `a` is smaller than `tol`.
1919
"""
2020
inv_safe(a::Number, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a)
21-
function inv_safe(a::ComplexF32, tol = defaulttol(a))
22-
str = string(a) # WHY does this fix the NaN issues??????
23-
return abs(a) < tol ? zero(a) : inv(a)
24-
end
21+
@noinline inv_safe(a::ComplexF32, tol = defaulttol(a)) = abs(a) < tol ? zero(a) : inv(a)

test/enzyme.jl

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,15 @@ end
171171
ΔV = remove_eiggauge_dependence!(ΔV, D, V; degeneracy_atol = atol)
172172
ΔD = randn(rng, complex(T), m, m)
173173
ΔD2 = Diagonal(randn(rng, complex(T), m))
174+
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
174175
@testset for alg in (
175176
LAPACK_Simple(),
176177
#LAPACK_Expert(), # expensive on CI
177178
)
178179
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
179-
test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
180+
test_reverse(eig_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm)
180181
test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg)
181-
test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
182+
test_reverse(eig_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm)
182183
test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg)
183184
end
184185
@testset "eig_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
@@ -189,7 +190,7 @@ end
189190
Vtrunc = V[:, ind]
190191
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
191192
ΔVtrunc = ΔV[:, ind]
192-
test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
193+
test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm)
193194
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
194195
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
195196
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
@@ -201,7 +202,7 @@ end
201202
Vtrunc = V[:, ind]
202203
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
203204
ΔVtrunc = ΔV[:, ind]
204-
test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
205+
test_reverse(eig_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm)
205206
test_pullbacks_match(rng, eig_trunc_no_error!, eig_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg; ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
206207
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
207208
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
@@ -275,26 +276,26 @@ end
275276
m = 19
276277
atol = rtol = m * m * precision(T)
277278
A = make_eigh_matrix(rng, T, m)
278-
Ac = copy(A)
279-
A = (A + A') / 2
279+
#A = (A + A') / 2
280280
D, V = eigh_full(A)
281281
D2 = Diagonal(D)
282282
ΔV = randn(rng, T, m, m)
283283
ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol)
284284
ΔD = randn(rng, real(T), m, m)
285285
ΔD2 = Diagonal(randn(rng, real(T), m))
286+
fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1)
286287
@testset for alg in (
287288
LAPACK_QRIteration(),
288289
#LAPACK_DivideAndConquer(),
289290
#LAPACK_Bisection(),
290291
#LAPACK_MultipleRelativelyRobustRepresentations(), # expensive on CI
291292
)
292293
@testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
293-
test_reverse(copy_eigh_full, RT, (Ac, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
294-
test_reverse(copy_eigh_full!, RT, (copy(Ac), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)))
295-
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, Ac, (D, V), (ΔD2, ΔV), alg)
296-
test_reverse(copy_eigh_vals, RT, (Ac, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag))
297-
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, Ac, D.diag, ΔD2.diag, alg)
294+
test_reverse(copy_eigh_full, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm)
295+
test_reverse(copy_eigh_full!, RT, (copy(A), TA), ((D, V), TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = (copy(ΔD2), copy(ΔV)), fdm = fdm)
296+
test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg)
297+
test_reverse(copy_eigh_vals, RT, (A, TA); fkwargs = (alg = alg,), atol = atol, rtol = rtol, output_tangent = copy(ΔD2.diag), fdm = fdm)
298+
test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg)
298299
end
299300
@testset "eigh_trunc reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,)
300301
for r in 1:4:m
@@ -305,8 +306,8 @@ end
305306
Vtrunc = V[:, ind]
306307
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
307308
ΔVtrunc = ΔV[:, ind]
308-
test_reverse(copy_eigh_trunc_no_error, RT, (Ac, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
309-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
309+
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm)
310+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
310311
end
311312
Ddiag = diagview(D)
312313
truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2))
@@ -315,8 +316,8 @@ end
315316
Vtrunc = V[:, ind]
316317
ΔDtrunc = Diagonal(diagview(ΔD2)[ind])
317318
ΔVtrunc = ΔV[:, ind]
318-
test_reverse(copy_eigh_trunc_no_error, RT, (Ac, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc))
319-
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, Ac, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
319+
test_reverse(copy_eigh_trunc_no_error, RT, (A, TA); fkwargs = (alg = truncalg,), atol = atol, rtol = rtol, output_tangent = (ΔDtrunc, ΔVtrunc), fdm = fdm)
320+
test_pullbacks_match(rng, copy_eigh_trunc_no_error!, copy_eigh_trunc_no_error, A, (D, V), (ΔD2, ΔV), truncalg, ȳ = (ΔDtrunc, ΔVtrunc), return_act = RT)
320321
end
321322
end
322323
end

test/runtests.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ if !is_buildkite
3434
@safetestset "Image and Null Space" begin
3535
include("orthnull.jl")
3636
end
37-
@safetestset "Enzyme" begin
38-
include("enzyme.jl")
37+
if VERSION < v"1.12.0" # reconsider when Enzyme works on 1.12
38+
@safetestset "Enzyme" begin
39+
include("enzyme.jl")
40+
end
3941
end
4042
@safetestset "Mooncake" begin
4143
include("mooncake.jl")

0 commit comments

Comments
 (0)