diff --git a/ext/TensorOperationsChainRulesCoreExt.jl b/ext/TensorOperationsChainRulesCoreExt.jl index a838917..7dfc564 100644 --- a/ext/TensorOperationsChainRulesCoreExt.jl +++ b/ext/TensorOperationsChainRulesCoreExt.jl @@ -40,15 +40,28 @@ function ChainRulesCore.rrule( return output, tensoralloc_pullback end +# this function more or less boils down to `fill!(similar(x), y)` but does so in a single +# call to allow higher-order derivatives +function similar_and_fill(x, y) + x′ = TensorOperations.tensoralloc(typeof(x), TensorOperations.tensorstructure(x)) + return fill!(x′, y) +end +function ChainRulesCore.rrule(::typeof(similar_and_fill), x, y) + similar_and_fill_pullback(Δx) = NoTangent(), ZeroTangent(), tensorscalar(unthunk(Δx)) + return similar_and_fill(x, y), similar_and_fill_pullback +end function ChainRulesCore.rrule(::typeof(tensorscalar), C) - projectC = ProjectTo(C) - function tensorscalar_pullback(Δc) - _Δc = unthunk(Δc) - return NoTangent(), projectC(_Δc) - end + tensorscalar_pullback(Δc) = NoTangent(), similar_and_fill(C, unthunk(Δc)) return tensorscalar(C), tensorscalar_pullback end +# To avoid computing rrules for α and β when these aren't needed, we want to have a +# type-stable quick bail-out +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false + # The current `rrule` design makes sure that the implementation for custom types does # not need to support the backend or allocator arguments # function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!), @@ -99,26 +112,34 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba) _dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...) projectA(_dA) end - dα = @thunk let - _dα = tensorscalar( - tensorcontract( - A, ((), linearize(pA)), !conjA, - ΔC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... + dα = if _needs_tangent(α) + @thunk let + _dα = tensorscalar( + tensorcontract( + A, ((), linearize(pA)), !conjA, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectα(_dα) + projectα(_dα) + end + else + ZeroTangent() end - dβ = @thunk let - # TODO: consider using `inner` - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pA))), true, - ΔC, (trivtuple(numind(pA)), ()), false, - ((), ()), One(), ba... + dβ = if _needs_tangent(β) + @thunk let + # TODO: consider using `inner` + _dβ = tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pA))), true, + ΔC, (trivtuple(numind(pA)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectβ(_dβ) + projectβ(_dβ) + end + else + ZeroTangent() end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba... @@ -212,28 +233,36 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba) ) projectB(_dB) end - dα = @thunk let - C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) - # TODO: consider using `inner` - _dα = tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... + dα = if _needs_tangent(α) + @thunk let + C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + # TODO: consider using `inner` + _dα = tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectα(_dα) + projectα(_dα) + end + else + ZeroTangent() end - dβ = @thunk let - # TODO: consider using `inner` - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(pAB))), true, - ΔC, (trivtuple(numind(pAB)), ()), false, - ((), ()), One(), ba... + dβ = if _needs_tangent(β) + @thunk let + # TODO: consider using `inner` + _dβ = tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(pAB))), true, + ΔC, (trivtuple(numind(pAB)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectβ(_dβ) + projectβ(_dβ) + end + else + ZeroTangent() end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, @@ -301,27 +330,35 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba) ) projectA(_dA) end - dα = @thunk let - C_αβ = tensortrace(A, p, q, false, One(), ba...) - _dα = tensorscalar( - tensorcontract( - C_αβ, ((), trivtuple(numind(p))), - !conjA, - ΔC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... + dα = if _needs_tangent(α) + @thunk let + C_αβ = tensortrace(A, p, q, false, One(), ba...) + _dα = tensorscalar( + tensorcontract( + C_αβ, ((), trivtuple(numind(p))), + !conjA, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectα(_dα) + projectα(_dα) + end + else + ZeroTangent() end - dβ = @thunk let - _dβ = tensorscalar( - tensorcontract( - C, ((), trivtuple(numind(p))), true, - ΔC, (trivtuple(numind(p)), ()), false, - ((), ()), One(), ba... + dβ = if _needs_tangent(β) + @thunk let + _dβ = tensorscalar( + tensorcontract( + C, ((), trivtuple(numind(p))), true, + ΔC, (trivtuple(numind(p)), ()), false, + ((), ()), One(), ba... + ) ) - ) - projectβ(_dβ) + projectβ(_dβ) + end + else + ZeroTangent() end dba = map(_ -> NoTangent(), ba) return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...