From 632791c88dfffb8b10119807d0d341f1bc109c48 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 26 May 2026 09:35:10 +0200 Subject: [PATCH 1/2] Make `JVPCache` an immutable struct --- Project.toml | 12 ++++++++++-- src/jvp.jl | 48 ++++++++++++++++++++++++++++-------------------- 2 files changed, 38 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index ff8e51c..5278b06 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDiff" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.31.0" +version = "2.31.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -37,4 +37,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "BlockBandedMatrices", "BandedMatrices", "Pkg", "SafeTestsets", "SparseArrays", "StaticArrays"] +test = [ + "Test", + "BlockBandedMatrices", + "BandedMatrices", + "Pkg", + "SafeTestsets", + "SparseArrays", + "StaticArrays", +] diff --git a/src/jvp.jl b/src/jvp.jl index fa940aa..08e5833 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -11,7 +11,7 @@ and `v` is a vector. - `x1::X1`: Temporary array for perturbed input values - `fx1::FX1`: Temporary array for function evaluations """ -mutable struct JVPCache{X1, FX1, FDType} +struct JVPCache{X1, FX1, FDType} x1::X1 fx1::FX1 end @@ -39,9 +39,10 @@ cache = JVPCache(x, Val(:forward)) """ function JVPCache( x, - fdtype::Union{Val{FD}, Type{FD}} = Val(:forward)) where {FD} + fdtype::Union{Val{FD}, Type{FD}} = Val(:forward) + ) where {FD} fdtype isa Type && (fdtype = fdtype()) - JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x)) + return JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x)) end """ @@ -74,9 +75,10 @@ are not used elsewhere if their values need to be preserved. function JVPCache( x, fx, - fdtype::Union{Val{FD}, Type{FD}} = Val(:forward)) where {FD} + fdtype::Union{Val{FD}, Type{FD}} = Val(:forward) + ) where {FD} fdtype isa Type && (fdtype = fdtype()) - JVPCache{typeof(x), typeof(fx), fdtype}(x, fx) + return JVPCache{typeof(x), typeof(fx), fdtype}(x, fx) end """ @@ -130,19 +132,21 @@ where `h` is the step size and `v` is the direction vector. - Central differences: 2 function evaluations, `O(h²)` accuracy - Particularly efficient when `v` is sparse or when only one directional derivative is needed """ -function finite_difference_jvp(f, x, v, +function finite_difference_jvp( + f, x, v, fdtype = Val(:forward), f_in = nothing; relstep = default_relstep(fdtype, eltype(x)), absstep = relstep, - dir = true) + dir = true + ) if f_in isa Nothing fx = f(x) else fx = f_in end cache = JVPCache(x, fx, fdtype) - finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir) + return finite_difference_jvp(f, x, v, cache, fx; relstep, absstep, dir) end """ @@ -164,7 +168,8 @@ function finite_difference_jvp( f_in = nothing; relstep = default_relstep(fdtype, eltype(x)), absstep = relstep, - dir = true) where {X1, FX1, fdtype} + dir = true + ) where {X1, FX1, fdtype} if fdtype == Val(:complex) ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") end @@ -175,17 +180,17 @@ function finite_difference_jvp( fx = f_in isa Nothing ? f(x) : f_in x1 = @. x + epsilon * v fx1 = f(x1) - fx1 = @. (fx1-fx)/epsilon + fx1 = @. (fx1 - fx) / epsilon elseif fdtype == Val(:central) x1 = @. x + epsilon * v fx1 = f(x1) x1 = @. x - epsilon * v fx = f(x1) - fx1 = @. (fx1-fx)/(2epsilon) + fx1 = @. (fx1 - fx) / (2epsilon) else fdtype_error(eltype(x)) end - fx1 + return fx1 end """ @@ -202,14 +207,16 @@ end Cache-less. """ -function finite_difference_jvp!(jvp, +function finite_difference_jvp!( + jvp, f, x, v, fdtype = Val(:forward), f_in = nothing; relstep = default_relstep(fdtype, eltype(x)), - absstep = relstep) + absstep = relstep + ) if !isnothing(f_in) cache = JVPCache(x, f_in, fdtype) elseif fdtype == Val(:forward) @@ -219,7 +226,7 @@ function finite_difference_jvp!(jvp, else cache = JVPCache(x, fdtype) end - finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep) + return finite_difference_jvp!(jvp, f, x, v, cache, cache.fx1; relstep, absstep) end """ @@ -244,7 +251,8 @@ function finite_difference_jvp!( f_in = nothing; relstep = default_relstep(fdtype, eltype(x)), absstep = relstep, - dir = true) where {X1, FX1, fdtype} + dir = true + ) where {X1, FX1, fdtype} if fdtype == Val(:complex) ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") end @@ -260,21 +268,21 @@ function finite_difference_jvp!( end @. x1 = x + epsilon * v f(jvp, x1) - @. jvp = (jvp-fx1)/epsilon + @. jvp = (jvp - fx1) / epsilon elseif fdtype == Val(:central) @. x1 = x - epsilon * v f(fx1, x1) @. x1 = x + epsilon * v f(jvp, x1) - @. jvp = (jvp-fx1)/(2epsilon) + @. jvp = (jvp - fx1) / (2epsilon) else fdtype_error(eltype(x)) end - nothing + return nothing end function resize!(cache::JVPCache, i::Int) resize!(cache.x1, i) cache.fx1 !== nothing && resize!(cache.fx1, i) - nothing + return nothing end From 79ffa230cf1cb937b074f6dddf3309dc6bb67d7d Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 26 May 2026 09:37:04 +0200 Subject: [PATCH 2/2] Proper formatting --- src/jvp.jl | 90 +++++++++++++++++++++++++++--------------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/src/jvp.jl b/src/jvp.jl index 08e5833..821e6ac 100644 --- a/src/jvp.jl +++ b/src/jvp.jl @@ -11,7 +11,7 @@ and `v` is a vector. - `x1::X1`: Temporary array for perturbed input values - `fx1::FX1`: Temporary array for function evaluations """ -struct JVPCache{X1, FX1, FDType} +struct JVPCache{X1,FX1,FDType} x1::X1 fx1::FX1 end @@ -38,11 +38,11 @@ cache = JVPCache(x, Val(:forward)) ``` """ function JVPCache( - x, - fdtype::Union{Val{FD}, Type{FD}} = Val(:forward) - ) where {FD} + x, + fdtype::Union{Val{FD},Type{FD}}=Val(:forward) +) where {FD} fdtype isa Type && (fdtype = fdtype()) - return JVPCache{typeof(x), typeof(x), fdtype}(copy(x), copy(x)) + return JVPCache{typeof(x),typeof(x),fdtype}(copy(x), copy(x)) end """ @@ -73,12 +73,12 @@ The arrays `x` and `fx1` will be modified during JVP computations. Ensure they are not used elsewhere if their values need to be preserved. """ function JVPCache( - x, - fx, - fdtype::Union{Val{FD}, Type{FD}} = Val(:forward) - ) where {FD} + x, + fx, + fdtype::Union{Val{FD},Type{FD}}=Val(:forward) +) where {FD} fdtype isa Type && (fdtype = fdtype()) - return JVPCache{typeof(x), typeof(fx), fdtype}(x, fx) + return JVPCache{typeof(x),typeof(fx),fdtype}(x, fx) end """ @@ -133,13 +133,13 @@ where `h` is the step size and `v` is the direction vector. - Particularly efficient when `v` is sparse or when only one directional derivative is needed """ function finite_difference_jvp( - f, x, v, - fdtype = Val(:forward), - f_in = nothing; - relstep = default_relstep(fdtype, eltype(x)), - absstep = relstep, - dir = true - ) + f, x, v, + fdtype=Val(:forward), + f_in=nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + dir=true +) if f_in isa Nothing fx = f(x) else @@ -161,15 +161,15 @@ end Cached. """ function finite_difference_jvp( - f, - x, - v, - cache::JVPCache{X1, FX1, fdtype}, - f_in = nothing; - relstep = default_relstep(fdtype, eltype(x)), - absstep = relstep, - dir = true - ) where {X1, FX1, fdtype} + f, + x, + v, + cache::JVPCache{X1,FX1,fdtype}, + f_in=nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + dir=true +) where {X1,FX1,fdtype} if fdtype == Val(:complex) ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") end @@ -208,15 +208,15 @@ end Cache-less. """ function finite_difference_jvp!( - jvp, - f, - x, - v, - fdtype = Val(:forward), - f_in = nothing; - relstep = default_relstep(fdtype, eltype(x)), - absstep = relstep - ) + jvp, + f, + x, + v, + fdtype=Val(:forward), + f_in=nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep +) if !isnothing(f_in) cache = JVPCache(x, f_in, fdtype) elseif fdtype == Val(:forward) @@ -243,16 +243,16 @@ end Cached. """ function finite_difference_jvp!( - jvp, - f, - x, - v, - cache::JVPCache{X1, FX1, fdtype}, - f_in = nothing; - relstep = default_relstep(fdtype, eltype(x)), - absstep = relstep, - dir = true - ) where {X1, FX1, fdtype} + jvp, + f, + x, + v, + cache::JVPCache{X1,FX1,fdtype}, + f_in=nothing; + relstep=default_relstep(fdtype, eltype(x)), + absstep=relstep, + dir=true +) where {X1,FX1,fdtype} if fdtype == Val(:complex) ArgumentError("finite_difference_jvp doesn't support :complex-mode finite diff") end