diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index b75b9d636..7bd0ab96b 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -253,100 +253,142 @@ the indices of `A` and `B` according to `(oindA, cindA)` and `(cindB, oindB)` re """ function contract!( C::AbstractTensorMap, - A::AbstractTensorMap, (oindA, cindA)::Index2Tuple, - B::AbstractTensorMap, (cindB, oindB)::Index2Tuple, - (p₁, p₂)::Index2Tuple, - α::Number, β::Number, + A::AbstractTensorMap, pA::Index2Tuple, + B::AbstractTensorMap, pB::Index2Tuple, + pAB::Index2Tuple, α::Number, β::Number, backend, allocator ) - length(cindA) == length(cindB) || + length(pA[2]) == length(pB[1]) || throw(IndexError("number of contracted indices does not match")) - N₁, N₂ = length(oindA), length(oindB) - - # find optimal contraction scheme - hsp = has_shared_permute - ipAB = TupleTools.invperm((p₁..., p₂...)) - oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂)) + N₁, N₂ = length(pA[1]), length(pB[2]) - qA = TupleTools.sortperm(cindA) - cindA′ = TupleTools.getindices(cindA, qA) - cindB′ = TupleTools.getindices(cindB, qA) + # find optimal contraction scheme by checking the following options: + # - sorting the contracted inds of A or B to avoid permutations + # - contracting B with A instead to avoid permutations - qB = TupleTools.sortperm(cindB) - cindA′′ = TupleTools.getindices(cindA, qB) - cindB′′ = TupleTools.getindices(cindB, qB) + qA = TupleTools.sortperm(pA[2]) + pA′ = Base.setindex(pA, TupleTools.getindices(pA[2], qA), 2) + pB′ = Base.setindex(pB, TupleTools.getindices(pB[1], qA), 1) - dA, dB, dC = dim(A), dim(B), dim(C) + qB = TupleTools.sortperm(pB[1]) + pA″ = Base.setindex(pA, TupleTools.getindices(pA[2], qB), 2) + pB″ = Base.setindex(pB, TupleTools.getindices(pB[1], qB), 1) # keep order A en B, check possibilities for cind - memcost1 = memcost2 = dC * (!hsp(C, (oindAinC, oindBinC))) - memcost1 += dA * (!hsp(A, (oindA, cindA′))) + dB * (!hsp(B, (cindB′, oindB))) - memcost2 += dA * (!hsp(A, (oindA, cindA′′))) + dB * (!hsp(B, (cindB′′, oindB))) + memcost1 = TO.contract_memcost(C, A, pA′, B, pB′, pAB) + memcost2 = TO.contract_memcost(C, A, pA″, B, pB″, pAB) # reverse order A en B, check possibilities for cind - memcost3 = memcost4 = dC * (!hsp(C, (oindBinC, oindAinC))) - memcost3 += dB * (!hsp(B, (oindB, cindB′))) + dA * (!hsp(A, (cindA′, oindA))) - memcost4 += dB * (!hsp(B, (oindB, cindB′′))) + dA * (!hsp(A, (cindA′′, oindA))) + pAB′ = ( + map(n -> ifelse(n > N₁, n - N₁, n + N₂), pAB[1]), + map(n -> ifelse(n > N₁, n - N₁, n + N₂), pAB[2]), + ) + memcost3 = TO.contract_memcost(C, B, reverse(pB′), A, reverse(pA′), pAB′) + memcost4 = TO.contract_memcost(C, B, reverse(pB″), A, reverse(pA″), pAB′) return if min(memcost1, memcost2) <= min(memcost3, memcost4) if memcost1 <= memcost2 - return _contract!(α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂) + return blas_contract!(C, A, pA′, B, pB′, pAB, α, β, backend, allocator) else - return _contract!(α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂) + return blas_contract!(C, A, pA″, B, pB″, pAB, α, β, backend, allocator) end else - p1′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₁) - p2′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₂) if memcost3 <= memcost4 - return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′) + return blas_contract!(C, B, reverse(pB′), A, reverse(pA′), pAB′, α, β, backend, allocator) else - return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′) + return blas_contract!(C, B, reverse(pB″), A, reverse(pA″), pAB′, α, β, backend, allocator) end end end -# TODO: also transform _contract! into new interface, and add backend support -function _contract!( - α, A::AbstractTensorMap, B::AbstractTensorMap, - β, C::AbstractTensorMap, - oindA::IndexTuple, cindA::IndexTuple, - oindB::IndexTuple, cindB::IndexTuple, - p₁::IndexTuple, p₂::IndexTuple +function TO.contract_memcost( + C::AbstractTensorMap, + A::AbstractTensorMap, pA::Index2Tuple, + B::AbstractTensorMap, pB::Index2Tuple, + pAB::Index2Tuple + ) + ipAB = TO.oindABinC(pAB, pA, pB) + return dim(A) * (!TO.isblascontractable(A, pA) || eltype(A) !== eltype(C)) + + dim(B) * (!TO.isblascontractable(B, pB) || eltype(B) !== eltype(C)) + + dim(C) * !TO.isblasdestination(C, ipAB) +end + +function TO.isblascontractable(A::AbstractTensorMap, pA::Index2Tuple) + return eltype(A) <: LinearAlgebra.BlasFloat && has_shared_permute(A, pA) +end +function TO.isblasdestination(A::AbstractTensorMap, ipAB::Index2Tuple) + return eltype(A) <: LinearAlgebra.BlasFloat && has_shared_permute(A, ipAB) +end + +function blas_contract!( + C::AbstractTensorMap, + A::AbstractTensorMap, pA::Index2Tuple, + B::AbstractTensorMap, pB::Index2Tuple, + pAB::Index2Tuple, α, β, + backend, allocator ) - if !(BraidingStyle(sectortype(C)) isa SymmetricBraiding) + bstyle = BraidingStyle(sectortype(C)) + bstyle isa SymmetricBraiding || throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead")) - end - N₁, N₂ = length(oindA), length(oindB) - copyA = false - if BraidingStyle(sectortype(A)) isa Fermionic - for i in cindA - if !isdual(space(A, i)) - copyA = true - end + TC = eltype(C) + + # check which tensors have to be permuted/copied + copyA = !(TO.isblascontractable(A, pA) && eltype(A) === TC) + copyB = !(TO.isblascontractable(B, pB) && eltype(B) === TC) + + if bstyle isa Fermionic && any(isdual ∘ Base.Fix1(space, B), pB[1]) + # twist smallest object if neither or both already have to be permuted + # otherwise twist the one that already is copied + if !(copyA ⊻ copyB) + twistA = dim(A) < dim(B) + else + twistA = copyA end + twistB = !twistA + copyA |= twistA + copyB |= twistB + else + twistA = false + twistB = false end - A′ = permute(A, (oindA, cindA); copy = copyA) - B′ = permute(B, (cindB, oindB)) - if BraidingStyle(sectortype(A)) isa Fermionic - for i in domainind(A′) - if !isdual(space(A′, i)) - A′ = twist!(A′, i) - end - end - # A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′))) - # commented version leads to boxing of `A′` and type instabilities in the result + + # Bring A in the correct form for BLAS contraction + if copyA + Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator) + Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator) + twistA && twist!(Anew, filter(!isdual ∘ Base.Fix1(space, Anew), domainind(Anew))) + else + Anew = permute(A, pA) end - ipAB = TupleTools.invperm((p₁..., p₂...)) - oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂)) - if has_shared_permute(C, (oindAinC, oindBinC)) - C′ = permute(C, (oindAinC, oindBinC)) - mul!(C′, A′, B′, α, β) + pAnew = (codomainind(Anew), domainind(Anew)) + + # Bring B in the correct form for BLAS contraction + if copyB + Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator) + Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator) + twistB && twist!(Bnew, filter(isdual ∘ Base.Fix1(space, Bnew), codomainind(Bnew))) else - C′ = A′ * B′ - add_permute!(C, C′, (p₁, p₂), α, β) + Bnew = permute(B, pB) end + pBnew = (codomainind(Bnew), domainind(Bnew)) + + # Bring C in the correct form for BLAS contraction + ipAB = TO.oindABinC(pAB, pAnew, pBnew) + copyC = !TO.isblasdestination(C, ipAB) + + if copyC + Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator) + mul!(Cnew, Anew, Bnew) + TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator) + TO.tensorfree!(Cnew, allocator) + else + Cnew = permute(C, ipAB) + mul!(Cnew, Anew, Bnew, α, β) + end + + copyA && TO.tensorfree!(Anew, allocator) + copyB && TO.tensorfree!(Bnew, allocator) + return C end