Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 107 additions & 65 deletions src/tensors/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,100 +253,142 @@ the indices of `A` and `B` according to `(oindA, cindA)` and `(cindB, oindB)` re
"""
function contract!(
C::AbstractTensorMap,
A::AbstractTensorMap, (oindA, cindA)::Index2Tuple,
B::AbstractTensorMap, (cindB, oindB)::Index2Tuple,
(p₁, p₂)::Index2Tuple,
α::Number, β::Number,
A::AbstractTensorMap, pA::Index2Tuple,
B::AbstractTensorMap, pB::Index2Tuple,
pAB::Index2Tuple, α::Number, β::Number,
backend, allocator
)
length(cindA) == length(cindB) ||
length(pA[2]) == length(pB[1]) ||
throw(IndexError("number of contracted indices does not match"))
N₁, N₂ = length(oindA), length(oindB)

# find optimal contraction scheme
hsp = has_shared_permute
ipAB = TupleTools.invperm((p₁..., p₂...))
oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁))
oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂))
N₁, N₂ = length(pA[1]), length(pB[2])

qA = TupleTools.sortperm(cindA)
cindA′ = TupleTools.getindices(cindA, qA)
cindB′ = TupleTools.getindices(cindB, qA)
# find optimal contraction scheme by checking the following options:
# - sorting the contracted inds of A or B to avoid permutations
# - contracting B with A instead to avoid permutations

qB = TupleTools.sortperm(cindB)
cindA′′ = TupleTools.getindices(cindA, qB)
cindB′′ = TupleTools.getindices(cindB, qB)
qA = TupleTools.sortperm(pA[2])
pA′ = Base.setindex(pA, TupleTools.getindices(pA[2], qA), 2)
pB′ = Base.setindex(pB, TupleTools.getindices(pB[1], qA), 1)

dA, dB, dC = dim(A), dim(B), dim(C)
qB = TupleTools.sortperm(pB[1])
pA″ = Base.setindex(pA, TupleTools.getindices(pA[2], qB), 2)
pB″ = Base.setindex(pB, TupleTools.getindices(pB[1], qB), 1)

# keep order A en B, check possibilities for cind
memcost1 = memcost2 = dC * (!hsp(C, (oindAinC, oindBinC)))
memcost1 += dA * (!hsp(A, (oindA, cindA′))) + dB * (!hsp(B, (cindB′, oindB)))
memcost2 += dA * (!hsp(A, (oindA, cindA′′))) + dB * (!hsp(B, (cindB′′, oindB)))
memcost1 = TO.contract_memcost(C, A, pA′, B, pB′, pAB)
memcost2 = TO.contract_memcost(C, A, pA″, B, pB″, pAB)

# reverse order A en B, check possibilities for cind
memcost3 = memcost4 = dC * (!hsp(C, (oindBinC, oindAinC)))
memcost3 += dB * (!hsp(B, (oindB, cindB′))) + dA * (!hsp(A, (cindA′, oindA)))
memcost4 += dB * (!hsp(B, (oindB, cindB′′))) + dA * (!hsp(A, (cindA′′, oindA)))
pAB′ = (
map(n -> ifelse(n > N₁, n - N₁, n + N₂), pAB[1]),
map(n -> ifelse(n > N₁, n - N₁, n + N₂), pAB[2]),
)
memcost3 = TO.contract_memcost(C, B, reverse(pB′), A, reverse(pA′), pAB′)
memcost4 = TO.contract_memcost(C, B, reverse(pB″), A, reverse(pA″), pAB′)

return if min(memcost1, memcost2) <= min(memcost3, memcost4)
if memcost1 <= memcost2
return _contract!(α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂)
return blas_contract!(C, A, pA′, B, pB′, pAB, α, β, backend, allocator)
else
return _contract!(α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂)
return blas_contract!(C, A, pA″, B, pB″, pAB, α, β, backend, allocator)
end
else
p1′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₁)
p2′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₂)
if memcost3 <= memcost4
return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′)
return blas_contract!(C, B, reverse(pB′), A, reverse(pA′), pAB′, α, β, backend, allocator)
else
return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′)
return blas_contract!(C, B, reverse(pB″), A, reverse(pA″), pAB′, α, β, backend, allocator)
end
end
end

# TODO: also transform _contract! into new interface, and add backend support
function _contract!(
α, A::AbstractTensorMap, B::AbstractTensorMap,
β, C::AbstractTensorMap,
oindA::IndexTuple, cindA::IndexTuple,
oindB::IndexTuple, cindB::IndexTuple,
p₁::IndexTuple, p₂::IndexTuple
function TO.contract_memcost(
C::AbstractTensorMap,
A::AbstractTensorMap, pA::Index2Tuple,
B::AbstractTensorMap, pB::Index2Tuple,
pAB::Index2Tuple
)
ipAB = TO.oindABinC(pAB, pA, pB)
return dim(A) * (!TO.isblascontractable(A, pA) || eltype(A) !== eltype(C)) +
dim(B) * (!TO.isblascontractable(B, pB) || eltype(B) !== eltype(C)) +
dim(C) * !TO.isblasdestination(C, ipAB)
end

function TO.isblascontractable(A::AbstractTensorMap, pA::Index2Tuple)
return eltype(A) <: LinearAlgebra.BlasFloat && has_shared_permute(A, pA)
end
function TO.isblasdestination(A::AbstractTensorMap, ipAB::Index2Tuple)
return eltype(A) <: LinearAlgebra.BlasFloat && has_shared_permute(A, ipAB)
end

function blas_contract!(
C::AbstractTensorMap,
A::AbstractTensorMap, pA::Index2Tuple,
B::AbstractTensorMap, pB::Index2Tuple,
pAB::Index2Tuple, α, β,
backend, allocator
)
if !(BraidingStyle(sectortype(C)) isa SymmetricBraiding)
bstyle = BraidingStyle(sectortype(C))
bstyle isa SymmetricBraiding ||
throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
end
N₁, N₂ = length(oindA), length(oindB)
copyA = false
if BraidingStyle(sectortype(A)) isa Fermionic
for i in cindA
if !isdual(space(A, i))
copyA = true
end
TC = eltype(C)

# check which tensors have to be permuted/copied
copyA = !(TO.isblascontractable(A, pA) && eltype(A) === TC)
copyB = !(TO.isblascontractable(B, pB) && eltype(B) === TC)

if bstyle isa Fermionic && any(isdual ∘ Base.Fix1(space, B), pB[1])
# twist smallest object if neither or both already have to be permuted
# otherwise twist the one that already is copied
if !(copyA ⊻ copyB)
twistA = dim(A) < dim(B)
else
twistA = copyA
end
twistB = !twistA
copyA |= twistA
copyB |= twistB
else
twistA = false
twistB = false
end
A′ = permute(A, (oindA, cindA); copy = copyA)
B′ = permute(B, (cindB, oindB))
if BraidingStyle(sectortype(A)) isa Fermionic
for i in domainind(A′)
if !isdual(space(A′, i))
A′ = twist!(A′, i)
end
end
# A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′)))
# commented version leads to boxing of `A′` and type instabilities in the result

# Bring A in the correct form for BLAS contraction
if copyA
Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator)
Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
twistA && twist!(Anew, filter(!isdual ∘ Base.Fix1(space, Anew), domainind(Anew)))
else
Anew = permute(A, pA)
end
ipAB = TupleTools.invperm((p₁..., p₂...))
oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁))
oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂))
if has_shared_permute(C, (oindAinC, oindBinC))
C′ = permute(C, (oindAinC, oindBinC))
mul!(C′, A′, B′, α, β)
pAnew = (codomainind(Anew), domainind(Anew))

# Bring B in the correct form for BLAS contraction
if copyB
Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator)
Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator)
twistB && twist!(Bnew, filter(isdual ∘ Base.Fix1(space, Bnew), codomainind(Bnew)))
else
C′ = A′ * B′
add_permute!(C, C′, (p₁, p₂), α, β)
Bnew = permute(B, pB)
end
pBnew = (codomainind(Bnew), domainind(Bnew))

# Bring C in the correct form for BLAS contraction
ipAB = TO.oindABinC(pAB, pAnew, pBnew)
copyC = !TO.isblasdestination(C, ipAB)

if copyC
Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)
mul!(Cnew, Anew, Bnew)
TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator)
TO.tensorfree!(Cnew, allocator)
else
Cnew = permute(C, ipAB)
mul!(Cnew, Anew, Bnew, α, β)
end

copyA && TO.tensorfree!(Anew, allocator)
copyB && TO.tensorfree!(Bnew, allocator)

return C
end

Expand Down
Loading