From a480669a9203628d829d5e9f6252f9bfb36a5dd5 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 19 Dec 2025 01:53:58 +0100 Subject: [PATCH 1/3] stabilize eig(h) tests --- test/ad_utils.jl | 35 +++++++++++++++++++++++++++++++++++ test/chainrules.jl | 5 ++--- test/mooncake.jl | 7 +++---- 3 files changed, 40 insertions(+), 7 deletions(-) diff --git a/test/ad_utils.jl b/test/ad_utils.jl index 4c03e50c..442134ee 100644 --- a/test/ad_utils.jl +++ b/test/ad_utils.jl @@ -27,5 +27,40 @@ function remove_eighgauge_dependence!( mul!(ΔV, V, gaugepart, -1, 1) return ΔV end +function stabilize_eigvals!(D) + absD = abs.(D) + p = invperm(sortperm(absD)) # rank of abs(D) + for i in 1:(length(D) - 1) + if absD[i] == absD[i + 1] + p[p .>= p[i + 1]] .-= 1 + end + end + n = maximum(p) + radii = 1 / n * ((1:n) + rand(real(eltype(D)), n) / 2) + for i in 1:length(D) + D[i] = sign(D[i]) * radii[p[i]] + end + return D +end +function make_eig_matrix(rng, T, n) + A = randn(rng, T, n, n) + D, V = eig_full(A) + Ddiag = diagview(D) + stabilize_eigvals!(Ddiag) + if T <: Real + A = real(V * D * inv(V)) + else + A = V * D * inv(V) + end + return A +end +function make_eigh_matrix(rng, T, n) + A = project_hermitian!(randn(rng, T, n, n)) + D, V = eigh_full(A) + Ddiag = diagview(D) + stabilize_eigvals!(Ddiag) + A = project_hermitian!(V * D * V') + return A +end precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) diff --git a/test/chainrules.jl b/test/chainrules.jl index 4be77380..a8b2fd3b 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -221,7 +221,7 @@ end rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) + A = make_eig_matrix(rng, T, m) D, V = eig_full(A) Ddiag = diagview(D) ΔV = randn(rng, complex(T), m, m) @@ -297,8 +297,7 @@ end rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) - A = A + A' + A = make_eigh_matrix(rng, T, m) D, V = eigh_full(A) Ddiag = diagview(D) ΔV = randn(rng, T, m, m) diff --git a/test/mooncake.jl b/test/mooncake.jl index 0c1b5d56..760102b1 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -238,7 +238,7 @@ end rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) + A = make_eig_matrix(rng, T, m) DV = eig_full(A) D, V = DV Ddiag = diagview(D) @@ -347,9 +347,9 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb rng = StableRNG(12345) m = 19 atol = rtol = m * m * precision(T) - A = randn(rng, T, m, m) - A = A + A' + A = make_eigh_matrix(rng, T, m) D, V = eigh_full(A) + Ddiag = diagview(D) ΔV = randn(rng, T, m, m) ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol) ΔD = randn(rng, real(T), m, m) @@ -357,7 +357,6 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb dD = make_mooncake_tangent(ΔD2) dV = make_mooncake_tangent(ΔV) dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV) - Ddiag = diagview(D) @testset for alg in ( LAPACK_QRIteration(), #LAPACK_DivideAndConquer(), From 83a213001254f54a40631b433d20c726fed366a8 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 19 Dec 2025 11:48:51 +0100 Subject: [PATCH 2/3] simplify/clarify code --- test/ad_utils.jl | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/test/ad_utils.jl b/test/ad_utils.jl index 442134ee..7a7cf39a 100644 --- a/test/ad_utils.jl +++ b/test/ad_utils.jl @@ -27,16 +27,19 @@ function remove_eighgauge_dependence!( mul!(ΔV, V, gaugepart, -1, 1) return ΔV end -function stabilize_eigvals!(D) +function stabilize_eigvals!(D::AbstractVector) absD = abs.(D) p = invperm(sortperm(absD)) # rank of abs(D) + # account for exact degeneracies in absolute value when having complex conjugate pairs for i in 1:(length(D) - 1) - if absD[i] == absD[i + 1] - p[p .>= p[i + 1]] .-= 1 + if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially + p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones end end n = maximum(p) - radii = 1 / n * ((1:n) + rand(real(eltype(D)), n) / 2) + # rescale eigenvalues so that they lie on distinct radii in the complex plane + # that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n + radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n for i in 1:length(D) D[i] = sign(D[i]) * radii[p[i]] end @@ -45,22 +48,15 @@ end function make_eig_matrix(rng, T, n) A = randn(rng, T, n, n) D, V = eig_full(A) - Ddiag = diagview(D) - stabilize_eigvals!(Ddiag) - if T <: Real - A = real(V * D * inv(V)) - else - A = V * D * inv(V) - end - return A + stabilize_eigvals!(diagview(D)) + Ac = V * D * inv(V) + return (T <: Real) ? real(Ac) : Ac end function make_eigh_matrix(rng, T, n) A = project_hermitian!(randn(rng, T, n, n)) D, V = eigh_full(A) - Ddiag = diagview(D) - stabilize_eigvals!(Ddiag) - A = project_hermitian!(V * D * V') - return A + stabilize_eigvals!(diagview(D)) + return project_hermitian!(V * D * V') end precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T))) From a7dfef452247e9715f41bb43785bd9d5d466d5ae Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 19 Dec 2025 11:49:21 +0100 Subject: [PATCH 3/3] some random fix --- src/common/matrixproperties.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/matrixproperties.jl b/src/common/matrixproperties.jl index 25876b90..9180ae15 100644 --- a/src/common/matrixproperties.jl +++ b/src/common/matrixproperties.jl @@ -50,7 +50,7 @@ function is_left_isometric(A::AbstractMatrix; atol::Real = 0, rtol::Real = defau P = A' * A nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))` diagview(P) .-= 1 - return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)` + return norm(P) ≤ max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)` end @doc """