From 485292082a3fe1b7752a8f798dcf9d570e7c9b72 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 27 Oct 2025 10:38:05 -0400 Subject: [PATCH 1/3] avoid twist copies in rrules --- ext/TensorKitChainRulesCoreExt/tensoroperations.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index 3809e7fbf..d93610e40 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -28,7 +28,7 @@ function ChainRulesCore.rrule( # for non-symmetric tensors this might be more efficient like this, # but for symmetric tensors an intermediate object will anyways be created # and then it might be more efficient to use an addition and inner product - tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) + tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) _dα = tensorscalar( tensorcontract( A, ((), linearize(pA)), !conjA, @@ -74,7 +74,7 @@ function ChainRulesCore.rrule( conjB′ = conjA ? conjB : !conjB TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α)) # TODO: allocator - tB = twist( + tB = _twist_nocopy( B, TupleTools.vcat( filter(x -> !isdual(space(B, x)), pB[1]), @@ -99,7 +99,7 @@ function ChainRulesCore.rrule( conjA′ = conjB ? conjA : !conjA TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α)) # TODO: allocator - tA = twist( + tA = _twist_nocopy( A, TupleTools.vcat( filter(x -> isdual(space(A, x)), pA[1]), @@ -188,3 +188,10 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap) end return val, scalar_pullback end + +# temporary function to avoid copies when not needed +# TODO: remove once `twist(t; copy=false)` is defined +function _twist_nocopy(t, inds) + (BraidingStyle(sectortype(t)) isa Fermionic && !isempty(inds)) || return t + return twist(t, inds) +end From 591b3fb91b9a9a61a19ff0393500f6eb081808da Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 27 Oct 2025 16:19:38 -0400 Subject: [PATCH 2/3] Add nocopy twists in tensorproduct --- ext/TensorKitChainRulesCoreExt/linalg.jl | 4 ++-- ext/TensorKitChainRulesCoreExt/tensoroperations.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl index af0c7f386..9c12af7a6 100644 --- a/ext/TensorKitChainRulesCoreExt/linalg.jl +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -42,7 +42,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe ipA = (codomainind(A), domainind(A)) pB = (allind(B), ()) dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) - tB = twist(B, filter(x -> isdual(space(B, x)), allind(B))) + tB = _twist_nocopy(B, filter(x -> isdual(space(B, x)), allind(B))) dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA) return projectA(dA) end @@ -50,7 +50,7 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe ipB = (codomainind(B), domainind(B)) pA = ((), allind(A)) dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) - tA = twist(A, filter(x -> isdual(space(A, x)), allind(A))) + tA = _twist_nocopy(A, filter(x -> isdual(space(A, x)), allind(A))) dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB) return projectB(dB) end diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index d93610e40..b15344d71 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -191,7 +191,7 @@ end # temporary function to avoid copies when not needed # TODO: remove once `twist(t; copy=false)` is defined -function _twist_nocopy(t, inds) +function _twist_nocopy(t, inds; kwargs...) (BraidingStyle(sectortype(t)) isa Fermionic && !isempty(inds)) || return t - return twist(t, inds) + return twist(t, inds; kwargs...) end From 871fafb2f6082ff2f91465bb694dfbdb0a1562dc Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 27 Oct 2025 17:20:35 -0400 Subject: [PATCH 3/3] Update ext/TensorKitChainRulesCoreExt/tensoroperations.jl Co-authored-by: Jutho --- ext/TensorKitChainRulesCoreExt/tensoroperations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl index b15344d71..73b60b5b7 100644 --- a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -192,6 +192,6 @@ end # temporary function to avoid copies when not needed # TODO: remove once `twist(t; copy=false)` is defined function _twist_nocopy(t, inds; kwargs...) - (BraidingStyle(sectortype(t)) isa Fermionic && !isempty(inds)) || return t + (BraidingStyle(sectortype(t)) isa Bosonic || isempty(inds)) && return t return twist(t, inds; kwargs...) end