From f940c51726ee27b9f4aad0230e2de8a356545095 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Sun, 28 Dec 2025 23:48:22 +0100 Subject: [PATCH] some yalapack fixes and test tryouts --- src/yalapack.jl | 146 ++++++++++++++++++++++++----------------------- test/ad_utils.jl | 4 +- 2 files changed, 76 insertions(+), 74 deletions(-) diff --git a/src/yalapack.jl b/src/yalapack.jl index 18541d41..3d3613cc 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -10,7 +10,7 @@ module YALAPACK # Yet another lapack wrapper using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK, LAPACKException, SingularException, PosDefException, checksquare, chkstride1, - require_one_based_indexing, triu!, isposdef, adjoint! + require_one_based_indexing, triu!, isposdef, adjoint!, rmul! using LinearAlgebra.BLAS: @blasfunc, libblastrampoline using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror @@ -20,66 +20,66 @@ const BlasMat{T <: BlasFloat} = StridedMatrix{T} # type alias for matrices that are possibly supported by YALAPACK, after conversion const MaybeBlasMat = Union{BlasMat, AbstractMatrix{<:Integer}} -# LU factorisation -for (getrf, getrs, elty) in ( - (:dgetrf_, :dgetrs_, :Float64), - (:sgetrf_, :sgetrs_, :Float32), - (:zgetrf_, :zgetrs_, :ComplexF64), - (:cgetrf_, :cgetrs_, :ComplexF32), - ) - @eval begin - function getrf!( - A::AbstractMatrix{$elty}, ipiv::AbstractVector{BlasInt}; - check::Bool = true - ) - require_one_based_indexing(A, ipiv) - chkstride1(A, ipiv) - chkfinite(A) - m, n = size(A) +# LU factorisation (currently unused in MatrixAlgebraKit) +# for (getrf, getrs, elty) in ( +# (:dgetrf_, :dgetrs_, :Float64), +# (:sgetrf_, :sgetrs_, :Float32), +# (:zgetrf_, :zgetrs_, :ComplexF64), +# (:cgetrf_, :cgetrs_, :ComplexF32), +# ) +# @eval begin +# function getrf!( +# A::AbstractMatrix{$elty}, ipiv::AbstractVector{BlasInt}; +# check::Bool = true +# ) +# require_one_based_indexing(A, ipiv) +# chkstride1(A, ipiv) +# chkfinite(A) +# m, n = size(A) - lda = max(1, stride(A, 2)) - info = Ref{BlasInt}() - ccall( - (@blasfunc($getrf), libblastrampoline), Cvoid, - ( - Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, - Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, - ), - m, n, A, lda, ipiv, info - ) - chkargsok(info[]) - return A, ipiv, info[] #Error code is stored in LU factorization type - end - function getrs!( - trans::AbstractChar, A::AbstractMatrix{$elty}, - ipiv::AbstractVector{BlasInt}, B::AbstractVecOrMat{$elty} - ) - require_one_based_indexing(A, ipiv, B) - chktrans(trans) - chkstride1(A, B, ipiv) - n = checksquare(A) - if n != size(B, 1) - throw(DimensionMismatch(lazy"B has leading dimension $(size(B,1)), but needs $n")) - end - if n != length(ipiv) - throw(DimensionMismatch(lazy"ipiv has length $(length(ipiv)), but needs to be $n")) - end - nrhs = size(B, 2) - info = Ref{BlasInt}() - ccall( - (@blasfunc($getrs), libblastrampoline), Cvoid, - ( - Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, - Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Clong, - ), - trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, - max(1, stride(B, 2)), info, 1 - ) - chklapackerror(info[]) - return B - end - end -end +# lda = max(1, stride(A, 2)) +# info = Ref{BlasInt}() +# ccall( +# (@blasfunc($getrf), libblastrampoline), Cvoid, +# ( +# Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, +# Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}, +# ), +# m, n, A, lda, ipiv, info +# ) +# chkargsok(info[]) +# return A, ipiv, info[] #Error code is stored in LU factorization type +# end +# function getrs!( +# trans::AbstractChar, A::AbstractMatrix{$elty}, +# ipiv::AbstractVector{BlasInt}, B::AbstractVecOrMat{$elty} +# ) +# require_one_based_indexing(A, ipiv, B) +# chktrans(trans) +# chkstride1(A, B, ipiv) +# n = checksquare(A) +# if n != size(B, 1) +# throw(DimensionMismatch(lazy"B has leading dimension $(size(B,1)), but needs $n")) +# end +# if n != length(ipiv) +# throw(DimensionMismatch(lazy"ipiv has length $(length(ipiv)), but needs to be $n")) +# end +# nrhs = size(B, 2) +# info = Ref{BlasInt}() +# ccall( +# (@blasfunc($getrs), libblastrampoline), Cvoid, +# ( +# Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt}, +# Ptr{BlasInt}, Ptr{$elty}, Ref{BlasInt}, Ptr{BlasInt}, Clong, +# ), +# trans, n, size(B, 2), A, max(1, stride(A, 2)), ipiv, B, +# max(1, stride(B, 2)), info, 1 +# ) +# chklapackerror(info[]) +# return B +# end +# end +# end # LQ, RQ, QL, and QR factorisation const DEFAULT_QR_BLOCKSIZE = 36 @@ -451,16 +451,16 @@ for (gemqr, gemlq, ungqr, unglq, ungql, ungrq, unmqr, unmlq, unmql, unmrq, gemqr k = min(mA, nA) if side == 'L' && mC != mA - throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the first dimension of A, $mA")) + throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $mC, must equal the first dimension of A, $mA")) end if side == 'R' && nC != mA - throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the first dimension of A, $mA")) + throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $nC, must equal the first dimension of A, $mA")) end if side == 'L' && k > mC - throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m")) + throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $mC")) end if side == 'R' && k > nC - throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $n")) + throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $nC")) end lda = max(1, stride(A, 2)) ldc = max(1, stride(C, 2)) @@ -503,16 +503,16 @@ for (gemqr, gemlq, ungqr, unglq, ungql, ungrq, unmqr, unmlq, unmql, unmrq, gemqr k = min(mA, nA) if side == 'L' && mC != nA - throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $m, must equal the second dimension of A, $nA")) + throw(DimensionMismatch(lazy"for a left-sided multiplication, the first dimension of C, $mC, must equal the second dimension of A, $nA")) end if side == 'R' && nC != nA - throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $n, must equal the second dimension of A, $nA")) + throw(DimensionMismatch(lazy"for a right-sided multiplication, the second dimension of C, $nC, must equal the second dimension of A, $nA")) end if side == 'L' && k > mC - throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $m")) + throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= m = $mC")) end if side == 'R' && k > nC - throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $n")) + throw(DimensionMismatch(lazy"invalid number of reflectors: k = $k should be <= n = $nC")) end lda = max(1, stride(A, 2)) ldc = max(1, stride(C, 2)) @@ -1170,6 +1170,7 @@ for (heev, heevx, heevr, heevd, hegvd, elty, relty) in n = checksquare(A) chkuplofinite(A, uplo) if haskey(kwargs, :irange) + irange = convert(UnitRange{Int}, kwargs[:irange]) il = first(irange) iu = last(irange) vl = vu = zero($relty) @@ -2143,6 +2144,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in m, n = size(A) minmn = min(m, n) if haskey(kwargs, :irange) + irange = convert(UnitRange{Int}, kwargs[:irange]) il = first(irange) iu = last(irange) vl = vu = zero($relty) @@ -2276,7 +2278,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in end end length(S) == n || - throw(DimensionMismatch("length mismatch between A ($minmn) and S ($(length(S)))")) + throw(DimensionMismatch("length mismatch between A ($n) and S ($(length(S)))")) lda = max(1, stride(A, 2)) mv = Ref{BlasInt}() # unused @@ -2284,7 +2286,7 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in if U !== A V = view(U, 1:n, 1:n) # use U as V storage else - V = view(similar(V), 1:n, 1:n) + V = view(similar(Vᴴ), 1:n, 1:n) end else V = Vᴴ # doesn't matter, V is not used @@ -2342,12 +2344,12 @@ for (gesvd, gesdd, gesvdx, gejsv, gesvj, elty, relty) in if cmplx if !isone(rwork[1]) @warn "singular values might have underflowed or overflowed" - LinearAlgebra.rmul!(S, rwork[1]) + rmul!(S, rwork[1]) end else if !isone(work[1]) @warn "singular values might have underflowed or overflowed" - LinearAlgebra.rmul!(S, work[1]) + rmul!(S, work[1]) end end if jobu == 'U' && U !== A diff --git a/test/ad_utils.jl b/test/ad_utils.jl index 7a7cf39a..fccc6c00 100644 --- a/test/ad_utils.jl +++ b/test/ad_utils.jl @@ -38,8 +38,8 @@ function stabilize_eigvals!(D::AbstractVector) end n = maximum(p) # rescale eigenvalues so that they lie on distinct radii in the complex plane - # that are chosen randomly in non-overlapping intervals [k/n, (k+0.5)/n)] for k=1,...,n - radii = ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n + # that are chosen randomly in non-overlapping intervals [10 * k/n, 10 * (k+0.5)/n)] for k=1,...,n + radii = 10 .* ((1:n) .+ rand(real(eltype(D)), n) ./ 2) ./ n for i in 1:length(D) D[i] = sign(D[i]) * radii[p[i]] end