Skip to content
219 changes: 154 additions & 65 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower
RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap
using Random: rand!

_fix_size(M, nrow, ncol) = M

# An immutable fixed size wrapper for matrices to work around
# the performance issue caused by https://github.com/JuliaLang/julia/issues/60409
# This is more-of-less a stripped down version of FixedSizeArrays
# which we can't easily use without pulling that into the standard library.
struct _FixedSizeMatrix{Trans,R}
ref::R
nrow::Int
ncol::Int
function _FixedSizeMatrix{Trans}(ref::R, nrow, ncol) where {Trans,R}
new{Trans,R}(ref, nrow, ncol)
end
end
@inline Base.getindex(A::_FixedSizeMatrix{'N'}, i, j) =
@inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[]
@inline Base.setindex!(A::_FixedSizeMatrix{'N'}, v, i, j) =
@inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[] = v

@inline Base.getindex(A::_FixedSizeMatrix{'T'}, i, j) =
@inbounds transpose(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[])
@inline Base.setindex!(A::_FixedSizeMatrix{'T'}, v, i, j) =
@inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = transpose(v)

@inline Base.getindex(A::_FixedSizeMatrix{'C'}, i, j) =
@inbounds adjoint(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[])
@inline Base.setindex!(A::_FixedSizeMatrix{'C'}, v, i, j) =
@inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = adjoint(v)

@inline _fix_size(A::Matrix, nrow, ncol) = _FixedSizeMatrix{'N'}(A.ref, nrow, ncol)
@inline _fix_size(A::Transpose{<:Any,<:Matrix}, nrow, ncol) =
_FixedSizeMatrix{'T'}(A.parent.ref, nrow, ncol)
@inline _fix_size(A::Adjoint{<:Any,<:Matrix}, nrow, ncol) =
_FixedSizeMatrix{'C'}(A.parent.ref, nrow, ncol)

const tilebufsize = 10800 # Approximately 32k/3

# In matrix-vector multiplication, the correct orientation of the vector is assumed.
Expand Down Expand Up @@ -74,47 +109,94 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
return C
end

# Slow non-inlined functions for throwing the error without messing up the caller
@noinline function _matmul_size_error(mC, nC, mA, nA, mB, nB, At, Bt)
if At == 'N'
Anames = "first", "second"
else
Anames = "second", "first"
end
if Bt == 'N'
Bnames = "first", "second"
else
Bnames = "second", "first"
end
nA == mB ||
throw(DimensionMismatch("$(Anames[2]) dimension of A, $nA, does not match the $(Bnames[1]) dimension of B, $mB"))
mA == mC ||
throw(DimensionMismatch("$(Anames[1]) dimension of A, $mA, does not match the first dimension of C, $mC"))
nB == nC ||
throw(DimensionMismatch("$(Bnames[2]) dimension of B, $nB, does not match the second dimension of C, $nC"))
# unreachable
throw(DimensionMismatch("Unknown dimension mismatch"))
end

@inline function _matmul_size(C, A, B, ::Val{At}, ::Val{Bt}) where {At,Bt}
mC = size(C, 1)
nC = size(C, 2)
mA = size(A, 1)
nA = size(A, 2)
mB = size(B, 1)
nB = size(B, 2)

_mA, _nA = At == 'N' ? (mA, nA) : (nA, mA)
_mB, _nB = Bt == 'N' ? (mB, nB) : (nB, mB)

if (_nA != _mB) | (_mA != mC) | (_nB != nC)
_matmul_size_error(mC, nC, _mA, _nA, _mB, _nB, At, Bt)
end
return mC, nC, mA, nA, mB, nB
end

@inline _matmul_size_AB(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('N'))
@inline _matmul_size_AtB(C, A, B) = _matmul_size(C, A, B, Val('T'), Val('N'))
@inline _matmul_size_ABt(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('T'))

function _spmatmul!(C, A, B, α, β)
size(A, 2) == size(B, 1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))"))
size(A, 1) == size(C, 1) ||
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))"))
size(B, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
Cax2 = axes(C, 2)
Aax2 = axes(A, 2)
mC, nC, mA, nA, mB, nB = _matmul_size_AB(C, A, B)
nzv = nonzeros(A)
rv = rowvals(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for k in axes(C, 2)
@inbounds for col in axes(A,2)
αxj = B[col,k] * α
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
B = _fix_size(B, mB, nB)
C = _fix_size(C, mC, nC)
for k in Cax2
@inbounds for col in Aax2
αxj = α isa Bool ? B[col,k] : B[col,k] * α
for j in nzrange(A, col)
C[rv[j], k] += nzv[j]*αxj
rvj = rv[j]
C[rvj, k] = muladd(nzv[j], αxj, C[rvj, k])
end
end
end
C
end

function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
size(A, 2) == size(C, 1) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))"))
size(A, 1) == size(B, 1) ||
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of B, $(size(B,1))"))
size(B, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
Cax2 = axes(C, 2)
Aax2 = axes(A, 2)
mC, nC, mA, nA, mB, nB = _matmul_size_AtB(C, A, B)
nzv = nonzeros(A)
rv = rowvals(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for k in axes(C, 2)
@inbounds for col in axes(A,2)
tmp = zero(eltype(C))
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
C0 = zero(eltype(C)) # Pre-allocate for BigFloat/BigInt etc
B = _fix_size(B, mB, nB)
C = _fix_size(C, mC, nC)
for k in Cax2
@inbounds for col in Aax2
tmp = C0
for j in nzrange(A, col)
tmp += tfun(nzv[j])*B[rv[j],k]
tmp = muladd(tfun(nzv[j]), B[rv[j], k], tmp)
end
C[col,k] += tmp * α
C[col, k] = α isa Bool ? tmp + C[col, k] : muladd(tmp, α, C[col, k])
end
end
C
end

Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number)
Expand All @@ -129,63 +211,71 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB
return C
end
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
mX == size(C, 1) ||
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
size(A, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
Aax2 = axes(A, 2)
Xax1 = axes(X, 1)
mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A)
rv = rowvals(A)
nzv = nonzeros(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
@inbounds for col in axes(A,2), k in nzrange(A, col)
Aiα = nzv[k] * α
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
C = _fix_size(C, mC, nC)
X = _fix_size(X, mX, nX)
@inbounds for col in Aax2, k in nzrange(A, col)
Aiα = α isa Bool ? nzv[k] : nzv[k] * α
rvk = rv[k]
@simd for multivec_row in axes(X,1)
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
@simd for multivec_row in Xax1
C[multivec_row, col] = muladd(X[multivec_row, rvk], Aiα,
C[multivec_row, col])
end
end
C
end
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
mX == size(C, 1) ||
throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))"))
size(A, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))"))
Xax1 = axes(X, 1)
Cax2 = axes(C, 2)
mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A)
rv = rowvals(A)
nzv = nonzeros(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
for multivec_row in axes(X,1), col in axes(C, 2)
@inbounds for k in nzrange(A, col)
C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
C = _fix_size(C, mC, nC)
X = _fix_size(X, mX, nX)
@inbounds for multivec_row in Xax1, col in Cax2
nzrng = nzrange(A, col)
if isempty(nzrng)
continue
end
tmp = C[multivec_row, col]
for k in nzrng
tmp = muladd(X[multivec_row, rv[k]],
(α isa Bool ? nzv[k] : nzv[k] * α), tmp)
end
C[multivec_row, col] = tmp
end
C
end

function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
mA, nA = size(A)
nA == size(B, 2) ||
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
mA == size(C, 1) ||
throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of C, $(size(C,1))"))
size(B, 1) == size(C, 2) ||
throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))
Bax2 = axes(B, 2)
Aax1 = axes(A, 1)
mC, nC, mA, nA, mB, nB = _matmul_size_ABt(C, A, B)
rv = rowvals(B)
nzv = nonzeros(B)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
@inbounds for col in axes(B, 2), k in nzrange(B, col)
Biα = tfun(nzv[k]) * α
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
if α isa Bool && !α
return
end
C = _fix_size(C, mC, nC)
A = _fix_size(A, mA, nA)
@inbounds for col in Bax2, k in nzrange(B, col)
Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α
rvk = rv[k]
@simd for multivec_col in axes(A,1)
C[multivec_col, rvk] += A[multivec_col, col] * Biα
@simd for multivec_col in Aax1
C[multivec_col, rvk] = muladd(A[multivec_col, col], Biα, C[multivec_col, rvk])
end
end
C
end

