Skip to content

Commit 0cab2e8

Browse files
kshyattKatharine Hyatt
authored andcommitted
Comments
1 parent 9b891df commit 0cab2e8

File tree

2 files changed

+12
-38
lines changed

2 files changed

+12
-38
lines changed

ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,27 +49,11 @@ MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GL
4949

5050
function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
5151
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
52-
D, V = DV
53-
if isnothing(D)
54-
D = Diagonal(eigval::AbstractVector{real(eltype(A))})
55-
else
56-
copyto!(D, Diagonal(eigval::AbstractVector{real(eltype(A))}))
57-
end
58-
if isnothing(V)
59-
V = eigvec::AbstractMatrix{eltype(A)}
60-
else
61-
copyto!(V, eigvec::AbstractMatrix{eltype(A)})
62-
end
63-
return D, V
52+
return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)}
6453
end
6554

6655
function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
67-
if isnothing(D)
68-
D = eigvals!(Hermitian(A); sortby = real)
69-
else
70-
copyto!(D, eigvals!(Hermitian(A); sortby = real))
71-
end
72-
return D
56+
return eigvals!(Hermitian(A); sortby = real)
7357
end
7458

7559
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}

test/testsuite/eigh.jl

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,16 @@ function test_eigh_full(
2626
)
2727
summary_str = testargs_summary(T, sz)
2828
return @testset "eigh_full! $summary_str" begin
29-
A = instantiate_matrix(T, sz)
30-
A = (A + A') / 2
29+
A = project_hermitian!(instantiate_matrix(T, sz))
3130
Ac = deepcopy(A)
3231

3332
D, V = @testinferred eigh_full(A)
3433
@test A * V V * D
3534
@test isunitary(V)
3635
@test all(isreal, D)
3736

38-
D2, V2 = eigh_full!(copy(A), (D, V))
39-
@test D2 === D
40-
@test V2 === V
37+
D2, V2 = eigh_full!(Ac, (D, V))
38+
@test A * V2 V2 * D2
4139

4240
D3 = @testinferred eigh_vals(A)
4341
@test D Diagonal(D3)
@@ -51,18 +49,16 @@ function test_eigh_full_algs(
5149
)
5250
summary_str = testargs_summary(T, sz)
5351
return @testset "eigh_full! algorithm $alg $summary_str" for alg in algs
54-
A = instantiate_matrix(T, sz)
55-
A = (A + A') / 2
52+
A = project_hermitian!(instantiate_matrix(T, sz))
5653
Ac = deepcopy(A)
5754

5855
D, V = @testinferred eigh_full(A; alg)
5956
@test A * V V * D
6057
@test isunitary(V)
6158
@test all(isreal, D)
6259

63-
D2, V2 = eigh_full!(copy(A), (D, V); alg)
64-
@test D2 === D
65-
@test V2 === V
60+
D2, V2 = eigh_full!(Ac, (D, V); alg)
61+
@test A * V2 V2 * D2
6662

6763
D3 = @testinferred eigh_vals(A; alg)
6864
@test D Diagonal(D3)
@@ -76,9 +72,7 @@ function test_eigh_trunc(
7672
)
7773
summary_str = testargs_summary(T, sz)
7874
return @testset "eigh_trunc! $summary_str" begin
79-
A = instantiate_matrix(T, sz)
80-
A = A * A'
81-
A = (A + A') / 2
75+
A = project_hermitian!(instantiate_matrix(T, sz))
8276
Ac = deepcopy(A)
8377
if !(T <: Diagonal)
8478

@@ -132,8 +126,7 @@ function test_eigh_trunc(
132126
Ddiag = similar(A, real(eltype(T)), m4)
133127
copyto!(Ddiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01])
134128
D = Diagonal(Ddiag)
135-
A = V * D * V'
136-
A = (A + A') / 2
129+
A = project_hermitian!(V * D * V')
137130
alg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(2))
138131
D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg)
139132
@test diagview(D2) diagview(D)[1:2]
@@ -155,9 +148,7 @@ function test_eigh_trunc_algs(
155148
)
156149
summary_str = testargs_summary(T, sz)
157150
return @testset "eigh_trunc! algorithm $alg $summary_str" for alg in algs
158-
A = instantiate_matrix(T, sz)
159-
A = A * A'
160-
A = (A + A') / 2
151+
A = project_hermitian!(instantiate_matrix(T, sz))
161152
Ac = deepcopy(A)
162153

163154
m = size(A, 1)
@@ -172,8 +163,7 @@ function test_eigh_trunc_algs(
172163
Ddiag = similar(A, real(eltype(T)), m4)
173164
copyto!(Ddiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01])
174165
D = Diagonal(Ddiag)
175-
A = V * D * V'
176-
A = (A + A') / 2
166+
A = project_hermitian!(V * D * V')
177167
truncalg = TruncatedAlgorithm(alg, truncrank(2))
178168
D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg = truncalg)
179169
@test diagview(D2) diagview(D)[1:2]

0 commit comments

Comments
 (0)