Skip to content
8 changes: 4 additions & 4 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ end

function MatrixAlgebraKit.householder_qr!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
Expand Down Expand Up @@ -102,9 +102,9 @@ end

function MatrixAlgebraKit.householder_qr_null!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
Expand Down
2 changes: 1 addition & 1 deletion src/MatrixAlgebraKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export left_polar!, right_polar!
export left_orth, right_orth, left_null, right_null
export left_orth!, right_orth!, left_null!, right_null!

export Native_HouseholderQR, Native_HouseholderLQ
export Householder, Native_HouseholderQR, Native_HouseholderLQ
export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert,
LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations,
LAPACK_DivideAndConquer, LAPACK_Jacobi
Expand Down
28 changes: 20 additions & 8 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,20 @@ See also [`@algdef`](@ref).
"""
struct Algorithm{name, K} <: AbstractAlgorithm
kwargs::K

# Ensure keywords are always in canonical order
function Algorithm{Name}(kwargs::NamedTuple) where {Name}
kwargs_sorted = _canonicalize_namedtuple(kwargs)
return new{Name, typeof(kwargs_sorted)}(kwargs_sorted)
end
end
Algorithm{Name}(; kwargs...) where {Name} = Algorithm{Name}(NamedTuple(kwargs))

# Utility function to canonicalize keys
# TODO: generated function can likely be dropped once Julia 1.10 support is dropped
@generated _canonicalize_namedtuple(nt::NamedTuple{N}) where {N} =
:(NamedTuple{$(Tuple(sort(collect(N))))}(nt))

name(alg::Algorithm) = name(typeof(alg))
name(::Type{<:Algorithm{N}}) where {N} = N

Expand Down Expand Up @@ -88,7 +101,9 @@ Finally, the same behavior is obtained when the keyword arguments are
passed as the third positional argument in the form of a `NamedTuple`.
""" select_algorithm

@inline function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
# WARNING: In order to keep everything type stable, this function is marked as foldable.
# This mostly means that the `default_algorithm` implementation must be foldable as well
Base.@assume_effects :foldable function select_algorithm(f::F, A, alg::Alg = nothing; kwargs...) where {F, Alg}
if isnothing(alg)
return default_algorithm(f, A; kwargs...)
elseif alg isa Symbol
Expand Down Expand Up @@ -117,8 +132,10 @@ In general, this is called by [`select_algorithm`](@ref) if no algorithm is spec
explicitly.
New types should prefer to register their default algorithms in the type domain.
""" default_algorithm
default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...)
default_algorithm(f::F, A, B; kwargs...) where {F} = default_algorithm(f, typeof(A), typeof(B); kwargs...)
@inline default_algorithm(f::F, A; kwargs...) where {F} =
default_algorithm(f, typeof(A); kwargs...)
@inline default_algorithm(f::F, A, B; kwargs...) where {F} =
default_algorithm(f, typeof(A), typeof(B); kwargs...)
# avoid infinite recursion:
function default_algorithm(f::F, ::Type{T}; kwargs...) where {F, T}
throw(MethodError(default_algorithm, (f, T)))
Expand Down Expand Up @@ -299,11 +316,6 @@ macro algdef(name)
return esc(
quote
const $name{K} = Algorithm{$(QuoteNode(name)), K}
function $name(; kwargs...)
# TODO: is this necessary/useful?
kw = NamedTuple(kwargs) # normalize type
return $name{typeof(kw)}(kw)
end
function Base.show(io::IO, alg::$name)
return ($_show_alg)(io, alg)
end
Expand Down
17 changes: 10 additions & 7 deletions src/implementations/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,10 @@ householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) =
lq_via_qr!(A, L, Q, Householder(; driver, kwargs...))
function householder_lq!(
driver::LAPACK, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive = true, pivoted = false,
blocksize = ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
positive = true, pivoted = false, blocksize::Int = 0
)
blocksize = blocksize > 0 ? blocksize : ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))

# error messages for disallowing driver - setting combinations
pivoted && (blocksize > 1) &&
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
Expand Down Expand Up @@ -176,10 +177,10 @@ function householder_lq!(
end
function householder_lq!(
driver::Native, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
# error messages for disallowing driver - setting combinations
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition"))
Expand Down Expand Up @@ -225,8 +226,10 @@ householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...
lq_null_via_qr!(A, Nᴴ, Householder(; driver, kwargs...))
function householder_lq_null!(
driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = pivoted ? 1 : YALAPACK.default_qr_blocksize(A)
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize = blocksize > 0 ? blocksize : (pivoted ? 1 : YALAPACK.default_qr_blocksize(A))

# error messages for disallowing driver - setting combinations
pivoted && (blocksize > 1) &&
throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition"))
Expand All @@ -248,10 +251,10 @@ function householder_lq_null!(
end
function householder_lq_null!(
driver::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
# error messages for disallowing driver - setting combinations
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked LQ decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted LQ decomposition"))
Expand Down
16 changes: 9 additions & 7 deletions src/implementations/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ householder_qr!(::DefaultDriver, A, Q, R; kwargs...) =
function householder_qr!(
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false,
blocksize::Int = ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))
blocksize::Int = 0
)
blocksize = blocksize > 0 ? blocksize : ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A))

# error messages for disallowing driver - setting combinations
(blocksize == 1 || driver === LAPACK()) ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
Expand Down Expand Up @@ -202,10 +204,10 @@ function householder_qr!(
end
function householder_qr!(
driver::Native, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
# error messages for disallowing driver - setting combinations
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
Expand Down Expand Up @@ -249,9 +251,9 @@ householder_qr_null!(::DefaultDriver, A, N; kwargs...) =
householder_qr_null!(default_householder_driver(A), A, N; kwargs...)
function householder_qr_null!(
driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false,
blocksize::Int = ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A))
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
blocksize = blocksize > 0 ? blocksize : ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A))
# error messages for disallowing driver - setting combinations
(blocksize == 1 || driver === LAPACK()) ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
Expand All @@ -277,10 +279,10 @@ function householder_qr_null!(
end
function householder_qr_null!(
driver::Native, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 0
)
# error messages for disallowing driver - setting combinations
blocksize == 1 ||
blocksize <= 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))
Expand Down
8 changes: 7 additions & 1 deletion src/interface/decompositions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,17 @@ The optional `driver` symbol can be used to choose between different implementat

- `positive::Bool = true` : Fix the gauge of the resulting factors by making the diagonal elements of `L` or `R` non-negative.
- `pivoted::Bool = false` : Use column- or row-pivoting for low-rank input matrices.
- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`.
- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`. Use the default if `blocksize ≤ 0`.

Depending on the driver, various other keywords may be (un)available to customize the implementation.
"""
@algdef Householder
function Householder(;
blocksize::Int = 0, driver::Driver = DefaultDriver(),
pivoted::Bool = false, positive::Bool = true
)
return Householder((; blocksize, driver, pivoted, positive))
end

