Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FiniteDiff"
uuid = "6a86dc24-6348-571c-b903-95158fe2bd41"
version = "2.31.0"
version = "2.31.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -37,4 +37,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets", "SparseArrays", "StaticArrays"]
test = [
"Test",
"BlockBandedMatrices",
"BandedMatrices",
"Pkg",
"SafeTestsets",
"SparseArrays",
"StaticArrays",
]
104 changes: 56 additions & 48 deletions src/jvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and `v` is a vector.
- `x1::X1`: Temporary array for perturbed input values
- `fx1::FX1`: Temporary array for function evaluations
"""
mutable struct JVPCache{X1, FX1, FDType}
struct JVPCache{X1,FX1,FDType}
x1::X1
fx1::FX1
end
Expand All @@ -38,10 +38,11 @@ cache = JVPCache(x, Val(:forward))
```
"""
function JVPCache(
x,
fdtype::Union{Val{FD}, Type{FD}} = Val(:forward)) where {FD}
x,
fdtype::Union{Val{FD},Type{FD}}=Val(:forward)
) where {FD}
fdtype isa Type && (fdtype = fdtype())
JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x))
return JVPCache{typeof(x),typeof(x),fdtype}(copy(x), copy(x))
end

"""
Expand Down Expand Up @@ -72,11 +73,12 @@ The arrays `x` and `fx1` will be modified during JVP computations. Ensure they
are not used elsewhere if their values need to be preserved.
"""
function JVPCache(
x,
fx,
fdtype::Union{Val{FD}, Type{FD}} = Val(:forward)) where {FD}
x,
fx,
fdtype::Union{Val{FD},Type{FD}}=Val(:forward)
) where {FD}
fdtype isa Type && (fdtype = fdtype())
JVPCache{typeof(x), typeof(fx), fdtype}(x, fx)
return JVPCache{typeof(x),typeof(fx),fdtype}(x, fx)
end

"""
Expand Down Expand Up @@ -130,19 +132,21 @@ where `h` is the step size and `v` is the direction vector.
- Central differences: 2 function evaluations, `O(h²)` accuracy
- Particularly efficient when `v` is sparse or when only one directional derivative is needed
"""
function finite_difference_jvp(f, x, v,
fdtype = Val(:forward),
f_in = nothing;
relstep = default_relstep(fdtype, eltype(x)),
absstep = relstep,
dir = true)
function finite_difference_jvp(
f, x, v,
fdtype=Val(:forward),
f_in=nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep,
dir=true
)
if f_in isa Nothing
fx = f(x)
else
fx = f_in
end
cache = JVPCache(x, fx, fdtype)
finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir)
return finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir)
end

"""
Expand All @@ -157,14 +161,15 @@ end
Cached.
"""
function finite_difference_jvp(
f,
x,
v,
cache::JVPCache{X1, FX1, fdtype},
f_in = nothing;
relstep = default_relstep(fdtype, eltype(x)),
absstep = relstep,
dir = true) where {X1, FX1, fdtype}
f,
x,
v,
cache::JVPCache{X1,FX1,fdtype},
f_in=nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep,
dir=true
) where {X1,FX1,fdtype}
if fdtype == Val(:complex)
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
end
Expand All @@ -175,17 +180,17 @@ function finite_difference_jvp(
fx = f_in isa Nothing ? f(x) : f_in
x1 = @. x + epsilon * v
fx1 = f(x1)
fx1 = @. (fx1-fx)/epsilon
fx1 = @. (fx1 - fx) / epsilon
elseif fdtype == Val(:central)
x1 = @. x + epsilon * v
fx1 = f(x1)
x1 = @. x - epsilon * v
fx = f(x1)
fx1 = @. (fx1-fx)/(2epsilon)
fx1 = @. (fx1 - fx) / (2epsilon)
else
fdtype_error(eltype(x))
end
fx1
return fx1
end

"""
Expand All @@ -202,14 +207,16 @@ end

Cache-less.
"""
function finite_difference_jvp!(jvp,
f,
x,
v,
fdtype = Val(:forward),
f_in = nothing;
relstep = default_relstep(fdtype, eltype(x)),
absstep = relstep)
function finite_difference_jvp!(
jvp,
f,
x,
v,
fdtype=Val(:forward),
f_in=nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep
)
if !isnothing(f_in)
cache = JVPCache(x, f_in, fdtype)
elseif fdtype == Val(:forward)
Expand All @@ -219,7 +226,7 @@ function finite_difference_jvp!(jvp,
else
cache = JVPCache(x, fdtype)
end
finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep)
return finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep)
end

"""
Expand All @@ -236,15 +243,16 @@ end
Cached.
"""
function finite_difference_jvp!(
jvp,
f,
x,
v,
cache::JVPCache{X1, FX1, fdtype},
f_in = nothing;
relstep = default_relstep(fdtype, eltype(x)),
absstep = relstep,
dir = true) where {X1, FX1, fdtype}
jvp,
f,
x,
v,
cache::JVPCache{X1,FX1,fdtype},
f_in=nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep,
dir=true
) where {X1,FX1,fdtype}
if fdtype == Val(:complex)
ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff")
end
Expand All @@ -260,21 +268,21 @@ function finite_difference_jvp!(
end
@. x1 = x + epsilon * v
f(jvp, x1)
@. jvp = (jvp-fx1)/epsilon
@. jvp = (jvp - fx1) / epsilon
elseif fdtype == Val(:central)
@. x1 = x - epsilon * v
f(fx1, x1)
@. x1 = x + epsilon * v
f(jvp, x1)
@. jvp = (jvp-fx1)/(2epsilon)
@. jvp = (jvp - fx1) / (2epsilon)
else
fdtype_error(eltype(x))
end
nothing
return nothing
end

function resize!(cache::JVPCache, i::Int)
resize!(cache.x1, i)
cache.fx1 !== nothing && resize!(cache.fx1, i)
nothing
return nothing
end
Loading