Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/common/matrixproperties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function is_left_isometric(A::AbstractMatrix; atol::Real = 0, rtol::Real = defau
P = A' * A
nP = norm(P) # isapprox would use `rtol * max(norm(P), norm(I))`
diagview(P) .-= 1
return norm(P) <= max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
return norm(P) max(atol, rtol * nP) # assume that the norm of I is `sqrt(n)`
end

@doc """
Expand Down
31 changes: 31 additions & 0 deletions test/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,36 @@ function remove_eighgauge_dependence!(
mul!(ΔV, V, gaugepart, -1, 1)
return ΔV
end
function stabilize_eigvals!(D::AbstractVector)
absD = abs.(D)
p = invperm(sortperm(absD)) # rank of abs(D)
# account for exact degeneracies in absolute value when having complex conjugate pairs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! This will help us a lot if we have to revisit in 6 months, I think.

for i in 1:(length(D) - 1)
if absD[i] == absD[i + 1] # conjugate pairs will appear sequentially
p[p .>= p[i + 1]] .-= 1 # lower the rank of all higher ones
end
end
n = maximum(p)
# rescale eigenvalues so that they lie on distinct radii in the complex plane
# that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n
radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n
for i in 1:length(D)
D[i] = sign(D[i]) * radii[p[i]]
end
return D
end
function make_eig_matrix(rng, T, n)
A = randn(rng, T, n, n)
D, V = eig_full(A)
stabilize_eigvals!(diagview(D))
Ac = V * D * inv(V)
return (T <: Real) ? real(Ac) : Ac
end
function make_eigh_matrix(rng, T, n)
A = project_hermitian!(randn(rng, T, n, n))
D, V = eigh_full(A)
stabilize_eigvals!(diagview(D))
return project_hermitian!(V * D * V')
end

precision(::Type{T}) where {T <: Number} = sqrt(eps(real(T)))
5 changes: 2 additions & 3 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ end
rng = StableRNG(12345)
m = 19
atol = rtol = m * m * precision(T)
A = randn(rng, T, m, m)
A = make_eig_matrix(rng, T, m)
D, V = eig_full(A)
Ddiag = diagview(D)
ΔV = randn(rng, complex(T), m, m)
Expand Down Expand Up @@ -297,8 +297,7 @@ end
rng = StableRNG(12345)
m = 19
atol = rtol = m * m * precision(T)
A = randn(rng, T, m, m)
A = A + A'
A = make_eigh_matrix(rng, T, m)
D, V = eigh_full(A)
Ddiag = diagview(D)
ΔV = randn(rng, T, m, m)
Expand Down
7 changes: 3 additions & 4 deletions test/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ end
rng = StableRNG(12345)
m = 19
atol = rtol = m * m * precision(T)
A = randn(rng, T, m, m)
A = make_eig_matrix(rng, T, m)
DV = eig_full(A)
D, V = DV
Ddiag = diagview(D)
Expand Down Expand Up @@ -347,17 +347,16 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc_no_error), A) = MatrixAlgeb
rng = StableRNG(12345)
m = 19
atol = rtol = m * m * precision(T)
A = randn(rng, T, m, m)
A = A + A'
A = make_eigh_matrix(rng, T, m)
D, V = eigh_full(A)
Ddiag = diagview(D)
ΔV = randn(rng, T, m, m)
ΔV = remove_eighgauge_dependence!(ΔV, D, V; degeneracy_atol = atol)
ΔD = randn(rng, real(T), m, m)
ΔD2 = Diagonal(randn(rng, real(T), m))
dD = make_mooncake_tangent(ΔD2)
dV = make_mooncake_tangent(ΔV)
dDV = Mooncake.build_tangent(typeof((ΔD2, ΔV)), dD, dV)
Ddiag = diagview(D)
@testset for alg in (
LAPACK_QRIteration(),
#LAPACK_DivideAndConquer(),
Expand Down