Skip to content

feat: DifferentiateWith + context arguments#975

Draft
gdalle wants to merge 7 commits intomainfrom
gd/diffwith
Draft

feat: DifferentiateWith + context arguments#975
gdalle wants to merge 7 commits intomainfrom
gd/diffwith

Conversation

@gdalle
Copy link
Member

@gdalle gdalle commented Mar 4, 2026

Partial fix for #806, #675.

Since DI doesn't compute derivatives with respect to contexts, in each implementation we need a way to mark the derivative as unknown.

  • Enrich DifferentiateWith with a context_wrappers field, denoting how additional arguments beyond x must be taken into account.
  • Implement unknown context derivatives in reverse mode:
    • With Zygote by returning a ChainRulesCore.@not_implemented tangent
    • With Mooncake by poisoning the FData and RData of context arguments using NaN
  • Implement unknown context derivatives in forward mode:
    • Impossible with ForwardDiff because we don't know right away whether a context argument contains duals or not, and any Constant context containing Duals would make the derivative wrong. We could allow Caches though.
    • Possible with Mooncake but the frule!! doesn't yet exist for DifferentiateWith
  • Add DI tests
  • Add native tests: impossible because we are missing some tangents?

@gdalle gdalle changed the title Gd/diffwith feat: DifferentiateWith + context arguments Mar 4, 2026
@gdalle
Copy link
Member Author

gdalle commented Mar 4, 2026

@Technici4n I'd love to have some feedback on this, and the related chalk-lab/Mooncake.jl#548 it enables!

@codecov
Copy link

codecov bot commented Mar 4, 2026

Codecov Report

❌ Patch coverage is 25.88235% with 63 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.89%. Comparing base (2b7c5e9) to head (bad6a01).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...ntiationInterfaceMooncakeExt/differentiate_with.jl 0.00% 31 Missing ⚠️
...e/ext/DifferentiationInterfaceMooncakeExt/utils.jl 0.00% 12 Missing ⚠️
...onInterfaceChainRulesCoreExt/differentiate_with.jl 0.00% 11 Missing ⚠️
...ationInterface/test/Back/DifferentiateWith/test.jl 0.00% 4 Missing ⚠️
...ationInterfaceForwardDiffExt/differentiate_with.jl 0.00% 2 Missing ⚠️
...rentiationInterface/src/misc/differentiate_with.jl 50.00% 2 Missing ⚠️
DifferentiationInterface/src/utils/context.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #975      +/-   ##
==========================================
+ Coverage   94.11%   95.89%   +1.78%     
==========================================
  Files         135      131       -4     
  Lines        7980     7999      +19     
==========================================
+ Hits         7510     7671     +161     
+ Misses        470      328     -142     
Flag Coverage Δ
DI 95.95% <25.88%> (+2.66%) ⬆️
DIT 95.76% <ø> (-0.47%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Technici4n
Copy link

I'm not very familiar with contexts but the basic idea sounds reasonable. Why is the nan poisoning only used now and wasn't previously? Does this PR introduce a new kind of context?

@gdalle
Copy link
Member Author

gdalle commented Mar 6, 2026

The idea behind contexts is summed up concisely in https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/stable/explanation/arguments/.
Essentially, DI only computes derivatives with respect to the first argument, but it can take into account additional information, either constants or caches that play a role in the computation but with respect to which the derivatives are not obtained. When writing a rule for a multi-argument function based on an underlying DI call, we thus have to signal that the pullback wrt the first argument is the only one that is actually correct and usable. The rest of the arguments (constants and caches) thus get NaN-ified pullbacks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants