Avoid MArrays with FiniteDiff backend#1019
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests.
Additional details and impacted files@@ Coverage Diff @@
## main #1019 +/- ##
===========================================
- Coverage 98.14% 87.60% -10.54%
===========================================
Files 138 135 -3
Lines 8131 8078 -53
===========================================
- Hits 7980 7077 -903
- Misses 151 1001 +850
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Nice catch! The alternative would be replacing |
|
Yes, using FiniteDiff
using DifferentiationInterface
using StaticArrays
function doit()
backend = AutoFiniteDiff()
x = @SVector [1.0, 2.0]
tx = (2.0.*x,)
f(x) = @. 3.0 * x
# prep = DifferentiationInterface.prepare_pushforward(f, backend, x, tx);
prep = DifferentiationInterface.prepare_pushforward_nokwarg(Val(true), f, backend, x, tx)
return prep
endWhen function DI.prepare_pushforward_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context, C}
) where {C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
cache = if x isa Number || y isa Number
nothing
else
JVPCache(copy(x), y, fdtype(backend))
end
relstep = if isnothing(backend.relstep)
default_relstep(fdtype(backend), eltype(x))
else
backend.relstep
end
absstep = if isnothing(backend.absstep)
relstep
else
backend.absstep
end
dir = backend.dir
return FiniteDiffOneArgPushforwardPrep(_sig, cache, relstep, absstep, dir)
endI see these allocations: julia> @allocated prep = doit()
18134560
julia> @allocated prep = doit()
96
julia> (I assume the first call includes the allocations associated with compiling.) Then, if I add those function DI.prepare_pushforward_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context, C}
) where {C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
cache = if x isa Number || y isa Number || (! DI.ismutable_array(x)) || (! DI.ismutable_array(y))
nothing
else
JVPCache(copy(x), y, fdtype(backend))
end
relstep = if isnothing(backend.relstep)
default_relstep(fdtype(backend), eltype(x))
else
backend.relstep
end
absstep = if isnothing(backend.absstep)
relstep
else
backend.absstep
end
dir = backend.dir
return FiniteDiffOneArgPushforwardPrep(_sig, cache, relstep, absstep, dir)
endI see this in the REPL: julia> @allocated prep = doit()
751888
julia> @allocated prep = doit()
0
julia> Here's another example that isolates the julia> function doit2()
x = @SVector [1.0, 2.0]
y = @SVector [2.0, 3.0]
foo = Val(:forward)
cache = FiniteDiff.JVPCache(x, y, foo)
return cache
end
doit2 (generic function with 1 method)
julia> @allocated cache = doit2()
882800
julia> @allocated cache = doit2()
48
julia> |
|
I tried implementing an upstream fix in JuliaDiff/FiniteDiff.jl#216 |
|
@gdalle Excellent, thanks! That fixes the allocations I was seeing with this small example: julia> function doit2()
x = @SVector [1.0, 2.0]
y = @SVector [2.0, 3.0]
foo = Val(:forward)
cache = FiniteDiff.JVPCache(x, y, foo)
return cache
end
doit2 (generic function with 1 method)
julia> @allocated cache = doit2()
882800
julia> @allocated cache = doit2()
48
julia> To avoid the allocations with this example using FiniteDiff
using DifferentiationInterface
using StaticArrays
function doit()
backend = AutoFiniteDiff()
x = @SVector [1.0, 2.0]
tx = (2.0.*x,)
f(x) = @. 3.0 * x
# prep = DifferentiationInterface.prepare_pushforward(f, backend, x, tx);
prep = DifferentiationInterface.prepare_pushforward_nokwarg(Val(true), f, backend, x, tx)
return prep
endit looks like we'd need to replace the Interestingly, it looks like |
I noticed some allocations when using FiniteDiff with StaticArrays that I eventually traced to calls to
similarin the FiniteDiff extension. Adding the check for immutable input or output arrays seemed to fix this. Haven't attempted to add tests for this—wanted to get some feedback before diving in. What do you think?