diff --git a/.gitignore b/.gitignore index ae7f128c3..07aec3ac7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ *-old __pycache__ .ipynb* -Manifest.toml \ No newline at end of file +Manifest.toml +.vscode +dev \ No newline at end of file diff --git a/src/TensorKit.jl b/src/TensorKit.jl index a81aabe32..fb0b39fd8 100644 --- a/src/TensorKit.jl +++ b/src/TensorKit.jl @@ -205,10 +205,17 @@ include("planar/planaroperations.jl") # deprecations: to be removed in version 1.0 or sooner include("auxiliary/deprecate.jl") +include("multithreading.jl") # Extensions # ---------- function __init__() @require_extensions + + global nthreads_mul = Threads.nthreads() + global nthreads_eigh = Threads.nthreads() + global nthreads_svd = Threads.nthreads() + global nthreads_add = Threads.nthreads() + end end diff --git a/src/multithreading.jl b/src/multithreading.jl new file mode 100644 index 000000000..7669567c6 --- /dev/null +++ b/src/multithreading.jl @@ -0,0 +1,33 @@ +# global variables to control multi-threading behaviors +global nthreads_mul::Int64 +global nthreads_eigh::Int64 +global nthreads_svd::Int64 +global nthreads_add::Int64 + +function set_num_threads_mul(n::Int64) + @assert 1 ≤ n ≤ Threads.nthreads() + global nthreads_mul = n + return nothing +end +get_num_threads_mul() = nthreads_mul + +function set_num_threads_add(n::Int64) + @assert 1 ≤ n ≤ Threads.nthreads() + global nthreads_add = n + return nothing +end +get_num_threads_add() = nthreads_add + +function set_num_threads_svd(n::Int64) + @assert 1 ≤ n ≤ Threads.nthreads() + global nthreads_svd = n + return nothing +end +get_num_threads_svd() = nthreads_svd + +function set_num_threads_eigh(n::Int64) + @assert 1 ≤ n ≤ Threads.nthreads() + global nthreads_eigh = n + return nothing +end +get_num_threads_eigh() = nthreads_eigh diff --git a/src/tensors/factorizations.jl b/src/tensors/factorizations.jl index 03bc20a73..24df11e87 100644 --- a/src/tensors/factorizations.jl +++ b/src/tensors/factorizations.jl @@ -445,17 +445,67 @@ function _compute_svddata!(t::TensorMap, alg::Union{SVD,SDD}) Udata = SectorDict{I,A}() Vdata = SectorDict{I,A}() dims = SectorDict{I,Int}() - local Σdata - for (c, b) in blocks(t) - U, Σ, V = MatrixAlgebra.svd!(b, alg) - Udata[c] = U - Vdata[c] = V - if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction - Σdata[c] = Σ - else - Σdata = SectorDict(c => Σ) + + num_threads = get_num_threads_svd() + if num_threads == 1 || length(blocksectors(t)) == 1 + local Σdata + for (c, b) in blocks(t) + U, Σ, V = MatrixAlgebra.svd!(b, alg) + Udata[c] = U + Vdata[c] = V + if @isdefined Σdata # cannot easily infer the type of Σ, so use this construction + Σdata[c] = Σ + else + Σdata = SectorDict(c => Σ) + end + dims[c] = length(Σ) + end + else + Σdata = SectorDict{I,Vector{real(scalartype(t))}}() + + # try to sort sectors by size + lsc = blocksectors(t) + if isa(lsc, AbstractVector) + lsD3 = lsD3 = map(lsc) do c + # O(D1^2D2) or O(D1D2^2) + return min(size(block(t, c), 1)^2 * size(block(t, c), 2), + size(block(t, c), 1) * size(block(t, c), 2)^2) + end + lsc = lsc[sortperm(lsD3; rev=true)] + end + + # producer + taskref = Ref{Task}() + ch = Channel(; taskref=taskref, spawn=true) do ch + for c in lsc + put!(ch, c) + end end - dims[c] = length(Σ) + + # consumers + Lock = Threads.ReentrantLock() + tasks = map(1:num_threads) do _ + task = Threads.@spawn for c in ch + U, Σ, V = MatrixAlgebra.svd!(blocks(t)[c], alg) + + # note inserting keys to dict is not thread safe + lock(Lock) + try + Udata[c] = U + Vdata[c] = V + Σdata[c] = Σ + dims[c] = length(Σ) + catch + rethrow() + finally + unlock(Lock) + end + end + return errormonitor(task) + end + + wait.(tasks) + wait(taskref[]) end return Udata, Σdata, Vdata, dims end @@ -518,13 +568,57 @@ function eigh!(t::TensorMap) Ddata = SectorDict{I,Ar}() Vdata = SectorDict{I,A}() dims = SectorDict{I,Int}() - for (c, b) in blocks(t) - values, vectors = MatrixAlgebra.eigh!(b) - d = length(values) - Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values)) - Vdata[c] = vectors - dims[c] = d + + num_threads = get_num_threads_eigh() + lsc = blocksectors(t) + if num_threads == 1 || length(lsc) == 1 + for c in lsc + values, vectors = MatrixAlgebra.eigh!(block(t, c)) + d = length(values) + Ddata[c] = copyto!(similar(values, (d, d)), Diagonal(values)) + Vdata[c] = vectors + dims[c] = d + end + else + # try to sort sectors by size + if isa(lsc, AbstractVector) + lsc = sort(lsc; by=c -> size(block(t, c), 1), rev=true) + end + + # producer + taskref = Ref{Task}() + ch = Channel(; taskref=taskref, spawn=true) do ch + for c in lsc + put!(ch, c) + end + end + + # consumers + Lock = Threads.ReentrantLock() + tasks = map(1:num_threads) do _ + task = Threads.@spawn for c in ch + values, vectors = MatrixAlgebra.eigh!(block(t, c)) + d = length(values) + values = copyto!(similar(values, (d, d)), Diagonal(values)) + + lock(Lock) + try + Ddata[c] = values + Vdata[c] = vectors + dims[c] = d + catch + rethrow() + finally + unlock(Lock) + end + end + return errormonitor(task) + end + + wait.(tasks) + wait(taskref[]) end + if length(domain(t)) == 1 W = domain(t)[1] else diff --git a/src/tensors/linalg.jl b/src/tensors/linalg.jl index 872f3564a..6f3a317a7 100644 --- a/src/tensors/linalg.jl +++ b/src/tensors/linalg.jl @@ -244,15 +244,60 @@ function LinearAlgebra.mul!(tC::AbstractTensorMap, domain(tA) == codomain(tB)) throw(SpaceMismatch("$(space(tC)) ≠ $(space(tA)) * $(space(tB))")) end - for c in blocksectors(tC) - if hasblock(tA, c) # then also tB should have such a block - A = block(tA, c) - B = block(tB, c) - C = block(tC, c) - mul!(StridedView(C), StridedView(A), StridedView(B), α, β) - elseif β != one(β) - rmul!(block(tC, c), β) + + num_threads = get_num_threads_mul() + if num_threads == 1 || length(blocksectors(tC)) == 1 + for c in blocksectors(tC) + if hasblock(tA, c) # then also tB should have such a block + A = block(tA, c) + B = block(tB, c) + C = block(tC, c) + mul!(StridedView(C), StridedView(A), StridedView(B), α, β) + elseif β != one(β) + rmul!(block(tC, c), β) + end + end + + else + lsc = blocksectors(tC) + # try to sort sectors by size + if isa(lsc, AbstractVector) + lsD3 = map(lsc) do c + if hasblock(tA, c) + return size(block(tA, c), 1) * size(block(tA, c), 2) * + size(block(tA, c), 2) + else + return size(block(tC, c), 1) * size(block(tC, c), 2) + end + end end + lsc = lsc[sortperm(lsD3; rev=true)] + + # producer + taskref = Ref{Task}() + ch = Channel(; taskref=taskref, spawn=true) do ch + for c in lsc + put!(ch, c) + end + end + + # consumers + tasks = map(1:num_threads) do _ + task = Threads.@spawn for c in ch + if hasblock(tA, c) + mul!(block(tC, c), + block(tA, c), + block(tB, c), + α, β) + elseif β != one(β) + rmul!(block(tC, c), β) + end + end + return errormonitor(task) + end + + wait.(tasks) + wait(taskref[]) end return tC end diff --git a/src/tensors/vectorinterface.jl b/src/tensors/vectorinterface.jl index b25633524..0ded5cc2c 100644 --- a/src/tensors/vectorinterface.jl +++ b/src/tensors/vectorinterface.jl @@ -60,8 +60,29 @@ end function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number, β::Number) space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty)) ≠ $(space(tx))")) - for c in blocksectors(tx) - VectorInterface.add!(block(ty, c), block(tx, c), α, β) + num_threads = get_num_threads_add() + lsc = blocksectors(tx) + if num_threads == 1 || length(lsc) == 1 + for c in lsc + VectorInterface.add!(block(ty, c), block(tx, c), α, β) + end + else + # try to sort sectors by size + if isa(lsc, AbstractVector) + # warning: using `sort!` here is not safe. I found it will lead to a "key ... not found" error when show tx again + lsc = sort(lsc; by=c -> prod(size(block(tx, c))), rev=true) + end + + idx = Threads.Atomic{Int64}(1) + Threads.@sync for _ in 1:num_threads + Threads.@spawn while true + i = Threads.atomic_add!(idx, 1) + i > length(lsc) && break + + c = lsc[i] + VectorInterface.add!(block(ty, c), block(tx, c), α, β) + end + end end return ty end