From 1fcf21091bd2f56ecd5f6fee10d9375eb66de0df Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 14:05:13 +0100 Subject: [PATCH 1/7] Use TestSuite for eigh --- ...MatrixAlgebraKitGenericLinearAlgebraExt.jl | 20 +- test/amd/eigh.jl | 105 ---------- test/cuda/eigh.jl | 118 ----------- test/eigh.jl | 184 +++++------------ test/genericlinearalgebra/eigh.jl | 93 --------- test/runtests.jl | 15 +- test/testsuite/TestSuite.jl | 1 + test/testsuite/eigh.jl | 192 ++++++++++++++++++ 8 files changed, 267 insertions(+), 461 deletions(-) delete mode 100644 test/amd/eigh.jl delete mode 100644 test/cuda/eigh.jl delete mode 100644 test/genericlinearalgebra/eigh.jl create mode 100644 test/testsuite/eigh.jl diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index e4637875..e319ddca 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -49,11 +49,27 @@ MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GL function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration) eigval, eigvec = eigen!(Hermitian(A); sortby = real) - return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)} + D, V = DV + if isnothing(D) + D = Diagonal(eigval::AbstractVector{real(eltype(A))}) + else + copyto!(D, Diagonal(eigval::AbstractVector{real(eltype(A))})) + end + if isnothing(V) + V = eigvec::AbstractMatrix{eltype(A)} + else + copyto!(V, eigvec::AbstractMatrix{eltype(A)}) + end + return D, V end function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration) - return eigvals!(Hermitian(A); sortby = real) + if isnothing(D) + D = eigvals!(Hermitian(A); sortby = real) + else + copyto!(D, eigvals!(Hermitian(A); sortby = real)) + end + return D end function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} diff --git a/test/amd/eigh.jl b/test/amd/eigh.jl deleted file mode 100644 index 4d23f128..00000000 --- a/test/amd/eigh.jl +++ /dev/null @@ -1,105 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview -using AMDGPU - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - ROCSOLVER_DivideAndConquer(), - ROCSOLVER_Jacobi(), - ROCSOLVER_Bisection(), - ROCSOLVER_QRIteration(), - ) - A = ROCArray(randn(rng, T, m, m)) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test parent(D) ≈ D3 - end -end - -#=@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (ROCSOLVER_QRIteration(), - ROCSOLVER_DivideAndConquer(), - ) - A = ROCArray(randn(rng, T, m, m)) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - - trunc = trunctol(; atol=s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - end -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - V = qr_compact(ROCArray(randn(rng, T, m, m)))[1] - D = Diagonal([0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(ROCSOLVER_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) -end=# - -@testset "eigh for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(ROCArray(Ad)) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - # TODO partialsortperm - #=A2 = Diagonal(ROCArray(T[0.9, 0.3, 0.1, 0.01])) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=# -end diff --git a/test/cuda/eigh.jl b/test/cuda/eigh.jl deleted file mode 100644 index a8171615..00000000 --- a/test/cuda/eigh.jl +++ /dev/null @@ -1,118 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview -using CUDA - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) - -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (CUSOLVER_DivideAndConquer(), CUSOLVER_Jacobi()) - A = CuArray(randn(rng, T, m, m)) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test parent(D) ≈ D3 - end -end -#= #TODO mul! -@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in (CUSOLVER_QRIteration(), - CUSOLVER_DivideAndConquer(), - ) - A = CuArray(randn(rng, T, m, m)) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc=truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 - end -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - V = qr_compact(CuArray(randn(rng, T, m, m)))[1] - D = Diagonal([0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) - @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(CUSOLVER_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end -=# -@testset "eigh for Diagonal{$T}" for T in BLASFloats - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(CuArray(Ad)) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - # TODO partialsortperm - #=A2 = Diagonal(CuArray(T[0.9, 0.3, 0.1, 0.01])) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol=# -end diff --git a/test/eigh.jl b/test/eigh.jl index 3b711c5b..2eac3386 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -3,139 +3,61 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm +using CUDA, AMDGPU BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) -GenericFloats = (Float16, BigFloat, Complex{BigFloat}) - -@testset "eigh_full! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - LAPACK_MultipleRelativelyRobustRepresentations(), - LAPACK_DivideAndConquer(), - LAPACK_QRIteration(), - LAPACK_Bisection(), - ) - A = randn(rng, T, m, m) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 === D - @test V2 === V - - D3 = @constinferred eigh_vals(A, alg) - @test D ≈ Diagonal(D3) +GenericFloats = (BigFloat, Complex{BigFloat}) + +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 54 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + if T ∈ BLASFloats + if CUDA.functional() + CUSOLVER_EIGH_ALGS = ( + CUSOLVER_Jacobi(), + CUSOLVER_DivideAndConquer(), + CUSOLVER_QRIteration(), + CUSOLVER_Bisection(), + ) + TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false) + TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false) + TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m) + TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) + end + if AMDGPU.functional() + ROCSOLVER_EIGH_ALGS = ( + ROCSOLVER_Jacobi(), + ROCSOLVER_DivideAndConquer(), + ROCSOLVER_QRIteration(), + ROCSOLVER_Bisection(), + ) + TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) + TestSuite.test_eigh(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false) + TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m) + TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),)) + end end -end - -@testset "eigh_trunc! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 54 - for alg in ( - LAPACK_QRIteration(), - LAPACK_Bisection(), - LAPACK_DivideAndConquer(), - LAPACK_MultipleRelativelyRobustRepresentations(), - ) - A = randn(rng, T, m, m) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D4, V4 = @constinferred eigh_trunc_no_error(A; alg, trunc) - @test length(diagview(D4)) == r - @test A * V4 ≈ V4 * D4 - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 + if !is_buildkite + TestSuite.test_eigh(T, (m, m)) + if T ∈ BLASFloats + LAPACK_EIGH_ALGS = ( + LAPACK_MultipleRelativelyRobustRepresentations(), + LAPACK_DivideAndConquer(), + LAPACK_QRIteration(), + LAPACK_Bisection(), + ) + TestSuite.test_eigh_algs(T, (m, m), LAPACK_EIGH_ALGS) + elseif T ∈ GenericFloats + GLA_EIGH_ALGS = (GLA_QRIteration(),) + TestSuite.test_eigh_algs(T, (m, m), GLA_EIGH_ALGS) + end + AT = Diagonal{T, Vector{T}} + TestSuite.test_eigh(AT, m) + TestSuite.test_eigh_algs(AT, m, (DiagonalAlgorithm(),)) end end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = qr_compact(randn(rng, T, m, m))[1] - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncerror(; atol = 0.2)) - D4, V4 = @constinferred eigh_trunc_no_error(A; alg) - @test diagview(D4) ≈ diagview(D)[1:2] -end - -@testset "eigh for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) - rng = StableRNG(123) - m = 54 - Ad = randn(rng, T, m) - Ad .+= conj.(Ad) - A = Diagonal(Ad) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eigh_full(A) - @test D isa Diagonal{real(T)} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eigh_vals(A) - @test D2 isa AbstractVector{real(T)} && length(D2) == m - @test diagview(D) ≈ D2 - - A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol - - A3 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(3)) - D3, V3 = @constinferred eigh_trunc_no_error(A3; alg) - @test diagview(D3) ≈ diagview(A3)[1:3] -end diff --git a/test/genericlinearalgebra/eigh.jl b/test/genericlinearalgebra/eigh.jl deleted file mode 100644 index 7e602026..00000000 --- a/test/genericlinearalgebra/eigh.jl +++ /dev/null @@ -1,93 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm -using GenericLinearAlgebra - -const eltypes = (BigFloat, Complex{BigFloat}) - -@testset "eigh_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - alg = GLA_QRIteration() - - A = randn(rng, T, m, m) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A; alg) - @test A * V ≈ V * D - @test isunitary(V) - @test all(isreal, D) - - D2, V2 = eigh_full!(copy(A), (D, V), alg) - @test D2 ≈ D - @test V2 ≈ V - - D3 = @constinferred eigh_vals(A, alg) - @test D ≈ Diagonal(D3) -end - -@testset "eigh_trunc! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 54 - alg = GLA_QRIteration() - A = randn(rng, T, m, m) - A = A * A' - A = (A + A') / 2 - Ac = similar(A) - D₀ = reverse(eigh_vals(A)) - - r = m - 2 - s = 1 + sqrt(eps(real(T))) - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eigh_trunc(A; alg, trunc = truncrank(r)) - Dfull, Vfull = eigh_full(A; alg) - @test length(diagview(D1)) == r - @test isisometric(V1) - @test A * V1 ≈ V1 * D1 - @test LinearAlgebra.opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - trunc = trunctol(; atol = s * D₀[r + 1]) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test isisometric(V2) - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # test for same subspace - @test V1 * (V1' * V2) ≈ V2 - @test V2 * (V2' * V1) ≈ V1 - @test V1 * (V1' * V3) ≈ V3 - @test V3 * (V3' * V1) ≈ V1 -end - -@testset "eigh_trunc! specify truncation algorithm T = $T" for T in eltypes - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = qr_compact(randn(rng, T, m, m))[1] - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * V' - A = (A + A') / 2 - alg = TruncatedAlgorithm(GLA_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eigh_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - - alg = TruncatedAlgorithm(GLA_QRIteration(), truncerror(; atol = 0.2)) - D3, V3, ϵ3 = @constinferred eigh_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end diff --git a/test/runtests.jl b/test/runtests.jl index 062801c0..baeb77bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,9 +13,6 @@ if !is_buildkite @safetestset "Singular Value Decomposition" begin include("svd.jl") end - @safetestset "Hermitian Eigenvalue Decomposition" begin - include("eigh.jl") - end @safetestset "Generalized Eigenvalue Decomposition" begin include("gen_eig.jl") end @@ -45,9 +42,6 @@ if !is_buildkite @safetestset "Singular Value Decomposition" begin include("genericlinearalgebra/svd.jl") end - @safetestset "Hermitian Eigenvalue Decomposition" begin - include("genericlinearalgebra/eigh.jl") - end end @safetestset "QR / LQ Decomposition" begin @@ -66,15 +60,15 @@ end @safetestset "General Eigenvalue Decomposition" begin include("eig.jl") end +@safetestset "Hermitian Eigenvalue Decomposition" begin + include("eigh.jl") +end using CUDA if CUDA.functional() @safetestset "CUDA SVD" begin include("cuda/svd.jl") end - @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin - include("cuda/eigh.jl") - end @safetestset "CUDA Image and Null Space" begin include("cuda/orthnull.jl") end @@ -85,9 +79,6 @@ if AMDGPU.functional() @safetestset "AMDGPU SVD" begin include("amd/svd.jl") end - @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin - include("amd/eigh.jl") - end @safetestset "AMDGPU Image and Null Space" begin include("amd/orthnull.jl") end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 6e13567c..25da777a 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -75,5 +75,6 @@ include("polar.jl") include("projections.jl") include("schur.jl") include("eig.jl") +include("eigh.jl") end diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl new file mode 100644 index 00000000..c79c3e68 --- /dev/null +++ b/test/testsuite/eigh.jl @@ -0,0 +1,192 @@ +using TestExtras +using GenericLinearAlgebra +using MatrixAlgebraKit: TruncatedAlgorithm +using LinearAlgebra: I, opnorm + +function test_eigh(T::Type, sz; test_trunc = true, kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "eigh $summary_str" begin + test_eigh_full(T, sz; kwargs...) + test_trunc && test_eigh_trunc(T, sz; kwargs...) + end +end + +function test_eigh_algs(T::Type, sz, algs; test_trunc = true, kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "eigh algorithms $summary_str" begin + test_eigh_full_algs(T, sz, algs; kwargs...) + test_trunc && test_eigh_trunc_algs(T, sz, algs; kwargs...) + end +end + +function test_eigh_full( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_full! $summary_str" begin + A = instantiate_matrix(T, sz) + A = (A + A') / 2 + Ac = deepcopy(A) + + D, V = @testinferred eigh_full(A) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(copy(A), (D, V)) + @test D2 === D + @test V2 === V + + D3 = @testinferred eigh_vals(A) + @test D ≈ Diagonal(D3) + end +end + +function test_eigh_full_algs( + T::Type, sz, algs; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_full! algorithm $alg $summary_str" for alg in algs + A = instantiate_matrix(T, sz) + A = (A + A') / 2 + Ac = deepcopy(A) + + D, V = @testinferred eigh_full(A; alg) + @test A * V ≈ V * D + @test isunitary(V) + @test all(isreal, D) + + D2, V2 = eigh_full!(copy(A), (D, V); alg) + @test D2 === D + @test V2 === V + + D3 = @testinferred eigh_vals(A; alg) + @test D ≈ Diagonal(D3) + end +end + +function test_eigh_trunc( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_trunc! $summary_str" begin + A = instantiate_matrix(T, sz) + A = A * A' + A = (A + A') / 2 + Ac = deepcopy(A) + if !(T <: Diagonal) + + m = size(A, 1) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(eltype(T)))) + atol = sqrt(eps(real(eltype(T)))) + # truncrank + D1, V1, ϵ1 = @testinferred eigh_trunc(A; trunc = truncrank(r)) + @test length(diagview(D1)) == r + @test isisometric(V1) + @test A * V1 ≈ V1 * D1 + @test opnorm(A - V1 * D1 * V1') ≈ D₀[r + 1] + @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + # trunctol + trunc = trunctol(; atol = s * D₀[r + 1]) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D2)) == r + @test isisometric(V2) + @test A * V2 ≈ V2 * D2 + @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + #truncerror + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; trunc) + @test length(diagview(D3)) == r + @test A * V3 ≈ V3 * D3 + @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol + + s = 1 - sqrt(eps(real(eltype(T)))) + trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) + D4, V4 = @testinferred eigh_trunc_no_error(A; trunc) + @test length(diagview(D4)) == r + @test A * V4 ≈ V4 * D4 + + # test for same subspace + @test V1 * (V1' * V2) ≈ V2 + @test V2 * (V2' * V1) ≈ V1 + @test V1 * (V1' * V3) ≈ V3 + @test V3 * (V3' * V1) ≈ V1 + @test V4 * (V4' * V1) ≈ V1 + end + @testset "specify truncation algorithm" begin + atol = sqrt(eps(real(eltype(T)))) + m4 = 4 + smallA = randn!(similar(A, (m4, m4))) + V = T <: Diagonal ? I : qr_compact(smallA)[1] + Ddiag = similar(A, real(eltype(T)), m4) + copyto!(Ddiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + D = Diagonal(Ddiag) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(2)) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test_throws ArgumentError eigh_trunc(A; alg, trunc = (; maxrank = 2)) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + + alg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncerror(; atol = 0.2)) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; alg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + end + end +end + +function test_eigh_trunc_algs( + T::Type, sz, algs; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "eigh_trunc! algorithm $alg $summary_str" for alg in algs + A = instantiate_matrix(T, sz) + A = A * A' + A = (A + A') / 2 + Ac = deepcopy(A) + + m = size(A, 1) + D₀ = reverse(eigh_vals(A)) + r = m - 2 + s = 1 + sqrt(eps(real(eltype(T)))) + # truncrank + atol = sqrt(eps(real(eltype(T)))) + m4 = 4 + smallA = randn!(similar(A, (m4, m4))) + V = T <: Diagonal ? I : qr_compact(smallA)[1] + Ddiag = similar(A, real(eltype(T)), m4) + copyto!(Ddiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) + D = Diagonal(Ddiag) + A = V * D * V' + A = (A + A') / 2 + truncalg = TruncatedAlgorithm(alg, truncrank(2)) + D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg = truncalg) + @test diagview(D2) ≈ diagview(D)[1:2] + @test_throws ArgumentError eigh_trunc(A; alg = truncalg, trunc = (; maxrank = 2)) + @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol + + truncalg = TruncatedAlgorithm(alg, truncerror(; atol = 0.2)) + D3, V3, ϵ3 = @testinferred eigh_trunc(A; alg = truncalg) + @test diagview(D3) ≈ diagview(D)[1:2] + @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol + + truncalg = TruncatedAlgorithm(alg, truncerror(; atol = 0.2)) + D4, V4 = @testinferred eigh_trunc_no_error(A; alg = truncalg) + @test diagview(D4) ≈ diagview(D)[1:2] + end +end From e6845cb22c445621da08e4f66f2414cc28ef6bf1 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 14:43:37 +0100 Subject: [PATCH 2/7] No bisection for CUDA --- test/eigh.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/eigh.jl b/test/eigh.jl index 2eac3386..7560fd74 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -22,7 +22,6 @@ for T in (BLASFloats..., GenericFloats...) CUSOLVER_Jacobi(), CUSOLVER_DivideAndConquer(), CUSOLVER_QRIteration(), - CUSOLVER_Bisection(), ) TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false) TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false) From 1efc623411fd122397846ab1cb000aaecc2440d1 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 09:12:08 -0500 Subject: [PATCH 3/7] No CUSOLVER_QRIteration --- test/eigh.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/eigh.jl b/test/eigh.jl index 7560fd74..f5608c07 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -21,7 +21,6 @@ for T in (BLASFloats..., GenericFloats...) CUSOLVER_EIGH_ALGS = ( CUSOLVER_Jacobi(), CUSOLVER_DivideAndConquer(), - CUSOLVER_QRIteration(), ) TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false) TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false) From c9aa24b6f019ca026d23d09f93498078a0f8178f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 09:18:51 -0500 Subject: [PATCH 4/7] Don't trunc for GPU --- test/eigh.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/eigh.jl b/test/eigh.jl index f5608c07..2a1fc4d4 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -24,8 +24,8 @@ for T in (BLASFloats..., GenericFloats...) ) TestSuite.test_eigh(CuMatrix{T}, (m, m); test_trunc = false) TestSuite.test_eigh_algs(CuMatrix{T}, (m, m), CUSOLVER_EIGH_ALGS; test_trunc = false) - TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m) - TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),)) + TestSuite.test_eigh(Diagonal{T, CuVector{T}}, m; test_trunc = false) + TestSuite.test_eigh_algs(Diagonal{T, CuVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) end if AMDGPU.functional() ROCSOLVER_EIGH_ALGS = ( @@ -36,8 +36,8 @@ for T in (BLASFloats..., GenericFloats...) ) TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) TestSuite.test_eigh(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false) - TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m) - TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),)) + TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_trunc = false) + TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) end end if !is_buildkite From 9b891df17d8759de6d496e047be647673df24508 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 09:46:34 -0500 Subject: [PATCH 5/7] Typo --- test/eigh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/eigh.jl b/test/eigh.jl index 2a1fc4d4..8766ccc0 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -35,7 +35,7 @@ for T in (BLASFloats..., GenericFloats...) ROCSOLVER_Bisection(), ) TestSuite.test_eigh(ROCMatrix{T}, (m, m); test_trunc = false) - TestSuite.test_eigh(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false) + TestSuite.test_eigh_algs(ROCMatrix{T}, (m, m), ROCSOLVER_EIGH_ALGS; test_trunc = false) TestSuite.test_eigh(Diagonal{T, ROCVector{T}}, m; test_trunc = false) TestSuite.test_eigh_algs(Diagonal{T, ROCVector{T}}, m, (DiagonalAlgorithm(),); test_trunc = false) end From 0cab2e86fea52c474e032b05f4786e6c9b8dc71e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Tue, 23 Dec 2025 13:18:19 -0500 Subject: [PATCH 6/7] Comments --- ...MatrixAlgebraKitGenericLinearAlgebraExt.jl | 20 ++----------- test/testsuite/eigh.jl | 30 +++++++------------ 2 files changed, 12 insertions(+), 38 deletions(-) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index e319ddca..e4637875 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -49,27 +49,11 @@ MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GL function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration) eigval, eigvec = eigen!(Hermitian(A); sortby = real) - D, V = DV - if isnothing(D) - D = Diagonal(eigval::AbstractVector{real(eltype(A))}) - else - copyto!(D, Diagonal(eigval::AbstractVector{real(eltype(A))})) - end - if isnothing(V) - V = eigvec::AbstractMatrix{eltype(A)} - else - copyto!(V, eigvec::AbstractMatrix{eltype(A)}) - end - return D, V + return Diagonal(eigval::AbstractVector{real(eltype(A))}), eigvec::AbstractMatrix{eltype(A)} end function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration) - if isnothing(D) - D = eigvals!(Hermitian(A); sortby = real) - else - copyto!(D, eigvals!(Hermitian(A); sortby = real)) - end - return D + return eigvals!(Hermitian(A); sortby = real) end function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl index c79c3e68..b1d8dadb 100644 --- a/test/testsuite/eigh.jl +++ b/test/testsuite/eigh.jl @@ -26,8 +26,7 @@ function test_eigh_full( ) summary_str = testargs_summary(T, sz) return @testset "eigh_full! $summary_str" begin - A = instantiate_matrix(T, sz) - A = (A + A') / 2 + A = project_hermitian!(instantiate_matrix(T, sz)) Ac = deepcopy(A) D, V = @testinferred eigh_full(A) @@ -35,9 +34,8 @@ function test_eigh_full( @test isunitary(V) @test all(isreal, D) - D2, V2 = eigh_full!(copy(A), (D, V)) - @test D2 === D - @test V2 === V + D2, V2 = eigh_full!(Ac, (D, V)) + @test A * V2 ≈ V2 * D2 D3 = @testinferred eigh_vals(A) @test D ≈ Diagonal(D3) @@ -51,8 +49,7 @@ function test_eigh_full_algs( ) summary_str = testargs_summary(T, sz) return @testset "eigh_full! algorithm $alg $summary_str" for alg in algs - A = instantiate_matrix(T, sz) - A = (A + A') / 2 + A = project_hermitian!(instantiate_matrix(T, sz)) Ac = deepcopy(A) D, V = @testinferred eigh_full(A; alg) @@ -60,9 +57,8 @@ function test_eigh_full_algs( @test isunitary(V) @test all(isreal, D) - D2, V2 = eigh_full!(copy(A), (D, V); alg) - @test D2 === D - @test V2 === V + D2, V2 = eigh_full!(Ac, (D, V); alg) + @test A * V2 ≈ V2 * D2 D3 = @testinferred eigh_vals(A; alg) @test D ≈ Diagonal(D3) @@ -76,9 +72,7 @@ function test_eigh_trunc( ) summary_str = testargs_summary(T, sz) return @testset "eigh_trunc! $summary_str" begin - A = instantiate_matrix(T, sz) - A = A * A' - A = (A + A') / 2 + A = project_hermitian!(instantiate_matrix(T, sz)) Ac = deepcopy(A) if !(T <: Diagonal) @@ -132,8 +126,7 @@ function test_eigh_trunc( Ddiag = similar(A, real(eltype(T)), m4) copyto!(Ddiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) D = Diagonal(Ddiag) - A = V * D * V' - A = (A + A') / 2 + A = project_hermitian!(V * D * V') alg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(2)) D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] @@ -155,9 +148,7 @@ function test_eigh_trunc_algs( ) summary_str = testargs_summary(T, sz) return @testset "eigh_trunc! algorithm $alg $summary_str" for alg in algs - A = instantiate_matrix(T, sz) - A = A * A' - A = (A + A') / 2 + A = project_hermitian!(instantiate_matrix(T, sz)) Ac = deepcopy(A) m = size(A, 1) @@ -172,8 +163,7 @@ function test_eigh_trunc_algs( Ddiag = similar(A, real(eltype(T)), m4) copyto!(Ddiag, real(eltype(T))[0.9, 0.3, 0.1, 0.01]) D = Diagonal(Ddiag) - A = V * D * V' - A = (A + A') / 2 + A = project_hermitian!(V * D * V') truncalg = TruncatedAlgorithm(alg, truncrank(2)) D2, V2, ϵ2 = @testinferred eigh_trunc(A; alg = truncalg) @test diagview(D2) ≈ diagview(D)[1:2] From 6bdb1d55421c35536975374fbd93d9d7c3c56f2b Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Sat, 27 Dec 2025 12:22:34 +0100 Subject: [PATCH 7/7] Fix eigh_trunc --- test/testsuite/eigh.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/testsuite/eigh.jl b/test/testsuite/eigh.jl index b1d8dadb..df6e4d6e 100644 --- a/test/testsuite/eigh.jl +++ b/test/testsuite/eigh.jl @@ -72,7 +72,9 @@ function test_eigh_trunc( ) summary_str = testargs_summary(T, sz) return @testset "eigh_trunc! $summary_str" begin - A = project_hermitian!(instantiate_matrix(T, sz)) + A = instantiate_matrix(T, sz) + A = A * A' + A = project_hermitian!(A) Ac = deepcopy(A) if !(T <: Diagonal) @@ -148,7 +150,9 @@ function test_eigh_trunc_algs( ) summary_str = testargs_summary(T, sz) return @testset "eigh_trunc! algorithm $alg $summary_str" for alg in algs - A = project_hermitian!(instantiate_matrix(T, sz)) + A = instantiate_matrix(T, sz) + A = A * A' + A = project_hermitian!(A) Ac = deepcopy(A) m = size(A, 1)