Skip to content

Commit 77e3736

Browse files
committed
don't propagate NaN/Inf if Partials are zero
1 parent 43ee7d8 commit 77e3736

4 files changed

Lines changed: 42 additions & 18 deletions

File tree

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ script:
1010
- if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
1111
- if (julia -e 'VERSION < v"0.5" && exit(1)'); then
1212
julia -e 'include(joinpath(JULIA_HOME, Base.DATAROOTDIR, "julia", "build_sysimg.jl")); build_sysimg(force=true)';
13-
julia -e 'Pkg.clone(pwd()); Pkg.build("ForwardDiff"); Pkg.test("ForwardDiff"; coverage=true)';
13+
julia -e 'Pkg.clone(pwd()); Pkg.build("ForwardDiff"); Pkg.checkout("DiffBase"); Pkg.test("ForwardDiff"; coverage=true)';
1414
julia -O3 -e 'include(joinpath(Pkg.dir("ForwardDiff"), "test/SIMDTest.jl"))';
1515
else
16-
julia -e 'Pkg.clone(pwd()); Pkg.build("ForwardDiff"); Pkg.test("ForwardDiff"; coverage=true)';
16+
julia -e 'Pkg.clone(pwd()); Pkg.build("ForwardDiff"); Pkg.checkout("DiffBase"); Pkg.test("ForwardDiff"; coverage=true)';
1717
fi
1818
after_success:
1919
- julia -e 'cd(Pkg.dir("ForwardDiff")); Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())'

src/partials.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,29 @@ Base.convert{N,T}(::Type{Partials{N,T}}, partials::Partials{N,T}) = partials
7575
@inline @compat(Base.:+){N}(a::Partials{N}, b::Partials{N}) = Partials(add_tuples(a.values, b.values))
7676
@inline @compat(Base.:-){N}(a::Partials{N}, b::Partials{N}) = Partials(sub_tuples(a.values, b.values))
7777
@inline @compat(Base.:-)(partials::Partials) = Partials(minus_tuple(partials.values))
78-
@inline @compat(Base.:*)(partials::Partials, x::Real) = Partials(scale_tuple(partials.values, x))
7978
@inline @compat(Base.:*)(x::Real, partials::Partials) = partials*x
80-
@inline @compat(Base.:/)(partials::Partials, x::Real) = Partials(div_tuple_by_scalar(partials.values, x))
8179

82-
@inline function _mul_partials{N}(a::Partials{N}, b::Partials{N}, afactor, bfactor)
83-
return Partials(mul_tuples(a.values, b.values, afactor, bfactor))
80+
# NaN/Inf-safe methods #
81+
#----------------------#
82+
83+
@inline function @compat(Base.:*)(partials::Partials, x::Real)
84+
x = ifelse((isnan(x) || isinf(x)) && iszero(partials), one(x), x)
85+
return Partials(scale_tuple(partials.values, x))
86+
end
87+
88+
@inline function @compat(Base.:/)(partials::Partials, x::Real)
89+
x = ifelse(x == zero(x) && iszero(partials), one(x), x)
90+
return Partials(div_tuple_by_scalar(partials.values, x))
91+
end
92+
93+
@inline function _mul_partials{N}(a::Partials{N}, b::Partials{N}, x_a, x_b)
94+
x_a = ifelse((isnan(x_a) || isinf(x_a)) && iszero(a), one(x_a), x_a)
95+
x_b = ifelse((isnan(x_b) || isinf(x_b)) && iszero(b), one(x_b), x_b)
96+
return Partials(mul_tuples(a.values, b.values, x_a, x_b))
8497
end
8598

8699
@inline function _div_partials(a::Partials, b::Partials, aval, bval)
87-
afactor = inv(bval)
88-
bfactor = -aval/(bval*bval)
89-
return _mul_partials(a, b, afactor, bfactor)
100+
return _mul_partials(a, b, inv(bval), -(aval / (bval*bval)))
90101
end
91102

92103
# edge cases where N == 0 #

test/JacobianTest.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ for f in DiffBase.ARRAY_TO_ARRAY_FUNCS
9999
out = ForwardDiff.jacobian(f, X, cfg)
100100
@test_approx_eq out j
101101

