diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index bcc7f30bf..1020665c9 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -166,7 +166,7 @@ jobs: actions: write contents: read strategy: - fail-fast: true + fail-fast: true # TODO: toggle matrix: version: - '1.10' diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 7b1f85fb9..333b12967 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -325,12 +325,14 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::Real, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) - dx = dot(a, dy) - return y, dx + tx = map(ty) do dy + dot(a, dy) + end + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -338,13 +340,15 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::Complex, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...)) b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)) - dx = real(dot(a, dy)) + im * real(dot(b, dy)) - return y, dx + tx = map(ty) do dy + real(dot(a, dy)) + im * real(dot(b, dy)) + end + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -352,15 +356,17 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::AbstractArray{<:Real}, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y = f(x, map(unwrap, contexts)...) - dx = map(CartesianIndices(x)) do j + tx = map(CartesianIndices(x)) do j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) - dot(a, dy) + map(ty) do dy + dot(a, dy) + end end - return y, dx + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -368,18 +374,20 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::AbstractArray{<:Complex}, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y = f(x, map(unwrap, contexts)...) - dx = map(CartesianIndices(x)) do j + tx = map(CartesianIndices(x)) do j a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)) b = only( pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...), ) - real(dot(a, dy)) + im * real(dot(b, dy)) + map(ty) do dy + real(dot(a, dy)) + im * real(dot(b, dy)) + end end - return y, dx + return y, arroftup_to_tupofarr(tx) end function value_and_pullback( @@ -392,13 +400,7 @@ function value_and_pullback( ) where {F, B, C} check_prep(f, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep - ys_and_tx = ntuple( - b -> _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...), - Val(B), - ) - y = first(first(ys_and_tx)) - tx = map(last, ys_and_tx) - return y, tx + return _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty, contexts...) end function value_and_pullback!( @@ -449,12 +451,14 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::Real, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} _, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) - dx = dot(a, dy) - return dx + tx = map(ty) do dy + dot(a, dy) + end + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -463,15 +467,17 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::Complex, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...)) _, b = onlysecond( value_and_pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...) ) - dx = real(dot(a, dy)) + im * real(dot(b, dy)) - return dx + tx = map(ty) do dy + real(dot(a, dy)) + im * real(dot(b, dy)) + end + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -480,14 +486,16 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::AbstractArray{<:Real}, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} - dx = map(CartesianIndices(x)) do j # preserve shape + ) where {F, B, C} + tx = map(CartesianIndices(x)) do j # preserve shape _, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) - dot(a, dy) + map(ty) do dy + dot(a, dy) + end end - return dx + return y, arroftup_to_tupofarr(tx) end function _value_and_pullback_via_pushforward( @@ -496,19 +504,21 @@ function _value_and_pullback_via_pushforward( pushforward_prep::PushforwardPrep, backend::AbstractADType, x::AbstractArray{<:Complex}, - dy, + ty::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} - dx = map(CartesianIndices(x)) do j # preserve shape + ) where {F, B, C} + tx = map(CartesianIndices(x)) do j # preserve shape a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)) _, b = onlysecond( value_and_pushforward( f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts... ), ) - real(dot(a, dy)) + im * real(dot(b, dy)) + map(ty) do dy + real(dot(a, dy)) + im * real(dot(b, dy)) + end end - return dx + return y, arroftup_to_tupofarr(tx) end function value_and_pullback( @@ -522,13 +532,9 @@ function value_and_pullback( ) where {F, B, C} check_prep(f!, y, prep, backend, x, ty, contexts...) (; pushforward_prep) = prep - tx = ntuple( - b -> _value_and_pullback_via_pushforward( - f!, y, pushforward_prep, backend, x, ty[b], contexts... - ), - Val(B), + return _value_and_pullback_via_pushforward( + f!, y, pushforward_prep, backend, x, ty, contexts... ) - return y, tx end function value_and_pullback!( diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 295f98145..e49203615 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -324,12 +324,14 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...)) - dy = dot(a, dx) - return y, dy + ty = map(tx) do dx + dot(a, dx) + end + return y, arroftup_to_tupofarr(ty) end function _value_and_pushforward_via_pullback( @@ -338,13 +340,15 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...)) b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y_ex),), contexts...)) - dy = real(dot(a, dx)) + im * real(dot(b, dx)) - return y, dy + ty = map(tx) do dx + real(dot(a, dx)) + im * real(dot(b, dx)) + end + return y, arroftup_to_tupofarr(ty) end function _value_and_pushforward_via_pullback( @@ -353,15 +357,17 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y = f(x, map(unwrap, contexts)...) - dy = map(CartesianIndices(y_ex)) do i + ty = map(CartesianIndices(y_ex)) do i a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...)) - dot(a, dx) + map(tx) do dx + dot(a, dx) + end end - return y, dy + return y, arroftup_to_tupofarr(ty) end function _value_and_pushforward_via_pullback( @@ -370,16 +376,18 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} + ) where {F, B, C} y = f(x, map(unwrap, contexts)...) - dy = map(CartesianIndices(y_ex)) do i + ty = map(CartesianIndices(y_ex)) do i a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...)) b = only(pullback(f, pullback_prep, backend, x, (im * basis(y_ex, i),), contexts...)) - real(dot(a, dx)) + im * real(dot(b, dx)) + map(tx) do dx + real(dot(a, dx)) + im * real(dot(b, dx)) + end end - return y, dy + return y, arroftup_to_tupofarr(ty) end function value_and_pushforward( @@ -392,13 +400,7 @@ function value_and_pushforward( ) where {F, B, C} check_prep(f, prep, backend, x, tx, contexts...) (; pullback_prep, y_example) = prep - ys_and_ty = ntuple( - b -> _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx[b], contexts...), - Val(B), - ) - y = first(first(ys_and_ty)) - ty = map(last, ys_and_ty) - return y, ty + return _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx, contexts...) end function value_and_pushforward!( @@ -449,14 +451,16 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} - dy = map(CartesianIndices(y)) do i # preserve shape + ) where {F, B, C} + ty = map(CartesianIndices(y)) do i # preserve shape _, a = onlysecond(value_and_pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) - dot(a, dx) + map(tx) do dx + dot(a, dx) + end end - return dy + return y, arroftup_to_tupofarr(ty) end function _value_and_pushforward_via_pullback( @@ -465,17 +469,19 @@ function _value_and_pushforward_via_pullback( pullback_prep::PullbackPrep, backend::AbstractADType, x, - dx, + tx::NTuple{B}, contexts::Vararg{Context, C}, - ) where {F, C} - dy = map(CartesianIndices(y)) do i # preserve shape + ) where {F, B, C} + ty = map(CartesianIndices(y)) do i # preserve shape a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)) _, b = onlysecond( value_and_pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...) ) - real(dot(a, dx)) + im * real(dot(b, dx)) + map(tx) do dx + real(dot(a, dx)) + im * real(dot(b, dx)) + end end - return dy + return y, arroftup_to_tupofarr(ty) end function value_and_pushforward( @@ -489,12 +495,7 @@ function value_and_pushforward( ) where {F, B, C} check_prep(f!, y, prep, backend, x, tx, contexts...) (; pullback_prep) = prep - ty = ntuple( - b -> - _value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...), - Val(B), - ) - return y, ty + return _value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx, contexts...) end function value_and_pushforward!( diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index 9b952ec2c..a5fb15cab 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -35,3 +35,6 @@ The trivial dense fallback is designed to protect against a change of format in get_pattern(M::AbstractMatrix) = trues(size(M)) onlysecond((a, b)) = (a, only(b)) + +arroftup_to_tupofarr(x::NTuple) = x +arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B))