Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
Expand Down Expand Up @@ -128,10 +128,17 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedROCMatrix, B::Strid
end

MatrixAlgebraKit.ishermitian_exact(A::StridedROCMatrix) = all(A .== adjoint(A))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
all(A.diag .== adjoint(A.diag))
MatrixAlgebraKit.ishermitian_approx(A::StridedROCMatrix; kwargs...) =
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)

MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) = all(A .== -adjoint(A))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
MatrixAlgebraKit.isantihermitian_exact(A::StridedROCMatrix) =
all(A .== -adjoint(A))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedROCVector{T}}) where {T} =
all(A.diag .== -adjoint(A.diag))
MatrixAlgebraKit.isantihermitian_approx(A::StridedROCMatrix; kwargs...) =
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)

function MatrixAlgebraKit._avgdiff!(A::StridedROCMatrix, B::StridedROCMatrix)
axes(A) == axes(B) || throw(DimensionMismatch())
Expand Down
20 changes: 14 additions & 6 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
Expand Down Expand Up @@ -134,11 +134,19 @@ function MatrixAlgebraKit._project_hermitian_diag!(A::StridedCuMatrix, B::Stride
return nothing
end

MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) = all(A .== adjoint(A))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== adjoint(A.diag))

MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) = all(A .== -adjoint(A))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} = all(A.diag .== -adjoint(A.diag))
MatrixAlgebraKit.ishermitian_exact(A::StridedCuMatrix) =
all(A .== adjoint(A))
MatrixAlgebraKit.ishermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
all(A.diag .== adjoint(A.diag))
MatrixAlgebraKit.ishermitian_approx(A::StridedCuMatrix; kwargs...) =
@invoke MatrixAlgebraKit.ishermitian_approx(A::Any; kwargs...)

MatrixAlgebraKit.isantihermitian_exact(A::StridedCuMatrix) =
all(A .== -adjoint(A))
MatrixAlgebraKit.isantihermitian_exact(A::Diagonal{T, <:StridedCuVector{T}}) where {T} =
all(A.diag .== -adjoint(A.diag))
MatrixAlgebraKit.isantihermitian_approx(A::StridedCuMatrix; kwargs...) =
@invoke MatrixAlgebraKit.isantihermitian_approx(A::Any; kwargs...)

