From 7134baad5370c58b10d8b78738feedc1f91a3c3d Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 31 Dec 2025 15:29:16 -0600 Subject: [PATCH 1/9] Use isone instead of comparing to one Better performance for mutable type like BigFloat --- src/linalg.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 0bc9f7ce..983c0fc6 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -83,7 +83,7 @@ function _spmatmul!(C, A, B, α, β) throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) nzv = nonzeros(A) rv = rowvals(A) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) for k in axes(C, 2) @inbounds for col in axes(A,2) αxj = B[col,k] * α @@ -104,7 +104,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) nzv = nonzeros(A) rv = rowvals(A) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) for k in axes(C, 2) @inbounds for col in axes(A,2) tmp = zero(eltype(C)) @@ -138,7 +138,7 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2 throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))")) rv = rowvals(A) nzv = nonzeros(A) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) @inbounds for col in axes(A,2), k in nzrange(A, col) Aiα = nzv[k] * α rvk = rv[k] @@ -158,7 +158,7 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))")) rv = rowvals(A) nzv = nonzeros(A) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) for multivec_row in axes(X,1), col in axes(C, 2) @inbounds for k in nzrange(A, col) C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α @@ -177,7 +177,7 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) rv = rowvals(B) nzv = nonzeros(B) - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) @inbounds for col in axes(B, 2), k in nzrange(B, col) Biα = tfun(nzv[k]) * α rvk = rv[k] @@ -1240,7 +1240,7 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided rv = rowvals(A) nzv = nonzeros(A) let z = T(0), sumcol=z, αxj=z, aarc=z, α = α - β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + isone(β) || LinearAlgebra._rmul_or_fill!(C, β) @inbounds for k in axes(B,2) for col in axes(B,1) αxj = B[col,k] * α From 119314753d20f7369a114dedb12d4b9585e4a9c4 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 31 Dec 2025 15:41:44 -0600 Subject: [PATCH 2/9] Remove unnecessary return values from internal functions --- src/linalg.jl | 6 ------ src/sparsevector.jl | 23 ++++++++++------------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 983c0fc6..06520938 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -92,7 +92,6 @@ function _spmatmul!(C, A, B, α, β) end end end - C end function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) @@ -114,7 +113,6 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) C[col,k] += tmp * α end end - C end Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, alpha::Number, beta::Number) @@ -146,7 +144,6 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2 C[multivec_row, col] += X[multivec_row, rvk] * Aiα end end - C end function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number) mX, nX = size(X) @@ -164,7 +161,6 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α end end - C end function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number) @@ -185,7 +181,6 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B C[multivec_col, rvk] += A[multivec_col, col] * Biα end end - C end function *(A::Diagonal, b::AbstractSparseVector) @@ -1259,7 +1254,6 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided end end end - C end # row range up to (and including if excl=false) diagonal diff --git a/src/sparsevector.jl b/src/sparsevector.jl index 6e47de67..908a14da 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -1930,9 +1930,9 @@ function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector, "Matrix A has $n columns, but vector x has a length $(length(x))")) length(y) == m || throw(DimensionMismatch( "Matrix A has $m rows, but vector y has a length $(length(y))")) - m == 0 && return y + m == 0 && return β != one(β) && LinearAlgebra._rmul_or_fill!(y, β) - _iszero(α) && return y + _iszero(α) && return xnzind = nonzeroinds(x) xnzval = nonzeros(x) @@ -1946,7 +1946,6 @@ function _spmul!(y::AbstractVector, A::AbstractMatrix, x::AbstractSparseVector, end end end - return y end function _At_or_Ac_mul_B!(tfun::Function, @@ -1958,14 +1957,14 @@ function _At_or_Ac_mul_B!(tfun::Function, "Matrix A has $n rows, but vector x has a length $(length(x))")) length(y) == m || throw(DimensionMismatch( "Matrix A has $m columns, but vector y has a length $(length(y))")) - m == 0 && return y + m == 0 && return β != one(β) && LinearAlgebra._rmul_or_fill!(y, β) - _iszero(α) && return y + _iszero(α) && return xnzind = nonzeroinds(x) xnzval = nonzeros(x) _nnz = length(xnzind) - _nnz == 0 && return y + _nnz == 0 && return Ty = promote_op(matprod, eltype(A), eltype(x)) @inbounds for j = 1:m @@ -1975,7 +1974,7 @@ function _At_or_Ac_mul_B!(tfun::Function, end y[j] += s * α end - return y + return end function *(A::AdjOrTrans{<:Any,<:StridedMatrix}, x::AbstractSparseVector) @@ -2053,9 +2052,9 @@ function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSpars "Matrix A has $n columns, but vector x has a length $(length(x))")) length(y) == m || throw(DimensionMismatch( "Matrix A has $m rows, but vector y has a length $(length(y))")) - m == 0 && return y + m == 0 && return β != one(β) && LinearAlgebra._rmul_or_fill!(y, β) - _iszero(α) && return y + _iszero(α) && return xnzind = nonzeroinds(x) xnzval = nonzeros(x) @@ -2073,7 +2072,6 @@ function _spmul!(y::AbstractVector, A::AbstractSparseMatrixCSC, x::AbstractSpars end end end - return y end function _At_or_Ac_mul_B!(tfun::Function, @@ -2085,9 +2083,9 @@ function _At_or_Ac_mul_B!(tfun::Function, "Matrix A has $n columns, but vector x has a length $(length(x))")) length(y) == n || throw(DimensionMismatch( "Matrix A has $m rows, but vector y has a length $(length(y))")) - n == 0 && return y + n == 0 && return β != one(β) && LinearAlgebra._rmul_or_fill!(y, β) - _iszero(α) && return y + _iszero(α) && return xnzind = nonzeroinds(x) xnzval = nonzeros(x) @@ -2102,7 +2100,6 @@ function _At_or_Ac_mul_B!(tfun::Function, 1, mx, xnzind, xnzval) @inbounds y[j] += s * α end - return y end From 715dd78fff3e2481f16e9fcb86fd0cd0badf5c4f Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 31 Dec 2025 16:10:00 -0600 Subject: [PATCH 3/9] Helper function to get sizes of all input matrices and outline error throwing function --- src/linalg.jl | 81 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 33 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 06520938..f84b373a 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -74,13 +74,51 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta) return C end +# Slow non-inlined functions for throwing the error without messing up the caller +@noinline function _matmul_size_error(mC, nC, mA, nA, mB, nB, At, Bt) + if At == 'N' + Anames = "first", "second" + else + Anames = "second", "first" + end + if Bt == 'N' + Bnames = "first", "second" + else + Bnames = "second", "first" + end + nA == mB || + throw(DimensionMismatch("$(Anames[2]) dimension of A, $nA, does not match the $(Bnames[1]) dimension of B, $mB")) + mA == mC || + throw(DimensionMismatch("$(Anames[1]) dimension of A, $mA, does not match the first dimension of C, $mC")) + nB == nC || + throw(DimensionMismatch("$(Bnames[2]) dimension of B, $nB, does not match the second dimension of C, $nC")) + # unreachable + throw(DimensionMismatch("Unknown dimension mismatch")) +end + +@inline function _matmul_size(C, A, B, ::Val{At}, ::Val{Bt}) where {At,Bt} + mC = size(C, 1) + nC = size(C, 2) + mA = size(A, 1) + nA = size(A, 2) + mB = size(B, 1) + nB = size(B, 2) + + _mA, _nA = At == 'N' ? (mA, nA) : (nA, mA) + _mB, _nB = Bt == 'N' ? (mB, nB) : (nB, mB) + + if (_nA != _mB) | (_mA != mC) | (_nB != nC) + _matmul_size_error(mC, nC, _mA, _nA, _mB, _nB, At, Bt) + end + return mC, nC, mA, nA, mB, nB +end + +@inline _matmul_size_AB(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('N')) +@inline _matmul_size_AtB(C, A, B) = _matmul_size(C, A, B, Val('T'), Val('N')) +@inline _matmul_size_ABt(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('T')) + function _spmatmul!(C, A, B, α, β) - size(A, 2) == size(B, 1) || - throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))")) - size(A, 1) == size(C, 1) || - throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))")) - size(B, 2) == size(C, 2) || - throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) + _matmul_size_AB(C, A, B) nzv = nonzeros(A) rv = rowvals(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) @@ -95,12 +133,7 @@ function _spmatmul!(C, A, B, α, β) end function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) - size(A, 2) == size(C, 1) || - throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of C, $(size(C,1))")) - size(A, 1) == size(B, 1) || - throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of B, $(size(B,1))")) - size(B, 2) == size(C, 2) || - throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) + _matmul_size_AtB(C, A, B) nzv = nonzeros(A) rv = rowvals(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) @@ -127,13 +160,7 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB return C end function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number) - mX, nX = size(X) - nX == size(A, 1) || - throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))")) - mX == size(C, 1) || - throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))")) - size(A, 2) == size(C, 2) || - throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))")) + _matmul_size_AB(C, X, A) rv = rowvals(A) nzv = nonzeros(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) @@ -146,13 +173,7 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2 end end function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number) - mX, nX = size(X) - nX == size(A, 1) || - throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))")) - mX == size(C, 1) || - throw(DimensionMismatch("first dimension of X, $mX, does not match the first dimension of C, $(size(C,1))")) - size(A, 2) == size(C, 2) || - throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the second dimension of C, $(size(C,2))")) + _matmul_size_AB(C, X, A) rv = rowvals(A) nzv = nonzeros(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) @@ -164,13 +185,7 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S end function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number) - mA, nA = size(A) - nA == size(B, 2) || - throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))")) - mA == size(C, 1) || - throw(DimensionMismatch("first dimension of A, $mA, does not match the first dimension of C, $(size(C,1))")) - size(B, 1) == size(C, 2) || - throw(DimensionMismatch("first dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) + _matmul_size_ABt(C, A, B) rv = rowvals(B) nzv = nonzeros(B) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) From 5862843e09d8cea57fbadeb9849da8139182c3ef Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 31 Dec 2025 16:13:40 -0600 Subject: [PATCH 4/9] Pre-compute matrix axes out of the loop --- src/linalg.jl | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index f84b373a..78818020 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -118,12 +118,14 @@ end @inline _matmul_size_ABt(C, A, B) = _matmul_size(C, A, B, Val('N'), Val('T')) function _spmatmul!(C, A, B, α, β) + Cax2 = axes(C, 2) + Aax2 = axes(A, 2) _matmul_size_AB(C, A, B) nzv = nonzeros(A) rv = rowvals(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) - for k in axes(C, 2) - @inbounds for col in axes(A,2) + for k in Cax2 + @inbounds for col in Aax2 αxj = B[col,k] * α for j in nzrange(A, col) C[rv[j], k] += nzv[j]*αxj @@ -133,12 +135,14 @@ function _spmatmul!(C, A, B, α, β) end function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) + Cax2 = axes(C, 2) + Aax2 = axes(A, 2) _matmul_size_AtB(C, A, B) nzv = nonzeros(A) rv = rowvals(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) - for k in axes(C, 2) - @inbounds for col in axes(A,2) + for k in Cax2 + @inbounds for col in Aax2 tmp = zero(eltype(C)) for j in nzrange(A, col) tmp += tfun(nzv[j])*B[rv[j],k] @@ -160,24 +164,28 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix, tA, tB return C end function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number) + Aax2 = axes(A, 2) + Xax1 = axes(X, 1) _matmul_size_AB(C, X, A) rv = rowvals(A) nzv = nonzeros(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) - @inbounds for col in axes(A,2), k in nzrange(A, col) + @inbounds for col in Aax2, k in nzrange(A, col) Aiα = nzv[k] * α rvk = rv[k] - @simd for multivec_row in axes(X,1) + @simd for multivec_row in Xax1 C[multivec_row, col] += X[multivec_row, rvk] * Aiα end end end function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number) + Xax1 = axes(X, 1) + Cax2 = axes(C, 2) _matmul_size_AB(C, X, A) rv = rowvals(A) nzv = nonzeros(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) - for multivec_row in axes(X,1), col in axes(C, 2) + for multivec_row in Xax1, col in Cax2 @inbounds for k in nzrange(A, col) C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α end @@ -185,14 +193,16 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S end function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number) + Bax2 = axes(B, 2) + Aax1 = axes(A, 1) _matmul_size_ABt(C, A, B) rv = rowvals(B) nzv = nonzeros(B) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) - @inbounds for col in axes(B, 2), k in nzrange(B, col) + @inbounds for col in Bax2, k in nzrange(B, col) Biα = tfun(nzv[k]) * α rvk = rv[k] - @simd for multivec_col in axes(A,1) + @simd for multivec_col in Aax1 C[multivec_col, rvk] += A[multivec_col, col] * Biα end end From 8d7ab9c96397ef40282e8f7fd8fd8a31e58b3c11 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 31 Dec 2025 16:22:54 -0600 Subject: [PATCH 5/9] Optimize for alpha being boolean --- src/linalg.jl | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 78818020..67c43338 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -124,9 +124,12 @@ function _spmatmul!(C, A, B, α, β) nzv = nonzeros(A) rv = rowvals(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end for k in Cax2 @inbounds for col in Aax2 - αxj = B[col,k] * α + αxj = α isa Bool ? B[col,k] : B[col,k] * α for j in nzrange(A, col) C[rv[j], k] += nzv[j]*αxj end @@ -141,13 +144,16 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) nzv = nonzeros(A) rv = rowvals(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end for k in Cax2 @inbounds for col in Aax2 tmp = zero(eltype(C)) for j in nzrange(A, col) tmp += tfun(nzv[j])*B[rv[j],k] end - C[col,k] += tmp * α + C[col,k] += α isa Bool ? tmp : tmp * α end end end @@ -170,8 +176,11 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2 rv = rowvals(A) nzv = nonzeros(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end @inbounds for col in Aax2, k in nzrange(A, col) - Aiα = nzv[k] * α + Aiα = α isa Bool ? nzv[k] : nzv[k] * α rvk = rv[k] @simd for multivec_row in Xax1 C[multivec_row, col] += X[multivec_row, rvk] * Aiα @@ -185,9 +194,14 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S rv = rowvals(A) nzv = nonzeros(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end for multivec_row in Xax1, col in Cax2 @inbounds for k in nzrange(A, col) - C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α + C[multivec_row, col] += + (α isa Bool ? X[multivec_row, rv[k]] * nzv[k] : + X[multivec_row, rv[k]] * nzv[k] * α) end end end @@ -199,8 +213,11 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B rv = rowvals(B) nzv = nonzeros(B) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) + if α isa Bool && !α + return + end @inbounds for col in Bax2, k in nzrange(B, col) - Biα = tfun(nzv[k]) * α + Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α rvk = rv[k] @simd for multivec_col in Aax1 C[multivec_col, rvk] += A[multivec_col, col] * Biα From 9426aafabb550b2cd5bc182d2628471e4f93c924 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 31 Dec 2025 16:32:00 -0600 Subject: [PATCH 6/9] Use a fixed size wrapper to workaround julia 1.11+ bug --- src/linalg.jl | 55 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 67c43338..3f116d26 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap using Random: rand! +_fix_size(M, nrow, ncol) = M + +# An immutable fixed size wrapper for matrices to work around +# the performance issue caused by https://github.com/JuliaLang/julia/issues/60409 +# This is more-of-less a stripped down version of FixedSizeArrays +# which we can't easily use without pulling that into the standard library. +struct _FixedSizeMatrix{Trans,R} + ref::R + nrow::Int + ncol::Int + function _FixedSizeMatrix{Trans}(ref::R, nrow, ncol) where {Trans,R} + new{Trans,R}(ref, nrow, ncol) + end +end +@inline Base.getindex(A::_FixedSizeMatrix{'N'}, i, j) = + @inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[] +@inline Base.setindex!(A::_FixedSizeMatrix{'N'}, v, i, j) = + @inbounds Core.memoryrefnew(A.ref, A.nrow * (j - 1) + i, false)[] = v + +@inline Base.getindex(A::_FixedSizeMatrix{'T'}, i, j) = + @inbounds transpose(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[]) +@inline Base.setindex!(A::_FixedSizeMatrix{'T'}, v, i, j) = + @inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = transpose(v) + +@inline Base.getindex(A::_FixedSizeMatrix{'C'}, i, j) = + @inbounds adjoint(Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[]) +@inline Base.setindex!(A::_FixedSizeMatrix{'C'}, v, i, j) = + @inbounds Core.memoryrefnew(A.ref, A.ncol * (i - 1) + j, false)[] = adjoint(v) + +@inline _fix_size(A::Matrix, nrow, ncol) = _FixedSizeMatrix{'N'}(A.ref, nrow, ncol) +@inline _fix_size(A::Transpose{<:Any,<:Matrix}, nrow, ncol) = + _FixedSizeMatrix{'T'}(A.parent.ref, nrow, ncol) +@inline _fix_size(A::Adjoint{<:Any,<:Matrix}, nrow, ncol) = + _FixedSizeMatrix{'C'}(A.parent.ref, nrow, ncol) + const tilebufsize = 10800 # Approximately 32k/3 # In matrix-vector multiplication, the correct orientation of the vector is assumed. @@ -120,13 +155,15 @@ end function _spmatmul!(C, A, B, α, β) Cax2 = axes(C, 2) Aax2 = axes(A, 2) - _matmul_size_AB(C, A, B) + mC, nC, mA, nA, mB, nB = _matmul_size_AB(C, A, B) nzv = nonzeros(A) rv = rowvals(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) if α isa Bool && !α return end + B = _fix_size(B, mB, nB) + C = _fix_size(C, mC, nC) for k in Cax2 @inbounds for col in Aax2 αxj = α isa Bool ? B[col,k] : B[col,k] * α @@ -140,13 +177,15 @@ end function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) Cax2 = axes(C, 2) Aax2 = axes(A, 2) - _matmul_size_AtB(C, A, B) + mC, nC, mA, nA, mB, nB = _matmul_size_AtB(C, A, B) nzv = nonzeros(A) rv = rowvals(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) if α isa Bool && !α return end + B = _fix_size(B, mB, nB) + C = _fix_size(C, mC, nC) for k in Cax2 @inbounds for col in Aax2 tmp = zero(eltype(C)) @@ -172,13 +211,15 @@ end function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number) Aax2 = axes(A, 2) Xax1 = axes(X, 1) - _matmul_size_AB(C, X, A) + mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A) rv = rowvals(A) nzv = nonzeros(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) if α isa Bool && !α return end + C = _fix_size(C, mC, nC) + X = _fix_size(X, mX, nX) @inbounds for col in Aax2, k in nzrange(A, col) Aiα = α isa Bool ? nzv[k] : nzv[k] * α rvk = rv[k] @@ -190,13 +231,15 @@ end function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number) Xax1 = axes(X, 1) Cax2 = axes(C, 2) - _matmul_size_AB(C, X, A) + mC, nC, mX, nX, mA, nA = _matmul_size_AB(C, X, A) rv = rowvals(A) nzv = nonzeros(A) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) if α isa Bool && !α return end + C = _fix_size(C, mC, nC) + X = _fix_size(X, mX, nX) for multivec_row in Xax1, col in Cax2 @inbounds for k in nzrange(A, col) C[multivec_row, col] += @@ -209,13 +252,15 @@ end function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number) Bax2 = axes(B, 2) Aax1 = axes(A, 1) - _matmul_size_ABt(C, A, B) + mC, nC, mA, nA, mB, nB = _matmul_size_ABt(C, A, B) rv = rowvals(B) nzv = nonzeros(B) isone(β) || LinearAlgebra._rmul_or_fill!(C, β) if α isa Bool && !α return end + C = _fix_size(C, mC, nC) + A = _fix_size(A, mA, nA) @inbounds for col in Bax2, k in nzrange(B, col) Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α rvk = rv[k] From 4757d04db7f3a8fc2d5e4315a6ed8eb0d178a306 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Wed, 31 Dec 2025 16:41:12 -0600 Subject: [PATCH 7/9] Use muladd when possible in sparse matrix multiplication --- src/linalg.jl | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 3f116d26..75bcdf52 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -168,7 +168,8 @@ function _spmatmul!(C, A, B, α, β) @inbounds for col in Aax2 αxj = α isa Bool ? B[col,k] : B[col,k] * α for j in nzrange(A, col) - C[rv[j], k] += nzv[j]*αxj + rvj = rv[j] + C[rvj, k] = muladd(nzv[j], αxj, C[rvj, k]) end end end @@ -184,15 +185,16 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β) if α isa Bool && !α return end + C0 = zero(eltype(C)) # Pre-allocate for BigFloat/BigInt etc B = _fix_size(B, mB, nB) C = _fix_size(C, mC, nC) for k in Cax2 @inbounds for col in Aax2 - tmp = zero(eltype(C)) + tmp = C0 for j in nzrange(A, col) - tmp += tfun(nzv[j])*B[rv[j],k] + tmp = muladd(tfun(nzv[j]), B[rv[j], k], tmp) end - C[col,k] += α isa Bool ? tmp : tmp * α + C[col, k] = α isa Bool ? tmp + C[col, k] : muladd(tmp, α, C[col, k]) end end end @@ -224,7 +226,8 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2 Aiα = α isa Bool ? nzv[k] : nzv[k] * α rvk = rv[k] @simd for multivec_row in Xax1 - C[multivec_row, col] += X[multivec_row, rvk] * Aiα + C[multivec_row, col] = muladd(X[multivec_row, rvk], Aiα, + C[multivec_row, col]) end end end @@ -240,12 +243,17 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::S end C = _fix_size(C, mC, nC) X = _fix_size(X, mX, nX) - for multivec_row in Xax1, col in Cax2 - @inbounds for k in nzrange(A, col) - C[multivec_row, col] += - (α isa Bool ? X[multivec_row, rv[k]] * nzv[k] : - X[multivec_row, rv[k]] * nzv[k] * α) + @inbounds for multivec_row in Xax1, col in Cax2 + nzrng = nzrange(A, col) + if isempty(nzrng) + continue end + tmp = C[multivec_row, col] + for k in nzrng + tmp = muladd(X[multivec_row, rv[k]], + (α isa Bool ? nzv[k] : nzv[k] * α), tmp) + end + C[multivec_row, col] = tmp end end @@ -265,7 +273,7 @@ function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B Biα = α isa Bool ? tfun(nzv[k]) : tfun(nzv[k]) * α rvk = rv[k] @simd for multivec_col in Aax1 - C[multivec_col, rvk] += A[multivec_col, col] * Biα + C[multivec_col, rvk] = muladd(A[multivec_col, col], Biα, C[multivec_col, rvk]) end end end From 8c7039d7802b4e9d3288d1c913eeb48ba2d54d08 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Mon, 5 Jan 2026 11:43:59 -0500 Subject: [PATCH 8/9] Add more complete multiplication tests --- test/linalg.jl | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/linalg.jl b/test/linalg.jl index 1e849305..885fda21 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -1040,4 +1040,35 @@ end @test_throws DimensionMismatch D1 * S * D1 end +@testset "multiplication of sparse and dense matrices" begin + function test_mul(A, B) + expected = Matrix(A) * Matrix(B) + @test A * B ≈ expected + C = similar(expected) + @test mul!(C, A, B) === C + @test C ≈ expected + ElType = eltype(C) + vs = Any[false, true, zero(ElType), one(ElType), one(ElType) + one(ElType)] + for α in vs + for β in vs + C .= rand.(ElType) + expected′ = expected .* α .+ C .* β + @test mul!(C, A, B, α, β) === C + @test C ≈ expected′ + end + end + end + + for ElType in [Int, Float64, ComplexF64, BigFloat] + SP = sprand(ElType, 10, 10, 0.3) + D = rand(ElType, 10, 10) + for f1 in [identity, adjoint, transpose] + for f2 in [identity, adjoint, transpose] + test_mul(f1(SP), f2(D)) + test_mul(f1(D), f2(SP)) + end + end + end +end + end From e9ec23cbd887ccf5c7721e13cdd15b029a981803 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Mon, 5 Jan 2026 15:05:22 -0500 Subject: [PATCH 9/9] Add more dimension error check --- test/linalg.jl | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/test/linalg.jl b/test/linalg.jl index 885fda21..510b4342 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -1049,26 +1049,38 @@ end @test C ≈ expected ElType = eltype(C) vs = Any[false, true, zero(ElType), one(ElType), one(ElType) + one(ElType)] - for α in vs - for β in vs - C .= rand.(ElType) - expected′ = expected .* α .+ C .* β - @test mul!(C, A, B, α, β) === C - @test C ≈ expected′ - end + for α in vs, β in vs + C .= rand.(ElType) + expected′ = expected .* α .+ C .* β + @test mul!(C, A, B, α, β) === C + @test C ≈ expected′ end end for ElType in [Int, Float64, ComplexF64, BigFloat] SP = sprand(ElType, 10, 10, 0.3) D = rand(ElType, 10, 10) - for f1 in [identity, adjoint, transpose] - for f2 in [identity, adjoint, transpose] - test_mul(f1(SP), f2(D)) - test_mul(f1(D), f2(SP)) - end + fs = [identity, adjoint, transpose] + for f1 in fs, f2 in fs + test_mul(f1(SP), f2(D)) + test_mul(f1(D), f2(SP)) end end end +@testset "dimension mismatch error" begin + fs = [rand, (x, y)->adjoint(rand(y, x)), (x, y)->transpose(rand(y, x)), + (x, y)->sprand(x, y, 0.5), (x, y)->adjoint(sprand(y, x, 0.5)), + (x, y)->transpose(sprand(y, x, 0.5))] + for fA in fs, fB in fs + mul!(zeros(6, 10), fA(6, 8), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(7, 10), fA(6, 8), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 11), fA(6, 8), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 10), fA(5, 8), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 10), fA(6, 9), fB(8, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 10), fA(6, 8), fB(7, 10)) + @test_throws DimensionMismatch mul!(zeros(6, 10), fA(6, 8), fB(8, 9)) + end +end + end