Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ jobs:
actions: write
contents: read
strategy:
fail-fast: true
fail-fast: true # TODO: toggle
matrix:
version:
- '1.10'
Expand Down
104 changes: 55 additions & 49 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,61 +325,69 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Real,
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
dx = dot(a, dy)
return y, dx
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Complex,
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y, a = onlysecond(value_and_pushforward(f, pushforward_prep, backend, x, (oneunit(x),), contexts...))
b = only(pushforward(f, pushforward_prep, backend, x, (im * oneunit(x),), contexts...))
dx = real(dot(a, dy)) + im * real(dot(b, dy))
return y, dx
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Real},
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y = f(x, map(unwrap, contexts)...)
dx = map(CartesianIndices(x)) do j
tx = map(CartesianIndices(x)) do j
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
dot(a, dy)
map(ty) do dy
dot(a, dy)
end
end
return y, dx
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
f::F,
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Complex},
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y = f(x, map(unwrap, contexts)...)
dx = map(CartesianIndices(x)) do j
tx = map(CartesianIndices(x)) do j
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
b = only(
pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...),
)
real(dot(a, dy)) + im * real(dot(b, dy))
map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return y, dx
return y, arroftup_to_tupofarr(tx)
end

function value_and_pullback(
Expand All @@ -392,13 +400,7 @@ function value_and_pullback(
) where {F, B, C}
check_prep(f, prep, backend, x, ty, contexts...)
(; pushforward_prep) = prep
ys_and_tx = ntuple(
b -> _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty[b], contexts...),
Val(B),
)
y = first(first(ys_and_tx))
tx = map(last, ys_and_tx)
return y, tx
return _value_and_pullback_via_pushforward(f, pushforward_prep, backend, x, ty, contexts...)
end

