Skip to content
10 changes: 5 additions & 5 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <
return GLA_QRIteration()
end

for f! in (:svd_compact!, :svd_full!, :svd_vals!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
for f! in (:svd_compact!, :svd_full!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing)
end
MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing

function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration)
F = svd!(A)
Expand Down Expand Up @@ -43,9 +44,8 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T
return GLA_QRIteration(; kwargs...)
end

for f! in (:eigh_full!, :eigh_vals!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing
end
MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing)
MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing

function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration)
eigval, eigvec = eigen!(Hermitian(A); sortby = real)
Expand Down
5 changes: 2 additions & 3 deletions ext/MatrixAlgebraKitGenericSchurExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@ function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T <
return GS_QRIteration(; kwargs...)
end

for f! in (:eig_full!, :eig_vals!)
@eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GS_QRIteration) = nothing
end
MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing)
MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing

function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration)
D, V = GenericSchur.eigen!(A)
Expand Down
32 changes: 24 additions & 8 deletions src/implementations/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,20 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm)
return S
end

function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm)
U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg)
function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ}
ϵ = similar(A, real(eltype(A)), compute_error)
(U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg)
return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ)))
end

function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ}
U, S, Vᴴ, ϵ = USVᴴϵ
U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg)
USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)
return USVᴴtrunc..., truncation_error!(diagview(S), ind)
if !isempty(ϵ)
ϵ .= truncation_error!(diagview(S), ind)
end
return USVᴴtrunc..., ϵ
end

# Diagonal logic
Expand Down Expand Up @@ -362,16 +372,22 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm)
return USVᴴ
end

function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized})
check_input(svd_trunc!, A, USVᴴ, alg.alg)
U, S, Vᴴ = USVᴴ
function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ}
U, S, Vᴴ, ϵ = USVᴴϵ
check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg)
_gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...)

# TODO: make sure that truncation is based on maxrank, otherwise this might be wrong
(Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc)

# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this?
if !isempty(ϵ)
# normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
normS = norm(diagview(Str))
normA = norm(A)
# equivalent to sqrt(normA^2 - normS^2)
# but may be more accurate
ϵ = sqrt((normA + normS) * (normA - normS))
end

do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool
do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr)
Expand Down
34 changes: 17 additions & 17 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,23 @@ if !is_buildkite
JET.test_package(MatrixAlgebraKit; target_defined_modules = true)
end
end

using GenericLinearAlgebra
@safetestset "QR / LQ Decomposition" begin
include("genericlinearalgebra/qr.jl")
include("genericlinearalgebra/lq.jl")
end
@safetestset "Singular Value Decomposition" begin
include("genericlinearalgebra/svd.jl")
end
@safetestset "Hermitian Eigenvalue Decomposition" begin
include("genericlinearalgebra/eigh.jl")
end

using GenericSchur
@safetestset "General Eigenvalue Decomposition" begin
include("genericschur/eig.jl")
end
end

using CUDA
Expand Down Expand Up @@ -110,20 +127,3 @@ if AMDGPU.functional()
include("amd/orthnull.jl")
end
end

using GenericLinearAlgebra
@safetestset "QR / LQ Decomposition" begin
include("genericlinearalgebra/qr.jl")
include("genericlinearalgebra/lq.jl")
end
@safetestset "Singular Value Decomposition" begin
include("genericlinearalgebra/svd.jl")
end
@safetestset "Hermitian Eigenvalue Decomposition" begin
include("genericlinearalgebra/eigh.jl")
end

using GenericSchur
@safetestset "General Eigenvalue Decomposition" begin
include("genericschur/eig.jl")
end