diff --git a/src/linalg.jl b/src/linalg.jl index 0bc9f7ce..75bcdf52 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap using Random: rand! +_fix_size(M, nrow, ncol) = M + +# An immutable fixed size wrapper for matrices to work around +# the performance issue caused by https://github.com/JuliaLang/julia/issues/60409 +# This is more-of-less a stripped down version of FixedSizeArrays +# which we can't easily use without pulling that into the standard library. +struct _FixedSizeMatrix{Trans,R} + ref::R + nrow::Int + ncol::Int + function _FixedSizeMatrix{Trans}(ref::R, nrow, ncol) where {Trans,R} + new{Trans,R}(ref, nrow, ncol) + end +end +@inline Base.getindex(A::_FixedSizeMatrix{'N'}, i, j) = + @inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[] +@inline Base.setindex!(A::_FixedSizeMatrix{'N'}, v, i, j) = + @inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[] = v + +@inline Base.getindex(A::_FixedSizeMatrix{'T'}, i, j) = + @inbounds transpose(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[]) +@inline Base.setindex!(A::_FixedSizeMatrix{'T'}, v, i, j) = + @inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = transpose(v) + +@inline Base.getindex(A::_FixedSizeMatrix{'C'}, i, j) = + @inbounds adjoint(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[]) +@inline Base.setindex!(A::_FixedSizeMatrix{'C'}, v, i, j) = + @inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = adjoint(v) + +@inline _fix_size(A::Matrix, nrow, ncol) = _FixedSizeMatrix{'N'}(A.ref, nrow, ncol) +@inline _fix_size(A::Transpose{<:Any,<:Matrix}, nrow, ncol) = + _FixedSizeMatrix{'T'}(A.parent.ref, nrow, ncol) +@inline _fix_size(A::Adjoint{<:Any,<:Matrix}, nrow, ncol) = + _FixedSizeMatrix{'C'}(A.parent.ref, nrow, ncol) + const tilebufsize = 10800 # Approximately 32k/3 # In matrix-vector multiplication, the correct orientation of the vector is assumed. @@ -74,47 +109,94 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta) return C end +# Slow non-inlined functions for throwing the error without messing up the caller +@noinline function _matmul_size_error(mC, nC, mA, nA, mB, nB, At, Bt) + if At == 'N' + Anames = "first", "second" + else + Anames = "second", "first" + end + if Bt == 'N' + Bnames = "first", "second" + else + Bnames = "second", "first" + end + nA == mB || + throw(DimensionMismatch("$(Anames[2]) dimension of A, $nA, does not match the $(Bnames[1]) dimension of B, $mB")) + mA == mC || + throw(DimensionMismatch("$(Anames[1]) dimension of A, $mA, does not match the first dimension of C, $mC")) + nB == nC || + throw(DimensionMismatch("$(Bnames[2]) dimension of B, $nB, does not match the second dimension of C, $nC")) + # unreachable + throw(DimensionMismatch("Unknown dimension mismatch")) +end + +@inline function _matmul_size(C, A, B, ::Val{At}, ::Val{Bt}) where {At,Bt} + mC = size(C, 1) + nC = size(C, 2) + mA = size(A, 1) + nA = size(A, 2) + mB = size(B, 1) + nB = size(B, 2) + + _mA, _nA = At == 'N' ? (mA, nA) : (nA, mA) + _mB, _nB = Bt == 'N' ? (mB, nB) : (nB, mB) + + if (_nA != _mB) | (_mA != mC) | (_nB != nC) + _matmul_size_error(mC, nC, _mA, _nA, _mB, _nB, At, Bt) + end + return mC, nC, mA, nA, mB, nB +end + +@inline _matmul_size_AB(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('N')) +@inline _matmul_size_AtB(C, A, B) = _matmul_size(C, A, B, Val('T'), Val('N')) +@inline _matmul_size_ABt(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('T')) + function _spmatmul!(C, A, B, α, β) - size(A, 2) == size(B, 1) || - throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))")) - size(A, 1) == size(C, 1) || - throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))")) - size(B, 2) == size(C, 2) || - throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) + Cax2 = axes(C, 2) + Aax2 = axes(A, 2) + mC, nC, mA, nA, mB, nB = _matmul_size_AB(C, A, B) nzv = nonzeros(A) rv = rowvals(A) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) - for k in axes(C, 2) - @inbounds for col in axes(A,2) - αxj = B[col,k] * α + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end + B = _fix_size(B, mB, nB) + C = _fix_size(C, mC, nC) + for k in Cax2 + @inbounds for col in Aax2 + αxj = α isa Bool ? B[col,k] : B[col,k] * α for j in nzrange(A, col) - C[rv[j], k] += nzv[j]*αxj + rvj = rv[j] + C[rvj, k] = muladd(nzv[j], αxj, C[rvj, k]) end end end - C end function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) - size(A, 2) == size(C, 1) || - throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))")) - size(A, 1) == size(B, 1) || - throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of B, $(size(B,1))")) - size(B, 2) == size(C, 2) || - throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) + Cax2 = axes(C, 2) + Aax2 = axes(A, 2) + mC, nC, mA, nA, mB, nB = _matmul_size_AtB(C, A, B) nzv = nonzeros(A) rv = rowvals(A) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) - for k in axes(C, 2) - @inbounds for col in axes(A,2) - tmp = zero(eltype(C)) + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end + C0 = zero(eltype(C)) # Pre-allocate for BigFloat/BigInt etc + B = _fix_size(B, mB, nB) + C = _fix_size(C, mC, nC) + for k in Cax2 + @inbounds for col in Aax2 + tmp = C0 for j in nzrange(A, col) - tmp += tfun(nzv[j])*B[rv[j],k] + tmp = muladd(tfun(nzv[j]), B[rv[j], k], tmp) end - C[col,k] += tmp * α + C[col, k] = α isa Bool ? tmp + C[col, k] : muladd(tmp, α, C[col, k]) end end - C end Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number) @@ -129,63 +211,71 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB return C end function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number) - mX, nX = size(X) - nX == size(A, 1) || - throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))")) - mX == size(C, 1) || - throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))")) - size(A, 2) == size(C, 2) || - throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))")) + Aax2 = axes(A, 2) + Xax1 = axes(X, 1) + mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A) rv = rowvals(A) nzv = nonzeros(A) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) - @inbounds for col in axes(A,2), k in nzrange(A, col) - Aiα = nzv[k] * α + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end + C = _fix_size(C, mC, nC) + X = _fix_size(X, mX, nX) + @inbounds for col in Aax2, k in nzrange(A, col) + Aiα = α isa Bool ? nzv[k] : nzv[k] * α rvk = rv[k] - @simd for multivec_row in axes(X,1) - C[multivec_row, col] += X[multivec_row, rvk] * Aiα + @simd for multivec_row in Xax1 + C[multivec_row, col] = muladd(X[multivec_row, rvk], Aiα, + C[multivec_row, col]) end end - C end function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number) - mX, nX = size(X) - nX == size(A, 1) || - throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))")) - mX == size(C, 1) || - throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))")) - size(A, 2) == size(C, 2) || - throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))")) + Xax1 = axes(X, 1) + Cax2 = axes(C, 2) + mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A) rv = rowvals(A) nzv = nonzeros(A) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) - for multivec_row in axes(X,1), col in axes(C, 2) - @inbounds for k in nzrange(A, col) - C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end + C = _fix_size(C, mC, nC) + X = _fix_size(X, mX, nX) + @inbounds for multivec_row in Xax1, col in Cax2 + nzrng = nzrange(A, col) + if isempty(nzrng) + continue + end + tmp = C[multivec_row, col] + for k in nzrng + tmp = muladd(X[multivec_row, rv[k]], + (α isa Bool ? nzv[k] : nzv[k] * α), tmp) end + C[multivec_row, col] = tmp end - C end function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number) - mA, nA = size(A) - nA == size(B, 2) || - throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))")) - mA == size(C, 1) || - throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of C, $(size(C,1))")) - size(B, 1) == size(C, 2) || - throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) + Bax2 = axes(B, 2) + Aax1 = axes(A, 1) + mC, nC, mA, nA, mB, nB = _matmul_size_ABt(C, A, B) rv = rowvals(B) nzv = nonzeros(B) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) - @inbounds for col in axes(B, 2), k in nzrange(B, col) - Biα = tfun(nzv[k]) * α + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end + C = _fix_size(C, mC, nC) + A = _fix_size(A, mA, nA) + @inbounds for col in Bax2, k in nzrange(B, col) + Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α rvk = rv[k] - @simd for multivec_col in axes(A,1) - C[multivec_col, rvk] += A[multivec_col, col] * Biα + @simd for multivec_col in Aax1 + C[multivec_col, rvk] = muladd(A[multivec_col, col], Biα, C[multivec_col, rvk]) end end - C end function *(A::Diagonal, b::AbstractSparseVector) @@ -1240,7 +1330,7 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided rv = rowvals(A) nzv = nonzeros(A) let z = T(0), sumcol=z, αxj=z, aarc=z, α = α - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) @inbounds for k in axes(B,2) for col in axes(B,1) αxj = B[col,k] * α @@ -1259,7 +1349,6 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided end end end - C end # row range up to (and including if excl=false) diagonal diff --git a/src/sparsevector.jl b/src/sparsevector.jl index 6e47de67..908a14da 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -1930,9 +1930,9 @@ function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector, "Matrix A has $n columns, but vector x has a length $(length(x))")) length(y) == m || throw(DimensionMismatch( "Matrix A has $m rows, but vector y has a length $(length(y))")) - m == 0 && return y + m == 0 && return β != one(β) && LinearAlgebra._rmul_or_fill!(y, β) - _iszero(α) && return y + _iszero(α) && return xnzind = nonzeroinds(x) xnzval = nonzeros(x) @@ -1946,7 +1946,6 @@ function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector, end end end - return y end function _At_or_Ac_mul_B!(tfun::Function, @@ -1958,14 +1957,14 @@ function _At_or_Ac_mul_B!(tfun::Function, "Matrix A has $n rows, but vector x has a length $(length(x))")) length(y) == m || throw(DimensionMismatch( "Matrix A has $m columns, but vector y has a length $(length(y))")) - m == 0 && return y + m == 0 && return β != one(β) && LinearAlgebra._rmul_or_fill!(y, β) - _iszero(α) && return y + _iszero(α) && return xnzind = nonzeroinds(x) xnzval = nonzeros(x) _nnz = length(xnzind) - _nnz == 0 && return y + _nnz == 0 && return Ty = promote_op(matprod, eltype(A), eltype(x)) @inbounds for j = 1:m @@ -1975,7 +1974,7 @@ function _At_or_Ac_mul_B!(tfun::Function, end y[j] += s * α end - return y + return end function *(A::AdjOrTrans{<:Any,<:StridedMatrix}, x::AbstractSparseVector) @@ -2053,9 +2052,9 @@ function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSpars "Matrix A has $n columns, but vector x has a length $(length(x))")) length(y) == m || throw(DimensionMismatch( "Matrix A has $m rows, but vector y has a length $(length(y))")) - m == 0 && return y + m == 0 && return β != one(β) && LinearAlgebra._rmul_or_fill!(y, β) - _iszero(α) && return y + _iszero(α) && return xnzind = nonzeroinds(x) xnzval = nonzeros(x) @@ -2073,7 +2072,6 @@ function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSpars end end end - return y end function _At_or_Ac_mul_B!(tfun::Function, @@ -2085,9 +2083,9 @@ function _At_or_Ac_mul_B!(tfun::Function, "Matrix A has $n columns, but vector x has a length $(length(x))")) length(y) == n || throw(DimensionMismatch( "Matrix A has $m rows, but vector y has a length $(length(y))")) - n == 0 && return y + n == 0 && return β != one(β) && LinearAlgebra._rmul_or_fill!(y, β) - _iszero(α) && return y + _iszero(α) && return xnzind = nonzeroinds(x) xnzval = nonzeros(x) @@ -2102,7 +2100,6 @@ function _At_or_Ac_mul_B!(tfun::Function, 1, mx, xnzind, xnzval) @inbounds y[j] += s * α end - return y end diff --git a/test/linalg.jl b/test/linalg.jl index 1e849305..510b4342 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -1040,4 +1040,47 @@ end @test_throws DimensionMismatch D1 * S * D1 end +@testset "multiplication of sparse and dense matrices" begin + function test_mul(A, B) + expected = Matrix(A) * Matrix(B) + @test A * B ≈ expected + C = similar(expected) + @test mul!(C, A, B) === C + @test C ≈ expected + ElType = eltype(C) + vs = Any[false, true, zero(ElType), one(ElType), one(ElType) + one(ElType)] + for α in vs, β in vs + C .= rand.(ElType) + expected′ = expected .* α .+ C .* β + @test mul!(C, A, B, α, β) === C + @test C ≈ expected′ + end + end + + for ElType in [Int, Float64, ComplexF64, BigFloat] + SP = sprand(ElType, 10, 10, 0.3) + D = rand(ElType, 10, 10) + fs = [identity, adjoint, transpose] + for f1 in fs, f2 in fs + test_mul(f1(SP), f2(D)) + test_mul(f1(D), f2(SP)) + end + end +end + +@testset "dimension mismatch error" begin + fs = [rand, (x, y)->adjoint(rand(y, x)), (x, y)->transpose(rand(y, x)), + (x, y)->sprand(x, y, 0.5), (x, y)->adjoint(sprand(y, x, 0.5)), + (x, y)->transpose(sprand(y, x, 0.5))] + for fA in fs, fB in fs + mul!(zeros(6, 10), fA(6, 8), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(7, 10), fA(6, 8), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 11), fA(6, 8), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 10), fA(5, 8), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 10), fA(6, 9), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 10), fA(6, 8), fB(7, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 10), fA(6, 8), fB(8, 9)) + end +end + end