Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 74 additions & 72 deletions src/yalapack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module YALAPACK # Yet another lapack wrapper

using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK,
LAPACKException, SingularException, PosDefException, checksquare, chkstride1,
require_one_based_indexing, triu!, isposdef, adjoint!
require_one_based_indexing, triu!, isposdef, adjoint!, rmul!

using LinearAlgebra.BLAS: @blasfunc, libblastrampoline
using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror
Expand All @@ -20,66 +20,66 @@ const BlasMat{T <: BlasFloat} = StridedMatrix{T}
# type alias for matrices that are possibly supported by YALAPACK, after conversion
const MaybeBlasMat = Union{BlasMat, AbstractMatrix{<:Integer}}

# LU factorisation
for (getrf, getrs, elty) in (
(:dgetrf_, :dgetrs_, :Float64),
(:sgetrf_, :sgetrs_, :Float32),
(:zgetrf_, :zgetrs_, :ComplexF64),
(:cgetrf_, :cgetrs_, :ComplexF32),
)
@eval begin
function getrf!(
A::AbstractMatrix{$elty}, ipiv::AbstractVector{BlasInt};
check::Bool = true
)
require_one_based_indexing(A, ipiv)
chkstride1(A, ipiv)
chkfinite(A)
m, n = size(A)
# LU factorisation (currently unused in MatrixAlgebraKit)
# for (getrf, getrs, elty) in (
# (:dgetrf_, :dgetrs_, :Float64),
# (:sgetrf_, :sgetrs_, :Float32),
# (:zgetrf_, :zgetrs_, :ComplexF64),
# (:cgetrf_, :cgetrs_, :ComplexF32),
# )
# @eval begin
# function getrf!(
# A::AbstractMatrix{$elty}, ipiv::AbstractVector{BlasInt};
# check::Bool = true
# )
# require_one_based_indexing(A, ipiv)
# chkstride1(A, ipiv)
# chkfinite(A)
# m, n = size(A)

lda = max(1, stride(A, 2))
info = Ref{BlasInt}()
ccall(
(@blasfunc($getrf), libblastrampoline), Cvoid,
(
Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty},
Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
),
m, n, A, lda, ipiv, info
)
chkargsok(info[])
return A, ipiv, info[] #Error code is stored in LU factorization type
end
function getrs!(
trans::AbstractChar, A::AbstractMatrix{$elty},
ipiv::AbstractVector{BlasInt}, B::AbstractVecOrMat{$elty}
)
require_one_based_indexing(A, ipiv, B)
chktrans(trans)
chkstride1(A, B, ipiv)
n = checksquare(A)
if n != size(B, 1)
throw(DimensionMismatch(lazy"B has leading dimension $(size(B,1)), but needs $n"))
end
if n != length(ipiv)
throw(DimensionMismatch(lazy"ipiv has length $(length(ipiv)), but needs to be $n"))
end
nrhs = size(B, 2)
info = Ref{BlasInt}()
ccall(
(@blasfunc($getrs), libblastrampoline), Cvoid,
(
Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Clong,
),
trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B,
max(1, stride(B, 2)), info, 1
)
chklapackerror(info[])
return B
end
end
end
# lda = max(1, stride(A, 2))
# info = Ref{BlasInt}()
# ccall(
# (@blasfunc($getrf), libblastrampoline), Cvoid,
# (
# Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty},
# Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt},
# ),
# m, n, A, lda, ipiv, info
# )
# chkargsok(info[])
# return A, ipiv, info[] #Error code is stored in LU factorization type
# end
# function getrs!(
# trans::AbstractChar, A::AbstractMatrix{$elty},
# ipiv::AbstractVector{BlasInt}, B::AbstractVecOrMat{$elty}
# )
# require_one_based_indexing(A, ipiv, B)
# chktrans(trans)
# chkstride1(A, B, ipiv)
# n = checksquare(A)
# if n != size(B, 1)
# throw(DimensionMismatch(lazy"B has leading dimension $(size(B,1)), but needs $n"))
# end
# if n != length(ipiv)
# throw(DimensionMismatch(lazy"ipiv has length $(length(ipiv)), but needs to be $n"))
# end
# nrhs = size(B, 2)
# info = Ref{BlasInt}()
# ccall(
# (@blasfunc($getrs), libblastrampoline), Cvoid,
# (
# Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
# Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Clong,
# ),
# trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B,
# max(1, stride(B, 2)), info, 1
# )
# chklapackerror(info[])
# return B
# end
# end
# end

