diff --git a/src/algorithms.jl b/src/algorithms.jl index a47be2c8..0c7f2594 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -53,6 +53,105 @@ function _show_alg(io::IO, alg::Algorithm) return print(io, ")") end +# Algorithm traits +# ---------------- +""" + left_orth_kind(alg::AbstractAlgorithm) -> f! + +Select an appropriate factorization function for applying `left_orth!(A, alg)`. +By default, this is either `left_orth_qr!`, `left_orth_polar!` or `left_orth_svd!`, but +this can be extended to insert arbitrary other decomposition functions, which should follow +the signature `f!(A, F, alg) -> F` +""" +left_orth_kind(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `left_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type, define: + + MatrixAlgebraKit.left_orth_kind(alg) = f! + + where `f!` should be the factorization function that will be used. + By default, this is either `left_orth_qr!`, `left_orth_polar!` or `left_orth_svd!`. + """ +) + +""" + right_orth_kind(alg::AbstractAlgorithm) -> f! + +Select an appropriate factorization function for applying `right_orth!(A, alg)`. +By default, this is either `right_orth_lq!`, `right_orth_polar!` or `right_orth_svd!`, but +this can be extended to insert arbitrary other decomposition functions, which should follow +the signature `f!(A, F, alg) -> F` +""" +right_orth_kind(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `right_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type, define: + + MatrixAlgebraKit.right_orth_kind(alg) = f! + + where `f!` should be the factorization function that will be used. + By default, this is either `right_orth_lq!`, `right_orth_polar!` or `right_orth_svd!`. + """ +) + +""" + left_null_kind(alg::AbstractAlgorithm) -> f! + +Select an appropriate factorization function for applying `left_null!(A, alg)`. +By default, this is either `left_null_qr!` or `left_null_svd!`, but this can be extended +to insert arbitrary other decomposition functions, which should follow the signature +`f!(A, F, alg) -> F` +""" +function left_null_kind(alg::AbstractAlgorithm) + left_orth_kind(alg) === left_orth_qr! && return left_null_qr! + left_orth_kind(alg) === left_orth_svd! && return left_null_svd! + return error( + """ + Unkown or invalid `left_null` algorithm type `$(typeof(alg))`. + To register the algorithm type, define: + + MatrixAlgebraKit.left_null_kind(alg) = f! + + where `f!` should be the factorization function that will be used. + By default, this is either `left_null_qr!` or `left_null_svd!`. + """ + ) +end + +""" + right_null_kind(alg::AbstractAlgorithm) -> f! + +Select an appropriate factorization function for applying `right_null!(A, alg)`. +By default, this is either `right_null_lq!` or `right_null_svd!`, but this can be extended +to insert arbitrary other decomposition functions, which should follow the signature +`f!(A, F, alg) -> F` +""" +function right_null_kind(alg::AbstractAlgorithm) + right_orth_kind(alg) === right_orth_lq! && return right_null_lq! + right_orth_kind(alg) === right_orth_svd! && return right_null_svd! + return error( + """ + Unkown or invalid `right_null` algorithm type `$(typeof(alg))`. + To register the algorithm type, define: + + MatrixAlgebraKit.right_null_kind(alg) = f! + + where `f!` should be the factorization function that will be used. + By default, this is either `right_null_lq!` or `right_null_svd!`. + """ + ) +end + +""" + does_truncate(alg::AbstractAlgorithm) -> Bool + +Check whether or not an algorithm can be used for a truncated decomposition. +""" +does_truncate(alg::AbstractAlgorithm) = false + +# Algorithm selection +# ------------------- @doc """ MatrixAlgebraKit.select_algorithm(f, A, alg::AbstractAlgorithm) MatrixAlgebraKit.select_algorithm(f, A, alg::Symbol; kwargs...) @@ -160,6 +259,24 @@ function select_truncation(trunc) end end +@doc """ + MatrixAlgebraKit.select_null_truncation(trunc) + +Construct a [`TruncationStrategy`](@ref) from the given `NamedTuple` of keywords or input strategy, to implement a nullspace selection. +""" select_null_truncation + +function select_null_truncation(trunc) + if isnothing(trunc) + return NoTruncation() + elseif trunc isa NamedTuple + return null_truncation_strategy(; trunc...) + elseif trunc isa TruncationStrategy + return trunc + else + return throw(ArgumentError("Unknown truncation strategy: $trunc")) + end +end + @doc """ MatrixAlgebraKit.findtruncated(values::AbstractVector, strategy::TruncationStrategy) @@ -200,6 +317,10 @@ struct TruncatedAlgorithm{A, T} <: AbstractAlgorithm trunc::T end +left_orth_kind(alg::TruncatedAlgorithm) = left_orth_kind(alg.alg) +right_orth_kind(alg::TruncatedAlgorithm) = right_orth_kind(alg.alg) +does_truncate(::TruncatedAlgorithm) = true + # Utility macros # -------------- diff --git a/src/implementations/orthnull.jl b/src/implementations/orthnull.jl index 76e32ce1..ff7d9983 100644 --- a/src/implementations/orthnull.jl +++ b/src/implementations/orthnull.jl @@ -5,287 +5,124 @@ copy_input(::typeof(right_orth), A) = copy_input(lq_compact, A) # do we ever nee copy_input(::typeof(left_null), A) = copy_input(qr_null, A) # do we ever need anything else copy_input(::typeof(right_null), A) = copy_input(lq_null, A) # do we ever need anything else -function check_input(::typeof(left_orth!), A::AbstractMatrix, VC, ::AbstractAlgorithm) - m, n = size(A) - minmn = min(m, n) - V, C = VC - @assert V isa AbstractMatrix && C isa AbstractMatrix - @check_size(V, (m, minmn)) - @check_scalar(V, A) - if !isempty(C) - @check_size(C, (minmn, n)) - @check_scalar(C, A) - end - return nothing -end -function check_input(::typeof(right_orth!), A::AbstractMatrix, CVᴴ, ::AbstractAlgorithm) - m, n = size(A) - minmn = min(m, n) - C, Vᴴ = CVᴴ - @assert C isa AbstractMatrix && Vᴴ isa AbstractMatrix - if !isempty(C) - @check_size(C, (m, minmn)) - @check_scalar(C, A) - end - @check_size(Vᴴ, (minmn, n)) - @check_scalar(Vᴴ, A) - return nothing -end +check_input(::typeof(left_orth!), A, VC, alg::AbstractAlgorithm) = + check_input(left_orth_kind(alg), A, VC, alg) +check_input(::typeof(left_orth_qr!), A, VC, alg::AbstractAlgorithm) = + check_input(qr_compact!, A, VC, alg) +check_input(::typeof(left_orth_polar!), A, VC, alg::AbstractAlgorithm) = + check_input(left_polar!, A, VC, alg) +check_input(::typeof(left_orth_svd!), A, VC, alg::AbstractAlgorithm) = + check_input(qr_compact!, A, VC, alg) -function check_input(::typeof(left_null!), A::AbstractMatrix, N, ::AbstractAlgorithm) - m, n = size(A) - minmn = min(m, n) - @assert N isa AbstractMatrix - @check_size(N, (m, m - minmn)) - @check_scalar(N, A) - return nothing -end -function check_input(::typeof(right_null!), A::AbstractMatrix, Nᴴ, ::AbstractAlgorithm) - m, n = size(A) - minmn = min(m, n) - @assert Nᴴ isa AbstractMatrix - @check_size(Nᴴ, (n - minmn, n)) - @check_scalar(Nᴴ, A) - return nothing -end +check_input(::typeof(right_orth!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(right_orth_kind(alg), A, CVᴴ, alg) +check_input(::typeof(right_orth_lq!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(lq_compact!, A, CVᴴ, alg) +check_input(::typeof(right_orth_polar!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(right_polar!, A, CVᴴ, alg) +check_input(::typeof(right_orth_svd!), A, CVᴴ, alg::AbstractAlgorithm) = + check_input(lq_compact!, A, CVᴴ, alg) + +check_input(::typeof(left_null!), A, N, alg::AbstractAlgorithm) = + check_input(left_null_kind(alg), A, N, alg) +check_input(::typeof(left_null_qr!), A, N, alg::AbstractAlgorithm) = + check_input(qr_null!, A, N, alg) +check_input(::typeof(left_null_svd!), A, N, alg::AbstractAlgorithm) = nothing + +check_input(::typeof(right_null!), A, Nᴴ, alg::AbstractAlgorithm) = + check_input(right_null_kind(alg), A, Nᴴ, alg) +check_input(::typeof(right_null_lq!), A, Nᴴ, alg::AbstractAlgorithm) = + check_input(lq_null!, A, Nᴴ, alg) +check_input(::typeof(right_null_svd!), A, Nᴴ, alg::AbstractAlgorithm) = nothing # Outputs # ------- -function initialize_output(::typeof(left_orth!), A::AbstractMatrix) - m, n = size(A) - minmn = min(m, n) - V = similar(A, (m, minmn)) - C = similar(A, (minmn, n)) - return (V, C) -end -function initialize_output(::typeof(right_orth!), A::AbstractMatrix) - m, n = size(A) - minmn = min(m, n) - C = similar(A, (m, minmn)) - Vᴴ = similar(A, (minmn, n)) - return (C, Vᴴ) -end +initialize_output(::typeof(left_orth!), A, alg::AbstractAlgorithm) = + initialize_output(left_orth_kind(alg), A, alg) +initialize_output(::typeof(left_orth_qr!), A, alg::AbstractAlgorithm) = + initialize_output(qr_compact!, A, alg) +initialize_output(::typeof(left_orth_polar!), A, alg::AbstractAlgorithm) = + initialize_output(left_polar!, A, alg) +initialize_output(::typeof(left_orth_svd!), A, alg::AbstractAlgorithm) = + initialize_output(qr_compact!, A, alg) -function initialize_output(::typeof(left_null!), A::AbstractMatrix) - m, n = size(A) - minmn = min(m, n) - N = similar(A, (m, m - minmn)) - return N -end -function initialize_output(::typeof(right_null!), A::AbstractMatrix) - m, n = size(A) - minmn = min(m, n) - Nᴴ = similar(A, (n - minmn, n)) - return Nᴴ +initialize_output(::typeof(right_orth!), A, alg::AbstractAlgorithm) = + initialize_output(right_orth_kind(alg), A, alg) +initialize_output(::typeof(right_orth_lq!), A, alg::AbstractAlgorithm) = + initialize_output(lq_compact!, A, alg) +initialize_output(::typeof(right_orth_polar!), A, alg::AbstractAlgorithm) = + initialize_output(right_polar!, A, alg) +initialize_output(::typeof(right_orth_svd!), A, alg::AbstractAlgorithm) = + initialize_output(lq_compact!, A, alg) + +initialize_output(::typeof(left_null!), A, alg::AbstractAlgorithm) = + initialize_output(left_null_kind(alg), A, alg) +initialize_output(::typeof(left_null_qr!), A, alg::AbstractAlgorithm) = + initialize_output(qr_null!, A, alg) +initialize_output(::typeof(left_null_svd!), A, alg::AbstractAlgorithm) = nothing + +initialize_output(::typeof(right_null!), A, alg::AbstractAlgorithm) = + initialize_output(right_null_kind(alg), A, alg) +initialize_output(::typeof(right_null_lq!), A, alg::AbstractAlgorithm) = + initialize_output(lq_null!, A, alg) +initialize_output(::typeof(right_null_svd!), A, alg::AbstractAlgorithm) = nothing + +function initialize_orth_svd(A::AbstractMatrix, F, alg) + S = Diagonal(initialize_output(svd_vals!, A, alg)) + return F[1], S, F[2] end +# fallback doesn't re-use F at all +initialize_orth_svd(A, F, alg) = initialize_output(svd_compact!, A, alg) # Implementation of orth functions # -------------------------------- -function left_orth!( - A, VC; - trunc = nothing, kind = isnothing(trunc) ? :qr : :svd, - alg_qr = (; positive = true), alg_polar = (;), alg_svd = (;) - ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for left_orth with kind=$kind")) - end - if kind == :qr - return left_orth_qr!(A, VC, alg_qr) - elseif kind == :polar - return left_orth_polar!(A, VC, alg_polar) - elseif kind == :svd - return left_orth_svd!(A, VC, alg_svd, trunc) - else - throw(ArgumentError("`left_orth!` received unknown value `kind = $kind`")) - end -end -function left_orth_qr!(A, VC, alg) - alg′ = select_algorithm(qr_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - return qr_compact!(A, VC, alg′) -end -function left_orth_polar!(A, VC, alg) - alg′ = select_algorithm(left_polar!, A, alg) - check_input(left_orth!, A, VC, alg′) - return left_polar!(A, VC, alg′) -end -function left_orth_svd!(A, VC, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - U, S, Vᴴ = svd_compact!(A, alg′) - V, C = VC - return copy!(V, U), mul!(C, S, Vᴴ) -end -function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg′)) - U, S, Vᴴ = svd_compact!(A, (V, S, C), alg′) - return U, lmul!(S, Vᴴ) -end -function left_orth_svd!(A, VC, alg, trunc) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - U, S, Vᴴ = svd_trunc!(A, alg_trunc) - V, C = VC - return copy!(V, U), mul!(C, S, Vᴴ) -end -function left_orth_svd!(A::AbstractMatrix, VC, alg, trunc) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(left_orth!, A, VC, alg′) - alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - V, C = VC - S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (V, S, C), alg_trunc) - return U, lmul!(S, Vᴴ) -end +left_orth!(A, VC, alg::AbstractAlgorithm) = left_orth_kind(alg)(A, VC, alg) +left_orth_qr!(A, VC, alg::AbstractAlgorithm) = qr_compact!(A, VC, alg) +left_orth_polar!(A, VC, alg::AbstractAlgorithm) = left_polar!(A, VC, alg) -function right_orth!( - A, CVᴴ; - trunc = nothing, kind = isnothing(trunc) ? :lq : :svd, - alg_lq = (; positive = true), alg_polar = (;), alg_svd = (;) - ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for right_orth with kind=$kind")) - end - if kind == :lq - return right_orth_lq!(A, CVᴴ, alg_lq) - elseif kind == :polar - return right_orth_polar!(A, CVᴴ, alg_polar) - elseif kind == :svd - return right_orth_svd!(A, CVᴴ, alg_svd, trunc) - else - throw(ArgumentError("`right_orth!` received unknown value `kind = $kind`")) - end -end -function right_orth_lq!(A, CVᴴ, alg) - alg′ = select_algorithm(lq_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - return lq_compact!(A, CVᴴ, alg′) -end -function right_orth_polar!(A, CVᴴ, alg) - alg′ = select_algorithm(right_polar!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - return right_polar!(A, CVᴴ, alg′) -end -function right_orth_svd!(A, CVᴴ, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - U, S, Vᴴ′ = svd_compact!(A, alg′) - C, Vᴴ = CVᴴ - return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) -end -function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg′)) - U, S, Vᴴ = svd_compact!(A, (C, S, Vᴴ), alg′) - return rmul!(U, S), Vᴴ -end -function right_orth_svd!(A, CVᴴ, alg, trunc) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - U, S, Vᴴ′ = svd_trunc!(A, alg_trunc) - C, Vᴴ = CVᴴ - return mul!(C, U, S), copy!(Vᴴ, Vᴴ′) -end -function right_orth_svd!(A::AbstractMatrix, CVᴴ, alg, trunc) - alg′ = select_algorithm(svd_compact!, A, alg) - check_input(right_orth!, A, CVᴴ, alg′) - alg_trunc = select_algorithm(svd_trunc!, A, alg′; trunc) - C, Vᴴ = CVᴴ - S = Diagonal(initialize_output(svd_vals!, A, alg_trunc.alg)) - U, S, Vᴴ = svd_trunc!(A, (C, S, Vᴴ), alg_trunc) - return rmul!(U, S), Vᴴ +right_orth!(A, CVᴴ, alg::AbstractAlgorithm) = right_orth_kind(alg)(A, CVᴴ, alg) +right_orth_lq!(A, CVᴴ, alg::AbstractAlgorithm) = lq_compact!(A, CVᴴ, alg) +right_orth_polar!(A, CVᴴ, alg::AbstractAlgorithm) = right_polar!(A, CVᴴ, alg) + +# orth_svd requires implementations of `lmul!` and `rmul!` +function left_orth_svd!(A, VC, alg::AbstractAlgorithm) + check_input(left_orth_svd!, A, VC, alg) + USVᴴ = initialize_orth_svd(A, VC, alg) + V, S, C = does_truncate(alg) ? svd_trunc!(A, USVᴴ, alg) : svd_compact!(A, USVᴴ, alg) + lmul!(S, C) + return V, C +end +function right_orth_svd!(A, CVᴴ, alg::AbstractAlgorithm) + check_input(right_orth_svd!, A, CVᴴ, alg) + USVᴴ = initialize_orth_svd(A, CVᴴ, alg) + C, S, Vᴴ = does_truncate(alg) ? svd_trunc!(A, USVᴴ, alg) : svd_compact!(A, USVᴴ, alg) + rmul!(C, S) + return C, Vᴴ end # Implementation of null functions # -------------------------------- -function null_truncation_strategy(; atol = nothing, rtol = nothing, maxnullity = nothing) - if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol) - return notrunc() - end - atol = @something atol 0 - rtol = @something rtol 0 - trunc = trunctol(; atol, rtol, keep_below = true) - return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev = false) : trunc -end +left_null!(A, N, alg::AbstractAlgorithm) = left_null_kind(alg)(A, N, alg) +left_null_qr!(A, N, alg::AbstractAlgorithm) = qr_null!(A, N, alg) -function left_null!( - A, N; - trunc = nothing, kind = isnothing(trunc) ? :qr : :svd, - alg_qr = (; positive = true), alg_svd = (;) - ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for left_null with kind=$kind")) - end - return if kind == :qr - left_null_qr!(A, N, alg_qr) - elseif kind == :svd - left_null_svd!(A, N, alg_svd, trunc) - else - throw(ArgumentError("`left_null!` received unknown value `kind = $kind`")) - end -end -function left_null_qr!(A, N, alg) - alg′ = select_algorithm(qr_null!, A, alg) - check_input(left_null!, A, N, alg′) - return qr_null!(A, N, alg′) -end -function left_null_svd!(A, N, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_full!, A, alg) - check_input(left_null!, A, N, alg′) - U, _, _ = svd_full!(A, alg′) - (m, n) = size(A) - return copy!(N, view(U, 1:m, (n + 1):m)) -end -function left_null_svd!(A, N, alg, trunc) - alg′ = select_algorithm(svd_full!, A, alg) - U, S, _ = svd_full!(A, alg′) - trunc′ = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return first(truncate(left_null!, (U, S), trunc′)) -end +right_null!(A, Nᴴ, alg::AbstractAlgorithm) = right_null_kind(alg)(A, Nᴴ, alg) +right_null_lq!(A, Nᴴ, alg::AbstractAlgorithm) = lq_null!(A, Nᴴ, alg) -function right_null!( - A, Nᴴ; - trunc = nothing, kind = isnothing(trunc) ? :lq : :svd, - alg_lq = (; positive = true), alg_svd = (;) - ) - if !isnothing(trunc) && kind != :svd - throw(ArgumentError("truncation not supported for right_null with kind=$kind")) - end - if kind == :lq - return right_null_lq!(A, Nᴴ, alg_lq) - elseif kind == :svd - return right_null_svd!(A, Nᴴ, alg_svd, trunc) - else - throw(ArgumentError("`right_null!` received unknown value `kind = $kind`")) - end -end -function right_null_lq!(A, Nᴴ, alg) - alg′ = select_algorithm(lq_null!, A, alg) - check_input(right_null!, A, Nᴴ, alg′) - return lq_null!(A, Nᴴ, alg′) -end -function right_null_svd!(A, Nᴴ, alg, trunc::Nothing = nothing) - alg′ = select_algorithm(svd_full!, A, alg) - check_input(right_null!, A, Nᴴ, alg′) - _, _, Vᴴ = svd_full!(A, alg′) - (m, n) = size(A) - return copy!(Nᴴ, view(Vᴴ, (m + 1):n, 1:n)) +function left_null_svd!(A, N, alg::TruncatedAlgorithm) + check_input(left_null_svd!, A, N, alg) + U, S, _ = svd_full!(A, alg.alg) + N, _ = truncate(left_null!, (U, S), alg.trunc) + return N end -function right_null_svd!(A, Nᴴ, alg, trunc) - alg′ = select_algorithm(svd_full!, A, alg) - check_input(right_null!, A, Nᴴ, alg′) - _, S, Vᴴ = svd_full!(A, alg′) - trunc′ = trunc isa TruncationStrategy ? trunc : - trunc isa NamedTuple ? null_truncation_strategy(; trunc...) : - throw(ArgumentError("Unknown truncation strategy: $trunc")) - return first(truncate(right_null!, (S, Vᴴ), trunc′)) +function right_null_svd!(A, Nᴴ, alg::TruncatedAlgorithm) + check_input(right_null_svd!, A, Nᴴ, alg) + _, S, Vᴴ = svd_full!(A, alg.alg) + Nᴴ, _ = truncate(right_null!, (S, Vᴴ), alg.trunc) + return Nᴴ end + +# randomized algorithms don't work for smallest values: +left_null_svd!(A, N, alg::TruncatedAlgorithm{<:GPU_Randomized}) = + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) +right_null_svd!(A, Nᴴ, alg::TruncatedAlgorithm{<:GPU_Randomized}) = + throw(ArgumentError("Randomized SVD ($alg) cannot be used for null spaces")) diff --git a/src/implementations/svd.jl b/src/implementations/svd.jl index 9ebc7319..fed36cd1 100644 --- a/src/implementations/svd.jl +++ b/src/implementations/svd.jl @@ -285,20 +285,6 @@ end # placed here to avoid code duplication since much of the logic is replicable across # CUDA and AMDGPU ### -const CUSOLVER_SVDAlgorithm = Union{ - CUSOLVER_QRIteration, - CUSOLVER_SVDPolar, - CUSOLVER_Jacobi, - CUSOLVER_Randomized, -} -const ROCSOLVER_SVDAlgorithm = Union{ - ROCSOLVER_QRIteration, - ROCSOLVER_Jacobi, -} -const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} - -const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} -const GPU_Randomized = Union{CUSOLVER_Randomized} function check_input( ::typeof(svd_trunc!), A::AbstractMatrix, USVᴴ, alg::CUSOLVER_Randomized diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index f6201b07..883c7759 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -26,6 +26,19 @@ function truncate(::typeof(right_null!), (S, Vᴴ), strategy::TruncationStrategy return Vᴴ[ind, :], ind end +# special case `NoTruncation` for null: should keep exact zeros due to rectangularity +function truncate(::typeof(left_null!), (U, S), strategy::NoTruncation) + m, n = size(S) + ind = (n + 1):m + return U[:, ind], ind +end +function truncate(::typeof(right_null!), (S, Vᴴ), strategy::NoTruncation) + m, n = size(S) + ind = (m + 1):n + return Vᴴ[ind, :], ind +end + + # findtruncated # ------------- # Generic fallback @@ -95,16 +108,20 @@ function _truncerr_impl(values::AbstractVector, I; atol::Real = 0, rtol::Real = end function findtruncated(values::AbstractVector, strategy::TruncationIntersection) - return mapreduce( - Base.Fix1(findtruncated, values), _ind_intersect, strategy.components; - init = trues(length(values)) - ) + length(strategy.components) == 0 && return eachindex(values) + length(strategy.components) == 1 && return findtruncated(values, only(strategy.components)) + + ind1 = findtruncated(values, strategy.components[1]) + ind2 = findtruncated(values, TruncationIntersection(Base.tail(strategy.components))) + return _ind_intersect(ind1, ind2) end function findtruncated_svd(values::AbstractVector, strategy::TruncationIntersection) - return mapreduce( - Base.Fix1(findtruncated_svd, values), _ind_intersect, - strategy.components; init = trues(length(values)) - ) + length(strategy.components) == 0 && return eachindex(values) + length(strategy.components) == 1 && return findtruncated_svd(values, only(strategy.components)) + + ind1 = findtruncated_svd(values, strategy.components[1]) + ind2 = findtruncated_svd(values, TruncationIntersection(Base.tail(strategy.components))) + return _ind_intersect(ind1, ind2) end # when one of the ind selections is a bitvector, have to handle differently diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index bdda6612..610c4928 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -36,6 +36,9 @@ elements of `L` are non-negative. @algdef LAPACK_HouseholderQL @algdef LAPACK_HouseholderRQ +left_orth_kind(::Union{LAPACK_HouseholderQR, LAPACK_HouseholderQL}) = left_orth_qr! +right_orth_kind(::Union{LAPACK_HouseholderLQ, LAPACK_HouseholderRQ}) = right_orth_lq! + # General Eigenvalue Decomposition # ------------------------------- """ @@ -117,6 +120,9 @@ const LAPACK_SVDAlgorithm = Union{ LAPACK_Jacobi, } +left_orth_kind(::LAPACK_SVDAlgorithm) = left_orth_svd! +right_orth_kind(::LAPACK_SVDAlgorithm) = right_orth_svd! + # ========================= # Polar decompositions # ========================= @@ -139,6 +145,9 @@ until convergence up to tolerance `tol`. """ @algdef PolarNewton +left_orth_kind(::Union{PolarViaSVD, PolarNewton}) = left_orth_polar! +right_orth_kind(::Union{PolarViaSVD, PolarNewton}) = right_orth_polar! + # ========================= # DIAGONAL ALGORITHMS # ========================= @@ -162,6 +171,8 @@ the diagonal elements of `R` are non-negative. """ @algdef CUSOLVER_HouseholderQR +left_orth_kind(::CUSOLVER_HouseholderQR) = left_orth_qr! + """ CUSOLVER_QRIteration() @@ -203,6 +214,8 @@ for more information. """ @algdef CUSOLVER_Randomized +does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true + """ CUSOLVER_Simple() @@ -222,6 +235,10 @@ Divide and Conquer algorithm. """ @algdef CUSOLVER_DivideAndConquer +const CUSOLVER_SVDAlgorithm = Union{ + CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, +} + # ========================= # ROCSOLVER ALGORITHMS # ========================= @@ -269,6 +286,7 @@ Divide and Conquer algorithm. """ @algdef ROCSOLVER_DivideAndConquer +const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} const GPU_Simple = Union{CUSOLVER_Simple} const GPU_EigAlgorithm = Union{GPU_Simple} @@ -277,8 +295,12 @@ const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} const GPU_Bisection = Union{ROCSOLVER_Bisection} const GPU_EighAlgorithm = Union{ - GPU_QRIteration, - GPU_Jacobi, - GPU_DivideAndConquer, - GPU_Bisection, + GPU_QRIteration, GPU_Jacobi, GPU_DivideAndConquer, GPU_Bisection, } +const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} + +const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} +const GPU_Randomized = Union{CUSOLVER_Randomized} + +left_orth_kind(::GPU_SVDAlgorithm) = left_orth_svd! +right_orth_kind(::GPU_SVDAlgorithm) = right_orth_svd! diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 867d69eb..1ff7c51a 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -45,22 +45,28 @@ selected according to a truncation strategy. The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded eigenvalues. -## Keyword arguments -The behavior of this function is controlled by the following keyword arguments: - -- `trunc`: Specifies the truncation strategy. This can be: - - A `NamedTuple` with fields `atol`, `rtol`, and/or `maxrank`, which will be converted to - a [`TruncationStrategy`](@ref). For details on available truncation strategies, see - [Truncations](@ref). - - A `TruncationStrategy` object directly (e.g., `truncrank(10)`, `trunctol(atol=1e-6)`, or - combinations using `&`). - - `nothing` (default), which keeps all eigenvalues. - -- Other keyword arguments are passed to the algorithm selection procedure. If no explicit - `alg` is provided, these keywords are used to select and configure the algorithm through - [`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm - selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) - for the default algorithm selection behavior. +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + +## Keyword Arguments +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the truncation strategy is already embedded in the algorithm. @@ -71,7 +77,7 @@ truncation strategy is already embedded in the algorithm. as it may not always be possible to use the provided `DV` as output. !!! note -$(docs_eig_note) +$docs_eig_note See also [`eig_full(!)`](@ref eig_full), [`eig_vals(!)`](@ref eig_vals), and [Truncations](@ref) for more information on truncation strategies. diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 314cb934..ae6a843c 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -50,22 +50,28 @@ selected according to a truncation strategy. The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded eigenvalues. +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + ## Keyword arguments -The behavior of this function is controlled by the following keyword arguments: - -- `trunc`: Specifies the truncation strategy. This can be: - - A `NamedTuple` with fields `atol`, `rtol`, and/or `maxrank`, which will be converted to - a [`TruncationStrategy`](@ref). For details on available truncation strategies, see - [Truncations](@ref). - - A `TruncationStrategy` object directly (e.g., `truncrank(10)`, `trunctol(atol=1e-6)`, or - combinations using `&`). - - `nothing` (default), which keeps all eigenvalues. - -- Other keyword arguments are passed to the algorithm selection procedure. If no explicit - `alg` is provided, these keywords are used to select and configure the algorithm through - [`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm - selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) - for the default algorithm selection behavior. +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the truncation strategy is already embedded in the algorithm. diff --git a/src/interface/orthnull.jl b/src/interface/orthnull.jl index c27179f8..16674abb 100644 --- a/src/interface/orthnull.jl +++ b/src/interface/orthnull.jl @@ -1,220 +1,464 @@ # Orth functions # -------------- """ - left_orth(A; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C - left_orth!(A, [VC]; [kind::Symbol, trunc, alg_qr, alg_polar, alg_svd]) -> V, C - -Compute an orthonormal basis `V` for the image of the matrix `A` of size `(m, n)`, -as well as a matrix `C` (the corestriction) such that `A` factors as `A = V * C`. -The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `trunc` can be used to control the -precision in determining the rank of `A` via its singular values. - -`trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxrank`. - -This is a high-level wrapper and will use one of the decompositions -[`qr_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and[`left_polar!`](@ref) -to compute the orthogonal basis `V`, as controlled by the keyword arguments. - -When `kind` is provided, its possible values are - -* `kind == :qr`: `V` and `C` are computed using the QR decomposition. - This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to - `qr_compact!(A, [VC], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)` - -* `kind == :polar`: `V` and `C` are computed using the polar decomposition, - This requires `isnothing(trunc)` and `left_orth!(A, [VC])` is equivalent to - `left_polar!(A, [VC], alg)` with a default value `alg = select_algorithm(left_polar!, A)` - -* `kind == :svd`: `V` and `C` are computed using the singular value decomposition `svd_trunc!` when a - truncation strategy is specified using the `trunc` keyword argument, and using `svd_compact!` otherwise. - `V` will contain the left singular vectors and `C` is computed as the product of the singular - values and the right singular vectors, i.e. with `U, S, Vᴴ = svd(A)`, we have - `V = U` and `C = S * Vᴴ`. - -When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)` -and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -for backend factorizations through the `alg_qr`, `alg_polar`, and `alg_svd` keyword arguments, -which will only be used if the corresponding factorization is called based on the other inputs. -If NamedTuples are passed as `alg_qr`, `alg_polar`, or `alg_svd`, a default algorithm is chosen -with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will -be used. + left_orth(A; [trunc], kwargs...) -> V, C + left_orth!(A, [VC]; [trunc], kwargs...) -> V, C + +Compute an orthonormal basis `V` for the image of the matrix `A`, as well as a matrix `C` +(the corestriction) such that `A` factors as `A = V * C`. + +This is a high-level wrapper where the keyword arguments can be used to specify and control +the specific orthogonal decomposition that should be used to factor `A`, whereas `trunc` +can optionally be used to control the precision in determining the rank of `A`, typically +via its singular values. + +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and +any non-trivial strategy typically requires an SVD-based decompositions. This keyword can +be either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) + +## Keyword arguments +There are 3 major modes of operation, based on the `alg` keyword, with slightly different +application purposes. + +### `alg::Nothing` +This default mode uses the presence of a truncation strategy `trunc` to determine an optimal +decomposition type, which will be QR-based for no truncation, or SVD-based for truncation. +The remaining keyword arguments are passed on directly to the algorithm selection procedure +of the chosen decomposition type. + +### `alg::Symbol` +Here, the driving selector is `alg`, and depending on its value, the algorithm selection +procedure takes other keywords into account: + +* `:qr` : Factorize via QR decomposition, with further customizations through the + `alg_qr` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + V, C = qr_compact(A; alg_qr...) +``` + +* `:polar` : Factorize via polar decomposition, with further customizations through the + `alg_polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + V, C = left_polar(A; alg_polar...) +``` + +* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. + This mode further allows truncation, which can be selected through the `trunc` argument. + This mode is roughly equivalent to: +```julia + V, S, C = svd_trunc(A; trunc, alg_svd...) + C = S * C +``` + +### `alg::AbstractAlgorithm` +In this expert mode the algorithm is supplied directly, and the kind of decomposition is +deduced from that. This hinges on the implementation of the algorithm trait +[`MatrixAlgebraKit.left_orth_kind(alg)`](@ref). + +--- !!! note - The bang method `left_orth!` optionally accepts the output structure and possibly destroys - the input matrix `A`. Always use the return value of the function as it may not always be - possible to use the provided `CV` as output. + The bang method `left_orth!` optionally accepts the output structure and possibly + destroys the input matrix `A`. Always use the return value of the function as it may + not always be possible to use the provided `CV` as output. -See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) +See also [`right_orth(!)`](@ref right_orth), [`left_null(!)`](@ref left_null), +[`right_null(!)`](@ref right_null) """ -function left_orth end -function left_orth! end -function left_orth!(A; kwargs...) - return left_orth!(A, initialize_output(left_orth!, A); kwargs...) -end -function left_orth(A; kwargs...) - return left_orth!(copy_input(left_orth, A); kwargs...) -end +@functiondef left_orth + +# helper functions +function left_orth_qr! end +function left_orth_polar! end +function left_orth_svd! end """ - right_orth(A; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ - right_orth!(A, [CVᴴ]; [kind::Symbol, trunc, alg_lq, alg_polar, alg_svd]) -> C, Vᴴ - -Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e. -for the image of `adjoint(A)`, as well as a matrix `C` such that `A = C * Vᴴ`. -The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `trunc` can be used to control the -precision in determining the rank of `A` via its singular values. - -`trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxrank`. - -This is a high-level wrapper and will use one of the decompositions -[`lq_compact!`](@ref), [`svd_compact!`](@ref)/[`svd_trunc!`](@ref), and -[`right_polar!`](@ref) to compute the orthogonal basis `V`, as controlled by the -keyword arguments. - -When `kind` is provided, its possible values are - -* `kind == :lq`: `C` and `Vᴴ` are computed using the QR decomposition, - This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to - `lq_compact!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` - -* `kind == :polar`: `C` and `Vᴴ` are computed using the polar decomposition, - This requires `isnothing(trunc)` and `right_orth!(A, [CVᴴ])` is equivalent to - `right_polar!(A, [CVᴴ], alg)` with a default value `alg = select_algorithm(right_polar!, A)` - -* `kind == :svd`: `C` and `Vᴴ` are computed using the singular value decomposition `svd_trunc!` when - a truncation strategy is specified using the `trunc` keyword argument, and using `svd_compact!` otherwise. - `V = adjoint(Vᴴ)` will contain the right singular vectors corresponding to the singular - values and `C` is computed as the product of the singular values and the right singular vectors, - i.e. with `U, S, Vᴴ = svd(A)`, we have `C = rmul!(U, S)` and `Vᴴ = Vᴴ`. - -When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` -and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -for backend factorizations through the `alg_lq`, `alg_polar`, and `alg_svd` keyword arguments, -which will only be used if the corresponding factorization is called based on the other inputs. -If `alg_lq`, `alg_polar`, or `alg_svd` are NamedTuples, a default algorithm is chosen -with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_lq` defaults to `(; positive=true)` so that by default a positive LQ decomposition will -be used. + right_orth(A; [trunc], kwargs...) -> C, Vᴴ + right_orth!(A, [CVᴴ]; [trunc], kwargs...) -> C, Vᴴ + +Compute an orthonormal basis `V = adjoint(Vᴴ)` for the coimage of the matrix `A`, i.e. for +the image of `adjoint(A)`, as well as a matrix `C` such that `A` factors as `A = C * Vᴴ`. + +This is a high-level wrapper where the keyword arguments can be used to specify and control +the specific orthogonal decomposition that should be used to factor `A`, whereas `trunc` can +optionally be used to control the precision in determining the rank of `A`, typically via +its singular values. + +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and +any non-trivial strategy typically requires an SVD-based decompositions. This keyword can +be either a `NamedTuple` or a [`TruncationStrategy`](@ref). + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$(docs_truncation_kwargs) + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$(docs_truncation_strategies) + +## Keyword arguments +There are 3 major modes of operation, based on the `alg` keyword, with slightly different +application purposes. + +### `alg::Nothing` +This default mode uses the presence of a truncation strategy `trunc` to determine an optimal +decomposition type, which will be LQ-based for no truncation, or SVD-based for truncation. +The remaining keyword arguments are passed on directly to the algorithm selection procedure +of the chosen decomposition type. + +### `alg::Symbol` +Here, the driving selector is `alg`, and depending on its value, the algorithm selection +procedure takes other keywords into account: + +* `:lq` : Factorize via LQ decomposition, with further customizations through the + `alg_lq` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + C, Vᴴ = lq_compact(A; alg_lq...) +``` + +* `:polar` : Factorize via polar decomposition, with further customizations through the + `alg_polar` keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + C, Vᴴ = right_polar(A; alg_polar...) +``` + +* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. + This mode further allows truncation, which can be selected through the `trunc` argument. + This mode is roughly equivalent to: +```julia + C, S, Vᴴ = svd_trunc(A; trunc, alg_svd...) + C = C * S +``` + +### `alg::AbstractAlgorithm` +In this expert mode the algorithm is supplied directly, and the kind of decomposition is +deduced from that. This hinges on the implementation of the algorithm trait +[`MatrixAlgebraKit.right_orth_kind(alg)`](@ref). + +--- !!! note - The bang method `right_orth!` optionally accepts the output structure and possibly destroys - the input matrix `A`. Always use the return value of the function as it may not always be - possible to use the provided `CVᴴ` as output. + The bang method `right_orth!` optionally accepts the output structure and possibly + destroys the input matrix `A`. Always use the return value of the function as it may not + always be possible to use the provided `CVᴴ` as output. -See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), [`right_null(!)`](@ref right_null) +See also [`left_orth(!)`](@ref left_orth), [`left_null(!)`](@ref left_null), +[`right_null(!)`](@ref right_null) """ -function right_orth end -function right_orth! end -function right_orth!(A; kwargs...) - return right_orth!(A, initialize_output(right_orth!, A); kwargs...) -end -function right_orth(A; kwargs...) - return right_orth!(copy_input(right_orth, A); kwargs...) -end +@functiondef right_orth + +# helper functions +function right_orth_lq! end +function right_orth_polar! end +function right_orth_svd! end # Null functions # -------------- """ - left_null(A; [kind::Symbol, trunc, alg_qr, alg_svd]) -> N - left_null!(A, [N]; [kind::Symbol, alg_qr, alg_svd]) -> N + left_null(A; [trunc], kwargs...) -> N + left_null!(A, [N]; [trunc], kwargs...) -> N -Compute an orthonormal basis `N` for the cokernel of the matrix `A` of size `(m, n)`, i.e. -the nullspace of `adjoint(A)`, such that `adjoint(A)*N ≈ 0` and `N'*N ≈ I`. -The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `trunc` can be used to control the -the rank of `A` via its singular values. +Compute an orthonormal basis `N` for the cokernel of the matrix `A`, i.e. the nullspace of +`adjoint(A)`, such that `adjoint(A) * N ≈ 0` and `N' * N ≈ I`. -`trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxnullity`. +This is a high-level wrapper where the keyword arguments can be used to specify and control +the underlying orthogonal decomposition that should be used to find the null space of `A'`, +whereas `trunc` can optionally be used to control the precision in determining the rank of +`A`, typically via its singular values. -This is a high-level wrapper and will use one of the decompositions `qr!` or `svd!` -to compute the orthogonal basis `N`, as controlled by the keyword arguments. +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and any +non-trivial strategy typically requires an SVD-based decomposition. This keyword can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). -When `kind` is provided, its possible values are +### `trunc::NamedTuple` +The supported truncation keyword arguments are: -* `kind == :qr`: `N` is computed using the QR decomposition. - This requires `isnothing(trunc)` and `left_null!(A, [N], kind=:qr)` is equivalent to - `qr_null!(A, [N], alg)` with a default value `alg = select_algorithm(qr_compact!, A; positive=true)` +$(docs_null_truncation_kwargs) -* `kind == :svd`: `N` is computed using the singular value decomposition and will contain - the left singular vectors corresponding to either the zero singular values if `trunc` - isn't specified or the singular values specified by `trunc`. +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. By default, +MatrixAlgebraKit supplies the following: -When `kind` is not provided, the default value is `:qr` when `isnothing(trunc)` -and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg_qr` and `alg_svd` keyword arguments, which will only be used by the corresponding -factorization backend. If `alg_qr` or `alg_svd` are NamedTuples, a default algorithm is chosen -with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_qr` defaults to `(; positive=true)` so that by default a positive QR decomposition will -be used. +$(docs_truncation_strategies) !!! note - The bang method `left_null!` optionally accepts the output structure and possibly destroys - the input matrix `A`. Always use the return value of the function as it may not always be - possible to use the provided `N` as output. + Here [`notrunc`](@ref) has special meaning, and signifies keeping the values that + correspond to the exact zeros determined from the additional columns of `A`. + +## Keyword arguments +There are 3 major modes of operation, based on the `alg` keyword, with slightly different +application purposes. + +### `alg::Nothing` +This default mode uses the presence of a truncation strategy `trunc` to determine an optimal +decomposition type, which will be QR-based for no truncation, or SVD-based for truncation. +The remaining keyword arguments are passed on directly to the algorithm selection procedure +of the chosen decomposition type. + +### `alg::Symbol` +Here, the driving selector is `alg`, and depending on its value, the algorithm selection +procedure takes other keywords into account: + +* `:qr` : Factorize via QR nullspace, with further customizations through the `alg_qr` + keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + N = qr_null(A; alg_qr...) +``` + +* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. + This mode further allows truncation, which can be selected through the `trunc` argument. + It is roughly equivalent to: +```julia + U, S, _ = svd_trunc(A; trunc, alg_svd...) + N = truncate(left_null, (U, S), trunc) +``` + +### `alg::AbstractAlgorithm` +In this expert mode the algorithm is supplied directly, and the kind of decomposition is +deduced from that. This hinges on the implementation of the algorithm trait +[`MatrixAlgebraKit.left_null_kind(alg)`](@ref). + +--- -See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) +!!! note + The bang method `left_null!` optionally accepts the output structure and possibly + destroys the input matrix `A`. Always use the return value of the function as it may not + always be possible to use the provided `N` as output. + +See also [`right_null(!)`](@ref right_null), [`left_orth(!)`](@ref left_orth), +[`right_orth(!)`](@ref right_orth) """ -function left_null end -function left_null! end -function left_null!(A; kwargs...) - return left_null!(A, initialize_output(left_null!, A); kwargs...) -end -function left_null(A; kwargs...) - return left_null!(copy_input(left_null, A); kwargs...) -end +@functiondef left_null + +# helper functions +function left_null_qr! end +function left_null_svd! end """ - right_null(A; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ - right_null!(A, [Nᴴ]; [kind::Symbol, alg_lq, alg_svd]) -> Nᴴ + right_null(A; [trunc], kwargs...) -> Nᴴ + right_null!(A, [Nᴴ]; [trunc], kwargs...) -> Nᴴ -Compute an orthonormal basis `N = adjoint(Nᴴ)` for the kernel or nullspace of the matrix `A` -of size `(m, n)`, such that `A*adjoint(Nᴴ) ≈ 0` and `Nᴴ*adjoint(Nᴴ) ≈ I`. -The keyword argument `kind` can be used to specify the specific orthogonal decomposition -that should be used to factor `A`, whereas `trunc` can be used to control the -the rank of `A` via its singular values. +Compute an orthonormal basis `N = adjoint(Nᴴ)` for the kernel of the matrix `A`, i.e. the +nullspace of `A`, such that `A * Nᴴ' ≈ 0` and `Nᴴ * Nᴴ' ≈ I`. -`trunc` can either be a truncation strategy object or a NamedTuple with fields -`atol`, `rtol`, and `maxnullity`. +This is a high-level wrapper where the keyword arguments can be used to specify and control +the underlying orthogonal decomposition that should be used to find the null space of `A`, +whereas `trunc` can optionally be used to control the precision in determining the rank of +`A`, typically via its singular values. -This is a high-level wrapper and will use one of the decompositions `lq!` or `svd!` -to compute the orthogonal basis `Nᴴ`, as controlled by the keyword arguments. +## Truncation +The optional truncation strategy can be controlled via the `trunc` keyword argument, and any +non-trivial strategy typically requires an SVD-based decomposition. This keyword can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). -When `kind` is provided, its possible values are +### `trunc::NamedTuple` +The supported truncation keyword arguments are: -* `kind == :lq`: `Nᴴ` is computed using the (nonpositive) LQ decomposition. - This requires `isnothing(trunc)` and `right_null!(A, [Nᴴ], kind=:lq)` is equivalent to - `lq_null!(A, [Nᴴ], alg)` with a default value `alg = select_algorithm(lq_compact!, A; positive=true)` +$(docs_null_truncation_kwargs) -* `kind == :svd`: `N` is computed using the singular value decomposition and will contain - the left singular vectors corresponding to the singular values that - are smaller than `max(atol, rtol * σ₁)`, where `σ₁` is the largest singular value of `A`. +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. By default, +MatrixAlgebraKit supplies the following: -When `kind` is not provided, the default value is `:lq` when `isnothing(trunc)` -and `:svd` otherwise. Finally, finer control is obtained by providing an explicit algorithm -using the `alg_lq` and `alg_svd` keyword arguments, which will only be used by the corresponding -factorization backend. If `alg_lq` or `alg_svd` are NamedTuples, a default algorithm is chosen -with `select_algorithm` and the NamedTuple is passed as keyword arguments to that algorithm. -`alg_lq` defaults to `(; positive=true)` so that by default a positive LQ decomposition will -be used. +$(docs_truncation_strategies) !!! note - The bang method `right_null!` optionally accepts the output structure and possibly destroys - the input matrix `A`. Always use the return value of the function as it may not always be - possible to use the provided `Nᴴ` as output. + Here [`notrunc`](@ref) has special meaning, and signifies keeping the values that + correspond to the exact zeros determined from the additional rows of `A`. + +## Keyword arguments +There are 3 major modes of operation, based on the `alg` keyword, with slightly different +application purposes. + +### `alg::Nothing` +This default mode uses the presence of a truncation strategy `trunc` to determine an optimal +decomposition type, which will be LQ-based for no truncation, or SVD-based for truncation. +The remaining keyword arguments are passed on directly to the algorithm selection procedure +of the chosen decomposition type. + +### `alg::Symbol` +Here, the driving selector is `alg`, and depending on its value, the algorithm selection +procedure takes other keywords into account: + +* `:lq` : Factorize via LQ nullspace, with further customizations through the `alg_lq` + keyword. This mode requires `isnothing(trunc)`, and is equivalent to +```julia + Nᴴ = lq_null(A; alg_qr...) +``` + +* `:svd` : Factorize via SVD, with further customizations through the `alg_svd` keyword. + This mode further allows truncation, which can be selected through the `trunc` argument. + It is roughly equivalent to: +```julia + _, S, Vᴴ = svd_trunc(A; trunc, alg_svd...) + Nᴴ = truncate(right_null, (S, Vᴴ), trunc) +``` + +### `alg::AbstractAlgorithm` +In this expert mode the algorithm is supplied directly, and the kind of decomposition is +deduced from that. This hinges on the implementation of the algorithm trait +[`MatrixAlgebraKit.right_null_kind(alg)`](@ref). + +--- -See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), [`right_orth(!)`](@ref right_orth) +!!! note + The bang method `right_null!` optionally accepts the output structure and possibly + destroys the input matrix `A`. Always use the return value of the function as it may not + always be possible to use the provided `Nᴴ` as output. + +See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth), +[`right_orth(!)`](@ref right_orth) """ -function right_null end -function right_null! end -function right_null!(A; kwargs...) - return right_null!(A, initialize_output(right_null!, A); kwargs...) +@functiondef right_null + +# helper functions +function right_null_lq! end +function right_null_svd! end + +# Algorithm selection +# ------------------- +# specific override for `alg::Symbol` case, to allow for choosing the kind of factorization. +function select_algorithm(::typeof(left_orth!), A, alg::Symbol; trunc = nothing, kwargs...) + alg === :svd && return select_algorithm( + left_orth_svd!, A, get(kwargs, :alg_svd, nothing); trunc + ) + + isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) + + alg === :qr && return select_algorithm(left_orth_qr!, A, get(kwargs, :alg_qr, nothing)) + alg === :polar && return select_algorithm(left_orth_polar!, A, get(kwargs, :alg_polar, nothing)) + + throw(ArgumentError(lazy"Unknown alg symbol $alg")) end -function right_null(A; kwargs...) - return right_null!(copy_input(right_null, A); kwargs...) + +default_algorithm(::typeof(left_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = + isnothing(trunc) ? select_algorithm(left_orth_qr!, A; kwargs...) : + select_algorithm(left_orth_svd!, A; trunc, kwargs...) +# disambiguate +default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = + isnothing(trunc) ? select_algorithm(left_orth_qr!, A; kwargs...) : + select_algorithm(left_orth_svd!, A; trunc, kwargs...) + + +select_algorithm(::typeof(left_orth_qr!), A, alg = nothing; kwargs...) = + select_algorithm(qr_compact!, A, alg; kwargs...) +select_algorithm(::typeof(left_orth_polar!), A, alg = nothing; kwargs...) = + select_algorithm(left_polar!, A, alg; kwargs...) +select_algorithm(::typeof(left_orth_svd!), A, alg = nothing; trunc = nothing, kwargs...) = + isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : + select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) + +# specific override for `alg::Symbol` case, to allow for choosing the kind of factorization. +function select_algorithm(::typeof(right_orth!), A, alg::Symbol; trunc = nothing, kwargs...) + alg === :svd && return select_algorithm( + right_orth_svd!, A, get(kwargs, :alg_svd, nothing); trunc + ) + + isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) + + alg === :lq && return select_algorithm(right_orth_lq!, A, get(kwargs, :alg_lq, nothing)) + alg === :polar && return select_algorithm(right_orth_polar!, A, get(kwargs, :alg_polar, nothing)) + + throw(ArgumentError(lazy"Unknown alg symbol $alg")) +end + +default_algorithm(::typeof(right_orth!), A::TA; trunc = nothing, kwargs...) where {TA} = + isnothing(trunc) ? select_algorithm(right_orth_lq!, A; kwargs...) : + select_algorithm(right_orth_svd!, A; trunc, kwargs...) +# disambiguate: +default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} = + isnothing(trunc) ? select_algorithm(right_orth_lq!, A; kwargs...) : + select_algorithm(right_orth_svd!, A; trunc, kwargs...) + + +select_algorithm(::typeof(right_orth_lq!), A, alg = nothing; kwargs...) = + select_algorithm(lq_compact!, A, alg; kwargs...) +select_algorithm(::typeof(right_orth_polar!), A, alg = nothing; kwargs...) = + select_algorithm(right_polar!, A, alg; kwargs...) +select_algorithm(::typeof(right_orth_svd!), A, alg = nothing; trunc = nothing, kwargs...) = + isnothing(trunc) ? select_algorithm(svd_compact!, A, alg; kwargs...) : + select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) + +function select_algorithm(::typeof(left_null!), A, alg::Symbol; trunc = nothing, kwargs...) + alg === :svd && return select_algorithm( + left_null_svd!, A, get(kwargs, :alg_svd, nothing); trunc + ) + + isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) + + alg === :qr && return select_algorithm(left_null_qr!, A, get(kwargs, :alg_qr, nothing)) + + throw(ArgumentError(lazy"unkown alg symbol $alg")) +end + +default_algorithm(::typeof(left_null!), A::TA; trunc = nothing, kwargs...) where {TA} = + isnothing(trunc) ? select_algorithm(left_null_qr!, A; kwargs...) : + select_algorithm(left_null_svd!, A; trunc, kwargs...) +# disambiguate +default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = + isnothing(trunc) ? select_algorithm(left_null_qr!, A; kwargs...) : + select_algorithm(left_null_svd!, A; trunc, kwargs...) + +select_algorithm(::typeof(left_null_qr!), A, alg = nothing; kwargs...) = + select_algorithm(qr_null!, A, alg; kwargs...) +function select_algorithm(::typeof(left_null_svd!), A, alg = nothing; trunc = nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) + end +end + +function select_algorithm(::typeof(right_null!), A, alg::Symbol; trunc = nothing, kwargs...) + alg === :svd && return select_algorithm( + right_null_svd!, A, get(kwargs, :alg_svd, nothing); trunc + ) + + isnothing(trunc) || throw(ArgumentError(lazy"alg ($alg) incompatible with truncation")) + + alg === :lq && return select_algorithm(right_null_lq!, A, get(kwargs, :alg_lq, nothing)) + + throw(ArgumentError(lazy"unkown alg symbol $alg")) +end + +default_algorithm(::typeof(right_null!), A::TA; trunc = nothing, kwargs...) where {TA} = + isnothing(trunc) ? select_algorithm(right_null_lq!, A; kwargs...) : + select_algorithm(right_null_svd!, A; trunc, kwargs...) +default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} = + isnothing(trunc) ? select_algorithm(right_null_lq!, A; kwargs...) : + select_algorithm(right_null_svd!, A; trunc, kwargs...) + +select_algorithm(::typeof(right_null_lq!), A, alg = nothing; kwargs...) = + select_algorithm(lq_null!, A, alg; kwargs...) + +function select_algorithm(::typeof(right_null_svd!), A, alg; trunc = nothing, kwargs...) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_svd = select_algorithm(svd_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_null_truncation(trunc)) + end + end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 606a1c4e..04e7121c 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -55,22 +55,28 @@ square diagonal matrix of size `(k, k)`, with `k` is set by the truncation strat The function also returns `ϵ`, the truncation error defined as the 2-norm of the discarded singular values. +## Truncation +The truncation strategy can be controlled via the `trunc` keyword argument. This can be +either a `NamedTuple` or a [`TruncationStrategy`](@ref). If `trunc` is not provided or +nothing, all values will be kept. + +### `trunc::NamedTuple` +The supported truncation keyword arguments are: + +$docs_truncation_kwargs + +### `trunc::TruncationStrategy` +For more control, a truncation strategy can be supplied directly. +By default, MatrixAlgebraKit supplies the following: + +$docs_truncation_strategies + ## Keyword arguments -The behavior of this function is controlled by the following keyword arguments: - -- `trunc`: Specifies the truncation strategy. This can be: - - A `NamedTuple` with fields `atol`, `rtol`, and/or `maxrank`, which will be converted to - a [`TruncationStrategy`](@ref). For details on available truncation strategies, see - [Truncations](@ref). - - A `TruncationStrategy` object directly (e.g., `truncrank(10)`, `trunctol(atol=1e-6)`, or - combinations using `&`). - - `nothing` (default), which keeps all singular values. - -- Other keyword arguments are passed to the algorithm selection procedure. If no explicit - `alg` is provided, these keywords are used to select and configure the algorithm through - [`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm - selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) - for the default algorithm selection behavior. +Other keyword arguments are passed to the algorithm selection procedure. If no explicit +`alg` is provided, these keywords are used to select and configure the algorithm through +[`MatrixAlgebraKit.select_algorithm`](@ref). The remaining keywords after algorithm +selection are passed to the algorithm constructor. See [`MatrixAlgebraKit.default_algorithm`](@ref) +for the default algorithm selection behavior. When `alg` is a [`TruncatedAlgorithm`](@ref), the `trunc` keyword cannot be specified as the truncation strategy is already embedded in the algorithm. diff --git a/src/interface/truncation.jl b/src/interface/truncation.jl index db417edb..a1e993de 100644 --- a/src/interface/truncation.jl +++ b/src/interface/truncation.jl @@ -1,3 +1,25 @@ +const docs_truncation_kwargs = """ +* `atol::Real` : Absolute tolerance for the truncation +* `rtol::Real` : Relative tolerance for the truncation +* `maxrank::Real` : Maximal rank for the truncation +* `maxerror::Real` : Maximal truncation error. +* `filter` : Custom filter to select truncated values. +""" + +const docs_truncation_strategies = """ +- [`notrunc`](@ref) +- [`truncrank`](@ref) +- [`trunctol`](@ref) +- [`truncerror`](@ref) +- [`truncfilter`](@ref) +""" + +const docs_null_truncation_kwargs = """ +* `atol::Real` : Absolute tolerance for the truncation +* `rtol::Real` : Relative tolerance for the truncation +* `maxnullity::Real` : Maximal rank for the truncation +""" + """ TruncationStrategy(; kwargs...) @@ -8,11 +30,7 @@ The following keyword arguments are all optional, and their default value (`noth will be ignored. It is also allowed to combine multiple of these, in which case the kept values will consist of the intersection of the different truncated strategies. -- `atol::Real` : Absolute tolerance for the truncation -- `rtol::Real` : Relative tolerance for the truncation -- `maxrank::Real` : Maximal rank for the truncation -- `maxerror::Real` : Maximal truncation error. -- `filter` : Custom filter to select truncated values. +$docs_truncation_kwargs """ function TruncationStrategy(; atol::Union{Real, Nothing} = nothing, @@ -36,6 +54,28 @@ function TruncationStrategy(; return strategy end +""" + null_truncation_strategy(; kwargs...) + +Select a nullspace truncation strategy based on the provided keyword arguments. + +## Keyword arguments +The following keyword arguments are all optional, and their default value (`nothing`) +will be ignored. It is also allowed to combine multiple of these, in which case the +discarded values will consist of the intersection of the different truncated strategies. + +$docs_null_truncation_kwargs +""" +function null_truncation_strategy(; atol = nothing, rtol = nothing, maxnullity = nothing) + if isnothing(maxnullity) && isnothing(atol) && isnothing(rtol) + return notrunc() + end + atol = @something atol 0 + rtol = @something rtol 0 + trunc = trunctol(; atol, rtol, keep_below = true) + return !isnothing(maxnullity) ? trunc & truncrank(maxnullity; rev = false) : trunc +end + """ NoTruncation() diff --git a/test/chainrules.jl b/test/chainrules.jl index 441a17fb..ba3f0681 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -541,18 +541,18 @@ end ) test_rrule( config, left_orth, A; - fkwargs = (; kind = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + fkwargs = (; alg = :qr), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) m >= n && test_rrule( config, left_orth, A; - fkwargs = (; kind = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) - ΔN = left_orth(A; kind = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) + ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) test_rrule( config, left_null, A; - fkwargs = (; kind = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, + fkwargs = (; alg = :qr), output_tangent = ΔN, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) @@ -561,19 +561,19 @@ end atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) test_rrule( - config, right_orth, A; fkwargs = (; kind = :lq), + config, right_orth, A; fkwargs = (; alg = :lq), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) m <= n && test_rrule( - config, right_orth, A; fkwargs = (; kind = :polar), + config, right_orth, A; fkwargs = (; alg = :polar), atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) - ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; kind = :lq)[2] + ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] test_rrule( config, right_null, A; - fkwargs = (; kind = :lq), output_tangent = ΔNᴴ, + fkwargs = (; alg = :lq), output_tangent = ΔNᴴ, atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) end diff --git a/test/linearmap.jl b/test/linearmap.jl index 61753cde..bbf84e5e 100644 --- a/test/linearmap.jl +++ b/test/linearmap.jl @@ -34,13 +34,6 @@ module LinearMaps LinearMap.(MAK.$f!(parent(A), parent.(F), alg)) end - for f! in (:left_orth!, :right_orth!) - @eval MAK.check_input(::typeof($f!), A::LinearMap, F, alg) = - MAK.check_input($f!, parent(A), parent.(F), alg) - @eval MAK.initialize_output(::typeof($f!), A::LinearMap) = - LinearMap.(MAK.initialize_output($f!, parent(A))) - end - for f in (:qr, :lq, :svd) default_f = Symbol(:default_, f, :_algorithm) @eval MAK.$default_f(::Type{LinearMap{A}}; kwargs...) where {A} = MAK.$default_f(A; kwargs...) diff --git a/test/orthnull.jl b/test/orthnull.jl index cee54252..f12ae91c 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -2,15 +2,12 @@ using MatrixAlgebraKit using Test using TestExtras using StableRNGs -using LinearAlgebra: LinearAlgebra, I, mul! -using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, - initialize_output, AbstractAlgorithm +using LinearAlgebra: LinearAlgebra, I # testing non-AbstractArray codepaths: include("linearmap.jl") eltypes = (Float32, Float64, ComplexF32, ComplexF64) - @testset "left_orth and left_null for T = $T" for T in eltypes rng = StableRNG(123) m = 54 @@ -29,7 +26,7 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test V * V' + N * N' ≈ I M = LinearMap(A) - VM, CM = @constinferred left_orth(M; kind = :svd) + VM, CM = @constinferred left_orth(M; alg = :svd) @test parent(VM) * parent(CM) ≈ A if m > n @@ -45,25 +42,33 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test isisometric(N) end - for alg_qr in ((; positive = true), (; positive = false), LAPACK_HouseholderQR()) - V, C = @constinferred left_orth(A; alg_qr) - N = @constinferred left_null(A; alg_qr) - @test V isa Matrix{T} && size(V) == (m, minmn) - @test C isa Matrix{T} && size(C) == (minmn, n) - @test N isa Matrix{T} && size(N) == (m, m - minmn) - @test V * C ≈ A - @test isisometric(V) - @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) - @test isisometric(N) - @test V * V' + N * N' ≈ I - end + # passing a kind and some kwargs + V, C = @constinferred left_orth(A; alg = :qr, alg_qr = (; positive = true)) + N = @constinferred left_null(A; alg = :qr, alg_qr = (; positive = true)) + @test V isa Matrix{T} && size(V) == (m, minmn) + @test C isa Matrix{T} && size(C) == (minmn, n) + @test N isa Matrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + @test V * V' + N * N' ≈ I + + # passing an algorithm + V, C = @constinferred left_orth(A; alg = LAPACK_HouseholderQR()) + N = @constinferred left_null(A; alg = :qr, alg_qr = (; positive = true)) + @test V isa Matrix{T} && size(V) == (m, minmn) + @test C isa Matrix{T} && size(C) == (minmn, n) + @test N isa Matrix{T} && size(N) == (m, m - minmn) + @test V * C ≈ A + @test isisometric(V) + @test LinearAlgebra.norm(A' * N) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) + @test isisometric(N) + @test V * V' + N * N' ≈ I Ac = similar(A) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C)) N2 = @constinferred left_null!(copy!(Ac, A), N) - @test V2 === V - @test C2 === C - @test N2 === N @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -73,9 +78,6 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) atol = eps(real(T)) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = (; atol = atol)) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = (; atol = atol)) - @test V2 !== V - @test C2 !== C - @test N2 !== C @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -89,9 +91,6 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) ) V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); trunc = trunc_orth) N2 = @constinferred left_null!(copy!(Ac, A), N; trunc = trunc_null) - @test V2 !== V - @test C2 !== C - @test N2 !== C @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -99,49 +98,40 @@ eltypes = (Float32, Float64, ComplexF32, ComplexF64) @test V2 * V2' + N2 * N2' ≈ I end - for kind in (:qr, :polar, :svd) # explicit kind kwarg - m < n && kind == :polar && continue - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind = kind) - @test V2 === V - @test C2 === C + for alg in (:qr, :polar, :svd) # explicit kind kwarg + m < n && alg === :polar && continue + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg) @test V2 * C2 ≈ A @test isisometric(V2) - if kind != :polar - N2 = @constinferred left_null!(copy!(Ac, A), N; kind = kind) - @test N2 === N + if alg != :polar + N2 = @constinferred left_null!(copy!(Ac, A), N; alg) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N2) @test V2 * V2' + N2 * N2' ≈ I end # with kind and tol kwargs - if kind == :svd - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind, trunc = (; atol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; kind, trunc = (; atol)) - @test V2 !== V - @test C2 !== C - @test N2 !== C + if alg == :svd + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) @test V2 * C2 ≈ A @test V2' * V2 ≈ I @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test N2' * N2 ≈ I @test V2 * V2' + N2 * N2' ≈ I - V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); kind, trunc = (; rtol)) - N2 = @constinferred left_null!(copy!(Ac, A), N; kind, trunc = (; rtol)) - @test V2 !== V - @test C2 !== C - @test N2 !== C + V2, C2 = @constinferred left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + N2 = @constinferred left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) @test V2 * C2 ≈ A @test isisometric(V2) @test LinearAlgebra.norm(A' * N2) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(N2) @test V2 * V2' + N2 * N2' ≈ I else - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind, trunc = (; atol)) - @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); kind, trunc = (; rtol)) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind, trunc = (; atol)) - @test_throws ArgumentError left_null!(copy!(Ac, A), N; kind, trunc = (; rtol)) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; atol)) + @test_throws ArgumentError left_orth!(copy!(Ac, A), (V, C); alg, trunc = (; rtol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; atol)) + @test_throws ArgumentError left_null!(copy!(Ac, A), N; alg, trunc = (; rtol)) end end end @@ -165,15 +155,12 @@ end @test Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I M = LinearMap(A) - CM, VMᴴ = @constinferred right_orth(M; kind = :svd) + CM, VMᴴ = @constinferred right_orth(M; alg = :svd) @test parent(CM) * parent(VMᴴ) ≈ A Ac = similar(A) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ) - @test C2 === C - @test Vᴴ2 === Vᴴ - @test Nᴴ2 === Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -183,9 +170,6 @@ end atol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; atol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; atol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @@ -195,57 +179,45 @@ end rtol = eps(real(T)) C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); trunc = (; rtol)) Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; trunc = (; rtol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - for kind in (:lq, :polar, :svd) - n < m && kind == :polar && continue - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind) - @test C2 === C - @test Vᴴ2 === Vᴴ + for alg in (:lq, :polar, :svd) + n < m && alg == :polar && continue + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) - if kind != :polar - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind) - @test Nᴴ2 === Nᴴ + if alg != :polar + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I end - if kind == :svd - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind, trunc = (; atol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind, trunc = (; atol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ + if alg == :svd + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I - C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); kind, trunc = (; rtol)) - Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; kind, trunc = (; rtol)) - @test C2 !== C - @test Vᴴ2 !== Vᴴ - @test Nᴴ2 !== Nᴴ + C2, Vᴴ2 = @constinferred right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) + Nᴴ2 = @constinferred right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) @test C2 * Vᴴ2 ≈ A @test isisometric(Vᴴ2; side = :right) @test LinearAlgebra.norm(A * adjoint(Nᴴ2)) ≈ 0 atol = MatrixAlgebraKit.defaulttol(T) @test isisometric(Nᴴ2; side = :right) @test Vᴴ2' * Vᴴ2 + Nᴴ2' * Nᴴ2 ≈ I else - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind, trunc = (; atol)) - @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); kind, trunc = (; rtol)) - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind, trunc = (; atol)) - @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; kind, trunc = (; rtol)) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; atol)) + @test_throws ArgumentError right_orth!(copy!(Ac, A), (C, Vᴴ); alg, trunc = (; rtol)) + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; atol)) + @test_throws ArgumentError right_null!(copy!(Ac, A), Nᴴ; alg, trunc = (; rtol)) end end end