From 8b3176b1e9c945e3800bdfb0e259130b1eceb99b Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Mon, 5 Jan 2026 14:52:55 +0100 Subject: [PATCH 1/2] implement ZeroTangent for Zero beta --- ext/TensorOperationsChainRulesCoreExt.jl | 88 ++++-------------------- 1 file changed, 15 insertions(+), 73 deletions(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 7dfc564..42b6af7 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -64,28 +64,6 @@ _needs_tangent(::Type{<:Union{One, Zero}}) = false # The current `rrule` design makes sure that the implementation for custom types does # not need to support the backend or allocator arguments -# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), -# C, -# A, pA::Index2Tuple, conjA::Bool, -# α::Number, β::Number, -# backend, allocator) -# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator)) -# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) -# end -# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), -# C, -# A, pA::Index2Tuple, conjA::Bool, -# α::Number, β::Number, -# backend) -# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,)) -# return val, ΔC -> (pb(ΔC)..., NoTangent()) -# end -# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), -# C, -# A, pA::Index2Tuple, conjA::Bool, -# α::Number, β::Number) -# return _rrule_tensoradd!(C, A, pA, conjA, α, β, ()) -# end function ChainRulesCore.rrule( ::typeof(TensorOperations.tensoradd!), C, @@ -105,7 +83,11 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) function pullback(ΔC′) ΔC = unthunk(ΔC′) - dC = @thunk projectC(scale(ΔC, conj(β))) + dC = if β === Zero() + ZeroTangent() + else + @thunk projectC(scale(ΔC, conj(β))) + end dA = @thunk let ipA = invperm(linearize(pA)) _dA = zerovector(A, VectorInterface.promote_add(ΔC, α)) @@ -148,35 +130,6 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) return C′, pullback end -# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), -# C, -# A, pA::Index2Tuple, conjA::Bool, -# B, pB::Index2Tuple, conjB::Bool, -# pAB::Index2Tuple, -# α::Number, β::Number, -# backend, allocator) -# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, -# (backend, allocator)) -# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) -# end -# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), -# C, -# A, pA::Index2Tuple, conjA::Bool, -# B, pB::Index2Tuple, conjB::Bool, -# pAB::Index2Tuple, -# α::Number, β::Number, -# backend) -# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,)) -# return val, ΔC -> (pb(ΔC)..., NoTangent()) -# end -# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!), -# C, -# A, pA::Index2Tuple, conjA::Bool, -# B, pB::Index2Tuple, conjB::Bool, -# pAB::Index2Tuple, -# α::Number, β::Number) -# return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ()) -# end function ChainRulesCore.rrule( ::typeof(TensorOperations.tensorcontract!), C, @@ -204,7 +157,11 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) TupleTools.getindices(ipAB, trivtuple(numout(pA))), TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))), ) - dC = @thunk projectC(scale(ΔC, conj(β))) + dC = if β === Zero() + ZeroTangent() + else + @thunk projectC(scale(ΔC, conj(β))) + end dA = @thunk let ipA = (invperm(linearize(pA)), ()) conjΔC = conjA @@ -273,25 +230,6 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) return C′, pullback end -# function ChainRulesCore.rrule(::typeof(tensortrace!), C, -# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, -# α::Number, β::Number, -# backend, allocator) -# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator)) -# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent()) -# end -# function ChainRulesCore.rrule(::typeof(tensortrace!), C, -# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, -# α::Number, β::Number, -# backend) -# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,)) -# return val, ΔC -> (pb(ΔC)..., NoTangent()) -# end -# function ChainRulesCore.rrule(::typeof(tensortrace!), C, -# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, -# α::Number, β::Number) -# return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ()) -# end function ChainRulesCore.rrule( ::typeof(tensortrace!), C, A, p::Index2Tuple, q::Index2Tuple, conjA::Bool, @@ -310,7 +248,11 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) function pullback(ΔC′) ΔC = unthunk(ΔC′) - dC = @thunk projectC(scale(ΔC, conj(β))) + dC = if β === Zero() + ZeroTangent() + else + @thunk projectC(scale(ΔC, conj(β))) + end dA = @thunk let ip = invperm((linearize(p)..., q[1]..., q[2]...)) Es = map(q[1], q[2]) do i1, i2 From cab0548a8565c7dd40d087772a1650b4c299187c Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Mon, 5 Jan 2026 14:53:11 +0100 Subject: [PATCH 2/2] fix cuTensor compat and bump minor version --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index d5f8928..a8a9ddc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "TensorOperations" uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" -version = "5.3.2" +version = "5.4" authors = ["Lukas Devos ", "Maarten Van Damme ", "Jutho Haegeman "] [deps] @@ -49,7 +49,7 @@ StridedViews = "0.3, 0.4" Test = "1" TupleTools = "1.6" VectorInterface = "0.4.1,0.5" -cuTENSOR = ">=2.1.1" +cuTENSOR = "2.1.1" julia = "1.8" [extras]