From b8b023b7636cf8e16273d257e0c9c430013cf8d3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 4 Mar 2026 09:50:09 +0100 Subject: [PATCH 1/9] fix: make wrong-mode pushforward/pullback return the correct array type --- DifferentiationInterface/Project.toml | 5 +++-- .../DifferentiationInterfaceGPUArraysCoreExt.jl | 7 +++++++ .../src/first_order/pullback.jl | 16 ++++++++-------- .../src/first_order/pushforward.jl | 12 ++++++------ DifferentiationInterface/src/utils/linalg.jl | 12 ++++++++++-- 5 files changed, 34 insertions(+), 18 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index af16a6ae2..38fd6239f 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,13 +1,14 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -authors = ["Guillaume Dalle", "Adrian Hill"] version = "0.7.16" +authors = ["Guillaume Dalle", "Adrian Hill"] [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c" @@ -38,7 +39,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation" DifferentiationInterfaceFiniteDiffExt = "FiniteDiff" DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences" DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] -DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore" +DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"] DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" DifferentiationInterfacePolyesterForwardDiffExt = [ diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl index 60d1ef6c0..ea9dfdaf1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl @@ -1,5 +1,6 @@ module DifferentiationInterfaceGPUArraysCoreExt +using Adapt: adapt import DifferentiationInterface as DI using GPUArraysCore: @allowscalar, AbstractGPUArray @@ -17,4 +18,10 @@ function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T} return b end +function DI.arroftup_to_tupofarr( + tx::AbstractArray{NTuple{B, T}}, x::AbstractGPUArray{T} + ) where {B, T} + return ntuple(b -> adapt(typeof(x), getindex.(tx, b)), Val(B)) +end + end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 333b12967..57c2c8513 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -332,7 +332,7 @@ function _value_and_pullback_via_pushforward( tx = map(ty) do dy dot(a, dy) end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -348,7 +348,7 @@ function _value_and_pullback_via_pushforward( tx = map(ty) do dy real(dot(a, dy)) + im * real(dot(b, dy)) end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -366,7 +366,7 @@ function _value_and_pullback_via_pushforward( dot(a, dy) end end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -387,7 +387,7 @@ function _value_and_pullback_via_pushforward( real(dot(a, dy)) + im * real(dot(b, dy)) end end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function value_and_pullback( @@ -458,7 +458,7 @@ function _value_and_pullback_via_pushforward( tx = map(ty) do dy dot(a, dy) end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -477,7 +477,7 @@ function _value_and_pullback_via_pushforward( tx = map(ty) do dy real(dot(a, dy)) + im * real(dot(b, dy)) end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -495,7 +495,7 @@ function _value_and_pullback_via_pushforward( dot(a, dy) end end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function _value_and_pullback_via_pushforward( @@ -518,7 +518,7 @@ function _value_and_pullback_via_pushforward( real(dot(a, dy)) + im * real(dot(b, dy)) end end - return y, arroftup_to_tupofarr(tx) + return y, arroftup_to_tupofarr(tx, x) end function value_and_pullback( diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index e49203615..0aadac34e 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -331,7 +331,7 @@ function _value_and_pushforward_via_pullback( ty = map(tx) do dx dot(a, dx) end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function _value_and_pushforward_via_pullback( @@ -348,7 +348,7 @@ function _value_and_pushforward_via_pullback( ty = map(tx) do dx real(dot(a, dx)) + im * real(dot(b, dx)) end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function _value_and_pushforward_via_pullback( @@ -367,7 +367,7 @@ function _value_and_pushforward_via_pullback( dot(a, dx) end end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function _value_and_pushforward_via_pullback( @@ -387,7 +387,7 @@ function _value_and_pushforward_via_pullback( real(dot(a, dx)) + im * real(dot(b, dx)) end end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function value_and_pushforward( @@ -460,7 +460,7 @@ function _value_and_pushforward_via_pullback( dot(a, dx) end end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function _value_and_pushforward_via_pullback( @@ -481,7 +481,7 @@ function _value_and_pushforward_via_pullback( real(dot(a, dx)) + im * real(dot(b, dx)) end end - return y, arroftup_to_tupofarr(ty) + return y, arroftup_to_tupofarr(ty, y) end function value_and_pushforward( diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index a5fb15cab..dbda0bda5 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -36,5 +36,13 @@ get_pattern(M::AbstractMatrix) = trues(size(M)) onlysecond((a, b)) = (a, only(b)) -arroftup_to_tupofarr(x::NTuple) = x -arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B)) +""" + arroftup_to_tupofarr(tx, x) + +Convert an array of tuples `tx` into a tuple of arrays, while respecting the array type of the primal `x`. +""" +arroftup_to_tupofarr(tx::NTuple{B, T}, x::T) where {B, T} = tx + +function arroftup_to_tupofarr(tx::AbstractArray{NTuple{B, T}}, x::AbstractArray{T}) where {B, T} + return ntuple(b -> similar(x) .= getindex.(tx, b), Val(B)) +end From ccadb9e55c34b588674721ecdc1b6900de00508c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 4 Mar 2026 12:04:09 +0100 Subject: [PATCH 2/9] Relax typing, add Adapt bound --- DifferentiationInterface/Project.toml | 1 + DifferentiationInterface/src/utils/linalg.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 38fd6239f..d4058e7ac 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -57,6 +57,7 @@ DifferentiationInterfaceTrackerExt = "Tracker" DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] +Adapt = "4.5.0" ADTypes = "1.18.0" ChainRulesCore = "1.23.0" DiffResults = "1.1.0" diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index dbda0bda5..e11b7e266 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -41,8 +41,8 @@ onlysecond((a, b)) = (a, only(b)) Convert an array of tuples `tx` into a tuple of arrays, while respecting the array type of the primal `x`. """ -arroftup_to_tupofarr(tx::NTuple{B, T}, x::T) where {B, T} = tx +arroftup_to_tupofarr(tx::NTuple{B, <:Number}, x::Number) where {B} = tx -function arroftup_to_tupofarr(tx::AbstractArray{NTuple{B, T}}, x::AbstractArray{T}) where {B, T} +function arroftup_to_tupofarr(tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractArray{<:Number}) where {B} return ntuple(b -> similar(x) .= getindex.(tx, b), Val(B)) end From a8a628d20130fb27513e6ce8496cf00bbd2faad0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 4 Mar 2026 16:04:28 +0100 Subject: [PATCH 3/9] Fix method ambiguity --- .../DifferentiationInterfaceGPUArraysCoreExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl index ea9dfdaf1..03d79c89c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl @@ -19,8 +19,8 @@ function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T} end function DI.arroftup_to_tupofarr( - tx::AbstractArray{NTuple{B, T}}, x::AbstractGPUArray{T} - ) where {B, T} + tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractGPUArray{<:Number} + ) where {B} return ntuple(b -> adapt(typeof(x), getindex.(tx, b)), Val(B)) end From ce5819e7b448bec99160d4f7645b72290ac7b620 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:37:51 +0100 Subject: [PATCH 4/9] Add tests --- .../test/Core/Internals/linalg.jl | 19 ++++++++++++++++++- .../test/Core/SimpleFiniteDiff/test.jl | 5 +++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/Core/Internals/linalg.jl b/DifferentiationInterface/test/Core/Internals/linalg.jl index bc12b1b17..c9a372d63 100644 --- a/DifferentiationInterface/test/Core/Internals/linalg.jl +++ b/DifferentiationInterface/test/Core/Internals/linalg.jl @@ -1,6 +1,7 @@ -using DifferentiationInterface: recursive_similar, get_pattern +using DifferentiationInterface: recursive_similar, get_pattern, arroftup_to_tupofarr using SparseArrays using Test +using JLArrays, ComponentArrays @testset "Recursive similar" begin @test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32} @@ -16,3 +17,19 @@ end @test_broken get_pattern(D) == Diagonal(trues(10)) @test get_pattern(sparse(D)) == Diagonal(trues(10)) end + +@testset "Wrong-mode array conversion" begin + x = [1.0, 3.0, 5.0] + xt = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)] + y = ComponentVector(a = [1.0, 3.0], b = [5.0]) + yt = ComponentVector(a = [(1.0, 2.0), (3.0, 4.0)], b = [(5.0, 6.0)]) + z = jl([1.0, 3.0, 5.0]) + zt = jl([(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]) + @test arroftup_to_tupofarr((1.0, 2.0), 1.0) == (1.0, 2.0) + @test arroftup_to_tupofarr(xt, x) == ([1.0, 3.0, 5.0], [2.0, 4.0, 6.0]) + @test arroftup_to_tupofarr(yt, y) == (ComponentVector(a = [1.0, 3.0], b = [5.0]), ComponentVector(a = [2.0, 4.0], b = [6.0])) + @test arroftup_to_tupofarr(zt, z) == (jl([1.0, 3.0, 5.0]), jl([2.0, 4.0, 6.0])) + @test arroftup_to_tupofarr(xt, x)[1] isa Vector + @test arroftup_to_tupofarr(yt, y)[1] isa ComponentVector + @test arroftup_to_tupofarr(zt, z)[1] isa JLVector +end diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index fe739284e..1ff483e17 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -169,3 +169,8 @@ end logging = LOGGING, ) end; + +@testset "Array format preservation in wrong mode" begin + @test gradient(sum, AutoSimpleFiniteDiff(), jl(ones(2))) isa JLVector + @test derivative(t -> jl(fill(t, 2)), AutoReverseFromPrimitive(AutoSimpleFiniteDiff()), 1.0) isa JLVector +end From 5ce58ea16b7cb4558a52a06546497ed76b7984f4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:53:49 +0100 Subject: [PATCH 5/9] feat: DifferentiateWith with contexts --- ...fferentiationInterfaceChainRulesCoreExt.jl | 3 +- .../differentiate_with.jl | 28 +++++--- .../differentiate_with.jl | 4 +- .../DifferentiationInterfaceMooncakeExt.jl | 5 +- .../differentiate_with.jl | 72 ++++++++++++------- .../utils.jl | 16 +++++ .../src/misc/differentiate_with.jl | 16 +++-- DifferentiationInterface/src/utils/context.jl | 2 + .../test/Back/DifferentiateWith/test.jl | 31 ++++++-- 9 files changed, 129 insertions(+), 48 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index bacb7baa4..46ea2487f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -9,7 +9,8 @@ using ChainRulesCore: RuleConfig, frule_via_ad, rrule_via_ad, - unthunk + unthunk, + @not_implemented import DifferentiationInterface as DI ruleconfig(backend::AutoChainRules) = backend.ruleconfig diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl index 292372b81..9a14a7625 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -1,10 +1,22 @@ -function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) - (; f, backend) = dw - y = f(x) - prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,)) - function pullbackfunc(dy) - tx = DI.pullback(f, prep_same, backend, x, (dy,)) - return (NoTangent(), only(tx)) +function ChainRulesCore.rrule( + dw::DI.DifferentiateWith{C}, x, contexts::Vararg{Any, C} + ) where {C} + (; f, backend, context_wrappers) = dw + y = f(x, contexts...) + wrapped_contexts = map(DI.call, context_wrappers, contexts) + prep_same = DI.prepare_pullback_same_point_nokwarg( + Val(false), f, backend, x, (y,), wrapped_contexts... + ) + function diffwith_pullbackfunc(dy) + dx = DI.pullback(f, prep_same, backend, x, (dy,), wrapped_contexts...) |> only + dc = map(contexts) do c + @not_implemented( + """ + Derivatives with respect to context arguments are not implemented. + """ + ) + end + return (NoTangent(), dx, dc...) end - return y, pullbackfunc + return y, diffwith_pullbackfunc end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl index 96316f5b6..4d46bc613 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/differentiate_with.jl @@ -1,4 +1,4 @@ -function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N} +function (dw::DI.DifferentiateWith{0})(x::Dual{T, V, N}) where {T, V, N} (; f, backend) = dw xval = myvalue(T, x) tx = mypartials(T, Val(N), x) @@ -6,7 +6,7 @@ function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N} return make_dual(T, y, ty) end -function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T, V, N}}) where {T, V, N} +function (dw::DI.DifferentiateWith{0})(x::AbstractArray{Dual{T, V, N}}) where {T, V, N} (; f, backend) = dw xval = myvalue(T, x) tx = mypartials(T, Val(N), x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 3513d548c..626e3660c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -18,6 +18,7 @@ using Mooncake: value_and_pullback!!, zero_dual, zero_tangent, + zero_rdata, rdata_type, fdata, rdata, @@ -26,11 +27,13 @@ using Mooncake: @is_primitive, zero_fcodual, MinimalCtx, + NoFData, NoRData, primal, _copy_output, _copy_to_output!!, - tangent_to_primal!! + tangent_to_primal!!, + increment!! const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index ad2d9f7c7..59b759cde 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,4 +1,9 @@ -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any} +const NumberOrArray = Union{Number, AbstractArray{<:Number}} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, <:NumberOrArray} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, <:NumberOrArray, <:NumberOrArray} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray} +# TODO: generate more cases programmatically struct MooncakeDifferentiateWithError <: Exception F::Type @@ -12,28 +17,37 @@ end function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) return print( io, - "MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.", + "MooncakeDifferentiateWithError: For the function type $(e.F) and argument types $(e.X), the output type $(e.Y) is currently not supported.", ) end -function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) +function Mooncake.rrule!!( + dw::CoDual{<:DI.DifferentiateWith{C}}, + x::CoDual{<:Number}, + contexts::Vararg{CoDual, C} + ) where {C} + @assert tangent_type(typeof(dw)) == NoTangent primal_func = primal(dw) primal_x = primal(x) - (; f, backend) = primal_func - y = zero_fcodual(f(primal_x)) + primal_contexts = map(primal, contexts) + (; f, backend, context_wrappers) = primal_func + y = zero_fcodual(f(primal_x, primal_contexts...)) + wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts) # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) - tx = DI.pullback(f, backend, primal_x, (y.dx,)) - @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) - return NoRData(), rdata(only(tx)) + dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only + @assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x))) + rc = nanify_fdata_and_rdata!!(contexts...) + return (NoRData(), rdata(dx), rc...) end # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) - tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) - return NoRData(), rdata(only(tx)) + dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) + rc = nanify_fdata_and_rdata!!(contexts...) + return (NoRData(), rdata(dx), rc...) end pullback = if primal(y) isa Number @@ -41,35 +55,41 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number elseif primal(y) isa AbstractArray pullback_array!! else - throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) + throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y))) end return y, pullback end function Mooncake.rrule!!( - dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}} - ) + dw::CoDual{<:DI.DifferentiateWith{C}}, + x::CoDual{<:AbstractArray{<:Number}}, + contexts::Vararg{CoDual, C} + ) where {C} + @assert tangent_type(typeof(dw)) == NoTangent primal_func = primal(dw) primal_x = primal(x) - fdata_arg = x.dx - (; f, backend) = primal_func - y = zero_fcodual(f(primal_x)) + primal_contexts = map(primal, contexts) + (; f, backend, context_wrappers) = primal_func + y = zero_fcodual(f(primal_x, primal_contexts...)) + wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts) # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) - tx = DI.pullback(f, backend, primal_x, (y.dx,)) - @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) - fdata_arg .+= only(tx) - return NoRData(), dy + dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) + x.dx .+= dx + rc = nanify_fdata_and_rdata!!(contexts...) + return (NoRData(), dy, rc...) end # output is a scalar, so we can use the scalar pullback function pullback_scalar!!(dy::Number) - tx = DI.pullback(f, backend, primal_x, (dy,)) - @assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) - fdata_arg .+= only(tx) - return NoRData(), NoRData() + dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) + x.dx .+= dx + rc = nanify_fdata_and_rdata!!(contexts...) + return (NoRData(), NoRData(), rc...) end pullback = if primal(y) isa Number @@ -77,7 +97,7 @@ function Mooncake.rrule!!( elseif primal(y) isa AbstractArray pullback_array!! else - throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y))) + throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y))) end return y, pullback diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b22d8d49b..25c525308 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -17,3 +17,19 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) return zero_tangent(x) end end + +nanify(x::AbstractFloat) = convert(typeof(x), NaN) +nanify(x::AbstractArray) = map(nan_tangent, x) +nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x) +nanify(::NoFData) = NoFData() +nanify(::NoRData) = NoRData() + +function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C} + primal_contexts = map(primal, contexts) + fdata_contexts = map(fdata, contexts) + zero_rdata_contexts = map(zero_rdata, primal_contexts) + foreach(fdata_contexts) do fc + increment!!(fc, nanify(fc)) + end + return map(nanify, zero_rdata_contexts) +end diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index f0c2ecf38..b9a6f0322 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -14,7 +14,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be !!! warning - `DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments. + `DifferentiateWith` only supports out-of-place functions `y = f(x, contexts...)`, where the derivatives with respect to `contexts` can be safely ignored in the rest of your code. It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl), reverse-mode [Mooncake](https://github.com/chalk-lab/Mooncake.jl), or if it automatically importing rules from [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) (e.g. [Zygote](https://github.com/FluxML/Zygote.jl)). Some backends are also able to [manually import rules](https://juliadiff.org/ChainRulesCore.jl/stable/#Packages-supporting-importing-rules-from-ChainRules.) from ChainRules. For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper). @@ -25,8 +25,9 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be # Fields - - `f`: the function in question, with signature `f(x)` + - `f`: the function in question, with signature `f(x, contexts...)` - `backend::AbstractADType`: the substitute backend to use for differentiation + - `context_wrappers::NTuple`: a tuple like `(Constant, Cache)`, meaning that `f(x, a, b)` will be differentiated with `Constant(a)` and `Cache(b)` as contexts. !!! note @@ -34,7 +35,7 @@ Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be # Constructor - DifferentiateWith(f, backend) + DifferentiateWith(f, backend, context_wrappers) # Example @@ -69,12 +70,17 @@ julia> Zygote.gradient(alg, [3.0, 5.0])[1] 70.0 ``` """ -struct DifferentiateWith{F, B <: AbstractADType} +struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}} f::F backend::B + context_wrappers::N end -(dw::DifferentiateWith)(x) = dw.f(x) +DifferentiateWith(f::F, backend::AbstractADType) where {F} = DifferentiateWith(f, backend, ()) + +function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C} + return dw.f(x, args...) +end function Base.show(io::IO, dw::DifferentiateWith) (; f, backend) = dw diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 2d2575d01..3058a4c63 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -179,3 +179,5 @@ Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are """ @inline fix_tail(f::F) where {F} = f fix_tail(f::F, args::Vararg{Any, N}) where {F, N} = FixTail(f, args...) + +@inline call(f::F, x) where {F} = f(x) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 860d5e85f..98e783777 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -24,9 +24,14 @@ function (adb::ADBreaker)(x::AbstractArray) return adb.f(x) end -function differentiatewith_scenarios() - outofplace_scens = filter(DIT.default_scenarios()) do scen - DIT.function_place(scen) == :out +# TODO: break Mooncake with overlay? + +function differentiatewith_scenarios(; kwargs...) + outofplace_scens = filter(DIT.default_scenarios(; kwargs...)) do scen + DIT.function_place(scen) == :out && + # save some time + !isa(scen.x, AbstractMatrix) && + !isa(scen.y, AbstractMatrix) end # with bad_scens, everything would break bad_scens = map(outofplace_scens) do scen @@ -44,7 +49,23 @@ test_differentiation( differentiatewith_scenarios(); excluded = SECOND_ORDER, logging = LOGGING, - testset_name = "DI tests", + testset_name = "DI tests - normal", +) + +test_differentiation( + [AutoZygote(), AutoMooncake(; config = nothing)], + map(DIT.constantify, differentiatewith_scenarios()); + excluded = SECOND_ORDER, + logging = LOGGING, + testset_name = "DI tests - Constant", +) + +test_differentiation( + [AutoMooncake(; config = nothing)], + map(DIT.cachify, differentiatewith_scenarios()); + excluded = SECOND_ORDER, + logging = LOGGING, + testset_name = "DI tests - Cache", ) @testset "ChainRules tests" begin @@ -71,7 +92,7 @@ end; e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) @test sprint(showerror, e) == - "MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported." + "MooncakeDifferentiateWithError: For the function type typeof(identity) and input types (Float64,), the output type Float64 is currently not supported." f_num2tup(x::Number) = (x,) f_vec2tup(x::Vector) = (first(x),) From 72cbf6e7076d6cf13bc70e6344bf0fafda2d0fe7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 4 Mar 2026 20:05:52 +0100 Subject: [PATCH 6/9] Fixes --- .../ext/DifferentiationInterfaceMooncakeExt/utils.jl | 2 +- DifferentiationInterface/src/misc/differentiate_with.jl | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 25c525308..e01631e67 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -26,7 +26,7 @@ nanify(::NoRData) = NoRData() function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C} primal_contexts = map(primal, contexts) - fdata_contexts = map(fdata, contexts) + fdata_contexts = map(tangent, contexts) zero_rdata_contexts = map(zero_rdata, primal_contexts) foreach(fdata_contexts) do fc increment!!(fc, nanify(fc)) diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index b9a6f0322..3dc084a00 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -83,7 +83,7 @@ function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C} end function Base.show(io::IO, dw::DifferentiateWith) - (; f, backend) = dw + (; f, backend, context_wrappers) = dw return print( io, DifferentiateWith, @@ -91,6 +91,8 @@ function Base.show(io::IO, dw::DifferentiateWith) repr(f; context = io), ", ", repr(backend; context = io), + ", ", + repr(context_wrappers; context = io), ")", ) end From bad6a010e0fc5e535513c187228e82bd633e700b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Mar 2026 12:46:51 +0100 Subject: [PATCH 7/9] Fix tests --- DifferentiationInterface/test/Back/DifferentiateWith/test.jl | 2 +- DifferentiationInterface/test/Core/Internals/display.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 98e783777..0454fe9dc 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -62,7 +62,7 @@ test_differentiation( test_differentiation( [AutoMooncake(; config = nothing)], - map(DIT.cachify, differentiatewith_scenarios()); + map(s -> DIT.cachify(s; use_tuples = true), differentiatewith_scenarios()); excluded = SECOND_ORDER, logging = LOGGING, testset_name = "DI tests - Cache", diff --git a/DifferentiationInterface/test/Core/Internals/display.jl b/DifferentiationInterface/test/Core/Internals/display.jl index 316fa6921..f1200fcfe 100644 --- a/DifferentiationInterface/test/Core/Internals/display.jl +++ b/DifferentiationInterface/test/Core/Internals/display.jl @@ -11,7 +11,7 @@ detector = DenseSparsityDetector(AutoForwardDiff(); atol = 1.0e-23) "DenseSparsityDetector(AutoForwardDiff(); atol=1.0e-23, method=:iterative)" diffwith = DifferentiateWith(exp, AutoForwardDiff()) -@test string(diffwith) == "DifferentiateWith(exp, AutoForwardDiff())" +@test string(diffwith) == "DifferentiateWith(exp, AutoForwardDiff(), ())" @test required_packages(AutoForwardDiff()) == ["ForwardDiff"] @test required_packages(AutoZygote()) == ["Zygote"] From b19bcde6028c8bf721595b882dcd25ce9890404d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 7 Mar 2026 10:20:10 +0100 Subject: [PATCH 8/9] Fix errors --- .../differentiate_with.jl | 14 +++++++------- .../test/Back/DifferentiateWith/test.jl | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 59b759cde..846146402 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,8 +1,8 @@ const NumberOrArray = Union{Number, AbstractArray{<:Number}} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, <:NumberOrArray} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, <:NumberOrArray, <:NumberOrArray} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray, <:NumberOrArray} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any} +@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any} # TODO: generate more cases programmatically struct MooncakeDifferentiateWithError <: Exception @@ -17,14 +17,14 @@ end function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) return print( io, - "MooncakeDifferentiateWithError: For the function type $(e.F) and argument types $(e.X), the output type $(e.Y) is currently not supported.", + "MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.", ) end function Mooncake.rrule!!( dw::CoDual{<:DI.DifferentiateWith{C}}, x::CoDual{<:Number}, - contexts::Vararg{CoDual, C} + contexts::Vararg{CoDual{<:NumberOrArray}, C} ) where {C} @assert tangent_type(typeof(dw)) == NoTangent primal_func = primal(dw) @@ -64,7 +64,7 @@ end function Mooncake.rrule!!( dw::CoDual{<:DI.DifferentiateWith{C}}, x::CoDual{<:AbstractArray{<:Number}}, - contexts::Vararg{CoDual, C} + contexts::Vararg{CoDual{<:NumberOrArray}, C} ) where {C} @assert tangent_type(typeof(dw)) == NoTangent primal_func = primal(dw) diff --git a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl index 0454fe9dc..fb261ee57 100644 --- a/DifferentiationInterface/test/Back/DifferentiateWith/test.jl +++ b/DifferentiationInterface/test/Back/DifferentiateWith/test.jl @@ -90,9 +90,9 @@ end; MooncakeDifferentiateWithError = Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError - e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) + e = MooncakeDifferentiateWithError(identity, (1.0,), 2.0) @test sprint(showerror, e) == - "MooncakeDifferentiateWithError: For the function type typeof(identity) and input types (Float64,), the output type Float64 is currently not supported." + "MooncakeDifferentiateWithError: For the function type `typeof(identity)` and input types `Tuple{Float64}`, the output type `Float64` is currently not supported." f_num2tup(x::Number) = (x,) f_vec2tup(x::Vector) = (first(x),) From d5d5b31d4473e97e80f32ad21e3b655b02612e3a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:03:49 +0100 Subject: [PATCH 9/9] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../differentiate_with.jl | 14 ++++++----- .../utils.jl | 5 ++-- .../src/misc/differentiate_with.jl | 23 +++++++++++++++++++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl index 846146402..95efbd7f0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl @@ -1,10 +1,12 @@ const NumberOrArray = Union{Number, AbstractArray{<:Number}} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any} -@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any} -# TODO: generate more cases programmatically +# Mark DifferentiateWith with a range of context arities as primitives. +# For C contexts, the corresponding call tuple type is +# Tuple{DI.DifferentiateWith{C}, Any, Vararg{Any, C}}: +# one slot for the primal input x and C slots for contexts. +for C in 0:16 + @eval @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{$C}, Vararg{Any, $(C + 1)}} +end struct MooncakeDifferentiateWithError <: Exception F::Type X::Type @@ -37,7 +39,7 @@ function Mooncake.rrule!!( # output is a vector, so we need to use the vector pullback function pullback_array!!(dy::NoRData) dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only - @assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x))) + @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) rc = nanify_fdata_and_rdata!!(contexts...) return (NoRData(), rdata(dx), rc...) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index e01631e67..edda7bdb2 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -19,8 +19,9 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) end nanify(x::AbstractFloat) = convert(typeof(x), NaN) -nanify(x::AbstractArray) = map(nan_tangent, x) -nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x) +nanify(x::AbstractArray) = map(nanify, x) +nanify(x::NamedTuple) = NamedTuple{keys(x)}(map(nanify, values(x))) +nanify(x::Tuple) = map(nanify, x) nanify(::NoFData) = NoFData() nanify(::NoRData) = NoRData() diff --git a/DifferentiationInterface/src/misc/differentiate_with.jl b/DifferentiationInterface/src/misc/differentiate_with.jl index 3dc084a00..249a08eac 100644 --- a/DifferentiationInterface/src/misc/differentiate_with.jl +++ b/DifferentiationInterface/src/misc/differentiate_with.jl @@ -76,6 +76,29 @@ struct DifferentiateWith{C, F, B <: AbstractADType, N <: NTuple{C, Any}} context_wrappers::N end +function DifferentiateWith( + f::F, + backend::B, + context_wrappers::NTuple{C, Any}, +) where {F, B <: AbstractADType, C} + for (i, wrapper) in pairs(context_wrappers) + # Accept typical constructor-like values: functions or types. + if !(wrapper isa Function || wrapper isa Type) + throw( + ArgumentError( + "Each context wrapper must be a callable object or type " * + "(e.g., a wrapper constructor like `Constant` or `Cache`), " * + "but element $i has type $(typeof(wrapper)).", + ), + ) + end + end + return DifferentiateWith{C, F, B, typeof(context_wrappers)}( + f, + backend, + context_wrappers, + ) +end DifferentiateWith(f::F, backend::AbstractADType) where {F} = DifferentiateWith(f, backend, ()) function (dw::DifferentiateWith{C})(x, args::Vararg{Any, C}) where {C}