Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.7.16"
authors = ["Guillaume Dalle", "Adrian Hill"]

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Expand Down Expand Up @@ -38,7 +39,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
DifferentiationInterfaceGPUArraysCoreExt = ["GPUArraysCore", "Adapt"]
DifferentiationInterfaceGTPSAExt = "GTPSA"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = [
Expand All @@ -56,6 +57,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
Adapt = "4.5.0"
ADTypes = "1.18.0"
ChainRulesCore = "1.23.0"
DiffResults = "1.1.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ using ChainRulesCore:
RuleConfig,
frule_via_ad,
rrule_via_ad,
unthunk
unthunk,
@not_implemented
import DifferentiationInterface as DI

ruleconfig(backend::AutoChainRules) = backend.ruleconfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,))
function pullbackfunc(dy)
tx = DI.pullback(f, prep_same, backend, x, (dy,))
return (NoTangent(), only(tx))
function ChainRulesCore.rrule(
dw::DI.DifferentiateWith{C}, x, contexts::Vararg{Any, C}
) where {C}
(; f, backend, context_wrappers) = dw
y = f(x, contexts...)
wrapped_contexts = map(DI.call, context_wrappers, contexts)
prep_same = DI.prepare_pullback_same_point_nokwarg(
Val(false), f, backend, x, (y,), wrapped_contexts...
)
function diffwith_pullbackfunc(dy)
dx = DI.pullback(f, prep_same, backend, x, (dy,), wrapped_contexts...) |> only
dc = map(contexts) do c
@not_implemented(
"""
Derivatives with respect to context arguments are not implemented.
"""
)
end
return (NoTangent(), dx, dc...)
end
return y, pullbackfunc
return y, diffwith_pullbackfunc
end
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function (dw::DI.DifferentiateWith)(x::Dual{T, V, N}) where {T, V, N}
function (dw::DI.DifferentiateWith{0})(x::Dual{T, V, N}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
return make_dual(T, y, ty)
end

function (dw::DI.DifferentiateWith)(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
function (dw::DI.DifferentiateWith{0})(x::AbstractArray{Dual{T, V, N}}) where {T, V, N}
(; f, backend) = dw
xval = myvalue(T, x)
tx = mypartials(T, Val(N), x)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module DifferentiationInterfaceGPUArraysCoreExt

using Adapt: adapt
import DifferentiationInterface as DI
using GPUArraysCore: @allowscalar, AbstractGPUArray

Expand All @@ -17,4 +18,10 @@ function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
return b
end

function DI.arroftup_to_tupofarr(
tx::AbstractArray{<:NTuple{B, <:Number}}, x::AbstractGPUArray{<:Number}
) where {B}
return ntuple(b -> adapt(typeof(x), getindex.(tx, b)), Val(B))
end

end
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using Mooncake:
value_and_pullback!!,
zero_dual,
zero_tangent,
zero_rdata,
rdata_type,
fdata,
rdata,
Expand All @@ -26,11 +27,13 @@ using Mooncake:
@is_primitive,
zero_fcodual,
MinimalCtx,
NoFData,
NoRData,
primal,
_copy_output,
_copy_to_output!!,
tangent_to_primal!!
tangent_to_primal!!,
increment!!

const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any}
const NumberOrArray = Union{Number, AbstractArray{<:Number}}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any}
# TODO: generate more cases programmatically

Comment on lines +2 to 7
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mooncake is only marking DifferentiateWith{0} through DifferentiateWith{3} as primitives. If a user constructs DifferentiateWith with more than 3 context arguments, Mooncake will not treat it as a primitive and will likely differentiate through the wrapper instead of using this custom rule (defeating the purpose of DifferentiateWith for contexts). Consider generating these @is_primitive declarations programmatically for a reasonable range, or using a Vararg-friendly approach if Mooncake supports it.

Suggested change
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any}
# TODO: generate more cases programmatically
# Mark DifferentiateWith with a range of context arities as primitives.
# For C contexts, the corresponding call tuple type is
# Tuple{DI.DifferentiateWith{C}, Any, Vararg{Any, C}}:
# one slot for the primal input x and C slots for contexts.
for C in 0:16
@eval @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{$C}, Vararg{Any, $(C + 1)}}
end

Copilot uses AI. Check for mistakes.
struct MooncakeDifferentiateWithError <: Exception
F::Type
Expand All @@ -12,72 +17,87 @@ end
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
return print(
io,
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.",
"MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.",
)
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith{C}},
x::CoDual{<:Number},
contexts::Vararg{CoDual{<:NumberOrArray}, C}
) where {C}
@assert tangent_type(typeof(dw)) == NoTangent
primal_func = primal(dw)
primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
primal_contexts = map(primal, contexts)
(; f, backend, context_wrappers) = primal_func
y = zero_fcodual(f(primal_x, primal_contexts...))
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
@assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x)))
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only(dx) will error here because dx is already the (scalar) pullback result after |> only (i.e., it is not a collection). This assertion should likely use rdata(dx) directly (or otherwise avoid calling only on dx).

