From c63886422789d36c918c02be1231bf207a6abb75 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Dec 2025 13:18:07 +0100 Subject: [PATCH 01/14] Add optional epsilon vector and keyword arg for error --- src/implementations/svd.jl | 41 +++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 222c7aee..a9b52668 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -206,10 +206,35 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end -function svd_trunc!(A, USVᴴ, alg::TruncatedAlgorithm) +# nothing case here to handle GenericLinearAlgebra +function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ} + U, S, Vᴴ, ϵ = USVᴴϵ + U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg) + USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + if !isempty(ϵ) + ϵ[1] = truncation_error!(diagview(S), ind) + end + return USVᴴtrunc..., ϵ +end +function svd_trunc!(A, USVᴴϵ::Tuple{Nothing, Tϵ}, alg::TruncatedAlgorithm) where {Tϵ} + USVᴴ, ϵ = USVᴴϵ U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - return USVᴴtrunc..., truncation_error!(diagview(S), ind) + if !isempty(ϵ) + ϵ[1] = truncation_error!(diagview(S), ind) + end + return USVᴴtrunc..., ϵ +end + +function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} + ϵ = compute_error ? zeros(real(eltype(A)), 1) : zeros(real(eltype(A)), 0) + (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) + return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, NaN) +end +function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true) + ϵ = compute_error ? zeros(real(eltype(A)), 1) : zeros(real(eltype(A)), 0) + U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg) + return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, NaN) end # Diagonal logic @@ -362,16 +387,18 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) return USVᴴ end -function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) - check_input(svd_trunc!, A, USVᴴ, alg.alg) - U, S, Vᴴ = USVᴴ +function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm{<:GPU_Randomized}) where {TU, TS, TVᴴ, Tϵ} + U, S, Vᴴ, ϵ = USVᴴϵ + check_input(svd_trunc!, A, (U, S, Vᴴ), alg.alg) _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum - ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this? + if !isempty(ϵ) + # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum + ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this? + end do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) From e5245e69b131c8f4c72cbc8b4cc971eb4e0a538d Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Dec 2025 13:32:04 +0100 Subject: [PATCH 02/14] Make NaN the same type --- src/implementations/svd.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index a9b52668..6e942285 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -227,14 +227,15 @@ function svd_trunc!(A, USVᴴϵ::Tuple{Nothing, Tϵ}, alg::TruncatedAlgorithm) w end function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} - ϵ = compute_error ? zeros(real(eltype(A)), 1) : zeros(real(eltype(A)), 0) + Tr = real(eltype(A)) + ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0) (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, NaN) + return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, Tr(NaN)) end function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true) ϵ = compute_error ? zeros(real(eltype(A)), 1) : zeros(real(eltype(A)), 0) U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, NaN) + return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, Tr(NaN)) end # Diagonal logic From 834c3c2e5ded5dfc87ea23579ccc34dec59d0758 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Dec 2025 13:56:26 +0100 Subject: [PATCH 03/14] GLA tests should not be run on GPU --- test/runtests.jl | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4b69a3dc..1ed1f456 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,23 @@ if !is_buildkite JET.test_package(MatrixAlgebraKit; target_defined_modules = true) end end + + using GenericLinearAlgebra + @safetestset "QR / LQ Decomposition" begin + include("genericlinearalgebra/qr.jl") + include("genericlinearalgebra/lq.jl") + end + @safetestset "Singular Value Decomposition" begin + include("genericlinearalgebra/svd.jl") + end + @safetestset "Hermitian Eigenvalue Decomposition" begin + include("genericlinearalgebra/eigh.jl") + end + + using GenericSchur + @safetestset "General Eigenvalue Decomposition" begin + include("genericschur/eig.jl") + end end using CUDA @@ -110,20 +127,3 @@ if AMDGPU.functional() include("amd/orthnull.jl") end end - -using GenericLinearAlgebra -@safetestset "QR / LQ Decomposition" begin - include("genericlinearalgebra/qr.jl") - include("genericlinearalgebra/lq.jl") -end -@safetestset "Singular Value Decomposition" begin - include("genericlinearalgebra/svd.jl") -end -@safetestset "Hermitian Eigenvalue Decomposition" begin - include("genericlinearalgebra/eigh.jl") -end - -using GenericSchur -@safetestset "General Eigenvalue Decomposition" begin - include("genericschur/eig.jl") -end From b7d9547306bf36e00afbae517a9cf893e4b59a36 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Dec 2025 14:38:18 +0100 Subject: [PATCH 04/14] Don't use NaN because of BigFloat --- src/implementations/svd.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 6e942285..155cb156 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -230,12 +230,12 @@ function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; co Tr = real(eltype(A)) ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0) (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, Tr(NaN)) + return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(Tr)) end function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true) ϵ = compute_error ? zeros(real(eltype(A)), 1) : zeros(real(eltype(A)), 0) U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, Tr(NaN)) + return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(Tr)) end # Diagonal logic From 5b4f4e0ba019b4b95d2032d8b372a493b698ed0e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Wed, 10 Dec 2025 15:26:53 +0100 Subject: [PATCH 05/14] Fix type signature --- src/implementations/svd.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 155cb156..31fe0b8e 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -230,12 +230,13 @@ function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; co Tr = real(eltype(A)) ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0) (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(Tr)) + return compute_error ? (U, S, Vᴴ, ϵ[1]::Tr) : (U, S, Vᴴ, -one(Tr)) end function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true) - ϵ = compute_error ? zeros(real(eltype(A)), 1) : zeros(real(eltype(A)), 0) + Tr = real(eltype(A)) + ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0) U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(Tr)) + return compute_error ? (U, S, Vᴴ, ϵ[1]::Tr) : (U, S, Vᴴ, -one(Tr)) end # Diagonal logic From ff23236a59abe3061294d4288e265ee596bc0ec9 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 12:05:18 +0100 Subject: [PATCH 06/14] Update src/implementations/svd.jl Co-authored-by: Jutho --- src/implementations/svd.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 31fe0b8e..5b66504a 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -227,8 +227,7 @@ function svd_trunc!(A, USVᴴϵ::Tuple{Nothing, Tϵ}, alg::TruncatedAlgorithm) w end function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} - Tr = real(eltype(A)) - ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0) + ϵ = similar(S, compute_error) (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) return compute_error ? (U, S, Vᴴ, ϵ[1]::Tr) : (U, S, Vᴴ, -one(Tr)) end From 9d2070b5ce810f2ba7188f5bb8f2ea66f87bd905 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 12:05:32 +0100 Subject: [PATCH 07/14] Update src/implementations/svd.jl Co-authored-by: Jutho --- src/implementations/svd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 5b66504a..f86fdf9b 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -229,7 +229,7 @@ end function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} ϵ = similar(S, compute_error) (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]::Tr) : (U, S, Vᴴ, -one(Tr)) + return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(eltype(ϵ))) end function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true) Tr = real(eltype(A)) From f40c1411c39eb489d9a7bbbd8555e7eca0f238af Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 12:10:34 +0100 Subject: [PATCH 08/14] Fix undefined err --- src/implementations/svd.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index f86fdf9b..f04b0b43 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -227,7 +227,7 @@ function svd_trunc!(A, USVᴴϵ::Tuple{Nothing, Tϵ}, alg::TruncatedAlgorithm) w end function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} - ϵ = similar(S, compute_error) + ϵ = similar(USVᴴ[2], compute_error) (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(eltype(ϵ))) end @@ -235,7 +235,7 @@ function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error:: Tr = real(eltype(A)) ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0) U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]::Tr) : (U, S, Vᴴ, -one(Tr)) + return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(Tr)) end # Diagonal logic From 2c34bc43614891730c590f652aabfc45f7da3231 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 12:17:36 +0100 Subject: [PATCH 09/14] Use broadcasting to set epsilon --- src/implementations/svd.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index f04b0b43..30095df3 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -212,7 +212,7 @@ function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgori U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) if !isempty(ϵ) - ϵ[1] = truncation_error!(diagview(S), ind) + ϵ .= truncation_error!(diagview(S), ind) end return USVᴴtrunc..., ϵ end @@ -221,7 +221,7 @@ function svd_trunc!(A, USVᴴϵ::Tuple{Nothing, Tϵ}, alg::TruncatedAlgorithm) w U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) if !isempty(ϵ) - ϵ[1] = truncation_error!(diagview(S), ind) + ϵ .= truncation_error!(diagview(S), ind) end return USVᴴtrunc..., ϵ end From 905647e73bd69db707f427753b07f84c06c99ebe Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 12:32:54 +0100 Subject: [PATCH 10/14] Force memory transfer --- src/implementations/svd.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 30095df3..7f0515a5 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -229,13 +229,13 @@ end function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} ϵ = similar(USVᴴ[2], compute_error) (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(eltype(ϵ))) + return compute_error ? (U, S, Vᴴ, collect(ϵ)[1]) : (U, S, Vᴴ, -one(eltype(ϵ))) end function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true) Tr = real(eltype(A)) - ϵ = compute_error ? zeros(Tr, 1) : zeros(Tr, 0) + ϵ = zeros(Tr, compute_error) U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg) - return compute_error ? (U, S, Vᴴ, ϵ[1]) : (U, S, Vᴴ, -one(Tr)) + return compute_error ? (U, S, Vᴴ, collect(ϵ)[1]) : (U, S, Vᴴ, -one(Tr)) end # Diagonal logic From f789ea6cf39cc448bafe06358849262ca470d9c6 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 14:09:15 +0100 Subject: [PATCH 11/14] Comments --- src/implementations/svd.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 7f0515a5..63b2aecb 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -229,13 +229,12 @@ end function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} ϵ = similar(USVᴴ[2], compute_error) (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, collect(ϵ)[1]) : (U, S, Vᴴ, -one(eltype(ϵ))) + return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) end function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true) - Tr = real(eltype(A)) - ϵ = zeros(Tr, compute_error) + ϵ = similar(A, real(eltype(A)), compute_error) U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg) - return compute_error ? (U, S, Vᴴ, collect(ϵ)[1]) : (U, S, Vᴴ, -one(Tr)) + return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) end # Diagonal logic From 4f2ad8851c83d564bff209cec6c72987177ac50f Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 14:14:53 +0100 Subject: [PATCH 12/14] New norm calculation --- src/implementations/svd.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 63b2aecb..f50e67eb 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -397,7 +397,11 @@ function svd_trunc!(A::AbstractMatrix, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg: if !isempty(ϵ) # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum - ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this? + normS = norm(diagview(Str)) + normA = norm(A) + # equivalent to sqrt(normA^2 - normS^2) + # but may be more accurate + ϵ = sqrt((normA + normS) * (normA - normS)) end do_gauge_fix = get(alg.alg.kwargs, :fixgauge, default_fixgauge())::Bool From 221e917b80f9660a51a7877d61754066170f3248 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 14:59:44 +0100 Subject: [PATCH 13/14] Move to tuples of Nothing --- ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl | 10 +++++----- ext/MatrixAlgebraKitGenericSchurExt.jl | 5 ++--- src/implementations/svd.jl | 17 +---------------- 3 files changed, 8 insertions(+), 24 deletions(-) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 43154b75..9b558b11 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -9,9 +9,10 @@ function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T < return GLA_QRIteration() end -for f! in (:svd_compact!, :svd_full!, :svd_vals!) - @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing +for f! in (:svd_compact!, :svd_full!) + @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing, nothing) end +MatrixAlgebraKit.initialize_output(::typeof(svd_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration) F = svd!(A) @@ -43,9 +44,8 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T return GLA_QRIteration(; kwargs...) end -for f! in (:eigh_full!, :eigh_vals!) - @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing -end +MatrixAlgebraKit.initialize_output(::typeof(eigh_full!), A::AbstractMatrix, ::GLA_QRIteration) = (nothing, nothing) +MatrixAlgebraKit.initialize_output(::typeof(eigh_vals!), A::AbstractMatrix, ::GLA_QRIteration) = nothing function MatrixAlgebraKit.eigh_full!(A::AbstractMatrix, DV, ::GLA_QRIteration) eigval, eigvec = eigen!(Hermitian(A); sortby = real) diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index d278b5c5..0af53afb 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -9,9 +9,8 @@ function MatrixAlgebraKit.default_eig_algorithm(::Type{T}; kwargs...) where {T < return GS_QRIteration(; kwargs...) end -for f! in (:eig_full!, :eig_vals!) - @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GS_QRIteration) = nothing -end +MatrixAlgebraKit.initialize_output(::typeof(eig_full!), A::AbstractMatrix, ::GS_QRIteration) = (nothing, nothing) +MatrixAlgebraKit.initialize_output(::typeof(eig_vals!), A::AbstractMatrix, ::GS_QRIteration) = nothing function MatrixAlgebraKit.eig_full!(A::AbstractMatrix, DV, ::GS_QRIteration) D, V = GenericSchur.eigen!(A) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index f50e67eb..b67db32a 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -216,24 +216,9 @@ function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgori end return USVᴴtrunc..., ϵ end -function svd_trunc!(A, USVᴴϵ::Tuple{Nothing, Tϵ}, alg::TruncatedAlgorithm) where {Tϵ} - USVᴴ, ϵ = USVᴴϵ - U, S, Vᴴ = svd_compact!(A, USVᴴ, alg.alg) - USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - if !isempty(ϵ) - ϵ .= truncation_error!(diagview(S), ind) - end - return USVᴴtrunc..., ϵ -end - function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} - ϵ = similar(USVᴴ[2], compute_error) - (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) -end -function svd_trunc!(A, USVᴴ::Nothing, alg::TruncatedAlgorithm; compute_error::Bool = true) ϵ = similar(A, real(eltype(A)), compute_error) - U, S, Vᴴ, ϵ = svd_trunc!(A, (USVᴴ, ϵ), alg) + (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) end From 74b4aca507802fd6b866478d1d48af4e11fcd72a Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 11 Dec 2025 15:16:36 +0100 Subject: [PATCH 14/14] Comment --- src/implementations/svd.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index b67db32a..126e6a04 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -206,7 +206,12 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) return S end -# nothing case here to handle GenericLinearAlgebra +function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} + ϵ = similar(A, real(eltype(A)), compute_error) + (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) + return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) +end + function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgorithm) where {TU, TS, TVᴴ, Tϵ} U, S, Vᴴ, ϵ = USVᴴϵ U, S, Vᴴ = svd_compact!(A, (U, S, Vᴴ), alg.alg) @@ -216,11 +221,6 @@ function svd_trunc!(A, USVᴴϵ::Tuple{TU, TS, TVᴴ, Tϵ}, alg::TruncatedAlgori end return USVᴴtrunc..., ϵ end -function svd_trunc!(A, USVᴴ::Tuple{TU, TS, TVᴴ}, alg::TruncatedAlgorithm; compute_error::Bool = true) where {TU, TS, TVᴴ} - ϵ = similar(A, real(eltype(A)), compute_error) - (U, S, Vᴴ, ϵ) = svd_trunc!(A, (USVᴴ..., ϵ), alg) - return compute_error ? (U, S, Vᴴ, norm(ϵ)) : (U, S, Vᴴ, -one(eltype(ϵ))) -end # Diagonal logic # --------------