102-
out = similar(X, length(X), length(X))
102+
out = similar(X, length(v), length(X))
103103
ForwardDiff.jacobian!(out, f, X, cfg)
104104
@test_approx_eq out j
105105

106-
out = DiffBase.JacobianResult(X)
106+
out = DiffBase.DiffResult(similar(v, length(v)), similar(v, length(v), length(X)))
107107
ForwardDiff.jacobian!(out, f, X, cfg)
108108
@test_approx_eq DiffBase.value(out) v
109109
@test_approx_eq DiffBase.jacobian(out) j

test/PartialsTest.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ samerng() = MersenneTwister(1)
99
for N in (0, 3), T in (Int, Float32, Float64)
1010
println(" ...testing Partials{$N,$T}")
1111

12-
VALUES = ntuple(n -> rand(T), Val{N})
12+
VALUES = (rand(T,N)...)
1313
PARTIALS = Partials{N,T}(VALUES)
1414

15-
VALUES2 = ntuple(n -> rand(T), Val{N})
15+
VALUES2 = (rand(T,N)...)
1616
PARTIALS2 = Partials{N,T}(VALUES2)
1717

1818
##############################
@@ -70,7 +70,7 @@ for N in (0, 3), T in (Int, Float32, Float64)
7070
@test hash(PARTIALS, hash(1)) == hash(copy(PARTIALS), hash(1))
7171
@test hash(PARTIALS, hash(1)) == hash(copy(PARTIALS), hash(1))
7272

73-
const TMPIO = IOBuffer()
73+
TMPIO = IOBuffer()
7474
write(TMPIO, PARTIALS)
7575
seekstart(TMPIO)
7676
@test read(TMPIO, typeof(PARTIALS)) == PARTIALS
@@ -84,8 +84,8 @@ for N in (0, 3), T in (Int, Float32, Float64)
8484
# Conversion/Promotion #
8585
########################
8686

87-
const WIDE_T = widen(T)
88-
const WIDE_PARTIALS = convert(Partials{N,WIDE_T}, PARTIALS)
87+
WIDE_T = widen(T)
88+
WIDE_PARTIALS = convert(Partials{N,WIDE_T}, PARTIALS)
8989

9090
@test typeof(WIDE_PARTIALS) == Partials{N,WIDE_T}
9191
@test WIDE_PARTIALS == PARTIALS
@@ -101,8 +101,8 @@ for N in (0, 3), T in (Int, Float32, Float64)
101101
@test (PARTIALS - PARTIALS).values == map(v -> v - v, VALUES)
102102
@test getfield(-(PARTIALS), :values) == map(-, VALUES)
103103

104-
const X = rand()
105-
const Y = rand()
104+
X = rand()
105+
Y = rand()
106106

107107
@test X * PARTIALS == PARTIALS * X
108108
@test (X * PARTIALS).values == map(v -> X * v, VALUES)
@@ -111,6 +111,19 @@ for N in (0, 3), T in (Int, Float32, Float64)
111111
if N > 0
112112
@test ForwardDiff._mul_partials(PARTIALS, PARTIALS2, X, Y).values == map((a, b) -> (X * a) + (Y * b), VALUES, VALUES2)
113113
@test ForwardDiff._div_partials(PARTIALS, PARTIALS2, X, Y) == ForwardDiff._mul_partials(PARTIALS, PARTIALS2, inv(Y), -X/(Y^2))
114+
115+
ZEROS = Partials((zeros(T, N)...))
116+
117+
@test (NaN * ZEROS).values == ZEROS.values
118+
@test (Inf * ZEROS).values == ZEROS.values
119+
@test (ZEROS / 0).values == ZEROS.values
120+
121+
@test ForwardDiff._mul_partials(ZEROS, ZEROS, X, NaN).values == ZEROS.values
122+
@test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, X).values == ZEROS.values
123+
@test ForwardDiff._mul_partials(ZEROS, ZEROS, X, Inf).values == ZEROS.values
124+
@test ForwardDiff._mul_partials(ZEROS, ZEROS, Inf, X).values == ZEROS.values
125+
@test ForwardDiff._mul_partials(ZEROS, ZEROS, Inf, NaN).values == ZEROS.values
126+
@test ForwardDiff._mul_partials(ZEROS, ZEROS, NaN, Inf).values == ZEROS.values
114127
end
115128
end
116129

0 commit comments

Comments
 (0)