default_householder_driver(A) = default_householder_driver(typeof(A))
default_householder_driver(::Type) = Native()
Expand Down
4 changes: 2 additions & 2 deletions src/interface/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact).
default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...)

default_lq_algorithm(T::Type; kwargs...) = throw(MethodError(default_lq_algorithm, (T,)))
default_lq_algorithm(::Type{T}; driver = default_householder_driver(T), kwargs...) where {T <: AbstractMatrix} =
Householder(; driver, kwargs...)
default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} =
Householder(; kwargs...)
default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} =
DiagonalAlgorithm(; kwargs...)
default_lq_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =
Expand Down
28 changes: 14 additions & 14 deletions src/interface/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,79 +334,79 @@ See also [`left_null(!)`](@ref left_null), [`left_orth(!)`](@ref left_orth) and
@inline select_algorithm(::typeof(right_null!), A, alg::Symbol; kwargs...) =
select_algorithm(right_null!, A, Val(alg); kwargs...)

function select_algorithm(::typeof(left_orth!), A, ::Val{:qr}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(left_orth!), A, ::Val{:qr}; trunc = nothing, kwargs...)
isnothing(trunc) ||
throw(ArgumentError("QR-based `left_orth` is incompatible with specifying `trunc`"))
alg′ = select_algorithm(qr_compact!, A; kwargs...)
return LeftOrthViaQR(alg′)
end
function select_algorithm(::typeof(left_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(left_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...)
isnothing(trunc) ||
throw(ArgumentError("Polar-based `left_orth` is incompatible with specifying `trunc`"))
alg′ = select_algorithm(left_polar!, A; kwargs...)
return LeftOrthViaPolar(alg′)
end
function select_algorithm(::typeof(left_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(left_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...)
alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A; kwargs...) :
select_algorithm(svd_trunc!, A; trunc, kwargs...)
return LeftOrthViaSVD(alg′)
end

function select_algorithm(::typeof(right_orth!), A, ::Val{:lq}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(right_orth!), A, ::Val{:lq}; trunc = nothing, kwargs...)
isnothing(trunc) ||
throw(ArgumentError("LQ-based `right_orth` is incompatible with specifying `trunc`"))
alg = select_algorithm(lq_compact!, A; kwargs...)
return RightOrthViaLQ(alg)
end
function select_algorithm(::typeof(right_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(right_orth!), A, ::Val{:polar}; trunc = nothing, kwargs...)
isnothing(trunc) ||
throw(ArgumentError("Polar-based `right_orth` is incompatible with specifying `trunc`"))
alg = select_algorithm(right_polar!, A; kwargs...)
return RightOrthViaPolar(alg)
end
function select_algorithm(::typeof(right_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(right_orth!), A, ::Val{:svd}; trunc = nothing, kwargs...)
alg′ = isnothing(trunc) ? select_algorithm(svd_compact!, A; kwargs...) :
select_algorithm(svd_trunc!, A; trunc, kwargs...)
return RightOrthViaSVD(alg′)
end

function select_algorithm(::typeof(left_null!), A, ::Val{:qr}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(left_null!), A, ::Val{:qr}; trunc = nothing, kwargs...)
isnothing(trunc) ||
throw(ArgumentError("QR-based `left_null` is incompatible with specifying `trunc`"))
alg = select_algorithm(qr_null!, A; kwargs...)
return LeftNullViaQR(alg)
end
function select_algorithm(::typeof(left_null!), A, ::Val{:svd}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(left_null!), A, ::Val{:svd}; trunc = nothing, kwargs...)
alg_svd = select_algorithm(svd_full!, A, get(kwargs, :svd, nothing))
alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc))
return LeftNullViaSVD(alg)
end

function select_algorithm(::typeof(right_null!), A, ::Val{:lq}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(right_null!), A, ::Val{:lq}; trunc = nothing, kwargs...)
isnothing(trunc) ||
throw(ArgumentError("LQ-based `right_null` is incompatible with specifying `trunc`"))
alg = select_algorithm(lq_null!, A; kwargs...)
return RightNullViaLQ(alg)
end
function select_algorithm(::typeof(right_null!), A, ::Val{:svd}; trunc = nothing, kwargs...)
@inline function select_algorithm(::typeof(right_null!), A, ::Val{:svd}; trunc = nothing, kwargs...)
alg_svd = select_algorithm(svd_full!, A; kwargs...)
alg = TruncatedAlgorithm(alg_svd, select_null_truncation(trunc))
return RightNullViaSVD(alg)
end

default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} =
@inline default_algorithm(::typeof(left_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} =
isnothing(trunc) ? select_algorithm(left_orth!, A, Val(:qr); kwargs...) :
select_algorithm(left_orth!, A, Val(:svd); trunc, kwargs...)

default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} =
@inline default_algorithm(::typeof(right_orth!), ::Type{A}; trunc = nothing, kwargs...) where {A} =
isnothing(trunc) ? select_algorithm(right_orth!, A, Val(:lq); kwargs...) :
select_algorithm(right_orth!, A, Val(:svd); trunc, kwargs...)

default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} =
@inline default_algorithm(::typeof(left_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} =
isnothing(trunc) ? select_algorithm(left_null!, A, Val(:qr); kwargs...) :
select_algorithm(left_null!, A, Val(:svd); trunc, kwargs...)

default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} =
@inline default_algorithm(::typeof(right_null!), ::Type{A}; trunc = nothing, kwargs...) where {A} =
isnothing(trunc) ? select_algorithm(right_null!, A, Val(:lq); kwargs...) :
select_algorithm(right_null!, A, Val(:svd); trunc, kwargs...)

Expand Down
4 changes: 2 additions & 2 deletions src/interface/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact).
default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...)