function *(A::Diagonal, b::AbstractSparseVector)
Expand Down Expand Up @@ -1240,7 +1330,7 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
rv = rowvals(A)
nzv = nonzeros(A)
let z = T(0), sumcol=z, αxj=z, aarc=z, α = α
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)
isone(β) || LinearAlgebra._rmul_or_fill!(C, β)
@inbounds for k in axes(B,2)
for col in axes(B,1)
αxj = B[col,k] * α
Expand All @@ -1259,7 +1349,6 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
end
end
end
C
end

# row range up to (and including if excl=false) diagonal
Expand Down
23 changes: 10 additions & 13 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1930,9 +1930,9 @@ function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector,
"Matrix A has $n columns, but vector x has a length $(length(x))"))
length(y) == m || throw(DimensionMismatch(
"Matrix A has $m rows, but vector y has a length $(length(y))"))
m == 0 && return y
m == 0 && return
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
_iszero(α) && return y
_iszero(α) && return

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
Expand All @@ -1946,7 +1946,6 @@ function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector,
end
end
end
return y
end

function _At_or_Ac_mul_B!(tfun::Function,
Expand All @@ -1958,14 +1957,14 @@ function _At_or_Ac_mul_B!(tfun::Function,
"Matrix A has $n rows, but vector x has a length $(length(x))"))
length(y) == m || throw(DimensionMismatch(
"Matrix A has $m columns, but vector y has a length $(length(y))"))
m == 0 && return y
m == 0 && return
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
_iszero(α) && return y
_iszero(α) && return

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
_nnz = length(xnzind)
_nnz == 0 && return y
_nnz == 0 && return

Ty = promote_op(matprod, eltype(A), eltype(x))
@inbounds for j = 1:m
Expand All @@ -1975,7 +1974,7 @@ function _At_or_Ac_mul_B!(tfun::Function,
end
y[j] += s * α
end
return y
return
end

function *(A::AdjOrTrans{<:Any,<:StridedMatrix}, x::AbstractSparseVector)
Expand Down Expand Up @@ -2053,9 +2052,9 @@ function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSpars
"Matrix A has $n columns, but vector x has a length $(length(x))"))
length(y) == m || throw(DimensionMismatch(
"Matrix A has $m rows, but vector y has a length $(length(y))"))
m == 0 && return y
m == 0 && return
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
_iszero(α) && return y
_iszero(α) && return

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
Expand All @@ -2073,7 +2072,6 @@ function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSpars
end
end
end
return y
end

function _At_or_Ac_mul_B!(tfun::Function,
Expand All @@ -2085,9 +2083,9 @@ function _At_or_Ac_mul_B!(tfun::Function,
"Matrix A has $n columns, but vector x has a length $(length(x))"))
length(y) == n || throw(DimensionMismatch(
"Matrix A has $m rows, but vector y has a length $(length(y))"))
n == 0 && return y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this type of change? In LinearAlgebra, the style seems to be to return the result array. That's also what the docstring suggests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I see it is tested that mul!(C, ...) === C. This seems to be handled by upper-level functions then.

n == 0 && return
β != one(β) && LinearAlgebra._rmul_or_fill!(y, β)
_iszero(α) && return y
_iszero(α) && return

xnzind = nonzeroinds(x)
xnzval = nonzeros(x)
Expand All @@ -2102,7 +2100,6 @@ function _At_or_Ac_mul_B!(tfun::Function,
1, mx, xnzind, xnzval)
@inbounds y[j] += s * α
end
return y
end


Expand Down
Loading
Loading