function value_and_pullback!(
Expand Down Expand Up @@ -449,12 +451,14 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Real,
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
dx = dot(a, dy)
return dx
tx = map(ty) do dy
dot(a, dy)
end
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -463,15 +467,17 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::Complex,
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
a = only(pushforward(f!, y, pushforward_prep, backend, x, (oneunit(x),), contexts...))
_, b = onlysecond(
value_and_pushforward(f!, y, pushforward_prep, backend, x, (im * oneunit(x),), contexts...)
)
dx = real(dot(a, dy)) + im * real(dot(b, dy))
return dx
tx = map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -480,14 +486,16 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Real},
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
dx = map(CartesianIndices(x)) do j # preserve shape
) where {F, B, C}
tx = map(CartesianIndices(x)) do j # preserve shape
_, a = onlysecond(value_and_pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
dot(a, dy)
map(ty) do dy
dot(a, dy)
end
end
return dx
return y, arroftup_to_tupofarr(tx)
end

function _value_and_pullback_via_pushforward(
Expand All @@ -496,19 +504,21 @@ function _value_and_pullback_via_pushforward(
pushforward_prep::PushforwardPrep,
backend::AbstractADType,
x::AbstractArray{<:Complex},
dy,
ty::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
dx = map(CartesianIndices(x)) do j # preserve shape
) where {F, B, C}
tx = map(CartesianIndices(x)) do j # preserve shape
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
_, b = onlysecond(
value_and_pushforward(
f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts...
),
)
real(dot(a, dy)) + im * real(dot(b, dy))
map(ty) do dy
real(dot(a, dy)) + im * real(dot(b, dy))
end
end
return dx
return y, arroftup_to_tupofarr(tx)
end

function value_and_pullback(
Expand All @@ -522,13 +532,9 @@ function value_and_pullback(
) where {F, B, C}
check_prep(f!, y, prep, backend, x, ty, contexts...)
(; pushforward_prep) = prep
tx = ntuple(
b -> _value_and_pullback_via_pushforward(
f!, y, pushforward_prep, backend, x, ty[b], contexts...
),
Val(B),
return _value_and_pullback_via_pushforward(
f!, y, pushforward_prep, backend, x, ty, contexts...
)
return y, tx
end

function value_and_pullback!(
Expand Down
83 changes: 42 additions & 41 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,14 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
dy = dot(a, dx)
return y, dy
ty = map(tx) do dx
dot(a, dx)
end
return y, arroftup_to_tupofarr(ty)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -338,13 +340,15 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y, a = onlysecond(value_and_pullback(f, pullback_prep, backend, x, (oneunit(y_ex),), contexts...))
b = only(pullback(f, pullback_prep, backend, x, (im * oneunit(y_ex),), contexts...))
dy = real(dot(a, dx)) + im * real(dot(b, dx))
return y, dy
ty = map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
return y, arroftup_to_tupofarr(ty)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -353,15 +357,17 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y = f(x, map(unwrap, contexts)...)
dy = map(CartesianIndices(y_ex)) do i
ty = map(CartesianIndices(y_ex)) do i
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
dot(a, dx)
map(tx) do dx
dot(a, dx)
end
end
return y, dy
return y, arroftup_to_tupofarr(ty)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -370,16 +376,18 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
) where {F, B, C}
y = f(x, map(unwrap, contexts)...)
dy = map(CartesianIndices(y_ex)) do i
ty = map(CartesianIndices(y_ex)) do i
a = only(pullback(f, pullback_prep, backend, x, (basis(y_ex, i),), contexts...))
b = only(pullback(f, pullback_prep, backend, x, (im * basis(y_ex, i),), contexts...))
real(dot(a, dx)) + im * real(dot(b, dx))
map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return y, dy
return y, arroftup_to_tupofarr(ty)
end

function value_and_pushforward(
Expand All @@ -392,13 +400,7 @@ function value_and_pushforward(
) where {F, B, C}
check_prep(f, prep, backend, x, tx, contexts...)
(; pullback_prep, y_example) = prep
ys_and_ty = ntuple(
b -> _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx[b], contexts...),
Val(B),
)
y = first(first(ys_and_ty))
ty = map(last, ys_and_ty)
return y, ty
return _value_and_pushforward_via_pullback(y_example, f, pullback_prep, backend, x, tx, contexts...)
end

function value_and_pushforward!(
Expand Down Expand Up @@ -449,14 +451,16 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
dy = map(CartesianIndices(y)) do i # preserve shape
) where {F, B, C}
ty = map(CartesianIndices(y)) do i # preserve shape
_, a = onlysecond(value_and_pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
dot(a, dx)
map(tx) do dx
dot(a, dx)
end
end
return dy
return y, arroftup_to_tupofarr(ty)
end

function _value_and_pushforward_via_pullback(
Expand All @@ -465,17 +469,19 @@ function _value_and_pushforward_via_pullback(
pullback_prep::PullbackPrep,
backend::AbstractADType,
x,
dx,
tx::NTuple{B},
contexts::Vararg{Context, C},
) where {F, C}
dy = map(CartesianIndices(y)) do i # preserve shape
) where {F, B, C}
ty = map(CartesianIndices(y)) do i # preserve shape
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
_, b = onlysecond(
value_and_pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)
)
real(dot(a, dx)) + im * real(dot(b, dx))
map(tx) do dx
real(dot(a, dx)) + im * real(dot(b, dx))
end
end
return dy
return y, arroftup_to_tupofarr(ty)
end

function value_and_pushforward(
Expand All @@ -489,12 +495,7 @@ function value_and_pushforward(
) where {F, B, C}
check_prep(f!, y, prep, backend, x, tx, contexts...)
(; pullback_prep) = prep
ty = ntuple(
b ->
_value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx[b], contexts...),
Val(B),
)
return y, ty
return _value_and_pushforward_via_pullback(f!, y, pullback_prep, backend, x, tx, contexts...)
end

function value_and_pushforward!(
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ The trivial dense fallback is designed to protect against a change of format in
get_pattern(M::AbstractMatrix) = trues(size(M))

onlysecond((a, b)) = (a, only(b))

arroftup_to_tupofarr(x::NTuple) = x
arroftup_to_tupofarr(x::AbstractArray{<:NTuple{B}}) where {B} = ntuple(b -> getindex.(x, b), Val(B))
Loading