diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl deleted file mode 100644 index 4d23f128..00000000 --- a/test/amd/eigh.jl +++ /dev/null @@ -1,105 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview -using AMDGPU - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - ROCSOLVER_DivideAndConquer(), - ROCSOLVER_Jacobi(), - ROCSOLVER_Bisection(), - ROCSOLVER_QRIteration(), - ) - A = ROCArray(randn(rng, T, m, m)) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test parent(D) ≈ D3 - end -end - -#=@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (ROCSOLVER_QRIteration(), - ROCSOLVER_DivideAndConquer(), - ) - A = ROCArray(randn(rng, T, m, m)) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - - trunc = trunctol(; atol=s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - end -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - V = qr_compact(ROCArray(randn(rng, T, m, m)))[1] - D = Diagonal([0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(ROCSOLVER_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) -end=# - -@testset "eigh for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(ROCArray(Ad)) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - # TODO partialsortperm - #=A2 = Diagonal(ROCArray(T[0.9, 0.3, 0.1, 0.01])) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=# -end diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl deleted file mode 100644 index a8171615..00000000 --- a/test/cuda/eigh.jl +++ /dev/null @@ -1,118 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview -using CUDA - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (CUSOLVER_DivideAndConquer(), CUSOLVER_Jacobi()) - A = CuArray(randn(rng, T, m, m)) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test parent(D) ≈ D3 - end -end -#= #TODO mul! -@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (CUSOLVER_QRIteration(), - CUSOLVER_DivideAndConquer(), - ) - A = CuArray(randn(rng, T, m, m)) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 - end -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - V = qr_compact(CuArray(randn(rng, T, m, m)))[1] - D = Diagonal([0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end -=# -@testset "eigh for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(CuArray(Ad)) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - # TODO partialsortperm - #=A2 = Diagonal(CuArray(T[0.9, 0.3, 0.1, 0.01])) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=# -end diff --git a/test/eigh.jl b/test/eigh.jl index 3b711c5b..8766ccc0 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -3,139 +3,59 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm +using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) -GenericFloats = (Float16, BigFloat, Complex{BigFloat}) - -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - LAPACK_MultipleRelativelyRobustRepresentations(), - LAPACK_DivideAndConquer(), - LAPACK_QRIteration(), - LAPACK_Bisection(), - ) - A = randn(rng, T, m, m) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test D ≈ Diagonal(D3) +GenericFloats = (BigFloat, Complex{BigFloat}) + +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 54 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + if T ∈ BLASFloats + if CUDA.functional() + CUSOLVER_EIGH_ALGS = ( + CUSOLVER_Jacobi(), + CUSOLVER_DivideAndConquer(), + ) + TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false) + TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false) + TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m; test_trunc = false) + TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) + end + if AMDGPU.functional() + ROCSOLVER_EIGH_ALGS = ( + ROCSOLVER_Jacobi(), + ROCSOLVER_DivideAndConquer(), + ROCSOLVER_QRIteration(), + ROCSOLVER_Bisection(), + ) + TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) + TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false) + TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_trunc = false) + TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) + end end -end - -@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - LAPACK_QRIteration(), - LAPACK_Bisection(), - LAPACK_DivideAndConquer(), - LAPACK_MultipleRelativelyRobustRepresentations(), - ) - A = randn(rng, T, m, m) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D4, V4 = @constinferred eigh_trunc_no_error(A; alg, trunc) - @test length(diagview(D4)) == r - @test A * V4 ≈ V4 * D4 - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 + if !is_buildkite + TestSuite.test_eigh(T, (m, m)) + if T ∈ BLASFloats + LAPACK_EIGH_ALGS = ( + LAPACK_MultipleRelativelyRobustRepresentations(), + LAPACK_DivideAndConquer(), + LAPACK_QRIteration(), + LAPACK_Bisection(), + ) + TestSuite.test_eigh_algs(T, (m, m), LAPACK_EIGH_ALGS) + elseif T ∈ GenericFloats + GLA_EIGH_ALGS = (GLA_QRIteration(),) + TestSuite.test_eigh_algs(T, (m, m), GLA_EIGH_ALGS) + end + AT = Diagonal{T, Vector{T}} + TestSuite.test_eigh(AT, m) + TestSuite.test_eigh_algs(AT, m, (DiagonalAlgorithm(),)) end end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = qr_compact(randn(rng, T, m, m))[1] - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) - D4, V4 = @constinferred eigh_trunc_no_error(A; alg) - @test diagview(D4) ≈ diagview(D)[1:2] -end - -@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(Ad) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol - - A3 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(3)) - D3, V3 = @constinferred eigh_trunc_no_error(A3; alg) - @test diagview(D3) ≈ diagview(A3)[1:3] -end diff --git a/test/genericlinearalgebra/eigh.jl b/test/genericlinearalgebra/eigh.jl deleted file mode 100644 index 7e602026..00000000 --- a/test/genericlinearalgebra/eigh.jl +++ /dev/null @@ -1,93 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm -using GenericLinearAlgebra - -const eltypes = (BigFloat, Complex{BigFloat}) - -@testset "eigh_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - alg = GLA_QRIteration() - - A = randn(rng, T, m, m) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 ≈ D - @test V2 ≈ V - - D3 = @constinferred eigh_vals(A, alg) - @test D ≈ Diagonal(D3) -end - -@testset "eigh_trunc! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - alg = GLA_QRIteration() - A = randn(rng, T, m, m) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) - Dfull, Vfull = eigh_full(A; alg) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in eltypes - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = qr_compact(randn(rng, T, m, m))[1] - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(GLA_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(GLA_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end diff --git a/test/runtests.jl b/test/runtests.jl index 062801c0..baeb77bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,9 +13,6 @@ if !is_buildkite @safetestset "Singular Value Decomposition" begin include("svd.jl") end - @safetestset "Hermitian Eigenvalue Decomposition" begin - include("eigh.jl") - end @safetestset "Generalized Eigenvalue Decomposition" begin include("gen_eig.jl") end @@ -45,9 +42,6 @@ if !is_buildkite @safetestset "Singular Value Decomposition" begin include("genericlinearalgebra/svd.jl") end - @safetestset "Hermitian Eigenvalue Decomposition" begin - include("genericlinearalgebra/eigh.jl") - end end @safetestset "QR / LQ Decomposition" begin @@ -66,15 +60,15 @@ end @safetestset "General Eigenvalue Decomposition" begin include("eig.jl") end +@safetestset "Hermitian Eigenvalue Decomposition" begin + include("eigh.jl") +end using CUDA if CUDA.functional() @safetestset "CUDA SVD" begin include("cuda/svd.jl") end - @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin - include("cuda/eigh.jl") - end @safetestset "CUDA Image and Null Space" begin include("cuda/orthnull.jl") end @@ -85,9 +79,6 @@ if AMDGPU.functional() @safetestset "AMDGPU SVD" begin include("amd/svd.jl") end - @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin - include("amd/eigh.jl") - end @safetestset "AMDGPU Image and Null Space" begin include("amd/orthnull.jl") end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 6e13567c..25da777a 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -75,5 +75,6 @@ include("polar.jl") include("projections.jl") include("schur.jl") include("eig.jl") +include("eigh.jl") end diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl new file mode 100644 index 00000000..df6e4d6e --- /dev/null +++ b/test/testsuite/eigh.jl @@ -0,0 +1,186 @@ +using TestExtras +using GenericLinearAlgebra +using MatrixAlgebraKit: TruncatedAlgorithm +using LinearAlgebra: I, opnorm + +function test_eigh(T::Type, sz; test_trunc = true, kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "eigh $summary_str" begin + test_eigh_full(T, sz; kwargs...) + test_trunc && test_eigh_trunc(T, sz; kwargs...) + end +end + +function test_eigh_algs(T::Type, sz, algs; test_trunc = true, kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "eigh algorithms $summary_str" begin + test_eigh_full_algs(T, sz, algs; kwargs...) + test_trunc && test_eigh_trunc_algs(T, sz, algs; kwargs...) + end +end + +function test_eigh_full( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_full! $summary_str" begin + A = project_hermitian!(instantiate_matrix(T, sz)) + Ac = deepcopy(A) + + D, V = @testinferred eigh_full(A) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(Ac, (D, V)) + @test A * V2 ≈ V2 * D2 + + D3 = @testinferred eigh_vals(A) + @test D ≈ Diagonal(D3) + end +end + +function test_eigh_full_algs( + T::Type, sz, algs; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_full! algorithm $alg $summary_str" for alg in algs + A = project_hermitian!(instantiate_matrix(T, sz)) + Ac = deepcopy(A) + + D, V = @testinferred eigh_full(A; alg) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(Ac, (D, V); alg) + @test A * V2 ≈ V2 * D2 + + D3 = @testinferred eigh_vals(A; alg) + @test D ≈ Diagonal(D3) + end +end + +function test_eigh_trunc( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_trunc! $summary_str" begin + A = instantiate_matrix(T, sz) + A = A * A' + A = project_hermitian!(A) + Ac = deepcopy(A) + if !(T <: Diagonal) + + m = size(A, 1) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(eltype(T)))) + atol = sqrt(eps(real(eltype(T)))) + # truncrank + D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r)) + @test length(diagview(D1)) == r + @test isisometric(V1) + @test A * V1 ≈ V1 * D1 + @test opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # trunctol + trunc = trunctol(; atol = s * D₀[r + 1]) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D2)) == r + @test isisometric(V2) + @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + #truncerror + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D3)) == r + @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D4, V4 = @testinferred eigh_trunc_no_error(A; trunc) + @test length(diagview(D4)) == r + @test A * V4 ≈ V4 * D4 + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + @test V1 * (V1' * V3) ≈ V3 + @test V3 * (V3' * V1) ≈ V1 + @test V4 * (V4' * V1) ≈ V1 + end + @testset "specify truncation algorithm" begin + atol = sqrt(eps(real(eltype(T)))) + m4 = 4 + smallA = randn!(similar(A, (m4, m4))) + V = T <: Diagonal ? I : qr_compact(smallA)[1] + Ddiag = similar(A, real(eltype(T)), m4) + copyto!(Ddiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + D = Diagonal(Ddiag) + A = project_hermitian!(V * D * V') + alg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(2)) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + + alg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncerror(; atol = 0.2)) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; alg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + end + end +end + +function test_eigh_trunc_algs( + T::Type, sz, algs; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_trunc! algorithm $alg $summary_str" for alg in algs + A = instantiate_matrix(T, sz) + A = A * A' + A = project_hermitian!(A) + Ac = deepcopy(A) + + m = size(A, 1) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(eltype(T)))) + # truncrank + atol = sqrt(eps(real(eltype(T)))) + m4 = 4 + smallA = randn!(similar(A, (m4, m4))) + V = T <: Diagonal ? I : qr_compact(smallA)[1] + Ddiag = similar(A, real(eltype(T)), m4) + copyto!(Ddiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + D = Diagonal(Ddiag) + A = project_hermitian!(V * D * V') + truncalg = TruncatedAlgorithm(alg, truncrank(2)) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg = truncalg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test_throws ArgumentError eigh_trunc(A; alg = truncalg, trunc = (; maxrank = 2)) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + + truncalg = TruncatedAlgorithm(alg, truncerror(; atol = 0.2)) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; alg = truncalg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + + truncalg = TruncatedAlgorithm(alg, truncerror(; atol = 0.2)) + D4, V4 = @testinferred eigh_trunc_no_error(A; alg = truncalg) + @test diagview(D4) ≈ diagview(D)[1:2] + end +end