-
Notifications
You must be signed in to change notification settings - Fork 29
feat: DifferentiateWith + context arguments #975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b8b023b
ccadb9e
a8a628d
ce5819e
5ce58ea
72cbf6e
bad6a01
b19bcde
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,22 @@ | ||
| function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x) | ||
| (; f, backend) = dw | ||
| y = f(x) | ||
| prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,)) | ||
| function pullbackfunc(dy) | ||
| tx = DI.pullback(f, prep_same, backend, x, (dy,)) | ||
| return (NoTangent(), only(tx)) | ||
| function ChainRulesCore.rrule( | ||
| dw::DI.DifferentiateWith{C}, x, contexts::Vararg{Any, C} | ||
| ) where {C} | ||
| (; f, backend, context_wrappers) = dw | ||
| y = f(x, contexts...) | ||
| wrapped_contexts = map(DI.call, context_wrappers, contexts) | ||
| prep_same = DI.prepare_pullback_same_point_nokwarg( | ||
| Val(false), f, backend, x, (y,), wrapped_contexts... | ||
| ) | ||
| function diffwith_pullbackfunc(dy) | ||
| dx = DI.pullback(f, prep_same, backend, x, (dy,), wrapped_contexts...) |> only | ||
| dc = map(contexts) do c | ||
| @not_implemented( | ||
| """ | ||
| Derivatives with respect to context arguments are not implemented. | ||
| """ | ||
| ) | ||
| end | ||
| return (NoTangent(), dx, dc...) | ||
| end | ||
| return y, pullbackfunc | ||
| return y, diffwith_pullbackfunc | ||
| end |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,4 +1,9 @@ | ||||||
| @is_primitive MinimalCtx Tuple{DI.DifferentiateWith, <:Any} | ||||||
| const NumberOrArray = Union{Number, AbstractArray{<:Number}} | ||||||
| @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{0}, Any} | ||||||
| @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{1}, Any, Any} | ||||||
| @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{2}, Any, Any, Any} | ||||||
| @is_primitive MinimalCtx Tuple{DI.DifferentiateWith{3}, Any, Any, Any, Any} | ||||||
| # TODO: generate more cases programmatically | ||||||
|
|
||||||
| struct MooncakeDifferentiateWithError <: Exception | ||||||
| F::Type | ||||||
|
|
@@ -12,72 +17,87 @@ end | |||||
| function Base.showerror(io::IO, e::MooncakeDifferentiateWithError) | ||||||
| return print( | ||||||
| io, | ||||||
| "MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.", | ||||||
| "MooncakeDifferentiateWithError: For the function type `$(e.F)` and input types `$(e.X)`, the output type `$(e.Y)` is currently not supported.", | ||||||
| ) | ||||||
| end | ||||||
|
|
||||||
| function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number}) | ||||||
| function Mooncake.rrule!!( | ||||||
| dw::CoDual{<:DI.DifferentiateWith{C}}, | ||||||
| x::CoDual{<:Number}, | ||||||
| contexts::Vararg{CoDual{<:NumberOrArray}, C} | ||||||
| ) where {C} | ||||||
| @assert tangent_type(typeof(dw)) == NoTangent | ||||||
| primal_func = primal(dw) | ||||||
| primal_x = primal(x) | ||||||
| (; f, backend) = primal_func | ||||||
| y = zero_fcodual(f(primal_x)) | ||||||
| primal_contexts = map(primal, contexts) | ||||||
| (; f, backend, context_wrappers) = primal_func | ||||||
| y = zero_fcodual(f(primal_x, primal_contexts...)) | ||||||
| wrapped_primal_contexts = map(DI.call, context_wrappers, primal_contexts) | ||||||
|
|
||||||
| # output is a vector, so we need to use the vector pullback | ||||||
| function pullback_array!!(dy::NoRData) | ||||||
| tx = DI.pullback(f, backend, primal_x, (y.dx,)) | ||||||
| @assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) | ||||||
| return NoRData(), rdata(only(tx)) | ||||||
| dx = DI.pullback(f, backend, primal_x, (y.dx,), wrapped_primal_contexts...) |> only | ||||||
| @assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x))) | ||||||
|
||||||
| @assert rdata(only(dx)) isa rdata_type(tangent_type(typeof(primal_x))) | |
| @assert rdata(dx) isa rdata_type(tangent_type(typeof(primal_x))) |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -17,3 +17,19 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake) | |||||||||||
| return zero_tangent(x) | ||||||||||||
| end | ||||||||||||
| end | ||||||||||||
|
|
||||||||||||
| nanify(x::AbstractFloat) = convert(typeof(x), NaN) | ||||||||||||
| nanify(x::AbstractArray) = map(nan_tangent, x) | ||||||||||||
| nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x) | ||||||||||||
|
Comment on lines
+22
to
+23
|
||||||||||||
| nanify(x::AbstractArray) = map(nan_tangent, x) | |
| nanify(x::Union{Tuple, NamedTuple}) = map(nan_tangent, x) | |
| nanify(x::AbstractArray) = map(nanify, x) | |
| nanify(x::NamedTuple) = NamedTuple{keys(x)}(map(nanify, values(x))) | |
| nanify(x::Tuple) = map(nanify, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mooncake is only marking
DifferentiateWith{0}throughDifferentiateWith{3}as primitives. If a user constructsDifferentiateWithwith more than 3 context arguments, Mooncake will not treat it as a primitive and will likely differentiate through the wrapper instead of using this custom rule (defeating the purpose ofDifferentiateWithfor contexts). Consider generating these@is_primitivedeclarations programmatically for a reasonable range, or using a Vararg-friendly approach if Mooncake supports it.