From 6d43f15f558c7ce7fb9908ed78c70275a477ebd9 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 10 Oct 2025 13:23:07 -0400 Subject: [PATCH 01/47] `left_orth` and `right_orth` with `@functiondef` --- src/interface/orthnull.jl | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index c27179f8..00640acc 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -49,14 +49,7 @@ be used. See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) """ -function left_orth end -function left_orth! end -function left_orth!(A; kwargs...) - return left_orth!(A, initialize_output(left_orth!, A); kwargs...) -end -function left_orth(A; kwargs...) - return left_orth!(copy_input(left_orth, A); kwargs...) -end +@functiondef left_orth """ right_orth(A; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ @@ -108,14 +101,7 @@ be used. See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) """ -function right_orth end -function right_orth! end -function right_orth!(A; kwargs...) - return right_orth!(A, initialize_output(right_orth!, A); kwargs...) -end -function right_orth(A; kwargs...) - return right_orth!(copy_input(right_orth, A); kwargs...) -end +@functiondef right_orth # Null functions # -------------- From 9c558d929c9527857466b2a3ef369d606508f434 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 19:16:34 -0400 Subject: [PATCH 02/47] refactor truncationintersection for type stability --- src/implementations/truncation.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index f6201b07..df327c8b 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -95,16 +95,20 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real = 0, rtol::Real = end function findtruncated(values::AbstractVector, strategy::TruncationIntersection) - return mapreduce( - Base.Fix1(findtruncated, values), _ind_intersect, strategy.components; - init = trues(length(values)) - ) + length(strategy.components) == 0 && return eachindex(values) + length(strategy.components) == 1 && return findtruncated(values, only(strategy.components)) + + ind1 = findtruncated(values, strategy.components[1]) + ind2 = findtruncated(values, TruncationIntersection(Base.tail(strategy.components))) + return _ind_intersect(ind1, ind2) end function findtruncated_svd(values::AbstractVector, strategy::TruncationIntersection) - return mapreduce( - Base.Fix1(findtruncated_svd, values), _ind_intersect, - strategy.components; init = trues(length(values)) - ) + length(strategy.components) == 0 && return eachindex(values) + length(strategy.components) == 1 && return findtruncated_svd(values, only(strategy.components)) + + ind1 = findtruncated_svd(values, strategy.components[1]) + ind2 = findtruncated_svd(values, TruncationIntersection(Base.tail(strategy.components))) + return _ind_intersect(ind1, ind2) end # when one of the ind selections is a bitvector, have to handle differently From 651dca3bbf63b3f96e63f69bc7c3c8f0beb6e5c2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 19:17:38 -0400 Subject: [PATCH 03/47] factor out `linearmap.jl` --- test/linearmap.jl | 7 ------- test/orthnull.jl | 7 ++++--- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/test/linearmap.jl b/test/linearmap.jl index 61753cde..bbf84e5e 100644 --- a/test/linearmap.jl +++ b/test/linearmap.jl @@ -34,13 +34,6 @@ module LinearMaps LinearMap.(MAK.$f!(parent(A), parent.(F), alg)) end - for f! in (:left_orth!, :right_orth!) - @eval MAK.check_input(::typeof($f!), A::LinearMap, F, alg) = - MAK.check_input($f!, parent(A), parent.(F), alg) - @eval MAK.initialize_output(::typeof($f!), A::LinearMap) = - LinearMap.(MAK.initialize_output($f!, parent(A))) - end - for f in (:qr, :lq, :svd) default_f = Symbol(:default_, f, :_algorithm) @eval MAK.$default_f(::Type{LinearMap{A}}; kwargs...) where {A} = MAK.$default_f(A; kwargs...) diff --git a/test/orthnull.jl b/test/orthnull.jl index cee54252..8d690375 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -2,9 +2,10 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: LinearAlgebra, I, mul! -using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, - initialize_output, AbstractAlgorithm +using LinearAlgebra: LinearAlgebra, I + +# testing non-AbstractArray codepaths: +include("linearmap.jl") # testing non-AbstractArray codepaths: include("linearmap.jl") From d4da1ec2409df63af5443ed476a90ef4601cfd39 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 19:19:09 -0400 Subject: [PATCH 04/47] refactor orth algorithm selection --- src/interface/orthnull.jl | 64 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 00640acc..2851f1b2 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -51,6 +51,11 @@ See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [ """ @functiondef left_orth +# helper functions +function left_orth_qr! end +function left_orth_polar! end +function left_orth_svd! end + """ right_orth(A; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ right_orth!(A, [CVᴴ]; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ @@ -103,6 +108,11 @@ See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`r """ @functiondef right_orth +# helper functions +function right_orth_lq! end +function right_orth_polar! end +function right_orth_svd! end + # Null functions # -------------- """ @@ -204,3 +214,57 @@ end function right_null(A; kwargs...) return right_null!(copy_input(right_null, A); kwargs...) end + +# Algorithm selection +# ------------------- +# specific override for `alg::Symbol` case, to allow for choosing the kind of factorization. +function select_algorithm(::typeof(left_orth!), A, alg::Symbol; trunc = nothing, kwargs...) + alg === :svd && return select_algorithm( + left_orth_svd!, A, get(kwargs, :alg_svd, nothing); trunc + ) + + isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) + + alg === :qr && return select_algorithm(left_orth_qr!, A, get(kwargs, :alg_qr, nothing)) + alg === :polar && return select_algorithm(left_orth_polar!, A, get(kwargs, :alg_polar, nothing)) + + throw(ArgumentError(lazy"Unknown alg symbol $alg")) +end + +default_algorithm(::typeof(left_orth!), A; trunc = nothing, kwargs...) = + isnothing(trunc) ? select_algorithm(left_orth_qr!, A; kwargs...) : + select_algorithm(left_orth_svd!, A; trunc, kwargs...) + +select_algorithm(::typeof(left_orth_qr!), A, alg = nothing; kwargs...) = + select_algorithm(qr_compact!, A, alg; kwargs...) +select_algorithm(::typeof(left_orth_polar!), A, alg = nothing; kwargs...) = + select_algorithm(left_polar!, A, alg; kwargs...) +select_algorithm(::typeof(left_orth_svd!), A, alg = nothing; trunc = nothing, kwargs...) = + isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : + select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) + +# specific override for `alg::Symbol` case, to allow for choosing the kind of factorization. +function select_algorithm(::typeof(right_orth!), A, alg::Symbol; trunc = nothing, kwargs...) + alg === :svd && return select_algorithm( + right_orth_svd!, A, get(kwargs, :alg_svd, nothing); trunc + ) + + isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) + + alg === :lq && return select_algorithm(right_orth_lq!, A, get(kwargs, :alg_lq, nothing)) + alg === :polar && return select_algorithm(right_orth_polar!, A, get(kwargs, :alg_polar, nothing)) + + throw(ArgumentError(lazy"Unknown alg symbol $alg")) +end + +default_algorithm(::typeof(right_orth!), A; trunc = nothing, kwargs...) = + isnothing(trunc) ? select_algorithm(right_orth_lq!, A; kwargs...) : + select_algorithm(right_orth_svd!, A; trunc, kwargs...) + +select_algorithm(::typeof(right_orth_lq!), A, alg = nothing; kwargs...) = + select_algorithm(lq_compact!, A, alg; kwargs...) +select_algorithm(::typeof(right_orth_polar!), A, alg = nothing; kwargs...) = + select_algorithm(right_polar!, A, alg; kwargs...) +select_algorithm(::typeof(right_orth_svd!), A, alg = nothing; trunc = nothing, kwargs...) = + isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : + select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) From 136cf62b35cea3060af812ac1b2123ffa89d6248 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 19:43:40 -0400 Subject: [PATCH 05/47] add algorithm traits --- src/algorithms.jl | 55 +++++++++++++++++++++++++++++++++ src/interface/decompositions.jl | 24 ++++++++++---- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index a47be2c8..c4cd64e1 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -53,6 +53,57 @@ function _show_alg(io::IO, alg::Algorithm) return print(io, ")") end +# Algorithm traits +# ---------------- +""" + left_orth_kind(alg::AbstractAlgorithm) -> f! + +Select an appropriate factorization function for applying `left_orth!(A, alg)`. +By default, this is either `left_orth_qr!`, `left_orth_polar!` or `left_orth_svd!`, but +this can be extended to insert arbitrary other decomposition functions, which should follow +the signature `f!(A, F, alg) -> F` +""" +left_orth_kind(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `left_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type, define: + + MatrixAlgebraKit.left_orth_kind(alg) = f! + + where `f!` should be the factorization function that will be used. + By default, this is either `left_orth_qr!`, `left_orth_polar!` or `left_orth_svd!`. + """ +) + +""" + right_orth_kind(alg::AbstractAlgorithm) -> f! + +Select an appropriate factorization function for applying `right_orth!(A, alg)`. +By default, this is either `right_orth_lq!`, `right_orth_polar!` or `right_orth_svd!`, but +this can be extended to insert arbitrary other decomposition functions, which should follow +the signature `f!(A, F, alg) -> F` +""" +right_orth_kind(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `right_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type, define: + + MatrixAlgebraKit.right_orth_kind(alg) = f! + + where `f!` should be the factorization function that will be used. + By default, this is either `right_orth_lq!`, `right_orth_polar!` or `right_orth_svd!`. + """ +) + +""" + does_truncate(alg::AbstractAlgorithm) -> Bool + +Check whether or not an algorithm can be used for a truncated decomposition. +""" +does_truncate(alg::AbstractAlgorithm) = false + +# Algorithm selection +# ------------------- @doc """ MatrixAlgebraKit.select_algorithm(f, A, alg::AbstractAlgorithm) MatrixAlgebraKit.select_algorithm(f, A, alg::Symbol; kwargs...) @@ -200,6 +251,10 @@ struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm trunc::T end +left_orth_kind(alg::TruncatedAlgorithm) = left_orth_kind(alg.alg) +right_orth_kind(alg::TruncatedAlgorithm) = right_orth_kind(alg.alg) +does_truncate(::TruncatedAlgorithm) = true + # Utility macros # -------------- diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index bdda6612..ef6af6c7 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -36,6 +36,9 @@ elements of `L` are non-negative. @algdef LAPACK_HouseholderQL @algdef LAPACK_HouseholderRQ +left_orth_kind(::Union{LAPACK_HouseholderQR, LAPACK_HouseholderQL}) = left_orth_qr! +right_orth_kind(::Union{LAPACK_HouseholderLQ, LAPACK_HouseholderRQ}) = right_orth_lq! + # General Eigenvalue Decomposition # ------------------------------- """ @@ -117,6 +120,9 @@ const LAPACK_SVDAlgorithm = Union{ LAPACK_Jacobi, } +left_orth_kind(::LAPACK_SVDAlgorithm) = left_orth_svd! +right_orth_kind(::LAPACK_SVDAlgorithm) = right_orth_svd! + # ========================= # Polar decompositions # ========================= @@ -139,6 +145,9 @@ until convergence up to tolerance `tol`. """ @algdef PolarNewton +left_orth_kind(::Union{PolarViaSVD, PolarNewton}) = left_orth_polar! +right_orth_kind(::Union{PolarViaSVD, PolarNewton}) = right_orth_polar! + # ========================= # DIAGONAL ALGORITHMS # ========================= @@ -162,6 +171,8 @@ the diagonal elements of `R` are non-negative. """ @algdef CUSOLVER_HouseholderQR +left_orth_kind(::CUSOLVER_HouseholderQR) = left_orth_qr! + """ CUSOLVER_QRIteration() @@ -203,6 +214,8 @@ for more information. """ @algdef CUSOLVER_Randomized +does_truncate(::CUSOLVER_Randomized) = true + """ CUSOLVER_Simple() @@ -276,9 +289,8 @@ const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} const GPU_Bisection = Union{ROCSOLVER_Bisection} -const GPU_EighAlgorithm = Union{ - GPU_QRIteration, - GPU_Jacobi, - GPU_DivideAndConquer, - GPU_Bisection, -} +const GPU_SVDAlgorithm = Union{GPU_Jacobi, GPU_DivideAndConquer, GPU_QRIteration} +const GPU_EighAlgorithm = Union{GPU_QRIteration, GPU_Jacobi, GPU_DivideAndConquer, GPU_Bisection, CUSOLVER_SVDPolar} + +left_orth_kind(::GPU_SVDAlgorithm) = left_orth_svd! +right_orth_kind(::GPU_SVDAlgorithm) = right_orth_svd! From 545ef5d76dd2f13c83bde94ba77f7101aa742199 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 19:43:56 -0400 Subject: [PATCH 06/47] refactor left_orth and right_orth implementations --- src/implementations/orthnull.jl | 230 ++++++++++---------------------- 1 file changed, 71 insertions(+), 159 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 76e32ce1..ea57d916 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -5,32 +5,48 @@ copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever nee copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else -function check_input(::typeof(left_orth!), A::AbstractMatrix, VC, ::AbstractAlgorithm) - m, n = size(A) - minmn = min(m, n) - V, C = VC - @assert V isa AbstractMatrix && C isa AbstractMatrix - @check_size(V, (m, minmn)) - @check_scalar(V, A) - if !isempty(C) - @check_size(C, (minmn, n)) - @check_scalar(C, A) - end - return nothing -end -function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ, ::AbstractAlgorithm) - m, n = size(A) - minmn = min(m, n) - C, Vᴴ = CVᴴ - @assert C isa AbstractMatrix && Vᴴ isa AbstractMatrix - if !isempty(C) - @check_size(C, (m, minmn)) - @check_scalar(C, A) - end - @check_size(Vᴴ, (minmn, n)) - @check_scalar(Vᴴ, A) - return nothing -end +check_input(::typeof(left_orth!), A, VC, alg::AbstractAlgorithm) = + check_input(left_orth_kind(alg), A, VC, alg) + +check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(right_orth_kind(alg), A, CVᴴ, alg) + + +check_input(::typeof(left_orth_qr!), A, VC, alg::AbstractAlgorithm) = + check_input(qr_compact!, A, VC, alg) +check_input(::typeof(left_orth_polar!), A, VC, alg::AbstractAlgorithm) = + check_input(left_polar!, A, VC, alg) +check_input(::typeof(left_orth_svd!), A, VC, alg::AbstractAlgorithm) = + check_input(qr_compact!, A, VC, alg) + +check_input(::typeof(right_orth_lq!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(lq_compact!, A, CVᴴ, alg) +check_input(::typeof(right_orth_polar!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(right_polar!, A, CVᴴ, alg) +check_input(::typeof(right_orth_svd!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(lq_compact!, A, CVᴴ, alg) + + +initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = + initialize_output(left_orth_kind(alg), A, alg) +initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = + initialize_output(right_orth_kind(alg), A, alg) + + +initialize_output(::typeof(left_orth_qr!), A, alg::AbstractAlgorithm) = + initialize_output(qr_compact!, A, alg) +initialize_output(::typeof(left_orth_polar!), A, alg::AbstractAlgorithm) = + initialize_output(left_polar!, A, alg) +initialize_output(::typeof(left_orth_svd!), A, alg::AbstractAlgorithm) = + initialize_output(qr_compact!, A, alg) + +initialize_output(::typeof(right_orth_lq!), A, alg::AbstractAlgorithm) = + initialize_output(lq_compact!, A, alg) +initialize_output(::typeof(right_orth_polar!), A, alg::AbstractAlgorithm) = + initialize_output(right_polar!, A, alg) +initialize_output(::typeof(right_orth_svd!), A, alg::AbstractAlgorithm) = + initialize_output(lq_compact!, A, alg) + function check_input(::typeof(left_null!), A::AbstractMatrix, N, ::AbstractAlgorithm) m, n = size(A) @@ -51,20 +67,13 @@ end # Outputs # ------- -function initialize_output(::typeof(left_orth!), A::AbstractMatrix) - m, n = size(A) - minmn = min(m, n) - V = similar(A, (m, minmn)) - C = similar(A, (minmn, n)) - return (V, C) -end -function initialize_output(::typeof(right_orth!), A::AbstractMatrix) - m, n = size(A) - minmn = min(m, n) - C = similar(A, (m, minmn)) - Vᴴ = similar(A, (minmn, n)) - return (C, Vᴴ) + +function initialize_orth_svd(A::AbstractMatrix, F, alg) + S = Diagonal(initialize_output(svd_vals!, A, alg)) + return F[1], S, F[2] end +# fallback doesn't re-use F at all +initialize_orth_svd(A, F, alg) = initialize_output(svd_compact!, A, alg) function initialize_output(::typeof(left_null!), A::AbstractMatrix) m, n = size(A) @@ -81,126 +90,28 @@ end # Implementation of orth functions # -------------------------------- -function left_orth!( - A, VC; - trunc = nothing, kind = isnothing(trunc) ? :qr : :svd, - alg_qr = (; positive = true), alg_polar = (;), alg_svd = (;) - ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) - end - if kind == :qr - return left_orth_qr!(A, VC, alg_qr) - elseif kind == :polar - return left_orth_polar!(A, VC, alg_polar) - elseif kind == :svd - return left_orth_svd!(A, VC, alg_svd, trunc) - else - throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) - end -end -function left_orth_qr!(A, VC, alg) - alg′ = select_algorithm(qr_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - return qr_compact!(A, VC, alg′) -end -function left_orth_polar!(A, VC, alg) - alg′ = select_algorithm(left_polar!, A, alg) - check_input(left_orth!, A, VC, alg′) - return left_polar!(A, VC, alg′) -end -function left_orth_svd!(A, VC, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - U, S, Vᴴ = svd_compact!(A, alg′) - V, C = VC - return copy!(V, U), mul!(C, S, Vᴴ) -end -function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg′)) - U, S, Vᴴ = svd_compact!(A, (V, S, C), alg′) - return U, lmul!(S, Vᴴ) -end -function left_orth_svd!(A, VC, alg, trunc) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - U, S, Vᴴ = svd_trunc!(A, alg_trunc) - V, C = VC - return copy!(V, U), mul!(C, S, Vᴴ) -end -function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc) - return U, lmul!(S, Vᴴ) -end +left_orth!(A, VC, alg::AbstractAlgorithm) = left_orth_kind(alg)(A, VC, alg) +right_orth!(A, CVᴴ, alg::AbstractAlgorithm) = right_orth_kind(alg)(A, CVᴴ, alg) -function right_orth!( - A, CVᴴ; - trunc = nothing, kind = isnothing(trunc) ? :lq : :svd, - alg_lq = (; positive = true), alg_polar = (;), alg_svd = (;) - ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) - end - if kind == :lq - return right_orth_lq!(A, CVᴴ, alg_lq) - elseif kind == :polar - return right_orth_polar!(A, CVᴴ, alg_polar) - elseif kind == :svd - return right_orth_svd!(A, CVᴴ, alg_svd, trunc) - else - throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`")) - end -end -function right_orth_lq!(A, CVᴴ, alg) - alg′ = select_algorithm(lq_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - return lq_compact!(A, CVᴴ, alg′) -end -function right_orth_polar!(A, CVᴴ, alg) - alg′ = select_algorithm(right_polar!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - return right_polar!(A, CVᴴ, alg′) -end -function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - U, S, Vᴴ′ = svd_compact!(A, alg′) - C, Vᴴ = CVᴴ - return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) -end -function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg′)) - U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg′) - return rmul!(U, S), Vᴴ -end -function right_orth_svd!(A, CVᴴ, alg, trunc) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - U, S, Vᴴ′ = svd_trunc!(A, alg_trunc) - C, Vᴴ = CVᴴ - return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) -end -function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc) - return rmul!(U, S), Vᴴ +left_orth_qr!(A, VC, alg::AbstractAlgorithm) = qr_compact!(A, VC, alg) +right_orth_lq!(A, CVᴴ, alg::AbstractAlgorithm) = lq_compact!(A, CVᴴ, alg) +left_orth_polar!(A, VC, alg::AbstractAlgorithm) = left_polar!(A, VC, alg) +right_orth_polar!(A, CVᴴ, alg::AbstractAlgorithm) = right_polar!(A, CVᴴ, alg) + +# orth_svd requires implementations of `lmul!` and `rmul!` +function left_orth_svd!(A, VC, alg::AbstractAlgorithm) + check_input(left_orth_svd!, A, VC, alg) + USVᴴ = initialize_orth_svd(A, VC, alg) + V, S, C = does_truncate(alg) ? svd_trunc!(A, USVᴴ, alg) : svd_compact!(A, USVᴴ, alg) + lmul!(S, C) + return V, C +end +function right_orth_svd!(A, CVᴴ, alg::AbstractAlgorithm) + check_input(right_orth_svd!, A, CVᴴ, alg) + USVᴴ = initialize_orth_svd(A, CVᴴ, alg) + C, S, Vᴴ = does_truncate(alg) ? svd_trunc!(A, USVᴴ, alg) : svd_compact!(A, USVᴴ, alg) + rmul!(C, S) + return C, Vᴴ end # Implementation of null functions @@ -249,7 +160,8 @@ function left_null_svd!(A, N, alg, trunc) trunc′ = trunc isa TruncationStrategy ? trunc : trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : throw(ArgumentError("Unknown truncation strategy: $trunc")) - return first(truncate(left_null!, (U, S), trunc′)) + N, ind = truncate(left_null!, (U, S), trunc′) + return N end function right_null!( From 1d0684028619b8704535efc2b39adc5bb9aa8564 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 20:21:02 -0400 Subject: [PATCH 07/47] refactor null algorithm selection --- src/algorithms.jl | 18 +++++++++++ src/interface/orthnull.jl | 60 +++++++++++++++++++++++++++++++++++++ src/interface/truncation.jl | 10 +++++++ 3 files changed, 88 insertions(+) diff --git a/src/algorithms.jl b/src/algorithms.jl index c4cd64e1..ed42988a 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -211,6 +211,24 @@ function select_truncation(trunc) end end +@doc """ + MatrixAlgebraKit.select_null_truncation(trunc) + +Construct a [`TruncationStrategy`](@ref) from the given `NamedTuple` of keywords or input strategy, to implement a nullspace selection. +""" select_null_truncation + +function select_null_truncation(trunc) + if isnothing(trunc) + return NoTruncation() + elseif trunc isa NamedTuple + return null_truncation_strategy(; trunc...) + elseif trunc isa TruncationStrategy + return trunc + else + return throw(ArgumentError("Unknown truncation strategy: $trunc")) + end +end + @doc """ MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 2851f1b2..a459d512 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -268,3 +268,63 @@ select_algorithm(::typeof(right_orth_polar!), A, alg = nothing; kwargs...) = select_algorithm(::typeof(right_orth_svd!), A, alg = nothing; trunc = nothing, kwargs...) = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) + +function select_algorithm(::typeof(left_null!), A, alg::Symbol; trunc = nothing, kwargs...) + alg === :svd && return select_algorithm( + left_null_svd!, A, get(kwargs, :alg_svd, nothing); trunc + ) + + isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) + + alg === :qr && return select_algorithm(left_null_qr!, A, get(kwargs, :alg_qr, nothing)) + + throw(ArgumentError(lazy"unkown alg symbol $alg")) +end + +default_algorithm(::typeof(left_null!), A; trunc = nothing, kwargs...) = + isnothing(trunc) ? select_algorithm(left_null_qr!, A; kwargs...) : + select_algorithm(left_null_svd!, A; trunc, kwargs...) + +select_algorithm(::typeof(left_null_qr!), A, alg = nothing; kwargs...) = + select_algorithm(qr_null!, A, alg; kwargs...) +function select_algorithm(::typeof(left_null_svd!), A, alg = nothing; trunc = nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) + end +end + +function select_algorithm(::typeof(right_null!), A, alg::Symbol; trunc = nothing, kwargs...) + alg === :svd && return select_algorithm( + right_null_svd!, A, get(kwargs, :alg_svd, nothing); trunc + ) + + isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) + + alg === :lq && return select_algorithm(right_null_lq!, A, get(kwargs, :alg_lq, nothing)) + + throw(ArgumentError(lazy"unkown alg symbol $alg")) +end + +default_algorithm(::typeof(right_null!), A; trunc = nothing, kwargs...) = + isnothing(trunc) ? select_algorithm(right_null_lq!, A; kwargs...) : + select_algorithm(right_null_svd!, A; trunc, kwargs...) + +select_algorithm(::typeof(right_null_lq!), A, alg = nothing; kwargs...) = + select_algorithm(lq_null!, A, alg; kwargs...) + +function select_algorithm(::typeof(right_null_svd!), A, alg; trunc = nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) + end + +end diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index db417edb..3f6ba313 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -36,6 +36,16 @@ function TruncationStrategy(; return strategy end +function null_truncation_strategy(; atol = nothing, rtol = nothing, maxnullity = nothing) + if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol) + return notrunc() + end + atol = @something atol 0 + rtol = @something rtol 0 + trunc = trunctol(; atol, rtol, keep_below = true) + return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev = false) : trunc +end + """ NoTruncation() From b68d8c80d372f1cac6578178e8a92406b3ffff48 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 20:21:11 -0400 Subject: [PATCH 08/47] refactor null implementation --- src/implementations/orthnull.jl | 172 +++++++++----------------------- 1 file changed, 47 insertions(+), 125 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index ea57d916..5172d815 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -1,5 +1,5 @@ -# Inputs -# ------ +# Orthogonalization +# ----------------- copy_input(::typeof(left_orth), A) = copy_input(qr_compact, A) # do we ever need anything else copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever need anything else copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else @@ -7,11 +7,6 @@ copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need a check_input(::typeof(left_orth!), A, VC, alg::AbstractAlgorithm) = check_input(left_orth_kind(alg), A, VC, alg) - -check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = - check_input(right_orth_kind(alg), A, CVᴴ, alg) - - check_input(::typeof(left_orth_qr!), A, VC, alg::AbstractAlgorithm) = check_input(qr_compact!, A, VC, alg) check_input(::typeof(left_orth_polar!), A, VC, alg::AbstractAlgorithm) = @@ -19,6 +14,8 @@ check_input(::typeof(left_orth_polar!), A, VC, alg::AbstractAlgorithm) = check_input(::typeof(left_orth_svd!), A, VC, alg::AbstractAlgorithm) = check_input(qr_compact!, A, VC, alg) +check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(right_orth_kind(alg), A, CVᴴ, alg) check_input(::typeof(right_orth_lq!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(lq_compact!, A, CVᴴ, alg) check_input(::typeof(right_orth_polar!), A, CVᴴ, alg::AbstractAlgorithm) = @@ -26,13 +23,21 @@ check_input(::typeof(right_orth_polar!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(::typeof(right_orth_svd!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(lq_compact!, A, CVᴴ, alg) +check_input(::typeof(left_null!), A, N, alg::AbstractAlgorithm) = + check_input(left_null_kind(alg), A, N, alg) +check_input(::typeof(left_null_qr!), A, N, alg::AbstractAlgorithm) = + check_input(qr_null!, A, N, alg) +check_input(::typeof(left_null_svd!), A, N, alg::AbstractAlgorithm) = nothing -initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = - initialize_output(left_orth_kind(alg), A, alg) -initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = - initialize_output(right_orth_kind(alg), A, alg) +check_input(::typeof(right_null!), A, Nᴴ, alg::AbstractAlgorithm) = + check_input(right_null_kind(alg), A, Nᴴ, alg) +check_input(::typeof(right_null_lq!), A, Nᴴ, alg::AbstractAlgorithm) = + check_input(lq_null!, A, Nᴴ, alg) +check_input(::typeof(right_null_svd!), A, Nᴴ, alg::AbstractAlgorithm) = nothing +initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = + initialize_output(left_orth_kind(alg), A, alg) initialize_output(::typeof(left_orth_qr!), A, alg::AbstractAlgorithm) = initialize_output(qr_compact!, A, alg) initialize_output(::typeof(left_orth_polar!), A, alg::AbstractAlgorithm) = @@ -40,6 +45,8 @@ initialize_output(::typeof(left_orth_polar!), A, alg::AbstractAlgorithm) = initialize_output(::typeof(left_orth_svd!), A, alg::AbstractAlgorithm) = initialize_output(qr_compact!, A, alg) +initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = + initialize_output(right_orth_kind(alg), A, alg) initialize_output(::typeof(right_orth_lq!), A, alg::AbstractAlgorithm) = initialize_output(lq_compact!, A, alg) initialize_output(::typeof(right_orth_polar!), A, alg::AbstractAlgorithm) = @@ -47,27 +54,20 @@ initialize_output(::typeof(right_orth_polar!), A, alg::AbstractAlgorithm) = initialize_output(::typeof(right_orth_svd!), A, alg::AbstractAlgorithm) = initialize_output(lq_compact!, A, alg) +initialize_output(::typeof(left_null!), A, alg::AbstractAlgorithm) = + initialize_output(left_null_kind(alg), A, alg) +initialize_output(::typeof(left_null_qr!), A, alg::AbstractAlgorithm) = + initialize_output(qr_null!, A, alg) +initialize_output(::typeof(left_null_svd!), A, alg::AbstractAlgorithm) = nothing -function check_input(::typeof(left_null!), A::AbstractMatrix, N, ::AbstractAlgorithm) - m, n = size(A) - minmn = min(m, n) - @assert N isa AbstractMatrix - @check_size(N, (m, m - minmn)) - @check_scalar(N, A) - return nothing -end -function check_input(::typeof(right_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm) - m, n = size(A) - minmn = min(m, n) - @assert Nᴴ isa AbstractMatrix - @check_size(Nᴴ, (n - minmn, n)) - @check_scalar(Nᴴ, A) - return nothing -end +initialize_output(::typeof(right_null!), A, alg::AbstractAlgorithm) = + initialize_output(right_null_kind(alg), A, alg) +initialize_output(::typeof(right_null_lq!), A, alg::AbstractAlgorithm) = + initialize_output(lq_null!, A, alg) +initialize_output(::typeof(right_null_svd!), A, alg::AbstractAlgorithm) = nothing # Outputs # ------- - function initialize_orth_svd(A::AbstractMatrix, F, alg) S = Diagonal(initialize_output(svd_vals!, A, alg)) return F[1], S, F[2] @@ -75,27 +75,14 @@ end # fallback doesn't re-use F at all initialize_orth_svd(A, F, alg) = initialize_output(svd_compact!, A, alg) -function initialize_output(::typeof(left_null!), A::AbstractMatrix) - m, n = size(A) - minmn = min(m, n) - N = similar(A, (m, m - minmn)) - return N -end -function initialize_output(::typeof(right_null!), A::AbstractMatrix) - m, n = size(A) - minmn = min(m, n) - Nᴴ = similar(A, (n - minmn, n)) - return Nᴴ -end - # Implementation of orth functions # -------------------------------- left_orth!(A, VC, alg::AbstractAlgorithm) = left_orth_kind(alg)(A, VC, alg) -right_orth!(A, CVᴴ, alg::AbstractAlgorithm) = right_orth_kind(alg)(A, CVᴴ, alg) - left_orth_qr!(A, VC, alg::AbstractAlgorithm) = qr_compact!(A, VC, alg) -right_orth_lq!(A, CVᴴ, alg::AbstractAlgorithm) = lq_compact!(A, CVᴴ, alg) left_orth_polar!(A, VC, alg::AbstractAlgorithm) = left_polar!(A, VC, alg) + +right_orth!(A, CVᴴ, alg::AbstractAlgorithm) = right_orth_kind(alg)(A, CVᴴ, alg) +right_orth_lq!(A, CVᴴ, alg::AbstractAlgorithm) = lq_compact!(A, CVᴴ, alg) right_orth_polar!(A, CVᴴ, alg::AbstractAlgorithm) = right_polar!(A, CVᴴ, alg) # orth_svd requires implementations of `lmul!` and `rmul!` @@ -116,88 +103,23 @@ end # Implementation of null functions # -------------------------------- -function null_truncation_strategy(; atol = nothing, rtol = nothing, maxnullity = nothing) - if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol) - return notrunc() - end - atol = @something atol 0 - rtol = @something rtol 0 - trunc = trunctol(; atol, rtol, keep_below = true) - return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev = false) : trunc -end -function left_null!( - A, N; - trunc = nothing, kind = isnothing(trunc) ? :qr : :svd, - alg_qr = (; positive = true), alg_svd = (;) - ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for left_null with kind=$kind")) - end - return if kind == :qr - left_null_qr!(A, N, alg_qr) - elseif kind == :svd - left_null_svd!(A, N, alg_svd, trunc) - else - throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) - end -end -function left_null_qr!(A, N, alg) - alg′ = select_algorithm(qr_null!, A, alg) - check_input(left_null!, A, N, alg′) - return qr_null!(A, N, alg′) -end -function left_null_svd!(A, N, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_full!, A, alg) - check_input(left_null!, A, N, alg′) - U, _, _ = svd_full!(A, alg′) - (m, n) = size(A) - return copy!(N, view(U, 1:m, (n + 1):m)) -end -function left_null_svd!(A, N, alg, trunc) - alg′ = select_algorithm(svd_full!, A, alg) - U, S, _ = svd_full!(A, alg′) - trunc′ = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - N, ind = truncate(left_null!, (U, S), trunc′) - return N -end -function right_null!( - A, Nᴴ; - trunc = nothing, kind = isnothing(trunc) ? :lq : :svd, - alg_lq = (; positive = true), alg_svd = (;) - ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for right_null with kind=$kind")) - end - if kind == :lq - return right_null_lq!(A, Nᴴ, alg_lq) - elseif kind == :svd - return right_null_svd!(A, Nᴴ, alg_svd, trunc) - else - throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) - end -end -function right_null_lq!(A, Nᴴ, alg) - alg′ = select_algorithm(lq_null!, A, alg) - check_input(right_null!, A, Nᴴ, alg′) - return lq_null!(A, Nᴴ, alg′) -end -function right_null_svd!(A, Nᴴ, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_full!, A, alg) - check_input(right_null!, A, Nᴴ, alg′) - _, _, Vᴴ = svd_full!(A, alg′) - (m, n) = size(A) - return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) +left_null!(A, N, alg::AbstractAlgorithm) = left_null_kind(alg)(A, N, alg) +left_null_qr!(A, N, alg::AbstractAlgorithm) = qr_null!(A, N, alg) + +right_null!(A, Nᴴ, alg::AbstractAlgorithm) = right_null_kind(alg)(A, Nᴴ, alg) +right_null_lq!(A, Nᴴ, alg::AbstractAlgorithm) = lq_null!(A, Nᴴ, alg) + +function left_null_svd!(A, N, alg::TruncatedAlgorithm) + check_input(left_null_svd!, A, N, alg) + U, S, _ = svd_full!(A, alg.alg) + N, _ = truncate(left_null!, (U, S), alg.trunc) + return N end -function right_null_svd!(A, Nᴴ, alg, trunc) - alg′ = select_algorithm(svd_full!, A, alg) - check_input(right_null!, A, Nᴴ, alg′) - _, S, Vᴴ = svd_full!(A, alg′) - trunc′ = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return first(truncate(right_null!, (S, Vᴴ), trunc′)) +function right_null_svd!(A, Nᴴ, alg::TruncatedAlgorithm) + check_input(right_null_svd!, A, Nᴴ, alg) + _, S, Vᴴ = svd_full!(A, alg.alg) + Nᴴ, _ = truncate(right_null!, (S, Vᴴ), alg.trunc) + return Nᴴ end From 4c31a3672c4d101e24f03e77f3b1805b99503189 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 20:28:26 -0400 Subject: [PATCH 09/47] more algorithm traits --- src/algorithms.jl | 48 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/algorithms.jl b/src/algorithms.jl index ed42988a..0c7f2594 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -95,6 +95,54 @@ right_orth_kind(alg::AbstractAlgorithm) = error( """ ) +""" + left_null_kind(alg::AbstractAlgorithm) -> f! + +Select an appropriate factorization function for applying `left_null!(A, alg)`. +By default, this is either `left_null_qr!` or `left_null_svd!`, but this can be extended +to insert arbitrary other decomposition functions, which should follow the signature +`f!(A, F, alg) -> F` +""" +function left_null_kind(alg::AbstractAlgorithm) + left_orth_kind(alg) === left_orth_qr! && return left_null_qr! + left_orth_kind(alg) === left_orth_svd! && return left_null_svd! + return error( + """ + Unkown or invalid `left_null` algorithm type `$(typeof(alg))`. + To register the algorithm type, define: + + MatrixAlgebraKit.left_null_kind(alg) = f! + + where `f!` should be the factorization function that will be used. + By default, this is either `left_null_qr!` or `left_null_svd!`. + """ + ) +end + +""" + right_null_kind(alg::AbstractAlgorithm) -> f! + +Select an appropriate factorization function for applying `right_null!(A, alg)`. +By default, this is either `right_null_lq!` or `right_null_svd!`, but this can be extended +to insert arbitrary other decomposition functions, which should follow the signature +`f!(A, F, alg) -> F` +""" +function right_null_kind(alg::AbstractAlgorithm) + right_orth_kind(alg) === right_orth_lq! && return right_null_lq! + right_orth_kind(alg) === right_orth_svd! && return right_null_svd! + return error( + """ + Unkown or invalid `right_null` algorithm type `$(typeof(alg))`. + To register the algorithm type, define: + + MatrixAlgebraKit.right_null_kind(alg) = f! + + where `f!` should be the factorization function that will be used. + By default, this is either `right_null_lq!` or `right_null_svd!`. + """ + ) +end + """ does_truncate(alg::AbstractAlgorithm) -> Bool From 18b5a94a15008be438624b3ceb70809ecc8ff4b6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 20:31:27 -0400 Subject: [PATCH 10/47] `left_null` and `right_null` with `@functiondef` --- src/interface/orthnull.jl | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index a459d512..f4ce4b2d 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -156,14 +156,11 @@ be used. See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) """ -function left_null end -function left_null! end -function left_null!(A; kwargs...) - return left_null!(A, initialize_output(left_null!, A); kwargs...) -end -function left_null(A; kwargs...) - return left_null!(copy_input(left_null, A); kwargs...) -end +@functiondef left_null + +# helper functions +function left_null_qr! end +function left_null_svd! end """ right_null(A; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ @@ -206,14 +203,11 @@ be used. See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) """ -function right_null end -function right_null! end -function right_null!(A; kwargs...) - return right_null!(A, initialize_output(right_null!, A); kwargs...) -end -function right_null(A; kwargs...) - return right_null!(copy_input(right_null, A); kwargs...) -end +@functiondef right_null + +# helper functions +function right_null_lq! end +function right_null_svd! end # Algorithm selection # ------------------- From 3b98df67242af7560649385e231889b83cb0db88 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 11 Oct 2025 20:36:35 -0400 Subject: [PATCH 11/47] reorganize algorithm unions --- src/implementations/svd.jl | 14 -------------- src/interface/decompositions.jl | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 9ebc7319..fed36cd1 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -285,20 +285,6 @@ end # placed here to avoid code duplication since much of the logic is replicable across # CUDA and AMDGPU ### -const CUSOLVER_SVDAlgorithm = Union{ - CUSOLVER_QRIteration, - CUSOLVER_SVDPolar, - CUSOLVER_Jacobi, - CUSOLVER_Randomized, -} -const ROCSOLVER_SVDAlgorithm = Union{ - ROCSOLVER_QRIteration, - ROCSOLVER_Jacobi, -} -const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} - -const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} -const GPU_Randomized = Union{CUSOLVER_Randomized} function check_input( ::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index ef6af6c7..35840124 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -235,6 +235,10 @@ Divide and Conquer algorithm. """ @algdef CUSOLVER_DivideAndConquer +const CUSOLVER_SVDAlgorithm = Union{ + CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, +} + # ========================= # ROCSOLVER ALGORITHMS # ========================= @@ -282,6 +286,7 @@ Divide and Conquer algorithm. """ @algdef ROCSOLVER_DivideAndConquer +const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} const GPU_Simple = Union{CUSOLVER_Simple} const GPU_EigAlgorithm = Union{GPU_Simple} @@ -289,8 +294,13 @@ const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} const GPU_Bisection = Union{ROCSOLVER_Bisection} -const GPU_SVDAlgorithm = Union{GPU_Jacobi, GPU_DivideAndConquer, GPU_QRIteration} -const GPU_EighAlgorithm = Union{GPU_QRIteration, GPU_Jacobi, GPU_DivideAndConquer, GPU_Bisection, CUSOLVER_SVDPolar} +const GPU_EighAlgorithm = Union{ + GPU_QRIteration, GPU_Jacobi, GPU_DivideAndConquer, GPU_Bisection, +} +const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} + +const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} +const GPU_Randomized = Union{CUSOLVER_Randomized} left_orth_kind(::GPU_SVDAlgorithm) = left_orth_svd! right_orth_kind(::GPU_SVDAlgorithm) = right_orth_svd! From 1541f9f8dc663580648bbcc8b994b0b90eb87108 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 12 Oct 2025 08:12:20 -0400 Subject: [PATCH 12/47] refactor null truncation --- src/implementations/truncation.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index df327c8b..934de34f 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -26,6 +26,17 @@ function truncate(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy return Vᴴ[ind, :], ind end +# special case `NoTruncation` for null: should keep exact zeros due to rectangularity +function truncate(::typeof(left_null!), (U, S), strategy::NoTruncation) + ind = (1:max(0, size(S, 1) - size(S, 2))) .+ length(diagview(S)) + return U[:, ind], ind +end +function truncate(::typeof(right_null!), (S, Vᴴ), strategy::NoTruncation) + ind = (1:max(0, size(S, 2) - size(S, 1))) .+ length(diagview(S)) + return Vᴴ[ind, :], ind +end + + # findtruncated # ------------- # Generic fallback From 08944ec0912a99e0efc11c635145311d741b997c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 12 Oct 2025 09:00:21 -0400 Subject: [PATCH 13/47] update docstrings --- src/interface/orthnull.jl | 460 +++++++++++++++++++++++++------------- 1 file changed, 303 insertions(+), 157 deletions(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index f4ce4b2d..5c507850 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -1,53 +1,104 @@ # Orth functions # -------------- + +const docs_truncation_kwargs = """ +* `atol::Real` : Absolute tolerance for the truncation +* `rtol::Real` : Relative tolerance for the truncation +* `maxrank::Real` : Maximal rank for the truncation +* `maxerror::Real` : Maximal truncation error. +* `filter` : Custom filter to select truncated values. +""" + +const docs_truncation_strategies = """ +- [`notrunc`](@ref) +- [`truncrank`](@ref) +- [`trunctol`](@ref) +- [`truncerror`](@ref) +- [`truncfilter`](@ref) +""" + +const docs_null_truncation_kwargs = """ +* `atol::Real` : Absolute tolerance for the truncation +* `rtol::Real` : Relative tolerance for the truncation +* `maxnullity::Real` : Maximal rank for the truncation """ - left_orth(A; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C - left_orth!(A, [VC]; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C - -Compute an orthonormal basis `V` for the image of the matrix `A` of size `(m, n)`, -as well as a matrix `C` (the corestriction) such that `A` factors as `A = V * C`. -The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `trunc` can be used to control the -precision in determining the rank of `A` via its singular values. - -`trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxrank`. - -This is a high-level wrapper and will use one of the decompositions -[`qr_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and[`left_polar!`](@ref) -to compute the orthogonal basis `V`, as controlled by the keyword arguments. - -When `kind` is provided, its possible values are - -* `kind == :qr`: `V` and `C` are computed using the QR decomposition. - This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to - `qr_compact!(A, [VC], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)` - -* `kind == :polar`: `V` and `C` are computed using the polar decomposition, - This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to - `left_polar!(A, [VC], alg)` with a default value `alg = select_algorithm(left_polar!, A)` - -* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_trunc!` when a - truncation strategy is specified using the `trunc` keyword argument, and using `svd_compact!` otherwise. - `V` will contain the left singular vectors and `C` is computed as the product of the singular - values and the right singular vectors, i.e. with `U, S, Vᴴ = svd(A)`, we have - `V = U` and `C = S * Vᴴ`. - -When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)` -and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -for backend factorizations through the `alg_qr`, `alg_polar`, and `alg_svd` keyword arguments, -which will only be used if the corresponding factorization is called based on the other inputs. -If NamedTuples are passed as `alg_qr`, `alg_polar`, or `alg_svd`, a default algorithm is chosen -with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will -be used. + +""" + left_orth(A; [trunc], kwargs...) -> V, C + left_orth!(A, [VC]; [trunc], kwargs...) -> V, C + +Compute an orthonormal basis `V` for the image of the matrix `A`, as well as a matrix `C` +(the corestriction) such that `A` factors as `A = V * C`. + +This is a high-level wrapper where te keyword arguments can be used to specify and control +the specific orthogonal decomposition that should be used to factor `A`, whereas `trunc` +can optionally be used to control the precision in determining the rank of `A`, typically +via its singular values. + +## Keyword arguments +There are 3 major modes of operation, based on the `alg` keyword, with slightly different +application purposes. + +### `alg::Nothing` +This default mode uses the presence of a truncation strategy `trunc` to determine an optimal +decomposition type, which will be QR-based for no truncation, or SVD-based for truncation. +The remaining keyword arguments are passed on directly to the algorithm selection procedure +of the chosen decomposition type. + +### `alg::Symbol` +Here, the driving selector is `alg`, and depending on its value, the algorithm selection +procedure takes other keywords into account: + +* `:qr` : Factorize via QR decomposition, with further customizations through the + `alg_qr` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + V, C = qr_compact(A; alg_qr...) +``` + +* `:polar` : Factorize via polar decomposition, with further customizations through the + `alg_polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + V, C = left_polar(A; alg_polar...) +``` + +* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. + This mode further allows truncation, which can be selected through the `trunc` argument. + This mode is roughly equivalent to: +```julia + V, S, C = svd_trunc(A; trunc, alg_svd...) + C = S * C +``` + +### `alg::AbstractAlgorithm` +In this expert mode the algorithm is supplied directly, and the kind of decomposition is +deduced from that. This hinges on the implementation of the algorithm trait +[`MatrixAlgebraKit.left_orth_kind(alg)`](@ref). + +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and +any non-trivial strategy typically requires an SVD-based decompositions. This keyword can +be either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) + +--- !!! note - The bang method `left_orth!` optionally accepts the output structure and possibly destroys - the input matrix `A`. Always use the return value of the function as it may not always be - possible to use the provided `CV` as output. + The bang method `left_orth!` optionally accepts the output structure and possibly + destroys the input matrix `A`. Always use the return value of the function as it may + not always be possible to use the provided `CV` as output. -See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) +See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), +[`right_null(!)`](@ref right_null) """ @functiondef left_orth @@ -57,54 +108,81 @@ function left_orth_polar! end function left_orth_svd! end """ - right_orth(A; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ - right_orth!(A, [CVᴴ]; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ - -Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e. -for the image of `adjoint(A)`, as well as a matrix `C` such that `A = C * Vᴴ`. -The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `trunc` can be used to control the -precision in determining the rank of `A` via its singular values. - -`trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxrank`. - -This is a high-level wrapper and will use one of the decompositions -[`lq_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and -[`right_polar!`](@ref) to compute the orthogonal basis `V`, as controlled by the -keyword arguments. - -When `kind` is provided, its possible values are - -* `kind == :lq`: `C` and `Vᴴ` are computed using the QR decomposition, - This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to - `lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` - -* `kind == :polar`: `C` and `Vᴴ` are computed using the polar decomposition, - This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to - `right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A)` - -* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_trunc!` when - a truncation strategy is specified using the `trunc` keyword argument, and using `svd_compact!` otherwise. - `V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular - values and `C` is computed as the product of the singular values and the right singular vectors, - i.e. with `U, S, Vᴴ = svd(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`. - -When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` -and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -for backend factorizations through the `alg_lq`, `alg_polar`, and `alg_svd` keyword arguments, -which will only be used if the corresponding factorization is called based on the other inputs. -If `alg_lq`, `alg_polar`, or `alg_svd` are NamedTuples, a default algorithm is chosen -with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_lq` defaults to `(; positive=true)` so that by default a positive LQ decomposition will -be used. + right_orth(A; [trunc], kwargs...) -> C, Vᴴ + right_orth!(A, [CVᴴ]; [trunc], kwargs...) -> C, Vᴴ + +Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e. for +the image of `adjoint(A)`, as well as a matrix `C` such that `A` factors as `A = C * Vᴴ`. + +This is a high-level wrapper where the keyword arguments can be used to specify and control +the specific orthogonal decomposition that should be used to factor `A`, whereas `trunc` can +optionally be used to control the precision in determining the rank of `A`, typically via +its singular values. + +## Keyword arguments +There are 3 major modes of operation, based on the `alg` keyword, with slightly different +application purposes. + +### `alg::Nothing` +This default mode uses the presence of a truncation strategy `trunc` to determine an optimal +decomposition type, which will be LQ-based for no truncation, or SVD-based for truncation. +The remaining keyword arguments are passed on directly to the algorithm selection procedure +of the chosen decomposition type. + +### `alg::Symbol` +Here, the driving selector is `alg`, and depending on its value, the algorithm selection +procedure takes other keywords into account: + +* `:lq` : Factorize via LQ decomposition, with further customizations through the + `alg_lq` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + C, Vᴴ = lq_compact(A; alg_lq...) +``` + +* `:polar` : Factorize via polar decomposition, with further customizations through the + `alg_polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + C, Vᴴ = right_polar(A; alg_polar...) +``` + +* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. + This mode further allows truncation, which can be selected through the `trunc` argument. + This mode is roughly equivalent to: +```julia + C, S, Vᴴ = svd_trunc(A; trunc, alg_svd...) + C = C * S +``` + +### `alg::AbstractAlgorithm` +In this expert mode the algorithm is supplied directly, and the kind of decomposition is +deduced from that. This hinges on the implementation of the algorithm trait +[`MatrixAlgebraKit.right_orth_kind(alg)`](@ref). + +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and +any non-trivial strategy typically requires an SVD-based decompositions. This keyword can +be either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) + +--- !!! note - The bang method `right_orth!` optionally accepts the output structure and possibly destroys - the input matrix `A`. Always use the return value of the function as it may not always be - possible to use the provided `CVᴴ` as output. + The bang method `right_orth!` optionally accepts the output structure and possibly + destroys the input matrix `A`. Always use the return value of the function as it may not + always be possible to use the provided `CVᴴ` as output. -See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) +See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), +[`right_null(!)`](@ref right_null) """ @functiondef right_orth @@ -116,45 +194,79 @@ function right_orth_svd! end # Null functions # -------------- """ - left_null(A; [kind::Symbol, trunc, alg_qr, alg_svd]) -> N - left_null!(A, [N]; [kind::Symbol, alg_qr, alg_svd]) -> N - -Compute an orthonormal basis `N` for the cokernel of the matrix `A` of size `(m, n)`, i.e. -the nullspace of `adjoint(A)`, such that `adjoint(A)*N ≈ 0` and `N'*N ≈ I`. -The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `trunc` can be used to control the -the rank of `A` via its singular values. - -`trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxnullity`. - -This is a high-level wrapper and will use one of the decompositions `qr!` or `svd!` -to compute the orthogonal basis `N`, as controlled by the keyword arguments. + left_null(A; [trunc], kwargs...) -> N + left_null!(A, [N]; [trunc], kwargs...) -> N + +Compute an orthonormal basis `N` for the cokernel of the matrix `A`, i.e. the nullspace of +`adjoint(A)`, such that `adjoint(A) * N ≈ 0` and `N' * N ≈ I`. + +This is a high-level wrapper where the keyword arguments can be used to specify and control +the underlying orthogonal decomposition that should be used to find the null space of `A'`, +whereas `trunc` can optionally be used to control the precision in determining the rank of +`A`, typically via its singular values. + +## Keyword arguments +There are 3 major modes of operation, based on the `alg` keyword, with slightly different +application purposes. + +### `alg::Nothing` +This default mode uses the presence of a truncation strategy `trunc` to determine an optimal +decomposition type, which will be QR-based for no truncation, or SVD-based for truncation. +The remaining keyword arguments are passed on directly to the algorithm selection procedure +of the chosen decomposition type. + +### `alg::Symbol` +Here, the driving selector is `alg`, and depending on its value, the algorithm selection +procedure takes other keywords into account: + +* `:qr` : Factorize via QR nullspace, with further customizations through the `alg_qr` + keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + N = qr_null(A; alg_qr...) +``` + +* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. + This mode further allows truncation, which can be selected through the `trunc` argument. + It is roughly equivalent to: +```julia + U, S, _ = svd_trunc(A; trunc, alg_svd...) + N = truncate(left_null, (U, S), trunc) +``` + +### `alg::AbstractAlgorithm` +In this expert mode the algorithm is supplied directly, and the kind of decomposition is +deduced from that. This hinges on the implementation of the algorithm trait +[`MatrixAlgebraKit.left_null_kind(alg)`](@ref). + +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and any +non-trivial strategy typically requires an SVD-based decomposition. This keyword can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_null_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. By default, +MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) -When `kind` is provided, its possible values are - -* `kind == :qr`: `N` is computed using the QR decomposition. - This requires `isnothing(trunc)` and `left_null!(A, [N], kind=:qr)` is equivalent to - `qr_null!(A, [N], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)` - -* `kind == :svd`: `N` is computed using the singular value decomposition and will contain - the left singular vectors corresponding to either the zero singular values if `trunc` - isn't specified or the singular values specified by `trunc`. +!!! note + Here [`notrunc`](@ref) has special meaning, and signifies keeping the values that + correspond to the exact zeros determined from the additional columns of `A`. -When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)` -and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg_qr` and `alg_svd` keyword arguments, which will only be used by the corresponding -factorization backend. If `alg_qr` or `alg_svd` are NamedTuples, a default algorithm is chosen -with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will -be used. +--- !!! note - The bang method `left_null!` optionally accepts the output structure and possibly destroys - the input matrix `A`. Always use the return value of the function as it may not always be - possible to use the provided `N` as output. + The bang method `left_null!` optionally accepts the output structure and possibly + destroys the input matrix `A`. Always use the return value of the function as it may not + always be possible to use the provided `N` as output. -See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) +See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), +[`right_orth(!)`](@ref right_orth) """ @functiondef left_null @@ -163,45 +275,79 @@ function left_null_qr! end function left_null_svd! end """ - right_null(A; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ - right_null!(A, [Nᴴ]; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ - -Compute an orthonormal basis `N = adjoint(Nᴴ)` for the kernel or nullspace of the matrix `A` -of size `(m, n)`, such that `A*adjoint(Nᴴ) ≈ 0` and `Nᴴ*adjoint(Nᴴ) ≈ I`. -The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `trunc` can be used to control the -the rank of `A` via its singular values. - -`trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxnullity`. + right_null(A; [trunc], kwargs...) -> Nᴴ + right_null!(A, [Nᴴ]; [trunc], kwargs...) -> Nᴴ + +Compute an orthonormal basis `N = adjoint(Nᴴ)` for the kernel of the matrix `A`, i.e. the +nullspace of `A`, such that `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. + +This is a high-level wrapper where the keyword arguments can be used to specify and control +the underlying orthogonal decomposition that should be used to find the null space of `A`, +whereas `trunc` can optionally be used to control the precision in determining the rank of +`A`, typically via its singular values. + +## Keyword arguments +There are 3 major modes of operation, based on the `alg` keyword, with slightly different +application purposes. + +### `alg::Nothing` +This default mode uses the presence of a truncation strategy `trunc` to determine an optimal +decomposition type, which will be LQ-based for no truncation, or SVD-based for truncation. +The remaining keyword arguments are passed on directly to the algorithm selection procedure +of the chosen decomposition type. + +### `alg::Symbol` +Here, the driving selector is `alg`, and depending on its value, the algorithm selection +procedure takes other keywords into account: + +* `:lq` : Factorize via LQ nullspace, with further customizations through the `alg_lq` + keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + Nᴴ = lq_null(A; alg_qr...) +``` + +* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. + This mode further allows truncation, which can be selected through the `trunc` argument. + It is roughly equivalent to: +```julia + _, S, Vᴴ = svd_trunc(A; trunc, alg_svd...) + Nᴴ = truncate(right_null, (S, Vᴴ), trunc) +``` + +### `alg::AbstractAlgorithm` +In this expert mode the algorithm is supplied directly, and the kind of decomposition is +deduced from that. This hinges on the implementation of the algorithm trait +[`MatrixAlgebraKit.right_null_kind(alg)`](@ref). + +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and any +non-trivial strategy typically requires an SVD-based decomposition. This keyword can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_null_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. By default, +MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) -This is a high-level wrapper and will use one of the decompositions `lq!` or `svd!` -to compute the orthogonal basis `Nᴴ`, as controlled by the keyword arguments. - -When `kind` is provided, its possible values are - -* `kind == :lq`: `Nᴴ` is computed using the (nonpositive) LQ decomposition. - This requires `isnothing(trunc)` and `right_null!(A, [Nᴴ], kind=:lq)` is equivalent to - `lq_null!(A, [Nᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` - -* `kind == :svd`: `N` is computed using the singular value decomposition and will contain - the left singular vectors corresponding to the singular values that - are smaller than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. +!!! note + Here [`notrunc`](@ref) has special meaning, and signifies keeping the values that + correspond to the exact zeros determined from the additional rows of `A`. -When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` -and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg_lq` and `alg_svd` keyword arguments, which will only be used by the corresponding -factorization backend. If `alg_lq` or `alg_svd` are NamedTuples, a default algorithm is chosen -with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_lq` defaults to `(; positive=true)` so that by default a positive LQ decomposition will -be used. +--- !!! note - The bang method `right_null!` optionally accepts the output structure and possibly destroys - the input matrix `A`. Always use the return value of the function as it may not always be - possible to use the provided `Nᴴ` as output. + The bang method `right_null!` optionally accepts the output structure and possibly + destroys the input matrix `A`. Always use the return value of the function as it may not + always be possible to use the provided `Nᴴ` as output. -See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) +See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), +[`right_orth(!)`](@ref right_orth) """ @functiondef right_null From 67dc55a9d13816f45fa9aa5f88b201a394235b9a Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 12 Oct 2025 10:00:39 -0400 Subject: [PATCH 14/47] disambiguate alg selection --- src/interface/orthnull.jl | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 5c507850..6a195a54 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -371,9 +371,14 @@ function select_algorithm(::typeof(left_orth!), A, alg::Symbol; trunc = nothing, throw(ArgumentError(lazy"Unknown alg symbol $alg")) end -default_algorithm(::typeof(left_orth!), A; trunc = nothing, kwargs...) = +default_algorithm(::typeof(left_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = isnothing(trunc) ? select_algorithm(left_orth_qr!, A; kwargs...) : select_algorithm(left_orth_svd!, A; trunc, kwargs...) +# disambiguate +default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = + isnothing(trunc) ? select_algorithm(left_orth_qr!, A; kwargs...) : + select_algorithm(left_orth_svd!, A; trunc, kwargs...) + select_algorithm(::typeof(left_orth_qr!), A, alg = nothing; kwargs...) = select_algorithm(qr_compact!, A, alg; kwargs...) @@ -397,10 +402,15 @@ function select_algorithm(::typeof(right_orth!), A, alg::Symbol; trunc = nothing throw(ArgumentError(lazy"Unknown alg symbol $alg")) end -default_algorithm(::typeof(right_orth!), A; trunc = nothing, kwargs...) = +default_algorithm(::typeof(right_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = + isnothing(trunc) ? select_algorithm(right_orth_lq!, A; kwargs...) : + select_algorithm(right_orth_svd!, A; trunc, kwargs...) +# disambiguate: +default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(right_orth_lq!, A; kwargs...) : select_algorithm(right_orth_svd!, A; trunc, kwargs...) + select_algorithm(::typeof(right_orth_lq!), A, alg = nothing; kwargs...) = select_algorithm(lq_compact!, A, alg; kwargs...) select_algorithm(::typeof(right_orth_polar!), A, alg = nothing; kwargs...) = @@ -421,7 +431,11 @@ function select_algorithm(::typeof(left_null!), A, alg::Symbol; trunc = nothing, throw(ArgumentError(lazy"unkown alg symbol $alg")) end -default_algorithm(::typeof(left_null!), A; trunc = nothing, kwargs...) = +default_algorithm(::typeof(left_null!), A::TA; trunc = nothing, kwargs...) where {TA} = + isnothing(trunc) ? select_algorithm(left_null_qr!, A; kwargs...) : + select_algorithm(left_null_svd!, A; trunc, kwargs...) +# disambiguate +default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(left_null_qr!, A; kwargs...) : select_algorithm(left_null_svd!, A; trunc, kwargs...) @@ -450,7 +464,10 @@ function select_algorithm(::typeof(right_null!), A, alg::Symbol; trunc = nothing throw(ArgumentError(lazy"unkown alg symbol $alg")) end -default_algorithm(::typeof(right_null!), A; trunc = nothing, kwargs...) = +default_algorithm(::typeof(right_null!), A::TA; trunc = nothing, kwargs...) where {TA} = + isnothing(trunc) ? select_algorithm(right_null_lq!, A; kwargs...) : + select_algorithm(right_null_svd!, A; trunc, kwargs...) +default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(right_null_lq!, A; kwargs...) : select_algorithm(right_null_svd!, A; trunc, kwargs...) From 3fc6cbbc29ce2db08af923a96c3bf964395cadb2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 12 Oct 2025 10:01:05 -0400 Subject: [PATCH 15/47] update tests --- test/chainrules.jl | 16 +++--- test/orthnull.jl | 132 ++++++++++++++++++--------------------------- 2 files changed, 61 insertions(+), 87 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 441a17fb..ba3f0681 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -541,18 +541,18 @@ end ) test_rrule( config, left_orth, A; - fkwargs = (; kind = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) m >= n && test_rrule( config, left_orth, A; - fkwargs = (; kind = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) - ΔN = left_orth(A; kind = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) test_rrule( config, left_null, A; - fkwargs = (; kind = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, + fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) @@ -561,19 +561,19 @@ end atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) test_rrule( - config, right_orth, A; fkwargs = (; kind = :lq), + config, right_orth, A; fkwargs = (; alg = :lq), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) m <= n && test_rrule( - config, right_orth, A; fkwargs = (; kind = :polar), + config, right_orth, A; fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind = :lq)[2] + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] test_rrule( config, right_null, A; - fkwargs = (; kind = :lq), output_tangent = ΔNᴴ, + fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end diff --git a/test/orthnull.jl b/test/orthnull.jl index 8d690375..40c5056d 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -11,7 +11,6 @@ include("linearmap.jl") include("linearmap.jl") eltypes = (Float32, Float64, ComplexF32, ComplexF64) - @testset "left_orth and left_null for T = $T" for T in eltypes rng = StableRNG(123) m = 54 @@ -30,7 +29,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test V * V' + N * N' ≈ I M = LinearMap(A) - VM, CM = @constinferred left_orth(M; kind = :svd) + VM, CM = @constinferred left_orth(M; alg = :svd) @test parent(VM) * parent(CM) ≈ A if m > n @@ -46,25 +45,33 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test isisometric(N) end - for alg_qr in ((; positive = true), (; positive = false), LAPACK_HouseholderQR()) - V, C = @constinferred left_orth(A; alg_qr) - N = @constinferred left_null(A; alg_qr) - @test V isa Matrix{T} && size(V) == (m, minmn) - @test C isa Matrix{T} && size(C) == (minmn, n) - @test N isa Matrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - @test V * V' + N * N' ≈ I - end + # passing a kind and some kwargs + V, C = @constinferred left_orth(A; alg = :qr, alg_qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, alg_qr = (; positive = true)) + @test V isa Matrix{T} && size(V) == (m, minmn) + @test C isa Matrix{T} && size(C) == (minmn, n) + @test N isa Matrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + @test V * V' + N * N' ≈ I + + # passing an algorithm + V, C = @constinferred left_orth(A; alg = LAPACK_HouseholderQR()) + N = @constinferred left_null(A; alg = :qr, alg_qr = (; positive = true)) + @test V isa Matrix{T} && size(V) == (m, minmn) + @test C isa Matrix{T} && size(C) == (minmn, n) + @test N isa Matrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + @test V * V' + N * N' ≈ I Ac = similar(A) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) N2 = @constinferred left_null!(copy!(Ac, A), N) - @test V2 === V - @test C2 === C - @test N2 === N @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -74,9 +81,6 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) atol = eps(real(T)) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol)) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol)) - @test V2 !== V - @test C2 !== C - @test N2 !== C @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -90,9 +94,6 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) ) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null) - @test V2 !== V - @test C2 !== C - @test N2 !== C @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -100,49 +101,40 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test V2 * V2' + N2 * N2' ≈ I end - for kind in (:qr, :polar, :svd) # explicit kind kwarg - m < n && kind == :polar && continue - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind) - @test V2 === V - @test C2 === C + for alg in (:qr, :polar, :svd) # explicit kind kwarg + m < n && alg === :polar && continue + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg) @test V2 * C2 ≈ A @test isisometric(V2) - if kind != :polar - N2 = @constinferred left_null!(copy!(Ac, A), N; kind = kind) - @test N2 === N + if alg != :polar + N2 = @constinferred left_null!(copy!(Ac, A), N; alg) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N2) @test V2 * V2' + N2 * N2' ≈ I end # with kind and tol kwargs - if kind == :svd - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind, trunc = (; atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; kind, trunc = (; atol)) - @test V2 !== V - @test C2 !== C - @test N2 !== C + if alg == :svd + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) @test V2 * C2 ≈ A @test V2' * V2 ≈ I @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test N2' * N2 ≈ I @test V2 * V2' + N2 * N2' ≈ I - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind, trunc = (; rtol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; kind, trunc = (; rtol)) - @test V2 !== V - @test C2 !== C - @test N2 !== C + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N2) @test V2 * V2' + N2 * N2' ≈ I else - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind, trunc = (; atol)) - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind, trunc = (; rtol)) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind, trunc = (; atol)) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind, trunc = (; rtol)) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) end end end @@ -166,15 +158,12 @@ end @test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I M = LinearMap(A) - CM, VMᴴ = @constinferred right_orth(M; kind = :svd) + CM, VMᴴ = @constinferred right_orth(M; alg = :svd) @test parent(CM) * parent(VMᴴ) ≈ A Ac = similar(A) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ) - @test C2 === C - @test Vᴴ2 === Vᴴ - @test Nᴴ2 === Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -184,9 +173,6 @@ end atol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -196,57 +182,45 @@ end rtol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - for kind in (:lq, :polar, :svd) - n < m && kind == :polar && continue - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind) - @test C2 === C - @test Vᴴ2 === Vᴴ + for alg in (:lq, :polar, :svd) + n < m && alg == :polar && continue + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) - if kind != :polar - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind) - @test Nᴴ2 === Nᴴ + if alg != :polar + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I end - if kind == :svd - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind, trunc = (; atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind, trunc = (; atol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ + if alg == :svd + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind, trunc = (; rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind, trunc = (; rtol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I else - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind, trunc = (; atol)) - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind, trunc = (; rtol)) - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind, trunc = (; atol)) - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind, trunc = (; rtol)) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) end end end From fdbb86ceaaf677ad568c565b288d8391d1d07b2e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 12 Oct 2025 17:51:33 -0400 Subject: [PATCH 16/47] more docs updates --- src/interface/eig.jl | 40 +++++++++++++++++++++---------------- src/interface/eigh.jl | 36 +++++++++++++++++++-------------- src/interface/orthnull.jl | 23 --------------------- src/interface/svd.jl | 36 +++++++++++++++++++-------------- src/interface/truncation.jl | 40 ++++++++++++++++++++++++++++++++----- 5 files changed, 100 insertions(+), 75 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 867d69eb..1ff7c51a 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -45,22 +45,28 @@ selected according to a truncation strategy. The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded eigenvalues. -## Keyword arguments -The behavior of this function is controlled by the following keyword arguments: - -- `trunc`: Specifies the truncation strategy. This can be: - - A `NamedTuple` with fields `atol`, `rtol`, and/or `maxrank`, which will be converted to - a [`TruncationStrategy`](@ref). For details on available truncation strategies, see - [Truncations](@ref). - - A `TruncationStrategy` object directly (e.g., `truncrank(10)`, `trunctol(atol=1e-6)`, or - combinations using `&`). - - `nothing` (default), which keeps all eigenvalues. - -- Other keyword arguments are passed to the algorithm selection procedure. If no explicit - `alg` is provided, these keywords are used to select and configure the algorithm through - [`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm - selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) - for the default algorithm selection behavior. +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + +## Keyword Arguments +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the truncation strategy is already embedded in the algorithm. @@ -71,7 +77,7 @@ truncation strategy is already embedded in the algorithm. as it may not always be possible to use the provided `DV` as output. !!! note -$(docs_eig_note) +$docs_eig_note See also [`eig_full(!)`](@ref eig_full), [`eig_vals(!)`](@ref eig_vals), and [Truncations](@ref) for more information on truncation strategies. diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 314cb934..ae6a843c 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -50,22 +50,28 @@ selected according to a truncation strategy. The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded eigenvalues. +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + ## Keyword arguments -The behavior of this function is controlled by the following keyword arguments: - -- `trunc`: Specifies the truncation strategy. This can be: - - A `NamedTuple` with fields `atol`, `rtol`, and/or `maxrank`, which will be converted to - a [`TruncationStrategy`](@ref). For details on available truncation strategies, see - [Truncations](@ref). - - A `TruncationStrategy` object directly (e.g., `truncrank(10)`, `trunctol(atol=1e-6)`, or - combinations using `&`). - - `nothing` (default), which keeps all eigenvalues. - -- Other keyword arguments are passed to the algorithm selection procedure. If no explicit - `alg` is provided, these keywords are used to select and configure the algorithm through - [`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm - selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) - for the default algorithm selection behavior. +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the truncation strategy is already embedded in the algorithm. diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 6a195a54..8e54efed 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -1,28 +1,5 @@ # Orth functions # -------------- - -const docs_truncation_kwargs = """ -* `atol::Real` : Absolute tolerance for the truncation -* `rtol::Real` : Relative tolerance for the truncation -* `maxrank::Real` : Maximal rank for the truncation -* `maxerror::Real` : Maximal truncation error. -* `filter` : Custom filter to select truncated values. -""" - -const docs_truncation_strategies = """ -- [`notrunc`](@ref) -- [`truncrank`](@ref) -- [`trunctol`](@ref) -- [`truncerror`](@ref) -- [`truncfilter`](@ref) -""" - -const docs_null_truncation_kwargs = """ -* `atol::Real` : Absolute tolerance for the truncation -* `rtol::Real` : Relative tolerance for the truncation -* `maxnullity::Real` : Maximal rank for the truncation -""" - """ left_orth(A; [trunc], kwargs...) -> V, C left_orth!(A, [VC]; [trunc], kwargs...) -> V, C diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 606a1c4e..04e7121c 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -55,22 +55,28 @@ square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strat The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded singular values. +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + ## Keyword arguments -The behavior of this function is controlled by the following keyword arguments: - -- `trunc`: Specifies the truncation strategy. This can be: - - A `NamedTuple` with fields `atol`, `rtol`, and/or `maxrank`, which will be converted to - a [`TruncationStrategy`](@ref). For details on available truncation strategies, see - [Truncations](@ref). - - A `TruncationStrategy` object directly (e.g., `truncrank(10)`, `trunctol(atol=1e-6)`, or - combinations using `&`). - - `nothing` (default), which keeps all singular values. - -- Other keyword arguments are passed to the algorithm selection procedure. If no explicit - `alg` is provided, these keywords are used to select and configure the algorithm through - [`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm - selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) - for the default algorithm selection behavior. +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the truncation strategy is already embedded in the algorithm. diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index 3f6ba313..a1e993de 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -1,3 +1,25 @@ +const docs_truncation_kwargs = """ +* `atol::Real` : Absolute tolerance for the truncation +* `rtol::Real` : Relative tolerance for the truncation +* `maxrank::Real` : Maximal rank for the truncation +* `maxerror::Real` : Maximal truncation error. +* `filter` : Custom filter to select truncated values. +""" + +const docs_truncation_strategies = """ +- [`notrunc`](@ref) +- [`truncrank`](@ref) +- [`trunctol`](@ref) +- [`truncerror`](@ref) +- [`truncfilter`](@ref) +""" + +const docs_null_truncation_kwargs = """ +* `atol::Real` : Absolute tolerance for the truncation +* `rtol::Real` : Relative tolerance for the truncation +* `maxnullity::Real` : Maximal rank for the truncation +""" + """ TruncationStrategy(; kwargs...) @@ -8,11 +30,7 @@ The following keyword arguments are all optional, and their default value (`noth will be ignored. It is also allowed to combine multiple of these, in which case the kept values will consist of the intersection of the different truncated strategies. -- `atol::Real` : Absolute tolerance for the truncation -- `rtol::Real` : Relative tolerance for the truncation -- `maxrank::Real` : Maximal rank for the truncation -- `maxerror::Real` : Maximal truncation error. -- `filter` : Custom filter to select truncated values. +$docs_truncation_kwargs """ function TruncationStrategy(; atol::Union{Real, Nothing} = nothing, @@ -36,6 +54,18 @@ function TruncationStrategy(; return strategy end +""" + null_truncation_strategy(; kwargs...) + +Select a nullspace truncation strategy based on the provided keyword arguments. + +## Keyword arguments +The following keyword arguments are all optional, and their default value (`nothing`) +will be ignored. It is also allowed to combine multiple of these, in which case the +discarded values will consist of the intersection of the different truncated strategies. + +$docs_null_truncation_kwargs +""" function null_truncation_strategy(; atol = nothing, rtol = nothing, maxnullity = nothing) if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol) return notrunc() From 0559aeec71d6a04483059cc448a1bd145012ffa1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 12 Oct 2025 18:19:54 -0400 Subject: [PATCH 17/47] mark randomized SVD as unusable for nullspaces --- src/implementations/orthnull.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 5172d815..8e50aad7 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -103,8 +103,6 @@ end # Implementation of null functions # -------------------------------- - - left_null!(A, N, alg::AbstractAlgorithm) = left_null_kind(alg)(A, N, alg) left_null_qr!(A, N, alg::AbstractAlgorithm) = qr_null!(A, N, alg) @@ -123,3 +121,9 @@ function right_null_svd!(A, Nᴴ, alg::TruncatedAlgorithm) Nᴴ, _ = truncate(right_null!, (S, Vᴴ), alg.trunc) return Nᴴ end + +# randomized algorithms don't work for smallest values: +left_null_svd!(A, N, alg::TruncatedAlgorithm{<:GPU_Randomized}) = + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) +right_null_svd!(A, Nᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) = + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) From 86f0c8ddffd8acafb02996ca4b24b9e8c979a7df Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 13 Oct 2025 08:15:14 -0400 Subject: [PATCH 18/47] fix JET complaint --- src/interface/decompositions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 35840124..610c4928 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -214,7 +214,7 @@ for more information. """ @algdef CUSOLVER_Randomized -does_truncate(::CUSOLVER_Randomized) = true +does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true """ CUSOLVER_Simple() From 84537c7bcac752c57c21781764e607ec07de56d6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 14 Oct 2025 08:40:30 -0400 Subject: [PATCH 19/47] docstring reorganization --- src/interface/orthnull.jl | 146 +++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 73 deletions(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 8e54efed..16674abb 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -7,11 +7,27 @@ Compute an orthonormal basis `V` for the image of the matrix `A`, as well as a matrix `C` (the corestriction) such that `A` factors as `A = V * C`. -This is a high-level wrapper where te keyword arguments can be used to specify and control +This is a high-level wrapper where the keyword arguments can be used to specify and control the specific orthogonal decomposition that should be used to factor `A`, whereas `trunc` can optionally be used to control the precision in determining the rank of `A`, typically via its singular values. +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and +any non-trivial strategy typically requires an SVD-based decompositions. This keyword can +be either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) + ## Keyword arguments There are 3 major modes of operation, based on the `alg` keyword, with slightly different application purposes. @@ -51,22 +67,6 @@ In this expert mode the algorithm is supplied directly, and the kind of decompos deduced from that. This hinges on the implementation of the algorithm trait [`MatrixAlgebraKit.left_orth_kind(alg)`](@ref). -## Truncation -The optional truncation strategy can be controlled via the `trunc` keyword argument, and -any non-trivial strategy typically requires an SVD-based decompositions. This keyword can -be either a `NamedTuple` or a [`TruncationStrategy`](@ref). - -### `trunc::NamedTuple` -The supported truncation keyword arguments are: - -$(docs_truncation_kwargs) - -### `trunc::TruncationStrategy` -For more control, a truncation strategy can be supplied directly. -By default, MatrixAlgebraKit supplies the following: - -$(docs_truncation_strategies) - --- !!! note @@ -96,6 +96,22 @@ the specific orthogonal decomposition that should be used to factor `A`, whereas optionally be used to control the precision in determining the rank of `A`, typically via its singular values. +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and +any non-trivial strategy typically requires an SVD-based decompositions. This keyword can +be either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) + ## Keyword arguments There are 3 major modes of operation, based on the `alg` keyword, with slightly different application purposes. @@ -135,22 +151,6 @@ In this expert mode the algorithm is supplied directly, and the kind of decompos deduced from that. This hinges on the implementation of the algorithm trait [`MatrixAlgebraKit.right_orth_kind(alg)`](@ref). -## Truncation -The optional truncation strategy can be controlled via the `trunc` keyword argument, and -any non-trivial strategy typically requires an SVD-based decompositions. This keyword can -be either a `NamedTuple` or a [`TruncationStrategy`](@ref). - -### `trunc::NamedTuple` -The supported truncation keyword arguments are: - -$(docs_truncation_kwargs) - -### `trunc::TruncationStrategy` -For more control, a truncation strategy can be supplied directly. -By default, MatrixAlgebraKit supplies the following: - -$(docs_truncation_strategies) - --- !!! note @@ -182,6 +182,26 @@ the underlying orthogonal decomposition that should be used to find the null spa whereas `trunc` can optionally be used to control the precision in determining the rank of `A`, typically via its singular values. +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and any +non-trivial strategy typically requires an SVD-based decomposition. This keyword can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_null_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. By default, +MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) + +!!! note + Here [`notrunc`](@ref) has special meaning, and signifies keeping the values that + correspond to the exact zeros determined from the additional columns of `A`. + ## Keyword arguments There are 3 major modes of operation, based on the `alg` keyword, with slightly different application purposes. @@ -215,26 +235,6 @@ In this expert mode the algorithm is supplied directly, and the kind of decompos deduced from that. This hinges on the implementation of the algorithm trait [`MatrixAlgebraKit.left_null_kind(alg)`](@ref). -## Truncation -The optional truncation strategy can be controlled via the `trunc` keyword argument, and any -non-trivial strategy typically requires an SVD-based decomposition. This keyword can be -either a `NamedTuple` or a [`TruncationStrategy`](@ref). - -### `trunc::NamedTuple` -The supported truncation keyword arguments are: - -$(docs_null_truncation_kwargs) - -### `trunc::TruncationStrategy` -For more control, a truncation strategy can be supplied directly. By default, -MatrixAlgebraKit supplies the following: - -$(docs_truncation_strategies) - -!!! note - Here [`notrunc`](@ref) has special meaning, and signifies keeping the values that - correspond to the exact zeros determined from the additional columns of `A`. - --- !!! note @@ -263,6 +263,26 @@ the underlying orthogonal decomposition that should be used to find the null spa whereas `trunc` can optionally be used to control the precision in determining the rank of `A`, typically via its singular values. +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and any +non-trivial strategy typically requires an SVD-based decomposition. This keyword can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_null_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. By default, +MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) + +!!! note + Here [`notrunc`](@ref) has special meaning, and signifies keeping the values that + correspond to the exact zeros determined from the additional rows of `A`. + ## Keyword arguments There are 3 major modes of operation, based on the `alg` keyword, with slightly different application purposes. @@ -296,26 +316,6 @@ In this expert mode the algorithm is supplied directly, and the kind of decompos deduced from that. This hinges on the implementation of the algorithm trait [`MatrixAlgebraKit.right_null_kind(alg)`](@ref). -## Truncation -The optional truncation strategy can be controlled via the `trunc` keyword argument, and any -non-trivial strategy typically requires an SVD-based decomposition. This keyword can be -either a `NamedTuple` or a [`TruncationStrategy`](@ref). - -### `trunc::NamedTuple` -The supported truncation keyword arguments are: - -$(docs_null_truncation_kwargs) - -### `trunc::TruncationStrategy` -For more control, a truncation strategy can be supplied directly. By default, -MatrixAlgebraKit supplies the following: - -$(docs_truncation_strategies) - -!!! note - Here [`notrunc`](@ref) has special meaning, and signifies keeping the values that - correspond to the exact zeros determined from the additional rows of `A`. - --- !!! note From e8d8b126cd321715fb31f5a5ff414cbcc06b04b1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 15 Oct 2025 14:41:53 -0400 Subject: [PATCH 20/47] improve truncation --- src/implementations/truncation.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 934de34f..883c7759 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -28,11 +28,13 @@ end # special case `NoTruncation` for null: should keep exact zeros due to rectangularity function truncate(::typeof(left_null!), (U, S), strategy::NoTruncation) - ind = (1:max(0, size(S, 1) - size(S, 2))) .+ length(diagview(S)) + m, n = size(S) + ind = (n + 1):m return U[:, ind], ind end function truncate(::typeof(right_null!), (S, Vᴴ), strategy::NoTruncation) - ind = (1:max(0, size(S, 2) - size(S, 1))) .+ length(diagview(S)) + m, n = size(S) + ind = (m + 1):n return Vᴴ[ind, :], ind end From 01ee3942142d14bac5e085fe7736f102d41fde89 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 15 Oct 2025 14:43:19 -0400 Subject: [PATCH 21/47] headers --- src/implementations/orthnull.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 8e50aad7..ff7d9983 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -1,5 +1,5 @@ -# Orthogonalization -# ----------------- +# Inputs +# ------ copy_input(::typeof(left_orth), A) = copy_input(qr_compact, A) # do we ever need anything else copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever need anything else copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else @@ -35,7 +35,8 @@ check_input(::typeof(right_null_lq!), A, Nᴴ, alg::AbstractAlgorithm) = check_input(lq_null!, A, Nᴴ, alg) check_input(::typeof(right_null_svd!), A, Nᴴ, alg::AbstractAlgorithm) = nothing - +# Outputs +# ------- initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = initialize_output(left_orth_kind(alg), A, alg) initialize_output(::typeof(left_orth_qr!), A, alg::AbstractAlgorithm) = @@ -66,8 +67,6 @@ initialize_output(::typeof(right_null_lq!), A, alg::AbstractAlgorithm) = initialize_output(lq_null!, A, alg) initialize_output(::typeof(right_null_svd!), A, alg::AbstractAlgorithm) = nothing -# Outputs -# ------- function initialize_orth_svd(A::AbstractMatrix, F, alg) S = Diagonal(initialize_output(svd_vals!, A, alg)) return F[1], S, F[2] From c0fda8cb4074e001f7ebc49170207b8e6e126e25 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 22 Oct 2025 15:26:54 -0400 Subject: [PATCH 22/47] fix merge --- test/orthnull.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/orthnull.jl b/test/orthnull.jl index 40c5056d..f12ae91c 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -7,9 +7,6 @@ using LinearAlgebra: LinearAlgebra, I # testing non-AbstractArray codepaths: include("linearmap.jl") -# testing non-AbstractArray codepaths: -include("linearmap.jl") - eltypes = (Float32, Float64, ComplexF32, ComplexF64) @testset "left_orth and left_null for T = $T" for T in eltypes rng = StableRNG(123) From 3bc008122ad6af11d3c0cc56ac5066a51ee2ab27 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 15 Oct 2025 16:11:30 -0400 Subject: [PATCH 23/47] work out alternate proposition --- src/algorithms.jl | 3 +- src/implementations/orthnull.jl | 51 ++++++++++++++++++ src/interface/decompositions.jl | 20 +++++++ src/interface/orthnull.jl | 94 +++++++++++++++------------------ test/orthnull.jl | 6 +-- 5 files changed, 119 insertions(+), 55 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 0c7f2594..7125e2db 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -351,7 +351,8 @@ end function _arg_expr(::Val{1}, f, f!) return quote # out of place to inplace - $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...) + @inline $f(A; alg = nothing, kwargs...) = $f(A, select_algorithm($f, A, alg; kwargs...)) + # $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...) $f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg) # fill in arguments diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index ff7d9983..4330376b 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -7,6 +7,14 @@ copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need a check_input(::typeof(left_orth!), A, VC, alg::AbstractAlgorithm) = check_input(left_orth_kind(alg), A, VC, alg) + +check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaQR) = + check_input(qr_compact!, A, VC, alg) +check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaPolar) = + check_input(left_polar!, A, VC, alg) +check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaSVD) = + check_input(qr_compact!, A, VC, alg) + check_input(::typeof(left_orth_qr!), A, VC, alg::AbstractAlgorithm) = check_input(qr_compact!, A, VC, alg) check_input(::typeof(left_orth_polar!), A, VC, alg::AbstractAlgorithm) = @@ -16,6 +24,14 @@ check_input(::typeof(left_orth_svd!), A, VC, alg::AbstractAlgorithm) = check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(right_orth_kind(alg), A, CVᴴ, alg) + +check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaLQ) = + check_input(lq_compact!, A, VC, alg) +check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaPolar) = + check_input(right_polar!, A, VC, alg) +check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaSVD) = + check_input(lq_compact!, A, VC, alg) + check_input(::typeof(right_orth_lq!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(lq_compact!, A, CVᴴ, alg) check_input(::typeof(right_orth_polar!), A, CVᴴ, alg::AbstractAlgorithm) = @@ -39,6 +55,14 @@ check_input(::typeof(right_null_svd!), A, Nᴴ, alg::AbstractAlgorithm) = nothin # ------- initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = initialize_output(left_orth_kind(alg), A, alg) + +initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaQR) = + initialize_output(qr_compact!, A, alg) +initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaPolar) = + initialize_output(left_polar!, A, alg) +initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaSVD) = + initialize_output(qr_compact!, A, alg) + initialize_output(::typeof(left_orth_qr!), A, alg::AbstractAlgorithm) = initialize_output(qr_compact!, A, alg) initialize_output(::typeof(left_orth_polar!), A, alg::AbstractAlgorithm) = @@ -48,6 +72,14 @@ initialize_output(::typeof(left_orth_svd!), A, alg::AbstractAlgorithm) = initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = initialize_output(right_orth_kind(alg), A, alg) + +initialize_output(::typeof(right_orth!), A, alg::RightOrthViaLQ) = + initialize_output(lq_compact!, A, alg) +initialize_output(::typeof(right_orth!), A, alg::RightOrthViaPolar) = + initialize_output(right_polar!, A, alg) +initialize_output(::typeof(right_orth!), A, alg::RightOrthViaSVD) = + initialize_output(lq_compact!, A, alg) + initialize_output(::typeof(right_orth_lq!), A, alg::AbstractAlgorithm) = initialize_output(lq_compact!, A, alg) initialize_output(::typeof(right_orth_polar!), A, alg::AbstractAlgorithm) = @@ -77,10 +109,14 @@ initialize_orth_svd(A, F, alg) = initialize_output(svd_compact!, A, alg) # Implementation of orth functions # -------------------------------- left_orth!(A, VC, alg::AbstractAlgorithm) = left_orth_kind(alg)(A, VC, alg) +left_orth!(A, VC, alg::LeftOrthViaQR) = qr_compact!(A, VC, alg.alg) +left_orth!(A, VC, alg::LeftOrthViaPolar) = left_polar!(A, VC, alg.alg) left_orth_qr!(A, VC, alg::AbstractAlgorithm) = qr_compact!(A, VC, alg) left_orth_polar!(A, VC, alg::AbstractAlgorithm) = left_polar!(A, VC, alg) right_orth!(A, CVᴴ, alg::AbstractAlgorithm) = right_orth_kind(alg)(A, CVᴴ, alg) +right_orth!(A, CVᴴ, alg::RightOrthViaLQ) = lq_compact!(A, CVᴴ, alg.alg) +right_orth!(A, CVᴴ, alg::RightOrthViaPolar) = right_polar!(A, CVᴴ, alg.alg) right_orth_lq!(A, CVᴴ, alg::AbstractAlgorithm) = lq_compact!(A, CVᴴ, alg) right_orth_polar!(A, CVᴴ, alg::AbstractAlgorithm) = right_polar!(A, CVᴴ, alg) @@ -92,6 +128,14 @@ function left_orth_svd!(A, VC, alg::AbstractAlgorithm) lmul!(S, C) return V, C end +function left_orth!(A, VC, alg::LeftOrthViaSVD) + check_input(left_orth!, A, VC, alg) + USVᴴ = initialize_orth_svd(A, VC, alg.alg) + V, S, C = does_truncate(alg.alg) ? svd_trunc!(A, USVᴴ, alg.alg) : svd_compact!(A, USVᴴ, alg.alg) + lmul!(S, C) + return V, C +end + function right_orth_svd!(A, CVᴴ, alg::AbstractAlgorithm) check_input(right_orth_svd!, A, CVᴴ, alg) USVᴴ = initialize_orth_svd(A, CVᴴ, alg) @@ -99,6 +143,13 @@ function right_orth_svd!(A, CVᴴ, alg::AbstractAlgorithm) rmul!(C, S) return C, Vᴴ end +function right_orth!(A, CVᴴ, alg::RightOrthViaSVD) + check_input(right_orth!, A, CVᴴ, alg) + USVᴴ = initialize_orth_svd(A, CVᴴ, alg.alg) + C, S, Vᴴ = does_truncate(alg.alg) ? svd_trunc!(A, USVᴴ, alg.alg) : svd_compact!(A, USVᴴ, alg.alg) + rmul!(C, S) + return C, Vᴴ +end # Implementation of null functions # -------------------------------- diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 610c4928..8816ae5d 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -148,6 +148,26 @@ until convergence up to tolerance `tol`. left_orth_kind(::Union{PolarViaSVD, PolarNewton}) = left_orth_polar! right_orth_kind(::Union{PolarViaSVD, PolarNewton}) = right_orth_polar! +# ========================= +# ORTHOGONALIZATION ALGORITHMS +# ========================= + +struct LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end + +const LeftOrthViaQR = LeftOrthAlgorithm{:qr} +const LeftOrthViaPolar = LeftOrthAlgorithm{:polar} +const LeftOrthViaSVD = LeftOrthAlgorithm{:svd} + +struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end + +const RightOrthViaLQ = RightOrthAlgorithm{:lq} +const RightOrthViaPolar = RightOrthAlgorithm{:polar} +const RightOrthViaSVD = RightOrthAlgorithm{:svd} + # ========================= # DIAGONAL ALGORITHMS # ========================= diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 16674abb..3bfe5652 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -335,66 +335,58 @@ function right_null_svd! end # Algorithm selection # ------------------- # specific override for `alg::Symbol` case, to allow for choosing the kind of factorization. -function select_algorithm(::typeof(left_orth!), A, alg::Symbol; trunc = nothing, kwargs...) - alg === :svd && return select_algorithm( - left_orth_svd!, A, get(kwargs, :alg_svd, nothing); trunc - ) - - isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) - - alg === :qr && return select_algorithm(left_orth_qr!, A, get(kwargs, :alg_qr, nothing)) - alg === :polar && return select_algorithm(left_orth_polar!, A, get(kwargs, :alg_polar, nothing)) +@inline select_algorithm(::typeof(left_orth!), A, alg::Symbol; trunc = nothing, kwargs...) = + LeftOrthAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) +@inline select_algorithm(::typeof(right_orth!), A, alg::Symbol; trunc = nothing, kwargs...) = + RightOrthAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) + +function LeftOrthViaQR(A; alg = nothing, trunc = nothing, kwargs...) + isnothing(trunc) || + throw(ArgumentError("QR-based `left_orth` is incompatible with specifying `trunc`")) + alg = select_algorithm(qr_compact!, A, alg; kwargs...) + return LeftOrthViaQR{typeof(alg)}(alg) +end +function LeftOrthViaPolar(A; alg = nothing, trunc = nothing, kwargs...) + isnothing(trunc) || + throw(ArgumentError("Polar-based `left_orth` is incompatible with specifying `trunc`")) + alg = select_algorithm(left_polar!, A, alg; kwargs...) + return LeftOrthViaPolar{typeof(alg)}(alg) +end +function LeftOrthViaSVD(A; alg = nothing, trunc = nothing, kwargs...) + alg = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : + select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) + return LeftOrthViaSVD{typeof(alg)}(alg) +end - throw(ArgumentError(lazy"Unknown alg symbol $alg")) +function RightOrthViaLQ(A; alg = nothing, trunc = nothing, kwargs...) + isnothing(trunc) || + throw(ArgumentError("LQ-based `right_orth` is incompatible with specifying `trunc`")) + alg = select_algorithm(lq_compact!, A, alg; kwargs...) + return RightOrthViaLQ{typeof(alg)}(alg) +end +function RightOrthViaPolar(A; alg = nothing, trunc = nothing, kwargs...) + isnothing(trunc) || + throw(ArgumentError("Polar-based `right_orth` is incompatible with specifying `trunc`")) + alg = select_algorithm(right_polar!, A, alg; kwargs...) + return RightOrthViaPolar{typeof(alg)}(alg) +end +function RightOrthViaSVD(A; alg = nothing, trunc = nothing, kwargs...) + alg = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : + select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) + return RightOrthViaSVD{typeof(alg)}(alg) end default_algorithm(::typeof(left_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(left_orth_qr!, A; kwargs...) : - select_algorithm(left_orth_svd!, A; trunc, kwargs...) + isnothing(trunc) ? LeftOrthViaQR(A; kwargs...) : LeftOrthViaSVD(A; trunc, kwargs...) # disambiguate default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? select_algorithm(left_orth_qr!, A; kwargs...) : - select_algorithm(left_orth_svd!, A; trunc, kwargs...) - - -select_algorithm(::typeof(left_orth_qr!), A, alg = nothing; kwargs...) = - select_algorithm(qr_compact!, A, alg; kwargs...) -select_algorithm(::typeof(left_orth_polar!), A, alg = nothing; kwargs...) = - select_algorithm(left_polar!, A, alg; kwargs...) -select_algorithm(::typeof(left_orth_svd!), A, alg = nothing; trunc = nothing, kwargs...) = - isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : - select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) - -# specific override for `alg::Symbol` case, to allow for choosing the kind of factorization. -function select_algorithm(::typeof(right_orth!), A, alg::Symbol; trunc = nothing, kwargs...) - alg === :svd && return select_algorithm( - right_orth_svd!, A, get(kwargs, :alg_svd, nothing); trunc - ) - - isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) - - alg === :lq && return select_algorithm(right_orth_lq!, A, get(kwargs, :alg_lq, nothing)) - alg === :polar && return select_algorithm(right_orth_polar!, A, get(kwargs, :alg_polar, nothing)) - - throw(ArgumentError(lazy"Unknown alg symbol $alg")) -end + isnothing(trunc) ? LeftOrthViaQR(A; kwargs...) : LeftOrthViaSVD(A; trunc, kwargs...) default_algorithm(::typeof(right_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(right_orth_lq!, A; kwargs...) : - select_algorithm(right_orth_svd!, A; trunc, kwargs...) -# disambiguate: + isnothing(trunc) ? RightOrthViaLQ(A; kwargs...) : RightOrthViaSVD(A; trunc, kwargs...) +# disambiguate default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? select_algorithm(right_orth_lq!, A; kwargs...) : - select_algorithm(right_orth_svd!, A; trunc, kwargs...) - - -select_algorithm(::typeof(right_orth_lq!), A, alg = nothing; kwargs...) = - select_algorithm(lq_compact!, A, alg; kwargs...) -select_algorithm(::typeof(right_orth_polar!), A, alg = nothing; kwargs...) = - select_algorithm(right_polar!, A, alg; kwargs...) -select_algorithm(::typeof(right_orth_svd!), A, alg = nothing; trunc = nothing, kwargs...) = - isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : - select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) + isnothing(trunc) ? RightOrthViaLQ(A; kwargs...) : RightOrthViaSVD(A; trunc, kwargs...) function select_algorithm(::typeof(left_null!), A, alg::Symbol; trunc = nothing, kwargs...) alg === :svd && return select_algorithm( diff --git a/test/orthnull.jl b/test/orthnull.jl index f12ae91c..dc532c07 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -43,8 +43,8 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) end # passing a kind and some kwargs - V, C = @constinferred left_orth(A; alg = :qr, alg_qr = (; positive = true)) - N = @constinferred left_null(A; alg = :qr, alg_qr = (; positive = true)) + V, C = @constinferred left_orth(A; alg = :qr, qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) @test V isa Matrix{T} && size(V) == (m, minmn) @test C isa Matrix{T} && size(C) == (minmn, n) @test N isa Matrix{T} && size(N) == (m, m - minmn) @@ -56,7 +56,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) # passing an algorithm V, C = @constinferred left_orth(A; alg = LAPACK_HouseholderQR()) - N = @constinferred left_null(A; alg = :qr, alg_qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) @test V isa Matrix{T} && size(V) == (m, minmn) @test C isa Matrix{T} && size(C) == (minmn, n) @test N isa Matrix{T} && size(N) == (m, m - minmn) From d3e155bcf5054fd21c95587dedf482c3f2cf1d97 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 22 Oct 2025 15:31:10 -0400 Subject: [PATCH 24/47] unpack algorithms --- src/implementations/orthnull.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 4330376b..77c21573 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -9,11 +9,11 @@ check_input(::typeof(left_orth!), A, VC, alg::AbstractAlgorithm) = check_input(left_orth_kind(alg), A, VC, alg) check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaQR) = - check_input(qr_compact!, A, VC, alg) + check_input(qr_compact!, A, VC, alg.alg) check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaPolar) = - check_input(left_polar!, A, VC, alg) + check_input(left_polar!, A, VC, alg.alg) check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaSVD) = - check_input(qr_compact!, A, VC, alg) + check_input(qr_compact!, A, VC, alg.alg) check_input(::typeof(left_orth_qr!), A, VC, alg::AbstractAlgorithm) = check_input(qr_compact!, A, VC, alg) @@ -26,11 +26,11 @@ check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(right_orth_kind(alg), A, CVᴴ, alg) check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaLQ) = - check_input(lq_compact!, A, VC, alg) + check_input(lq_compact!, A, VC, alg.alg) check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaPolar) = - check_input(right_polar!, A, VC, alg) + check_input(right_polar!, A, VC, alg.alg) check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaSVD) = - check_input(lq_compact!, A, VC, alg) + check_input(lq_compact!, A, VC, alg.alg) check_input(::typeof(right_orth_lq!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(lq_compact!, A, CVᴴ, alg) @@ -57,11 +57,11 @@ initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = initialize_output(left_orth_kind(alg), A, alg) initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaQR) = - initialize_output(qr_compact!, A, alg) + initialize_output(qr_compact!, A, alg.alg) initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaPolar) = - initialize_output(left_polar!, A, alg) + initialize_output(left_polar!, A, alg.alg) initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaSVD) = - initialize_output(qr_compact!, A, alg) + initialize_output(qr_compact!, A, alg.alg) initialize_output(::typeof(left_orth_qr!), A, alg::AbstractAlgorithm) = initialize_output(qr_compact!, A, alg) @@ -74,11 +74,11 @@ initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = initialize_output(right_orth_kind(alg), A, alg) initialize_output(::typeof(right_orth!), A, alg::RightOrthViaLQ) = - initialize_output(lq_compact!, A, alg) + initialize_output(lq_compact!, A, alg.alg) initialize_output(::typeof(right_orth!), A, alg::RightOrthViaPolar) = - initialize_output(right_polar!, A, alg) + initialize_output(right_polar!, A, alg.alg) initialize_output(::typeof(right_orth!), A, alg::RightOrthViaSVD) = - initialize_output(lq_compact!, A, alg) + initialize_output(lq_compact!, A, alg.alg) initialize_output(::typeof(right_orth_lq!), A, alg::AbstractAlgorithm) = initialize_output(lq_compact!, A, alg) From 437ed8c9ddaff70faef6130a0b7fea1e72bb3b0e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Oct 2025 07:36:26 -0400 Subject: [PATCH 25/47] rework traits --- src/algorithms.jl | 39 -------------- src/implementations/orthnull.jl | 65 ++++------------------ src/interface/decompositions.jl | 95 ++++++++++++++++++++++++++------- src/interface/lq.jl | 10 ---- src/interface/orthnull.jl | 47 +++++++--------- 5 files changed, 104 insertions(+), 152 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 7125e2db..6c9b9e28 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -55,45 +55,6 @@ end # Algorithm traits # ---------------- -""" - left_orth_kind(alg::AbstractAlgorithm) -> f! - -Select an appropriate factorization function for applying `left_orth!(A, alg)`. -By default, this is either `left_orth_qr!`, `left_orth_polar!` or `left_orth_svd!`, but -this can be extended to insert arbitrary other decomposition functions, which should follow -the signature `f!(A, F, alg) -> F` -""" -left_orth_kind(alg::AbstractAlgorithm) = error( - """ - Unkown or invalid `left_orth` algorithm type `$(typeof(alg))`. - To register the algorithm type, define: - - MatrixAlgebraKit.left_orth_kind(alg) = f! - - where `f!` should be the factorization function that will be used. - By default, this is either `left_orth_qr!`, `left_orth_polar!` or `left_orth_svd!`. - """ -) - -""" - right_orth_kind(alg::AbstractAlgorithm) -> f! - -Select an appropriate factorization function for applying `right_orth!(A, alg)`. -By default, this is either `right_orth_lq!`, `right_orth_polar!` or `right_orth_svd!`, but -this can be extended to insert arbitrary other decomposition functions, which should follow -the signature `f!(A, F, alg) -> F` -""" -right_orth_kind(alg::AbstractAlgorithm) = error( - """ - Unkown or invalid `right_orth` algorithm type `$(typeof(alg))`. - To register the algorithm type, define: - - MatrixAlgebraKit.right_orth_kind(alg) = f! - - where `f!` should be the factorization function that will be used. - By default, this is either `right_orth_lq!`, `right_orth_polar!` or `right_orth_svd!`. - """ -) """ left_null_kind(alg::AbstractAlgorithm) -> f! diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 77c21573..3b35d79a 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -6,7 +6,7 @@ copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need an copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else check_input(::typeof(left_orth!), A, VC, alg::AbstractAlgorithm) = - check_input(left_orth_kind(alg), A, VC, alg) + check_input(left_orth!, A, VC, left_orth_alg(alg)) check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaQR) = check_input(qr_compact!, A, VC, alg.alg) @@ -15,15 +15,8 @@ check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaPolar) = check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaSVD) = check_input(qr_compact!, A, VC, alg.alg) -check_input(::typeof(left_orth_qr!), A, VC, alg::AbstractAlgorithm) = - check_input(qr_compact!, A, VC, alg) -check_input(::typeof(left_orth_polar!), A, VC, alg::AbstractAlgorithm) = - check_input(left_polar!, A, VC, alg) -check_input(::typeof(left_orth_svd!), A, VC, alg::AbstractAlgorithm) = - check_input(qr_compact!, A, VC, alg) - check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = - check_input(right_orth_kind(alg), A, CVᴴ, alg) + check_input(right_orth!, A, CVᴴ, right_orth_alg(alg)) check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaLQ) = check_input(lq_compact!, A, VC, alg.alg) @@ -32,13 +25,6 @@ check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaPolar) = check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaSVD) = check_input(lq_compact!, A, VC, alg.alg) -check_input(::typeof(right_orth_lq!), A, CVᴴ, alg::AbstractAlgorithm) = - check_input(lq_compact!, A, CVᴴ, alg) -check_input(::typeof(right_orth_polar!), A, CVᴴ, alg::AbstractAlgorithm) = - check_input(right_polar!, A, CVᴴ, alg) -check_input(::typeof(right_orth_svd!), A, CVᴴ, alg::AbstractAlgorithm) = - check_input(lq_compact!, A, CVᴴ, alg) - check_input(::typeof(left_null!), A, N, alg::AbstractAlgorithm) = check_input(left_null_kind(alg), A, N, alg) check_input(::typeof(left_null_qr!), A, N, alg::AbstractAlgorithm) = @@ -54,7 +40,7 @@ check_input(::typeof(right_null_svd!), A, Nᴴ, alg::AbstractAlgorithm) = nothin # Outputs # ------- initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = - initialize_output(left_orth_kind(alg), A, alg) + initialize_output(left_orth!, A, left_orth_alg(alg)) initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaQR) = initialize_output(qr_compact!, A, alg.alg) @@ -63,15 +49,8 @@ initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaPolar) = initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaSVD) = initialize_output(qr_compact!, A, alg.alg) -initialize_output(::typeof(left_orth_qr!), A, alg::AbstractAlgorithm) = - initialize_output(qr_compact!, A, alg) -initialize_output(::typeof(left_orth_polar!), A, alg::AbstractAlgorithm) = - initialize_output(left_polar!, A, alg) -initialize_output(::typeof(left_orth_svd!), A, alg::AbstractAlgorithm) = - initialize_output(qr_compact!, A, alg) - initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = - initialize_output(right_orth_kind(alg), A, alg) + initialize_output(right_orth!, A, right_orth_alg(alg)) initialize_output(::typeof(right_orth!), A, alg::RightOrthViaLQ) = initialize_output(lq_compact!, A, alg.alg) @@ -80,13 +59,6 @@ initialize_output(::typeof(right_orth!), A, alg::RightOrthViaPolar) = initialize_output(::typeof(right_orth!), A, alg::RightOrthViaSVD) = initialize_output(lq_compact!, A, alg.alg) -initialize_output(::typeof(right_orth_lq!), A, alg::AbstractAlgorithm) = - initialize_output(lq_compact!, A, alg) -initialize_output(::typeof(right_orth_polar!), A, alg::AbstractAlgorithm) = - initialize_output(right_polar!, A, alg) -initialize_output(::typeof(right_orth_svd!), A, alg::AbstractAlgorithm) = - initialize_output(lq_compact!, A, alg) - initialize_output(::typeof(left_null!), A, alg::AbstractAlgorithm) = initialize_output(left_null_kind(alg), A, alg) initialize_output(::typeof(left_null_qr!), A, alg::AbstractAlgorithm) = @@ -108,26 +80,10 @@ initialize_orth_svd(A, F, alg) = initialize_output(svd_compact!, A, alg) # Implementation of orth functions # -------------------------------- -left_orth!(A, VC, alg::AbstractAlgorithm) = left_orth_kind(alg)(A, VC, alg) +left_orth!(A, VC, alg::AbstractAlgorithm) = left_orth!(A, VC, left_orth_alg(alg)) left_orth!(A, VC, alg::LeftOrthViaQR) = qr_compact!(A, VC, alg.alg) left_orth!(A, VC, alg::LeftOrthViaPolar) = left_polar!(A, VC, alg.alg) -left_orth_qr!(A, VC, alg::AbstractAlgorithm) = qr_compact!(A, VC, alg) -left_orth_polar!(A, VC, alg::AbstractAlgorithm) = left_polar!(A, VC, alg) - -right_orth!(A, CVᴴ, alg::AbstractAlgorithm) = right_orth_kind(alg)(A, CVᴴ, alg) -right_orth!(A, CVᴴ, alg::RightOrthViaLQ) = lq_compact!(A, CVᴴ, alg.alg) -right_orth!(A, CVᴴ, alg::RightOrthViaPolar) = right_polar!(A, CVᴴ, alg.alg) -right_orth_lq!(A, CVᴴ, alg::AbstractAlgorithm) = lq_compact!(A, CVᴴ, alg) -right_orth_polar!(A, CVᴴ, alg::AbstractAlgorithm) = right_polar!(A, CVᴴ, alg) - # orth_svd requires implementations of `lmul!` and `rmul!` -function left_orth_svd!(A, VC, alg::AbstractAlgorithm) - check_input(left_orth_svd!, A, VC, alg) - USVᴴ = initialize_orth_svd(A, VC, alg) - V, S, C = does_truncate(alg) ? svd_trunc!(A, USVᴴ, alg) : svd_compact!(A, USVᴴ, alg) - lmul!(S, C) - return V, C -end function left_orth!(A, VC, alg::LeftOrthViaSVD) check_input(left_orth!, A, VC, alg) USVᴴ = initialize_orth_svd(A, VC, alg.alg) @@ -135,14 +91,11 @@ function left_orth!(A, VC, alg::LeftOrthViaSVD) lmul!(S, C) return V, C end +# orth_svd requires implementations of `lmul!` and `rmul!` +right_orth!(A, CVᴴ, alg::AbstractAlgorithm) = right_orth!(A, CVᴴ, right_orth_alg(alg)) +right_orth!(A, CVᴴ, alg::RightOrthViaLQ) = lq_compact!(A, CVᴴ, alg.alg) +right_orth!(A, CVᴴ, alg::RightOrthViaPolar) = right_polar!(A, CVᴴ, alg.alg) -function right_orth_svd!(A, CVᴴ, alg::AbstractAlgorithm) - check_input(right_orth_svd!, A, CVᴴ, alg) - USVᴴ = initialize_orth_svd(A, CVᴴ, alg) - C, S, Vᴴ = does_truncate(alg) ? svd_trunc!(A, USVᴴ, alg) : svd_compact!(A, USVᴴ, alg) - rmul!(C, S) - return C, Vᴴ -end function right_orth!(A, CVᴴ, alg::RightOrthViaSVD) check_input(right_orth!, A, CVᴴ, alg) USVᴴ = initialize_orth_svd(A, CVᴴ, alg.alg) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 8816ae5d..bc19cdd6 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -148,25 +148,6 @@ until convergence up to tolerance `tol`. left_orth_kind(::Union{PolarViaSVD, PolarNewton}) = left_orth_polar! right_orth_kind(::Union{PolarViaSVD, PolarNewton}) = right_orth_polar! -# ========================= -# ORTHOGONALIZATION ALGORITHMS -# ========================= - -struct LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm - alg::Alg -end - -const LeftOrthViaQR = LeftOrthAlgorithm{:qr} -const LeftOrthViaPolar = LeftOrthAlgorithm{:polar} -const LeftOrthViaSVD = LeftOrthAlgorithm{:svd} - -struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm - alg::Alg -end - -const RightOrthViaLQ = RightOrthAlgorithm{:lq} -const RightOrthViaPolar = RightOrthAlgorithm{:polar} -const RightOrthViaSVD = RightOrthAlgorithm{:svd} # ========================= # DIAGONAL ALGORITHMS @@ -308,6 +289,24 @@ Divide and Conquer algorithm. const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} +# Alternative algorithm (necessary for CUDA) +""" + LQViaTransposedQR(qr_alg) + +Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. +The `qr_alg` specifies which QR-decomposition implementation to use. +""" +struct LQViaTransposedQR{A <: AbstractAlgorithm} <: AbstractAlgorithm + qr_alg::A +end +function Base.show(io::IO, alg::LQViaTransposedQR) + print(io, "LQViaTransposedQR(") + _show_alg(io, alg.qr_alg) + return print(io, ")") +end + +# Various consts and unions +# ------------------------- const GPU_Simple = Union{CUSOLVER_Simple} const GPU_EigAlgorithm = Union{GPU_Simple} const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} @@ -324,3 +323,61 @@ const GPU_Randomized = Union{CUSOLVER_Randomized} left_orth_kind(::GPU_SVDAlgorithm) = left_orth_svd! right_orth_kind(::GPU_SVDAlgorithm) = right_orth_svd! + +# ================================ +# ORTHOGONALIZATION ALGORITHMS +# ================================ + +struct LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end + +LeftOrthAlgorithm(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `left_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type for `left_orth`, define + + MatrixAlgebraKit.LeftOrthAlgorithm(alg) = LeftOrthAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:qr`, `:polar` or `:svd`, to select [`qr_compact!`](@ref), + [`left_polar!`](@ref), [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) + +const LeftOrthViaQR = LeftOrthAlgorithm{:qr} +LeftOrthAlgorithm(alg::Union{LAPACK_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) = + LeftOrthViaQR(alg) + +const LeftOrthViaPolar = LeftOrthAlgorithm{:polar} +LeftOrthAlgorithm(alg::Union{PolarSVD, PolarNewton}) = LeftOrthViaPolar(alg) + +const LeftOrthViaSVD = LeftOrthAlgorithm{:svd} +LeftOrthAlgorithm(alg::Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm}) = LeftOrthViaSVD(alg) + + +struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end + +RightOrthAlgorithm(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `right_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type for `right_orth`, define + + MatrixAlgebraKit.RightOrthAlgorithm(alg) = RightOrthAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:lq`, `:polar` or `:svd`, to select [`lq_compact!`](@ref), + [`right_polar!`](@ref), [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) + +const RightOrthViaLQ = RightOrthAlgorithm{:lq} +RightOrthAlgorithm(alg::Union{LAPACK_HouseholderLQ, LQViaTransposedQR}) = LeftOrthViaLQ(alg) + +const RightOrthViaPolar = RightOrthAlgorithm{:polar} +RightOrthAlgorithm(alg::Union{PolarSVD, PolarNewton}) = RightOrthViaPolar(alg) + +const RightOrthViaSVD = RightOrthAlgorithm{:svd} +RightOrthAlgorithm(alg::Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm}) = RightOrthViaSVD(alg) diff --git a/src/interface/lq.jl b/src/interface/lq.jl index b0d138d4..97a70184 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -84,13 +84,3 @@ for f in (:lq_full!, :lq_compact!, :lq_null!) return default_lq_algorithm(A; kwargs...) end end - -# Alternative algorithm (necessary for CUDA) -struct LQViaTransposedQR{A <: AbstractAlgorithm} <: AbstractAlgorithm - qr_alg::A -end -function Base.show(io::IO, alg::LQViaTransposedQR) - print(io, "LQViaTransposedQR(") - _show_alg(io, alg.qr_alg) - return print(io, ")") -end diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 3bfe5652..c8937712 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -43,29 +43,30 @@ Here, the driving selector is `alg`, and depending on its value, the algorithm s procedure takes other keywords into account: * `:qr` : Factorize via QR decomposition, with further customizations through the - `alg_qr` keyword. This mode requires `isnothing(trunc)`, and is equivalent to + `qr` keyword. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - V, C = qr_compact(A; alg_qr...) + V, C = qr_compact(A; alg = qr) ``` * `:polar` : Factorize via polar decomposition, with further customizations through the - `alg_polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to + `polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - V, C = left_polar(A; alg_polar...) + V, C = left_polar(A; alg = polar) ``` -* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. +* `:svd` : Factorize via SVD, with further customizations through the `svd` keyword. This mode further allows truncation, which can be selected through the `trunc` argument. This mode is roughly equivalent to: ```julia - V, S, C = svd_trunc(A; trunc, alg_svd...) + V, S, C = svd_trunc(A; trunc, alg = svd) C = S * C ``` ### `alg::AbstractAlgorithm` In this expert mode the algorithm is supplied directly, and the kind of decomposition is -deduced from that. This hinges on the implementation of the algorithm trait -[`MatrixAlgebraKit.left_orth_kind(alg)`](@ref). +deduced from that. This is achieved either directly by providing a +[`LeftOrthAlgorithm{kind}`](@ref LeftOrthAlgorithm), or automatically by attempting to +deduce the decomposition kind with `LeftOrthAlgorithm(alg)`. --- @@ -74,16 +75,10 @@ deduced from that. This hinges on the implementation of the algorithm trait destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `CV` as output. -See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), -[`right_null(!)`](@ref right_null) +See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) """ @functiondef left_orth -# helper functions -function left_orth_qr! end -function left_orth_polar! end -function left_orth_svd! end - """ right_orth(A; [trunc], kwargs...) -> C, Vᴴ right_orth!(A, [CVᴴ]; [trunc], kwargs...) -> C, Vᴴ @@ -127,29 +122,30 @@ Here, the driving selector is `alg`, and depending on its value, the algorithm s procedure takes other keywords into account: * `:lq` : Factorize via LQ decomposition, with further customizations through the - `alg_lq` keyword. This mode requires `isnothing(trunc)`, and is equivalent to + `lq` keyword. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - C, Vᴴ = lq_compact(A; alg_lq...) + C, Vᴴ = lq_compact(A; alg = lq) ``` * `:polar` : Factorize via polar decomposition, with further customizations through the - `alg_polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to + `polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - C, Vᴴ = right_polar(A; alg_polar...) + C, Vᴴ = right_polar(A; alg = polar) ``` -* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. +* `:svd` : Factorize via SVD, with further customizations through the `svd` keyword. This mode further allows truncation, which can be selected through the `trunc` argument. This mode is roughly equivalent to: ```julia - C, S, Vᴴ = svd_trunc(A; trunc, alg_svd...) + C, S, Vᴴ = svd_trunc(A; trunc, alg = svd) C = C * S ``` ### `alg::AbstractAlgorithm` In this expert mode the algorithm is supplied directly, and the kind of decomposition is -deduced from that. This hinges on the implementation of the algorithm trait -[`MatrixAlgebraKit.right_orth_kind(alg)`](@ref). +deduced from that. This is achieved either directly by providing a +[`RightOrthAlgorithm{kind}`](@ref RightOrthAlgorithm), or automatically by attempting to +deduce the decomposition kind with `RightOrthAlgorithm(alg)`. --- @@ -163,11 +159,6 @@ See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), """ @functiondef right_orth -# helper functions -function right_orth_lq! end -function right_orth_polar! end -function right_orth_svd! end - # Null functions # -------------- """ From f6266ea703afdfecf7635a563df460da9ac5c9c5 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Oct 2025 10:22:44 -0400 Subject: [PATCH 26/47] also include null implementation --- src/implementations/orthnull.jl | 66 ++++++++-------- src/interface/decompositions.jl | 71 ++++++++++++++--- src/interface/orthnull.jl | 130 +++++++++++++------------------- test/orthnull.jl | 2 + 4 files changed, 147 insertions(+), 122 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 3b35d79a..92e4cd1b 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -26,16 +26,16 @@ check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaSVD) = check_input(lq_compact!, A, VC, alg.alg) check_input(::typeof(left_null!), A, N, alg::AbstractAlgorithm) = - check_input(left_null_kind(alg), A, N, alg) -check_input(::typeof(left_null_qr!), A, N, alg::AbstractAlgorithm) = - check_input(qr_null!, A, N, alg) -check_input(::typeof(left_null_svd!), A, N, alg::AbstractAlgorithm) = nothing + check_input(left_null!, A, N, left_null_alg(alg)) +check_input(::typeof(left_null!), A, N, alg::LeftNullViaQR) = + check_input(qr_null!, A, N, alg.alg) +check_input(::typeof(left_null!), A, N, alg::LeftNullViaSVD) = nothing check_input(::typeof(right_null!), A, Nᴴ, alg::AbstractAlgorithm) = - check_input(right_null_kind(alg), A, Nᴴ, alg) -check_input(::typeof(right_null_lq!), A, Nᴴ, alg::AbstractAlgorithm) = - check_input(lq_null!, A, Nᴴ, alg) -check_input(::typeof(right_null_svd!), A, Nᴴ, alg::AbstractAlgorithm) = nothing + check_input(right_null!, A, Nᴴ, right_null_alg(alg)) +check_input(::typeof(right_null!), A, Nᴴ, alg::RightNullViaLQ) = + check_input(lq_null!, A, Nᴴ, alg.alg) +check_input(::typeof(right_null!), A, Nᴴ, alg::RightNullViaSVD) = nothing # Outputs # ------- @@ -60,16 +60,16 @@ initialize_output(::typeof(right_orth!), A, alg::RightOrthViaSVD) = initialize_output(lq_compact!, A, alg.alg) initialize_output(::typeof(left_null!), A, alg::AbstractAlgorithm) = - initialize_output(left_null_kind(alg), A, alg) -initialize_output(::typeof(left_null_qr!), A, alg::AbstractAlgorithm) = - initialize_output(qr_null!, A, alg) -initialize_output(::typeof(left_null_svd!), A, alg::AbstractAlgorithm) = nothing + initialize_output(left_null!, A, left_null_alg(alg)) +initialize_output(::typeof(left_null!), A, alg::LeftNullViaQR) = + initialize_output(qr_null!, A, alg.alg) +initialize_output(::typeof(left_null!), A, alg::LeftNullViaSVD) = nothing initialize_output(::typeof(right_null!), A, alg::AbstractAlgorithm) = - initialize_output(right_null_kind(alg), A, alg) -initialize_output(::typeof(right_null_lq!), A, alg::AbstractAlgorithm) = - initialize_output(lq_null!, A, alg) -initialize_output(::typeof(right_null_svd!), A, alg::AbstractAlgorithm) = nothing + initialize_output(right_null!, A, right_null_alg(alg)) +initialize_output(::typeof(right_null!), A, alg::RightNullViaLQ) = + initialize_output(lq_null!, A, alg.alg) +initialize_output(::typeof(right_null!), A, alg::RightNullViaSVD) = nothing function initialize_orth_svd(A::AbstractMatrix, F, alg) S = Diagonal(initialize_output(svd_vals!, A, alg)) @@ -106,27 +106,27 @@ end # Implementation of null functions # -------------------------------- -left_null!(A, N, alg::AbstractAlgorithm) = left_null_kind(alg)(A, N, alg) -left_null_qr!(A, N, alg::AbstractAlgorithm) = qr_null!(A, N, alg) +left_null!(A, N, alg::AbstractAlgorithm) = left_null!(A, N, left_null_alg(alg)) +left_null!(A, N, alg::LeftNullViaQR) = qr_null!(A, N, alg.alg) -right_null!(A, Nᴴ, alg::AbstractAlgorithm) = right_null_kind(alg)(A, Nᴴ, alg) -right_null_lq!(A, Nᴴ, alg::AbstractAlgorithm) = lq_null!(A, Nᴴ, alg) +right_null!(A, Nᴴ, alg::AbstractAlgorithm) = right_null!(A, Nᴴ, right_null_alg(alg)) +right_null!(A, Nᴴ, alg::RightNullViaLQ) = lq_null!(A, Nᴴ, alg.alg) -function left_null_svd!(A, N, alg::TruncatedAlgorithm) - check_input(left_null_svd!, A, N, alg) - U, S, _ = svd_full!(A, alg.alg) - N, _ = truncate(left_null!, (U, S), alg.trunc) +function left_null!(A, N, alg::LeftNullViaSVD{<:TruncatedAlgorithm}) + check_input(left_null!, A, N, alg) + U, S, _ = svd_full!(A, alg.alg.alg) + N, _ = truncate(left_null!, (U, S), alg.alg.trunc) return N end -function right_null_svd!(A, Nᴴ, alg::TruncatedAlgorithm) - check_input(right_null_svd!, A, Nᴴ, alg) - _, S, Vᴴ = svd_full!(A, alg.alg) - Nᴴ, _ = truncate(right_null!, (S, Vᴴ), alg.trunc) +function right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm}) + check_input(right_null!, A, Nᴴ, alg) + _, S, Vᴴ = svd_full!(A, alg.alg.alg) + Nᴴ, _ = truncate(right_null!, (S, Vᴴ), alg.alg.trunc) return Nᴴ end -# randomized algorithms don't work for smallest values: -left_null_svd!(A, N, alg::TruncatedAlgorithm{<:GPU_Randomized}) = - throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) -right_null_svd!(A, Nᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) = - throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) +# randomized algorithms don't currently work for smallest values: +left_null!(A, N, alg::LeftNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) = + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet")) +right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) = + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet")) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index bc19cdd6..838d267e 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -307,6 +307,7 @@ end # Various consts and unions # ------------------------- + const GPU_Simple = Union{CUSOLVER_Simple} const GPU_EigAlgorithm = Union{GPU_Simple} const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} @@ -321,8 +322,10 @@ const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} const GPU_Randomized = Union{CUSOLVER_Randomized} -left_orth_kind(::GPU_SVDAlgorithm) = left_orth_svd! -right_orth_kind(::GPU_SVDAlgorithm) = right_orth_svd! +const QRAlgorithms = Union{LAPACK_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} +const LQAlgorithms = Union{LAPACK_HouseholderLQ, LQViaTransposedQR} +const SVDAlgorithms = Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm} +const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} # ================================ # ORTHOGONALIZATION ALGORITHMS @@ -346,15 +349,14 @@ LeftOrthAlgorithm(alg::AbstractAlgorithm) = error( ) const LeftOrthViaQR = LeftOrthAlgorithm{:qr} -LeftOrthAlgorithm(alg::Union{LAPACK_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR}) = - LeftOrthViaQR(alg) +LeftOrthAlgorithm(alg::QRAlgorithms) = LeftOrthViaQR{typeof(alg)}(alg) const LeftOrthViaPolar = LeftOrthAlgorithm{:polar} -LeftOrthAlgorithm(alg::Union{PolarSVD, PolarNewton}) = LeftOrthViaPolar(alg) +LeftOrthAlgorithm(alg::PolarAlgorithms) = LeftOrthViaPolar{typeof(alg)}(alg) const LeftOrthViaSVD = LeftOrthAlgorithm{:svd} -LeftOrthAlgorithm(alg::Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm}) = LeftOrthViaSVD(alg) - +LeftOrthAlgorithm(alg::SVDAlgorithms) = LeftOrthViaSVD{typeof(alg)}(alg) +LeftOrthAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftOrthViaSVD{typeof(alg)}(alg) struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg @@ -374,10 +376,59 @@ RightOrthAlgorithm(alg::AbstractAlgorithm) = error( ) const RightOrthViaLQ = RightOrthAlgorithm{:lq} -RightOrthAlgorithm(alg::Union{LAPACK_HouseholderLQ, LQViaTransposedQR}) = LeftOrthViaLQ(alg) +RightOrthAlgorithm(alg::LQAlgorithms) = RightOrthViaLQ{typeof(alg)}(alg) const RightOrthViaPolar = RightOrthAlgorithm{:polar} -RightOrthAlgorithm(alg::Union{PolarSVD, PolarNewton}) = RightOrthViaPolar(alg) +RightOrthAlgorithm(alg::PolarAlgorithms) = RightOrthViaPolar{typeof(alg)}(alg) const RightOrthViaSVD = RightOrthAlgorithm{:svd} -RightOrthAlgorithm(alg::Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm}) = RightOrthViaSVD(alg) +RightOrthAlgorithm(alg::SVDAlgorithms) = RightOrthViaSVD{typeof(alg)}(alg) +RightOrthAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightOrthViaSVD{typeof(alg)}(alg) + +struct LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end + +LeftNullAlgorithm(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `left_null` algorithm type `$(typeof(alg))`. + To register the algorithm type for `left_null`, define + + MatrixAlgebraKit.LeftNullAlgorithm(alg) = LeftNullAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:qr` or `:svd`, to select [`qr_null!`](@ref), + [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) + +const LeftNullViaQR = LeftNullAlgorithm{:qr} +LeftNullAlgorithm(alg::QRAlgorithms) = LeftNullViaQR{typeof(alg)}(alg) + +const LeftNullViaSVD = LeftNullAlgorithm{:svd} +LeftNullAlgorithm(alg::SVDAlgorithms) = LeftNullViaSVD{typeof(alg)}(alg) +LeftNullAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftNullViaSVD{typeof(alg)}(alg) + +struct RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end + +RightNullAlgorithm(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `right_null` algorithm type `$(typeof(alg))`. + To register the algorithm type for `right_null`, define + + MatrixAlgebraKit.RightNullAlgorithm(alg) = RightNullAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:lq` or `:svd`, to select [`lq_null!`](@ref), + [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) + +const RightNullViaLQ = RightNullAlgorithm{:lq} +RightNullAlgorithm(alg::LQAlgorithms) = RightNullViaLQ{typeof(alg)}(alg) + +const RightNullViaSVD = RightNullAlgorithm{:svd} +RightNullAlgorithm(alg::SVDAlgorithms) = RightNullViaSVD{typeof(alg)}(alg) +RightNullAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightNullViaSVD{typeof(alg)}(alg) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index c8937712..1e691cc1 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -207,24 +207,25 @@ of the chosen decomposition type. Here, the driving selector is `alg`, and depending on its value, the algorithm selection procedure takes other keywords into account: -* `:qr` : Factorize via QR nullspace, with further customizations through the `alg_qr` +* `:qr` : Factorize via QR nullspace, with further customizations through the `qr` keyword. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - N = qr_null(A; alg_qr...) + N = qr_null(A; alg = qr) ``` -* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. +* `:svd` : Factorize via SVD, with further customizations through the `svd` keyword. This mode further allows truncation, which can be selected through the `trunc` argument. It is roughly equivalent to: ```julia - U, S, _ = svd_trunc(A; trunc, alg_svd...) + U, S, _ = svd_trunc(A; trunc, alg = svd) N = truncate(left_null, (U, S), trunc) ``` ### `alg::AbstractAlgorithm` In this expert mode the algorithm is supplied directly, and the kind of decomposition is -deduced from that. This hinges on the implementation of the algorithm trait -[`MatrixAlgebraKit.left_null_kind(alg)`](@ref). +deduced from that. This is achieved either directly by providing a +[`LeftNullAlgorithm{kind}`](@ref LeftNullAlgorithm), or automatically by attempting to +deduce the decomposition kind with `LeftNullAlgorithm(alg)`. --- @@ -238,10 +239,6 @@ See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), """ @functiondef left_null -# helper functions -function left_null_qr! end -function left_null_svd! end - """ right_null(A; [trunc], kwargs...) -> Nᴴ right_null!(A, [Nᴴ]; [trunc], kwargs...) -> Nᴴ @@ -288,24 +285,25 @@ of the chosen decomposition type. Here, the driving selector is `alg`, and depending on its value, the algorithm selection procedure takes other keywords into account: -* `:lq` : Factorize via LQ nullspace, with further customizations through the `alg_lq` +* `:lq` : Factorize via LQ nullspace, with further customizations through the `lq` keyword. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - Nᴴ = lq_null(A; alg_qr...) + Nᴴ = lq_null(A; alg = lq) ``` -* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. +* `:svd` : Factorize via SVD, with further customizations through the `svd` keyword. This mode further allows truncation, which can be selected through the `trunc` argument. It is roughly equivalent to: ```julia - _, S, Vᴴ = svd_trunc(A; trunc, alg_svd...) + _, S, Vᴴ = svd_trunc(A; trunc, alg = svd) Nᴴ = truncate(right_null, (S, Vᴴ), trunc) ``` ### `alg::AbstractAlgorithm` In this expert mode the algorithm is supplied directly, and the kind of decomposition is -deduced from that. This hinges on the implementation of the algorithm trait -[`MatrixAlgebraKit.right_null_kind(alg)`](@ref). +deduced from that. This is achieved either directly by providing a +[`RightNullAlgorithm{kind}`](@ref RightNullAlgorithm), or automatically by attempting to +deduce the decomposition kind with `RightNullAlgorithm(alg)`. --- @@ -319,10 +317,6 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), """ @functiondef right_null -# helper functions -function right_null_lq! end -function right_null_svd! end - # Algorithm selection # ------------------- # specific override for `alg::Symbol` case, to allow for choosing the kind of factorization. @@ -330,6 +324,10 @@ function right_null_svd! end LeftOrthAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) @inline select_algorithm(::typeof(right_orth!), A, alg::Symbol; trunc = nothing, kwargs...) = RightOrthAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) +@inline select_algorithm(::typeof(left_null!), A, alg::Symbol; trunc = nothing, kwargs...) = + LeftNullAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) +@inline select_algorithm(::typeof(right_null!), A, alg::Symbol; trunc = nothing, kwargs...) = + RightNullAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) function LeftOrthViaQR(A; alg = nothing, trunc = nothing, kwargs...) isnothing(trunc) || @@ -367,6 +365,30 @@ function RightOrthViaSVD(A; alg = nothing, trunc = nothing, kwargs...) return RightOrthViaSVD{typeof(alg)}(alg) end +function LeftNullViaQR(A; alg = nothing, trunc = nothing, kwargs...) + isnothing(trunc) || + throw(ArgumentError("QR-based `left_null` is incompatible with specifying `trunc`")) + alg = select_algorithm(qr_null!, A, alg; kwargs...) + return LeftNullViaQR{typeof(alg)}(alg) +end +function LeftNullViaSVD(A; alg = nothing, trunc = nothing, kwargs...) + alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) + alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) + return LeftNullViaSVD{typeof(alg)}(alg) +end + +function RightNullViaLQ(A; alg = nothing, trunc = nothing, kwargs...) + isnothing(trunc) || + throw(ArgumentError("LQ-based `right_null` is incompatible with specifying `trunc`")) + alg = select_algorithm(lq_null!, A, alg; kwargs...) + return RightNullViaLQ{typeof(alg)}(alg) +end +function RightNullViaSVD(A; alg = nothing, trunc = nothing, kwargs...) + alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) + alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) + return RightNullViaSVD{typeof(alg)}(alg) +end + default_algorithm(::typeof(left_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = isnothing(trunc) ? LeftOrthViaQR(A; kwargs...) : LeftOrthViaSVD(A; trunc, kwargs...) # disambiguate @@ -379,69 +401,19 @@ default_algorithm(::typeof(right_orth!), A::TA; trunc = nothing, kwargs...) wher default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? RightOrthViaLQ(A; kwargs...) : RightOrthViaSVD(A; trunc, kwargs...) -function select_algorithm(::typeof(left_null!), A, alg::Symbol; trunc = nothing, kwargs...) - alg === :svd && return select_algorithm( - left_null_svd!, A, get(kwargs, :alg_svd, nothing); trunc - ) - - isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) - - alg === :qr && return select_algorithm(left_null_qr!, A, get(kwargs, :alg_qr, nothing)) - - throw(ArgumentError(lazy"unkown alg symbol $alg")) -end - default_algorithm(::typeof(left_null!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(left_null_qr!, A; kwargs...) : - select_algorithm(left_null_svd!, A; trunc, kwargs...) + isnothing(trunc) ? LeftNullViaQR(A; kwargs...) : LeftNullViaSVD(A; trunc, kwargs...) # disambiguate default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? select_algorithm(left_null_qr!, A; kwargs...) : - select_algorithm(left_null_svd!, A; trunc, kwargs...) - -select_algorithm(::typeof(left_null_qr!), A, alg = nothing; kwargs...) = - select_algorithm(qr_null!, A, alg; kwargs...) -function select_algorithm(::typeof(left_null_svd!), A, alg = nothing; trunc = nothing, kwargs...) - if alg isa TruncatedAlgorithm - isnothing(trunc) || - throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) - return alg - else - alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) - end -end - -function select_algorithm(::typeof(right_null!), A, alg::Symbol; trunc = nothing, kwargs...) - alg === :svd && return select_algorithm( - right_null_svd!, A, get(kwargs, :alg_svd, nothing); trunc - ) - - isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) - - alg === :lq && return select_algorithm(right_null_lq!, A, get(kwargs, :alg_lq, nothing)) - - throw(ArgumentError(lazy"unkown alg symbol $alg")) -end + isnothing(trunc) ? LeftNullViaQR(A; kwargs...) : LeftNullViaSVD(A; trunc, kwargs...) default_algorithm(::typeof(right_null!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(right_null_lq!, A; kwargs...) : - select_algorithm(right_null_svd!, A; trunc, kwargs...) + isnothing(trunc) ? RightNullViaLQ(A; kwargs...) : RightNullViaSVD(A; trunc, kwargs...) +# disambiguate default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? select_algorithm(right_null_lq!, A; kwargs...) : - select_algorithm(right_null_svd!, A; trunc, kwargs...) - -select_algorithm(::typeof(right_null_lq!), A, alg = nothing; kwargs...) = - select_algorithm(lq_null!, A, alg; kwargs...) - -function select_algorithm(::typeof(right_null_svd!), A, alg; trunc = nothing, kwargs...) - if alg isa TruncatedAlgorithm - isnothing(trunc) || - throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) - return alg - else - alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) - end + isnothing(trunc) ? RightNullViaLQ(A; kwargs...) : RightNullViaSVD(A; trunc, kwargs...) -end +left_orth_alg(alg::AbstractAlgorithm) = LeftOrthAlgorithm(alg) +right_orth_alg(alg::AbstractAlgorithm) = RightOrthAlgorithm(alg) +left_null_alg(alg::AbstractAlgorithm) = LeftNullAlgorithm(alg) +right_null_alg(alg::AbstractAlgorithm) = RightNullAlgorithm(alg) diff --git a/test/orthnull.jl b/test/orthnull.jl index dc532c07..ec74c710 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -130,6 +130,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) else @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + alg == :polar && continue @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) end @@ -216,6 +217,7 @@ end else @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) + alg == :polar && continue @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) end From 6bc158a8a650e1cf8f1c2974dc9c240bfea5fdef Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Oct 2025 10:35:58 -0400 Subject: [PATCH 27/47] maybeblasmat --- src/interface/eig.jl | 2 +- src/interface/eigh.jl | 2 +- src/interface/gen_eig.jl | 2 +- src/interface/lq.jl | 2 +- src/interface/qr.jl | 2 +- src/interface/svd.jl | 2 +- src/yalapack.jl | 2 ++ 7 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 1ff7c51a..28b5c69c 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -108,7 +108,7 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # ------------------- default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...) default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,))) -function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.BlasMat} +function default_eig_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} return LAPACK_Expert(; kwargs...) end function default_eig_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index ae6a843c..97f8f95c 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -115,7 +115,7 @@ default_eigh_algorithm(A; kwargs...) = default_eigh_algorithm(typeof(A); kwargs. function default_eigh_algorithm(T::Type; kwargs...) throw(MethodError(default_eigh_algorithm, (T,))) end -function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.BlasMat} +function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end function default_eigh_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} diff --git a/src/interface/gen_eig.jl b/src/interface/gen_eig.jl index 92cd3477..1903837f 100644 --- a/src/interface/gen_eig.jl +++ b/src/interface/gen_eig.jl @@ -60,7 +60,7 @@ See also [`gen_eig_full(!)`](@ref gen_eig_full). default_gen_eig_algorithm(A, B; kwargs...) = default_gen_eig_algorithm(typeof(A), typeof(B); kwargs...) default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA, TB} = throw(MethodError(default_gen_eig_algorithm, (TA, TB))) -default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA <: YALAPACK.BlasMat, TB <: YALAPACK.BlasMat} = +default_gen_eig_algorithm(::Type{TA}, ::Type{TB}; kwargs...) where {TA <: YALAPACK.MaybeBlasMat, TB <: YALAPACK.MaybeBlasMat} = LAPACK_Simple(; kwargs...) for f in (:gen_eig_full!, :gen_eig_vals!) diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 97a70184..a1efa39b 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -72,7 +72,7 @@ default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...) function default_lq_algorithm(T::Type; kwargs...) throw(MethodError(default_lq_algorithm, (T,))) end -function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.BlasMat} +function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} return LAPACK_HouseholderLQ(; kwargs...) end function default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} diff --git a/src/interface/qr.jl b/src/interface/qr.jl index 7d5f8dfb..56624282 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -72,7 +72,7 @@ default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...) function default_qr_algorithm(T::Type; kwargs...) throw(MethodError(default_qr_algorithm, (T,))) end -function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.BlasMat} +function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} return LAPACK_HouseholderQR(; kwargs...) end function default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 04e7121c..2ea26204 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -112,7 +112,7 @@ default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs... function default_svd_algorithm(T::Type; kwargs...) throw(MethodError(default_svd_algorithm, (T,))) end -function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.BlasMat} +function default_svd_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} return LAPACK_DivideAndConquer(; kwargs...) end function default_svd_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} diff --git a/src/yalapack.jl b/src/yalapack.jl index 5bc87073..b2d4c27e 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -17,6 +17,8 @@ using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapa # type alias for matrices that are definitely supported by YALAPACK const BlasMat{T <: BlasFloat} = StridedMatrix{T} +# type alias for matrices that are possibly supported by YALAPACK, after conversion +const MaybeBlasMat = Union{BlasMat, AbstractMatrix{<:Integer}} # LU factorisation for (getrf, getrs, elty) in ( From c7942c0b470384c461398dd7c844c9a02d2c5453 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Oct 2025 14:48:07 -0400 Subject: [PATCH 28/47] fix docs build --- docs/src/user_interface/decompositions.md | 15 ++++++++++++++ src/interface/decompositions.jl | 24 +++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index b6d0e94e..cca993fb 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -179,6 +179,14 @@ left_orth right_orth ``` +In order to dispatch to the underlying factorizations, the following wrapper functions are used: + +```@docs; canonical=false +LeftOrthAlgorithm +RightOrthAlgorithm +``` + + ## Null Spaces Similarly, it can be convenient to obtain an orthogonal basis for the kernel or cokernel of a matrix. @@ -189,3 +197,10 @@ Again, this is typically implemented through a combination of the decompositions left_null right_null ``` + +Again, dispatching happens through the following wrapper algorithm types: + +```@docs; canonical=false +LeftNullAlgorithm +RightNullAlgorithm +``` diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 838d267e..c500931f 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -331,6 +331,12 @@ const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} # ORTHOGONALIZATION ALGORITHMS # ================================ +""" + LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) + +Wrapper type to denote the `Kind` of factorization that is used as a backend for [`left_orth`](@ref). +By default `Kind` is a symbol, which can be either `:qr`, `:polar` or `:svd`. +""" struct LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg end @@ -358,6 +364,12 @@ const LeftOrthViaSVD = LeftOrthAlgorithm{:svd} LeftOrthAlgorithm(alg::SVDAlgorithms) = LeftOrthViaSVD{typeof(alg)}(alg) LeftOrthAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftOrthViaSVD{typeof(alg)}(alg) +""" + RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) + +Wrapper type to denote the `Kind` of factorization that is used as a backend for [`right_orth`](@ref). +By default `Kind` is a symbol, which can be either `:lq`, `:polar` or `:svd`. +""" struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg end @@ -385,6 +397,12 @@ const RightOrthViaSVD = RightOrthAlgorithm{:svd} RightOrthAlgorithm(alg::SVDAlgorithms) = RightOrthViaSVD{typeof(alg)}(alg) RightOrthAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightOrthViaSVD{typeof(alg)}(alg) +""" + LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) + +Wrapper type to denote the `Kind` of factorization that is used as a backend for [`left_null`](@ref). +By default `Kind` is a symbol, which can be either `:qr` or `:svd`. +""" struct LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg end @@ -409,6 +427,12 @@ const LeftNullViaSVD = LeftNullAlgorithm{:svd} LeftNullAlgorithm(alg::SVDAlgorithms) = LeftNullViaSVD{typeof(alg)}(alg) LeftNullAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftNullViaSVD{typeof(alg)}(alg) +""" + RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) + +Wrapper type to denote the `Kind` of factorization that is used as a backend for [`right_null`](@ref). +By default `Kind` is a symbol, which can be either `:lq` or `:svd`. +""" struct RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg end From f547ae083bc9fdb5ff059602109a4a1a4f161df1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Oct 2025 14:50:58 -0400 Subject: [PATCH 29/47] some cleanup --- src/algorithms.jl | 50 --------------------------------- src/interface/decompositions.jl | 12 -------- 2 files changed, 62 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 6c9b9e28..e7bd635a 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -56,54 +56,6 @@ end # Algorithm traits # ---------------- -""" - left_null_kind(alg::AbstractAlgorithm) -> f! - -Select an appropriate factorization function for applying `left_null!(A, alg)`. -By default, this is either `left_null_qr!` or `left_null_svd!`, but this can be extended -to insert arbitrary other decomposition functions, which should follow the signature -`f!(A, F, alg) -> F` -""" -function left_null_kind(alg::AbstractAlgorithm) - left_orth_kind(alg) === left_orth_qr! && return left_null_qr! - left_orth_kind(alg) === left_orth_svd! && return left_null_svd! - return error( - """ - Unkown or invalid `left_null` algorithm type `$(typeof(alg))`. - To register the algorithm type, define: - - MatrixAlgebraKit.left_null_kind(alg) = f! - - where `f!` should be the factorization function that will be used. - By default, this is either `left_null_qr!` or `left_null_svd!`. - """ - ) -end - -""" - right_null_kind(alg::AbstractAlgorithm) -> f! - -Select an appropriate factorization function for applying `right_null!(A, alg)`. -By default, this is either `right_null_lq!` or `right_null_svd!`, but this can be extended -to insert arbitrary other decomposition functions, which should follow the signature -`f!(A, F, alg) -> F` -""" -function right_null_kind(alg::AbstractAlgorithm) - right_orth_kind(alg) === right_orth_lq! && return right_null_lq! - right_orth_kind(alg) === right_orth_svd! && return right_null_svd! - return error( - """ - Unkown or invalid `right_null` algorithm type `$(typeof(alg))`. - To register the algorithm type, define: - - MatrixAlgebraKit.right_null_kind(alg) = f! - - where `f!` should be the factorization function that will be used. - By default, this is either `right_null_lq!` or `right_null_svd!`. - """ - ) -end - """ does_truncate(alg::AbstractAlgorithm) -> Bool @@ -278,8 +230,6 @@ struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm trunc::T end -left_orth_kind(alg::TruncatedAlgorithm) = left_orth_kind(alg.alg) -right_orth_kind(alg::TruncatedAlgorithm) = right_orth_kind(alg.alg) does_truncate(::TruncatedAlgorithm) = true # Utility macros diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index c500931f..ab2df6d8 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -36,9 +36,6 @@ elements of `L` are non-negative. @algdef LAPACK_HouseholderQL @algdef LAPACK_HouseholderRQ -left_orth_kind(::Union{LAPACK_HouseholderQR, LAPACK_HouseholderQL}) = left_orth_qr! -right_orth_kind(::Union{LAPACK_HouseholderLQ, LAPACK_HouseholderRQ}) = right_orth_lq! - # General Eigenvalue Decomposition # ------------------------------- """ @@ -120,9 +117,6 @@ const LAPACK_SVDAlgorithm = Union{ LAPACK_Jacobi, } -left_orth_kind(::LAPACK_SVDAlgorithm) = left_orth_svd! -right_orth_kind(::LAPACK_SVDAlgorithm) = right_orth_svd! - # ========================= # Polar decompositions # ========================= @@ -145,10 +139,6 @@ until convergence up to tolerance `tol`. """ @algdef PolarNewton -left_orth_kind(::Union{PolarViaSVD, PolarNewton}) = left_orth_polar! -right_orth_kind(::Union{PolarViaSVD, PolarNewton}) = right_orth_polar! - - # ========================= # DIAGONAL ALGORITHMS # ========================= @@ -172,8 +162,6 @@ the diagonal elements of `R` are non-negative. """ @algdef CUSOLVER_HouseholderQR -left_orth_kind(::CUSOLVER_HouseholderQR) = left_orth_qr! - """ CUSOLVER_QRIteration() From b3407a2ed9eb88c8149ba4f783d5081f3252f5ad Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Oct 2025 16:04:20 -0400 Subject: [PATCH 30/47] fix type stability again --- src/algorithms.jl | 6 +- src/interface/decompositions.jl | 4 ++ src/interface/orthnull.jl | 108 +++++++++++++++++--------------- test/orthnull.jl | 18 +++--- 4 files changed, 75 insertions(+), 61 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index e7bd635a..40065b5f 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -95,7 +95,7 @@ function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg} return Algorithm{alg}(; kwargs...) elseif alg isa Type return alg(; kwargs...) - elseif alg isa NamedTuple + elseif alg isa NamedTuple || alg isa Base.Pairs isempty(kwargs) || throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) return default_algorithm(f, A; alg...) @@ -267,10 +267,10 @@ function _arg_expr(::Val{1}, f, f!) $f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg) # fill in arguments - function $f!(A; alg = nothing, kwargs...) + @inline function $f!(A; alg = nothing, kwargs...) return $f!(A, select_algorithm($f!, A, alg; kwargs...)) end - function $f!(A, out; alg = nothing, kwargs...) + @inline function $f!(A, out; alg = nothing, kwargs...) return $f!(A, out, select_algorithm($f!, A, alg; kwargs...)) end function $f!(A, alg::AbstractAlgorithm) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index ab2df6d8..dfaa0654 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -328,6 +328,7 @@ By default `Kind` is a symbol, which can be either `:qr`, `:polar` or `:svd`. struct LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg end +LeftOrthAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = LeftOrthAlgorithm{Kind, Alg}(alg) LeftOrthAlgorithm(alg::AbstractAlgorithm) = error( """ @@ -361,6 +362,7 @@ By default `Kind` is a symbol, which can be either `:lq`, `:polar` or `:svd`. struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg end +RightOrthAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = RightOrthAlgorithm{Kind, Alg}(alg) RightOrthAlgorithm(alg::AbstractAlgorithm) = error( """ @@ -394,6 +396,7 @@ By default `Kind` is a symbol, which can be either `:qr` or `:svd`. struct LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg end +LeftNullAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = LeftNullAlgorithm{Kind, Alg}(alg) LeftNullAlgorithm(alg::AbstractAlgorithm) = error( """ @@ -424,6 +427,7 @@ By default `Kind` is a symbol, which can be either `:lq` or `:svd`. struct RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm alg::Alg end +RightNullAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = RightNullAlgorithm{Kind, Alg}(alg) RightNullAlgorithm(alg::AbstractAlgorithm) = error( """ diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 1e691cc1..3267a68c 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -320,98 +320,108 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), # Algorithm selection # ------------------- # specific override for `alg::Symbol` case, to allow for choosing the kind of factorization. -@inline select_algorithm(::typeof(left_orth!), A, alg::Symbol; trunc = nothing, kwargs...) = - LeftOrthAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) -@inline select_algorithm(::typeof(right_orth!), A, alg::Symbol; trunc = nothing, kwargs...) = - RightOrthAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) -@inline select_algorithm(::typeof(left_null!), A, alg::Symbol; trunc = nothing, kwargs...) = - LeftNullAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) -@inline select_algorithm(::typeof(right_null!), A, alg::Symbol; trunc = nothing, kwargs...) = - RightNullAlgorithm{alg}(A; alg = get(kwargs, alg, nothing), trunc) - -function LeftOrthViaQR(A; alg = nothing, trunc = nothing, kwargs...) +@inline select_algorithm(::typeof(left_orth!), A, alg::Symbol; kwargs...) = + select_algorithm(left_orth!, A, Val(alg); kwargs...) +@inline select_algorithm(::typeof(right_orth!), A, alg::Symbol; kwargs...) = + select_algorithm(right_orth!, A, Val(alg); kwargs...) +@inline select_algorithm(::typeof(left_null!), A, alg::Symbol; kwargs...) = + select_algorithm(left_null!, A, Val(alg); kwargs...) +@inline select_algorithm(::typeof(right_null!), A, alg::Symbol; kwargs...) = + select_algorithm(right_null!, A, Val(alg); kwargs...) + +function select_algorithm(::typeof(left_orth!), A, ::Val{:qr}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("QR-based `left_orth` is incompatible with specifying `trunc`")) - alg = select_algorithm(qr_compact!, A, alg; kwargs...) - return LeftOrthViaQR{typeof(alg)}(alg) + alg′ = select_algorithm(qr_compact!, A, get(kwargs, :qr, nothing)) + return LeftOrthViaQR(alg′) end -function LeftOrthViaPolar(A; alg = nothing, trunc = nothing, kwargs...) +function select_algorithm(::typeof(left_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("Polar-based `left_orth` is incompatible with specifying `trunc`")) - alg = select_algorithm(left_polar!, A, alg; kwargs...) - return LeftOrthViaPolar{typeof(alg)}(alg) + alg′ = select_algorithm(left_polar!, A, get(kwargs, :polar, nothing)) + return LeftOrthViaPolar(alg′) end -function LeftOrthViaSVD(A; alg = nothing, trunc = nothing, kwargs...) - alg = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : - select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) - return LeftOrthViaSVD{typeof(alg)}(alg) +function select_algorithm(::typeof(left_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...) + alg = get(kwargs, :svd, nothing) + alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg) : + select_algorithm(svd_trunc!, A, alg; trunc) + return LeftOrthViaSVD(alg′) end -function RightOrthViaLQ(A; alg = nothing, trunc = nothing, kwargs...) +function select_algorithm(::typeof(right_orth!), A, ::Val{:lq}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("LQ-based `right_orth` is incompatible with specifying `trunc`")) - alg = select_algorithm(lq_compact!, A, alg; kwargs...) - return RightOrthViaLQ{typeof(alg)}(alg) + alg = select_algorithm(lq_compact!, A, get(kwargs, :lq, nothing)) + return RightOrthViaLQ(alg) end -function RightOrthViaPolar(A; alg = nothing, trunc = nothing, kwargs...) +function select_algorithm(::typeof(right_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("Polar-based `right_orth` is incompatible with specifying `trunc`")) - alg = select_algorithm(right_polar!, A, alg; kwargs...) - return RightOrthViaPolar{typeof(alg)}(alg) + alg = select_algorithm(right_polar!, A, get(kwargs, :polar, nothing)) + return RightOrthViaPolar(alg) end -function RightOrthViaSVD(A; alg = nothing, trunc = nothing, kwargs...) - alg = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : - select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) - return RightOrthViaSVD{typeof(alg)}(alg) +function select_algorithm(::typeof(right_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...) + alg = get(kwargs, :svd, nothing) + alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg) : + select_algorithm(svd_trunc!, A, alg; trunc) + return RightOrthViaSVD(alg′) end -function LeftNullViaQR(A; alg = nothing, trunc = nothing, kwargs...) +function select_algorithm(::typeof(left_null!), A, ::Val{:qr}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("QR-based `left_null` is incompatible with specifying `trunc`")) - alg = select_algorithm(qr_null!, A, alg; kwargs...) - return LeftNullViaQR{typeof(alg)}(alg) + alg = select_algorithm(qr_null!, A, get(kwargs, :qr, nothing)) + return LeftNullViaQR(alg) end -function LeftNullViaSVD(A; alg = nothing, trunc = nothing, kwargs...) - alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) +function select_algorithm(::typeof(left_null!), A, ::Val{:svd}; trunc = nothing, kwargs...) + alg_svd = select_algorithm(svd_full!, A, get(kwargs, :svd, nothing)) alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) - return LeftNullViaSVD{typeof(alg)}(alg) + return LeftNullViaSVD(alg) end -function RightNullViaLQ(A; alg = nothing, trunc = nothing, kwargs...) +function select_algorithm(::typeof(right_null!), A, ::Val{:lq}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("LQ-based `right_null` is incompatible with specifying `trunc`")) - alg = select_algorithm(lq_null!, A, alg; kwargs...) - return RightNullViaLQ{typeof(alg)}(alg) + alg = select_algorithm(lq_null!, A, get(kwargs, :lq, nothing)) + return RightNullViaLQ(alg) end -function RightNullViaSVD(A; alg = nothing, trunc = nothing, kwargs...) - alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) +function select_algorithm(::typeof(right_null!), A, ::Val{:svd}; trunc = nothing, kwargs...) + alg_svd = select_algorithm(svd_full!, A, get(kwargs, :svd, nothing)) alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) - return RightNullViaSVD{typeof(alg)}(alg) + return RightNullViaSVD(alg) end default_algorithm(::typeof(left_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? LeftOrthViaQR(A; kwargs...) : LeftOrthViaSVD(A; trunc, kwargs...) + isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); qr = kwargs) : + select_algorithm(left_orth!, A, Val(:svd); svd = kwargs) # disambiguate default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? LeftOrthViaQR(A; kwargs...) : LeftOrthViaSVD(A; trunc, kwargs...) + isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); qr = kwargs) : + select_algorithm(left_orth!, A, Val(:svd); svd = kwargs) default_algorithm(::typeof(right_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? RightOrthViaLQ(A; kwargs...) : RightOrthViaSVD(A; trunc, kwargs...) + isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); lq = kwargs) : + select_algorithm(right_orth!, A, Val(:svd); svd = kwargs) # disambiguate default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? RightOrthViaLQ(A; kwargs...) : RightOrthViaSVD(A; trunc, kwargs...) + isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); lq = kwargs) : + select_algorithm(right_orth!, A, Val(:svd); svd = kwargs) default_algorithm(::typeof(left_null!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? LeftNullViaQR(A; kwargs...) : LeftNullViaSVD(A; trunc, kwargs...) + isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); qr = kwargs) : + select_algorithm(left_null!, A, Val(:svd); svd = kwargs, trunc) # disambiguate default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? LeftNullViaQR(A; kwargs...) : LeftNullViaSVD(A; trunc, kwargs...) + isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); qr = kwargs) : + select_algorithm(left_null!, A, Val(:svd); svd = kwargs, trunc) default_algorithm(::typeof(right_null!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? RightNullViaLQ(A; kwargs...) : RightNullViaSVD(A; trunc, kwargs...) + isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); lq = kwargs) : + select_algorithm(right_null!, A, Val(:svd); svd = kwargs, trunc) # disambiguate default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? RightNullViaLQ(A; kwargs...) : RightNullViaSVD(A; trunc, kwargs...) + isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); lq = kwargs) : + select_algorithm(right_null!, A, Val(:svd); svd = kwargs, trunc) left_orth_alg(alg::AbstractAlgorithm) = LeftOrthAlgorithm(alg) right_orth_alg(alg::AbstractAlgorithm) = RightOrthAlgorithm(alg) diff --git a/test/orthnull.jl b/test/orthnull.jl index ec74c710..6fc3fd27 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -100,7 +100,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) for alg in (:qr, :polar, :svd) # explicit kind kwarg m < n && alg === :polar && continue - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg) + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg))) @test V2 * C2 ≈ A @test isisometric(V2) if alg != :polar @@ -112,7 +112,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) # with kind and tol kwargs if alg == :svd - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; atol)) N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) @test V2 * C2 ≈ A @test V2' * V2 ≈ I @@ -120,7 +120,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test N2' * N2 ≈ I @test V2 * V2' + N2 * N2' ≈ I - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; rtol)) N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) @test V2 * C2 ≈ A @test isisometric(V2) @@ -188,27 +188,27 @@ end for alg in (:lq, :polar, :svd) n < m && alg == :polar && continue - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg) + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg))) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) if alg != :polar - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg))) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I end if alg == :svd - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; atol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; atol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; rtol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; rtol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) From 916fc28ed79fc552f74051738e40384e37adf796 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Oct 2025 16:26:29 -0400 Subject: [PATCH 31/47] update gpu tests --- test/amd/orthnull.jl | 205 ++++++++++++++---------------------------- test/cuda/orthnull.jl | 202 ++++++++++++++--------------------------- test/orthnull.jl | 4 +- 3 files changed, 136 insertions(+), 275 deletions(-) diff --git a/test/amd/orthnull.jl b/test/amd/orthnull.jl index 97f5d915..756e8064 100644 --- a/test/amd/orthnull.jl +++ b/test/amd/orthnull.jl @@ -10,7 +10,9 @@ using AMDGPU # testing non-AbstractArray codepaths: include(joinpath("..", "linearmap.jl")) -@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +eltypes = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "left_orth and left_null for T = $T" for T in eltypes rng = StableRNG(123) m = 54 @testset for n in (37, m, 63) @@ -30,7 +32,7 @@ include(joinpath("..", "linearmap.jl")) @test hV * hV' + hN * hN' ≈ I M = LinearMap(A) - VM, CM = @constinferred left_orth(M; kind = :svd) + VM, CM = @constinferred left_orth(M; alg = :svd) @test parent(VM) * parent(CM) ≈ A if m > n @@ -48,20 +50,33 @@ include(joinpath("..", "linearmap.jl")) @test isisometric(N) end - for alg_qr in ((; positive = true), (; positive = false), ROCSOLVER_HouseholderQR()) - V, C = @constinferred left_orth(A; alg_qr) - N = @constinferred left_null(A; alg_qr) - @test V isa ROCMatrix{T} && size(V) == (m, minmn) - @test C isa ROCMatrix{T} && size(C) == (minmn, n) - @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - hV = collect(V) - hN = collect(N) - @test hV * hV' + hN * hN' ≈ I - end + # passing a kind and some kwargs + V, C = @constinferred left_orth(A; alg = :qr, qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + @test V isa ROCMatrix{T} && size(V) == (m, minmn) + @test C isa ROCMatrix{T} && size(C) == (minmn, n) + @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + hV = collect(V) + hN = collect(N) + @test hV * hV' + hN * hN' ≈ I + + # passing an algorithm + V, C = @constinferred left_orth(A; alg = CUSOLVER_HouseholderQR()) + N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + @test V isa ROCMatrix{T} && size(V) == (m, minmn) + @test C isa ROCMatrix{T} && size(C) == (minmn, n) + @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + hV = collect(V) + hN = collect(N) + @test hV * hV' + hN * hN' ≈ I Ac = similar(A) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) @@ -82,9 +97,6 @@ include(joinpath("..", "linearmap.jl")) AMDGPU.@allowscalar begin N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol)) end - @test V2 !== V - @test C2 !== C - @test N2 !== C @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -102,9 +114,6 @@ include(joinpath("..", "linearmap.jl")) AMDGPU.@allowscalar begin N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null) end - @test V2 !== V - @test C2 !== C - @test N2 !== C @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -114,16 +123,13 @@ include(joinpath("..", "linearmap.jl")) @test hV2 * hV2' + hN2 * hN2' ≈ I end - @testset for kind in (:qr, :polar, :svd) # explicit kind kwarg - m < n && kind == :polar && continue - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind) - @test V2 === V - @test C2 === C + @testset for alg in (:qr, :polar, :svd) # explicit alg kwarg + m < n && alg == :polar && continue + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg))) @test V2 * C2 ≈ A @test isisometric(V2) - if kind != :polar - N2 = @constinferred left_null!(copy!(Ac, A), N; kind = kind) - @test N2 === N + if alg != :polar + N2 = @constinferred left_null!(copy!(Ac, A), N; alg) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N2) hV2 = collect(V2) @@ -131,42 +137,20 @@ include(joinpath("..", "linearmap.jl")) @test hV2 * hV2' + hN2 * hN2' ≈ I end - # with kind and tol kwargs - if kind == :svd - V2, C2 = @constinferred left_orth!( - copy!(Ac, A), (V, C); kind = kind, - trunc = (; atol = atol) - ) - AMDGPU.@allowscalar begin - N2 = @constinferred left_null!( - copy!(Ac, A), N; kind = kind, - trunc = (; atol = atol) - ) - end - @test V2 !== V - @test C2 !== C - @test N2 !== C + # with alg and tol kwargs + if alg == :svd + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; atol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) @test V2 * C2 ≈ A - @test V2' * V2 ≈ I + @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N2) hV2 = collect(V2) hN2 = collect(N2) @test hV2 * hV2' + hN2 * hN2' ≈ I - V2, C2 = @constinferred left_orth!( - copy!(Ac, A), (V, C); kind = kind, - trunc = (; rtol = rtol) - ) - AMDGPU.@allowscalar begin - N2 = @constinferred left_null!( - copy!(Ac, A), N; kind = kind, - trunc = (; rtol = rtol) - ) - end - @test V2 !== V - @test C2 !== C - @test N2 !== C + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; rtol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -175,31 +159,17 @@ include(joinpath("..", "linearmap.jl")) hN2 = collect(N2) @test hV2 * hV2' + hN2 * hN2' ≈ I else - @test_throws ArgumentError left_orth!( - copy!(Ac, A), (V, C); kind = kind, - trunc = (; atol = atol) - ) - @test_throws ArgumentError left_orth!( - copy!(Ac, A), (V, C); kind = kind, - trunc = (; rtol = rtol) - ) - @test_throws ArgumentError left_null!( - copy!(Ac, A), N; kind = kind, - trunc = (; atol = atol) - ) - @test_throws ArgumentError left_null!( - copy!(Ac, A), N; kind = kind, - trunc = (; rtol = rtol) - ) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + alg == :polar && continue + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) end end end end -@testset "right_orth and right_null for T = $T" for T in ( - Float32, Float64, ComplexF32, - ComplexF64, - ) +@testset "right_orth and right_null for T = $T" for T in eltypes rng = StableRNG(123) m = 54 @testset for n in (37, m, 63) @@ -219,15 +189,12 @@ end @test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ ≈ I M = LinearMap(A) - CM, VMᴴ = @constinferred right_orth(M; kind = :svd) + CM, VMᴴ = @constinferred right_orth(M; alg = :svd) @test parent(CM) * parent(VMᴴ) ≈ A Ac = similar(A) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ) - @test C2 === C - @test Vᴴ2 === Vᴴ - @test Nᴴ2 === Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -240,9 +207,6 @@ end rtol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -253,9 +217,6 @@ end C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -264,16 +225,14 @@ end hNᴴ2 = collect(Nᴴ2) @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - @testset "kind = $kind" for kind in (:lq, :polar, :svd) - n < m && kind == :polar && continue - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind = kind) - @test C2 === C - @test Vᴴ2 === Vᴴ + + @testset "alg = $alg" for alg in (:lq, :polar, :svd) + n < m && alg == :polar && continue + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg))) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) - if kind != :polar - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind = kind) - @test Nᴴ2 === Nᴴ + if alg != :polar + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg))) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) hVᴴ2 = collect(Vᴴ2) @@ -281,18 +240,9 @@ end @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I end - if kind == :svd - C2, Vᴴ2 = @constinferred right_orth!( - copy!(Ac, A), (C, Vᴴ); kind = kind, - trunc = (; atol = atol) - ) - Nᴴ2 = @constinferred right_null!( - copy!(Ac, A), Nᴴ; kind = kind, - trunc = (; atol = atol) - ) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ + if alg == :svd + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; atol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; atol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -301,42 +251,23 @@ end hNᴴ2 = collect(Nᴴ2) @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - C2, Vᴴ2 = @constinferred right_orth!( - copy!(Ac, A), (C, Vᴴ); kind = kind, - trunc = (; rtol = rtol) - ) - Nᴴ2 = @constinferred right_null!( - copy!(Ac, A), Nᴴ; kind = kind, - trunc = (; rtol = rtol) - ) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; rtol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; rtol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) hVᴴ2 = collect(Vᴴ2) hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ diagm(ones(T, size(Vᴴ2, 2))) atol = m * n * MatrixAlgebraKit.defaulttol(T) + @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I else - @test_throws ArgumentError right_orth!( - copy!(Ac, A), (C, Vᴴ); kind = kind, - trunc = (; atol = atol) - ) - @test_throws ArgumentError right_orth!( - copy!(Ac, A), (C, Vᴴ); kind = kind, - trunc = (; rtol = rtol) - ) - @test_throws ArgumentError right_null!( - copy!(Ac, A), Nᴴ; kind = kind, - trunc = (; atol = atol) - ) - @test_throws ArgumentError right_null!( - copy!(Ac, A), Nᴴ; kind = kind, - trunc = (; rtol = rtol) - ) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) + alg == :polar && continue + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) end end + end end diff --git a/test/cuda/orthnull.jl b/test/cuda/orthnull.jl index cfe7d9e3..3c3ec0c0 100644 --- a/test/cuda/orthnull.jl +++ b/test/cuda/orthnull.jl @@ -10,7 +10,9 @@ using CUDA # testing non-AbstractArray codepaths: include(joinpath("..", "linearmap.jl")) -@testset "left_orth and left_null for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) +eltypes = (Float32, Float64, ComplexF32, ComplexF64) + +@testset "left_orth and left_null for T = $T" for T in eltypes rng = StableRNG(123) m = 54 @testset for n in (37, m, 63) @@ -30,7 +32,7 @@ include(joinpath("..", "linearmap.jl")) @test hV * hV' + hN * hN' ≈ I M = LinearMap(A) - VM, CM = @constinferred left_orth(M; kind = :svd) + VM, CM = @constinferred left_orth(M; alg = :svd) @test parent(VM) * parent(CM) ≈ A if m > n @@ -48,27 +50,37 @@ include(joinpath("..", "linearmap.jl")) @test isisometric(N) end - for alg_qr in ((; positive = true), (; positive = false), CUSOLVER_HouseholderQR()) - V, C = @constinferred left_orth(A; alg_qr) - N = @constinferred left_null(A; alg_qr) - @test V isa CuMatrix{T} && size(V) == (m, minmn) - @test C isa CuMatrix{T} && size(C) == (minmn, n) - @test N isa CuMatrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - hV = collect(V) - hN = collect(N) - @test hV * hV' + hN * hN' ≈ I - end + # passing a kind and some kwargs + V, C = @constinferred left_orth(A; alg = :qr, qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + @test V isa CuMatrix{T} && size(V) == (m, minmn) + @test C isa CuMatrix{T} && size(C) == (minmn, n) + @test N isa CuMatrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + hV = collect(V) + hN = collect(N) + @test hV * hV' + hN * hN' ≈ I + + # passing an algorithm + V, C = @constinferred left_orth(A; alg = CUSOLVER_HouseholderQR()) + N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + @test V isa CuMatrix{T} && size(V) == (m, minmn) + @test C isa CuMatrix{T} && size(C) == (minmn, n) + @test N isa CuMatrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + hV = collect(V) + hN = collect(N) + @test hV * hV' + hN * hN' ≈ I Ac = similar(A) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) N2 = @constinferred left_null!(copy!(Ac, A), N) - @test V2 === V - @test C2 === C - @test N2 === N @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -80,9 +92,6 @@ include(joinpath("..", "linearmap.jl")) atol = eps(real(T)) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol)) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol)) - @test V2 !== V - @test C2 !== C - @test N2 !== C @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -98,9 +107,6 @@ include(joinpath("..", "linearmap.jl")) ) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null) - @test V2 !== V - @test C2 !== C - @test N2 !== C @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -110,16 +116,13 @@ include(joinpath("..", "linearmap.jl")) @test hV2 * hV2' + hN2 * hN2' ≈ I end - @testset for kind in (:qr, :polar, :svd) # explicit kind kwarg - m < n && kind == :polar && continue - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind) - @test V2 === V - @test C2 === C + @testset for alg in (:qr, :polar, :svd) # explicit alg kwarg + m < n && alg == :polar && continue + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg))) @test V2 * C2 ≈ A @test isisometric(V2) - if kind != :polar - N2 = @constinferred left_null!(copy!(Ac, A), N; kind = kind) - @test N2 === N + if alg != :polar + N2 = @constinferred left_null!(copy!(Ac, A), N; alg) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N2) hV2 = collect(V2) @@ -127,38 +130,20 @@ include(joinpath("..", "linearmap.jl")) @test hV2 * hV2' + hN2 * hN2' ≈ I end - # with kind and tol kwargs - if kind == :svd - V2, C2 = @constinferred left_orth!( - copy!(Ac, A), (V, C); kind = kind, - trunc = (; atol = atol) - ) - N2 = @constinferred left_null!( - copy!(Ac, A), N; kind = kind, - trunc = (; atol = atol) - ) - @test V2 !== V - @test C2 !== C - @test N2 !== C + # with alg and tol kwargs + if alg == :svd + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; atol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) @test V2 * C2 ≈ A - @test V2' * V2 ≈ I + @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N2) hV2 = collect(V2) hN2 = collect(N2) @test hV2 * hV2' + hN2 * hN2' ≈ I - V2, C2 = @constinferred left_orth!( - copy!(Ac, A), (V, C); kind = kind, - trunc = (; rtol = rtol) - ) - N2 = @constinferred left_null!( - copy!(Ac, A), N; kind = kind, - trunc = (; rtol = rtol) - ) - @test V2 !== V - @test C2 !== C - @test N2 !== C + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; rtol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -167,31 +152,17 @@ include(joinpath("..", "linearmap.jl")) hN2 = collect(N2) @test hV2 * hV2' + hN2 * hN2' ≈ I else - @test_throws ArgumentError left_orth!( - copy!(Ac, A), (V, C); kind = kind, - trunc = (; atol = atol) - ) - @test_throws ArgumentError left_orth!( - copy!(Ac, A), (V, C); kind = kind, - trunc = (; rtol = rtol) - ) - @test_throws ArgumentError left_null!( - copy!(Ac, A), N; kind = kind, - trunc = (; atol = atol) - ) - @test_throws ArgumentError left_null!( - copy!(Ac, A), N; kind = kind, - trunc = (; rtol = rtol) - ) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + alg == :polar && continue + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) end end end end -@testset "right_orth and right_null for T = $T" for T in ( - Float32, Float64, ComplexF32, - ComplexF64, - ) +@testset "right_orth and right_null for T = $T" for T in eltypes rng = StableRNG(123) m = 54 @testset for n in (37, m, 63) @@ -211,15 +182,12 @@ end @test hVᴴ' * hVᴴ + hNᴴ' * hNᴴ ≈ I M = LinearMap(A) - CM, VMᴴ = @constinferred right_orth(M; kind = :svd) + CM, VMᴴ = @constinferred right_orth(M; alg = :svd) @test parent(CM) * parent(VMᴴ) ≈ A Ac = similar(A) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ) - @test C2 === C - @test Vᴴ2 === Vᴴ - @test Nᴴ2 === Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -232,9 +200,6 @@ end rtol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -245,9 +210,6 @@ end C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -256,16 +218,13 @@ end hNᴴ2 = collect(Nᴴ2) @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - @testset "kind = $kind" for kind in (:lq, :polar, :svd) - n < m && kind == :polar && continue - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind = kind) - @test C2 === C - @test Vᴴ2 === Vᴴ + @testset "alg = $alg" for alg in (:lq, :polar, :svd) + n < m && alg == :polar && continue + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg))) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) - if kind != :polar - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind = kind) - @test Nᴴ2 === Nᴴ + if alg != :polar + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg))) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) hVᴴ2 = collect(Vᴴ2) @@ -273,18 +232,9 @@ end @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I end - if kind == :svd - C2, Vᴴ2 = @constinferred right_orth!( - copy!(Ac, A), (C, Vᴴ); kind = kind, - trunc = (; atol = atol) - ) - Nᴴ2 = @constinferred right_null!( - copy!(Ac, A), Nᴴ; kind = kind, - trunc = (; atol = atol) - ) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ + if alg == :svd + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; atol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; atol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -293,41 +243,21 @@ end hNᴴ2 = collect(Nᴴ2) @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I - C2, Vᴴ2 = @constinferred right_orth!( - copy!(Ac, A), (C, Vᴴ); kind = kind, - trunc = (; rtol = rtol) - ) - Nᴴ2 = @constinferred right_null!( - copy!(Ac, A), Nᴴ; kind = kind, - trunc = (; rtol = rtol) - ) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; rtol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; rtol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) hVᴴ2 = collect(Vᴴ2) hNᴴ2 = collect(Nᴴ2) - @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ diagm(ones(T, size(Vᴴ2, 2))) atol = m * n * MatrixAlgebraKit.defaulttol(T) + @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I else - @test_throws ArgumentError right_orth!( - copy!(Ac, A), (C, Vᴴ); kind = kind, - trunc = (; atol = atol) - ) - @test_throws ArgumentError right_orth!( - copy!(Ac, A), (C, Vᴴ); kind = kind, - trunc = (; rtol = rtol) - ) - @test_throws ArgumentError right_null!( - copy!(Ac, A), Nᴴ; kind = kind, - trunc = (; atol = atol) - ) - @test_throws ArgumentError right_null!( - copy!(Ac, A), Nᴴ; kind = kind, - trunc = (; rtol = rtol) - ) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) + alg == :polar && continue + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) end end end diff --git a/test/orthnull.jl b/test/orthnull.jl index 6fc3fd27..1bd779c0 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -115,9 +115,9 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; atol)) N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) @test V2 * C2 ≈ A - @test V2' * V2 ≈ I + @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test N2' * N2 ≈ I + @test isisometric(N2) @test V2 * V2' + N2 * N2' ≈ I V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; rtol)) From 14bc3c7274f94b20372e8698528419b6a1605e28 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 23 Oct 2025 17:58:14 -0400 Subject: [PATCH 32/47] address some review comments --- docs/src/user_interface/decompositions.md | 2 +- src/algorithms.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index cca993fb..21dc38d2 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -179,7 +179,7 @@ left_orth right_orth ``` -In order to dispatch to the underlying factorizations, the following wrapper functions are used: +In order to dispatch to the underlying factorizations, the following wrapper algorithms are used: ```@docs; canonical=false LeftOrthAlgorithm diff --git a/src/algorithms.jl b/src/algorithms.jl index 40065b5f..388e1eac 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -263,7 +263,6 @@ end function _arg_expr(::Val{1}, f, f!) return quote # out of place to inplace @inline $f(A; alg = nothing, kwargs...) = $f(A, select_algorithm($f, A, alg; kwargs...)) - # $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...) $f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg) # fill in arguments From d45cfa95d8f0a991607b9f2f2eef7ac694f49cdc Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 23 Oct 2025 18:08:58 -0400 Subject: [PATCH 33/47] some AMD fixes --- .../MatrixAlgebraKitAMDGPUExt.jl | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 48bf56fd..5b6eea5c 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -170,5 +170,26 @@ function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tu Utrunc = U[:, trunc_cols] return Utrunc, ind end +function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.right_null!), SVᴴ::Tuple{TS, TVᴴ}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TS, TVᴴ <: ROCArray} + # TODO: avoid allocation? + S, Vᴴ = SVᴴ + extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1)))) + ind = MatrixAlgebraKit.findtruncated(extended_S, strategy) + trunc_rows = collect(1:size(Vᴴ, 1))[ind] + Vᴴtrunc = Vᴴ[trunc_rows, :] + return Vᴴtrunc, ind +end + +# disambiguate: +function MatrixAlgebraKit.truncate(::typeof(left_null!), (U, S)::Tuple{TU, TS}, ::NoTruncation) where {TU <: ROCArray, TS} + m, n = size(S) + ind = (n + 1):m + return U[:, ind], ind +end +function MatrixAlgebraKit.truncate(::typeof(right_null!), (S, Vᴴ)::Tuple{TS, TVᴴ}, ::NoTruncation) where {TS, TVᴴ <: ROCArray} + m, n = size(S) + ind = (m + 1):n + return Vᴴ[ind, :], ind +end end From 2b62dfdb10a3caed85e3edc3b03680886be842f0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 23 Oct 2025 21:12:10 -0400 Subject: [PATCH 34/47] some more AMD fixes --- .../MatrixAlgebraKitAMDGPUExt.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 5b6eea5c..04bee40e 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -4,7 +4,7 @@ using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe -using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm +using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx! @@ -161,7 +161,9 @@ function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix) return A, B end -function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tuple{TU, TS}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TU <: ROCArray, TS} +function MatrixAlgebraKit.truncate( + ::typeof(left_null!), US::Tuple{TU, TS}, strategy::TruncationStrategy + ) where {TU <: ROCMatrix, TS} # TODO: avoid allocation? U, S = US extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 1) - size(S, 2)))) @@ -170,7 +172,9 @@ function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.left_null!), US::Tu Utrunc = U[:, trunc_cols] return Utrunc, ind end -function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.right_null!), SVᴴ::Tuple{TS, TVᴴ}, strategy::MatrixAlgebraKit.TruncationStrategy) where {TS, TVᴴ <: ROCArray} +function MatrixAlgebraKit.truncate( + ::typeof(right_null!), SVᴴ::Tuple{TS, TVᴴ}, strategy::TruncationStrategy + ) where {TS, TVᴴ <: ROCMatrix} # TODO: avoid allocation? S, Vᴴ = SVᴴ extended_S = vcat(diagview(S), zeros(eltype(S), max(0, size(S, 2) - size(S, 1)))) @@ -181,12 +185,16 @@ function MatrixAlgebraKit.truncate(::typeof(MatrixAlgebraKit.right_null!), SVᴴ end # disambiguate: -function MatrixAlgebraKit.truncate(::typeof(left_null!), (U, S)::Tuple{TU, TS}, ::NoTruncation) where {TU <: ROCArray, TS} +function MatrixAlgebraKit.truncate( + ::typeof(left_null!), (U, S)::Tuple{TU, TS}, ::NoTruncation + ) where {TU <: ROCMatrix, TS} m, n = size(S) ind = (n + 1):m return U[:, ind], ind end -function MatrixAlgebraKit.truncate(::typeof(right_null!), (S, Vᴴ)::Tuple{TS, TVᴴ}, ::NoTruncation) where {TS, TVᴴ <: ROCArray} +function MatrixAlgebraKit.truncate( + ::typeof(right_null!), (S, Vᴴ)::Tuple{TS, TVᴴ}, ::NoTruncation + ) where {TS, TVᴴ <: ROCMatrix} m, n = size(S) ind = (m + 1):n return Vᴴ[ind, :], ind From adc36bb8951dea82e95ebc78ab6ac2a9f6641cd4 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 24 Oct 2025 15:20:23 -0400 Subject: [PATCH 35/47] more more AMD fixes --- test/amd/orthnull.jl | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/test/amd/orthnull.jl b/test/amd/orthnull.jl index 756e8064..b86bfdf6 100644 --- a/test/amd/orthnull.jl +++ b/test/amd/orthnull.jl @@ -206,7 +206,9 @@ end atol = eps(real(T)) rtol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol = atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol)) + AMDGPU.@allowscalar begin + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol = atol)) + end @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -216,7 +218,9 @@ end @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol = rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol)) + AMDGPU.@allowscalar begin + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol = rtol)) + end @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -242,7 +246,9 @@ end if alg == :svd C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; atol)) + AMDGPU.@allowscalar begin + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; atol)) + end @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -252,7 +258,9 @@ end @test hVᴴ2' * hVᴴ2 + hNᴴ2' * hNᴴ2 ≈ I C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg = $(QuoteNode(alg)), trunc = (; rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; rtol)) + AMDGPU.@allowscalar begin + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg = $(QuoteNode(alg)), trunc = (; rtol)) + end @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) From 9c916af5c00bebd574b1ff2a115d86a00bb33bfd Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 25 Oct 2025 12:57:53 -0400 Subject: [PATCH 36/47] move lqviatransposedqr --- src/interface/decompositions.jl | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index dfaa0654..3a81f086 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -140,7 +140,7 @@ until convergence up to tolerance `tol`. @algdef PolarNewton # ========================= -# DIAGONAL ALGORITHMS +# Varia # ========================= """ DiagonalAlgorithm(; kwargs...) @@ -150,6 +150,21 @@ the diagonal structure of the input and outputs. """ @algdef DiagonalAlgorithm +""" + LQViaTransposedQR(qr_alg) + +Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. +The `qr_alg` specifies which QR-decomposition implementation to use. +""" +struct LQViaTransposedQR{A <: AbstractAlgorithm} <: AbstractAlgorithm + qr_alg::A +end +function Base.show(io::IO, alg::LQViaTransposedQR) + print(io, "LQViaTransposedQR(") + _show_alg(io, alg.qr_alg) + return print(io, ")") +end + # ========================= # CUSOLVER ALGORITHMS # ========================= @@ -277,22 +292,6 @@ Divide and Conquer algorithm. const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} -# Alternative algorithm (necessary for CUDA) -""" - LQViaTransposedQR(qr_alg) - -Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. -The `qr_alg` specifies which QR-decomposition implementation to use. -""" -struct LQViaTransposedQR{A <: AbstractAlgorithm} <: AbstractAlgorithm - qr_alg::A -end -function Base.show(io::IO, alg::LQViaTransposedQR) - print(io, "LQViaTransposedQR(") - _show_alg(io, alg.qr_alg) - return print(io, ")") -end - # Various consts and unions # ------------------------- From 185d939c909882c8720cb78397e0bc2a99998e39 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 25 Oct 2025 13:03:07 -0400 Subject: [PATCH 37/47] no randomized svd for null --- src/implementations/orthnull.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 92e4cd1b..27c98712 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -127,6 +127,6 @@ end # randomized algorithms don't currently work for smallest values: left_null!(A, N, alg::LeftNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) = - throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet")) + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) = - throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet")) + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) From f89a318ad43be6ba0154a007b8b4fe5e12630938 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 25 Oct 2025 13:03:31 -0400 Subject: [PATCH 38/47] no initialization for orthnull with SVD --- src/implementations/orthnull.jl | 36 +++++++++------------------------ 1 file changed, 10 insertions(+), 26 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 27c98712..651855a3 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -12,8 +12,7 @@ check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaQR) = check_input(qr_compact!, A, VC, alg.alg) check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaPolar) = check_input(left_polar!, A, VC, alg.alg) -check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaSVD) = - check_input(qr_compact!, A, VC, alg.alg) +check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaSVD) = nothing check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(right_orth!, A, CVᴴ, right_orth_alg(alg)) @@ -22,8 +21,7 @@ check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaLQ) = check_input(lq_compact!, A, VC, alg.alg) check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaPolar) = check_input(right_polar!, A, VC, alg.alg) -check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaSVD) = - check_input(lq_compact!, A, VC, alg.alg) +check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaSVD) = nothing check_input(::typeof(left_null!), A, N, alg::AbstractAlgorithm) = check_input(left_null!, A, N, left_null_alg(alg)) @@ -46,8 +44,7 @@ initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaQR) = initialize_output(qr_compact!, A, alg.alg) initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaPolar) = initialize_output(left_polar!, A, alg.alg) -initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaSVD) = - initialize_output(qr_compact!, A, alg.alg) +initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaSVD) = nothing initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = initialize_output(right_orth!, A, right_orth_alg(alg)) @@ -56,8 +53,7 @@ initialize_output(::typeof(right_orth!), A, alg::RightOrthViaLQ) = initialize_output(lq_compact!, A, alg.alg) initialize_output(::typeof(right_orth!), A, alg::RightOrthViaPolar) = initialize_output(right_polar!, A, alg.alg) -initialize_output(::typeof(right_orth!), A, alg::RightOrthViaSVD) = - initialize_output(lq_compact!, A, alg.alg) +initialize_output(::typeof(right_orth!), A, alg::RightOrthViaSVD) = nothing initialize_output(::typeof(left_null!), A, alg::AbstractAlgorithm) = initialize_output(left_null!, A, left_null_alg(alg)) @@ -71,35 +67,24 @@ initialize_output(::typeof(right_null!), A, alg::RightNullViaLQ) = initialize_output(lq_null!, A, alg.alg) initialize_output(::typeof(right_null!), A, alg::RightNullViaSVD) = nothing -function initialize_orth_svd(A::AbstractMatrix, F, alg) - S = Diagonal(initialize_output(svd_vals!, A, alg)) - return F[1], S, F[2] -end -# fallback doesn't re-use F at all -initialize_orth_svd(A, F, alg) = initialize_output(svd_compact!, A, alg) - # Implementation of orth functions # -------------------------------- left_orth!(A, VC, alg::AbstractAlgorithm) = left_orth!(A, VC, left_orth_alg(alg)) left_orth!(A, VC, alg::LeftOrthViaQR) = qr_compact!(A, VC, alg.alg) left_orth!(A, VC, alg::LeftOrthViaPolar) = left_polar!(A, VC, alg.alg) -# orth_svd requires implementations of `lmul!` and `rmul!` function left_orth!(A, VC, alg::LeftOrthViaSVD) check_input(left_orth!, A, VC, alg) - USVᴴ = initialize_orth_svd(A, VC, alg.alg) - V, S, C = does_truncate(alg.alg) ? svd_trunc!(A, USVᴴ, alg.alg) : svd_compact!(A, USVᴴ, alg.alg) + V, S, C = does_truncate(alg.alg) ? svd_trunc!(A, alg.alg) : svd_compact!(A, alg.alg) lmul!(S, C) return V, C end -# orth_svd requires implementations of `lmul!` and `rmul!` + right_orth!(A, CVᴴ, alg::AbstractAlgorithm) = right_orth!(A, CVᴴ, right_orth_alg(alg)) right_orth!(A, CVᴴ, alg::RightOrthViaLQ) = lq_compact!(A, CVᴴ, alg.alg) right_orth!(A, CVᴴ, alg::RightOrthViaPolar) = right_polar!(A, CVᴴ, alg.alg) - function right_orth!(A, CVᴴ, alg::RightOrthViaSVD) check_input(right_orth!, A, CVᴴ, alg) - USVᴴ = initialize_orth_svd(A, CVᴴ, alg.alg) - C, S, Vᴴ = does_truncate(alg.alg) ? svd_trunc!(A, USVᴴ, alg.alg) : svd_compact!(A, USVᴴ, alg.alg) + C, S, Vᴴ = does_truncate(alg.alg) ? svd_trunc!(A, alg.alg) : svd_compact!(A, alg.alg) rmul!(C, S) return C, Vᴴ end @@ -108,16 +93,15 @@ end # -------------------------------- left_null!(A, N, alg::AbstractAlgorithm) = left_null!(A, N, left_null_alg(alg)) left_null!(A, N, alg::LeftNullViaQR) = qr_null!(A, N, alg.alg) - -right_null!(A, Nᴴ, alg::AbstractAlgorithm) = right_null!(A, Nᴴ, right_null_alg(alg)) -right_null!(A, Nᴴ, alg::RightNullViaLQ) = lq_null!(A, Nᴴ, alg.alg) - function left_null!(A, N, alg::LeftNullViaSVD{<:TruncatedAlgorithm}) check_input(left_null!, A, N, alg) U, S, _ = svd_full!(A, alg.alg.alg) N, _ = truncate(left_null!, (U, S), alg.alg.trunc) return N end + +right_null!(A, Nᴴ, alg::AbstractAlgorithm) = right_null!(A, Nᴴ, right_null_alg(alg)) +right_null!(A, Nᴴ, alg::RightNullViaLQ) = lq_null!(A, Nᴴ, alg.alg) function right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm}) check_input(right_null!, A, Nᴴ, alg) _, S, Vᴴ = svd_full!(A, alg.alg.alg) From 878b428ec8e9a1d891ad889fd06e750cf79a2e2c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 25 Oct 2025 13:03:35 -0400 Subject: [PATCH 39/47] fix docstring --- src/algorithms.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 388e1eac..be88e64a 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -59,7 +59,8 @@ end """ does_truncate(alg::AbstractAlgorithm) -> Bool -Check whether or not an algorithm can be used for a truncated decomposition. +Indicate whether or not an algorithm will compute a truncated decomposition +(such that composing the factors only approximates the input up to some tolerance). """ does_truncate(alg::AbstractAlgorithm) = false From 1744416bdd739b8f359cda94a51af1bfc9264e11 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 25 Oct 2025 13:07:10 -0400 Subject: [PATCH 40/47] more more more AMD fixes --- test/amd/orthnull.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/amd/orthnull.jl b/test/amd/orthnull.jl index b86bfdf6..034f19ff 100644 --- a/test/amd/orthnull.jl +++ b/test/amd/orthnull.jl @@ -140,7 +140,9 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) # with alg and tol kwargs if alg == :svd V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) + AMDGPU.@allowscalar begin + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) + end @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -150,7 +152,9 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test hV2 * hV2' + hN2 * hN2' ≈ I V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg = $(QuoteNode(alg)), trunc = (; rtol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) + AMDGPU.@allowscalar begin + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) + end @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) From 54dda052107ab133ca3852d1449cbed0d2c5f87b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 25 Oct 2025 13:10:19 -0400 Subject: [PATCH 41/47] Revert "no randomized svd for null" This reverts commit 185d939c909882c8720cb78397e0bc2a99998e39. --- src/implementations/orthnull.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 651855a3..413e4b5c 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -111,6 +111,6 @@ end # randomized algorithms don't currently work for smallest values: left_null!(A, N, alg::LeftNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) = - throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet")) right_null!(A, Nᴴ, alg::RightNullViaSVD{<:TruncatedAlgorithm{<:GPU_Randomized}}) = - throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces yet")) From 15ff6aee33197cb3cc4b4f3a99f44d0a8d971d49 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 30 Oct 2025 11:09:46 -0400 Subject: [PATCH 42/47] update syntax for leftorth --- src/interface/orthnull.jl | 189 +++++++++++++++++++------------------- test/amd/orthnull.jl | 6 +- test/cuda/orthnull.jl | 6 +- test/orthnull.jl | 6 +- 4 files changed, 105 insertions(+), 102 deletions(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 3267a68c..6aa33fd8 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -1,8 +1,8 @@ # Orth functions # -------------- """ - left_orth(A; [trunc], kwargs...) -> V, C - left_orth!(A, [VC]; [trunc], kwargs...) -> V, C + left_orth(A; [alg], [trunc], kwargs...) -> V, C + left_orth!(A, [VC], [alg]; [trunc], kwargs...) -> V, C Compute an orthonormal basis `V` for the image of the matrix `A`, as well as a matrix `C` (the corestriction) such that `A` factors as `A = V * C`. @@ -34,31 +34,32 @@ application purposes. ### `alg::Nothing` This default mode uses the presence of a truncation strategy `trunc` to determine an optimal -decomposition type, which will be QR-based for no truncation, or SVD-based for truncation. -The remaining keyword arguments are passed on directly to the algorithm selection procedure -of the chosen decomposition type. +decomposition type, which will typically be QR-based for no truncation, or SVD-based for +truncation. The remaining keyword arguments are passed on directly to the algorithm selection +procedure of the chosen decomposition type. ### `alg::Symbol` -Here, the driving selector is `alg`, and depending on its value, the algorithm selection -procedure takes other keywords into account: +Here, the driving selector is `alg`, which is used to select the kind of decomposition. The +remaining keyword arguments are passed on directly to the algorithm selection procedure of +the chosen decomposition type. By default, the supported kinds are: -* `:qr` : Factorize via QR decomposition, with further customizations through the - `qr` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +* `:qr` : Factorize via QR decomposition, with further customizations through the other + keywords. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - V, C = qr_compact(A; alg = qr) + V, C = qr_compact(A; kwargs...) ``` * `:polar` : Factorize via polar decomposition, with further customizations through the - `polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to + other keywords. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - V, C = left_polar(A; alg = polar) + V, C = left_polar(A; kwargs...) ``` -* `:svd` : Factorize via SVD, with further customizations through the `svd` keyword. - This mode further allows truncation, which can be selected through the `trunc` argument. - This mode is roughly equivalent to: +* `:svd` : Factorize via SVD, with further customizations through the other keywords. + This mode further allows truncation, which can be selected through the `trunc` argument, + and is roughly equivalent to: ```julia - V, S, C = svd_trunc(A; trunc, alg = svd) + V, S, C = svd_trunc(A; trunc, kwargs...) C = S * C ``` @@ -66,7 +67,7 @@ procedure takes other keywords into account: In this expert mode the algorithm is supplied directly, and the kind of decomposition is deduced from that. This is achieved either directly by providing a [`LeftOrthAlgorithm{kind}`](@ref LeftOrthAlgorithm), or automatically by attempting to -deduce the decomposition kind with `LeftOrthAlgorithm(alg)`. +deduce the decomposition kind with [`left_orth_alg(alg)`](@ref left_orth_alg). --- @@ -75,21 +76,22 @@ deduce the decomposition kind with `LeftOrthAlgorithm(alg)`. destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `CV` as output. -See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) +See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null) and +[`right_null(!)`](@ref right_null). """ @functiondef left_orth """ - right_orth(A; [trunc], kwargs...) -> C, Vᴴ - right_orth!(A, [CVᴴ]; [trunc], kwargs...) -> C, Vᴴ + right_orth(A; [alg], [trunc], kwargs...) -> C, Vᴴ + right_orth!(A, [CVᴴ], [alg]; [trunc], kwargs...) -> C, Vᴴ Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e. for the image of `adjoint(A)`, as well as a matrix `C` such that `A` factors as `A = C * Vᴴ`. This is a high-level wrapper where the keyword arguments can be used to specify and control -the specific orthogonal decomposition that should be used to factor `A`, whereas `trunc` can -optionally be used to control the precision in determining the rank of `A`, typically via -its singular values. +the specific orthogonal decomposition that should be used to factor `A`, whereas `trunc` +can optionally be used to control the precision in determining the rank of `A`, typically +via its singular values. ## Truncation The optional truncation strategy can be controlled via the `trunc` keyword argument, and @@ -113,31 +115,32 @@ application purposes. ### `alg::Nothing` This default mode uses the presence of a truncation strategy `trunc` to determine an optimal -decomposition type, which will be LQ-based for no truncation, or SVD-based for truncation. -The remaining keyword arguments are passed on directly to the algorithm selection procedure -of the chosen decomposition type. +decomposition type, which will typicall be LQ-based for no truncation, or SVD-based for +truncation. The remaining keyword arguments are passed on directly to the algorithm selection +procedure of the chosen decomposition type. ### `alg::Symbol` -Here, the driving selector is `alg`, and depending on its value, the algorithm selection -procedure takes other keywords into account: +Here, the driving selector is `alg`, which is used to select the kind of decomposition. The +remaining keyword arguments are passed on directly to the algorithm selection procedure of +the chosen decomposition type. By default, the supported kinds are: -* `:lq` : Factorize via LQ decomposition, with further customizations through the - `lq` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +* `:lq` : Factorize via LQ decomposition, with further customizations through the other + keywords. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - C, Vᴴ = lq_compact(A; alg = lq) + C, Vᴴ = lq_compact(A; kwargs...) ``` * `:polar` : Factorize via polar decomposition, with further customizations through the - `polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to + other keywords. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - C, Vᴴ = right_polar(A; alg = polar) + C, Vᴴ = right_polar(A; kwargs...) ``` -* `:svd` : Factorize via SVD, with further customizations through the `svd` keyword. - This mode further allows truncation, which can be selected through the `trunc` argument. - This mode is roughly equivalent to: +* `:svd` : Factorize via SVD, with further customizations through the other keywords. + This mode further allows truncation, which can be selected through the `trunc` argument, + and is roughly equivalent to: ```julia - C, S, Vᴴ = svd_trunc(A; trunc, alg = svd) + C, S, Vᴴ = svd_trunc(A; trunc, kwargs...) C = C * S ``` @@ -145,7 +148,7 @@ procedure takes other keywords into account: In this expert mode the algorithm is supplied directly, and the kind of decomposition is deduced from that. This is achieved either directly by providing a [`RightOrthAlgorithm{kind}`](@ref RightOrthAlgorithm), or automatically by attempting to -deduce the decomposition kind with `RightOrthAlgorithm(alg)`. +deduce the decomposition kind with [`right_orth_alg`](@ref). --- @@ -154,16 +157,16 @@ deduce the decomposition kind with `RightOrthAlgorithm(alg)`. destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `CVᴴ` as output. -See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), -[`right_null(!)`](@ref right_null) +See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null) and +[`right_null(!)`](@ref right_null). """ @functiondef right_orth # Null functions # -------------- """ - left_null(A; [trunc], kwargs...) -> N - left_null!(A, [N]; [trunc], kwargs...) -> N + left_null(A; [alg], [trunc], kwargs...) -> N + left_null!(A, [N], [alg]; [trunc], kwargs...) -> N Compute an orthonormal basis `N` for the cokernel of the matrix `A`, i.e. the nullspace of `adjoint(A)`, such that `adjoint(A) * N ≈ 0` and `N' * N ≈ I`. @@ -204,20 +207,21 @@ The remaining keyword arguments are passed on directly to the algorithm selectio of the chosen decomposition type. ### `alg::Symbol` -Here, the driving selector is `alg`, and depending on its value, the algorithm selection -procedure takes other keywords into account: +Here, the driving selector is `alg`, which is used to select the kind of decomposition. The +remaining keyword arguments are passed on directly to the algorithm selection procedure of +the chosen decomposition type. By default, the supported kinds are: -* `:qr` : Factorize via QR nullspace, with further customizations through the `qr` - keyword. This mode requires `isnothing(trunc)`, and is equivalent to +* `:qr` : Factorize via QR nullspace, with further customizations through the other + keywords. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - N = qr_null(A; alg = qr) + N = qr_null(A; kwargs...) ``` -* `:svd` : Factorize via SVD, with further customizations through the `svd` keyword. +* `:svd` : Factorize via SVD, with further customizations through the other keywords. This mode further allows truncation, which can be selected through the `trunc` argument. It is roughly equivalent to: ```julia - U, S, _ = svd_trunc(A; trunc, alg = svd) + U, S, _ = svd_full(A; kwargs...) N = truncate(left_null, (U, S), trunc) ``` @@ -225,7 +229,7 @@ procedure takes other keywords into account: In this expert mode the algorithm is supplied directly, and the kind of decomposition is deduced from that. This is achieved either directly by providing a [`LeftNullAlgorithm{kind}`](@ref LeftNullAlgorithm), or automatically by attempting to -deduce the decomposition kind with `LeftNullAlgorithm(alg)`. +deduce the decomposition kind with [`left_null_alg(alg)`](@ref left_null_alg). --- @@ -234,14 +238,14 @@ deduce the decomposition kind with `LeftNullAlgorithm(alg)`. destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `N` as output. -See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), -[`right_orth(!)`](@ref right_orth) +See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth) and +[`right_orth(!)`](@ref right_orth). """ @functiondef left_null """ - right_null(A; [trunc], kwargs...) -> Nᴴ - right_null!(A, [Nᴴ]; [trunc], kwargs...) -> Nᴴ + right_null(A; [alg], [trunc], kwargs...) -> Nᴴ + right_null!(A, [Nᴴ], [alg]; [trunc], kwargs...) -> Nᴴ Compute an orthonormal basis `N = adjoint(Nᴴ)` for the kernel of the matrix `A`, i.e. the nullspace of `A`, such that `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. @@ -282,20 +286,21 @@ The remaining keyword arguments are passed on directly to the algorithm selectio of the chosen decomposition type. ### `alg::Symbol` -Here, the driving selector is `alg`, and depending on its value, the algorithm selection -procedure takes other keywords into account: +Here, the driving selector is `alg`, which is used to select the kind of decomposition. The +remaining keyword arguments are passed on directly to the algorithm selection procedure of +the chosen decomposition type. By default, the supported kinds are: -* `:lq` : Factorize via LQ nullspace, with further customizations through the `lq` - keyword. This mode requires `isnothing(trunc)`, and is equivalent to +* `:lq` : Factorize via LQ nullspace, with further customizations through the other + keywords. This mode requires `isnothing(trunc)`, and is equivalent to ```julia - Nᴴ = lq_null(A; alg = lq) + Nᴴ = lq_null(A; kwargs...) ``` -* `:svd` : Factorize via SVD, with further customizations through the `svd` keyword. +* `:svd` : Factorize via SVD, with further customizations through the other keywords. This mode further allows truncation, which can be selected through the `trunc` argument. It is roughly equivalent to: ```julia - _, S, Vᴴ = svd_trunc(A; trunc, alg = svd) + _, S, Vᴴ = svd_full(A; kwargs...) Nᴴ = truncate(right_null, (S, Vᴴ), trunc) ``` @@ -303,7 +308,7 @@ procedure takes other keywords into account: In this expert mode the algorithm is supplied directly, and the kind of decomposition is deduced from that. This is achieved either directly by providing a [`RightNullAlgorithm{kind}`](@ref RightNullAlgorithm), or automatically by attempting to -deduce the decomposition kind with `RightNullAlgorithm(alg)`. +deduce the decomposition kind with [`right_null_alg(alg)`](@ref right_null_alg). --- @@ -312,8 +317,8 @@ deduce the decomposition kind with `RightNullAlgorithm(alg)`. destroys the input matrix `A`. Always use the return value of the function as it may not always be possible to use the provided `Nᴴ` as output. -See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), -[`right_orth(!)`](@ref right_orth) +See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth) and +[`right_orth(!)`](@ref right_orth). """ @functiondef right_null @@ -332,45 +337,43 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), function select_algorithm(::typeof(left_orth!), A, ::Val{:qr}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("QR-based `left_orth` is incompatible with specifying `trunc`")) - alg′ = select_algorithm(qr_compact!, A, get(kwargs, :qr, nothing)) + alg′ = select_algorithm(qr_compact!, A; kwargs...) return LeftOrthViaQR(alg′) end function select_algorithm(::typeof(left_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("Polar-based `left_orth` is incompatible with specifying `trunc`")) - alg′ = select_algorithm(left_polar!, A, get(kwargs, :polar, nothing)) + alg′ = select_algorithm(left_polar!, A; kwargs...) return LeftOrthViaPolar(alg′) end function select_algorithm(::typeof(left_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...) - alg = get(kwargs, :svd, nothing) - alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg) : - select_algorithm(svd_trunc!, A, alg; trunc) + alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A; kwargs...) : + select_algorithm(svd_trunc!, A; trunc, kwargs...) return LeftOrthViaSVD(alg′) end function select_algorithm(::typeof(right_orth!), A, ::Val{:lq}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("LQ-based `right_orth` is incompatible with specifying `trunc`")) - alg = select_algorithm(lq_compact!, A, get(kwargs, :lq, nothing)) + alg = select_algorithm(lq_compact!, A; kwargs...) return RightOrthViaLQ(alg) end function select_algorithm(::typeof(right_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("Polar-based `right_orth` is incompatible with specifying `trunc`")) - alg = select_algorithm(right_polar!, A, get(kwargs, :polar, nothing)) + alg = select_algorithm(right_polar!, A; kwargs...) return RightOrthViaPolar(alg) end function select_algorithm(::typeof(right_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...) - alg = get(kwargs, :svd, nothing) - alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A, alg) : - select_algorithm(svd_trunc!, A, alg; trunc) + alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A; kwargs...) : + select_algorithm(svd_trunc!, A; trunc, kwargs...) return RightOrthViaSVD(alg′) end function select_algorithm(::typeof(left_null!), A, ::Val{:qr}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("QR-based `left_null` is incompatible with specifying `trunc`")) - alg = select_algorithm(qr_null!, A, get(kwargs, :qr, nothing)) + alg = select_algorithm(qr_null!, A; kwargs...) return LeftNullViaQR(alg) end function select_algorithm(::typeof(left_null!), A, ::Val{:svd}; trunc = nothing, kwargs...) @@ -382,46 +385,46 @@ end function select_algorithm(::typeof(right_null!), A, ::Val{:lq}; trunc = nothing, kwargs...) isnothing(trunc) || throw(ArgumentError("LQ-based `right_null` is incompatible with specifying `trunc`")) - alg = select_algorithm(lq_null!, A, get(kwargs, :lq, nothing)) + alg = select_algorithm(lq_null!, A; kwargs...) return RightNullViaLQ(alg) end function select_algorithm(::typeof(right_null!), A, ::Val{:svd}; trunc = nothing, kwargs...) - alg_svd = select_algorithm(svd_full!, A, get(kwargs, :svd, nothing)) + alg_svd = select_algorithm(svd_full!, A; kwargs...) alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) return RightNullViaSVD(alg) end default_algorithm(::typeof(left_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); qr = kwargs) : - select_algorithm(left_orth!, A, Val(:svd); svd = kwargs) + isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); kwargs...) : + select_algorithm(left_orth!, A, Val(:svd); trunc, kwargs...) # disambiguate default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); qr = kwargs) : - select_algorithm(left_orth!, A, Val(:svd); svd = kwargs) + isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); kwargs...) : + select_algorithm(left_orth!, A, Val(:svd); trunc, kwargs...) default_algorithm(::typeof(right_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); lq = kwargs) : - select_algorithm(right_orth!, A, Val(:svd); svd = kwargs) + isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); kwargs...) : + select_algorithm(right_orth!, A, Val(:svd); trunc, kwargs...) # disambiguate default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); lq = kwargs) : - select_algorithm(right_orth!, A, Val(:svd); svd = kwargs) + isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); kwargs...) : + select_algorithm(right_orth!, A, Val(:svd); trunc, kwargs...) default_algorithm(::typeof(left_null!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); qr = kwargs) : - select_algorithm(left_null!, A, Val(:svd); svd = kwargs, trunc) + isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); kwargs...) : + select_algorithm(left_null!, A, Val(:svd); trunc, kwargs...) # disambiguate default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); qr = kwargs) : - select_algorithm(left_null!, A, Val(:svd); svd = kwargs, trunc) + isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); kwargs...) : + select_algorithm(left_null!, A, Val(:svd); trunc, kwargs...) default_algorithm(::typeof(right_null!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); lq = kwargs) : - select_algorithm(right_null!, A, Val(:svd); svd = kwargs, trunc) + isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); kwargs...) : + select_algorithm(right_null!, A, Val(:svd); trunc, kwargs...) # disambiguate default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = - isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); lq = kwargs) : - select_algorithm(right_null!, A, Val(:svd); svd = kwargs, trunc) + isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); kwargs...) : + select_algorithm(right_null!, A, Val(:svd); trunc, kwargs...) left_orth_alg(alg::AbstractAlgorithm) = LeftOrthAlgorithm(alg) right_orth_alg(alg::AbstractAlgorithm) = RightOrthAlgorithm(alg) diff --git a/test/amd/orthnull.jl b/test/amd/orthnull.jl index 034f19ff..6ed44228 100644 --- a/test/amd/orthnull.jl +++ b/test/amd/orthnull.jl @@ -51,8 +51,8 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) end # passing a kind and some kwargs - V, C = @constinferred left_orth(A; alg = :qr, qr = (; positive = true)) - N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + V, C = @constinferred left_orth(A; alg = :qr, positive = true) + N = @constinferred left_null(A; alg = :qr, positive = true) @test V isa ROCMatrix{T} && size(V) == (m, minmn) @test C isa ROCMatrix{T} && size(C) == (minmn, n) @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) @@ -66,7 +66,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) # passing an algorithm V, C = @constinferred left_orth(A; alg = CUSOLVER_HouseholderQR()) - N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, positive = true) @test V isa ROCMatrix{T} && size(V) == (m, minmn) @test C isa ROCMatrix{T} && size(C) == (minmn, n) @test N isa ROCMatrix{T} && size(N) == (m, m - minmn) diff --git a/test/cuda/orthnull.jl b/test/cuda/orthnull.jl index 3c3ec0c0..2a2a26f6 100644 --- a/test/cuda/orthnull.jl +++ b/test/cuda/orthnull.jl @@ -51,8 +51,8 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) end # passing a kind and some kwargs - V, C = @constinferred left_orth(A; alg = :qr, qr = (; positive = true)) - N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + V, C = @constinferred left_orth(A; alg = :qr, positive = true) + N = @constinferred left_null(A; alg = :qr, positive = true) @test V isa CuMatrix{T} && size(V) == (m, minmn) @test C isa CuMatrix{T} && size(C) == (minmn, n) @test N isa CuMatrix{T} && size(N) == (m, m - minmn) @@ -66,7 +66,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) # passing an algorithm V, C = @constinferred left_orth(A; alg = CUSOLVER_HouseholderQR()) - N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, positive = true) @test V isa CuMatrix{T} && size(V) == (m, minmn) @test C isa CuMatrix{T} && size(C) == (minmn, n) @test N isa CuMatrix{T} && size(N) == (m, m - minmn) diff --git a/test/orthnull.jl b/test/orthnull.jl index 1bd779c0..ce742e8f 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -43,8 +43,8 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) end # passing a kind and some kwargs - V, C = @constinferred left_orth(A; alg = :qr, qr = (; positive = true)) - N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + V, C = @constinferred left_orth(A; alg = :qr, positive = true) + N = @constinferred left_null(A; alg = :qr, positive = true) @test V isa Matrix{T} && size(V) == (m, minmn) @test C isa Matrix{T} && size(C) == (minmn, n) @test N isa Matrix{T} && size(N) == (m, m - minmn) @@ -56,7 +56,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) # passing an algorithm V, C = @constinferred left_orth(A; alg = LAPACK_HouseholderQR()) - N = @constinferred left_null(A; alg = :qr, qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, positive = true) @test V isa Matrix{T} && size(V) == (m, minmn) @test C isa Matrix{T} && size(C) == (minmn, n) @test N isa Matrix{T} && size(N) == (m, m - minmn) From aff5766fd534b396e1622dce61026fe16b72bef1 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 30 Oct 2025 11:24:48 -0400 Subject: [PATCH 43/47] update algselector docstrings --- src/interface/orthnull.jl | 71 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 6aa33fd8..fb04df33 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -426,7 +426,78 @@ default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); kwargs...) : select_algorithm(right_null!, A, Val(:svd); trunc, kwargs...) +""" + left_orth_alg(alg::AbstractAlgorithm) -> LeftOrthAlgorithm + +Convert an algorithm to a [`LeftOrthAlgorithm`](@ref) wrapper for use with [`left_orth`](@ref). + +This function attempts to deduce the appropriate factorization kind (`:qr`, `:polar`, or `:svd`) +from the algorithm type and wraps it in a `LeftOrthAlgorithm`. Custom algorithm types can be +registered by defining: + +```julia +MatrixAlgebraKit.LeftOrthAlgorithm(alg::CustomAlgorithm) = LeftOrthAlgorithm{kind}(alg) +``` + +where `kind` specifies the factorization backend to use. + +See also [`LeftOrthAlgorithm`](@ref), [`left_orth`](@ref). +""" left_orth_alg(alg::AbstractAlgorithm) = LeftOrthAlgorithm(alg) + +""" + right_orth_alg(alg::AbstractAlgorithm) -> RightOrthAlgorithm + +Convert an algorithm to a [`RightOrthAlgorithm`](@ref) wrapper for use with [`right_orth`](@ref). + +This function attempts to deduce the appropriate factorization kind (`:lq`, `:polar`, or `:svd`) +from the algorithm type and wraps it in a `RightOrthAlgorithm`. Custom algorithm types can be +registered by defining: + +```julia +MatrixAlgebraKit.RightOrthAlgorithm(alg::CustomAlgorithm) = RightOrthAlgorithm{kind}(alg) +``` + +where `kind` specifies the factorization backend to use. + +See also [`RightOrthAlgorithm`](@ref), [`right_orth`](@ref). +""" right_orth_alg(alg::AbstractAlgorithm) = RightOrthAlgorithm(alg) + +""" + left_null_alg(alg::AbstractAlgorithm) -> LeftNullAlgorithm + +Convert an algorithm to a [`LeftNullAlgorithm`](@ref) wrapper for use with [`left_null`](@ref). + +This function attempts to deduce the appropriate factorization kind (`:qr` or `:svd`) from +the algorithm type and wraps it in a `LeftNullAlgorithm`. Custom algorithm types can be +registered by defining: + +```julia +MatrixAlgebraKit.LeftNullAlgorithm(alg::CustomAlgorithm) = LeftNullAlgorithm{kind}(alg) +``` + +where `kind` specifies the factorization backend to use. + +See also [`LeftNullAlgorithm`](@ref), [`left_null`](@ref). +""" left_null_alg(alg::AbstractAlgorithm) = LeftNullAlgorithm(alg) + +""" + right_null_alg(alg::AbstractAlgorithm) -> RightNullAlgorithm + +Convert an algorithm to a [`RightNullAlgorithm`](@ref) wrapper for use with [`right_null`](@ref). + +This function attempts to deduce the appropriate factorization kind (`:lq` or `:svd`) from +the algorithm type and wraps it in a `RightNullAlgorithm`. Custom algorithm types can be +registered by defining: + +```julia +MatrixAlgebraKit.RightNullAlgorithm(alg::CustomAlgorithm) = RightNullAlgorithm{kind}(alg) +``` + +where `kind` specifies the factorization backend to use. + +See also [`RightNullAlgorithm`](@ref), [`right_null`](@ref). +""" right_null_alg(alg::AbstractAlgorithm) = RightNullAlgorithm(alg) From afa64bc309a1449f8ac2707157b9cc263ecf9daf Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 30 Oct 2025 11:36:17 -0400 Subject: [PATCH 44/47] migrate logic from constructors to functions --- src/interface/decompositions.jl | 32 ++++---------- src/interface/orthnull.jl | 78 +++++++++++++++++++++++++++++---- 2 files changed, 78 insertions(+), 32 deletions(-) diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 3a81f086..c36a6b02 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -329,12 +329,13 @@ struct LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm end LeftOrthAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = LeftOrthAlgorithm{Kind, Alg}(alg) +# Note: specific algorithm selection is handled by `left_orth_alg` in orthnull.jl LeftOrthAlgorithm(alg::AbstractAlgorithm) = error( """ Unkown or invalid `left_orth` algorithm type `$(typeof(alg))`. To register the algorithm type for `left_orth`, define - MatrixAlgebraKit.LeftOrthAlgorithm(alg) = LeftOrthAlgorithm{kind}(alg) + MatrixAlgebraKit.left_orth_alg(alg::CustomAlgorithm) = LeftOrthAlgorithm{kind}(alg) where `kind` selects the factorization type that will be used. By default, this is either `:qr`, `:polar` or `:svd`, to select [`qr_compact!`](@ref), @@ -343,14 +344,8 @@ LeftOrthAlgorithm(alg::AbstractAlgorithm) = error( ) const LeftOrthViaQR = LeftOrthAlgorithm{:qr} -LeftOrthAlgorithm(alg::QRAlgorithms) = LeftOrthViaQR{typeof(alg)}(alg) - const LeftOrthViaPolar = LeftOrthAlgorithm{:polar} -LeftOrthAlgorithm(alg::PolarAlgorithms) = LeftOrthViaPolar{typeof(alg)}(alg) - const LeftOrthViaSVD = LeftOrthAlgorithm{:svd} -LeftOrthAlgorithm(alg::SVDAlgorithms) = LeftOrthViaSVD{typeof(alg)}(alg) -LeftOrthAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftOrthViaSVD{typeof(alg)}(alg) """ RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) @@ -363,12 +358,13 @@ struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm end RightOrthAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = RightOrthAlgorithm{Kind, Alg}(alg) +# Note: specific algorithm selection is handled by `right_orth_alg` in orthnull.jl RightOrthAlgorithm(alg::AbstractAlgorithm) = error( """ Unkown or invalid `right_orth` algorithm type `$(typeof(alg))`. To register the algorithm type for `right_orth`, define - MatrixAlgebraKit.RightOrthAlgorithm(alg) = RightOrthAlgorithm{kind}(alg) + MatrixAlgebraKit.right_orth_alg(alg::CustomAlgorithm) = RightOrthAlgorithm{kind}(alg) where `kind` selects the factorization type that will be used. By default, this is either `:lq`, `:polar` or `:svd`, to select [`lq_compact!`](@ref), @@ -377,14 +373,8 @@ RightOrthAlgorithm(alg::AbstractAlgorithm) = error( ) const RightOrthViaLQ = RightOrthAlgorithm{:lq} -RightOrthAlgorithm(alg::LQAlgorithms) = RightOrthViaLQ{typeof(alg)}(alg) - const RightOrthViaPolar = RightOrthAlgorithm{:polar} -RightOrthAlgorithm(alg::PolarAlgorithms) = RightOrthViaPolar{typeof(alg)}(alg) - const RightOrthViaSVD = RightOrthAlgorithm{:svd} -RightOrthAlgorithm(alg::SVDAlgorithms) = RightOrthViaSVD{typeof(alg)}(alg) -RightOrthAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightOrthViaSVD{typeof(alg)}(alg) """ LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) @@ -397,12 +387,13 @@ struct LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm end LeftNullAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = LeftNullAlgorithm{Kind, Alg}(alg) +# Note: specific algorithm selection is handled by `left_null_alg` in orthnull.jl LeftNullAlgorithm(alg::AbstractAlgorithm) = error( """ Unkown or invalid `left_null` algorithm type `$(typeof(alg))`. To register the algorithm type for `left_null`, define - MatrixAlgebraKit.LeftNullAlgorithm(alg) = LeftNullAlgorithm{kind}(alg) + MatrixAlgebraKit.left_null_alg(alg::CustomAlgorithm) = LeftNullAlgorithm{kind}(alg) where `kind` selects the factorization type that will be used. By default, this is either `:qr` or `:svd`, to select [`qr_null!`](@ref), @@ -411,11 +402,7 @@ LeftNullAlgorithm(alg::AbstractAlgorithm) = error( ) const LeftNullViaQR = LeftNullAlgorithm{:qr} -LeftNullAlgorithm(alg::QRAlgorithms) = LeftNullViaQR{typeof(alg)}(alg) - const LeftNullViaSVD = LeftNullAlgorithm{:svd} -LeftNullAlgorithm(alg::SVDAlgorithms) = LeftNullViaSVD{typeof(alg)}(alg) -LeftNullAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftNullViaSVD{typeof(alg)}(alg) """ RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) @@ -428,12 +415,13 @@ struct RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm end RightNullAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = RightNullAlgorithm{Kind, Alg}(alg) +# Note: specific algorithm selection is handled by `right_null_alg` in orthnull.jl RightNullAlgorithm(alg::AbstractAlgorithm) = error( """ Unkown or invalid `right_null` algorithm type `$(typeof(alg))`. To register the algorithm type for `right_null`, define - MatrixAlgebraKit.RightNullAlgorithm(alg) = RightNullAlgorithm{kind}(alg) + MatrixAlgebraKit.right_null_alg(alg::CustomAlgorithm) = RightNullAlgorithm{kind}(alg) where `kind` selects the factorization type that will be used. By default, this is either `:lq` or `:svd`, to select [`lq_null!`](@ref), @@ -442,8 +430,4 @@ RightNullAlgorithm(alg::AbstractAlgorithm) = error( ) const RightNullViaLQ = RightNullAlgorithm{:lq} -RightNullAlgorithm(alg::LQAlgorithms) = RightNullViaLQ{typeof(alg)}(alg) - const RightNullViaSVD = RightNullAlgorithm{:svd} -RightNullAlgorithm(alg::SVDAlgorithms) = RightNullViaSVD{typeof(alg)}(alg) -RightNullAlgorithm(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightNullViaSVD{typeof(alg)}(alg) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index fb04df33..34193ec8 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -436,14 +436,30 @@ from the algorithm type and wraps it in a `LeftOrthAlgorithm`. Custom algorithm registered by defining: ```julia -MatrixAlgebraKit.LeftOrthAlgorithm(alg::CustomAlgorithm) = LeftOrthAlgorithm{kind}(alg) +MatrixAlgebraKit.left_orth_alg(alg::CustomAlgorithm) = LeftOrthAlgorithm{kind}(alg) ``` where `kind` specifies the factorization backend to use. See also [`LeftOrthAlgorithm`](@ref), [`left_orth`](@ref). """ -left_orth_alg(alg::AbstractAlgorithm) = LeftOrthAlgorithm(alg) +left_orth_alg(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `left_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type for `left_orth`, define + + MatrixAlgebraKit.left_orth_alg(alg::CustomAlgorithm) = LeftOrthAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:qr`, `:polar` or `:svd`, to select [`qr_compact!`](@ref), + [`left_polar!`](@ref), [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) +left_orth_alg(alg::LeftOrthAlgorithm) = alg +left_orth_alg(alg::QRAlgorithms) = LeftOrthViaQR(alg) +left_orth_alg(alg::PolarAlgorithms) = LeftOrthViaPolar(alg) +left_orth_alg(alg::SVDAlgorithms) = LeftOrthViaSVD(alg) +left_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftOrthViaSVD(alg) """ right_orth_alg(alg::AbstractAlgorithm) -> RightOrthAlgorithm @@ -455,14 +471,30 @@ from the algorithm type and wraps it in a `RightOrthAlgorithm`. Custom algorithm registered by defining: ```julia -MatrixAlgebraKit.RightOrthAlgorithm(alg::CustomAlgorithm) = RightOrthAlgorithm{kind}(alg) +MatrixAlgebraKit.right_orth_alg(alg::CustomAlgorithm) = RightOrthAlgorithm{kind}(alg) ``` where `kind` specifies the factorization backend to use. See also [`RightOrthAlgorithm`](@ref), [`right_orth`](@ref). """ -right_orth_alg(alg::AbstractAlgorithm) = RightOrthAlgorithm(alg) +right_orth_alg(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `right_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type for `right_orth`, define + + MatrixAlgebraKit.right_orth_alg(alg::CustomAlgorithm) = RightOrthAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:lq`, `:polar` or `:svd`, to select [`lq_compact!`](@ref), + [`right_polar!`](@ref), [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) +right_orth_alg(alg::RightOrthAlgorithm) = alg +right_orth_alg(alg::LQAlgorithms) = RightOrthViaLQ(alg) +right_orth_alg(alg::PolarAlgorithms) = RightOrthViaPolar(alg) +right_orth_alg(alg::SVDAlgorithms) = RightOrthViaSVD(alg) +right_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightOrthViaSVD(alg) """ left_null_alg(alg::AbstractAlgorithm) -> LeftNullAlgorithm @@ -474,14 +506,29 @@ the algorithm type and wraps it in a `LeftNullAlgorithm`. Custom algorithm types registered by defining: ```julia -MatrixAlgebraKit.LeftNullAlgorithm(alg::CustomAlgorithm) = LeftNullAlgorithm{kind}(alg) +MatrixAlgebraKit.left_null_alg(alg::CustomAlgorithm) = LeftNullAlgorithm{kind}(alg) ``` where `kind` specifies the factorization backend to use. See also [`LeftNullAlgorithm`](@ref), [`left_null`](@ref). """ -left_null_alg(alg::AbstractAlgorithm) = LeftNullAlgorithm(alg) +left_null_alg(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `left_null` algorithm type `$(typeof(alg))`. + To register the algorithm type for `left_null`, define + + MatrixAlgebraKit.left_null_alg(alg::CustomAlgorithm) = LeftNullAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:qr` or `:svd`, to select [`qr_null!`](@ref), + [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) +left_null_alg(alg::LeftNullAlgorithm) = alg +left_null_alg(alg::QRAlgorithms) = LeftNullViaQR(alg) +left_null_alg(alg::SVDAlgorithms) = LeftNullViaSVD(alg) +left_null_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftNullViaSVD(alg) """ right_null_alg(alg::AbstractAlgorithm) -> RightNullAlgorithm @@ -493,11 +540,26 @@ the algorithm type and wraps it in a `RightNullAlgorithm`. Custom algorithm type registered by defining: ```julia -MatrixAlgebraKit.RightNullAlgorithm(alg::CustomAlgorithm) = RightNullAlgorithm{kind}(alg) +MatrixAlgebraKit.right_null_alg(alg::CustomAlgorithm) = RightNullAlgorithm{kind}(alg) ``` where `kind` specifies the factorization backend to use. See also [`RightNullAlgorithm`](@ref), [`right_null`](@ref). """ -right_null_alg(alg::AbstractAlgorithm) = RightNullAlgorithm(alg) +right_null_alg(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `right_null` algorithm type `$(typeof(alg))`. + To register the algorithm type for `right_null`, define + + MatrixAlgebraKit.right_null_alg(alg::CustomAlgorithm) = RightNullAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:lq` or `:svd`, to select [`lq_null!`](@ref), + [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) +right_null_alg(alg::RightNullAlgorithm) = alg +right_null_alg(alg::LQAlgorithms) = RightNullViaLQ(alg) +right_null_alg(alg::SVDAlgorithms) = RightNullViaSVD(alg) +right_null_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightNullViaSVD(alg) From fb60fe3ebf1374197297d87f594bbb249632a73e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 30 Oct 2025 14:59:46 -0400 Subject: [PATCH 45/47] also update docs --- docs/src/user_interface/decompositions.md | 181 +++++++++++++++++++++- 1 file changed, 174 insertions(+), 7 deletions(-) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index 21dc38d2..a41b366f 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -169,38 +169,205 @@ PolarNewton ## Orthogonal Subspaces -Often it is useful to compute orthogonal bases for a particular subspace defined by a matrix. -Given a matrix `A` we can compute an orthonormal basis for its image or coimage, and factorize the matrix accordingly. +Often it is useful to compute orthogonal bases for particular subspaces defined by a matrix. +Given a matrix `A`, we can compute an orthonormal basis for its image or coimage, and factorize the matrix accordingly. These bases are accessible through [`left_orth`](@ref) and [`right_orth`](@ref) respectively. -This is implemented through a combination of the decompositions mentioned above, and serves as a convenient interface to these operations. + +### Overview + +The [`left_orth`](@ref) function computes an orthonormal basis `V` for the image (column space) of `A`, along with a corestriction matrix `C` such that `A = V * C`. +The resulting `V` has orthonormal columns (`V' * V ≈ I` or `isisometric(V)`). + +Similarly, [`right_orth`](@ref) computes an orthonormal basis for the coimage (row space) of `A`, i.e., the image of `A'`. +It returns matrices `C` and `Vᴴ` such that `A = C * Vᴴ`, where `V = (Vᴴ)'` has orthonormal columns (`isisometric(Vᴴ; side = :right)`). + +These functions serve as high-level interfaces that automatically select the most appropriate decomposition based on the specified options, making them convenient for users who want orthonormalization without worrying about the underlying implementation details. ```@docs; canonical=false left_orth right_orth ``` -In order to dispatch to the underlying factorizations, the following wrapper algorithms are used: +### Algorithm Selection + +Both functions support multiple decomposition drivers, which can be selected through the `alg` keyword argument: + +**For `left_orth`:** +- `alg = :qr` (default without truncation): Uses QR decomposition via [`qr_compact`](@ref) +- `alg = :polar`: Uses polar decomposition via [`left_polar`](@ref) +- `alg = :svd` (default with truncation): Uses SVD via [`svd_compact`](@ref) or [`svd_trunc`](@ref) + +**For `right_orth`:** +- `alg = :lq` (default without truncation): Uses LQ decomposition via [`lq_compact`](@ref) +- `alg = :polar`: Uses polar decomposition via [`right_polar`](@ref) +- `alg = :svd` (default with truncation): Uses SVD via [`svd_compact`](@ref) or [`svd_trunc`](@ref) + +When `alg` is not specified, the function automatically selects `:qr`/`:lq` for exact orthogonalization, or `:svd` when a truncation strategy is provided. + +### Extending with Custom Algorithms + +To register a custom algorithm type for use with these functions, you need to define the appropriate conversion function, for example: + +```julia +# For left_orth +MatrixAlgebraKit.left_orth_alg(alg::MyCustomAlgorithm) = LeftOrthAlgorithm{:qr}(alg) + +# For right_orth +MatrixAlgebraKit.right_orth_alg(alg::MyCustomAlgorithm) = RightOrthAlgorithm{:lq}(alg) +``` + +The type parameter (`:qr`, `:lq`, `:polar`, or `:svd`) indicates which factorization backend will be used. +The wrapper algorithm types handle the dispatch to the appropriate implementation: ```@docs; canonical=false +left_orth_alg +right_orth_alg LeftOrthAlgorithm RightOrthAlgorithm ``` +### Examples + +Basic orthogonalization: + +```jldoctest orthnull; output=false +using MatrixAlgebraKit +using LinearAlgebra + +A = [1.0 2.0; 3.0 4.0; 5.0 6.0] +V, C = left_orth(A) +(V' * V) ≈ I && A ≈ V * C + +# output +true +``` + +Using different algorithms: + +```jldoctest orthnull; output=false +A = randn(4, 3) +V1, C1 = left_orth(A; alg = :qr) +V2, C2 = left_orth(A; alg = :polar) +V3, C3 = left_orth(A; alg = :svd) +A ≈ V1 * C1 ≈ V2 * C2 ≈ V3 * C3 + +# output +true +``` + +With truncation: + +```jldoctest orthnull; output=false +A = [1.0 0.0; 0.0 1e-10; 0.0 0.0] +V, C = left_orth(A; trunc = (atol = 1e-8,)) +size(V, 2) == 1 # Only one column retained + +# output +true +``` + ## Null Spaces Similarly, it can be convenient to obtain an orthogonal basis for the kernel or cokernel of a matrix. -These are the compliments of the coimage and image, respectively, and can be computed using the [`left_null`](@ref) and [`right_null`](@ref) functions. -Again, this is typically implemented through a combination of the decompositions mentioned above, and serves as a convenient interface to these operations. +These are the complements of the coimage and image, respectively, and can be computed using the [`left_null`](@ref) and [`right_null`](@ref) functions. + +### Overview + +The [`left_null`](@ref) function computes an orthonormal basis `N` for the cokernel (left nullspace) of `A`, which is the nullspace of `A'`. +This means `A' * N ≈ 0` and `N' * N ≈ I`. + +Similarly, [`right_null`](@ref) computes an orthonormal basis for the kernel (right nullspace) of `A`. +It returns `Nᴴ` such that `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`, where `N = (Nᴴ)'` has orthonormal columns. + +These functions automatically handle rank determination and provide convenient access to nullspace computation without requiring detailed knowledge of the underlying decomposition methods. ```@docs; canonical=false left_null right_null ``` -Again, dispatching happens through the following wrapper algorithm types: +### Algorithm Selection + +Both functions support multiple decomposition drivers, which can be selected through the `alg` keyword argument: + +**For `left_null`:** +- `alg = :qr` (default without truncation): Uses QR-based nullspace computation via [`qr_null`](@ref) +- `alg = :svd` (default with truncation): Uses SVD via [`svd_full`](@ref) with appropriate truncation + +**For `right_null`:** +- `alg = :lq` (default without truncation): Uses LQ-based nullspace computation via [`lq_null`](@ref) +- `alg = :svd` (default with truncation): Uses SVD via [`svd_full`](@ref) with appropriate truncation + +When `alg` is not specified, the function automatically selects `:qr`/`:lq` for exact nullspace computation, or `:svd` when a truncation strategy is provided to handle numerical rank determination. + +!!! note + For nullspace functions, [`notrunc`](@ref) has special meaning when used with the default QR/LQ algorithms. + It indicates that the nullspace should be computed from the exact zeros determined by the additional rows/columns of the extended matrix, without any tolerance-based truncation. + +### Extending with Custom Algorithms + +To register a custom algorithm type for use with these functions, you need to define the appropriate conversion function: + +```julia +# For left_null +MatrixAlgebraKit.left_null_alg(alg::MyCustomAlgorithm) = LeftNullAlgorithm{:qr}(alg) + +# For right_null +MatrixAlgebraKit.right_null_alg(alg::MyCustomAlgorithm) = RightNullAlgorithm{:lq}(alg) +``` + +The type parameter (`:qr`, `:lq`, or `:svd`) indicates which factorization backend will be used. +The wrapper algorithm types handle the dispatch to the appropriate implementation: ```@docs; canonical=false LeftNullAlgorithm RightNullAlgorithm +left_null_alg +right_null_alg +``` + +### Examples + +Basic nullspace computation: + +```jldoctest orthnull; output=false +A = [1.0 2.0 3.0; 4.0 5.0 6.0] # Rank 2 matrix +N = left_null(A) +size(N) == (2, 0) + +# output +true +``` + +```jldoctest orthnull; output=false +Nᴴ = right_null(A) +size(Nᴴ) == (1, 3) && norm(A * Nᴴ') < 1e-14 && isisometric(Nᴴ; side = :right) + +# output +true +``` + +Computing nullspace with rank detection: + +```jldoctest orthnull; output=false +A = [1.0 2.0; 2.0 4.0; 3.0 6.0] # Rank 1 matrix (second column = 2*first) +N = left_null(A; alg = :svd, trunc = (atol = 1e-10,)) +size(N) == (3, 2) && norm(A' * N) < 1e-12 && isisometric(N) + +# output +true +``` + +Using different algorithms: + +```jldoctest orthnull; output=false +A = [1.0 0.0 0.0; 0.0 1.0 0.0] +N1 = right_null(A; alg = :lq) +N2 = right_null(A; alg = :svd) +norm(A * N1') < 1e-14 && norm(A * N2') < 1e-14 && + isisometric(N1; side = :right) && isisometric(N2; side = :right) + +# output +true ``` From a2b7e9a71feb3a114234f5d99f544e4c745ceba2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 30 Oct 2025 15:43:41 -0400 Subject: [PATCH 46/47] remove unnecessary overloads --- src/interface/orthnull.jl | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 34193ec8..67dc9ff4 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -394,34 +394,18 @@ function select_algorithm(::typeof(right_null!), A, ::Val{:svd}; trunc = nothing return RightNullViaSVD(alg) end -default_algorithm(::typeof(left_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); kwargs...) : - select_algorithm(left_orth!, A, Val(:svd); trunc, kwargs...) -# disambiguate default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); kwargs...) : select_algorithm(left_orth!, A, Val(:svd); trunc, kwargs...) -default_algorithm(::typeof(right_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); kwargs...) : - select_algorithm(right_orth!, A, Val(:svd); trunc, kwargs...) -# disambiguate default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); kwargs...) : select_algorithm(right_orth!, A, Val(:svd); trunc, kwargs...) -default_algorithm(::typeof(left_null!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); kwargs...) : - select_algorithm(left_null!, A, Val(:svd); trunc, kwargs...) -# disambiguate default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); kwargs...) : select_algorithm(left_null!, A, Val(:svd); trunc, kwargs...) -default_algorithm(::typeof(right_null!), A::TA; trunc = nothing, kwargs...) where {TA} = - isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); kwargs...) : - select_algorithm(right_null!, A, Val(:svd); trunc, kwargs...) -# disambiguate default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); kwargs...) : select_algorithm(right_null!, A, Val(:svd); trunc, kwargs...) From f04422853c98334f8f0272a89133e43654199af0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 1 Nov 2025 08:21:55 -0400 Subject: [PATCH 47/47] apply suggestions --- docs/src/user_interface/decompositions.md | 2 +- src/algorithms.jl | 1 - src/implementations/orthnull.jl | 4 ---- src/interface/orthnull.jl | 4 ++-- 4 files changed, 3 insertions(+), 8 deletions(-) diff --git a/docs/src/user_interface/decompositions.md b/docs/src/user_interface/decompositions.md index a41b366f..1ae1fefe 100644 --- a/docs/src/user_interface/decompositions.md +++ b/docs/src/user_interface/decompositions.md @@ -270,7 +270,7 @@ true ## Null Spaces Similarly, it can be convenient to obtain an orthogonal basis for the kernel or cokernel of a matrix. -These are the complements of the coimage and image, respectively, and can be computed using the [`left_null`](@ref) and [`right_null`](@ref) functions. +These are the orthogonal complements of the coimage and image, respectively, and can be computed using the [`left_null`](@ref) and [`right_null`](@ref) functions. ### Overview diff --git a/src/algorithms.jl b/src/algorithms.jl index be88e64a..e9e7b8e8 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -55,7 +55,6 @@ end # Algorithm traits # ---------------- - """ does_truncate(alg::AbstractAlgorithm) -> Bool diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 413e4b5c..2a15259b 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -7,7 +7,6 @@ copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need a check_input(::typeof(left_orth!), A, VC, alg::AbstractAlgorithm) = check_input(left_orth!, A, VC, left_orth_alg(alg)) - check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaQR) = check_input(qr_compact!, A, VC, alg.alg) check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaPolar) = @@ -16,7 +15,6 @@ check_input(::typeof(left_orth!), A, VC, alg::LeftOrthViaSVD) = nothing check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = check_input(right_orth!, A, CVᴴ, right_orth_alg(alg)) - check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaLQ) = check_input(lq_compact!, A, VC, alg.alg) check_input(::typeof(right_orth!), A, VC, alg::RightOrthViaPolar) = @@ -39,7 +37,6 @@ check_input(::typeof(right_null!), A, Nᴴ, alg::RightNullViaSVD) = nothing # ------- initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = initialize_output(left_orth!, A, left_orth_alg(alg)) - initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaQR) = initialize_output(qr_compact!, A, alg.alg) initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaPolar) = @@ -48,7 +45,6 @@ initialize_output(::typeof(left_orth!), A, alg::LeftOrthViaSVD) = nothing initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = initialize_output(right_orth!, A, right_orth_alg(alg)) - initialize_output(::typeof(right_orth!), A, alg::RightOrthViaLQ) = initialize_output(lq_compact!, A, alg.alg) initialize_output(::typeof(right_orth!), A, alg::RightOrthViaPolar) = diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index 67dc9ff4..64acb509 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -60,7 +60,7 @@ the chosen decomposition type. By default, the supported kinds are: and is roughly equivalent to: ```julia V, S, C = svd_trunc(A; trunc, kwargs...) - C = S * C + C = lmul!(S, C) ``` ### `alg::AbstractAlgorithm` @@ -141,7 +141,7 @@ the chosen decomposition type. By default, the supported kinds are: and is roughly equivalent to: ```julia C, S, Vᴴ = svd_trunc(A; trunc, kwargs...) - C = C * S + C = rmul!(C, S) ``` ### `alg::AbstractAlgorithm`