From 9481898928b8e2a1e2db6e830c783dc1ad8e7f20 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 21 Nov 2025 07:37:55 +0100 Subject: [PATCH 1/3] fix: improve wrong-mode pushforward/pullback --- .github/workflows/Test.yml | 2 +- .../src/first_order/pullback.jl | 100 ++++++++++-------- .../src/first_order/pushforward.jl | 83 ++++++++------- DifferentiationInterface/src/utils/linalg.jl | 3 + 4 files changed, 99 insertions(+), 89 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index bcc7f30bf..a1d6fc87b 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -166,7 +166,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true + fail-fast: false # TODO: toggle matrix: version: - '1.10' diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 7b1f85fb9..993bd985e 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -325,12 +325,14 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::Real, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) - dx = dot(a, dy) - return y, dx + tx = map(ty) do dy + dot(a, dy) + end + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -338,13 +340,15 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::Complex, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)) - dx = real(dot(a, dy)) + im * real(dot(b, dy)) - return y, dx + tx = map(ty) do dy + real(dot(a, dy)) + im * real(dot(b, dy)) + end + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -352,15 +356,17 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::AbstractArray{<:Real}, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y = f(x, map(unwrap, contexts)...) - dx = map(CartesianIndices(x)) do j + tx = map(CartesianIndices(x)) do j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) - dot(a, dy) + map(ty) do dy + dot(a, dy) + end end - return y, dx + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -368,18 +374,20 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::AbstractArray{<:Complex}, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y = f(x, map(unwrap, contexts)...) - dx = map(CartesianIndices(x)) do j + tx = map(CartesianIndices(x)) do j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) b = only( pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...), ) - real(dot(a, dy)) + im * real(dot(b, dy)) + map(ty) do dy + real(dot(a, dy)) + im * real(dot(b, dy)) + end end - return y, dx + return y, arroftup_to_tupofarr(tx) end function value_and_pullback( @@ -392,13 +400,7 @@ function value_and_pullback( ) where {F, B, C} check_prep(f, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep - ys_and_tx = ntuple( - b -> _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...), - Val(B), - ) - y = first(first(ys_and_tx)) - tx = map(last, ys_and_tx) - return y, tx + return _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty, contexts...) end function value_and_pullback!( @@ -449,12 +451,14 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::Real, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} _, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) - dx = dot(a, dy) - return dx + tx = map(ty) do dy + dot(a, dy) + end + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -463,15 +467,17 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::Complex, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) _, b = onlysecond( value_and_pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...) ) - dx = real(dot(a, dy)) + im * real(dot(b, dy)) - return dx + tx = map(ty) do dy + real(dot(a, dy)) + im * real(dot(b, dy)) + end + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -480,14 +486,16 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::AbstractArray{<:Real}, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} - dx = map(CartesianIndices(x)) do j # preserve shape + ) where {F, B, C} + tx = map(CartesianIndices(x)) do j # preserve shape _, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) - dot(a, dy) + map(ty) do dy + dot(a, dy) + end end - return dx + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -499,16 +507,18 @@ function _value_and_pullback_via_pushforward( dy, contexts::Vararg{Context, C}, ) where {F, C} - dx = map(CartesianIndices(x)) do j # preserve shape + tx = map(CartesianIndices(x)) do j # preserve shape a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) _, b = onlysecond( value_and_pushforward( f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts... ), ) - real(dot(a, dy)) + im * real(dot(b, dy)) + map(ty) do dy + real(dot(a, dy)) + im * real(dot(b, dy)) + end end - return dx + return y, arroftup_to_tupofarr(tx) end function value_and_pullback( @@ -522,13 +532,9 @@ function value_and_pullback( ) where {F, B, C} check_prep(f!, y, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep - tx = ntuple( - b -> _value_and_pullback_via_pushforward( - f!, y, pushforward_prep, backend, x, ty[b], contexts... - ), - Val(B), + return _value_and_pullback_via_pushforward( + f!, y, pushforward_prep, backend, x, ty, contexts... ) - return y, tx end function value_and_pullback!( diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 295f98145..e49203615 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -324,12 +324,14 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...)) - dy = dot(a, dx) - return y, dy + ty = map(tx) do dx + dot(a, dx) + end + return y, arroftup_to_tupofarr(ty) end function _value_and_pushforward_via_pullback( @@ -338,13 +340,15 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...)) b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y_ex),), contexts...)) - dy = real(dot(a, dx)) + im * real(dot(b, dx)) - return y, dy + ty = map(tx) do dx + real(dot(a, dx)) + im * real(dot(b, dx)) + end + return y, arroftup_to_tupofarr(ty) end function _value_and_pushforward_via_pullback( @@ -353,15 +357,17 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y = f(x, map(unwrap, contexts)...) - dy = map(CartesianIndices(y_ex)) do i + ty = map(CartesianIndices(y_ex)) do i a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...)) - dot(a, dx) + map(tx) do dx + dot(a, dx) + end end - return y, dy + return y, arroftup_to_tupofarr(ty) end function _value_and_pushforward_via_pullback( @@ -370,16 +376,18 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y = f(x, map(unwrap, contexts)...) - dy = map(CartesianIndices(y_ex)) do i + ty = map(CartesianIndices(y_ex)) do i a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...)) b = only(pullback(f, pullback_prep, backend, x, (im * basis(y_ex, i),), contexts...)) - real(dot(a, dx)) + im * real(dot(b, dx)) + map(tx) do dx + real(dot(a, dx)) + im * real(dot(b, dx)) + end end - return y, dy + return y, arroftup_to_tupofarr(ty) end function value_and_pushforward( @@ -392,13 +400,7 @@ function value_and_pushforward( ) where {F, B, C} check_prep(f, prep, backend, x, tx, contexts...) (; pullback_prep, y_example) = prep - ys_and_ty = ntuple( - b -> _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx[b], contexts...), - Val(B), - ) - y = first(first(ys_and_ty)) - ty = map(last, ys_and_ty) - return y, ty + return _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx, contexts...) end function value_and_pushforward!( @@ -449,14 +451,16 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} - dy = map(CartesianIndices(y)) do i # preserve shape + ) where {F, B, C} + ty = map(CartesianIndices(y)) do i # preserve shape _, a = onlysecond(value_and_pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) - dot(a, dx) + map(tx) do dx + dot(a, dx) + end end - return dy + return y, arroftup_to_tupofarr(ty) end function _value_and_pushforward_via_pullback( @@ -465,17 +469,19 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} - dy = map(CartesianIndices(y)) do i # preserve shape + ) where {F, B, C} + ty = map(CartesianIndices(y)) do i # preserve shape a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) _, b = onlysecond( value_and_pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...) ) - real(dot(a, dx)) + im * real(dot(b, dx)) + map(tx) do dx + real(dot(a, dx)) + im * real(dot(b, dx)) + end end - return dy + return y, arroftup_to_tupofarr(ty) end function value_and_pushforward( @@ -489,12 +495,7 @@ function value_and_pushforward( ) where {F, B, C} check_prep(f!, y, prep, backend, x, tx, contexts...) (; pullback_prep) = prep - ty = ntuple( - b -> - _value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...), - Val(B), - ) - return y, ty + return _value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx, contexts...) end function value_and_pushforward!( diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index 9b952ec2c..a5fb15cab 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -35,3 +35,6 @@ The trivial dense fallback is designed to protect against a change of format in get_pattern(M::AbstractMatrix) = trues(size(M)) onlysecond((a, b)) = (a, only(b)) + +arroftup_to_tupofarr(x::NTuple) = x +arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B)) From 9ee7fc7d0206398dc61520af8d50b68cca71e1d6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 21 Nov 2025 08:21:46 +0100 Subject: [PATCH 2/3] Fix --- DifferentiationInterface/src/first_order/pullback.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 993bd985e..333b12967 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -504,9 +504,9 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::AbstractArray{<:Complex}, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} tx = map(CartesianIndices(x)) do j # preserve shape a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) _, b = onlysecond( From a548cf2e110e8097d967d8281abd7357ecb574f2 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 21 Nov 2025 16:15:19 +0100 Subject: [PATCH 3/3] Update .github/workflows/Test.yml --- .github/workflows/Test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index a1d6fc87b..1020665c9 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -166,7 +166,7 @@ jobs: actions: write contents: read strategy: - fail-fast: false # TODO: toggle + fail-fast: true # TODO: toggle matrix: version: - '1.10'