From 91ca3df89d911cdf44d6c8dd3545365ea5a6dfba Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 4 Nov 2025 21:44:30 -0500 Subject: [PATCH 1/3] rework TensorOperations to use backend and allocator --- src/tensors/tensoroperations.jl | 156 +++++++++++++++++++------------- 1 file changed, 91 insertions(+), 65 deletions(-) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index b75b9d636..516f8d27e 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -253,100 +253,126 @@ the indices of `A` and `B` according to `(oindA, cindA)` and `(cindB, oindB)` re """ function contract!( C::AbstractTensorMap, - A::AbstractTensorMap, (oindA, cindA)::Index2Tuple, - B::AbstractTensorMap, (cindB, oindB)::Index2Tuple, - (p₁, p₂)::Index2Tuple, - α::Number, β::Number, + A::AbstractTensorMap, pA::Index2Tuple, + B::AbstractTensorMap, pB::Index2Tuple, + pAB::Index2Tuple, α::Number, β::Number, backend, allocator ) - length(cindA) == length(cindB) || + length(pA[2]) == length(pB[1]) || throw(IndexError("number of contracted indices does not match")) - N₁, N₂ = length(oindA), length(oindB) - - # find optimal contraction scheme - hsp = has_shared_permute - ipAB = TupleTools.invperm((p₁..., p₂...)) - oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂)) + N₁, N₂ = length(pA[1]), length(pB[2]) - qA = TupleTools.sortperm(cindA) - cindA′ = TupleTools.getindices(cindA, qA) - cindB′ = TupleTools.getindices(cindB, qA) + # find optimal contraction scheme by checking the following options: + # - sorting the contracted inds of A or B to avoid permutations + # - contracting B with A instead to avoid permutations - qB = TupleTools.sortperm(cindB) - cindA′′ = TupleTools.getindices(cindA, qB) - cindB′′ = TupleTools.getindices(cindB, qB) + qA = TupleTools.sortperm(pA[2]) + pA′ = Base.setindex(pA, TupleTools.getindices(pA[2], qA), 2) + pB′ = Base.setindex(pB, TupleTools.getindices(pB[1], qA), 1) - dA, dB, dC = dim(A), dim(B), dim(C) + qB = TupleTools.sortperm(pB[1]) + pA″ = Base.setindex(pA, TupleTools.getindices(pA[2], qB), 2) + pB″ = Base.setindex(pB, TupleTools.getindices(pB[1], qB), 1) # keep order A en B, check possibilities for cind - memcost1 = memcost2 = dC * (!hsp(C, (oindAinC, oindBinC))) - memcost1 += dA * (!hsp(A, (oindA, cindA′))) + dB * (!hsp(B, (cindB′, oindB))) - memcost2 += dA * (!hsp(A, (oindA, cindA′′))) + dB * (!hsp(B, (cindB′′, oindB))) + memcost1 = TO.contract_memcost(C, A, pA′, B, pB′, pAB) + memcost2 = TO.contract_memcost(C, A, pA″, B, pB″, pAB) # reverse order A en B, check possibilities for cind - memcost3 = memcost4 = dC * (!hsp(C, (oindBinC, oindAinC))) - memcost3 += dB * (!hsp(B, (oindB, cindB′))) + dA * (!hsp(A, (cindA′, oindA))) - memcost4 += dB * (!hsp(B, (oindB, cindB′′))) + dA * (!hsp(A, (cindA′′, oindA))) + pAB′ = ( + map(n -> ifelse(n > N₁, n - N₁, n + N₂), pAB[1]), + map(n -> ifelse(n > N₁, n - N₁, n + N₂), pAB[2]), + ) + memcost3 = TO.contract_memcost(C, B, reverse(pB′), A, reverse(pA′), pAB′) + memcost4 = TO.contract_memcost(C, B, reverse(pB″), A, reverse(pA″), pAB′) return if min(memcost1, memcost2) <= min(memcost3, memcost4) if memcost1 <= memcost2 - return _contract!(α, A, B, β, C, oindA, cindA′, oindB, cindB′, p₁, p₂) + return blas_contract!(C, A, pA′, B, pB′, pAB, α, β, backend, allocator) else - return _contract!(α, A, B, β, C, oindA, cindA′′, oindB, cindB′′, p₁, p₂) + return blas_contract!(C, A, pA″, B, pB″, pAB, α, β, backend, allocator) end else - p1′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₁) - p2′ = map(n -> ifelse(n > N₁, n - N₁, n + N₂), p₂) if memcost3 <= memcost4 - return _contract!(α, B, A, β, C, oindB, cindB′, oindA, cindA′, p1′, p2′) + return blas_contract!(C, B, reverse(pB′), A, reverse(pA′), pAB′, α, β, backend, allocator) else - return _contract!(α, B, A, β, C, oindB, cindB′′, oindA, cindA′′, p1′, p2′) + return blas_contract!(C, B, reverse(pB″), A, reverse(pA″), pAB′, α, β, backend, allocator) end end end -# TODO: also transform _contract! into new interface, and add backend support -function _contract!( - α, A::AbstractTensorMap, B::AbstractTensorMap, - β, C::AbstractTensorMap, - oindA::IndexTuple, cindA::IndexTuple, - oindB::IndexTuple, cindB::IndexTuple, - p₁::IndexTuple, p₂::IndexTuple +function TO.contract_memcost( + C::AbstractTensorMap, + A::AbstractTensorMap, pA::Index2Tuple, + B::AbstractTensorMap, pB::Index2Tuple, + pAB::Index2Tuple + ) + ipAB = TO.oindABinC(pAB, pA, pB) + return dim(A) * (!TO.isblascontractable(A, pA) || eltype(A) !== eltype(C)) + + dim(B) * (!TO.isblascontractable(B, pB) || eltype(B) !== eltype(C)) + + dim(C) * !TO.isblasdestination(C, ipAB) +end + +function TO.isblascontractable(A::AbstractTensorMap, pA::Index2Tuple) + return eltype(A) <: LinearAlgebra.BlasFloat && has_shared_permute(A, pA) +end +function TO.isblasdestination(A::AbstractTensorMap, ipAB::Index2Tuple) + return eltype(A) <: LinearAlgebra.BlasFloat && has_shared_permute(A, ipAB) +end + +function blas_contract!( + C::AbstractTensorMap, + A::AbstractTensorMap, pA::Index2Tuple, + B::AbstractTensorMap, pB::Index2Tuple, + pAB::Index2Tuple, α, β, + backend, allocator ) - if !(BraidingStyle(sectortype(C)) isa SymmetricBraiding) + I = sectortype(C) + BraidingStyle(I) isa SymmetricBraiding || throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead")) - end - N₁, N₂ = length(oindA), length(oindB) - copyA = false - if BraidingStyle(sectortype(A)) isa Fermionic - for i in cindA - if !isdual(space(A, i)) - copyA = true - end + TC = eltype(C) + + # Bring A in the correct form for BLAS contraction + flagA = TO.isblascontractable(A, pA) && eltype(A) === TC && + !(BraidingStyle(I) isa Fermionic && any(i -> isdual(space(A, i)), pA[2])) + if !flagA + Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator) + Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator) + for i in domainind(Anew) + isdual(space(Anew, i)) || twist!(Anew, i) end + else + Anew = permute(A, pA) end - A′ = permute(A, (oindA, cindA); copy = copyA) - B′ = permute(B, (cindB, oindB)) - if BraidingStyle(sectortype(A)) isa Fermionic - for i in domainind(A′) - if !isdual(space(A′, i)) - A′ = twist!(A′, i) - end - end - # A′ = twist!(A′, filter(i -> !isdual(space(A′, i)), domainind(A′))) - # commented version leads to boxing of `A′` and type instabilities in the result + pAnew = (codomainind(Anew), domainind(Anew)) + + # Bring B in the correct form for BLAS contraction + flagB = TO.isblascontractable(B, pB) && eltype(B) === TC + if !flagB + Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator) + Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator) + else + Bnew = permute(B, pB) end - ipAB = TupleTools.invperm((p₁..., p₂...)) - oindAinC = TupleTools.getindices(ipAB, ntuple(n -> n, N₁)) - oindBinC = TupleTools.getindices(ipAB, ntuple(n -> n + N₁, N₂)) - if has_shared_permute(C, (oindAinC, oindBinC)) - C′ = permute(C, (oindAinC, oindBinC)) - mul!(C′, A′, B′, α, β) + pBnew = (codomainind(Bnew), domainind(Bnew)) + + # Bring C in the correct form for BLAS contraction + ipAB = TO.oindABinC(pAB, pAnew, pBnew) + flagC = TO.isblasdestination(C, ipAB) + + if flagC + Cnew = permute(C, ipAB) + mul!(Cnew, Anew, Bnew, α, β) else - C′ = A′ * B′ - add_permute!(C, C′, (p₁, p₂), α, β) + Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator) + mul!(Cnew, Anew, Bnew) + TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator) + TO.tensorfree!(Cnew, allocator) end + + flagA || TO.tensorfree!(Anew, allocator) + flagB || TO.tensorfree!(Bnew, allocator) + return C end From fad06255ff729e6d3461ca1ddc1343d73d3f0485 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 5 Nov 2025 13:46:09 -0500 Subject: [PATCH 2/3] rework to twist smallest object --- src/tensors/tensoroperations.jl | 50 ++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 516f8d27e..7ee93018e 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -327,30 +327,46 @@ function blas_contract!( pAB::Index2Tuple, α, β, backend, allocator ) - I = sectortype(C) - BraidingStyle(I) isa SymmetricBraiding || + bstyle = BraidingStyle(sectortype(C)) + bstyle isa SymmetricBraiding || throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead")) TC = eltype(C) + # check which tensors have to be permuted/copied + copyA = !(TO.isblascontractable(A, pA) && eltype(A) === TC) + copyB = !(TO.isblascontractable(B, pB) && eltype(B) === TC) + + if bstyle isa Fermionic && any(isdual ∘ Base.Fix1(space, B), pB[1]) + # twist smallest object if neither or both already have to be permuted + # otherwise twist the one that already is copied + if copyA ⊻ copyB + twistA = dim(A) < dim(B) + else + twistA = copyA + end + twistB = !twistA + copyA |= twistA + copyB |= twistB + else + twistA = false + twistB = false + end + # Bring A in the correct form for BLAS contraction - flagA = TO.isblascontractable(A, pA) && eltype(A) === TC && - !(BraidingStyle(I) isa Fermionic && any(i -> isdual(space(A, i)), pA[2])) - if !flagA + if copyA Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator) Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator) - for i in domainind(Anew) - isdual(space(Anew, i)) || twist!(Anew, i) - end + twistA && twist!(Anew, filter(!isdual ∘ Base.Fix1(space, Anew), domainind(Anew))) else Anew = permute(A, pA) end pAnew = (codomainind(Anew), domainind(Anew)) # Bring B in the correct form for BLAS contraction - flagB = TO.isblascontractable(B, pB) && eltype(B) === TC - if !flagB + if copyB Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator) Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator) + twistB && twist!(Bnew, filter(isdual ∘ Base.Fix1(space, Bnew), codomainind(Bnew))) else Bnew = permute(B, pB) end @@ -358,20 +374,20 @@ function blas_contract!( # Bring C in the correct form for BLAS contraction ipAB = TO.oindABinC(pAB, pAnew, pBnew) - flagC = TO.isblasdestination(C, ipAB) + copyC = !TO.isblasdestination(C, ipAB) - if flagC - Cnew = permute(C, ipAB) - mul!(Cnew, Anew, Bnew, α, β) - else + if copyC Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator) mul!(Cnew, Anew, Bnew) TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator) TO.tensorfree!(Cnew, allocator) + else + Cnew = permute(C, ipAB) + mul!(Cnew, Anew, Bnew, α, β) end - flagA || TO.tensorfree!(Anew, allocator) - flagB || TO.tensorfree!(Bnew, allocator) + copyA && TO.tensorfree!(Anew, allocator) + copyB && TO.tensorfree!(Bnew, allocator) return C end From bb4b7efbc3b4dc321bc2fdbb07ee1235dfe462ff Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 15 Nov 2025 07:32:01 -0500 Subject: [PATCH 3/3] fix logic mistake --- src/tensors/tensoroperations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensors/tensoroperations.jl b/src/tensors/tensoroperations.jl index 7ee93018e..7bd0ab96b 100644 --- a/src/tensors/tensoroperations.jl +++ b/src/tensors/tensoroperations.jl @@ -339,7 +339,7 @@ function blas_contract!( if bstyle isa Fermionic && any(isdual ∘ Base.Fix1(space, B), pB[1]) # twist smallest object if neither or both already have to be permuted # otherwise twist the one that already is copied - if copyA ⊻ copyB + if !(copyA ⊻ copyB) twistA = dim(A) < dim(B) else twistA = copyA