From a3632057489a986d7c40a78b7adfb10b91754d85 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 30 Dec 2025 22:35:48 +0100 Subject: [PATCH 1/3] insert opt-out for non-derivable alpha and beta --- ext/TensorOperationsChainRulesCoreExt.jl | 139 ++++++++++++++--------- 1 file changed, 85 insertions(+), 54 deletions(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index a838917..90feccb 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -49,6 +49,13 @@ function ChainRulesCore.rrule(::typeof(tensorscalar), C) return tensorscalar(C), tensorscalar_pullback end +# To avoid computing rrules for α and β when these aren't needed, we want to have a +# type-stable quick bail-out +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false + # The current `rrule` design makes sure that the implementation for custom types does # not need to support the backend or allocator arguments # function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), @@ -99,26 +106,34 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) projectA(_dA) end - dα = @thunk let - _dα = tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - ΔC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... + dα = if _needs_tangent(α) + @thunk let + _dα = tensorscalar( + tensorcontract( + A, ((), linearize(pA)), !conjA, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectα(_dα) + projectα(_dα) + end + else + ZeroTangent() end - dβ = @thunk let - # TODO: consider using `inner` - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pA))), true, - ΔC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... + dβ = if _needs_tangent(β) + @thunk let + # TODO: consider using `inner` + _dβ = tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pA))), true, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectβ(_dβ) + projectβ(_dβ) + end + else + ZeroTangent() end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba... @@ -212,28 +227,36 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ) projectB(_dB) end - dα = @thunk let - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - _dα = tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... + dα = if _needs_tangent(α) + @thunk let + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + # TODO: consider using `inner` + _dα = tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectα(_dα) + projectα(_dα) + end + else + ZeroTangent() end - dβ = @thunk let - # TODO: consider using `inner` - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... + dβ = if _needs_tangent(β) + @thunk let + # TODO: consider using `inner` + _dβ = tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectβ(_dβ) + projectβ(_dβ) + end + else + ZeroTangent() end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, @@ -301,27 +324,35 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) ) projectA(_dA) end - dα = @thunk let - C_αβ = tensortrace(A, p, q, false, One(), ba...) - _dα = tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - ΔC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... + dα = if _needs_tangent(α) + @thunk let + C_αβ = tensortrace(A, p, q, false, One(), ba...) + _dα = tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(p))), + !conjA, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectα(_dα) + projectα(_dα) + end + else + ZeroTangent() end - dβ = @thunk let - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(p))), true, - ΔC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... + dβ = if _needs_tangent(β) + @thunk let + _dβ = tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(p))), true, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectβ(_dβ) + projectβ(_dβ) + end + else + ZeroTangent() end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba... From c29673b3d8bbe4ce3cdf346fc5261b4a18883f44 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 1 Jan 2026 08:37:02 -0500 Subject: [PATCH 2/3] revert change to allocate device 0-dim arrays --- ext/TensorOperationsChainRulesCoreExt.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 90feccb..92d7cb2 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -41,10 +41,9 @@ function ChainRulesCore.rrule( end function ChainRulesCore.rrule(::typeof(tensorscalar), C) - projectC = ProjectTo(C) function tensorscalar_pullback(Δc) - _Δc = unthunk(Δc) - return NoTangent(), projectC(_Δc) + ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C)) + return NoTangent(), fill!(ΔC, unthunk(Δc)) end return tensorscalar(C), tensorscalar_pullback end From 758b0b90f26ca835144f290427fca1387bcdbe8d Mon Sep 17 00:00:00 2001 From: lkdvos Date: Thu, 1 Jan 2026 08:44:59 -0500 Subject: [PATCH 3/3] restore higher-order differentiability --- ext/TensorOperationsChainRulesCoreExt.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index 92d7cb2..7dfc564 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -40,11 +40,18 @@ function ChainRulesCore.rrule( return output, tensoralloc_pullback end +# this function more or less boils down to `fill!(similar(x), y)` but does so in a single +# call to allow higher-order derivatives +function similar_and_fill(x, y) + x′ = TensorOperations.tensoralloc(typeof(x), TensorOperations.tensorstructure(x)) + return fill!(x′, y) +end +function ChainRulesCore.rrule(::typeof(similar_and_fill), x, y) + similar_and_fill_pullback(Δx) = NoTangent(), ZeroTangent(), tensorscalar(unthunk(Δx)) + return similar_and_fill(x, y), similar_and_fill_pullback +end function ChainRulesCore.rrule(::typeof(tensorscalar), C) - function tensorscalar_pullback(Δc) - ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C)) - return NoTangent(), fill!(ΔC, unthunk(Δc)) - end + tensorscalar_pullback(Δc) = NoTangent(), similar_and_fill(C, unthunk(Δc)) return tensorscalar(C), tensorscalar_pullback end