default_qr_algorithm(T::Type; kwargs...) = throw(MethodError(default_qr_algorithm, (T,)))
default_qr_algorithm(::Type{T}; driver = default_householder_driver(T), kwargs...) where {T <: AbstractMatrix} =
Householder(; driver, kwargs...)
default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} =
Householder(; kwargs...)
default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} =
DiagonalAlgorithm(; kwargs...)
default_qr_algorithm(::Type{<:Base.ReshapedArray{T, N, A}}) where {T, N, A} =
Expand Down
8 changes: 4 additions & 4 deletions test/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using MatrixAlgebraKit
using Test
using TestExtras
using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm,
default_algorithm, select_algorithm, Householder, LAPACK
default_algorithm, select_algorithm, Householder, DefaultDriver

@testset "default_algorithm" begin
A = randn(3, 3)
Expand All @@ -17,21 +17,21 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm,
LAPACK_MultipleRelativelyRobustRepresentations()
end
for f in (lq_full!, lq_full, lq_compact!, lq_compact, lq_null!, lq_null)
@test @constinferred(default_algorithm(f, A)) == Householder(; driver = LAPACK())
@test @constinferred(default_algorithm(f, A)) == Householder()
end
for f in (left_polar!, left_polar, right_polar!, right_polar)
@test @constinferred(default_algorithm(f, A)) ==
PolarViaSVD(LAPACK_DivideAndConquer())
end
for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null)
@test @constinferred(default_algorithm(f, A)) == Householder(; driver = LAPACK())
@test @constinferred(default_algorithm(f, A)) == Householder()
end
for f in (schur_full!, schur_full, schur_vals!, schur_vals)
@test @constinferred(default_algorithm(f, A)) === LAPACK_Expert()
end

@test @constinferred(default_algorithm(qr_compact!, A; blocksize = 2)) ==
Householder(; driver = LAPACK(), blocksize = 2)
Householder(; blocksize = 2)
end

@testset "select_algorithm" begin
Expand Down
Loading
Loading