# LQ, RQ, QL, and QR factorisation
const DEFAULT_QR_BLOCKSIZE = 36
Expand Down Expand Up @@ -451,16 +451,16 @@ for (gemqr, gemlq, ungqr, unglq, ungql, ungrq, unmqr, unmlq, unmql, unmrq, gemqr
k = min(mA, nA)

if side == 'L' && mC != mA
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the first dimension of A, $mA"))
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $mC, must equal the first dimension of A, $mA"))
end
if side == 'R' && nC != mA
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the first dimension of A, $mA"))
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $nC, must equal the first dimension of A, $mA"))
end
if side == 'L' && k > mC
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m"))
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $mC"))
end
if side == 'R' && k > nC
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $n"))
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $nC"))
end
lda = max(1, stride(A, 2))
ldc = max(1, stride(C, 2))
Expand Down Expand Up @@ -503,16 +503,16 @@ for (gemqr, gemlq, ungqr, unglq, ungql, ungrq, unmqr, unmlq, unmql, unmrq, gemqr
k = min(mA, nA)

if side == 'L' && mC != nA
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the second dimension of A, $nA"))
throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $mC, must equal the second dimension of A, $nA"))
end
if side == 'R' && nC != nA
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the second dimension of A, $nA"))
throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $nC, must equal the second dimension of A, $nA"))
end
if side == 'L' && k > mC
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m"))
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $mC"))
end
if side == 'R' && k > nC
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $n"))
throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $nC"))
end
lda = max(1, stride(A, 2))
ldc = max(1, stride(C, 2))
Expand Down Expand Up @@ -1170,6 +1170,7 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in
n = checksquare(A)
chkuplofinite(A, uplo)
if haskey(kwargs, :irange)
irange = convert(UnitRange{Int}, kwargs[:irange])
il = first(irange)
iu = last(irange)
vl = vu = zero($relty)
Expand Down Expand Up @@ -2143,6 +2144,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
m, n = size(A)
minmn = min(m, n)
if haskey(kwargs, :irange)
irange = convert(UnitRange{Int}, kwargs[:irange])
il = first(irange)
iu = last(irange)
vl = vu = zero($relty)
Expand Down Expand Up @@ -2276,15 +2278,15 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
end
end
length(S) == n ||
throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))"))
throw(DimensionMismatch("length mismatch between A ($n) and S ($(length(S)))"))

lda = max(1, stride(A, 2))
mv = Ref{BlasInt}() # unused
if jobv == 'V'
if U !== A
V = view(U, 1:n, 1:n) # use U as V storage
else
V = view(similar(V), 1:n, 1:n)
V = view(similar(Vᴴ), 1:n, 1:n)
end
else
V = Vᴴ # doesn't matter, V is not used
Expand Down Expand Up @@ -2342,12 +2344,12 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in
if cmplx
if !isone(rwork[1])
@warn "singular values might have underflowed or overflowed"
LinearAlgebra.rmul!(S, rwork[1])
rmul!(S, rwork[1])
end
else
if !isone(work[1])
@warn "singular values might have underflowed or overflowed"
LinearAlgebra.rmul!(S, work[1])
rmul!(S, work[1])
end
end
if jobu == 'U' && U !== A
Expand Down
4 changes: 2 additions & 2 deletions test/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ function stabilize_eigvals!(D::AbstractVector)
end
n = maximum(p)
# rescale eigenvalues so that they lie on distinct radii in the complex plane
# that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n
radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n
# that are chosen randomly in non-overlapping intervals [10 * k/n, 10 * (k+0.5)/n)] for k=1,...,n
radii = 10 .* ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n
for i in 1:length(D)
D[i] = sign(D[i]) * radii[p[i]]
end
Expand Down