Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions ext/MatrixAlgebraKitGenericSchurExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ module MatrixAlgebraKitGenericSchurExt

using MatrixAlgebraKit
using MatrixAlgebraKit: check_input
using LinearAlgebra: Diagonal
using LinearAlgebra: Diagonal, sorteig!
using GenericSchur

function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}}
function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return GS_QRIteration(; kwargs...)
end

Expand All @@ -21,4 +21,21 @@ function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration)
return GenericSchur.eigvals!(A)
end

function MatrixAlgebraKit.schur_full!(A::AbstractMatrix, TZv, alg::GS_QRIteration)
check_input(schur_full!, A, TZv, alg)
T, Z, vals = TZv
S = GenericSchur.gschur(A)
copyto!(T, S.T)
copyto!(Z, S.Z)
copyto!(vals, S.values)
return T, Z, vals
end

function MatrixAlgebraKit.schur_vals!(A::AbstractMatrix, vals, alg::GS_QRIteration)
check_input(schur_vals!, A, vals, alg)
S = GenericSchur.gschur(A)
copyto!(vals, sorteig!(S.values))
return vals
end

end
116 changes: 0 additions & 116 deletions test/genericschur/eig.jl

This file was deleted.

10 changes: 3 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@ if !is_buildkite
@safetestset "Generalized Eigenvalue Decomposition" begin
include("gen_eig.jl")
end
@safetestset "Schur Decomposition" begin
include("schur.jl")
end
@safetestset "Image and Null Space" begin
include("orthnull.jl")
end
Expand Down Expand Up @@ -55,10 +52,6 @@ if !is_buildkite
include("genericlinearalgebra/eigh.jl")
end

using GenericSchur
@safetestset "General Eigenvalue Decomposition" begin
include("genericschur/eig.jl")
end
end

@safetestset "QR / LQ Decomposition" begin
Expand All @@ -71,6 +64,9 @@ end
@safetestset "Projections" begin
include("projections.jl")
end
@safetestset "Schur Decomposition" begin
include("schur.jl")
end

using CUDA
if CUDA.functional()
Expand Down
44 changes: 23 additions & 21 deletions test/schur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,32 @@ using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: I
using LinearAlgebra: I, Diagonal

@testset "schur_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
rng = StableRNG(123)
m = 54
for alg in (LAPACK_Simple(), LAPACK_Expert())
A = randn(rng, T, m, m)
Tc = complex(T)
BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
GenericFloats = (BigFloat, Complex{BigFloat})

TA, Z, vals = @constinferred schur_full(A; alg)
@test eltype(TA) == eltype(Z) == T
@test eltype(vals) == Tc
@test isisometric(Z)
@test A * Z ≈ Z * TA
@isdefined(TestSuite) || include("testsuite/TestSuite.jl")
using .TestSuite

Ac = similar(A)
TA2, Z2, vals2 = @constinferred schur_full!(copy!(Ac, A), (TA, Z, vals), alg)
@test TA2 === TA
@test Z2 === Z
@test vals2 === vals
@test A * Z ≈ Z * TA
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

valsc = @constinferred schur_vals(A, alg)
@test eltype(valsc) == Tc
@test valsc ≈ eig_vals(A, alg)
m = 54
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if T ∈ BLASFloats
#=if CUDA.functional()
TestSuite.test_schur(CuMatrix{T}, (m, m); test_blocksize = false)
TestSuite.test_schur(Diagonal{T, CuVector{T}}, m; test_blocksize = false)
end
if AMDGPU.functional()
TestSuite.test_schur(ROCMatrix{T}, (m, m); test_blocksize = false)
TestSuite.test_schur(Diagonal{T, ROCVector{T}}, m; test_blocksize = false)
end=# # not yet supported
end
if !is_buildkite
TestSuite.test_schur(T, (m, m))
#AT = Diagonal{T, Vector{T}}
#TestSuite.test_schur(AT, m) # not supported yet
end
end
1 change: 1 addition & 0 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,6 @@ include("qr.jl")
include("lq.jl")
include("polar.jl")
include("projections.jl")
include("schur.jl")

end
57 changes: 57 additions & 0 deletions test/testsuite/schur.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using TestExtras
using GenericSchur

function test_schur(T::Type, sz; kwargs...)
summary_str = testargs_summary(T, sz)
return @testset "schur $summary_str" begin
test_schur_full(T, sz; kwargs...)
test_schur_vals(T, sz; kwargs...)
end
end

function test_schur_full(
T::Type, sz;
atol::Real = 0, rtol::Real = precision(T),
kwargs...
)
summary_str = testargs_summary(T, sz)
return @testset "schur_full! $summary_str" begin
A = instantiate_matrix(T, sz)
Ac = deepcopy(A)
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))

TA, Z, vals = @testinferred schur_full(A)
@test eltype(TA) == eltype(Z) == eltype(T)
@test eltype(vals) == Tc
@test isisometric(Z)
@test A * Z ≈ Z * TA

TA2, Z2, vals2 = @testinferred schur_full!(Ac, (TA, Z, vals))
@test TA2 === TA
@test Z2 === Z
@test vals2 === vals
@test A * Z ≈ Z * TA
end
end

function test_schur_vals(
T::Type, sz;
atol::Real = 0, rtol::Real = precision(T),
kwargs...
)
summary_str = testargs_summary(T, sz)
return @testset "schur_vals! $summary_str" begin
A = instantiate_matrix(T, sz)
Ac = deepcopy(A)
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))

valsc = @testinferred schur_vals(A)
@test eltype(valsc) == Tc
@test valsc ≈ eig_vals(A)

valsc = similar(A, Tc, size(A, 1))
valsc = @testinferred schur_vals!(Ac, valsc)
@test eltype(valsc) == Tc
@test valsc ≈ eig_vals(A)
end
end