diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index af0c7f386..9c12af7a6 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -42,7 +42,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe ipA = (codomainind(A), domainind(A)) pB = (allind(B), ()) dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) - tB = twist(B, filter(x -> isdual(space(B, x)), allind(B))) + tB = _twist_nocopy(B, filter(x -> isdual(space(B, x)), allind(B))) dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA) return projectA(dA) end @@ -50,7 +50,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe ipB = (codomainind(B), domainind(B)) pA = ((), allind(A)) dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) - tA = twist(A, filter(x -> isdual(space(A, x)), allind(A))) + tA = _twist_nocopy(A, filter(x -> isdual(space(A, x)), allind(A))) dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB) return projectB(dB) end diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index 3809e7fbf..73b60b5b7 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -28,7 +28,7 @@ function ChainRulesCore.rrule( # for non-symmetric tensors this might be more efficient like this, # but for symmetric tensors an intermediate object will anyways be created # and then it might be more efficient to use an addition and inner product - tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) + tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) _dα = tensorscalar( tensorcontract( A, ((), linearize(pA)), !conjA, @@ -74,7 +74,7 @@ function ChainRulesCore.rrule( conjB′ = conjA ? conjB : !conjB TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) # TODO: allocator - tB = twist( + tB = _twist_nocopy( B, TupleTools.vcat( filter(x -> !isdual(space(B, x)), pB[1]), @@ -99,7 +99,7 @@ function ChainRulesCore.rrule( conjA′ = conjB ? conjA : !conjA TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) # TODO: allocator - tA = twist( + tA = _twist_nocopy( A, TupleTools.vcat( filter(x -> isdual(space(A, x)), pA[1]), @@ -188,3 +188,10 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap) end return val, scalar_pullback end + +# temporary function to avoid copies when not needed +# TODO: remove once `twist(t; copy=false)` is defined +function _twist_nocopy(t, inds; kwargs...) + (BraidingStyle(sectortype(t)) isa Bosonic || isempty(inds)) && return t + return twist(t, inds; kwargs...) +end