diff --git a/src/common/matrixproperties.jl b/src/common/matrixproperties.jl index 25876b90..9180ae15 100644 --- a/src/common/matrixproperties.jl +++ b/src/common/matrixproperties.jl @@ -50,7 +50,7 @@ function is_left_isometric(A::AbstractMatrix; atol::Real = 0, rtol::Real = defau P = A' * A nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))` diagview(P) .-= 1 - return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)` + return norm(P) ≤ max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)` end @doc """ diff --git a/test/ad_utils.jl b/test/ad_utils.jl index 4c03e50c..7a7cf39a 100644 --- a/test/ad_utils.jl +++ b/test/ad_utils.jl @@ -27,5 +27,36 @@ function remove_eighgauge_dependence!( mul!(ΔV, V, gaugepart, -1, 1) return ΔV end +function stabilize_eigvals!(D::AbstractVector) + absD = abs.(D) + p = invperm(sortperm(absD)) # rank of abs(D) + # account for exact degeneracies in absolute value when having complex conjugate pairs + for i in 1:(length(D) - 1) + if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially + p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones + end + end + n = maximum(p) + # rescale eigenvalues so that they lie on distinct radii in the complex plane + # that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n + radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n + for i in 1:length(D) + D[i] = sign(D[i]) * radii[p[i]] + end + return D +end +function make_eig_matrix(rng, T, n) + A = randn(rng, T, n, n) + D, V = eig_full(A) + stabilize_eigvals!(diagview(D)) + Ac = V * D * inv(V) + return (T <: Real) ? real(Ac) : Ac +end +function make_eigh_matrix(rng, T, n) + A = project_hermitian!(randn(rng, T, n, n)) + D, V = eigh_full(A) + stabilize_eigvals!(diagview(D)) + return project_hermitian!(V * D * V') +end precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) diff --git a/test/chainrules.jl b/test/chainrules.jl index 4be77380..a8b2fd3b 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -221,7 +221,7 @@ end rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) + A = make_eig_matrix(rng, T, m) D, V = eig_full(A) Ddiag = diagview(D) ΔV = randn(rng, complex(T), m, m) @@ -297,8 +297,7 @@ end rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) - A = A + A' + A = make_eigh_matrix(rng, T, m) D, V = eigh_full(A) Ddiag = diagview(D) ΔV = randn(rng, T, m, m) diff --git a/test/mooncake.jl b/test/mooncake.jl index 0c1b5d56..760102b1 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -238,7 +238,7 @@ end rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) + A = make_eig_matrix(rng, T, m) DV = eig_full(A) D, V = DV Ddiag = diagview(D) @@ -347,9 +347,9 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) - A = A + A' + A = make_eigh_matrix(rng, T, m) D, V = eigh_full(A) + Ddiag = diagview(D) ΔV = randn(rng, T, m, m) ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) ΔD = randn(rng, real(T), m, m) @@ -357,7 +357,6 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb dD = make_mooncake_tangent(ΔD2) dV = make_mooncake_tangent(ΔV) dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - Ddiag = diagview(D) @testset for alg in ( LAPACK_QRIteration(), #LAPACK_DivideAndConquer(),