Suggested change
@assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x)))
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))

Copilot uses AI. Check for mistakes.
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), rdata(dx), rc...)
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), rdata(dx), rc...)
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
end

return y, pullback
end

function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray{<:Number}}
)
dw::CoDual{<:DI.DifferentiateWith{C}},
x::CoDual{<:AbstractArray{<:Number}},
contexts::Vararg{CoDual{<:NumberOrArray}, C}
) where {C}
@assert tangent_type(typeof(dw)) == NoTangent
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = x.dx
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))
primal_contexts = map(primal, contexts)
(; f, backend, context_wrappers) = primal_func
y = zero_fcodual(f(primal_x, primal_contexts...))
wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts)

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), dy
dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
x.dx .+= dx
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), dy, rc...)
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), NoRData()
dx = DI.pullback(f, backend, primal_x, (dy,), wrapped_primal_contexts...) |> only
@assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x)))
x.dx .+= dx
rc = nanify_fdata_and_rdata!!(contexts...)
return (NoRData(), NoRData(), rc...)
end

pullback = if primal(y) isa Number
pullback_scalar!!
elseif primal(y) isa AbstractArray
pullback_array!!
else
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
throw(MooncakeDifferentiateWithError(primal_func, (primal_x, primal_contexts...), primal(y)))
end

return y, pullback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,19 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
return zero_tangent(x)
end
end

nanify(x::AbstractFloat) = convert(typeof(x), NaN)
nanify(x::AbstractArray) = map(nan_tangent, x)
nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x)
Comment on lines +22 to +23
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nanify calls nan_tangent, but there is no definition/import for nan_tangent in this extension (or elsewhere in the repo), so this will throw an UndefVarError at runtime. Replace nan_tangent with the intended poisoning function (likely nanify recursively), and ensure the returned container matches the expected Mooncake fdata/rdata structure.

Suggested change
nanify(x::AbstractArray) = map(nan_tangent, x)
nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x)
nanify(x::AbstractArray) = map(nanify, x)
nanify(x::NamedTuple) = NamedTuple{keys(x)}(map(nanify, values(x)))
nanify(x::Tuple) = map(nanify, x)

Copilot uses AI. Check for mistakes.
nanify(::NoFData) = NoFData()
nanify(::NoRData) = NoRData()

function nanify_fdata_and_rdata!!(contexts::Vararg{CoDual, C}) where {C}
primal_contexts = map(primal, contexts)
fdata_contexts = map(tangent, contexts)
zero_rdata_contexts = map(zero_rdata, primal_contexts)
foreach(fdata_contexts) do fc
increment!!(fc, nanify(fc))
end
return map(nanify, zero_rdata_contexts)
end
16 changes: 8 additions & 8 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -348,7 +348,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -366,7 +366,7 @@ function _value_and_pullback_via_pushforward(
dot(a, dy)
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -387,7 +387,7 @@ function _value_and_pullback_via_pushforward(
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function value_and_pullback(
Expand Down Expand Up @@ -458,7 +458,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -477,7 +477,7 @@ function _value_and_pullback_via_pushforward(
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -495,7 +495,7 @@ function _value_and_pullback_via_pushforward(
dot(a, dy)
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -518,7 +518,7 @@ function _value_and_pullback_via_pushforward(
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return y, arroftup_to_tupofarr(tx)
return y, arroftup_to_tupofarr(tx, x)
end

function value_and_pullback(
Expand Down
12 changes: 6 additions & 6 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ function _value_and_pushforward_via_pullback(
ty = map(tx) do dx
dot(a, dx)
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -348,7 +348,7 @@ function _value_and_pushforward_via_pullback(
ty = map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -367,7 +367,7 @@ function _value_and_pushforward_via_pullback(
dot(a, dx)
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -387,7 +387,7 @@ function _value_and_pushforward_via_pullback(
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function value_and_pushforward(
Expand Down Expand Up @@ -460,7 +460,7 @@ function _value_and_pushforward_via_pullback(
dot(a, dx)
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -481,7 +481,7 @@ function _value_and_pushforward_via_pullback(
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return y, arroftup_to_tupofarr(ty)
return y, arroftup_to_tupofarr(ty, y)
end

function value_and_pushforward(
Expand Down
Loading
Loading