From df5b54f05ba8e839676151c64deb1b4db0ac5d58 Mon Sep 17 00:00:00 2001 From: qyli Date: Thu, 11 Jan 2024 03:57:36 +0800 Subject: [PATCH 1/4] add multi-threading support for mul!, add! and tsvd! --- .gitignore | 4 +- src/tensors/factorizations.jl | 80 +++++++++++++++++++++++++++++----- src/tensors/linalg.jl | 77 ++++++++++++++++++++++++++++---- src/tensors/vectorinterface.jl | 32 ++++++++++++-- 4 files changed, 170 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index ae7f128c3..07aec3ac7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ *-old __pycache__ .ipynb* -Manifest.toml \ No newline at end of file +Manifest.toml +.vscode +dev \ No newline at end of file diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 03bc20a73..b6def3721 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -438,7 +438,8 @@ end # helper functions -function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}) +function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}; + numthreads::Int=Threads.nthreads()) InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:tsvd!) I = sectortype(t) A = storagetype(t) @@ -446,16 +447,75 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}) Vdata = SectorDict{I,A}() dims = SectorDict{I,Int}() local Σdata - for (c, b) in blocks(t) - U, Σ, V = MatrixAlgebra.svd!(b, alg) - Udata[c] = U - Vdata[c] = V - if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction - Σdata[c] = Σ - else - Σdata = SectorDict(c => Σ) + if numthreads == 1 + for (c, b) in blocks(t) + U, Σ, V = MatrixAlgebra.svd!(b, alg) + Udata[c] = U + Vdata[c] = V + if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction + Σdata[c] = Σ + else + Σdata = SectorDict(c => Σ) + end + dims[c] = length(Σ) + end + elseif numthreads == -1 + tasks = map(blocksectors(t)) do c + Threads.@spawn MatrixAlgebra.svd!(blocks(t)[c], alg) + end + for (c, task) in zip(blocksectors(t), tasks) + U, Σ, V = fetch(task) + Udata[c] = U + Vdata[c] = V + if @isdefined Σdata + Σdata[c] = Σ + else + Σdata = SectorDict(c => Σ) + end + dims[c] = length(Σ) + end + else + Σdata = SectorDict{I,Vector{real(scalartype(t))}}() + + # sort sectors by size + lsc = blocksectors(t) + lsD3 = map(lsc) do c + # O(D1^2D2) or O(D1D2^2) + return min(size(blocks(t)[c])[1]^2 * size(blocks(t)[c])[2], + size(blocks(t)[c])[1] * size(blocks(t)[c])[2]^2) end - dims[c] = length(Σ) + lsc = lsc[sortperm(lsD3; rev=true)] + + # producer + taskref = Ref{Task}() + ch = Channel(; taskref=taskref, spawn=true) do ch + for c in vcat(lsc, fill(nothing, numthreads)) + put!(ch, c) + end + end + + # consumers + Lock = Threads.SpinLock() + tasks = map(1:numthreads) do _ + task = Threads.@spawn while true + c = take!(ch) + isnothing(c) && break + U, Σ, V = MatrixAlgebra.svd!(blocks(t)[c], alg) + + # note inserting keys to dict is not thread safe + lock(Lock) + Udata[c] = U + Vdata[c] = V + Σdata[c] = Σ + dims[c] = length(Σ) + unlock(Lock) + end + errormonitor(task) + end + + wait.(tasks) + wait(taskref[]) + end return Udata, Σdata, Vdata, dims end diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 872f3564a..22288f351 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -239,20 +239,79 @@ end # TensorMap multiplication function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, - tB::AbstractTensorMap, α=true, β=false) + tB::AbstractTensorMap, α=true, β=false; + numthreads::Int64=Threads.nthreads()) if !(codomain(tC) == codomain(tA) && domain(tC) == domain(tB) && domain(tA) == codomain(tB)) throw(SpaceMismatch("$(space(tC)) ≠ $(space(tA)) * $(space(tB))")) end - for c in blocksectors(tC) - if hasblock(tA, c) # then also tB should have such a block - A = block(tA, c) - B = block(tB, c) - C = block(tC, c) - mul!(StridedView(C), StridedView(A), StridedView(B), α, β) - elseif β != one(β) - rmul!(block(tC, c), β) + + if numthreads == 1 + for c in blocksectors(tC) + if hasblock(tA, c) # then also tB should have such a block + A = block(tA, c) + B = block(tB, c) + C = block(tC, c) + mul!(StridedView(C), StridedView(A), StridedView(B), α, β) + elseif β != one(β) + rmul!(block(tC, c), β) + end + end + + elseif numthreads == -1 + Threads.@sync for c in blocksectors(tC) + if hasblock(tA, c) + Threads.@spawn mul!(StridedView(block(tC, c)), + StridedView(block(tA, c)), + StridedView(block(tB, c)), + α, β) + elseif β != one(β) + Threads.@spawn rmul!(block(tC, c), β) + end end + + else + + # sort sectors by size + lsc = blocksectors(tC) + lsD3 = map(lsc) do c + if hasblock(tA, c) + return size(blocks(tA)[c], 1) * size(blocks(tA)[c], 2) * + size(blocks(tB)[c], 2) + else + return size(blocks(tC)[c], 1) * size(blocks(tC)[c], 2) + end + end + lsc = lsc[sortperm(lsD3; rev=true)] + + # producer + taskref = Ref{Task}() + ch = Channel(; taskref=taskref, spawn=true) do ch + for c in vcat(lsc, fill(nothing, numthreads)) + put!(ch, c) + end + end + + # consumers + tasks = map(1:numthreads) do _ + task = Threads.@spawn while true + c = take!(ch) + isnothing(c) && break + + if hasblock(tA, c) + mul!(StridedView(block(tC, c)), + StridedView(block(tA, c)), + StridedView(block(tB, c)), + α, β) + elseif β != one(β) + rmul!(block(tC, c), β) + end + end + return errormonitor(task) + end + + wait.(tasks) + wait(taskref[]) end return tC end diff --git a/src/tensors/vectorinterface.jl b/src/tensors/vectorinterface.jl index b25633524..f4d8d873b 100644 --- a/src/tensors/vectorinterface.jl +++ b/src/tensors/vectorinterface.jl @@ -58,10 +58,36 @@ function VectorInterface.add(ty::AbstractTensorMap, tx::AbstractTensorMap, return VectorInterface.add!(scale!(similar(ty, T), ty, β), tx, α) end function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap, - α::Number, β::Number) + α::Number, β::Number; + numthreads::Int64=1) space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty)) ≠ $(space(tx))")) - for c in blocksectors(tx) - VectorInterface.add!(block(ty, c), block(tx, c), α, β) + if numthreads == 1 + for c in blocksectors(tx) + VectorInterface.add!(block(ty, c), block(tx, c), α, β) + end + elseif numthreads == -1 + Threads.@sync for c in blocksectors(tx) + Threads.@spawn VectorInterface.add!(block(ty, c), block(tx, c), α, β) + end + else + # producer + taskref = Ref{Task}() + ch = Channel(; taskref=taskref, spawn=true) do ch + for c in vcat(blocksectors(tx), fill(nothing, numthreads)) + put!(ch, c) + end + end + # consumers + tasks = map(1:numthreads) do _ + task = Threads.@spawn while true + c = take!(ch) + VectorInterface.add!(block(ty, c), block(tx, c), α, β) + end + return errormonitor(tast) + end + + wait.(tasks) + wait(taskref[]) end return ty end From f23c52f01239b21bece0c21ef660e59182249ac1 Mon Sep 17 00:00:00 2001 From: qyli Date: Fri, 12 Jan 2024 18:58:21 +0800 Subject: [PATCH 2/4] compatible with tensors over ElementarySpace --- src/tensors/factorizations.jl | 13 ++++++------- src/tensors/linalg.jl | 8 ++++---- src/tensors/vectorinterface.jl | 4 ++-- test/Project.toml | 10 ++++++++++ 4 files changed, 22 insertions(+), 13 deletions(-) create mode 100644 test/Project.toml diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index b6def3721..1e37ee92a 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -446,8 +446,8 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}; Udata = SectorDict{I,A}() Vdata = SectorDict{I,A}() dims = SectorDict{I,Int}() - local Σdata - if numthreads == 1 + if numthreads == 1 || length(blocksectors(t)) == 1 + local Σdata for (c, b) in blocks(t) U, Σ, V = MatrixAlgebra.svd!(b, alg) Udata[c] = U @@ -481,8 +481,8 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}; lsc = blocksectors(t) lsD3 = map(lsc) do c # O(D1^2D2) or O(D1D2^2) - return min(size(blocks(t)[c])[1]^2 * size(blocks(t)[c])[2], - size(blocks(t)[c])[1] * size(blocks(t)[c])[2]^2) + return min(size(block(t, c), 1)^2 * size(block(t, c), 2), + size(block(t, c), 1) * size(block(t, c), 2)^2) end lsc = lsc[sortperm(lsD3; rev=true)] @@ -509,13 +509,12 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}; Σdata[c] = Σ dims[c] = length(Σ) unlock(Lock) - end - errormonitor(task) + end + return errormonitor(task) end wait.(tasks) wait(taskref[]) - end return Udata, Σdata, Vdata, dims end diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 22288f351..e25f518cc 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -246,7 +246,7 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, throw(SpaceMismatch("$(space(tC)) ≠ $(space(tA)) * $(space(tB))")) end - if numthreads == 1 + if numthreads == 1 || length(blocksectors(tC)) == 1 for c in blocksectors(tC) if hasblock(tA, c) # then also tB should have such a block A = block(tA, c) @@ -276,10 +276,10 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, lsc = blocksectors(tC) lsD3 = map(lsc) do c if hasblock(tA, c) - return size(blocks(tA)[c], 1) * size(blocks(tA)[c], 2) * - size(blocks(tB)[c], 2) + return size(block(tA, c), 1) * size(block(tA, c), 2) * + size(block(tA, c), 2) else - return size(blocks(tC)[c], 1) * size(blocks(tC)[c], 2) + return size(block(tC, c), 1) * size(block(tC, c), 2) end end lsc = lsc[sortperm(lsD3; rev=true)] diff --git a/src/tensors/vectorinterface.jl b/src/tensors/vectorinterface.jl index f4d8d873b..257ec15fe 100644 --- a/src/tensors/vectorinterface.jl +++ b/src/tensors/vectorinterface.jl @@ -61,7 +61,7 @@ function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number, β::Number; numthreads::Int64=1) space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty)) ≠ $(space(tx))")) - if numthreads == 1 + if numthreads == 1 || length(blocksectors) == 1 for c in blocksectors(tx) VectorInterface.add!(block(ty, c), block(tx, c), α, β) end @@ -83,7 +83,7 @@ function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap, c = take!(ch) VectorInterface.add!(block(ty, c), block(tx, c), α, β) end - return errormonitor(tast) + return errormonitor(task) end wait.(tasks) diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 000000000..5e3ba6a97 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,10 @@ +[deps] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721" +TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" +WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b" From c9a8ef7fc77e17a0312c71367c48e66dc74f32b2 Mon Sep 17 00:00:00 2001 From: qyli Date: Fri, 12 Jan 2024 19:53:22 +0800 Subject: [PATCH 3/4] remove Project.toml --- test/Project.toml | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 test/Project.toml diff --git a/test/Project.toml b/test/Project.toml deleted file mode 100644 index 5e3ba6a97..000000000 --- a/test/Project.toml +++ /dev/null @@ -1,10 +0,0 @@ -[deps] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -HalfIntegers = "f0d1745a-41c9-11e9-1dd9-e5d34d218721" -TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec" -TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" -WignerSymbols = "9f57e263-0b3d-5e2e-b1be-24f2bb48858b" From a4de69ad006faafc20e840180cd7b6cfbaf2855b Mon Sep 17 00:00:00 2001 From: qyli Date: Thu, 8 Feb 2024 02:59:08 +0800 Subject: [PATCH 4/4] update implementations & add multi-threading eigh and interfaces --- src/TensorKit.jl | 7 ++ src/multithreading.jl | 33 ++++++++++ src/tensors/factorizations.jl | 117 +++++++++++++++++++++------------ src/tensors/linalg.jl | 50 +++++--------- src/tensors/vectorinterface.jl | 41 +++++------- 5 files changed, 152 insertions(+), 96 deletions(-) create mode 100644 src/multithreading.jl diff --git a/src/TensorKit.jl b/src/TensorKit.jl index a81aabe32..fb0b39fd8 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -205,10 +205,17 @@ include("planar/planaroperations.jl") # deprecations: to be removed in version 1.0 or sooner include("auxiliary/deprecate.jl") +include("multithreading.jl") # Extensions # ---------- function __init__() @require_extensions + + global nthreads_mul = Threads.nthreads() + global nthreads_eigh = Threads.nthreads() + global nthreads_svd = Threads.nthreads() + global nthreads_add = Threads.nthreads() + end end diff --git a/src/multithreading.jl b/src/multithreading.jl new file mode 100644 index 000000000..7669567c6 --- /dev/null +++ b/src/multithreading.jl @@ -0,0 +1,33 @@ +# global variables to control multi-threading behaviors +global nthreads_mul::Int64 +global nthreads_eigh::Int64 +global nthreads_svd::Int64 +global nthreads_add::Int64 + +function set_num_threads_mul(n::Int64) + @assert 1 ≤ n ≤ Threads.nthreads() + global nthreads_mul = n + return nothing +end +get_num_threads_mul() = nthreads_mul + +function set_num_threads_add(n::Int64) + @assert 1 ≤ n ≤ Threads.nthreads() + global nthreads_add = n + return nothing +end +get_num_threads_add() = nthreads_add + +function set_num_threads_svd(n::Int64) + @assert 1 ≤ n ≤ Threads.nthreads() + global nthreads_svd = n + return nothing +end +get_num_threads_svd() = nthreads_svd + +function set_num_threads_eigh(n::Int64) + @assert 1 ≤ n ≤ Threads.nthreads() + global nthreads_eigh = n + return nothing +end +get_num_threads_eigh() = nthreads_eigh diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 1e37ee92a..24df11e87 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -438,15 +438,16 @@ end # helper functions -function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}; - numthreads::Int=Threads.nthreads()) +function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}) InnerProductStyle(t) === EuclideanProduct() || throw_invalid_innerproduct(:tsvd!) I = sectortype(t) A = storagetype(t) Udata = SectorDict{I,A}() Vdata = SectorDict{I,A}() dims = SectorDict{I,Int}() - if numthreads == 1 || length(blocksectors(t)) == 1 + + num_threads = get_num_threads_svd() + if num_threads == 1 || length(blocksectors(t)) == 1 local Σdata for (c, b) in blocks(t) U, Σ, V = MatrixAlgebra.svd!(b, alg) @@ -459,56 +460,46 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}; end dims[c] = length(Σ) end - elseif numthreads == -1 - tasks = map(blocksectors(t)) do c - Threads.@spawn MatrixAlgebra.svd!(blocks(t)[c], alg) - end - for (c, task) in zip(blocksectors(t), tasks) - U, Σ, V = fetch(task) - Udata[c] = U - Vdata[c] = V - if @isdefined Σdata - Σdata[c] = Σ - else - Σdata = SectorDict(c => Σ) - end - dims[c] = length(Σ) - end else Σdata = SectorDict{I,Vector{real(scalartype(t))}}() - # sort sectors by size + # try to sort sectors by size lsc = blocksectors(t) - lsD3 = map(lsc) do c - # O(D1^2D2) or O(D1D2^2) - return min(size(block(t, c), 1)^2 * size(block(t, c), 2), - size(block(t, c), 1) * size(block(t, c), 2)^2) + if isa(lsc, AbstractVector) + lsD3 = lsD3 = map(lsc) do c + # O(D1^2D2) or O(D1D2^2) + return min(size(block(t, c), 1)^2 * size(block(t, c), 2), + size(block(t, c), 1) * size(block(t, c), 2)^2) + end + lsc = lsc[sortperm(lsD3; rev=true)] end - lsc = lsc[sortperm(lsD3; rev=true)] # producer taskref = Ref{Task}() ch = Channel(; taskref=taskref, spawn=true) do ch - for c in vcat(lsc, fill(nothing, numthreads)) + for c in lsc put!(ch, c) end end # consumers - Lock = Threads.SpinLock() - tasks = map(1:numthreads) do _ - task = Threads.@spawn while true - c = take!(ch) - isnothing(c) && break + Lock = Threads.ReentrantLock() + tasks = map(1:num_threads) do _ + task = Threads.@spawn for c in ch U, Σ, V = MatrixAlgebra.svd!(blocks(t)[c], alg) # note inserting keys to dict is not thread safe lock(Lock) - Udata[c] = U - Vdata[c] = V - Σdata[c] = Σ - dims[c] = length(Σ) - unlock(Lock) + try + Udata[c] = U + Vdata[c] = V + Σdata[c] = Σ + dims[c] = length(Σ) + catch + rethrow() + finally + unlock(Lock) + end end return errormonitor(task) end @@ -577,13 +568,57 @@ function eigh!(t::TensorMap) Ddata = SectorDict{I,Ar}() Vdata = SectorDict{I,A}() dims = SectorDict{I,Int}() - for (c, b) in blocks(t) - values, vectors = MatrixAlgebra.eigh!(b) - d = length(values) - Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values)) - Vdata[c] = vectors - dims[c] = d + + num_threads = get_num_threads_eigh() + lsc = blocksectors(t) + if num_threads == 1 || length(lsc) == 1 + for c in lsc + values, vectors = MatrixAlgebra.eigh!(block(t, c)) + d = length(values) + Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values)) + Vdata[c] = vectors + dims[c] = d + end + else + # try to sort sectors by size + if isa(lsc, AbstractVector) + lsc = sort(lsc; by=c -> size(block(t, c), 1), rev=true) + end + + # producer + taskref = Ref{Task}() + ch = Channel(; taskref=taskref, spawn=true) do ch + for c in lsc + put!(ch, c) + end + end + + # consumers + Lock = Threads.ReentrantLock() + tasks = map(1:num_threads) do _ + task = Threads.@spawn for c in ch + values, vectors = MatrixAlgebra.eigh!(block(t, c)) + d = length(values) + values = copyto!(similar(values, (d, d)), Diagonal(values)) + + lock(Lock) + try + Ddata[c] = values + Vdata[c] = vectors + dims[c] = d + catch + rethrow() + finally + unlock(Lock) + end + end + return errormonitor(task) + end + + wait.(tasks) + wait(taskref[]) end + if length(domain(t)) == 1 W = domain(t)[1] else diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index e25f518cc..6f3a317a7 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -239,14 +239,14 @@ end # TensorMap multiplication function LinearAlgebra.mul!(tC::AbstractTensorMap, tA::AbstractTensorMap, - tB::AbstractTensorMap, α=true, β=false; - numthreads::Int64=Threads.nthreads()) + tB::AbstractTensorMap, α=true, β=false) if !(codomain(tC) == codomain(tA) && domain(tC) == domain(tB) && domain(tA) == codomain(tB)) throw(SpaceMismatch("$(space(tC)) ≠ $(space(tA)) * $(space(tB))")) end - if numthreads == 1 || length(blocksectors(tC)) == 1 + num_threads = get_num_threads_mul() + if num_threads == 1 || length(blocksectors(tC)) == 1 for c in blocksectors(tC) if hasblock(tA, c) # then also tB should have such a block A = block(tA, c) @@ -258,28 +258,17 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, end end - elseif numthreads == -1 - Threads.@sync for c in blocksectors(tC) - if hasblock(tA, c) - Threads.@spawn mul!(StridedView(block(tC, c)), - StridedView(block(tA, c)), - StridedView(block(tB, c)), - α, β) - elseif β != one(β) - Threads.@spawn rmul!(block(tC, c), β) - end - end - else - - # sort sectors by size lsc = blocksectors(tC) - lsD3 = map(lsc) do c - if hasblock(tA, c) - return size(block(tA, c), 1) * size(block(tA, c), 2) * - size(block(tA, c), 2) - else - return size(block(tC, c), 1) * size(block(tC, c), 2) + # try to sort sectors by size + if isa(lsc, AbstractVector) + lsD3 = map(lsc) do c + if hasblock(tA, c) + return size(block(tA, c), 1) * size(block(tA, c), 2) * + size(block(tA, c), 2) + else + return size(block(tC, c), 1) * size(block(tC, c), 2) + end end end lsc = lsc[sortperm(lsD3; rev=true)] @@ -287,21 +276,18 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, # producer taskref = Ref{Task}() ch = Channel(; taskref=taskref, spawn=true) do ch - for c in vcat(lsc, fill(nothing, numthreads)) + for c in lsc put!(ch, c) end end # consumers - tasks = map(1:numthreads) do _ - task = Threads.@spawn while true - c = take!(ch) - isnothing(c) && break - + tasks = map(1:num_threads) do _ + task = Threads.@spawn for c in ch if hasblock(tA, c) - mul!(StridedView(block(tC, c)), - StridedView(block(tA, c)), - StridedView(block(tB, c)), + mul!(block(tC, c), + block(tA, c), + block(tB, c), α, β) elseif β != one(β) rmul!(block(tC, c), β) diff --git a/src/tensors/vectorinterface.jl b/src/tensors/vectorinterface.jl index 257ec15fe..0ded5cc2c 100644 --- a/src/tensors/vectorinterface.jl +++ b/src/tensors/vectorinterface.jl @@ -58,36 +58,31 @@ function VectorInterface.add(ty::AbstractTensorMap, tx::AbstractTensorMap, return VectorInterface.add!(scale!(similar(ty, T), ty, β), tx, α) end function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap, - α::Number, β::Number; - numthreads::Int64=1) + α::Number, β::Number) space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty)) ≠ $(space(tx))")) - if numthreads == 1 || length(blocksectors) == 1 - for c in blocksectors(tx) + num_threads = get_num_threads_add() + lsc = blocksectors(tx) + if num_threads == 1 || length(lsc) == 1 + for c in lsc VectorInterface.add!(block(ty, c), block(tx, c), α, β) end - elseif numthreads == -1 - Threads.@sync for c in blocksectors(tx) - Threads.@spawn VectorInterface.add!(block(ty, c), block(tx, c), α, β) + else + # try to sort sectors by size + if isa(lsc, AbstractVector) + # warning: using `sort!` here is not safe. I found it will lead to a "key ... not found" error when show tx again + lsc = sort(lsc; by=c -> prod(size(block(tx, c))), rev=true) end - else - # producer - taskref = Ref{Task}() - ch = Channel(; taskref=taskref, spawn=true) do ch - for c in vcat(blocksectors(tx), fill(nothing, numthreads)) - put!(ch, c) - end - end - # consumers - tasks = map(1:numthreads) do _ - task = Threads.@spawn while true - c = take!(ch) + + idx = Threads.Atomic{Int64}(1) + Threads.@sync for _ in 1:num_threads + Threads.@spawn while true + i = Threads.atomic_add!(idx, 1) + i > length(lsc) && break + + c = lsc[i] VectorInterface.add!(block(ty, c), block(tx, c), α, β) end - return errormonitor(task) end - - wait.(tasks) - wait(taskref[]) end return ty end