diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index 0af53afb..686ca87b 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -2,10 +2,10 @@ module MatrixAlgebraKitGenericSchurExt using MatrixAlgebraKit using MatrixAlgebraKit: check_input -using LinearAlgebra: Diagonal +using LinearAlgebra: Diagonal, sorteig! using GenericSchur -function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} +function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} return GS_QRIteration(; kwargs...) end @@ -21,4 +21,21 @@ function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration) return GenericSchur.eigvals!(A) end +function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration) + check_input(schur_full!, A, TZv, alg) + T, Z, vals = TZv + S = GenericSchur.gschur(A) + copyto!(T, S.T) + copyto!(Z, S.Z) + copyto!(vals, S.values) + return T, Z, vals +end + +function MatrixAlgebraKit.schur_vals!(A::AbstractMatrix, vals, alg::GS_QRIteration) + check_input(schur_vals!, A, vals, alg) + S = GenericSchur.gschur(A) + copyto!(vals, sorteig!(S.values)) + return vals +end + end diff --git a/test/genericschur/eig.jl b/test/genericschur/eig.jl deleted file mode 100644 index ce1e8f1b..00000000 --- a/test/genericschur/eig.jl +++ /dev/null @@ -1,116 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using LinearAlgebra: Diagonal -using MatrixAlgebraKit: TruncatedAlgorithm, diagview, norm -using GenericSchur - -const eltypes = (BigFloat, Complex{BigFloat}) - -@testset "eig_full! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 24 - alg = GS_QRIteration() - A = randn(rng, T, m, m) - Tc = complex(T) - - D, V = @constinferred eig_full(A; alg = ($alg)) - @test eltype(D) == eltype(V) == Tc - @test A * V ≈ V * D - - alg′ = @constinferred MatrixAlgebraKit.select_algorithm(eig_full!, A, $alg) - - Ac = similar(A) - D2, V2 = @constinferred eig_full!(copy!(Ac, A), (D, V), alg′) - @test D2 ≈ D - @test V2 ≈ V - @test A * V ≈ V * D - - Dc = @constinferred eig_vals(A, alg′) - @test eltype(Dc) == Tc - @test D ≈ Diagonal(Dc) -end - -@testset "eig_trunc! for T = $T" for T in eltypes - rng = StableRNG(123) - m = 6 - alg = GS_QRIteration() - A = randn(rng, T, m, m) - A *= A' # TODO: deal with eigenvalue ordering etc - # eigenvalues are sorted by ascending real component... - D₀ = sort!(eig_vals(A); by = abs, rev = true) - rmin = findfirst(i -> abs(D₀[end - i]) != abs(D₀[end - i - 1]), 1:(m - 2)) - r = length(D₀) - rmin - atol = sqrt(eps(real(T))) - - D1, V1, ϵ1 = @constinferred eig_trunc(A; alg, trunc = truncrank(r)) - D1base, V1base = @constinferred eig_full(A; alg) - - @test length(diagview(D1)) == r - @test A * V1 ≈ V1 * D1 - @test ϵ1 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 + sqrt(eps(real(T))) - trunc = trunctol(; atol = s * abs(D₀[r + 1])) - D2, V2, ϵ2 = @constinferred eig_trunc(A; alg, trunc) - @test length(diagview(D2)) == r - @test A * V2 ≈ V2 * D2 - @test ϵ2 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - s = 1 - sqrt(eps(real(T))) - trunc = truncerror(; atol = s * norm(@view(D₀[r:end]), 1), p = 1) - D3, V3, ϵ3 = @constinferred eig_trunc(A; alg, trunc) - @test length(diagview(D3)) == r - @test A * V3 ≈ V3 * D3 - @test ϵ3 ≈ norm(view(D₀, (r + 1):m)) atol = atol - - # trunctol keeps order, truncrank might not - # test for same subspace - @test V1 * ((V1' * V1) \ (V1' * V2)) ≈ V2 - @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 - @test V1 * ((V1' * V1) \ (V1' * V3)) ≈ V3 - @test V3 * ((V3' * V3) \ (V3' * V1)) ≈ V1 -end - -@testset "eig_trunc! specify truncation algorithm T = $T" for T in eltypes - rng = StableRNG(123) - m = 4 - atol = sqrt(eps(real(T))) - V = randn(rng, T, m, m) - D = Diagonal(real(T)[0.9, 0.3, 0.1, 0.01]) - A = V * D * inv(V) - alg = TruncatedAlgorithm(GS_QRIteration(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eig_trunc(A; alg) - @test diagview(D2) ≈ diagview(D)[1:2] - @test ϵ2 ≈ norm(diagview(D)[3:4]) atol = atol - @test_throws ArgumentError eig_trunc(A; alg, trunc = (; maxrank = 2)) - - alg = TruncatedAlgorithm(GS_QRIteration(), truncerror(; atol = 0.2, p = 1)) - D3, V3, ϵ3 = @constinferred eig_trunc(A; alg) - @test diagview(D3) ≈ diagview(D)[1:2] - @test ϵ3 ≈ norm(diagview(D)[3:4]) atol = atol -end - -@testset "eig for Diagonal{$T}" for T in eltypes - rng = StableRNG(123) - m = 24 - Ad = randn(rng, T, m) - A = Diagonal(Ad) - atol = sqrt(eps(real(T))) - - D, V = @constinferred eig_full(A) - @test D isa Diagonal{T} && size(D) == size(A) - @test V isa Diagonal{T} && size(V) == size(A) - @test A * V ≈ V * D - - D2 = @constinferred eig_vals(A) - @test D2 isa AbstractVector{T} && length(D2) == m - @test diagview(D) ≈ D2 - - A2 = Diagonal(T[0.9, 0.3, 0.1, 0.01]) - alg = TruncatedAlgorithm(DiagonalAlgorithm(), truncrank(2)) - D2, V2, ϵ2 = @constinferred eig_trunc(A2; alg) - @test diagview(D2) ≈ diagview(A2)[1:2] - @test ϵ2 ≈ norm(diagview(A2)[3:4]) atol = atol -end diff --git a/test/runtests.jl b/test/runtests.jl index a2613823..7325d410 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,9 +22,6 @@ if !is_buildkite @safetestset "Generalized Eigenvalue Decomposition" begin include("gen_eig.jl") end - @safetestset "Schur Decomposition" begin - include("schur.jl") - end @safetestset "Image and Null Space" begin include("orthnull.jl") end @@ -55,10 +52,6 @@ if !is_buildkite include("genericlinearalgebra/eigh.jl") end - using GenericSchur - @safetestset "General Eigenvalue Decomposition" begin - include("genericschur/eig.jl") - end end @safetestset "QR / LQ Decomposition" begin @@ -71,6 +64,9 @@ end @safetestset "Projections" begin include("projections.jl") end +@safetestset "Schur Decomposition" begin + include("schur.jl") +end using CUDA if CUDA.functional() diff --git a/test/schur.jl b/test/schur.jl index e24de579..d306e3d8 100644 --- a/test/schur.jl +++ b/test/schur.jl @@ -2,30 +2,32 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: I +using LinearAlgebra: I, Diagonal -@testset "schur_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) - rng = StableRNG(123) - m = 54 - for alg in (LAPACK_Simple(), LAPACK_Expert()) - A = randn(rng, T, m, m) - Tc = complex(T) +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (BigFloat, Complex{BigFloat}) - TA, Z, vals = @constinferred schur_full(A; alg) - @test eltype(TA) == eltype(Z) == T - @test eltype(vals) == Tc - @test isisometric(Z) - @test A * Z ≈ Z * TA +@isdefined(TestSuite) || include("testsuite/TestSuite.jl") +using .TestSuite - Ac = similar(A) - TA2, Z2, vals2 = @constinferred schur_full!(copy!(Ac, A), (TA, Z, vals), alg) - @test TA2 === TA - @test Z2 === Z - @test vals2 === vals - @test A * Z ≈ Z * TA +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" - valsc = @constinferred schur_vals(A, alg) - @test eltype(valsc) == Tc - @test valsc ≈ eig_vals(A, alg) +m = 54 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + if T ∈ BLASFloats + #=if CUDA.functional() + TestSuite.test_schur(CuMatrix{T}, (m, m); test_blocksize = false) + TestSuite.test_schur(Diagonal{T, CuVector{T}}, m; test_blocksize = false) + end + if AMDGPU.functional() + TestSuite.test_schur(ROCMatrix{T}, (m, m); test_blocksize = false) + TestSuite.test_schur(Diagonal{T, ROCVector{T}}, m; test_blocksize = false) + end=# # not yet supported + end + if !is_buildkite + TestSuite.test_schur(T, (m, m)) + #AT = Diagonal{T, Vector{T}} + #TestSuite.test_schur(AT, m) # not supported yet end end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index f2f55bbc..deee4e12 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -73,5 +73,6 @@ include("qr.jl") include("lq.jl") include("polar.jl") include("projections.jl") +include("schur.jl") end diff --git a/test/testsuite/schur.jl b/test/testsuite/schur.jl new file mode 100644 index 00000000..572b9501 --- /dev/null +++ b/test/testsuite/schur.jl @@ -0,0 +1,57 @@ +using TestExtras +using GenericSchur + +function test_schur(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "schur $summary_str" begin + test_schur_full(T, sz; kwargs...) + test_schur_vals(T, sz; kwargs...) + end +end + +function test_schur_full( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "schur_full! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T)) + + TA, Z, vals = @testinferred schur_full(A) + @test eltype(TA) == eltype(Z) == eltype(T) + @test eltype(vals) == Tc + @test isisometric(Z) + @test A * Z ≈ Z * TA + + TA2, Z2, vals2 = @testinferred schur_full!(Ac, (TA, Z, vals)) + @test TA2 === TA + @test Z2 === Z + @test vals2 === vals + @test A * Z ≈ Z * TA + end +end + +function test_schur_vals( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "schur_vals! $summary_str" begin + A = instantiate_matrix(T, sz) + Ac = deepcopy(A) + Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T)) + + valsc = @testinferred schur_vals(A) + @test eltype(valsc) == Tc + @test valsc ≈ eig_vals(A) + + valsc = similar(A, Tc, size(A, 1)) + valsc = @testinferred schur_vals!(Ac, valsc) + @test eltype(valsc) == Tc + @test valsc ≈ eig_vals(A) + end +end