From eb913cb9db6dec5af611ce44c32765091ac81ced Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Thu, 6 Nov 2025 11:15:34 +0100 Subject: [PATCH 01/20] implement exponential work in progress --- src/MatrixAlgebraKit.jl | 4 + src/implementations/exp.jl | 48 +++++++ src/interface/decompositions.jl | 41 ++++++ src/interface/exp.jl | 18 +++ test/exp.jl | 28 ++++ test/runtests.jl | 237 ++++++++++++++++---------------- 6 files changed, 259 insertions(+), 117 deletions(-) create mode 100644 src/implementations/exp.jl create mode 100644 src/interface/exp.jl create mode 100644 test/exp.jl diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 4a846f85..cff94b8b 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -29,6 +29,7 @@ export left_polar, right_polar export left_polar!, right_polar! export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! +export exp, exp! export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, @@ -36,6 +37,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton +export LA_exponential, ExponentialViaEig, ExponentialViaEigh export DiagonalAlgorithm export NativeBlocked export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, @@ -92,6 +94,7 @@ include("interface/gen_eig.jl") include("interface/schur.jl") include("interface/polar.jl") include("interface/orthnull.jl") +include("interface/exp.jl") include("implementations/projections.jl") include("implementations/truncation.jl") @@ -104,6 +107,7 @@ include("implementations/gen_eig.jl") include("implementations/schur.jl") include("implementations/polar.jl") include("implementations/orthnull.jl") +include("implementations/exp.jl") include("pullbacks/qr.jl") include("pullbacks/lq.jl") diff --git a/src/implementations/exp.jl b/src/implementations/exp.jl new file mode 100644 index 00000000..0882aca0 --- /dev/null +++ b/src/implementations/exp.jl @@ -0,0 +1,48 @@ +# Inputs +# ------ +function copy_input(::typeof(exp), A::AbstractMatrix) + return copy!(similar(A, float(eltype(A))), A) +end + +copy_input(::typeof(exp), A::Diagonal) = copy(A) + +function check_input(::typeof(exp!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected")) + @assert expA isa AbstractMatrix + @check_size(expA, (m, m)) + @check_scalar(expA, A) +end + +# Outputs +# ------- +function initialize_output(::typeof(exp!), A::AbstractMatrix, ::AbstractAlgorithm) + n = size(A, 1) # square check will happen later + expA = similar(A, (n, n)) + return expA +end + +# Implementation +# -------------- +function exp!(A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + copyto!(expA, LinearAlgebra.exp!(A)) + return A +end + +function MatrixAlgebraKit.exp!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh) + D, V = eigh_full(A, alg.eigh_alg) + return V * Diagonal(exp.(diagview(D))) * inv(V) +end + +function MatrixAlgebraKit.exp!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEig) + D, V = eig_full(A, alg.eig_alg) + return V * Diagonal(exp.(diagview(D))) * inv(V) +end + +# Diagonal logic +# -------------- +function exp!(A::Diagonal, expA, alg::DiagonalAlgorithm) + check_input(exp!, A, expA, alg) + copyto!(expA, Diagonal(LinearAlgebra.exp.(diagview(A)))) + return expA +end diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 1bdf1534..3b3418c5 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -319,6 +319,47 @@ Divide and Conquer algorithm. const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} +# ================================ +# EXPONENTIAL ALGORITHMS +# ================================ +""" + LA_exponential() + +Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. +The `qr_alg` specifies which QR-decomposition implementation to use. +""" +@algdef LA_exponential + +""" + ExponentialViaEigh() + +Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. +The `qr_alg` specifies which QR-decomposition implementation to use. +""" +struct ExponentialViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm + eigh_alg::A +end +function Base.show(io::IO, alg::ExponentialViaEigh) + print(io, "ExponentialViaEigh(") + _show_alg(io, alg.eigh_alg) + return print(io, ")") +end + +""" + ExponentialViaEig() + +Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. +The `qr_alg` specifies which QR-decomposition implementation to use. +""" +struct ExponentialViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm + eig_alg::A +end +function Base.show(io::IO, alg::ExponentialViaEig) + print(io, "ExponentialViaEigh(") + _show_alg(io, alg.eig_alg) + return print(io, ")") +end + # Various consts and unions # ------------------------- diff --git a/src/interface/exp.jl b/src/interface/exp.jl new file mode 100644 index 00000000..c2ffc34a --- /dev/null +++ b/src/interface/exp.jl @@ -0,0 +1,18 @@ +# Exponetial functions +# -------------- + +# TODO: docs +@functiondef exp + +# Algorithm selection +# ------------------- +default_exp_algorithm(A; kwargs...) = default_exp_algorithm(typeof(A); kwargs...) +function default_exp_algorithm(T::Type; kwargs...) + return LA_exponential(; kwargs...) +end + +for f in (:exp!,) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_exp_algorithm(A; kwargs...) + end +end diff --git a/test/exp.jl b/test/exp.jl new file mode 100644 index 00000000..11237844 --- /dev/null +++ b/test/exp.jl @@ -0,0 +1,28 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra + +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +@testset "exp! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 2 + + A = randn(rng, T, m, m) + A = (A + A') / 2 + + D, V = @constinferred eigh_full(A) + + expA = @constinferred exp(A) + Dexp, Vexp = @constinferred eigh_full(expA) + + println("A = ", A) + println("exp(A) = ", expA) + + println("LHS = ", diagview(Dexp)) + println("RHS = ", LinearAlgebra.exp.(diagview(D))) + @assert diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index ec255538..17ef3588 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,125 +2,128 @@ using SafeTestsets # don't run all tests on GPU, only the GPU # specific ones -is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -if !is_buildkite - @safetestset "Algorithms" begin - include("algorithms.jl") - end - @safetestset "Projections" begin - include("projections.jl") - end - @safetestset "Truncate" begin - include("truncate.jl") - end - @safetestset "QR / LQ Decomposition" begin - include("qr.jl") - include("lq.jl") - end - @safetestset "Singular Value Decomposition" begin - include("svd.jl") - end - @safetestset "Hermitian Eigenvalue Decomposition" begin - include("eigh.jl") - end - @safetestset "General Eigenvalue Decomposition" begin - include("eig.jl") - end - @safetestset "Generalized Eigenvalue Decomposition" begin - include("gen_eig.jl") - end - @safetestset "Schur Decomposition" begin - include("schur.jl") - end - @safetestset "Polar Decomposition" begin - include("polar.jl") - end - @safetestset "Image and Null Space" begin - include("orthnull.jl") - end - @safetestset "ChainRules" begin - include("chainrules.jl") - end - @safetestset "MatrixAlgebraKit.jl" begin - @safetestset "Code quality (Aqua.jl)" begin - using MatrixAlgebraKit - using Aqua - Aqua.test_all(MatrixAlgebraKit) - end - @safetestset "Code linting (JET.jl)" begin - using MatrixAlgebraKit - using JET - JET.test_package(MatrixAlgebraKit; target_defined_modules = true) - end - end +# is_buildkite = get(ENV, "BUILDKITE", "false") == "true" +# if !is_buildkite +# @safetestset "Algorithms" begin +# include("algorithms.jl") +# end +# @safetestset "Projections" begin +# include("projections.jl") +# end +# @safetestset "Truncate" begin +# include("truncate.jl") +# end +# @safetestset "QR / LQ Decomposition" begin +# include("qr.jl") +# include("lq.jl") +# end +# @safetestset "Singular Value Decomposition" begin +# include("svd.jl") +# end +# @safetestset "Hermitian Eigenvalue Decomposition" begin +# include("eigh.jl") +# end +# @safetestset "General Eigenvalue Decomposition" begin +# include("eig.jl") +# end +# @safetestset "Generalized Eigenvalue Decomposition" begin +# include("gen_eig.jl") +# end +# @safetestset "Schur Decomposition" begin +# include("schur.jl") +# end +# @safetestset "Polar Decomposition" begin +# include("polar.jl") +# end +# @safetestset "Image and Null Space" begin +# include("orthnull.jl") +# end +# @safetestset "ChainRules" begin +# include("chainrules.jl") +# end +# @safetestset "MatrixAlgebraKit.jl" begin +# @safetestset "Code quality (Aqua.jl)" begin +# using MatrixAlgebraKit +# using Aqua +# Aqua.test_all(MatrixAlgebraKit) +# end +# @safetestset "Code linting (JET.jl)" begin +# using MatrixAlgebraKit +# using JET +# JET.test_package(MatrixAlgebraKit; target_defined_modules = true) +# end +# end +# end +@safetestset "Exponential" begin + include("exp.jl") end -using CUDA -if CUDA.functional() - @safetestset "CUDA QR" begin - include("cuda/qr.jl") - end - @safetestset "CUDA LQ" begin - include("cuda/lq.jl") - end - @safetestset "CUDA Projections" begin - include("cuda/projections.jl") - end - @safetestset "CUDA SVD" begin - include("cuda/svd.jl") - end - @safetestset "CUDA General Eigenvalue Decomposition" begin - include("cuda/eig.jl") - end - @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin - include("cuda/eigh.jl") - end - @safetestset "CUDA Polar Decomposition" begin - include("cuda/polar.jl") - end - @safetestset "CUDA Image and Null Space" begin - include("cuda/orthnull.jl") - end -end +# using CUDA +# if CUDA.functional() +# @safetestset "CUDA QR" begin +# include("cuda/qr.jl") +# end +# @safetestset "CUDA LQ" begin +# include("cuda/lq.jl") +# end +# @safetestset "CUDA Projections" begin +# include("cuda/projections.jl") +# end +# @safetestset "CUDA SVD" begin +# include("cuda/svd.jl") +# end +# @safetestset "CUDA General Eigenvalue Decomposition" begin +# include("cuda/eig.jl") +# end +# @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin +# include("cuda/eigh.jl") +# end +# @safetestset "CUDA Polar Decomposition" begin +# include("cuda/polar.jl") +# end +# @safetestset "CUDA Image and Null Space" begin +# include("cuda/orthnull.jl") +# end +# end -using AMDGPU -if AMDGPU.functional() - @safetestset "AMDGPU QR" begin - include("amd/qr.jl") - end - @safetestset "AMDGPU LQ" begin - include("amd/lq.jl") - end - @safetestset "AMDGPU Projections" begin - include("amd/projections.jl") - end - @safetestset "AMDGPU SVD" begin - include("amd/svd.jl") - end - @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin - include("amd/eigh.jl") - end - @safetestset "AMDGPU Polar Decomposition" begin - include("amd/polar.jl") - end - @safetestset "AMDGPU Image and Null Space" begin - include("amd/orthnull.jl") - end -end +# using AMDGPU +# if AMDGPU.functional() +# @safetestset "AMDGPU QR" begin +# include("amd/qr.jl") +# end +# @safetestset "AMDGPU LQ" begin +# include("amd/lq.jl") +# end +# @safetestset "AMDGPU Projections" begin +# include("amd/projections.jl") +# end +# @safetestset "AMDGPU SVD" begin +# include("amd/svd.jl") +# end +# @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin +# include("amd/eigh.jl") +# end +# @safetestset "AMDGPU Polar Decomposition" begin +# include("amd/polar.jl") +# end +# @safetestset "AMDGPU Image and Null Space" begin +# include("amd/orthnull.jl") +# end +# end -using GenericLinearAlgebra -@safetestset "QR / LQ Decomposition" begin - include("genericlinearalgebra/qr.jl") - include("genericlinearalgebra/lq.jl") -end -@safetestset "Singular Value Decomposition" begin - include("genericlinearalgebra/svd.jl") -end -@safetestset "Hermitian Eigenvalue Decomposition" begin - include("genericlinearalgebra/eigh.jl") -end +# using GenericLinearAlgebra +# @safetestset "QR / LQ Decomposition" begin +# include("genericlinearalgebra/qr.jl") +# include("genericlinearalgebra/lq.jl") +# end +# @safetestset "Singular Value Decomposition" begin +# include("genericlinearalgebra/svd.jl") +# end +# @safetestset "Hermitian Eigenvalue Decomposition" begin +# include("genericlinearalgebra/eigh.jl") +# end -using GenericSchur -@safetestset "General Eigenvalue Decomposition" begin - include("genericschur/eig.jl") -end +# using GenericSchur +# @safetestset "General Eigenvalue Decomposition" begin +# include("genericschur/eig.jl") +# end From a3dc04d89ffca5aa2730774ef385e42aa978117a Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Wed, 12 Nov 2025 15:07:23 +0100 Subject: [PATCH 02/20] update on exponential --- ext/MatrixAlgebraKitGenericSchurExt.jl | 4 + src/MatrixAlgebraKit.jl | 9 +- src/implementations/exp.jl | 48 ----- src/implementations/exponential.jl | 50 +++++ src/interface/decompositions.jl | 4 +- src/interface/exp.jl | 18 -- src/interface/exponential.jl | 17 ++ test/exp.jl | 28 --- test/exponential.jl | 48 +++++ test/genericlinearalgebra/exponential.jl | 28 +++ test/genericschur/exponential.jl | 29 +++ test/runtests.jl | 246 ++++++++++++----------- 12 files changed, 309 insertions(+), 220 deletions(-) delete mode 100644 src/implementations/exp.jl create mode 100644 src/implementations/exponential.jl delete mode 100644 src/interface/exp.jl create mode 100644 src/interface/exponential.jl delete mode 100644 test/exp.jl create mode 100644 test/exponential.jl create mode 100644 test/genericlinearalgebra/exponential.jl create mode 100644 test/genericschur/exponential.jl diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index d278b5c5..18ec4922 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -22,4 +22,8 @@ function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration) return GenericSchur.eigvals!(A) end +function MatrixAlgebraKit.default_exponential_algorithm(E::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} + return ExponentialViaEig(GS_QRIteration(; kwargs...)) +end + end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index cff94b8b..a59ff4ab 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -8,6 +8,7 @@ using LinearAlgebra: isposdef, issymmetric using LinearAlgebra: Diagonal, diag, diagind, isdiag using LinearAlgebra: UpperTriangular, LowerTriangular using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt +# import LinearAlgebra: exp, exp! export isisometric, isunitary, ishermitian, isantihermitian @@ -29,7 +30,7 @@ export left_polar, right_polar export left_polar!, right_polar! export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! -export exp, exp! +export exponential, exponential! export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, @@ -37,7 +38,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton -export LA_exponential, ExponentialViaEig, ExponentialViaEigh +export ExponentialViaLA, ExponentialViaEig, ExponentialViaEigh export DiagonalAlgorithm export NativeBlocked export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, @@ -94,7 +95,7 @@ include("interface/gen_eig.jl") include("interface/schur.jl") include("interface/polar.jl") include("interface/orthnull.jl") -include("interface/exp.jl") +include("interface/exponential.jl") include("implementations/projections.jl") include("implementations/truncation.jl") @@ -107,7 +108,7 @@ include("implementations/gen_eig.jl") include("implementations/schur.jl") include("implementations/polar.jl") include("implementations/orthnull.jl") -include("implementations/exp.jl") +include("implementations/exponential.jl") include("pullbacks/qr.jl") include("pullbacks/lq.jl") diff --git a/src/implementations/exp.jl b/src/implementations/exp.jl deleted file mode 100644 index 0882aca0..00000000 --- a/src/implementations/exp.jl +++ /dev/null @@ -1,48 +0,0 @@ -# Inputs -# ------ -function copy_input(::typeof(exp), A::AbstractMatrix) - return copy!(similar(A, float(eltype(A))), A) -end - -copy_input(::typeof(exp), A::Diagonal) = copy(A) - -function check_input(::typeof(exp!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) - m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) - @assert expA isa AbstractMatrix - @check_size(expA, (m, m)) - @check_scalar(expA, A) -end - -# Outputs -# ------- -function initialize_output(::typeof(exp!), A::AbstractMatrix, ::AbstractAlgorithm) - n = size(A, 1) # square check will happen later - expA = similar(A, (n, n)) - return expA -end - -# Implementation -# -------------- -function exp!(A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) - copyto!(expA, LinearAlgebra.exp!(A)) - return A -end - -function MatrixAlgebraKit.exp!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh) - D, V = eigh_full(A, alg.eigh_alg) - return V * Diagonal(exp.(diagview(D))) * inv(V) -end - -function MatrixAlgebraKit.exp!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEig) - D, V = eig_full(A, alg.eig_alg) - return V * Diagonal(exp.(diagview(D))) * inv(V) -end - -# Diagonal logic -# -------------- -function exp!(A::Diagonal, expA, alg::DiagonalAlgorithm) - check_input(exp!, A, expA, alg) - copyto!(expA, Diagonal(LinearAlgebra.exp.(diagview(A)))) - return expA -end diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl new file mode 100644 index 00000000..3a9b6f55 --- /dev/null +++ b/src/implementations/exponential.jl @@ -0,0 +1,50 @@ +# Inputs +# ------ +function copy_input(::typeof(exponential), A::AbstractMatrix) + return copy!(similar(A, float(eltype(A))), A) +end + +copy_input(::typeof(exponential), A::Diagonal) = copy(A) + +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected")) + @assert expA isa AbstractMatrix + @check_size(expA, (m, m)) + return @check_scalar(expA, A) +end + +# Outputs +# ------- +function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) + n = size(A, 1) # square check will happen later + expA = similar(A, (n, n)) + return expA +end + +# Implementation +# -------------- +function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaLA) + copyto!(expA, LinearAlgebra.exp(A)) + return expA +end + +function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh) + D, V = eigh_full(A, alg.eigh_alg) + copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V)) + return expA +end + +function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEig) + D, V = eig_full(A, alg.eig_alg) + copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V)) + return expA +end + +# Diagonal logic +# -------------- +function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm) + check_input(exponential!, A, expA, alg) + copyto!(expA, Diagonal(LinearAlgebra.exp.(diagview(A)))) + return expA +end diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 3b3418c5..b35b2a76 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -323,12 +323,12 @@ const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} # EXPONENTIAL ALGORITHMS # ================================ """ - LA_exponential() + ExponentialViaLA() Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. The `qr_alg` specifies which QR-decomposition implementation to use. """ -@algdef LA_exponential +@algdef ExponentialViaLA """ ExponentialViaEigh() diff --git a/src/interface/exp.jl b/src/interface/exp.jl deleted file mode 100644 index c2ffc34a..00000000 --- a/src/interface/exp.jl +++ /dev/null @@ -1,18 +0,0 @@ -# Exponetial functions -# -------------- - -# TODO: docs -@functiondef exp - -# Algorithm selection -# ------------------- -default_exp_algorithm(A; kwargs...) = default_exp_algorithm(typeof(A); kwargs...) -function default_exp_algorithm(T::Type; kwargs...) - return LA_exponential(; kwargs...) -end - -for f in (:exp!,) - @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_exp_algorithm(A; kwargs...) - end -end diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl new file mode 100644 index 00000000..2913bda4 --- /dev/null +++ b/src/interface/exponential.jl @@ -0,0 +1,17 @@ +# Exponential functions +# -------------- +@functiondef exponential +# @algdef exponential! + +# Algorithm selection +# ------------------- +default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...) +function default_exponential_algorithm(T::Type; kwargs...) + return ExponentialViaLA(; kwargs...) +end + +for f in (:exponential!,) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_exponential_algorithm(A; kwargs...) + end +end diff --git a/test/exp.jl b/test/exp.jl deleted file mode 100644 index 11237844..00000000 --- a/test/exp.jl +++ /dev/null @@ -1,28 +0,0 @@ -using MatrixAlgebraKit -using Test -using TestExtras -using StableRNGs -using MatrixAlgebraKit: diagview -using LinearAlgebra - -BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) -@testset "exp! for T = $T" for T in BLASFloats - rng = StableRNG(123) - m = 2 - - A = randn(rng, T, m, m) - A = (A + A') / 2 - - D, V = @constinferred eigh_full(A) - - expA = @constinferred exp(A) - Dexp, Vexp = @constinferred eigh_full(expA) - - println("A = ", A) - println("exp(A) = ", expA) - - println("LHS = ", diagview(Dexp)) - println("RHS = ", LinearAlgebra.exp.(diagview(D))) - @assert diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) - -end \ No newline at end of file diff --git a/test/exponential.jl b/test/exponential.jl new file mode 100644 index 00000000..70b21965 --- /dev/null +++ b/test/exponential.jl @@ -0,0 +1,48 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra + +BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) +GenericFloats = (Float16, BigFloat, Complex{BigFloat}) + +@testset "exp! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 2 + + A = randn(rng, T, m, m) + A = (A + A') / 2 + D, V = @constinferred eigh_full(A) + algs = (ExponentialViaLA(), ExponentialViaEig(LAPACK_Simple()), ExponentialViaEigh(LAPACK_QRIteration())) + expA_LA = @constinferred exp(A) + @testset "algorithm $alg" for alg in algs + expA = similar(A) + + @constinferred exponential!(copy(A), expA) + expA2 = @constinferred exponential(A; alg = alg) + @test expA ≈ expA_LA + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eigh_full(expA) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + end +end + +@testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) + rng = StableRNG(123) + atol = sqrt(eps(real(T))) + m = 54 + Ad = randn(T, m) + A = Diagonal(Ad) + + expA = similar(A) + @constinferred exponential!(copy(A), expA) + expA2 = @constinferred exponential(A; alg = DiagonalAlgorithm()) + @test expA2 ≈ expA + + D, V = @constinferred eig_full(A) + Dexp, Vexp = @constinferred eig_full(expA) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) +end diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl new file mode 100644 index 00000000..1872b825 --- /dev/null +++ b/test/genericlinearalgebra/exponential.jl @@ -0,0 +1,28 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra + +GenericFloats = (BigFloat, Complex{BigFloat}) + +@testset "exp! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 2 + + A = randn(rng, T, m, m) + A = (A + A') / 2 + D, V = @constinferred eigh_full(A) + algs = (ExponentialViaEigh(GLA_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expA = similar(A) + + @constinferred exponential!(copy(A), expA; alg) + expA2 = @constinferred exponential(A; alg) + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eigh_full(expA) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + end +end diff --git a/test/genericschur/exponential.jl b/test/genericschur/exponential.jl new file mode 100644 index 00000000..16007e82 --- /dev/null +++ b/test/genericschur/exponential.jl @@ -0,0 +1,29 @@ +using MatrixAlgebraKit +using Test +using TestExtras +using StableRNGs +using MatrixAlgebraKit: diagview +using LinearAlgebra + +GenericFloats = (BigFloat, Complex{BigFloat}) + +@testset "exp! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 2 + + A = randn(rng, T, m, m) + D, V = @constinferred eig_full(A) + algs = (ExponentialViaEig(GS_QRIteration()),) + expA_LA = @constinferred exponential(A) + @testset "algorithm $alg" for alg in algs + expA = similar(A) + + @constinferred exponential!(copy(A), expA) + expA2 = @constinferred exponential(A; alg = alg) + @test expA ≈ expA_LA + @test expA2 ≈ expA + + Dexp, Vexp = @constinferred eig_full(expA) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 17ef3588..347c4b21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,128 +2,134 @@ using SafeTestsets # don't run all tests on GPU, only the GPU # specific ones -# is_buildkite = get(ENV, "BUILDKITE", "false") == "true" -# if !is_buildkite -# @safetestset "Algorithms" begin -# include("algorithms.jl") -# end -# @safetestset "Projections" begin -# include("projections.jl") -# end -# @safetestset "Truncate" begin -# include("truncate.jl") -# end -# @safetestset "QR / LQ Decomposition" begin -# include("qr.jl") -# include("lq.jl") -# end -# @safetestset "Singular Value Decomposition" begin -# include("svd.jl") -# end -# @safetestset "Hermitian Eigenvalue Decomposition" begin -# include("eigh.jl") -# end -# @safetestset "General Eigenvalue Decomposition" begin -# include("eig.jl") -# end -# @safetestset "Generalized Eigenvalue Decomposition" begin -# include("gen_eig.jl") -# end -# @safetestset "Schur Decomposition" begin -# include("schur.jl") -# end -# @safetestset "Polar Decomposition" begin -# include("polar.jl") -# end -# @safetestset "Image and Null Space" begin -# include("orthnull.jl") -# end -# @safetestset "ChainRules" begin -# include("chainrules.jl") -# end -# @safetestset "MatrixAlgebraKit.jl" begin -# @safetestset "Code quality (Aqua.jl)" begin -# using MatrixAlgebraKit -# using Aqua -# Aqua.test_all(MatrixAlgebraKit) -# end -# @safetestset "Code linting (JET.jl)" begin -# using MatrixAlgebraKit -# using JET -# JET.test_package(MatrixAlgebraKit; target_defined_modules = true) -# end -# end -# end -@safetestset "Exponential" begin - include("exp.jl") +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" +if !is_buildkite + @safetestset "Algorithms" begin + include("algorithms.jl") + end + @safetestset "Projections" begin + include("projections.jl") + end + @safetestset "Truncate" begin + include("truncate.jl") + end + @safetestset "QR / LQ Decomposition" begin + include("qr.jl") + include("lq.jl") + end + @safetestset "Singular Value Decomposition" begin + include("svd.jl") + end + @safetestset "Hermitian Eigenvalue Decomposition" begin + include("eigh.jl") + end + @safetestset "General Eigenvalue Decomposition" begin + include("eig.jl") + end + @safetestset "Generalized Eigenvalue Decomposition" begin + include("gen_eig.jl") + end + @safetestset "Schur Decomposition" begin + include("schur.jl") + end + @safetestset "Polar Decomposition" begin + include("polar.jl") + end + @safetestset "Image and Null Space" begin + include("orthnull.jl") + end + @safetestset "Exponential" begin + include("exponential.jl") + end + @safetestset "ChainRules" begin + include("chainrules.jl") + end + @safetestset "MatrixAlgebraKit.jl" begin + @safetestset "Code quality (Aqua.jl)" begin + using MatrixAlgebraKit + using Aqua + Aqua.test_all(MatrixAlgebraKit) + end + @safetestset "Code linting (JET.jl)" begin + using MatrixAlgebraKit + using JET + JET.test_package(MatrixAlgebraKit; target_defined_modules = true) + end + end end -# using CUDA -# if CUDA.functional() -# @safetestset "CUDA QR" begin -# include("cuda/qr.jl") -# end -# @safetestset "CUDA LQ" begin -# include("cuda/lq.jl") -# end -# @safetestset "CUDA Projections" begin -# include("cuda/projections.jl") -# end -# @safetestset "CUDA SVD" begin -# include("cuda/svd.jl") -# end -# @safetestset "CUDA General Eigenvalue Decomposition" begin -# include("cuda/eig.jl") -# end -# @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin -# include("cuda/eigh.jl") -# end -# @safetestset "CUDA Polar Decomposition" begin -# include("cuda/polar.jl") -# end -# @safetestset "CUDA Image and Null Space" begin -# include("cuda/orthnull.jl") -# end -# end +using CUDA +if CUDA.functional() + @safetestset "CUDA QR" begin + include("cuda/qr.jl") + end + @safetestset "CUDA LQ" begin + include("cuda/lq.jl") + end + @safetestset "CUDA Projections" begin + include("cuda/projections.jl") + end + @safetestset "CUDA SVD" begin + include("cuda/svd.jl") + end + @safetestset "CUDA General Eigenvalue Decomposition" begin + include("cuda/eig.jl") + end + @safetestset "CUDA Hermitian Eigenvalue Decomposition" begin + include("cuda/eigh.jl") + end + @safetestset "CUDA Polar Decomposition" begin + include("cuda/polar.jl") + end + @safetestset "CUDA Image and Null Space" begin + include("cuda/orthnull.jl") + end +end -# using AMDGPU -# if AMDGPU.functional() -# @safetestset "AMDGPU QR" begin -# include("amd/qr.jl") -# end -# @safetestset "AMDGPU LQ" begin -# include("amd/lq.jl") -# end -# @safetestset "AMDGPU Projections" begin -# include("amd/projections.jl") -# end -# @safetestset "AMDGPU SVD" begin -# include("amd/svd.jl") -# end -# @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin -# include("amd/eigh.jl") -# end -# @safetestset "AMDGPU Polar Decomposition" begin -# include("amd/polar.jl") -# end -# @safetestset "AMDGPU Image and Null Space" begin -# include("amd/orthnull.jl") -# end -# end +using AMDGPU +if AMDGPU.functional() + @safetestset "AMDGPU QR" begin + include("amd/qr.jl") + end + @safetestset "AMDGPU LQ" begin + include("amd/lq.jl") + end + @safetestset "AMDGPU Projections" begin + include("amd/projections.jl") + end + @safetestset "AMDGPU SVD" begin + include("amd/svd.jl") + end + @safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin + include("amd/eigh.jl") + end + @safetestset "AMDGPU Polar Decomposition" begin + include("amd/polar.jl") + end + @safetestset "AMDGPU Image and Null Space" begin + include("amd/orthnull.jl") + end +end -# using GenericLinearAlgebra -# @safetestset "QR / LQ Decomposition" begin -# include("genericlinearalgebra/qr.jl") -# include("genericlinearalgebra/lq.jl") -# end -# @safetestset "Singular Value Decomposition" begin -# include("genericlinearalgebra/svd.jl") -# end -# @safetestset "Hermitian Eigenvalue Decomposition" begin -# include("genericlinearalgebra/eigh.jl") -# end +using GenericLinearAlgebra +@safetestset "QR / LQ Decomposition" begin + include("genericlinearalgebra/qr.jl") + include("genericlinearalgebra/lq.jl") +end +@safetestset "Singular Value Decomposition" begin + include("genericlinearalgebra/svd.jl") +end +@safetestset "Hermitian Eigenvalue Decomposition" begin + include("genericlinearalgebra/eigh.jl") +end +@safetestset "Exponential" begin + include("genericlinearalgebra/exponential.jl") +end -# using GenericSchur -# @safetestset "General Eigenvalue Decomposition" begin -# include("genericschur/eig.jl") -# end +using GenericSchur +@safetestset "General Eigenvalue Decomposition" begin + include("genericschur/eig.jl") +end +@safetestset "Exponential" begin + include("genericschur/exponential.jl") +end From 8dc3ecda064ba86b8ef4c59fa4c991f47c5ab277 Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Wed, 12 Nov 2025 15:15:44 +0100 Subject: [PATCH 03/20] remove comment --- src/MatrixAlgebraKit.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index a59ff4ab..bf6a906e 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -8,7 +8,6 @@ using LinearAlgebra: isposdef, issymmetric using LinearAlgebra: Diagonal, diag, diagind, isdiag using LinearAlgebra: UpperTriangular, LowerTriangular using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt -# import LinearAlgebra: exp, exp! export isisometric, isunitary, ishermitian, isantihermitian From 5095cdb1c367bddffeecaf6a75ce55c6609cb977 Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Thu, 13 Nov 2025 11:15:32 +0100 Subject: [PATCH 04/20] comments change some input tests remove redundant comment include ComplexF16 in tests fix unchanged test names and docs improve allocations --- src/implementations/exponential.jl | 26 +++++++++++++++++++++----- src/interface/decompositions.jl | 11 +++++------ src/interface/exponential.jl | 1 - test/exponential.jl | 6 +++--- 4 files changed, 29 insertions(+), 15 deletions(-) diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl index 3a9b6f55..63ebfc78 100644 --- a/src/implementations/exponential.jl +++ b/src/implementations/exponential.jl @@ -8,12 +8,20 @@ copy_input(::typeof(exponential), A::Diagonal) = copy(A) function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) m, n = size(A) - m == n || throw(DimensionMismatch("square input matrix expected")) - @assert expA isa AbstractMatrix + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) @check_size(expA, (m, m)) return @check_scalar(expA, A) end +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert expA isa Diagonal + @check_size(expA, (m, m)) + @check_scalar(expA, A) + return nothing +end + # Outputs # ------- function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) @@ -22,6 +30,10 @@ function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::Abstract return expA end +function initialize_output(::typeof(exponential!), A::Diagonal, ::DiagonalAlgorithm) + return similar(A) +end + # Implementation # -------------- function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaLA) @@ -31,13 +43,17 @@ end function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh) D, V = eigh_full(A, alg.eigh_alg) - copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V)) + iV = inv(V) + map!(exp, diagview(D), diagview(D)) + mul!(expA, rmul!(V, D), iV) return expA end function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEig) D, V = eig_full(A, alg.eig_alg) - copyto!(expA, V * Diagonal(exp.(diagview(D))) * inv(V)) + iV = inv(V) + map!(exp, diagview(D), diagview(D)) + mul!(expA, rmul!(V, D), iV) return expA end @@ -45,6 +61,6 @@ end # -------------- function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm) check_input(exponential!, A, expA, alg) - copyto!(expA, Diagonal(LinearAlgebra.exp.(diagview(A)))) + map!(exp, diagview(expA), diagview(A)) return expA end diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index b35b2a76..ae754a37 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -325,16 +325,15 @@ const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} """ ExponentialViaLA() -Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. -The `qr_alg` specifies which QR-decomposition implementation to use. +Algorithm type to denote finding the exponential of `A` via the implementation of `LinearAlgebra`. """ @algdef ExponentialViaLA """ ExponentialViaEigh() -Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. -The `qr_alg` specifies which QR-decomposition implementation to use. +Algorithm type to denote finding the exponential `A` by computing the hermitian eigendecomposition of `A`. +The `eigh_alg` specifies which hermitian eigendecomposition implementation to use. """ struct ExponentialViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm eigh_alg::A @@ -348,8 +347,8 @@ end """ ExponentialViaEig() -Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. -The `qr_alg` specifies which QR-decomposition implementation to use. +Algorithm type to denote finding the exponential `A` by computing the eigendecomposition of `A`. +The `eig_alg` specifies which eigendecomposition implementation to use. """ struct ExponentialViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm eig_alg::A diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl index 2913bda4..5cfcef3f 100644 --- a/src/interface/exponential.jl +++ b/src/interface/exponential.jl @@ -1,7 +1,6 @@ # Exponential functions # -------------- @functiondef exponential -# @algdef exponential! # Algorithm selection # ------------------- diff --git a/test/exponential.jl b/test/exponential.jl index 70b21965..a532c813 100644 --- a/test/exponential.jl +++ b/test/exponential.jl @@ -6,9 +6,9 @@ using MatrixAlgebraKit: diagview using LinearAlgebra BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) -GenericFloats = (Float16, BigFloat, Complex{BigFloat}) +GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) -@testset "exp! for T = $T" for T in BLASFloats +@testset "exponential! for T = $T" for T in BLASFloats rng = StableRNG(123) m = 2 @@ -30,7 +30,7 @@ GenericFloats = (Float16, BigFloat, Complex{BigFloat}) end end -@testset "svd for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) +@testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) atol = sqrt(eps(real(T))) m = 54 From 89dfa238e2e41fc29803643ee0bd38ce5cb7d929 Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Wed, 19 Nov 2025 15:07:32 +0100 Subject: [PATCH 05/20] change name of decompositions.jl to matrixfunctions.jl --- src/interface/{decompositions.jl => matrixfunctions.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/interface/{decompositions.jl => matrixfunctions.jl} (100%) diff --git a/src/interface/decompositions.jl b/src/interface/matrixfunctions.jl similarity index 100% rename from src/interface/decompositions.jl rename to src/interface/matrixfunctions.jl From 996ecb52f2208319a057ce7b6d71c7f58f28c29e Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Wed, 19 Nov 2025 15:11:33 +0100 Subject: [PATCH 06/20] revert name change --- src/interface/{matrixfunctions.jl => decompositions.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/interface/{matrixfunctions.jl => decompositions.jl} (100%) diff --git a/src/interface/matrixfunctions.jl b/src/interface/decompositions.jl similarity index 100% rename from src/interface/matrixfunctions.jl rename to src/interface/decompositions.jl From f220035e53462c2face270a250eb5b525405b67a Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Thu, 20 Nov 2025 15:47:04 +0100 Subject: [PATCH 07/20] general comments change name to `MatrixFunctionViaEig` etc change `decompositions` to `matrixfunctions` add default algorithm for Diagonal matrices add input checks add @testthrows to catch non-hermitian matrices being given to MatrixFunctionViaEigh change default exponential algorithm to e.g. `MatrixFunctionViaEig` of the default `eig_alg` --- ext/MatrixAlgebraKitGenericSchurExt.jl | 3 +- src/MatrixAlgebraKit.jl | 2 +- src/implementations/exponential.jl | 31 ++++++++++++++++--- src/interface/exponential.jl | 5 ++- .../{decompositions.jl => matrixfunctions.jl} | 20 ++++++------ test/exponential.jl | 8 ++--- test/genericlinearalgebra/exponential.jl | 2 +- test/genericschur/exponential.jl | 2 +- 8 files changed, 50 insertions(+), 23 deletions(-) rename src/interface/{decompositions.jl => matrixfunctions.jl} (97%) diff --git a/ext/MatrixAlgebraKitGenericSchurExt.jl b/ext/MatrixAlgebraKitGenericSchurExt.jl index 18ec4922..371b0481 100644 --- a/ext/MatrixAlgebraKitGenericSchurExt.jl +++ b/ext/MatrixAlgebraKitGenericSchurExt.jl @@ -23,7 +23,8 @@ function MatrixAlgebraKit.eig_vals!(A::AbstractMatrix, D, ::GS_QRIteration) end function MatrixAlgebraKit.default_exponential_algorithm(E::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{BigFloat, Complex{BigFloat}}}} - return ExponentialViaEig(GS_QRIteration(; kwargs...)) + eig_alg = MatrixAlgebraKit.default_eig_algorithm(E; kwargs...) + return MatrixFunctionViaEig(eig_alg) end end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index e21d1d2f..c887b4c4 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -37,7 +37,7 @@ export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, export GLA_HouseholderQR, GLA_QRIteration, GS_QRIteration export LQViaTransposedQR export PolarViaSVD, PolarNewton -export ExponentialViaLA, ExponentialViaEig, ExponentialViaEigh +export MatrixFunctionViaLA, MatrixFunctionViaEig, MatrixFunctionViaEigh export DiagonalAlgorithm export NativeBlocked export CUSOLVER_Simple, CUSOLVER_HouseholderQR, CUSOLVER_QRIteration, CUSOLVER_SVDPolar, diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl index 63ebfc78..656614df 100644 --- a/src/implementations/exponential.jl +++ b/src/implementations/exponential.jl @@ -13,6 +13,16 @@ function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMa return @check_scalar(expA, A) end +function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + if !ishermitian(A) + throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix)")) + end + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + @check_size(expA, (m, m)) + return @check_scalar(expA, A) +end + function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) m, n = size(A) @assert m == n && isdiag(A) @@ -36,20 +46,33 @@ end # Implementation # -------------- -function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaLA) +function exponential!(A::AbstractMatrix{T}, expA::AbstractMatrix{T}, alg::MatrixFunctionViaLA) where {T <: BlasFloat} + check_input(exponential!, A, expA, alg) copyto!(expA, LinearAlgebra.exp(A)) return expA end -function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEigh) +function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + check_input(exponential!, A, expA, alg) D, V = eigh_full(A, alg.eigh_alg) + + diagview(D) .= exp.( diagview(D) ./ 2) + rmul!(V, D) + mul!(expA, V, adjoint(V)) + return expA +end + +function exponential!(A::AbstractMatrix{T}, expA::AbstractMatrix{T}, alg::MatrixFunctionViaEig) where {T <: Real} + check_input(exponential!, A, expA, alg) + D, V = eig_full(A, alg.eig_alg) iV = inv(V) map!(exp, diagview(D), diagview(D)) - mul!(expA, rmul!(V, D), iV) + expA .= real.(rmul!(V, D) * iV) return expA end -function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::ExponentialViaEig) +function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) + check_input(exponential!, A, expA, alg) D, V = eig_full(A, alg.eig_alg) iV = inv(V) map!(exp, diagview(D), diagview(D)) diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl index 5cfcef3f..099edcf8 100644 --- a/src/interface/exponential.jl +++ b/src/interface/exponential.jl @@ -6,7 +6,10 @@ # ------------------- default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...) function default_exponential_algorithm(T::Type; kwargs...) - return ExponentialViaLA(; kwargs...) + return MatrixFunctionViaLA(; kwargs...) +end +function default_exponential_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} + return DiagonalAlgorithm(; kwargs...) end for f in (:exponential!,) diff --git a/src/interface/decompositions.jl b/src/interface/matrixfunctions.jl similarity index 97% rename from src/interface/decompositions.jl rename to src/interface/matrixfunctions.jl index 25ad344a..a54f3109 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/matrixfunctions.jl @@ -357,38 +357,38 @@ const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} # EXPONENTIAL ALGORITHMS # ================================ """ - ExponentialViaLA() + MatrixFunctionViaLA() Algorithm type to denote finding the exponential of `A` via the implementation of `LinearAlgebra`. """ -@algdef ExponentialViaLA +@algdef MatrixFunctionViaLA """ - ExponentialViaEigh() + MatrixFunctionViaEigh() Algorithm type to denote finding the exponential `A` by computing the hermitian eigendecomposition of `A`. The `eigh_alg` specifies which hermitian eigendecomposition implementation to use. """ -struct ExponentialViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm +struct MatrixFunctionViaEigh{A <: AbstractAlgorithm} <: AbstractAlgorithm eigh_alg::A end -function Base.show(io::IO, alg::ExponentialViaEigh) - print(io, "ExponentialViaEigh(") +function Base.show(io::IO, alg::MatrixFunctionViaEigh) + print(io, "MatrixFunctionViaEigh(") _show_alg(io, alg.eigh_alg) return print(io, ")") end """ - ExponentialViaEig() + MatrixFunctionViaEig() Algorithm type to denote finding the exponential `A` by computing the eigendecomposition of `A`. The `eig_alg` specifies which eigendecomposition implementation to use. """ -struct ExponentialViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm +struct MatrixFunctionViaEig{A <: AbstractAlgorithm} <: AbstractAlgorithm eig_alg::A end -function Base.show(io::IO, alg::ExponentialViaEig) - print(io, "ExponentialViaEigh(") +function Base.show(io::IO, alg::MatrixFunctionViaEig) + print(io, "MatrixFunctionViaEig(") _show_alg(io, alg.eig_alg) return print(io, ")") end diff --git a/test/exponential.jl b/test/exponential.jl index a532c813..16327b30 100644 --- a/test/exponential.jl +++ b/test/exponential.jl @@ -13,9 +13,8 @@ GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) m = 2 A = randn(rng, T, m, m) - A = (A + A') / 2 - D, V = @constinferred eigh_full(A) - algs = (ExponentialViaLA(), ExponentialViaEig(LAPACK_Simple()), ExponentialViaEigh(LAPACK_QRIteration())) + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) expA_LA = @constinferred exp(A) @testset "algorithm $alg" for alg in algs expA = similar(A) @@ -25,9 +24,10 @@ GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) @test expA ≈ expA_LA @test expA2 ≈ expA - Dexp, Vexp = @constinferred eigh_full(expA) + Dexp, Vexp = @constinferred eig_full(expA) @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) end + @test_throws DomainError exponential(A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) end @testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl index 1872b825..b3dfbe8c 100644 --- a/test/genericlinearalgebra/exponential.jl +++ b/test/genericlinearalgebra/exponential.jl @@ -14,7 +14,7 @@ GenericFloats = (BigFloat, Complex{BigFloat}) A = randn(rng, T, m, m) A = (A + A') / 2 D, V = @constinferred eigh_full(A) - algs = (ExponentialViaEigh(GLA_QRIteration()),) + algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) @testset "algorithm $alg" for alg in algs expA = similar(A) diff --git a/test/genericschur/exponential.jl b/test/genericschur/exponential.jl index 16007e82..794b3cb6 100644 --- a/test/genericschur/exponential.jl +++ b/test/genericschur/exponential.jl @@ -13,7 +13,7 @@ GenericFloats = (BigFloat, Complex{BigFloat}) A = randn(rng, T, m, m) D, V = @constinferred eig_full(A) - algs = (ExponentialViaEig(GS_QRIteration()),) + algs = (MatrixFunctionViaEig(GS_QRIteration()),) expA_LA = @constinferred exponential(A) @testset "algorithm $alg" for alg in algs expA = similar(A) From c68afaddd22592ed6c082a3e80446c3a7d4f44da Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Thu, 20 Nov 2025 18:00:32 +0100 Subject: [PATCH 08/20] bug fix Name change in `MatrixAlgebraKit.jl` fix formatting --- src/MatrixAlgebraKit.jl | 2 +- src/implementations/exponential.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index c887b4c4..54946c66 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -82,7 +82,7 @@ include("common/matrixproperties.jl") include("yalapack.jl") include("algorithms.jl") include("interface/projections.jl") -include("interface/decompositions.jl") +include("interface/matrixfunctions.jl") include("interface/truncation.jl") include("interface/qr.jl") include("interface/lq.jl") diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl index 656614df..7a69e067 100644 --- a/src/implementations/exponential.jl +++ b/src/implementations/exponential.jl @@ -56,7 +56,7 @@ function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFuncti check_input(exponential!, A, expA, alg) D, V = eigh_full(A, alg.eigh_alg) - diagview(D) .= exp.( diagview(D) ./ 2) + diagview(D) .= exp.(diagview(D) ./ 2) rmul!(V, D) mul!(expA, V, adjoint(V)) return expA From 95ddb06fdece778b452ec51944b527f9d44cc047 Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Thu, 20 Nov 2025 23:32:03 +0100 Subject: [PATCH 09/20] avoid allocation in diagonal case --- src/implementations/exponential.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl index 7a69e067..365e730b 100644 --- a/src/implementations/exponential.jl +++ b/src/implementations/exponential.jl @@ -40,9 +40,7 @@ function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::Abstract return expA end -function initialize_output(::typeof(exponential!), A::Diagonal, ::DiagonalAlgorithm) - return similar(A) -end +initialize_output(::typeof(exponential!), A::Diagonal, ::DiagonalAlgorithm) = A # Implementation # -------------- From c8e811c2a2390c12a2e1ed26d3042c90bd04df87 Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Wed, 26 Nov 2025 14:10:11 +0100 Subject: [PATCH 10/20] include exponentiali(tau, A) --- src/MatrixAlgebraKit.jl | 2 +- src/implementations/exponential.jl | 81 ++++++++++++++++++++++++ src/interface/exponential.jl | 8 +++ test/exponential.jl | 50 ++++++++++++++- test/genericlinearalgebra/exponential.jl | 30 ++++++++- test/genericschur/exponential.jl | 27 +++++++- 6 files changed, 189 insertions(+), 9 deletions(-) diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 54946c66..99f686c1 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -29,7 +29,7 @@ export left_polar, right_polar export left_polar!, right_polar! export left_orth, right_orth, left_null, right_null export left_orth!, right_orth!, left_null!, right_null! -export exponential, exponential! +export exponential, exponential!, exponentiali, exponentiali! export LAPACK_HouseholderQR, LAPACK_HouseholderLQ, LAPACK_Simple, LAPACK_Expert, LAPACK_QRIteration, LAPACK_Bisection, LAPACK_MultipleRelativelyRobustRepresentations, diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl index 365e730b..33a63b8b 100644 --- a/src/implementations/exponential.jl +++ b/src/implementations/exponential.jl @@ -6,6 +6,12 @@ end copy_input(::typeof(exponential), A::Diagonal) = copy(A) +function copy_input(::typeof(exponentiali), τ::Number, A::AbstractMatrix) + return τ, copy!(similar(A, complex(eltype(A))), A) +end + +copy_input(::typeof(exponentiali), τ::Number, A::Diagonal) = τ, copy(A) + function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) m, n = size(A) m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) @@ -32,6 +38,28 @@ function check_input(::typeof(exponential!), A::AbstractMatrix, expA::AbstractMa return nothing end +function check_input(::typeof(exponentiali!), A::AbstractMatrix, expA::AbstractMatrix, alg::AbstractAlgorithm) + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + return @check_size(expA, (m, m)) +end + +function check_input(::typeof(exponentiali!), A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + if !ishermitian(A) + throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix)")) + end + m, n = size(A) + m == n || throw(DimensionMismatch("square input matrix expected. Got ($m,$n)")) + return @check_size(expA, (m, m)) +end + +function check_input(::typeof(exponentiali!), A::AbstractMatrix, expA::AbstractMatrix, ::DiagonalAlgorithm) + m, n = size(A) + @assert m == n && isdiag(A) + @assert expA isa Diagonal + return @check_size(expA, (m, m)) +end + # Outputs # ------- function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) @@ -42,6 +70,14 @@ end initialize_output(::typeof(exponential!), A::Diagonal, ::DiagonalAlgorithm) = A +function initialize_output(::typeof(exponentiali!), τ::Number, A::AbstractMatrix, ::AbstractAlgorithm) + n = size(A, 1) # square check will happen later + expA = similar(complex(A), (n, n)) + return expA +end + +initialize_output(::typeof(exponentiali!), τ::Number, A::Diagonal, ::DiagonalAlgorithm) = complex(A) + # Implementation # -------------- function exponential!(A::AbstractMatrix{T}, expA::AbstractMatrix{T}, alg::MatrixFunctionViaLA) where {T <: BlasFloat} @@ -78,6 +114,45 @@ function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFuncti return expA end +function exponentiali!(τ::Number, A::AbstractMatrix{T1}, expA::AbstractMatrix{T2}, alg::MatrixFunctionViaLA) where {T1 <: BlasFloat, T2 <: BlasFloat} + check_input(exponentiali!, A, expA, alg) + copyto!(expA, LinearAlgebra.exp(im*τ*A)) + return expA +end + +function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) + check_input(exponentiali!, A, expA, alg) + Dreal, Vreal = eigh_full(A, alg.eigh_alg) + + Dcomplex = complex(Dreal) + Vcomplex = complex(Vreal) + + iV = copy(adjoint(Vcomplex)) + + diagview(Dcomplex) .= exp.(diagview(Dcomplex) .* (im*τ)) + rmul!(Vcomplex, Dcomplex) + mul!(expA, Vcomplex, iV) + return expA +end + +function exponentiali!(τ::T, A::AbstractMatrix{T}, expA::AbstractMatrix{T}, alg::MatrixFunctionViaEig) where {T <: Real} + check_input(exponentiali!, A, expA, alg) + D, V = eig_full(A, alg.eig_alg) + iV = inv(V) + map!(exp, diagview(D), diagview(D) .* (im*τ)) + expA .= real.(rmul!(V, D) * iV) + return expA +end + +function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) + check_input(exponentiali!, A, expA, alg) + D, V = eig_full(A, alg.eig_alg) + iV = inv(V) + map!(exp, diagview(D), diagview(D) .* (im*τ)) + mul!(expA, rmul!(V, D), iV) + return expA +end + # Diagonal logic # -------------- function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm) @@ -85,3 +160,9 @@ function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm) map!(exp, diagview(expA), diagview(A)) return expA end + +function exponentiali!(τ::Number, A::Diagonal, expA, alg::DiagonalAlgorithm) + check_input(exponentiali!, A, expA, alg) + map!(exp, diagview(expA), diagview(A) .* (im*τ)) + return expA +end diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl index 099edcf8..2a015232 100644 --- a/src/interface/exponential.jl +++ b/src/interface/exponential.jl @@ -2,6 +2,8 @@ # -------------- @functiondef exponential +@functiondef n_args = 2 exponentiali + # Algorithm selection # ------------------- default_exponential_algorithm(A; kwargs...) = default_exponential_algorithm(typeof(A); kwargs...) @@ -17,3 +19,9 @@ for f in (:exponential!,) return default_exponential_algorithm(A; kwargs...) end end + +for f in (:exponentiali!,) + @eval function default_algorithm(::typeof($f), ::Tuple{A,B}; kwargs...) where {A, B} + return default_exponential_algorithm(B; kwargs...) + end +end diff --git a/test/exponential.jl b/test/exponential.jl index 16327b30..c4173ccc 100644 --- a/test/exponential.jl +++ b/test/exponential.jl @@ -10,12 +10,14 @@ GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) @testset "exponential! for T = $T" for T in BLASFloats rng = StableRNG(123) - m = 2 + m = 54 A = randn(rng, T, m, m) + A /= norm(A) + D, V = @constinferred eig_full(A) algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) - expA_LA = @constinferred exp(A) + expA_LA = @constinferred exponential(A) @testset "algorithm $alg" for alg in algs expA = similar(A) @@ -25,11 +27,35 @@ GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) @test expA2 ≈ expA Dexp, Vexp = @constinferred eig_full(expA) - @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D)); by = imag) end @test_throws DomainError exponential(A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) end +@testset "exponentiali! for T = $T" for T in BLASFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + τ = randn(rng, T) + + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) + expiτA_LA = @constinferred exp(im*τ*A) + @testset "algorithm $alg" for alg in algs + expiτA = similar(complex(A)) + + @constinferred exponentiali!(τ, copy(A), expiτA; alg) + expiτA2 = @constinferred exponentiali(τ, A; alg = alg) + @test expiτA ≈ expiτA_LA + @test expiτA2 ≈ expiτA + + Dexp, Vexp = @constinferred eig_full(expiτA) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im*τ)); by = imag) + end + @test_throws DomainError exponentiali(τ, A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) +end + @testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) atol = sqrt(eps(real(T))) @@ -46,3 +72,21 @@ end Dexp, Vexp = @constinferred eig_full(expA) @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) end + +@testset "exponentiali! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) + rng = StableRNG(123) + atol = sqrt(eps(real(T))) + m = 54 + Ad = randn(T, m) + A = Diagonal(Ad) + τ = randn(rng, T) + + expiτA = similar(complex(A)) + @constinferred exponentiali!(τ, copy(A), expiτA) + expiτA2 = @constinferred exponentiali(τ, A; alg = DiagonalAlgorithm()) + @test expiτA2 ≈ expiτA + + D, V = @constinferred eig_full(A) + Dexp, Vexp = @constinferred eig_full(expiτA) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D) .* (im*τ)) +end diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl index b3dfbe8c..a55105d0 100644 --- a/test/genericlinearalgebra/exponential.jl +++ b/test/genericlinearalgebra/exponential.jl @@ -7,9 +7,9 @@ using LinearAlgebra GenericFloats = (BigFloat, Complex{BigFloat}) -@testset "exp! for T = $T" for T in GenericFloats +@testset "exponential! for T = $T" for T in GenericFloats rng = StableRNG(123) - m = 2 + m = 54 A = randn(rng, T, m, m) A = (A + A') / 2 @@ -26,3 +26,29 @@ GenericFloats = (BigFloat, Complex{BigFloat}) @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) end end + +using GenericSchur +@testset "exponentiali! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 2 + + A = randn(rng, T, m, m) + A = (A + A') / 2 + τ = randn(rng, real(T)) + + A = Complex{BigFloat}[1.0 0.0; 0.0 -1.0] + τ = Complex{BigFloat}(2.0) + + D, V = @constinferred eigh_full(A) + algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expiτA = similar(complex(A)) + + @constinferred exponentiali!(τ, copy(A), expiτA; alg) + expiτA2 = @constinferred exponentiali(τ, A; alg) + @test expiτA2 ≈ expiτA + + Dexp, Vexp = @constinferred eig_full(expiτA) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im*τ)); by = imag) + end +end diff --git a/test/genericschur/exponential.jl b/test/genericschur/exponential.jl index 794b3cb6..d0a36579 100644 --- a/test/genericschur/exponential.jl +++ b/test/genericschur/exponential.jl @@ -7,9 +7,9 @@ using LinearAlgebra GenericFloats = (BigFloat, Complex{BigFloat}) -@testset "exp! for T = $T" for T in GenericFloats +@testset "exponential! for T = $T" for T in GenericFloats rng = StableRNG(123) - m = 2 + m = 54 A = randn(rng, T, m, m) D, V = @constinferred eig_full(A) @@ -24,6 +24,27 @@ GenericFloats = (BigFloat, Complex{BigFloat}) @test expA2 ≈ expA Dexp, Vexp = @constinferred eig_full(expA) - @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D)); by = imag) + end +end + +@testset "exponentiali! for T = $T" for T in GenericFloats + rng = StableRNG(123) + m = 54 + + A = randn(rng, T, m, m) + τ = randn(rng, T) + + D, V = @constinferred eig_full(A) + algs = (MatrixFunctionViaEig(GS_QRIteration()),) + @testset "algorithm $alg" for alg in algs + expiτA = similar(complex(A)) + + @constinferred exponentiali!(τ, copy(A), expiτA) + expiτA2 = @constinferred exponentiali(τ, A; alg) + @test expiτA2 ≈ expiτA + + Dexp, Vexp = @constinferred eig_full(expiτA) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im*τ)); by = imag) end end From 022941756b1745ecd636497be7aa85805b0543aa Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Wed, 26 Nov 2025 14:13:37 +0100 Subject: [PATCH 11/20] remove simple test case and make the test more general --- test/genericlinearalgebra/exponential.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl index a55105d0..394d0ce0 100644 --- a/test/genericlinearalgebra/exponential.jl +++ b/test/genericlinearalgebra/exponential.jl @@ -34,10 +34,7 @@ using GenericSchur A = randn(rng, T, m, m) A = (A + A') / 2 - τ = randn(rng, real(T)) - - A = Complex{BigFloat}[1.0 0.0; 0.0 -1.0] - τ = Complex{BigFloat}(2.0) + τ = randn(rng, T) D, V = @constinferred eigh_full(A) algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) From cbbf813744d9e8f567afdff15cfea4d061e7287d Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Wed, 26 Nov 2025 14:14:27 +0100 Subject: [PATCH 12/20] fix formatting --- src/implementations/exponential.jl | 10 +++++----- src/interface/exponential.jl | 2 +- test/exponential.jl | 6 +++--- test/genericlinearalgebra/exponential.jl | 2 +- test/genericschur/exponential.jl | 2 +- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl index 33a63b8b..bc48ed29 100644 --- a/src/implementations/exponential.jl +++ b/src/implementations/exponential.jl @@ -116,7 +116,7 @@ end function exponentiali!(τ::Number, A::AbstractMatrix{T1}, expA::AbstractMatrix{T2}, alg::MatrixFunctionViaLA) where {T1 <: BlasFloat, T2 <: BlasFloat} check_input(exponentiali!, A, expA, alg) - copyto!(expA, LinearAlgebra.exp(im*τ*A)) + copyto!(expA, LinearAlgebra.exp(im * τ * A)) return expA end @@ -129,7 +129,7 @@ function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg: iV = copy(adjoint(Vcomplex)) - diagview(Dcomplex) .= exp.(diagview(Dcomplex) .* (im*τ)) + diagview(Dcomplex) .= exp.(diagview(Dcomplex) .* (im * τ)) rmul!(Vcomplex, Dcomplex) mul!(expA, Vcomplex, iV) return expA @@ -139,7 +139,7 @@ function exponentiali!(τ::T, A::AbstractMatrix{T}, expA::AbstractMatrix{T}, alg check_input(exponentiali!, A, expA, alg) D, V = eig_full(A, alg.eig_alg) iV = inv(V) - map!(exp, diagview(D), diagview(D) .* (im*τ)) + map!(exp, diagview(D), diagview(D) .* (im * τ)) expA .= real.(rmul!(V, D) * iV) return expA end @@ -148,7 +148,7 @@ function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg: check_input(exponentiali!, A, expA, alg) D, V = eig_full(A, alg.eig_alg) iV = inv(V) - map!(exp, diagview(D), diagview(D) .* (im*τ)) + map!(exp, diagview(D), diagview(D) .* (im * τ)) mul!(expA, rmul!(V, D), iV) return expA end @@ -163,6 +163,6 @@ end function exponentiali!(τ::Number, A::Diagonal, expA, alg::DiagonalAlgorithm) check_input(exponentiali!, A, expA, alg) - map!(exp, diagview(expA), diagview(A) .* (im*τ)) + map!(exp, diagview(expA), diagview(A) .* (im * τ)) return expA end diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl index 2a015232..3d0abe42 100644 --- a/src/interface/exponential.jl +++ b/src/interface/exponential.jl @@ -21,7 +21,7 @@ for f in (:exponential!,) end for f in (:exponentiali!,) - @eval function default_algorithm(::typeof($f), ::Tuple{A,B}; kwargs...) where {A, B} + @eval function default_algorithm(::typeof($f), ::Tuple{A, B}; kwargs...) where {A, B} return default_exponential_algorithm(B; kwargs...) end end diff --git a/test/exponential.jl b/test/exponential.jl index c4173ccc..55327442 100644 --- a/test/exponential.jl +++ b/test/exponential.jl @@ -41,7 +41,7 @@ end D, V = @constinferred eig_full(A) algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) - expiτA_LA = @constinferred exp(im*τ*A) + expiτA_LA = @constinferred exp(im * τ * A) @testset "algorithm $alg" for alg in algs expiτA = similar(complex(A)) @@ -51,7 +51,7 @@ end @test expiτA2 ≈ expiτA Dexp, Vexp = @constinferred eig_full(expiτA) - @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im*τ)); by = imag) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im * τ)); by = imag) end @test_throws DomainError exponentiali(τ, A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) end @@ -88,5 +88,5 @@ end D, V = @constinferred eig_full(A) Dexp, Vexp = @constinferred eig_full(expiτA) - @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D) .* (im*τ)) + @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D) .* (im * τ)) end diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl index 394d0ce0..c0720ba8 100644 --- a/test/genericlinearalgebra/exponential.jl +++ b/test/genericlinearalgebra/exponential.jl @@ -46,6 +46,6 @@ using GenericSchur @test expiτA2 ≈ expiτA Dexp, Vexp = @constinferred eig_full(expiτA) - @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im*τ)); by = imag) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im * τ)); by = imag) end end diff --git a/test/genericschur/exponential.jl b/test/genericschur/exponential.jl index d0a36579..7e272a5e 100644 --- a/test/genericschur/exponential.jl +++ b/test/genericschur/exponential.jl @@ -45,6 +45,6 @@ end @test expiτA2 ≈ expiτA Dexp, Vexp = @constinferred eig_full(expiτA) - @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im*τ)); by = imag) + @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im * τ)); by = imag) end end From d08d545ae8236a1134c6061cee3efc34acc5595f Mon Sep 17 00:00:00 2001 From: sanderdemeyer Date: Mon, 1 Dec 2025 12:22:13 +0100 Subject: [PATCH 13/20] add docs --- src/interface/exponential.jl | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/interface/exponential.jl b/src/interface/exponential.jl index 3d0abe42..ffd725d0 100644 --- a/src/interface/exponential.jl +++ b/src/interface/exponential.jl @@ -1,7 +1,40 @@ # Exponential functions # -------------- + +""" + exponential(A; kwargs...) -> expA + exponential(A, alg::AbstractAlgorithm) -> expA + exponential!(A, [expA]; kwargs...) -> expA + exponential!(A, [expA], alg::AbstractAlgorithm) -> expA + +Compute the exponential of the square matrix `A`, + +!!! note + The bang method `exponential!` optionally accepts the output structure and + possibly destroys the input matrix `A`. Always use the return value of the function + as it may not always be possible to use the provided `expA` as output. + +See also [`exponentiali(!)`](@ref exponentiali). +""" @functiondef exponential +""" + exponentiali(τ, A; kwargs...) -> expiτA + exponentiali(τ, A, alg::AbstractAlgorithm) -> expiτA + exponentiali!(τ, A, [expiτA]; kwargs...) -> expiτA + exponentiali!(τ, A, [expiτA], alg::AbstractAlgorithm) -> expiτA + +Compute the exponential of `i*τ*A`, where `i` is the imaginary unit, `τ` is a scalar, and `A` is a square matrix. +This allows the user to use the hermitian eigendecomposition when `A` is hermitian, even when `i*τ*A` is not. + +!!! note + The bang method `exponentiali!` optionally accepts the output structure and + possibly destroys the input matrix `A`. + Always use the return value of the function as it may not always be + possible to use the provided `expiτA` as output. + +See also [`exponential(!)`](@ref exponential). +""" @functiondef n_args = 2 exponentiali # Algorithm selection From 720ada57a640683d0311a46210618a06b3bf4b4c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Dec 2025 13:54:45 -0500 Subject: [PATCH 14/20] remove a bunch of allocations and clean up --- src/implementations/exponential.jl | 94 ++++++++++-------------------- 1 file changed, 32 insertions(+), 62 deletions(-) diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl index bc48ed29..613011a4 100644 --- a/src/implementations/exponential.jl +++ b/src/implementations/exponential.jl @@ -62,59 +62,40 @@ end # Outputs # ------- -function initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) - n = size(A, 1) # square check will happen later - expA = similar(A, (n, n)) - return expA -end - -initialize_output(::typeof(exponential!), A::Diagonal, ::DiagonalAlgorithm) = A - -function initialize_output(::typeof(exponentiali!), τ::Number, A::AbstractMatrix, ::AbstractAlgorithm) - n = size(A, 1) # square check will happen later - expA = similar(complex(A), (n, n)) - return expA -end - -initialize_output(::typeof(exponentiali!), τ::Number, A::Diagonal, ::DiagonalAlgorithm) = complex(A) +initialize_output(::typeof(exponential!), A::AbstractMatrix, ::AbstractAlgorithm) = A +initialize_output(::typeof(exponentiali!), τ::Number, A::AbstractMatrix, ::AbstractAlgorithm) = + complex(A) # Implementation # -------------- -function exponential!(A::AbstractMatrix{T}, expA::AbstractMatrix{T}, alg::MatrixFunctionViaLA) where {T <: BlasFloat} +function exponential!(A, expA, alg::MatrixFunctionViaLA) check_input(exponential!, A, expA, alg) - copyto!(expA, LinearAlgebra.exp(A)) - return expA + return LinearAlgebra.exp!(A) end function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) check_input(exponential!, A, expA, alg) - D, V = eigh_full(A, alg.eigh_alg) - + D, V = eigh_full!(A, alg.eigh_alg) diagview(D) .= exp.(diagview(D) ./ 2) - rmul!(V, D) - mul!(expA, V, adjoint(V)) - return expA -end - -function exponential!(A::AbstractMatrix{T}, expA::AbstractMatrix{T}, alg::MatrixFunctionViaEig) where {T <: Real} - check_input(exponential!, A, expA, alg) - D, V = eig_full(A, alg.eig_alg) - iV = inv(V) - map!(exp, diagview(D), diagview(D)) - expA .= real.(rmul!(V, D) * iV) - return expA + VexpD = rmul!(V, D) + return mul!(expA, VexpD, V') end function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) check_input(exponential!, A, expA, alg) - D, V = eig_full(A, alg.eig_alg) + D, V = eig_full!(A, alg.eig_alg) + diagview(D) .= exp.(diagview(D)) iV = inv(V) - map!(exp, diagview(D), diagview(D)) - mul!(expA, rmul!(V, D), iV) + VexpD = rmul!(V, D) + if eltype(A) <: Real + expA .= real.(VexpD * iV) + else + mul!(expA, VexpD, iV) + end return expA end -function exponentiali!(τ::Number, A::AbstractMatrix{T1}, expA::AbstractMatrix{T2}, alg::MatrixFunctionViaLA) where {T1 <: BlasFloat, T2 <: BlasFloat} +function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaLA) check_input(exponentiali!, A, expA, alg) copyto!(expA, LinearAlgebra.exp(im * τ * A)) return expA @@ -122,47 +103,36 @@ end function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) check_input(exponentiali!, A, expA, alg) - Dreal, Vreal = eigh_full(A, alg.eigh_alg) - - Dcomplex = complex(Dreal) - Vcomplex = complex(Vreal) - - iV = copy(adjoint(Vcomplex)) - - diagview(Dcomplex) .= exp.(diagview(Dcomplex) .* (im * τ)) - rmul!(Vcomplex, Dcomplex) - mul!(expA, Vcomplex, iV) - return expA -end - -function exponentiali!(τ::T, A::AbstractMatrix{T}, expA::AbstractMatrix{T}, alg::MatrixFunctionViaEig) where {T <: Real} - check_input(exponentiali!, A, expA, alg) - D, V = eig_full(A, alg.eig_alg) - iV = inv(V) - map!(exp, diagview(D), diagview(D) .* (im * τ)) - expA .= real.(rmul!(V, D) * iV) - return expA + D, V = eigh_full!(A, alg.eigh_alg) + expD = diagonal(exp.(diagview(D) .* (im * τ))) + if eltype(A) <: Real + VexpD = V * expD + return expA .= real.(VexpD * V') + else + VexpD = rmul!(V, expD) + return mul!(expA, VexpD, V') + end end function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) check_input(exponentiali!, A, expA, alg) - D, V = eig_full(A, alg.eig_alg) + D, V = eig_full!(A, alg.eig_alg) + diagview(D) .= exp.(diagview(D) .* (im * τ)) iV = inv(V) - map!(exp, diagview(D), diagview(D) .* (im * τ)) - mul!(expA, rmul!(V, D), iV) - return expA + VexpD = rmul!(V, D) + return mul!(expA, VexpD, iV) end # Diagonal logic # -------------- function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm) check_input(exponential!, A, expA, alg) - map!(exp, diagview(expA), diagview(A)) + diagview(expA) .= exp.(diagview(A)) return expA end function exponentiali!(τ::Number, A::Diagonal, expA, alg::DiagonalAlgorithm) check_input(exponentiali!, A, expA, alg) - map!(exp, diagview(expA), diagview(A) .* (im * τ)) + diagview(expA) .= exp.(diagview(A) .* (im * τ)) return expA end From be111ea66ad5745cf7ac1f1b7edd13ddb074b879 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Dec 2025 14:25:36 -0500 Subject: [PATCH 15/20] introduce `map_diagonal` to simplify and relax types --- src/common/view.jl | 20 +++++++++++++++++++ src/implementations/exponential.jl | 32 ++++++++++++++---------------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/common/view.jl b/src/common/view.jl index e03bfb88..b9b4c8cc 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -20,6 +20,26 @@ See also [`diagview`](@ref). diagonal(v::AbstractVector) = Diagonal(v) +""" + map_diagonal!(f, dst, src...) + +Map the scalar function `f` over all elements of the diagonal of `src...`, returning +a diagonal result. + +See also [`map_diagonal!`](@ref). +""" +map_diagonal(f, src, srcs...) = diagonal(f.(diagview(src), map(diagview, srcs)...)) + +""" + map_diagonal!(f, dst, src...) + +Map the scalar function `f` over all elements of the diagonal of `src...`, +into the diagonal elements of destination `dst`. + +See also [`map_diagonal`](@ref). +""" +map_diagonal!(f, dst, src, srcs...) = (diagview(dst) .= f.(diagview(src), map(diagview, srcs)...); dst) + # triangularind function lowertriangularind(A::AbstractMatrix) Base.require_one_based_indexing(A) diff --git a/src/implementations/exponential.jl b/src/implementations/exponential.jl index 613011a4..430ff213 100644 --- a/src/implementations/exponential.jl +++ b/src/implementations/exponential.jl @@ -73,20 +73,20 @@ function exponential!(A, expA, alg::MatrixFunctionViaLA) return LinearAlgebra.exp!(A) end -function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) +function exponential!(A, expA, alg::MatrixFunctionViaEigh) check_input(exponential!, A, expA, alg) D, V = eigh_full!(A, alg.eigh_alg) - diagview(D) .= exp.(diagview(D) ./ 2) - VexpD = rmul!(V, D) + expD = map_diagonal!(x -> exp(x / 2), D, D) + VexpD = rmul!(V, expD) return mul!(expA, VexpD, V') end function exponential!(A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) check_input(exponential!, A, expA, alg) D, V = eig_full!(A, alg.eig_alg) - diagview(D) .= exp.(diagview(D)) + expD = map_diagonal!(exp, D, D) iV = inv(V) - VexpD = rmul!(V, D) + VexpD = rmul!(V, expD) if eltype(A) <: Real expA .= real.(VexpD * iV) else @@ -97,14 +97,14 @@ end function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaLA) check_input(exponentiali!, A, expA, alg) - copyto!(expA, LinearAlgebra.exp(im * τ * A)) - return expA + expA .= A .* (im * τ) + return LinearAlgebra.exp!(expA) end function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEigh) check_input(exponentiali!, A, expA, alg) D, V = eigh_full!(A, alg.eigh_alg) - expD = diagonal(exp.(diagview(D) .* (im * τ))) + expD = map_diagonal(x -> exp(x * im * τ), D) if eltype(A) <: Real VexpD = V * expD return expA .= real.(VexpD * V') @@ -114,25 +114,23 @@ function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg: end end -function exponentiali!(τ::Number, A::AbstractMatrix, expA::AbstractMatrix, alg::MatrixFunctionViaEig) +function exponentiali!(τ::Number, A, expA, alg::MatrixFunctionViaEig) check_input(exponentiali!, A, expA, alg) D, V = eig_full!(A, alg.eig_alg) - diagview(D) .= exp.(diagview(D) .* (im * τ)) + expD = map_diagonal!(x -> exp(x * im * τ), D, D) iV = inv(V) - VexpD = rmul!(V, D) + VexpD = rmul!(V, expD) return mul!(expA, VexpD, iV) end # Diagonal logic # -------------- -function exponential!(A::Diagonal, expA, alg::DiagonalAlgorithm) +function exponential!(A, expA, alg::DiagonalAlgorithm) check_input(exponential!, A, expA, alg) - diagview(expA) .= exp.(diagview(A)) - return expA + return map_diagonal!(exp, expA, A) end -function exponentiali!(τ::Number, A::Diagonal, expA, alg::DiagonalAlgorithm) +function exponentiali!(τ::Number, A, expA, alg::DiagonalAlgorithm) check_input(exponentiali!, A, expA, alg) - diagview(expA) .= exp.(diagview(A) .* (im * τ)) - return expA + return map_diagonal!(x -> exp(x * im * τ), expA, A) end From c760a47149a64546e49b39af35cf6897eca3b74e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Dec 2025 14:25:52 -0500 Subject: [PATCH 16/20] rework tests --- test/exponential.jl | 84 +++++++++++++++++++++------------------------ 1 file changed, 40 insertions(+), 44 deletions(-) diff --git a/test/exponential.jl b/test/exponential.jl index 55327442..d817510d 100644 --- a/test/exponential.jl +++ b/test/exponential.jl @@ -4,6 +4,7 @@ using TestExtras using StableRNGs using MatrixAlgebraKit: diagview using LinearAlgebra +using LinearAlgebra: exp BLASFloats = (Float32, Float64, ComplexF32, ComplexF64) GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) @@ -12,23 +13,21 @@ GenericFloats = (Float16, ComplexF16, BigFloat, Complex{BigFloat}) rng = StableRNG(123) m = 54 - A = randn(rng, T, m, m) - A /= norm(A) + A = LinearAlgebra.normalize!(randn(rng, T, m, m)) + Ac = copy(A) + expA = LinearAlgebra.exp(A) + + expA2 = @constinferred exponential(A) + @test expA ≈ expA2 + @test A == Ac - D, V = @constinferred eig_full(A) algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) - expA_LA = @constinferred exponential(A) @testset "algorithm $alg" for alg in algs - expA = similar(A) - - @constinferred exponential!(copy(A), expA) - expA2 = @constinferred exponential(A; alg = alg) - @test expA ≈ expA_LA - @test expA2 ≈ expA - - Dexp, Vexp = @constinferred eig_full(expA) - @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D)); by = imag) + expA2 = @constinferred exponential(A, alg) + @test expA ≈ expA2 + @test A == Ac end + @test_throws DomainError exponential(A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) end @@ -38,55 +37,52 @@ end A = randn(rng, T, m, m) τ = randn(rng, T) + Ac = copy(A) - D, V = @constinferred eig_full(A) - algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) - expiτA_LA = @constinferred exp(im * τ * A) - @testset "algorithm $alg" for alg in algs - expiτA = similar(complex(A)) + Aimτ = A * im * τ + expAimτ = LinearAlgebra.exp(Aimτ) - @constinferred exponentiali!(τ, copy(A), expiτA; alg) - expiτA2 = @constinferred exponentiali(τ, A; alg = alg) - @test expiτA ≈ expiτA_LA - @test expiτA2 ≈ expiτA + expAimτ2 = @constinferred exponentiali(τ, A) + @test expAimτ ≈ expAimτ2 + @test A == Ac - Dexp, Vexp = @constinferred eig_full(expiτA) - @test sort(diagview(Dexp); by = imag) ≈ sort(LinearAlgebra.exp.(diagview(D) .* (im * τ)); by = imag) + algs = (MatrixFunctionViaLA(), MatrixFunctionViaEig(LAPACK_Simple())) + @testset "algorithm $alg" for alg in algs + expAimτ2 = @constinferred exponentiali(τ, A, alg) + @test expAimτ ≈ expAimτ2 + @test A == Ac end + @test_throws DomainError exponentiali(τ, A; alg = MatrixFunctionViaEigh(LAPACK_QRIteration())) end @testset "exponential! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) - atol = sqrt(eps(real(T))) m = 54 - Ad = randn(T, m) - A = Diagonal(Ad) - expA = similar(A) - @constinferred exponential!(copy(A), expA) - expA2 = @constinferred exponential(A; alg = DiagonalAlgorithm()) - @test expA2 ≈ expA + A = Diagonal(randn(rng, T, m)) + τ = randn(rng, T) + Ac = copy(A) + + expA = LinearAlgebra.exp(A) - D, V = @constinferred eig_full(A) - Dexp, Vexp = @constinferred eig_full(expA) - @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D)) + expA2 = @constinferred exponential(A) + @test expA ≈ expA2 + @test A == Ac end @testset "exponentiali! for Diagonal{$T}" for T in (BLASFloats..., GenericFloats...) rng = StableRNG(123) - atol = sqrt(eps(real(T))) m = 54 - Ad = randn(T, m) - A = Diagonal(Ad) + + A = Diagonal(randn(rng, T, m)) τ = randn(rng, T) + Ac = copy(A) - expiτA = similar(complex(A)) - @constinferred exponentiali!(τ, copy(A), expiτA) - expiτA2 = @constinferred exponentiali(τ, A; alg = DiagonalAlgorithm()) - @test expiτA2 ≈ expiτA + Aimτ = A * im * τ + expAimτ = LinearAlgebra.exp(Aimτ) - D, V = @constinferred eig_full(A) - Dexp, Vexp = @constinferred eig_full(expiτA) - @test diagview(Dexp) ≈ LinearAlgebra.exp.(diagview(D) .* (im * τ)) + expAimτ2 = @constinferred exponentiali(τ, A) + @test expAimτ ≈ expAimτ2 + @test A == Ac end From d0d14e14aa5f4d58bf67420f503710a80354f0b0 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Dec 2025 14:33:50 -0500 Subject: [PATCH 17/20] revert wrong filename changes --- src/MatrixAlgebraKit.jl | 5 +- src/interface/decompositions.jl | 494 ++++++++++++++++++++++++++++++ src/interface/matrixfunctions.jl | 495 ------------------------------- src/matrixfunctions.jl | 1 - 4 files changed, 498 insertions(+), 497 deletions(-) create mode 100644 src/interface/decompositions.jl delete mode 100644 src/matrixfunctions.jl diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index 0a476abd..81d9d486 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -83,9 +83,12 @@ include("common/matrixproperties.jl") include("yalapack.jl") include("algorithms.jl") + include("interface/projections.jl") -include("interface/matrixfunctions.jl") +include("interface/decompositions.jl") include("interface/truncation.jl") +include("interface/matrixfunctions.jl") + include("interface/qr.jl") include("interface/lq.jl") include("interface/svd.jl") diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl new file mode 100644 index 00000000..0fc7403d --- /dev/null +++ b/src/interface/decompositions.jl @@ -0,0 +1,494 @@ +# TODO: module Decompositions? + +# ================= +# LAPACK ALGORITHMS +# ================= + +# reference for naming LAPACK algorithms: +# https://www.netlib.org/lapack/explore-html/topics.html + +# QR, LQ, QL, RQ Decomposition +# ---------------------------- +""" + LAPACK_HouseholderQR(; blocksize, positive = false, pivoted = false) + +Algorithm type to denote the standard LAPACK algorithm for computing the QR decomposition of +a matrix using Householder reflectors. The specific LAPACK function can be controlled using +the keyword arugments, i.e. `?geqrt` will be chosen if `blocksize > 1`. With +`blocksize == 1`, `?geqrf` will be chosen if `pivoted == false` and `?geqp3` will be chosen +if `pivoted == true`. The keyword `positive = true` can be used to ensure that the diagonal +elements of `R` are non-negative. +""" +@algdef LAPACK_HouseholderQR + +""" + LAPACK_HouseholderLQ(; blocksize, positive = false) + +Algorithm type to denote the standard LAPACK algorithm for computing the LQ decomposition of +a matrix using Householder reflectors. The specific LAPACK function can be controlled using +the keyword arugments, i.e. `?gelqt` will be chosen if `blocksize > 1` or `?gelqf` will be +chosen if `blocksize == 1`. The keyword `positive = true` can be used to ensure that the diagonal +elements of `L` are non-negative. +""" +@algdef LAPACK_HouseholderLQ + +""" + GLA_HouseholderQR(; positive = false) + +Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the QR decomposition +of a matrix using Householder reflectors. Currently, only `blocksize = 1` and `pivoted == false` +are supported. The keyword `positive = true` can be used to ensure that the diagonal elements +of `R` are non-negative. +""" +@algdef GLA_HouseholderQR + +# TODO: +@algdef LAPACK_HouseholderQL +@algdef LAPACK_HouseholderRQ + +# General Eigenvalue Decomposition +# ------------------------------- +""" + LAPACK_Simple(; fixgauge::Bool = true) + +Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian +eigenvalue decomposition of a matrix. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, +see also [`gaugefix!`](@ref). +""" +@algdef LAPACK_Simple + +""" + LAPACK_Expert(; fixgauge::Bool = true) + +Algorithm type to denote the expert LAPACK driver for computing the Schur or non-Hermitian +eigenvalue decomposition of a matrix. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, +see also [`gaugefix!`](@ref). +""" +@algdef LAPACK_Expert + +const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert} + +""" + GS_QRIteration() + +Algorithm type to denote the GenericSchur.jl implementation for computing the +eigenvalue decomposition of a non-Hermitian matrix. +""" +@algdef GS_QRIteration + +# Hermitian Eigenvalue Decomposition +# ---------------------------------- +""" + LAPACK_QRIteration(; fixgauge::Bool = true) + +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +QR Iteration algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef LAPACK_QRIteration + +""" + LAPACK_Bisection(; fixgauge::Bool = true) + +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Bisection algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef LAPACK_Bisection + +""" + LAPACK_DivideAndConquer(; fixgauge::Bool = true) + +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Divide and Conquer algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef LAPACK_DivideAndConquer + +""" + LAPACK_MultipleRelativelyRobustRepresentations(; fixgauge::Bool = true) + +Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a +Hermitian matrix using the Multiple Relatively Robust Representations algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, +see also [`gaugefix!`](@ref). +""" +@algdef LAPACK_MultipleRelativelyRobustRepresentations + +const LAPACK_EighAlgorithm = Union{ + LAPACK_QRIteration, + LAPACK_Bisection, + LAPACK_DivideAndConquer, + LAPACK_MultipleRelativelyRobustRepresentations, +} + +""" + GLA_QRIteration(; fixgauge::Bool = true) + +Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the +eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of +a general matrix. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef GLA_QRIteration + +# Singular Value Decomposition +# ---------------------------- +""" + LAPACK_Jacobi(; fixgauge::Bool = true) + +Algorithm type to denote the LAPACK driver for computing the singular value decomposition of +a general matrix using the Jacobi algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, +see also [`gaugefix!`](@ref). +""" +@algdef LAPACK_Jacobi + +const LAPACK_SVDAlgorithm = Union{ + LAPACK_QRIteration, + LAPACK_Bisection, + LAPACK_DivideAndConquer, + LAPACK_Jacobi, +} + +# ========================= +# Polar decompositions +# ========================= +""" + PolarViaSVD(svd_alg) + +Algorithm for computing the polar decomposition of a matrix `A` via the singular value +decomposition (SVD) of `A`. The `svd_alg` argument specifies the SVD algorithm to use. +""" +struct PolarViaSVD{SVDAlg} <: AbstractAlgorithm + svd_alg::SVDAlg +end + +""" + PolarNewton(; maxiter = 10, tol = defaulttol(A)) + +Algorithm for computing the polar decomposition of a matrix `A` via +scaled Newton iteration, with a maximum of `maxiter` iterations and +until convergence up to tolerance `tol`. +""" +@algdef PolarNewton + +# ========================= +# Varia +# ========================= +""" + DiagonalAlgorithm(; kwargs...) + +Algorithm type to denote a native Julia implementation of the decompositions making use of +the diagonal structure of the input and outputs. +""" +@algdef DiagonalAlgorithm + +""" + LQViaTransposedQR(qr_alg) + +Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. +The `qr_alg` specifies which QR-decomposition implementation to use. +""" +struct LQViaTransposedQR{A <: AbstractAlgorithm} <: AbstractAlgorithm + qr_alg::A +end +function Base.show(io::IO, alg::LQViaTransposedQR) + print(io, "LQViaTransposedQR(") + _show_alg(io, alg.qr_alg) + return print(io, ")") +end + +# ========================= +# CUSOLVER ALGORITHMS +# ========================= +""" + CUSOLVER_HouseholderQR(; positive = false) + +Algorithm type to denote the standard CUSOLVER algorithm for computing the QR decomposition of +a matrix using Householder reflectors. The keyword `positive = true` can be used to ensure that +the diagonal elements of `R` are non-negative. +""" +@algdef CUSOLVER_HouseholderQR + +""" + CUSOLVER_QRIteration(; fixgauge::Bool = true) + +Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +QR Iteration algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef CUSOLVER_QRIteration + +""" + CUSOLVER_SVDPolar(; fixgauge::Bool = true) + +Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of +a general matrix by using Halley's iterative algorithm to compute the polar decompositon, +followed by the hermitian eigenvalue decomposition of the positive definite factor. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular +vectors, see also [`gaugefix!`](@ref). +""" +@algdef CUSOLVER_SVDPolar + +""" + CUSOLVER_Jacobi(; fixgauge::Bool = true) + +Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of +a general matrix using the Jacobi algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular +vectors, see also [`gaugefix!`](@ref). +""" +@algdef CUSOLVER_Jacobi + +""" + CUSOLVER_Randomized(; k, p, niters) + +Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of +a general matrix using the randomized SVD algorithm. Here, `k` denotes the number of singular +values that should be computed, therefore requiring `k <= min(size(A))`. This method is accurate +for small values of `k` compared to the size of the input matrix, where the accuracy can be +improved by increasing `p`, the number of additional values used for oversampling, +and `niters`, the number of iterations the solver uses, at the cost of increasing the runtime. + +See also the [CUSOLVER documentation](https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdnxgesvdr) +for more information. +""" +@algdef CUSOLVER_Randomized + +does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true + +""" + CUSOLVER_Simple(; fixgauge::Bool = true) + +Algorithm type to denote the simple CUSOLVER driver for computing the non-Hermitian +eigenvalue decomposition of a matrix. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, +see also [`gaugefix!`](@ref). +""" +@algdef CUSOLVER_Simple + +const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} + +""" + CUSOLVER_DivideAndConquer(; fixgauge::Bool = true) + +Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Divide and Conquer algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef CUSOLVER_DivideAndConquer + +const CUSOLVER_SVDAlgorithm = Union{ + CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, +} + +# ========================= +# ROCSOLVER ALGORITHMS +# ========================= +""" + ROCSOLVER_HouseholderQR(; positive = false) + +Algorithm type to denote the standard ROCSOLVER algorithm for computing the QR decomposition of +a matrix using Householder reflectors. The keyword `positive=true` can be used to ensure that +the diagonal elements of `R` are non-negative. +""" +@algdef ROCSOLVER_HouseholderQR + +""" + ROCSOLVER_QRIteration(; fixgauge::Bool = true) + +Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +QR Iteration algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef ROCSOLVER_QRIteration + +""" + ROCSOLVER_Jacobi(; fixgauge::Bool = true) + +Algorithm type to denote the ROCSOLVER driver for computing the singular value decomposition of +a general matrix using the Jacobi algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular +vectors, see also [`gaugefix!`](@ref). +""" +@algdef ROCSOLVER_Jacobi + +""" + ROCSOLVER_Bisection(; fixgauge::Bool = true) + +Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Bisection algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef ROCSOLVER_Bisection + +""" + ROCSOLVER_DivideAndConquer(; fixgauge::Bool = true) + +Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a +Hermitian matrix, or the singular value decomposition of a general matrix using the +Divide and Conquer algorithm. +The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or +singular vectors, see also [`gaugefix!`](@ref). +""" +@algdef ROCSOLVER_DivideAndConquer + +const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} + +# Various consts and unions +# ------------------------- + +const GPU_Simple = Union{CUSOLVER_Simple} +const GPU_EigAlgorithm = Union{GPU_Simple} +const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} +const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} +const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} +const GPU_Bisection = Union{ROCSOLVER_Bisection} +const GPU_EighAlgorithm = Union{ + GPU_QRIteration, GPU_Jacobi, GPU_DivideAndConquer, GPU_Bisection, +} +const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} + +const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} +const GPU_Randomized = Union{CUSOLVER_Randomized} + +const QRAlgorithms = Union{LAPACK_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} +const LQAlgorithms = Union{LAPACK_HouseholderLQ, LQViaTransposedQR} +const SVDAlgorithms = Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm} +const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} + +# ================================ +# ORTHOGONALIZATION ALGORITHMS +# ================================ + +""" + LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) + +Wrapper type to denote the `Kind` of factorization that is used as a backend for [`left_orth`](@ref). +By default `Kind` is a symbol, which can be either `:qr`, `:polar` or `:svd`. +""" +struct LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end +LeftOrthAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = LeftOrthAlgorithm{Kind, Alg}(alg) + +# Note: specific algorithm selection is handled by `left_orth_alg` in orthnull.jl +LeftOrthAlgorithm(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `left_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type for `left_orth`, define + + MatrixAlgebraKit.left_orth_alg(alg::CustomAlgorithm) = LeftOrthAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:qr`, `:polar` or `:svd`, to select [`qr_compact!`](@ref), + [`left_polar!`](@ref), [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) + +const LeftOrthViaQR = LeftOrthAlgorithm{:qr} +const LeftOrthViaPolar = LeftOrthAlgorithm{:polar} +const LeftOrthViaSVD = LeftOrthAlgorithm{:svd} + +""" + RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) + +Wrapper type to denote the `Kind` of factorization that is used as a backend for [`right_orth`](@ref). +By default `Kind` is a symbol, which can be either `:lq`, `:polar` or `:svd`. +""" +struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end +RightOrthAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = RightOrthAlgorithm{Kind, Alg}(alg) + +# Note: specific algorithm selection is handled by `right_orth_alg` in orthnull.jl +RightOrthAlgorithm(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `right_orth` algorithm type `$(typeof(alg))`. + To register the algorithm type for `right_orth`, define + + MatrixAlgebraKit.right_orth_alg(alg::CustomAlgorithm) = RightOrthAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:lq`, `:polar` or `:svd`, to select [`lq_compact!`](@ref), + [`right_polar!`](@ref), [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) + +const RightOrthViaLQ = RightOrthAlgorithm{:lq} +const RightOrthViaPolar = RightOrthAlgorithm{:polar} +const RightOrthViaSVD = RightOrthAlgorithm{:svd} + +""" + LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) + +Wrapper type to denote the `Kind` of factorization that is used as a backend for [`left_null`](@ref). +By default `Kind` is a symbol, which can be either `:qr` or `:svd`. +""" +struct LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end +LeftNullAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = LeftNullAlgorithm{Kind, Alg}(alg) + +# Note: specific algorithm selection is handled by `left_null_alg` in orthnull.jl +LeftNullAlgorithm(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `left_null` algorithm type `$(typeof(alg))`. + To register the algorithm type for `left_null`, define + + MatrixAlgebraKit.left_null_alg(alg::CustomAlgorithm) = LeftNullAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:qr` or `:svd`, to select [`qr_null!`](@ref), + [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) + +const LeftNullViaQR = LeftNullAlgorithm{:qr} +const LeftNullViaSVD = LeftNullAlgorithm{:svd} + +""" + RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) + +Wrapper type to denote the `Kind` of factorization that is used as a backend for [`right_null`](@ref). +By default `Kind` is a symbol, which can be either `:lq` or `:svd`. +""" +struct RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm + alg::Alg +end +RightNullAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = RightNullAlgorithm{Kind, Alg}(alg) + +# Note: specific algorithm selection is handled by `right_null_alg` in orthnull.jl +RightNullAlgorithm(alg::AbstractAlgorithm) = error( + """ + Unkown or invalid `right_null` algorithm type `$(typeof(alg))`. + To register the algorithm type for `right_null`, define + + MatrixAlgebraKit.right_null_alg(alg::CustomAlgorithm) = RightNullAlgorithm{kind}(alg) + + where `kind` selects the factorization type that will be used. + By default, this is either `:lq` or `:svd`, to select [`lq_null!`](@ref), + [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. + """ +) + +const RightNullViaLQ = RightNullAlgorithm{:lq} +const RightNullViaSVD = RightNullAlgorithm{:svd} diff --git a/src/interface/matrixfunctions.jl b/src/interface/matrixfunctions.jl index a54f3109..ce24652d 100644 --- a/src/interface/matrixfunctions.jl +++ b/src/interface/matrixfunctions.jl @@ -1,358 +1,3 @@ -# TODO: module Decompositions? - -# ================= -# LAPACK ALGORITHMS -# ================= - -# reference for naming LAPACK algorithms: -# https://www.netlib.org/lapack/explore-html/topics.html - -# QR, LQ, QL, RQ Decomposition -# ---------------------------- -""" - LAPACK_HouseholderQR(; blocksize, positive = false, pivoted = false) - -Algorithm type to denote the standard LAPACK algorithm for computing the QR decomposition of -a matrix using Householder reflectors. The specific LAPACK function can be controlled using -the keyword arugments, i.e. `?geqrt` will be chosen if `blocksize > 1`. With -`blocksize == 1`, `?geqrf` will be chosen if `pivoted == false` and `?geqp3` will be chosen -if `pivoted == true`. The keyword `positive = true` can be used to ensure that the diagonal -elements of `R` are non-negative. -""" -@algdef LAPACK_HouseholderQR - -""" - LAPACK_HouseholderLQ(; blocksize, positive = false) - -Algorithm type to denote the standard LAPACK algorithm for computing the LQ decomposition of -a matrix using Householder reflectors. The specific LAPACK function can be controlled using -the keyword arugments, i.e. `?gelqt` will be chosen if `blocksize > 1` or `?gelqf` will be -chosen if `blocksize == 1`. The keyword `positive = true` can be used to ensure that the diagonal -elements of `L` are non-negative. -""" -@algdef LAPACK_HouseholderLQ - -""" - GLA_HouseholderQR(; positive = false) - -Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the QR decomposition -of a matrix using Householder reflectors. Currently, only `blocksize = 1` and `pivoted == false` -are supported. The keyword `positive = true` can be used to ensure that the diagonal elements -of `R` are non-negative. -""" -@algdef GLA_HouseholderQR - -# TODO: -@algdef LAPACK_HouseholderQL -@algdef LAPACK_HouseholderRQ - -# General Eigenvalue Decomposition -# ------------------------------- -""" - LAPACK_Simple(; fixgauge::Bool = true) - -Algorithm type to denote the simple LAPACK driver for computing the Schur or non-Hermitian -eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). -""" -@algdef LAPACK_Simple - -""" - LAPACK_Expert(; fixgauge::Bool = true) - -Algorithm type to denote the expert LAPACK driver for computing the Schur or non-Hermitian -eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). -""" -@algdef LAPACK_Expert - -const LAPACK_EigAlgorithm = Union{LAPACK_Simple, LAPACK_Expert} - -""" - GS_QRIteration() - -Algorithm type to denote the GenericSchur.jl implementation for computing the -eigenvalue decomposition of a non-Hermitian matrix. -""" -@algdef GS_QRIteration - -# Hermitian Eigenvalue Decomposition -# ---------------------------------- -""" - LAPACK_QRIteration(; fixgauge::Bool = true) - -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef LAPACK_QRIteration - -""" - LAPACK_Bisection(; fixgauge::Bool = true) - -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Bisection algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef LAPACK_Bisection - -""" - LAPACK_DivideAndConquer(; fixgauge::Bool = true) - -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef LAPACK_DivideAndConquer - -""" - LAPACK_MultipleRelativelyRobustRepresentations(; fixgauge::Bool = true) - -Algorithm type to denote the LAPACK driver for computing the eigenvalue decomposition of a -Hermitian matrix using the Multiple Relatively Robust Representations algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). -""" -@algdef LAPACK_MultipleRelativelyRobustRepresentations - -const LAPACK_EighAlgorithm = Union{ - LAPACK_QRIteration, - LAPACK_Bisection, - LAPACK_DivideAndConquer, - LAPACK_MultipleRelativelyRobustRepresentations, -} - -""" - GLA_QRIteration(; fixgauge::Bool = true) - -Algorithm type to denote the GenericLinearAlgebra.jl implementation for computing the -eigenvalue decomposition of a Hermitian matrix, or the singular value decomposition of -a general matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef GLA_QRIteration - -# Singular Value Decomposition -# ---------------------------- -""" - LAPACK_Jacobi(; fixgauge::Bool = true) - -Algorithm type to denote the LAPACK driver for computing the singular value decomposition of -a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular vectors, -see also [`gaugefix!`](@ref). -""" -@algdef LAPACK_Jacobi - -const LAPACK_SVDAlgorithm = Union{ - LAPACK_QRIteration, - LAPACK_Bisection, - LAPACK_DivideAndConquer, - LAPACK_Jacobi, -} - -# ========================= -# Polar decompositions -# ========================= -""" - PolarViaSVD(svd_alg) - -Algorithm for computing the polar decomposition of a matrix `A` via the singular value -decomposition (SVD) of `A`. The `svd_alg` argument specifies the SVD algorithm to use. -""" -struct PolarViaSVD{SVDAlg} <: AbstractAlgorithm - svd_alg::SVDAlg -end - -""" - PolarNewton(; maxiter = 10, tol = defaulttol(A)) - -Algorithm for computing the polar decomposition of a matrix `A` via -scaled Newton iteration, with a maximum of `maxiter` iterations and -until convergence up to tolerance `tol`. -""" -@algdef PolarNewton - -# ========================= -# Varia -# ========================= -""" - DiagonalAlgorithm(; kwargs...) - -Algorithm type to denote a native Julia implementation of the decompositions making use of -the diagonal structure of the input and outputs. -""" -@algdef DiagonalAlgorithm - -""" - LQViaTransposedQR(qr_alg) - -Algorithm type to denote finding the LQ decomposition of `A` by computing the QR decomposition of `Aᵀ`. -The `qr_alg` specifies which QR-decomposition implementation to use. -""" -struct LQViaTransposedQR{A <: AbstractAlgorithm} <: AbstractAlgorithm - qr_alg::A -end -function Base.show(io::IO, alg::LQViaTransposedQR) - print(io, "LQViaTransposedQR(") - _show_alg(io, alg.qr_alg) - return print(io, ")") -end - -# ========================= -# CUSOLVER ALGORITHMS -# ========================= -""" - CUSOLVER_HouseholderQR(; positive = false) - -Algorithm type to denote the standard CUSOLVER algorithm for computing the QR decomposition of -a matrix using Householder reflectors. The keyword `positive = true` can be used to ensure that -the diagonal elements of `R` are non-negative. -""" -@algdef CUSOLVER_HouseholderQR - -""" - CUSOLVER_QRIteration(; fixgauge::Bool = true) - -Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef CUSOLVER_QRIteration - -""" - CUSOLVER_SVDPolar(; fixgauge::Bool = true) - -Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of -a general matrix by using Halley's iterative algorithm to compute the polar decompositon, -followed by the hermitian eigenvalue decomposition of the positive definite factor. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). -""" -@algdef CUSOLVER_SVDPolar - -""" - CUSOLVER_Jacobi(; fixgauge::Bool = true) - -Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of -a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). -""" -@algdef CUSOLVER_Jacobi - -""" - CUSOLVER_Randomized(; k, p, niters) - -Algorithm type to denote the CUSOLVER driver for computing the singular value decomposition of -a general matrix using the randomized SVD algorithm. Here, `k` denotes the number of singular -values that should be computed, therefore requiring `k <= min(size(A))`. This method is accurate -for small values of `k` compared to the size of the input matrix, where the accuracy can be -improved by increasing `p`, the number of additional values used for oversampling, -and `niters`, the number of iterations the solver uses, at the cost of increasing the runtime. - -See also the [CUSOLVER documentation](https://docs.nvidia.com/cuda/cusolver/index.html#cusolverdnxgesvdr) -for more information. -""" -@algdef CUSOLVER_Randomized - -does_truncate(::TruncatedAlgorithm{<:CUSOLVER_Randomized}) = true - -""" - CUSOLVER_Simple(; fixgauge::Bool = true) - -Algorithm type to denote the simple CUSOLVER driver for computing the non-Hermitian -eigenvalue decomposition of a matrix. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigenvectors, -see also [`gaugefix!`](@ref). -""" -@algdef CUSOLVER_Simple - -const CUSOLVER_EigAlgorithm = Union{CUSOLVER_Simple} - -""" - CUSOLVER_DivideAndConquer(; fixgauge::Bool = true) - -Algorithm type to denote the CUSOLVER driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef CUSOLVER_DivideAndConquer - -const CUSOLVER_SVDAlgorithm = Union{ - CUSOLVER_QRIteration, CUSOLVER_SVDPolar, CUSOLVER_Jacobi, CUSOLVER_Randomized, -} - -# ========================= -# ROCSOLVER ALGORITHMS -# ========================= -""" - ROCSOLVER_HouseholderQR(; positive = false) - -Algorithm type to denote the standard ROCSOLVER algorithm for computing the QR decomposition of -a matrix using Householder reflectors. The keyword `positive=true` can be used to ensure that -the diagonal elements of `R` are non-negative. -""" -@algdef ROCSOLVER_HouseholderQR - -""" - ROCSOLVER_QRIteration(; fixgauge::Bool = true) - -Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -QR Iteration algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef ROCSOLVER_QRIteration - -""" - ROCSOLVER_Jacobi(; fixgauge::Bool = true) - -Algorithm type to denote the ROCSOLVER driver for computing the singular value decomposition of -a general matrix using the Jacobi algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the singular -vectors, see also [`gaugefix!`](@ref). -""" -@algdef ROCSOLVER_Jacobi - -""" - ROCSOLVER_Bisection(; fixgauge::Bool = true) - -Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Bisection algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef ROCSOLVER_Bisection - -""" - ROCSOLVER_DivideAndConquer(; fixgauge::Bool = true) - -Algorithm type to denote the ROCSOLVER driver for computing the eigenvalue decomposition of a -Hermitian matrix, or the singular value decomposition of a general matrix using the -Divide and Conquer algorithm. -The `fixgauge` keyword can be used to toggle whether or not to fix the gauge of the eigen or -singular vectors, see also [`gaugefix!`](@ref). -""" -@algdef ROCSOLVER_DivideAndConquer - -const ROCSOLVER_SVDAlgorithm = Union{ROCSOLVER_QRIteration, ROCSOLVER_Jacobi} - # ================================ # EXPONENTIAL ALGORITHMS # ================================ @@ -392,143 +37,3 @@ function Base.show(io::IO, alg::MatrixFunctionViaEig) _show_alg(io, alg.eig_alg) return print(io, ")") end - -# Various consts and unions -# ------------------------- - -const GPU_Simple = Union{CUSOLVER_Simple} -const GPU_EigAlgorithm = Union{GPU_Simple} -const GPU_QRIteration = Union{CUSOLVER_QRIteration, ROCSOLVER_QRIteration} -const GPU_Jacobi = Union{CUSOLVER_Jacobi, ROCSOLVER_Jacobi} -const GPU_DivideAndConquer = Union{CUSOLVER_DivideAndConquer, ROCSOLVER_DivideAndConquer} -const GPU_Bisection = Union{ROCSOLVER_Bisection} -const GPU_EighAlgorithm = Union{ - GPU_QRIteration, GPU_Jacobi, GPU_DivideAndConquer, GPU_Bisection, -} -const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} - -const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} -const GPU_Randomized = Union{CUSOLVER_Randomized} - -const QRAlgorithms = Union{LAPACK_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} -const LQAlgorithms = Union{LAPACK_HouseholderLQ, LQViaTransposedQR} -const SVDAlgorithms = Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm} -const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} - -# ================================ -# ORTHOGONALIZATION ALGORITHMS -# ================================ - -""" - LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) - -Wrapper type to denote the `Kind` of factorization that is used as a backend for [`left_orth`](@ref). -By default `Kind` is a symbol, which can be either `:qr`, `:polar` or `:svd`. -""" -struct LeftOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm - alg::Alg -end -LeftOrthAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = LeftOrthAlgorithm{Kind, Alg}(alg) - -# Note: specific algorithm selection is handled by `left_orth_alg` in orthnull.jl -LeftOrthAlgorithm(alg::AbstractAlgorithm) = error( - """ - Unkown or invalid `left_orth` algorithm type `$(typeof(alg))`. - To register the algorithm type for `left_orth`, define - - MatrixAlgebraKit.left_orth_alg(alg::CustomAlgorithm) = LeftOrthAlgorithm{kind}(alg) - - where `kind` selects the factorization type that will be used. - By default, this is either `:qr`, `:polar` or `:svd`, to select [`qr_compact!`](@ref), - [`left_polar!`](@ref), [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. - """ -) - -const LeftOrthViaQR = LeftOrthAlgorithm{:qr} -const LeftOrthViaPolar = LeftOrthAlgorithm{:polar} -const LeftOrthViaSVD = LeftOrthAlgorithm{:svd} - -""" - RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) - -Wrapper type to denote the `Kind` of factorization that is used as a backend for [`right_orth`](@ref). -By default `Kind` is a symbol, which can be either `:lq`, `:polar` or `:svd`. -""" -struct RightOrthAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm - alg::Alg -end -RightOrthAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = RightOrthAlgorithm{Kind, Alg}(alg) - -# Note: specific algorithm selection is handled by `right_orth_alg` in orthnull.jl -RightOrthAlgorithm(alg::AbstractAlgorithm) = error( - """ - Unkown or invalid `right_orth` algorithm type `$(typeof(alg))`. - To register the algorithm type for `right_orth`, define - - MatrixAlgebraKit.right_orth_alg(alg::CustomAlgorithm) = RightOrthAlgorithm{kind}(alg) - - where `kind` selects the factorization type that will be used. - By default, this is either `:lq`, `:polar` or `:svd`, to select [`lq_compact!`](@ref), - [`right_polar!`](@ref), [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. - """ -) - -const RightOrthViaLQ = RightOrthAlgorithm{:lq} -const RightOrthViaPolar = RightOrthAlgorithm{:polar} -const RightOrthViaSVD = RightOrthAlgorithm{:svd} - -""" - LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) - -Wrapper type to denote the `Kind` of factorization that is used as a backend for [`left_null`](@ref). -By default `Kind` is a symbol, which can be either `:qr` or `:svd`. -""" -struct LeftNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm - alg::Alg -end -LeftNullAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = LeftNullAlgorithm{Kind, Alg}(alg) - -# Note: specific algorithm selection is handled by `left_null_alg` in orthnull.jl -LeftNullAlgorithm(alg::AbstractAlgorithm) = error( - """ - Unkown or invalid `left_null` algorithm type `$(typeof(alg))`. - To register the algorithm type for `left_null`, define - - MatrixAlgebraKit.left_null_alg(alg::CustomAlgorithm) = LeftNullAlgorithm{kind}(alg) - - where `kind` selects the factorization type that will be used. - By default, this is either `:qr` or `:svd`, to select [`qr_null!`](@ref), - [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. - """ -) - -const LeftNullViaQR = LeftNullAlgorithm{:qr} -const LeftNullViaSVD = LeftNullAlgorithm{:svd} - -""" - RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm}(alg) - -Wrapper type to denote the `Kind` of factorization that is used as a backend for [`right_null`](@ref). -By default `Kind` is a symbol, which can be either `:lq` or `:svd`. -""" -struct RightNullAlgorithm{Kind, Alg <: AbstractAlgorithm} <: AbstractAlgorithm - alg::Alg -end -RightNullAlgorithm{Kind}(alg::Alg) where {Kind, Alg <: AbstractAlgorithm} = RightNullAlgorithm{Kind, Alg}(alg) - -# Note: specific algorithm selection is handled by `right_null_alg` in orthnull.jl -RightNullAlgorithm(alg::AbstractAlgorithm) = error( - """ - Unkown or invalid `right_null` algorithm type `$(typeof(alg))`. - To register the algorithm type for `right_null`, define - - MatrixAlgebraKit.right_null_alg(alg::CustomAlgorithm) = RightNullAlgorithm{kind}(alg) - - where `kind` selects the factorization type that will be used. - By default, this is either `:lq` or `:svd`, to select [`lq_null!`](@ref), - [`svd_compact!`](@ref) or [`svd_trunc!`](@ref) respectively. - """ -) - -const RightNullViaLQ = RightNullAlgorithm{:lq} -const RightNullViaSVD = RightNullAlgorithm{:svd} diff --git a/src/matrixfunctions.jl b/src/matrixfunctions.jl deleted file mode 100644 index 8b137891..00000000 --- a/src/matrixfunctions.jl +++ /dev/null @@ -1 +0,0 @@ - From cf98bd4df962c8b1286a759e3c705ab816386440 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Dec 2025 15:03:41 -0500 Subject: [PATCH 18/20] avoid running non-GPU tests through buildkite --- test/runtests.jl | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index a9151a1d..e17700d9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -115,24 +115,28 @@ if AMDGPU.functional() end using GenericLinearAlgebra -@safetestset "QR / LQ Decomposition" begin - include("genericlinearalgebra/qr.jl") - include("genericlinearalgebra/lq.jl") -end -@safetestset "Singular Value Decomposition" begin - include("genericlinearalgebra/svd.jl") -end -@safetestset "Hermitian Eigenvalue Decomposition" begin - include("genericlinearalgebra/eigh.jl") -end -@safetestset "Exponential" begin - include("genericlinearalgebra/exponential.jl") +if !is_buildkite + @safetestset "QR / LQ Decomposition" begin + include("genericlinearalgebra/qr.jl") + include("genericlinearalgebra/lq.jl") + end + @safetestset "Singular Value Decomposition" begin + include("genericlinearalgebra/svd.jl") + end + @safetestset "Hermitian Eigenvalue Decomposition" begin + include("genericlinearalgebra/eigh.jl") + end + @safetestset "Exponential" begin + include("genericlinearalgebra/exponential.jl") + end end using GenericSchur -@safetestset "General Eigenvalue Decomposition" begin - include("genericschur/eig.jl") -end -@safetestset "Exponential" begin - include("genericschur/exponential.jl") +if !is_buildkite + @safetestset "General Eigenvalue Decomposition" begin + include("genericschur/eig.jl") + end + @safetestset "Exponential" begin + include("genericschur/exponential.jl") + end end From 1536eb4150788ba2292696c093dcbe6da0f402a6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Dec 2025 15:33:23 -0500 Subject: [PATCH 19/20] correct wrong in-place assumptions --- test/genericlinearalgebra/exponential.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl index c0720ba8..7166a412 100644 --- a/test/genericlinearalgebra/exponential.jl +++ b/test/genericlinearalgebra/exponential.jl @@ -16,9 +16,7 @@ GenericFloats = (BigFloat, Complex{BigFloat}) D, V = @constinferred eigh_full(A) algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) @testset "algorithm $alg" for alg in algs - expA = similar(A) - - @constinferred exponential!(copy(A), expA; alg) + expA = @constinferred exponential!(copy(A), expA; alg) expA2 = @constinferred exponential(A; alg) @test expA2 ≈ expA @@ -39,9 +37,7 @@ using GenericSchur D, V = @constinferred eigh_full(A) algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) @testset "algorithm $alg" for alg in algs - expiτA = similar(complex(A)) - - @constinferred exponentiali!(τ, copy(A), expiτA; alg) + expiτA = @constinferred exponentiali!(τ, copy(A), expiτA; alg) expiτA2 = @constinferred exponentiali(τ, A; alg) @test expiτA2 ≈ expiτA From 349800ea54a538467c470a740cc346509fce5ffb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 3 Dec 2025 15:36:42 -0500 Subject: [PATCH 20/20] fixes part II --- test/genericlinearalgebra/exponential.jl | 4 ++-- test/genericschur/exponential.jl | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/test/genericlinearalgebra/exponential.jl b/test/genericlinearalgebra/exponential.jl index 7166a412..c8a09508 100644 --- a/test/genericlinearalgebra/exponential.jl +++ b/test/genericlinearalgebra/exponential.jl @@ -16,7 +16,7 @@ GenericFloats = (BigFloat, Complex{BigFloat}) D, V = @constinferred eigh_full(A) algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) @testset "algorithm $alg" for alg in algs - expA = @constinferred exponential!(copy(A), expA; alg) + expA = @constinferred exponential!(copy(A); alg) expA2 = @constinferred exponential(A; alg) @test expA2 ≈ expA @@ -37,7 +37,7 @@ using GenericSchur D, V = @constinferred eigh_full(A) algs = (MatrixFunctionViaEigh(GLA_QRIteration()),) @testset "algorithm $alg" for alg in algs - expiτA = @constinferred exponentiali!(τ, copy(A), expiτA; alg) + expiτA = @constinferred exponentiali!(τ, copy(A); alg) expiτA2 = @constinferred exponentiali(τ, A; alg) @test expiτA2 ≈ expiτA diff --git a/test/genericschur/exponential.jl b/test/genericschur/exponential.jl index 7e272a5e..9bbfa5d9 100644 --- a/test/genericschur/exponential.jl +++ b/test/genericschur/exponential.jl @@ -16,9 +16,7 @@ GenericFloats = (BigFloat, Complex{BigFloat}) algs = (MatrixFunctionViaEig(GS_QRIteration()),) expA_LA = @constinferred exponential(A) @testset "algorithm $alg" for alg in algs - expA = similar(A) - - @constinferred exponential!(copy(A), expA) + expA = @constinferred exponential!(copy(A)) expA2 = @constinferred exponential(A; alg = alg) @test expA ≈ expA_LA @test expA2 ≈ expA @@ -38,9 +36,7 @@ end D, V = @constinferred eig_full(A) algs = (MatrixFunctionViaEig(GS_QRIteration()),) @testset "algorithm $alg" for alg in algs - expiτA = similar(complex(A)) - - @constinferred exponentiali!(τ, copy(A), expiτA) + expiτA = @constinferred exponentiali!(τ, copy(A)) expiτA2 = @constinferred exponentiali(τ, A; alg) @test expiτA2 ≈ expiτA