diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 43154b75..9b558b11 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -9,9 +9,10 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T < return GLA_QRIteration() end -for f! in (:svd_compact!, :svd_full!, :svd_vals!) - @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing +for f! in (:svd_compact!, :svd_full!) + @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing) end +MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration) F = svd!(A) @@ -43,9 +44,8 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T return GLA_QRIteration(; kwargs...) end -for f! in (:eigh_full!, :eigh_vals!) - @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing -end +MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing) +MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration) eigval, eigvec = eigen!(Hermitian(A); sortby = real) diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index d278b5c5..0af53afb 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -9,9 +9,8 @@ function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T < return GS_QRIteration(; kwargs...) end -for f! in (:eig_full!, :eig_vals!) - @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GS_QRIteration) = nothing -end +MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing) +MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration) D, V = GenericSchur.eigen!(A) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 222c7aee..126e6a04 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -206,10 +206,20 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end -function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) - U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) +function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} + ϵ = similar(A, real(eltype(A)), compute_error) + (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) + return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) +end + +function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ} + U, S, Vᴴ, ϵ = USVᴴϵ + U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - return USVᴴtrunc..., truncation_error!(diagview(S), ind) + if !isempty(ϵ) + ϵ .= truncation_error!(diagview(S), ind) + end + return USVᴴtrunc..., ϵ end # Diagonal logic @@ -362,16 +372,22 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) return USVᴴ end -function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) - check_input(svd_trunc!, A, USVᴴ, alg.alg) - U, S, Vᴴ = USVᴴ +function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ} + U, S, Vᴴ, ϵ = USVᴴϵ + check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum - ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this? + if !isempty(ϵ) + # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum + normS = norm(diagview(Str)) + normA = norm(A) + # equivalent to sqrt(normA^2 - normS^2) + # but may be more accurate + ϵ = sqrt((normA + normS) * (normA - normS)) + end do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) diff --git a/test/runtests.jl b/test/runtests.jl index 4b69a3dc..1ed1f456 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,23 @@ if !is_buildkite JET.test_package(MatrixAlgebraKit; target_defined_modules = true) end end + + using GenericLinearAlgebra + @safetestset "QR / LQ Decomposition" begin + include("genericlinearalgebra/qr.jl") + include("genericlinearalgebra/lq.jl") + end + @safetestset "Singular Value Decomposition" begin + include("genericlinearalgebra/svd.jl") + end + @safetestset "Hermitian Eigenvalue Decomposition" begin + include("genericlinearalgebra/eigh.jl") + end + + using GenericSchur + @safetestset "General Eigenvalue Decomposition" begin + include("genericschur/eig.jl") + end end using CUDA @@ -110,20 +127,3 @@ if AMDGPU.functional() include("amd/orthnull.jl") end end - -using GenericLinearAlgebra -@safetestset "QR / LQ Decomposition" begin - include("genericlinearalgebra/qr.jl") - include("genericlinearalgebra/lq.jl") -end -@safetestset "Singular Value Decomposition" begin - include("genericlinearalgebra/svd.jl") -end -@safetestset "Hermitian Eigenvalue Decomposition" begin - include("genericlinearalgebra/eigh.jl") -end - -using GenericSchur -@safetestset "General Eigenvalue Decomposition" begin - include("genericschur/eig.jl") -end