function MatrixAlgebraKit._avgdiff!(A::StridedCuMatrix, B::StridedCuMatrix)
axes(A) == axes(B) || throw(DimensionMismatch())
Expand Down
10 changes: 10 additions & 0 deletions src/common/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@ function default_pullback_gaugetol(a)
n = norm(a, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
end

"""
default_hermitian_tol(A)

Default tolerance for deciding to warn if the provided `A` is not hermitian.
"""
function default_hermitian_tol(A)
n = norm(A, Inf)
return eps(eltype(n))^(3 / 4) * max(n, one(n))
end
71 changes: 62 additions & 9 deletions src/common/matrixproperties.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,31 +69,32 @@ is_right_isometric(A; kwargs...) = is_left_isometric(A'; kwargs...)
Test whether a linear map is Hermitian, i.e. `A = A'`.
The `isapprox_kwargs` can be used to control the tolerances of the equality.
"""
function ishermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...)
function ishermitian(A; atol::Real = 0, rtol::Real = 0, kwargs...)
if iszero(atol) && iszero(rtol)
return ishermitian_exact(A; kwargs...)
else
return 2 * norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
return ishermitian_approx(A; atol, rtol, kwargs...)
end
end
function ishermitian_exact(A)
return A == A'
end
function ishermitian_exact(A::StridedMatrix; kwargs...)
return strided_ishermitian_exact(A, Val(false); kwargs...)

ishermitian_exact(A) = A == A'
ishermitian_exact(A::StridedMatrix; kwargs...) = strided_ishermitian_exact(A, Val(false); kwargs...)
function ishermitian_approx(A; atol, rtol, kwargs...)
return norm(project_antihermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
end
ishermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(false); kwargs...)

"""
isantihermitian(A; isapprox_kwargs...)

Test whether a linear map is anti-Hermitian, i.e. `A = -A'`.
The `isapprox_kwargs` can be used to control the tolerances of the equality.
"""
function isantihermitian(A; atol::Real = 0, rtol::Real = 0, norm = LinearAlgebra.norm, kwargs...)
function isantihermitian(A; atol::Real = 0, rtol::Real = 0, kwargs...)
if iszero(atol) && iszero(rtol)
return isantihermitian_exact(A; kwargs...)
else
return 2 * norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
return isantihermitian_approx(A; atol, rtol, kwargs...)
end
end
function isantihermitian_exact(A)
Expand All @@ -102,6 +103,10 @@ end
function isantihermitian_exact(A::StridedMatrix; kwargs...)
return strided_ishermitian_exact(A, Val(true); kwargs...)
end
function isantihermitian_approx(A; atol, rtol, kwargs...)
return norm(project_hermitian(A; kwargs...)) ≤ max(atol, rtol * norm(A))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to use the norm provided in kwargs here? Or do we not want to support this option (would be fine with me, I am very happy with the standard Frobenius norm).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also above for ishermitian_approx.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I want to support that option for now, because I don't think I really want to deal with the strided_ishermitian_approx implementation that would be required for that.

end
isantihermitian_approx(A::StridedMatrix; kwargs...) = strided_ishermitian_approx(A, Val(true); kwargs...)

# blocked implementation of exact checks for strided matrices
# -----------------------------------------------------------
Expand Down Expand Up @@ -139,3 +144,51 @@ function _ishermitian_exact_offdiag(Al, Au, ::Val{anti}) where {anti}
end
return true
end


function strided_ishermitian_approx(
A::AbstractMatrix, anti::Val;
blocksize = 32, atol::Real = default_hermitian_tol(A), rtol::Real = 0
)
n = size(A, 1)
ϵ² = abs2(zero(eltype(A)))
ϵ²max = oftype(ϵ², rtol > 0 ? max(atol, rtol * norm(A)) : atol)^2
for j in 1:blocksize:n
jb = min(blocksize, n - j + 1)
ϵ² += _ishermitian_approx_diag(view(A, j:(j + jb - 1), j:(j + jb - 1)), anti)
ϵ² < ϵ²max || return false
for i in 1:blocksize:(j - 1)
ib = blocksize
ϵ² += 2 * _ishermitian_approx_offdiag(
view(A, i:(i + ib - 1), j:(j + jb - 1)),
view(A, j:(j + jb - 1), i:(i + ib - 1)),
anti
)
ϵ² < ϵ²max || return false
end
end
return true
end

function _ishermitian_approx_diag(A, ::Val{anti}) where {anti}
n = size(A, 1)
ϵ² = abs2(zero(eltype(A)))
@inbounds for j in 1:n
@simd for i in 1:j
val = _project_hermitian(A[i, j], A[j, i], !anti)
ϵ² += abs2(val) * (1 + Int(i < j))
end
end
return ϵ²
end
function _ishermitian_approx_offdiag(Al, Au, ::Val{anti}) where {anti}
m, n = size(Al) # == reverse(size(Al))
ϵ² = abs2(zero(eltype(Al)))
@inbounds for j in 1:n
@simd for i in 1:m
val = _project_hermitian(Al[i, j], Au[j, i], !anti)
ϵ² += abs2(val)
end
end
return ϵ²
end
38 changes: 24 additions & 14 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,40 @@ copy_input(::typeof(eigh_trunc), A) = copy_input(eigh_full, A)

copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
check_hermitian(A, ::AbstractAlgorithm) = check_hermitian(A)
check_hermitian(A, alg::Algorithm) = check_hermitian(A; atol = get(alg.kwargs, :hermitian_tol, default_hermitian_tol(A)))
function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = 0)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
ishermitian(A; atol, rtol) ||
throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix."))
return nothing
end

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
check_hermitian(A, alg)
D, V = DV
m = size(A, 1)
@assert D isa Diagonal && V isa AbstractMatrix
@check_size(D, (m, m))
@check_scalar(D, A, real)
@check_size(V, (m, m))
@check_scalar(V, A)
return nothing
end
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::AbstractAlgorithm)
check_hermitian(A, alg)
m = size(A, 1)
@assert D isa AbstractVector
@check_size(D, (n,))
@check_size(D, (m,))
@check_scalar(D, A, real)
return nothing
end

function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
@assert (eltype(A) <: Real && issymmetric(A)) || ishermitian(A)
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::DiagonalAlgorithm)
check_hermitian(A, alg)
@assert isdiag(A)
m = size(A, 1)
D, V = DV
@assert D isa Diagonal && V isa Diagonal
@check_size(D, (m, m))
Expand All @@ -40,12 +50,12 @@ function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, ::DiagonalAlgo
@check_scalar(V, A)
return nothing
end
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
@assert (eltype(A) <: Real && issymmetric(A)) || ishermitian(A)
function check_input(::typeof(eigh_vals!), A::AbstractMatrix, D, alg::DiagonalAlgorithm)
check_hermitian(A, alg)
@assert isdiag(A)
m = size(A, 1)
@assert D isa AbstractVector
@check_size(D, (n,))
@check_size(D, (m,))
@check_scalar(D, A, real)
return nothing
end
Expand Down
8 changes: 5 additions & 3 deletions src/implementations/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,16 @@ function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::V
return B
end

@inline function _project_hermitian(Aij::Number, Aji::Number, anti::Bool)
return anti ? (Aij - Aji') / 2 : (Aij + Aji') / 2
end
function _project_hermitian_offdiag!(
Au::AbstractMatrix, Al::AbstractMatrix, Bu::AbstractMatrix, Bl::AbstractMatrix, ::Val{anti}
) where {anti}

m, n = size(Au) # == reverse(size(Au))
return @inbounds for j in 1:n
@simd for i in 1:m
val = anti ? (Au[i, j] - adjoint(Al[j, i])) / 2 : (Au[i, j] + adjoint(Al[j, i])) / 2
val = _project_hermitian(Au[i, j], Al[j, i], anti)
Bu[i, j] = val
aval = adjoint(val)
Bl[j, i] = anti ? -aval : aval
Expand All @@ -104,7 +106,7 @@ function _project_hermitian_diag!(A::AbstractMatrix, B::AbstractMatrix, ::Val{an
n = size(A, 1)
@inbounds for j in 1:n
@simd for i in 1:(j - 1)
val = anti ? (A[i, j] - adjoint(A[j, i])) / 2 : (A[i, j] + adjoint(A[j, i])) / 2
val = _project_hermitian(A[i, j], A[j, i], anti)
B[i, j] = val
aval = adjoint(val)
B[j, i] = anti ? -aval : aval
Expand Down
28 changes: 5 additions & 23 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module YALAPACK # Yet another lapack wrapper

using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK,
LAPACKException, SingularException, PosDefException, checksquare, chkstride1,
require_one_based_indexing, triu!, issymmetric, ishermitian, isposdef, adjoint!
require_one_based_indexing, triu!, isposdef, adjoint!

using LinearAlgebra.BLAS: @blasfunc, libblastrampoline
using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror
Expand Down Expand Up @@ -984,16 +984,12 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
A::AbstractMatrix{$elty},
W::AbstractVector{$relty} = similar(A, $relty, size(A, 1)),
V::AbstractMatrix{$elty} = A;
uplo::AbstractChar = 'U'
uplo::AbstractChar = 'U',
kwargs...
) # shouldn't matter but 'U' seems slightly faster than 'L'
require_one_based_indexing(A, V, W)
chkstride1(A, V, W)
n = checksquare(A)
if $elty <: Real
issymmetric(A) || throw(ArgumentError("A must be symmetric"))
else
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
end
chkuplofinite(A, uplo)
n == length(W) || throw(DimensionMismatch("length mismatch between A and W"))
if length(V) == 0
Expand Down Expand Up @@ -1063,11 +1059,6 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
require_one_based_indexing(A, V, W)
chkstride1(A, V, W)
n = checksquare(A)
if $elty <: Real
issymmetric(A) || throw(ArgumentError("A must be symmetric"))
else
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
end
chkuplofinite(A, uplo)
if haskey(kwargs, :irange)
il = first(kwargs[:irange])
Expand Down Expand Up @@ -1175,11 +1166,6 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
require_one_based_indexing(A, V, W)
chkstride1(A, V, W)
n = checksquare(A)
if $elty <: Real
issymmetric(A) || throw(ArgumentError("A must be symmetric"))
else
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
end
chkuplofinite(A, uplo)
if haskey(kwargs, :irange)
il = first(irange)
Expand Down Expand Up @@ -1289,16 +1275,12 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
A::AbstractMatrix{$elty},
W::AbstractVector{$relty} = similar(A, $relty, size(A, 1)),
V::AbstractMatrix{$elty} = A;
uplo::AbstractChar = 'U'
uplo::AbstractChar = 'U',
kwargs...
) # shouldn't matter but 'U' seems slightly faster than 'L'
require_one_based_indexing(A, V, W)
chkstride1(A, V, W)
n = checksquare(A)
if $elty <: Real
issymmetric(A) || throw(ArgumentError("A must be symmetric"))
else
ishermitian(A) || throw(ArgumentError("A must be Hermitian"))
end
uplo = 'U' # shouldn't matter but 'U' seems slightly faster than 'L'
chkuplofinite(A, uplo)
n == length(W) || throw(DimensionMismatch("length mismatch between A and W"))
Expand Down
17 changes: 16 additions & 1 deletion test/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using MatrixAlgebraKit
using Test
using TestExtras
using StableRNGs
using LinearAlgebra: LinearAlgebra, Diagonal, norm
using LinearAlgebra: LinearAlgebra, Diagonal, norm, normalize!

const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)

Expand Down Expand Up @@ -43,6 +43,21 @@ const BLASFloats = (Float32, Float64, ComplexF32, ComplexF64)
@test isantihermitian(Ba)
@test Ba ≈ Aa
end

# test approximate error calculation
A = normalize!(randn(rng, T, m, m))
Ah = project_hermitian(A)
Aa = project_antihermitian(A)

Ah_approx = Ah + noisefactor * Aa
ϵ = norm(project_antihermitian(Ah_approx))
@test !ishermitian(Ah_approx; atol = (999 // 1000) * ϵ)
@test ishermitian(Ah_approx; atol = (1001 // 1000) * ϵ)

Aa_approx = Aa + noisefactor * Ah
ϵ = norm(project_hermitian(Aa_approx))
@test !isantihermitian(Aa_approx; atol = (999 // 1000) * ϵ)
@test isantihermitian(Aa_approx; atol = (1001 // 1000) * ϵ)
end

@testset "project_isometric! for T = $T" for T in BLASFloats
Expand Down