From a041be9d4c2a0a98671fe8a694c5fa1fad674099 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 10:03:11 -0500 Subject: [PATCH 01/18] move gaugefix code to gaugefix section --- src/implementations/svd.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index ba54f233..aef64d8d 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -105,6 +105,8 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) end +# Gauge fixation +# -------------- function gaugefix!(::typeof(svd_full!), U, S, Vᴴ, m::Int, n::Int) for j in 1:max(m, n) if j <= min(m, n) @@ -126,8 +128,6 @@ function gaugefix!(::typeof(svd_full!), U, S, Vᴴ, m::Int, n::Int) return (U, S, Vᴴ) end -# Gauge fixation -# -------------- function gaugefix!(::typeof(svd_compact!), U, S, Vᴴ, m::Int, n::Int) for j in 1:size(U, 2) u = view(U, :, j) From 4e2a298bc7175c571e31ba57e9cef9fa1d4fd65f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 10:21:18 -0500 Subject: [PATCH 02/18] add gaugefix keyword argument for enabling/disabling `svd` gauge fixing --- src/implementations/svd.jl | 108 +++++++++++++++++++++++-------------- 1 file changed, 68 insertions(+), 40 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index aef64d8d..f56b6d32 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -164,13 +164,17 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) one!(Vᴴ) return USVᴴ end + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_QRIteration - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) YALAPACK.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa LAPACK_DivideAndConquer - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa LAPACK_Bisection throw(ArgumentError("LAPACK_Bisection is not supported for full SVD")) @@ -179,60 +183,71 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) else throw(ArgumentError("Unsupported SVD algorithm")) end + for i in 2:minmn S[i, i] = S[i, 1] S[i, 1] = zero(eltype(S)) end - # TODO: make this controllable using a `gaugefix` keyword argument - gaugefix!(svd_full!, U, S, Vᴴ, m, n) + + dogaugefix && gaugefix!(svd_full!, U, S, Vᴴ, m, n) + return USVᴴ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_QRIteration - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) YALAPACK.gesvd!(A, S.diag, U, Vᴴ) elseif alg isa LAPACK_DivideAndConquer - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, S.diag, U, Vᴴ) elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg.kwargs...) + YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; lapack_kwargs...) elseif alg isa LAPACK_Jacobi - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) YALAPACK.gesvj!(A, S.diag, U, Vᴴ) else throw(ArgumentError("Unsupported SVD algorithm")) end - # TODO: make this controllable using a `gaugefix` keyword argument - gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) + + dogaugefix && gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) + return USVᴴ end function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) check_input(svd_vals!, A, S, alg) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) + + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_QRIteration - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_QRIteration does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) YALAPACK.gesvd!(A, S, U, Vᴴ) elseif alg isa LAPACK_DivideAndConquer - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_DivideAndConquer does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, S, U, Vᴴ) elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, S, U, Vᴴ; alg.kwargs...) + YALAPACK.gesvdx!(A, S, U, Vᴴ; lapack_kwargs...) elseif alg isa LAPACK_Jacobi - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Jacobi does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) YALAPACK.gesvj!(A, S, U, Vᴴ) else throw(ArgumentError("Unsupported SVD algorithm")) end + return S end @@ -365,21 +380,25 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) one!(Vᴴ) return USVᴴ end + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa GPU_QRIteration - isempty(alg.kwargs) || - @warn "GPU_QRIteration does not accept any keyword arguments" + isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" _gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) + _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; lapack_kwargs...) elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg.kwargs...) + _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; lapack_kwargs...) else throw(ArgumentError("Unsupported SVD algorithm")) end diagview(S) .= view(S, 1:minmn, 1) view(S, 2:minmn, 1) .= zero(eltype(S)) - # TODO: make this controllable using a `gaugefix` keyword argument - gaugefix!(svd_full!, U, S, Vᴴ, m, n) + + dogaugefix && gaugefix!(svd_full!, U, S, Vᴴ, m, n) + return USVᴴ end @@ -387,8 +406,10 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran check_input(svd_trunc!, A, USVᴴ, alg.alg) U, S, Vᴴ = USVᴴ _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) - # TODO: make this controllable using a `gaugefix` keyword argument - gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) + + dogaugefix = get(alg.alg.kwargs, :gaugefix, true)::Bool + dogaugefix && gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) + # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) Strunc = diagview(USVᴴtrunc[2]) @@ -400,19 +421,23 @@ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa GPU_QRIteration - isempty(alg.kwargs) || - @warn "GPU_QRIteration does not accept any keyword arguments" + isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" _gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ) elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg.kwargs...) + _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; lapack_kwargs...) elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, S.diag, U, Vᴴ; alg.kwargs...) + _gpu_gesvdj!(A, S.diag, U, Vᴴ; lapack_kwargs...) else throw(ArgumentError("Unsupported SVD algorithm")) end - # TODO: make this controllable using a `gaugefix` keyword argument - gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) + + dogaugefix && gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) + return USVᴴ end _argmaxabs(x) = reduce(_largest, x; init = zero(eltype(x))) @@ -421,16 +446,19 @@ _largest(x, y) = abs(x) < abs(y) ? y : x function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) check_input(svd_vals!, A, S, alg) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) + + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa GPU_QRIteration - isempty(alg.kwargs) || - @warn "GPU_QRIteration does not accept any keyword arguments" + isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" _gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ) elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, S, U, Vᴴ; alg.kwargs...) + _gpu_Xgesvdp!(A, S, U, Vᴴ; lapack_kwargs...) elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, S, U, Vᴴ; alg.kwargs...) + _gpu_gesvdj!(A, S, U, Vᴴ; lapack_kwargs...) else throw(ArgumentError("Unsupported SVD algorithm")) end + return S end From 4475012934432ae65b050f41480e6b74625a4ea1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 10:40:23 -0500 Subject: [PATCH 03/18] add gaugefix keyword argument for enabling/disabling `eig` gauge fixing --- src/implementations/eig.jl | 44 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 8532c29a..8bcbbeb7 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -81,28 +81,37 @@ end function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) check_input(eig_full!, A, DV, alg) D, V = DV + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_Simple - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) YALAPACK.geev!(A, D.diag, V) else # alg isa LAPACK_Expert - YALAPACK.geevx!(A, D.diag, V; alg.kwargs...) + YALAPACK.geevx!(A, D.diag, V; lapack_kwargs...) end - # TODO: make this controllable using a `gaugefix` keyword argument - V = gaugefix!(V) + + dogaugefix && (V = gaugefix!(V)) + return D, V end function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm) check_input(eig_vals!, A, D, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) + + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_Simple - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple (geev) does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) YALAPACK.geev!(A, D, V) else # alg isa LAPACK_Expert - YALAPACK.geevx!(A, D, V; alg.kwargs...) + YALAPACK.geevx!(A, D, V; lapack_kwargs...) end + return D end @@ -135,23 +144,30 @@ _gpu_geev!(A, D, V) = throw(MethodError(_gpu_geev!, (A, D, V))) function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) check_input(eig_full!, A, DV, alg) D, V = DV + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa GPU_Simple - isempty(alg.kwargs) || - @warn "GPU_Simple (geev) does not accept any keyword arguments" + isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_Simple" _gpu_geev!(A, D.diag, V) end - # TODO: make this controllable using a `gaugefix` keyword argument - V = gaugefix!(V) + + dogaugefix && (V = gaugefix!(V)) + return D, V end function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm) check_input(eig_vals!, A, D, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) + + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa GPU_Simple - isempty(alg.kwargs) || - @warn "GPU_Simple (geev) does not accept any keyword arguments" + isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_Simple" _gpu_geev!(A, D, V) end + return D end From 94ece0e870c5708e49440812e70d4c72e7508826 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 10:40:36 -0500 Subject: [PATCH 04/18] add gaugefix kewyord argument for enabling/disabling `eigh` gauge fixing --- src/implementations/eigh.jl | 58 ++++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 13d0b9d3..e37e7503 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -91,32 +91,41 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) check_input(eigh_full!, A, DV, alg) D, V = DV Dd = D.diag + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_MultipleRelativelyRobustRepresentations - YALAPACK.heevr!(A, Dd, V; alg.kwargs...) + YALAPACK.heevr!(A, Dd, V; lapack_kwargs...) elseif alg isa LAPACK_DivideAndConquer - YALAPACK.heevd!(A, Dd, V; alg.kwargs...) + YALAPACK.heevd!(A, Dd, V; lapack_kwargs...) elseif alg isa LAPACK_Simple - YALAPACK.heev!(A, Dd, V; alg.kwargs...) + YALAPACK.heev!(A, Dd, V; lapack_kwargs...) else # alg isa LAPACK_Expert - YALAPACK.heevx!(A, Dd, V; alg.kwargs...) + YALAPACK.heevx!(A, Dd, V; lapack_kwargs...) end - # TODO: make this controllable using a `gaugefix` keyword argument - V = gaugefix!(V) + + dogaugefix && (V = gaugefix!(V)) + return D, V end function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm) check_input(eigh_vals!, A, D, alg) V = similar(A, (size(A, 1), 0)) + + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_MultipleRelativelyRobustRepresentations - YALAPACK.heevr!(A, D, V; alg.kwargs...) + YALAPACK.heevr!(A, D, V; lapack_kwargs...) elseif alg isa LAPACK_DivideAndConquer - YALAPACK.heevd!(A, D, V; alg.kwargs...) + YALAPACK.heevd!(A, D, V; lapack_kwargs...) elseif alg isa LAPACK_QRIteration # == LAPACK_Simple - YALAPACK.heev!(A, D, V; alg.kwargs...) + YALAPACK.heev!(A, D, V; lapack_kwargs...) else # alg isa LAPACK_Bisection == LAPACK_Expert - YALAPACK.heevx!(A, D, V; alg.kwargs...) + YALAPACK.heevx!(A, D, V; lapack_kwargs...) end + return D end @@ -158,35 +167,44 @@ function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) check_input(eigh_full!, A, DV, alg) D, V = DV Dd = D.diag + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa GPU_Jacobi - _gpu_heevj!(A, Dd, V; alg.kwargs...) + _gpu_heevj!(A, Dd, V; lapack_kwargs...) elseif alg isa GPU_DivideAndConquer - _gpu_heevd!(A, Dd, V; alg.kwargs...) + _gpu_heevd!(A, Dd, V; lapack_kwargs...) elseif alg isa GPU_QRIteration # alg isa GPU_QRIteration == GPU_Simple - _gpu_heev!(A, Dd, V; alg.kwargs...) + _gpu_heev!(A, Dd, V; lapack_kwargs...) elseif alg isa GPU_Bisection # alg isa GPU_Bisection == GPU_Expert - _gpu_heevx!(A, Dd, V; alg.kwargs...) + _gpu_heevx!(A, Dd, V; lapack_kwargs...) else throw(ArgumentError("Unsupported eigh algorithm")) end - # TODO: make this controllable using a `gaugefix` keyword argument - V = gaugefix!(V) + + dogaugefix && (V = gaugefix!(V)) + return D, V end function eigh_vals!(A::AbstractMatrix, D, alg::GPU_EighAlgorithm) check_input(eigh_vals!, A, D, alg) V = similar(A, (size(A, 1), 0)) + + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa GPU_Jacobi - _gpu_heevj!(A, D, V; alg.kwargs...) + _gpu_heevj!(A, D, V; lapack_kwargs...) elseif alg isa GPU_DivideAndConquer - _gpu_heevd!(A, D, V; alg.kwargs...) + _gpu_heevd!(A, D, V; lapack_kwargs...) elseif alg isa GPU_QRIteration - _gpu_heev!(A, D, V; alg.kwargs...) + _gpu_heev!(A, D, V; lapack_kwargs...) elseif alg isa GPU_Bisection - _gpu_heevx!(A, D, V; alg.kwargs...) + _gpu_heevx!(A, D, V; lapack_kwargs...) else throw(ArgumentError("Unsupported eigh algorithm")) end + return D end From fa673bb26464f7ad9136ed579c441e9e6b8fbde8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 10:44:16 -0500 Subject: [PATCH 05/18] add gaugefix keyword argument for enabling/disabling `gen_eig` gauge fixing --- src/implementations/gen_eig.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index 07bf6e0c..e796ecfa 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -57,27 +57,36 @@ end function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_EigAlgorithm) check_input(gen_eig_full!, A, B, WV, alg) W, V = WV + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_Simple - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) YALAPACK.ggev!(A, B, W.diag, V, similar(W.diag, eltype(A))) else # alg isa LAPACK_Expert throw(ArgumentError("LAPACK_Expert is not supported for ggev")) end - # TODO: make this controllable using a `gaugefix` keyword argument - V = gaugefix!(V) + + dogaugefix && (V = gaugefix!(V)) + return W, V end function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigAlgorithm) check_input(gen_eig_vals!, A, B, W, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) + + lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + if alg isa LAPACK_Simple - isempty(alg.kwargs) || - throw(ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments")) + isempty(lapack_kwargs) || + throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) YALAPACK.ggev!(A, B, W, V, similar(W, eltype(A))) else # alg isa LAPACK_Expert throw(ArgumentError("LAPACK_Expert is not supported for ggev")) end + return W end From 8d8cc21dbaa4e47e8410e34e617c90dd854e1670 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 10:53:34 -0500 Subject: [PATCH 06/18] update docstrings --- src/interface/decompositions.jl | 66 ++++++++++++++++++++++++--------- 1 file changed, 49 insertions(+), 17 deletions(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 1bdf1534..92937e00 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -49,18 +49,22 @@ of `R` are non-negative. # General Eigenvalue Decomposition # ------------------------------- """ - LAPACK_Simple() + LAPACK_Simple(; gaugefix::Bool = true) Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian eigenvalue decomposition of a matrix. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, +see also [`gaugefix!`](@ref). """ @algdef LAPACK_Simple """ - LAPACK_Expert() + LAPACK_Expert(; gaugefix::Bool = true) Algorithm type to denote the expert LAPACK driver for computing the Schur or non-Hermitian eigenvalue decomposition of a matrix. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, +see also [`gaugefix!`](@ref). """ @algdef LAPACK_Expert @@ -77,37 +81,45 @@ eigenvalue decomposition of a non-Hermitian matrix. # Hermitian Eigenvalue Decomposition # ---------------------------------- """ - LAPACK_QRIteration() + LAPACK_QRIteration(; gaugefix::Bool = true) Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the QR Iteration algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef LAPACK_QRIteration """ - LAPACK_Bisection() + LAPACK_Bisection(; gaugefix::Bool = true) Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Bisection algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef LAPACK_Bisection """ - LAPACK_DivideAndConquer() + LAPACK_DivideAndConquer(; gaugefix::Bool = true) Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef LAPACK_DivideAndConquer """ - LAPACK_MultipleRelativelyRobustRepresentations() + LAPACK_MultipleRelativelyRobustRepresentations(; gaugefix::Bool = true) Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a Hermitian matrix using the Multiple Relatively Robust Representations algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, +see also [`gaugefix!`](@ref). """ @algdef LAPACK_MultipleRelativelyRobustRepresentations @@ -130,10 +142,12 @@ a general matrix. # Singular Value Decomposition # ---------------------------- """ - LAPACK_Jacobi() + LAPACK_Jacobi(; gaugefix::Bool = true) Algorithm type to denote the LAPACK driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, +see also [`gaugefix!`](@ref). """ @algdef LAPACK_Jacobi @@ -199,34 +213,40 @@ end CUSOLVER_HouseholderQR(; positive = false) Algorithm type to denote the standard CUSOLVER algorithm for computing the QR decomposition of -a matrix using Householder reflectors. The keyword `positive=true` can be used to ensure that +a matrix using Householder reflectors. The keyword `positive = true` can be used to ensure that the diagonal elements of `R` are non-negative. """ @algdef CUSOLVER_HouseholderQR """ - CUSOLVER_QRIteration() + CUSOLVER_QRIteration(; gaugefix::Bool = true) Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the QR Iteration algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef CUSOLVER_QRIteration """ - CUSOLVER_SVDPolar() + CUSOLVER_SVDPolar(; gaugefix::Bool = true) Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of a general matrix by using Halley's iterative algorithm to compute the polar decompositon, followed by the hermitian eigenvalue decomposition of the positive definite factor. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the singular +vectors, see also [`gaugefix!`](@ref). """ @algdef CUSOLVER_SVDPolar """ - CUSOLVER_Jacobi() + CUSOLVER_Jacobi(; gaugefix::Bool = true) Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the singular +vectors, see also [`gaugefix!`](@ref). """ @algdef CUSOLVER_Jacobi @@ -248,21 +268,25 @@ for more information. does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true """ - CUSOLVER_Simple() + CUSOLVER_Simple(; gaugefix::Bool = true) Algorithm type to denote the simple CUSOLVER driver for computing the non-Hermitian eigenvalue decomposition of a matrix. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, +see also [`gaugefix!`](@ref). """ @algdef CUSOLVER_Simple const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} """ - CUSOLVER_DivideAndConquer() + CUSOLVER_DivideAndConquer(; gaugefix::Bool = true) Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef CUSOLVER_DivideAndConquer @@ -283,37 +307,45 @@ the diagonal elements of `R` are non-negative. @algdef ROCSOLVER_HouseholderQR """ - ROCSOLVER_QRIteration() + ROCSOLVER_QRIteration(; gaugefix::Bool = true) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the QR Iteration algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef ROCSOLVER_QRIteration """ - ROCSOLVER_Jacobi() + ROCSOLVER_Jacobi(; gaugefix::Bool = true) Algorithm type to denote the ROCSOLVER driver for computing the singular value decomposition of a general matrix using the Jacobi algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the singular +vectors, see also [`gaugefix!`](@ref). """ @algdef ROCSOLVER_Jacobi """ - ROCSOLVER_Bisection() + ROCSOLVER_Bisection(; gaugefix::Bool = true) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Bisection algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef ROCSOLVER_Bisection """ - ROCSOLVER_DivideAndConquer() + ROCSOLVER_DivideAndConquer(; gaugefix::Bool = true) Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix using the Divide and Conquer algorithm. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef ROCSOLVER_DivideAndConquer From 364deb691e1d9b1ce01c0187b955cd02f746d9c0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 11:09:52 -0500 Subject: [PATCH 07/18] centralize gaugefix function --- src/common/gauge.jl | 53 ++++++++++++++++++++++++++++++++++++++ src/implementations/svd.jl | 53 +++----------------------------------- 2 files changed, 57 insertions(+), 49 deletions(-) diff --git a/src/common/gauge.jl b/src/common/gauge.jl index a9bf985b..571d62e4 100644 --- a/src/common/gauge.jl +++ b/src/common/gauge.jl @@ -1,3 +1,13 @@ +""" + gaugefix!(f_eig, V) + gaugefix!(f_svd, U, Vᴴ) + +Fix the residual gauge degrees of freedom in the eigen or singular vectors, that are +obtained from the decomposition performed by `f_eig` or `f_svd`. +This is achieved by ensuring that the entry with the largest magnitude in `V` or `U` +is real and positive. +""" gaugefix! + function gaugefix!(V::AbstractMatrix) for j in axes(V, 2) v = view(V, :, j) @@ -6,3 +16,46 @@ function gaugefix!(V::AbstractMatrix) end return V end + +function gaugefix!(::typeof(svd_full!), U, Vᴴ, m::Int, n::Int) + for j in 1:max(m, n) + if j <= min(m, n) + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + elseif j <= m + u = view(U, :, j) + s = conj(sign(_argmaxabs(u))) + u .*= s + else + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(v))) + v .*= s + end + end + return (U, Vᴴ) +end + +function gaugefix!(::typeof(svd_compact!), U, Vᴴ, m::Int, n::Int) + for j in 1:size(U, 2) + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + end + return (U, Vᴴ) +end + +function gaugefix!(::typeof(svd_trunc!), U, Vᴴ, m::Int, n::Int) + for j in 1:min(m, n) + u = view(U, :, j) + v = view(Vᴴ, j, :) + s = conj(sign(_argmaxabs(u))) + u .*= s + v .*= conj(s) + end + return (U, Vᴴ) +end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index f56b6d32..e9ec1aaf 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -105,51 +105,6 @@ function initialize_output(::typeof(svd_vals!), A::Diagonal, ::DiagonalAlgorithm return eltype(A) <: Real ? diagview(A) : similar(A, real(eltype(A)), size(A, 1)) end -# Gauge fixation -# -------------- -function gaugefix!(::typeof(svd_full!), U, S, Vᴴ, m::Int, n::Int) - for j in 1:max(m, n) - if j <= min(m, n) - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) - elseif j <= m - u = view(U, :, j) - s = conj(sign(_argmaxabs(u))) - u .*= s - else - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(v))) - v .*= s - end - end - return (U, S, Vᴴ) -end - -function gaugefix!(::typeof(svd_compact!), U, S, Vᴴ, m::Int, n::Int) - for j in 1:size(U, 2) - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) - end - return (U, S, Vᴴ) -end - -function gaugefix!(::typeof(svd_trunc!), U, S, Vᴴ, m::Int, n::Int) - for j in 1:min(m, n) - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) - end - return (U, S, Vᴴ) -end - # Implementation # -------------- function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) @@ -189,7 +144,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) S[i, 1] = zero(eltype(S)) end - dogaugefix && gaugefix!(svd_full!, U, S, Vᴴ, m, n) + dogaugefix && gaugefix!(svd_full!, U, Vᴴ, m, n) return USVᴴ end @@ -219,7 +174,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) throw(ArgumentError("Unsupported SVD algorithm")) end - dogaugefix && gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) + dogaugefix && gaugefix!(svd_compact!, U, Vᴴ, size(A)...) return USVᴴ end @@ -397,7 +352,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) diagview(S) .= view(S, 1:minmn, 1) view(S, 2:minmn, 1) .= zero(eltype(S)) - dogaugefix && gaugefix!(svd_full!, U, S, Vᴴ, m, n) + dogaugefix && gaugefix!(svd_full!, U, Vᴴ, m, n) return USVᴴ end @@ -436,7 +391,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) throw(ArgumentError("Unsupported SVD algorithm")) end - dogaugefix && gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) + dogaugefix && gaugefix!(svd_compact!, U, Vᴴ, size(A)...) return USVᴴ end From 9a1ce27586d0310a533ed3a6ac1fa97c52e442dc Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 11:11:55 -0500 Subject: [PATCH 08/18] remove unnecessary size arguments --- src/common/gauge.jl | 8 +++++--- src/implementations/svd.jl | 10 +++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/common/gauge.jl b/src/common/gauge.jl index 571d62e4..1cedd69a 100644 --- a/src/common/gauge.jl +++ b/src/common/gauge.jl @@ -17,7 +17,8 @@ function gaugefix!(V::AbstractMatrix) return V end -function gaugefix!(::typeof(svd_full!), U, Vᴴ, m::Int, n::Int) +function gaugefix!(::typeof(svd_full!), U, Vᴴ) + m, n = size(U, 1), size(Vᴴ, 2) for j in 1:max(m, n) if j <= min(m, n) u = view(U, :, j) @@ -38,7 +39,7 @@ function gaugefix!(::typeof(svd_full!), U, Vᴴ, m::Int, n::Int) return (U, Vᴴ) end -function gaugefix!(::typeof(svd_compact!), U, Vᴴ, m::Int, n::Int) +function gaugefix!(::typeof(svd_compact!), U, Vᴴ) for j in 1:size(U, 2) u = view(U, :, j) v = view(Vᴴ, j, :) @@ -49,7 +50,8 @@ function gaugefix!(::typeof(svd_compact!), U, Vᴴ, m::Int, n::Int) return (U, Vᴴ) end -function gaugefix!(::typeof(svd_trunc!), U, Vᴴ, m::Int, n::Int) +function gaugefix!(::typeof(svd_trunc!), U, Vᴴ) + m, n = size(U, 1), size(Vᴴ, 2) for j in 1:min(m, n) u = view(U, :, j) v = view(Vᴴ, j, :) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index e9ec1aaf..09d4377d 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -144,7 +144,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) S[i, 1] = zero(eltype(S)) end - dogaugefix && gaugefix!(svd_full!, U, Vᴴ, m, n) + dogaugefix && gaugefix!(svd_full!, U, Vᴴ) return USVᴴ end @@ -174,7 +174,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) throw(ArgumentError("Unsupported SVD algorithm")) end - dogaugefix && gaugefix!(svd_compact!, U, Vᴴ, size(A)...) + dogaugefix && gaugefix!(svd_compact!, U, Vᴴ) return USVᴴ end @@ -352,7 +352,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) diagview(S) .= view(S, 1:minmn, 1) view(S, 2:minmn, 1) .= zero(eltype(S)) - dogaugefix && gaugefix!(svd_full!, U, Vᴴ, m, n) + dogaugefix && gaugefix!(svd_full!, U, Vᴴ) return USVᴴ end @@ -363,7 +363,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) dogaugefix = get(alg.alg.kwargs, :gaugefix, true)::Bool - dogaugefix && gaugefix!(svd_trunc!, U, S, Vᴴ, size(A)...) + dogaugefix && gaugefix!(svd_trunc!, U, Vᴴ) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) @@ -391,7 +391,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) throw(ArgumentError("Unsupported SVD algorithm")) end - dogaugefix && gaugefix!(svd_compact!, U, Vᴴ, size(A)...) + dogaugefix && gaugefix!(svd_compact!, U, Vᴴ) return USVᴴ end From a876e280d27e648b624ff32babe2169ace93d1bb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 11:21:32 -0500 Subject: [PATCH 09/18] bring eig gaugefix in the same shape --- src/common/gauge.jl | 3 ++- src/implementations/eig.jl | 4 ++-- src/implementations/eigh.jl | 4 ++-- src/implementations/gen_eig.jl | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/common/gauge.jl b/src/common/gauge.jl index 1cedd69a..68a94c79 100644 --- a/src/common/gauge.jl +++ b/src/common/gauge.jl @@ -8,7 +8,8 @@ This is achieved by ensuring that the entry with the largest magnitude in `V` or is real and positive. """ gaugefix! -function gaugefix!(V::AbstractMatrix) + +function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix) for j in axes(V, 2) v = view(V, :, j) s = conj(sign(_argmaxabs(v))) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 8bcbbeb7..33e90ab3 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -93,7 +93,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) YALAPACK.geevx!(A, D.diag, V; lapack_kwargs...) end - dogaugefix && (V = gaugefix!(V)) + dogaugefix && (V = gaugefix!(eig_full!, V)) return D, V end @@ -153,7 +153,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) _gpu_geev!(A, D.diag, V) end - dogaugefix && (V = gaugefix!(V)) + dogaugefix && (V = gaugefix!(eig_full!, V)) return D, V end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index e37e7503..81809bc2 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -105,7 +105,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) YALAPACK.heevx!(A, Dd, V; lapack_kwargs...) end - dogaugefix && (V = gaugefix!(V)) + dogaugefix && (V = gaugefix!(eigh_full!, V)) return D, V end @@ -183,7 +183,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) throw(ArgumentError("Unsupported eigh algorithm")) end - dogaugefix && (V = gaugefix!(V)) + dogaugefix && (V = gaugefix!(eigh_full!, V)) return D, V end diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index e796ecfa..3f332ed9 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -69,7 +69,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig throw(ArgumentError("LAPACK_Expert is not supported for ggev")) end - dogaugefix && (V = gaugefix!(V)) + dogaugefix && (V = gaugefix!(gen_eig_full!, V)) return W, V end From ad84593107836c346adc2772037309157fd4d6d1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 11:21:50 -0500 Subject: [PATCH 10/18] move include order to ensure things are defined --- src/MatrixAlgebraKit.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 4a846f85..d3a04741 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -76,7 +76,6 @@ include("common/safemethods.jl") include("common/view.jl") include("common/regularinv.jl") include("common/matrixproperties.jl") -include("common/gauge.jl") include("yalapack.jl") include("algorithms.jl") @@ -93,6 +92,8 @@ include("interface/schur.jl") include("interface/polar.jl") include("interface/orthnull.jl") +include("common/gauge.jl") # needs to be defined after the functions are + include("implementations/projections.jl") include("implementations/truncation.jl") include("implementations/qr.jl") From ad98070ae1a1f98ccdf239a1716da7c98d034b9c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 11:26:37 -0500 Subject: [PATCH 11/18] update error messages in tests --- test/amd/svd.jl | 6 +++--- test/cuda/svd.jl | 6 +++--- test/gen_eig.jl | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/amd/svd.jl b/test/amd/svd.jl index 4f58f925..f8a48045 100644 --- a/test/amd/svd.jl +++ b/test/amd/svd.jl @@ -41,7 +41,7 @@ include(joinpath("..", "utilities.jl")) @test ROCArray(diagview(S)) ≈ Sd # ROCArray is necessary because norm of ROCArray view with non-unit step is broken if alg isa ROCSOLVER_QRIteration - @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) + @test_warn "invalid keyword arguments for GPU_QRIteration" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) end end end @@ -83,8 +83,8 @@ end @test ROCArray(diagview(S)) ≈ Sc # ROCArray is necessary because norm of ROCArray view with non-unit step is broken if alg isa ROCSOLVER_QRIteration - @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_full!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) - @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_vals!(copy!(Ac, A), Sc, ROCSOLVER_QRIteration(; bad = "bad")) + @test_warn "invalid keyword arguments for GPU_QRIteration" svd_full!(copy!(Ac, A), (U, S, Vᴴ), ROCSOLVER_QRIteration(; bad = "bad")) + @test_warn "invalid keyword arguments for GPU_QRIteration" svd_vals!(copy!(Ac, A), Sc, ROCSOLVER_QRIteration(; bad = "bad")) end end end diff --git a/test/cuda/svd.jl b/test/cuda/svd.jl index bfe993f8..e2173d03 100644 --- a/test/cuda/svd.jl +++ b/test/cuda/svd.jl @@ -41,7 +41,7 @@ include(joinpath("..", "utilities.jl")) @test CuArray(diagview(S)) ≈ Sd # CuArray is necessary because norm of CuArray view with non-unit step is broken if alg isa CUSOLVER_QRIteration - @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) + @test_warn "invalid keyword arguments for GPU_QRIteration" svd_compact!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) end end end @@ -84,8 +84,8 @@ end @test CuArray(diagview(S)) ≈ Sc # CuArray is necessary because norm of CuArray view with non-unit step is broken if alg isa CUSOLVER_QRIteration - @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) - @test_warn "GPU_QRIteration does not accept any keyword arguments" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad")) + @test_warn "invalid keyword arguments for GPU_QRIteration" svd_full!(copy!(Ac, A), (U, S, Vᴴ), CUSOLVER_QRIteration(; bad = "bad")) + @test_warn "invalid keyword arguments for GPU_QRIteration" svd_vals!(copy!(Ac, A), Sc, CUSOLVER_QRIteration(; bad = "bad")) end end end diff --git a/test/gen_eig.jl b/test/gen_eig.jl index eebfeee5..8f905b77 100644 --- a/test/gen_eig.jl +++ b/test/gen_eig.jl @@ -42,9 +42,9 @@ using LinearAlgebra: Diagonal A = randn(rng, T, m, m) B = randn(rng, T, m, m) @test_throws ArgumentError("LAPACK_Expert is not supported for ggev") gen_eig_full(A, B; alg = LAPACK_Expert()) - @test_throws ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments") gen_eig_full(A, B; alg = LAPACK_Simple(bad = "sad")) + @test_throws ArgumentError("invalid keyword arguments for LAPACK_Simple") gen_eig_full(A, B; alg = LAPACK_Simple(bad = "sad")) @test_throws ArgumentError("LAPACK_Expert is not supported for ggev") gen_eig_vals(A, B; alg = LAPACK_Expert()) - @test_throws ArgumentError("LAPACK_Simple (ggev) does not accept any keyword arguments") gen_eig_vals(A, B; alg = LAPACK_Simple(bad = "sad")) + @test_throws ArgumentError("invalid keyword arguments for LAPACK_Simple") gen_eig_vals(A, B; alg = LAPACK_Simple(bad = "sad")) # a tuple of the input types is passed to `default_algorithm` @test_throws MethodError MatrixAlgebraKit.default_algorithm(gen_eig_full, A, B) From 38085370ce052b92d534ba4ae8fe45c8c7327a3e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 11:42:18 -0500 Subject: [PATCH 12/18] update docs --- docs/src/user_interface/decompositions.md | 59 +++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index 1ae1fefe..a49294ef 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -72,6 +72,11 @@ eigh_trunc eigh_vals ``` +!!! note "Gauge Degrees of Freedom" + The eigenvectors returned by these functions have residual phase degrees of freedom. + By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. + See [Gauge choices](@ref sec_gaugefix) for more details. + Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms: ```@autodocs; canonical=false @@ -89,6 +94,11 @@ eig_trunc eig_vals ``` +!!! note "Gauge Degrees of Freedom" + The eigenvectors returned by these functions have residual phase degrees of freedom. + By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. + See [Gauge choices](@ref sec_gaugefix) for more details. + Alongside these functions, we provide a LAPACK-based implementation for dense arrays, as provided by the following algorithms: ```@autodocs; canonical=false @@ -137,6 +147,11 @@ svd_vals svd_trunc ``` +!!! note "Gauge Degrees of Freedom" + The singular vectors returned by these functions have residual phase degrees of freedom. + By default, MatrixAlgebraKit applies a gauge fixing convention to ensure reproducible results. + See [Gauge choices](@ref sec_gaugefix) for more details. + MatrixAlgebraKit again ships with LAPACK-based implementations for dense arrays: ```@autodocs; canonical=false @@ -371,3 +386,47 @@ norm(A * N1') < 1e-14 && norm(A * N2') < 1e-14 && # output true ``` + +## [Gauge choices](@id sec_gaugefix) + +Both eigenvalue and singular value decompositions have residual gauge degrees of freedom even when the eigenvalues or singular values are unique. +These arise from the fact that even after normalization, the eigenvectors and singular vectors are only determined up to a phase factor. + +### Phase Ambiguity in Decompositions + +For the eigenvalue decomposition `A * V = V * D`, if `v` is an eigenvector with eigenvalue `λ` and `|v| = 1`, then so is `e^(iθ) * v` for any real phase `θ`. +When `λ` is non-degenerate (i.e., has multiplicity 1), the eigenvector is unique up to this phase. + +Similarly, for the singular value decomposition `A = U * Σ * Vᴴ`, the singular vectors `u` and `v` corresponding to a non-degenerate singular value `σ` are unique only up to a common phase. +We can replace `u → e^(iθ) * u` and `vᴴ → e^(-iθ) * vᴴ` simultaneously. + +### Gauge Fixing Convention + +To remove this phase ambiguity and ensure reproducible results, MatrixAlgebraKit implements a gauge fixing convention by default. +The convention ensures that **the entry with the largest magnitude in each eigenvector or left singular vector is real and positive**. + +For eigenvectors, this means that for each column `v` of `V`, we multiply by `conj(sign(v[i]))` where `i` is the index of the entry with largest absolute value. + +For singular vectors, we apply the phase factor to both `u` and `v` based on the entry with largest magnitude in `u`. +This preserves the decomposition `A = U * Σ * Vᴴ` while fixing the gauge. + +### Disabling Gauge Fixing + +Gauge fixing is enabled by default for all eigenvalue and singular value decompositions. +If you prefer to obtain the raw results from the underlying LAPACK routines without gauge fixing, you can disable it using the `gaugefix` keyword argument: + +```julia +# With gauge fixing (default) +D, V = eigh_full(A) + +# Without gauge fixing +D, V = eigh_full(A; gaugefix = false) +``` + +The same keyword is available for `eig_full`, `eig_trunc`, `svd_full`, `svd_compact`, and `svd_trunc` functions. + +```@docs; canonical=false +MatrixAlgebraKit.gaugefix! +``` + + From 4e41998afba6fc0773c4c05e8ec717052e85e9ec Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 12 Nov 2025 11:45:00 -0500 Subject: [PATCH 13/18] add gaugefix keyword to GenericLinearAlgebra extension --- ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl | 18 +++++++++++++----- src/interface/decompositions.jl | 4 +++- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 7b078381..1b327cff 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -1,7 +1,7 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit -using MatrixAlgebraKit: sign_safe, check_input, diagview +using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix! using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! @@ -13,18 +13,26 @@ for f! in (:svd_compact!, :svd_full!, :svd_vals!) @eval MatrixAlgebraKit.initialize_output(::typeof($f!), A::AbstractMatrix, ::GLA_QRIteration) = nothing end -function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, ::GLA_QRIteration) +function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration) F = svd!(A) U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt - return MatrixAlgebraKit.gaugefix!(svd_compact!, U, S, Vᴴ, size(A)...) + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + dogaugefix && gaugefix!(svd_compact!, U, Vᴴ) + + return U, S, Vᴴ end -function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, ::GLA_QRIteration) +function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIteration) F = svd!(A; full = true) U, Vᴴ = F.U, F.Vt S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1)))) diagview(S) .= F.S - return MatrixAlgebraKit.gaugefix!(svd_full!, U, S, Vᴴ, size(A)...) + + dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + dogaugefix && gaugefix!(svd_full!, U, Vᴴ) + + return U, S, Vᴴ end function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, ::GLA_QRIteration) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 92937e00..df5d55e0 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -131,11 +131,13 @@ const LAPACK_EighAlgorithm = Union{ } """ - GLA_QRIteration() + GLA_QRIteration(; gaugefix::Bool = true) Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of a general matrix. +The `gaugefix` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). """ @algdef GLA_QRIteration From e22b2cf458bf31407d408d28bdca9bf364c7b6d8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 13 Nov 2025 10:21:02 -0500 Subject: [PATCH 14/18] rename `dogaugefix` to `do_gauge_fix` --- ...MatrixAlgebraKitGenericLinearAlgebraExt.jl | 8 ++++---- src/implementations/eig.jl | 8 ++++---- src/implementations/eigh.jl | 8 ++++---- src/implementations/gen_eig.jl | 4 ++-- src/implementations/svd.jl | 20 +++++++++---------- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 1b327cff..67f1072a 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -17,8 +17,8 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIte F = svd!(A) U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool - dogaugefix && gaugefix!(svd_compact!, U, Vᴴ) + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) return U, S, Vᴴ end @@ -29,8 +29,8 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIterat S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1)))) diagview(S) .= F.S - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool - dogaugefix && gaugefix!(svd_full!, U, Vᴴ) + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) return U, S, Vᴴ end diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 33e90ab3..4479f97d 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -82,7 +82,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) check_input(eig_full!, A, DV, alg) D, V = DV - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_Simple @@ -93,7 +93,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) YALAPACK.geevx!(A, D.diag, V; lapack_kwargs...) end - dogaugefix && (V = gaugefix!(eig_full!, V)) + do_gauge_fix && (V = gaugefix!(eig_full!, V)) return D, V end @@ -145,7 +145,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) check_input(eig_full!, A, DV, alg) D, V = DV - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_Simple @@ -153,7 +153,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) _gpu_geev!(A, D.diag, V) end - dogaugefix && (V = gaugefix!(eig_full!, V)) + do_gauge_fix && (V = gaugefix!(eig_full!, V)) return D, V end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 81809bc2..fc56bd94 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -92,7 +92,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) D, V = DV Dd = D.diag - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_MultipleRelativelyRobustRepresentations @@ -105,7 +105,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) YALAPACK.heevx!(A, Dd, V; lapack_kwargs...) end - dogaugefix && (V = gaugefix!(eigh_full!, V)) + do_gauge_fix && (V = gaugefix!(eigh_full!, V)) return D, V end @@ -168,7 +168,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) D, V = DV Dd = D.diag - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_Jacobi @@ -183,7 +183,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) throw(ArgumentError("Unsupported eigh algorithm")) end - dogaugefix && (V = gaugefix!(eigh_full!, V)) + do_gauge_fix && (V = gaugefix!(eigh_full!, V)) return D, V end diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index 3f332ed9..745be187 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -58,7 +58,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig check_input(gen_eig_full!, A, B, WV, alg) W, V = WV - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_Simple @@ -69,7 +69,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig throw(ArgumentError("LAPACK_Expert is not supported for ggev")) end - dogaugefix && (V = gaugefix!(gen_eig_full!, V)) + do_gauge_fix && (V = gaugefix!(gen_eig_full!, V)) return W, V end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 09d4377d..8f465a8f 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -120,7 +120,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) return USVᴴ end - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_QRIteration @@ -144,7 +144,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) S[i, 1] = zero(eltype(S)) end - dogaugefix && gaugefix!(svd_full!, U, Vᴴ) + do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) return USVᴴ end @@ -153,7 +153,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_QRIteration @@ -174,7 +174,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) throw(ArgumentError("Unsupported SVD algorithm")) end - dogaugefix && gaugefix!(svd_compact!, U, Vᴴ) + do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) return USVᴴ end @@ -336,7 +336,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) return USVᴴ end - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_QRIteration @@ -352,7 +352,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) diagview(S) .= view(S, 1:minmn, 1) view(S, 2:minmn, 1) .= zero(eltype(S)) - dogaugefix && gaugefix!(svd_full!, U, Vᴴ) + do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) return USVᴴ end @@ -362,8 +362,8 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran U, S, Vᴴ = USVᴴ _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) - dogaugefix = get(alg.alg.kwargs, :gaugefix, true)::Bool - dogaugefix && gaugefix!(svd_trunc!, U, Vᴴ) + do_gauge_fix = get(alg.alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix && gaugefix!(svd_trunc!, U, Vᴴ) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) @@ -377,7 +377,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ - dogaugefix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_QRIteration @@ -391,7 +391,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) throw(ArgumentError("Unsupported SVD algorithm")) end - dogaugefix && gaugefix!(svd_compact!, U, Vᴴ) + do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) return USVᴴ end From c553fdd52768a87c2361a544ee33b6aa77fb0bc1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 13 Nov 2025 10:29:44 -0500 Subject: [PATCH 15/18] add global gaugefix toggle --- docs/src/user_interface/decompositions.md | 3 ++- ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl | 6 +++--- src/common/defaults.jl | 17 +++++++++++++++++ src/implementations/eig.jl | 4 ++-- src/implementations/eigh.jl | 4 ++-- src/implementations/gen_eig.jl | 2 +- src/implementations/svd.jl | 10 +++++----- 7 files changed, 32 insertions(+), 14 deletions(-) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index a49294ef..510c66e2 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -424,9 +424,10 @@ D, V = eigh_full(A; gaugefix = false) ``` The same keyword is available for `eig_full`, `eig_trunc`, `svd_full`, `svd_compact`, and `svd_trunc` functions. +Additionally, the default value can also be controlled with a global toggle using [`MatrixAlgebraKit.default_gaugefix`](@ref). ```@docs; canonical=false MatrixAlgebraKit.gaugefix! +MatrixAlgebraKit.default_gaugefix ``` - diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index 67f1072a..e09fccfa 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -1,7 +1,7 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit -using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix! +using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, default_gaugefix using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! @@ -17,7 +17,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIte F = svd!(A) U, S, Vᴴ = F.U, Diagonal(F.S), F.Vt - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool do_gauge_fix && gaugefix!(svd_compact!, U, Vᴴ) return U, S, Vᴴ @@ -29,7 +29,7 @@ function MatrixAlgebraKit.svd_full!(A::AbstractMatrix, USVᴴ, alg::GLA_QRIterat S = MatrixAlgebraKit.zero!(similar(F.S, (size(U, 2), size(Vᴴ, 1)))) diagview(S) .= F.S - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool do_gauge_fix && gaugefix!(svd_full!, U, Vᴴ) return U, S, Vᴴ diff --git a/src/common/defaults.jl b/src/common/defaults.jl index 332807da..9f48c941 100644 --- a/src/common/defaults.jl +++ b/src/common/defaults.jl @@ -41,3 +41,20 @@ default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4) Default tolerance for deciding to warn if the provided `A` is not hermitian. """ default_hermitian_tol(A) = eps(norm(A, Inf))^(3 / 4) + + +const DEFAULT_GAUGEFIX = Ref(true) + +@doc """ + default_gaugefix() -> current_value + default_gaugefix(new_value::Bool) -> previous_value + +Global toggle for enabling or disabling the default behavior of gauge fixing the output of the eigen- and singular value decompositions. +""" default_gaugefix + +default_gaugefix() = DEFAULT_GAUGEFIX[] +function default_gaugefix(new_value::Bool) + previous_value = DEFAULT_GAUGEFIX[] + DEFAULT_GAUGEFIX[] = new_value + return previous_value +end diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index 4479f97d..e8fda774 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -82,7 +82,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) check_input(eig_full!, A, DV, alg) D, V = DV - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_Simple @@ -145,7 +145,7 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) check_input(eig_full!, A, DV, alg) D, V = DV - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_Simple diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index fc56bd94..0498ac86 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -92,7 +92,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) D, V = DV Dd = D.diag - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_MultipleRelativelyRobustRepresentations @@ -168,7 +168,7 @@ function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) D, V = DV Dd = D.diag - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_Jacobi diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index 745be187..473ed7b9 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -58,7 +58,7 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig check_input(gen_eig_full!, A, B, WV, alg) W, V = WV - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_Simple diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 8f465a8f..6cef8fd4 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -120,7 +120,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) return USVᴴ end - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_QRIteration @@ -153,7 +153,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_QRIteration @@ -336,7 +336,7 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) return USVᴴ end - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_QRIteration @@ -362,7 +362,7 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran U, S, Vᴴ = USVᴴ _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) - do_gauge_fix = get(alg.alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.alg.kwargs, :gaugefix, default_gaugefix())::Bool do_gauge_fix && gaugefix!(svd_trunc!, U, Vᴴ) # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong @@ -377,7 +377,7 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) check_input(svd_compact!, A, USVᴴ, alg) U, S, Vᴴ = USVᴴ - do_gauge_fix = get(alg.kwargs, :gaugefix, true)::Bool + do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_QRIteration From e0956f16e092bc23565d0969a3540d9e823a35c8 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 13 Nov 2025 10:33:32 -0500 Subject: [PATCH 16/18] implement review comments --- docs/src/user_interface/decompositions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index 510c66e2..60959cc4 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -413,7 +413,7 @@ This preserves the decomposition `A = U * Σ * Vᴴ` while fixing the gauge. ### Disabling Gauge Fixing Gauge fixing is enabled by default for all eigenvalue and singular value decompositions. -If you prefer to obtain the raw results from the underlying LAPACK routines without gauge fixing, you can disable it using the `gaugefix` keyword argument: +If you prefer to obtain the raw results from the underlying computational routines without gauge fixing, you can disable it using the `gaugefix` keyword argument: ```julia # With gauge fixing (default) From e663374d903eb15f580f04b6406bc432f53b0dbb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 13 Nov 2025 10:54:42 -0500 Subject: [PATCH 17/18] more careful about gaugefixing svd_trunc --- src/common/gauge.jl | 43 ++++++++++++++------------------------ src/implementations/svd.jl | 17 ++++++++------- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/src/common/gauge.jl b/src/common/gauge.jl index 68a94c79..e855548f 100644 --- a/src/common/gauge.jl +++ b/src/common/gauge.jl @@ -12,53 +12,42 @@ is real and positive. function gaugefix!(::Union{typeof(eig_full!), typeof(eigh_full!), typeof(gen_eig_full!)}, V::AbstractMatrix) for j in axes(V, 2) v = view(V, :, j) - s = conj(sign(_argmaxabs(v))) - @inbounds v .*= s + s = sign(_argmaxabs(v)) + @inbounds v .*= conj(s) end return V end function gaugefix!(::typeof(svd_full!), U, Vᴴ) - m, n = size(U, 1), size(Vᴴ, 2) + m, n = size(U, 2), size(Vᴴ, 1) for j in 1:max(m, n) if j <= min(m, n) u = view(U, :, j) v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) + s = sign(_argmaxabs(u)) + u .*= conj(s) + v .*= s elseif j <= m u = view(U, :, j) - s = conj(sign(_argmaxabs(u))) - u .*= s + s = sign(_argmaxabs(u)) + u .*= conj(s) else v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(v))) - v .*= s + s = sign(_argmaxabs(v)) + v .*= conj(s) end end return (U, Vᴴ) end -function gaugefix!(::typeof(svd_compact!), U, Vᴴ) - for j in 1:size(U, 2) - u = view(U, :, j) - v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) - end - return (U, Vᴴ) -end - -function gaugefix!(::typeof(svd_trunc!), U, Vᴴ) - m, n = size(U, 1), size(Vᴴ, 2) - for j in 1:min(m, n) +function gaugefix!(::Union{typeof(svd_compact!), typeof(svd_trunc!)}, U, Vᴴ) + @assert axes(U, 2) == axes(Vᴴ, 1) + for j in axes(U, 2) u = view(U, :, j) v = view(Vᴴ, j, :) - s = conj(sign(_argmaxabs(u))) - u .*= s - v .*= conj(s) + s = sign(_argmaxabs(u)) + u .*= conj(s) + v .*= s end return (U, Vᴴ) end diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 6cef8fd4..7bb40f8f 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -1,4 +1,4 @@ -# Inputs +# Input # ------ copy_input(::typeof(svd_full), A::AbstractMatrix) = copy!(similar(A, float(eltype(A))), A) copy_input(::typeof(svd_compact), A) = copy_input(svd_full, A) @@ -362,15 +362,16 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran U, S, Vᴴ = USVᴴ _gpu_Xgesvdr!(A, S.diag, U, Vᴴ; alg.alg.kwargs...) - do_gauge_fix = get(alg.alg.kwargs, :gaugefix, default_gaugefix())::Bool - do_gauge_fix && gaugefix!(svd_trunc!, U, Vᴴ) - # TODO: make sure that truncation is based on maxrank, otherwise this might be wrong - USVᴴtrunc, ind = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) - Strunc = diagview(USVᴴtrunc[2]) + (Utr, Str, Vᴴtr), _ = truncate(svd_trunc!, (U, S, Vᴴ), alg.trunc) + # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum - ϵ = sqrt(norm(A)^2 - norm(Strunc)^2) # is there a more accurate way to do this? - return USVᴴtrunc..., ϵ + ϵ = sqrt(norm(A)^2 - norm(diagview(Str))^2) # is there a more accurate way to do this? + + do_gauge_fix = get(alg.alg.kwargs, :gaugefix, default_gaugefix())::Bool + do_gauge_fix && gaugefix!(svd_trunc!, Utr, Vᴴtr) + + return Utr, Str, Vᴴtr, ϵ end function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) From d6b2dd349d2ab9471788198d8de6aa8fb22818fc Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 14 Nov 2025 08:42:30 -0500 Subject: [PATCH 18/18] rename `lapack_kwargs` to `alg_kwargs` --- src/implementations/eig.jl | 20 +++++++------- src/implementations/eigh.jl | 40 +++++++++++++-------------- src/implementations/gen_eig.jl | 8 +++--- src/implementations/svd.jl | 50 +++++++++++++++++----------------- 4 files changed, 59 insertions(+), 59 deletions(-) diff --git a/src/implementations/eig.jl b/src/implementations/eig.jl index e8fda774..bdb6981a 100644 --- a/src/implementations/eig.jl +++ b/src/implementations/eig.jl @@ -83,14 +83,14 @@ function eig_full!(A::AbstractMatrix, DV, alg::LAPACK_EigAlgorithm) D, V = DV do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_Simple - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) YALAPACK.geev!(A, D.diag, V) else # alg isa LAPACK_Expert - YALAPACK.geevx!(A, D.diag, V; lapack_kwargs...) + YALAPACK.geevx!(A, D.diag, V; alg_kwargs...) end do_gauge_fix && (V = gaugefix!(eig_full!, V)) @@ -102,14 +102,14 @@ function eig_vals!(A::AbstractMatrix, D, alg::LAPACK_EigAlgorithm) check_input(eig_vals!, A, D, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_Simple - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) YALAPACK.geev!(A, D, V) else # alg isa LAPACK_Expert - YALAPACK.geevx!(A, D, V; lapack_kwargs...) + YALAPACK.geevx!(A, D, V; alg_kwargs...) end return D @@ -146,10 +146,10 @@ function eig_full!(A::AbstractMatrix, DV, alg::GPU_EigAlgorithm) D, V = DV do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_Simple - isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_Simple" + isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_Simple" _gpu_geev!(A, D.diag, V) end @@ -162,10 +162,10 @@ function eig_vals!(A::AbstractMatrix, D, alg::GPU_EigAlgorithm) check_input(eig_vals!, A, D, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_Simple - isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_Simple" + isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_Simple" _gpu_geev!(A, D, V) end diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index 0498ac86..d4a67dd6 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -93,16 +93,16 @@ function eigh_full!(A::AbstractMatrix, DV, alg::LAPACK_EighAlgorithm) Dd = D.diag do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_MultipleRelativelyRobustRepresentations - YALAPACK.heevr!(A, Dd, V; lapack_kwargs...) + YALAPACK.heevr!(A, Dd, V; alg_kwargs...) elseif alg isa LAPACK_DivideAndConquer - YALAPACK.heevd!(A, Dd, V; lapack_kwargs...) + YALAPACK.heevd!(A, Dd, V; alg_kwargs...) elseif alg isa LAPACK_Simple - YALAPACK.heev!(A, Dd, V; lapack_kwargs...) + YALAPACK.heev!(A, Dd, V; alg_kwargs...) else # alg isa LAPACK_Expert - YALAPACK.heevx!(A, Dd, V; lapack_kwargs...) + YALAPACK.heevx!(A, Dd, V; alg_kwargs...) end do_gauge_fix && (V = gaugefix!(eigh_full!, V)) @@ -114,16 +114,16 @@ function eigh_vals!(A::AbstractMatrix, D, alg::LAPACK_EighAlgorithm) check_input(eigh_vals!, A, D, alg) V = similar(A, (size(A, 1), 0)) - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_MultipleRelativelyRobustRepresentations - YALAPACK.heevr!(A, D, V; lapack_kwargs...) + YALAPACK.heevr!(A, D, V; alg_kwargs...) elseif alg isa LAPACK_DivideAndConquer - YALAPACK.heevd!(A, D, V; lapack_kwargs...) + YALAPACK.heevd!(A, D, V; alg_kwargs...) elseif alg isa LAPACK_QRIteration # == LAPACK_Simple - YALAPACK.heev!(A, D, V; lapack_kwargs...) + YALAPACK.heev!(A, D, V; alg_kwargs...) else # alg isa LAPACK_Bisection == LAPACK_Expert - YALAPACK.heevx!(A, D, V; lapack_kwargs...) + YALAPACK.heevx!(A, D, V; alg_kwargs...) end return D @@ -169,16 +169,16 @@ function eigh_full!(A::AbstractMatrix, DV, alg::GPU_EighAlgorithm) Dd = D.diag do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_Jacobi - _gpu_heevj!(A, Dd, V; lapack_kwargs...) + _gpu_heevj!(A, Dd, V; alg_kwargs...) elseif alg isa GPU_DivideAndConquer - _gpu_heevd!(A, Dd, V; lapack_kwargs...) + _gpu_heevd!(A, Dd, V; alg_kwargs...) elseif alg isa GPU_QRIteration # alg isa GPU_QRIteration == GPU_Simple - _gpu_heev!(A, Dd, V; lapack_kwargs...) + _gpu_heev!(A, Dd, V; alg_kwargs...) elseif alg isa GPU_Bisection # alg isa GPU_Bisection == GPU_Expert - _gpu_heevx!(A, Dd, V; lapack_kwargs...) + _gpu_heevx!(A, Dd, V; alg_kwargs...) else throw(ArgumentError("Unsupported eigh algorithm")) end @@ -192,16 +192,16 @@ function eigh_vals!(A::AbstractMatrix, D, alg::GPU_EighAlgorithm) check_input(eigh_vals!, A, D, alg) V = similar(A, (size(A, 1), 0)) - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_Jacobi - _gpu_heevj!(A, D, V; lapack_kwargs...) + _gpu_heevj!(A, D, V; alg_kwargs...) elseif alg isa GPU_DivideAndConquer - _gpu_heevd!(A, D, V; lapack_kwargs...) + _gpu_heevd!(A, D, V; alg_kwargs...) elseif alg isa GPU_QRIteration - _gpu_heev!(A, D, V; lapack_kwargs...) + _gpu_heev!(A, D, V; alg_kwargs...) elseif alg isa GPU_Bisection - _gpu_heevx!(A, D, V; lapack_kwargs...) + _gpu_heevx!(A, D, V; alg_kwargs...) else throw(ArgumentError("Unsupported eigh algorithm")) end diff --git a/src/implementations/gen_eig.jl b/src/implementations/gen_eig.jl index 473ed7b9..94dfe47e 100644 --- a/src/implementations/gen_eig.jl +++ b/src/implementations/gen_eig.jl @@ -59,10 +59,10 @@ function gen_eig_full!(A::AbstractMatrix, B::AbstractMatrix, WV, alg::LAPACK_Eig W, V = WV do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_Simple - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) YALAPACK.ggev!(A, B, W.diag, V, similar(W.diag, eltype(A))) else # alg isa LAPACK_Expert @@ -78,10 +78,10 @@ function gen_eig_vals!(A::AbstractMatrix, B::AbstractMatrix, W, alg::LAPACK_EigA check_input(gen_eig_vals!, A, B, W, alg) V = similar(A, complex(eltype(A)), (size(A, 1), 0)) - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_Simple - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_Simple")) YALAPACK.ggev!(A, B, W, V, similar(W, eltype(A))) else # alg isa LAPACK_Expert diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 7bb40f8f..2066b2cd 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -121,14 +121,14 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) end do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_QRIteration - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) YALAPACK.gesvd!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa LAPACK_DivideAndConquer - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa LAPACK_Bisection @@ -154,20 +154,20 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::LAPACK_SVDAlgorithm) U, S, Vᴴ = USVᴴ do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_QRIteration - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) YALAPACK.gesvd!(A, S.diag, U, Vᴴ) elseif alg isa LAPACK_DivideAndConquer - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, S.diag, U, Vᴴ) elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; lapack_kwargs...) + YALAPACK.gesvdx!(A, S.diag, U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) YALAPACK.gesvj!(A, S.diag, U, Vᴴ) else @@ -183,20 +183,20 @@ function svd_vals!(A::AbstractMatrix, S, alg::LAPACK_SVDAlgorithm) check_input(svd_vals!, A, S, alg) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa LAPACK_QRIteration - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_QRIteration")) YALAPACK.gesvd!(A, S, U, Vᴴ) elseif alg isa LAPACK_DivideAndConquer - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_DivideAndConquer")) YALAPACK.gesdd!(A, S, U, Vᴴ) elseif alg isa LAPACK_Bisection - YALAPACK.gesvdx!(A, S, U, Vᴴ; lapack_kwargs...) + YALAPACK.gesvdx!(A, S, U, Vᴴ; alg_kwargs...) elseif alg isa LAPACK_Jacobi - isempty(lapack_kwargs) || + isempty(alg_kwargs) || throw(ArgumentError("invalid keyword arguments for LAPACK_Jacobi")) YALAPACK.gesvj!(A, S, U, Vᴴ) else @@ -337,15 +337,15 @@ function svd_full!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) end do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_QRIteration - isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" + isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" _gpu_gesvd_maybe_transpose!(A, view(S, 1:minmn, 1), U, Vᴴ) elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; lapack_kwargs...) + _gpu_Xgesvdp!(A, view(S, 1:minmn, 1), U, Vᴴ; alg_kwargs...) elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; lapack_kwargs...) + _gpu_gesvdj!(A, view(S, 1:minmn, 1), U, Vᴴ; alg_kwargs...) else throw(ArgumentError("Unsupported SVD algorithm")) end @@ -379,15 +379,15 @@ function svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAlgorithm) U, S, Vᴴ = USVᴴ do_gauge_fix = get(alg.kwargs, :gaugefix, default_gaugefix())::Bool - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_QRIteration - isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" + isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" _gpu_gesvd_maybe_transpose!(A, S.diag, U, Vᴴ) elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; lapack_kwargs...) + _gpu_Xgesvdp!(A, S.diag, U, Vᴴ; alg_kwargs...) elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, S.diag, U, Vᴴ; lapack_kwargs...) + _gpu_gesvdj!(A, S.diag, U, Vᴴ; alg_kwargs...) else throw(ArgumentError("Unsupported SVD algorithm")) end @@ -403,15 +403,15 @@ function svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm) check_input(svd_vals!, A, S, alg) U, Vᴴ = similar(A, (0, 0)), similar(A, (0, 0)) - lapack_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) + alg_kwargs = Base.structdiff(alg.kwargs, NamedTuple{(:gaugefix,)}) if alg isa GPU_QRIteration - isempty(lapack_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" + isempty(alg_kwargs) || @warn "invalid keyword arguments for GPU_QRIteration" _gpu_gesvd_maybe_transpose!(A, S, U, Vᴴ) elseif alg isa GPU_SVDPolar - _gpu_Xgesvdp!(A, S, U, Vᴴ; lapack_kwargs...) + _gpu_Xgesvdp!(A, S, U, Vᴴ; alg_kwargs...) elseif alg isa GPU_Jacobi - _gpu_gesvdj!(A, S, U, Vᴴ; lapack_kwargs...) + _gpu_gesvdj!(A, S, U, Vᴴ; alg_kwargs...) else throw(ArgumentError("Unsupported SVD algorithm")) end