Skip to content

Commit 7f81bef

Browse files
committed
more lenient Normal/Exponential constructors
1 parent dda69ba commit 7f81bef

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

src/distributions.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,12 @@ struct Normalμσ{T} <: Normal{T}
137137
σ::T
138138
end
139139

140-
Normal(::Type{T}=Float64) where {T} = Normal01{T}()
141-
Normal::T, σ::T) where {T} = Normalμσ(μ, σ)
140+
const NormalTypes = Union{AbstractFloat,Complex{<:AbstractFloat}}
141+
142+
Normal(::Type{T}=Float64) where {T<:NormalTypes} = Normal01{T}()
143+
Normal::T, σ::T) where {T<:NormalTypes} = Normalμσ(μ, σ)
144+
Normal::T, σ::T) where {T<:Real} = Normalμσ(AbstractFloat(μ), AbstractFloat(σ))
145+
Normal(μ, σ) = Normal(promote(μ, σ)...)
142146

143147
abstract type Exponential{T} <: Distribution{T} end
144148

@@ -150,6 +154,7 @@ end
150154

151155
Exponential(::Type{T}=Float64) where {T<:AbstractFloat} = Exponential1{T}()
152156
Exponential::T) where {T<:AbstractFloat} = Exponentialθ(θ)
157+
Exponential::Real) = Exponentialθ(AbstractFloat(θ))
153158

154159

155160
## floats

src/sampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ end
116116

117117
## Normal & Exponential
118118

119-
rand(rng::AbstractRNG, ::SamplerTrivial{Normal01{T}}) where {T<:Union{AbstractFloat,Complex{<:AbstractFloat}}} =
119+
rand(rng::AbstractRNG, ::SamplerTrivial{Normal01{T}}) where {T<:NormalTypes} =
120120
randn(rng, T)
121121

122122
Sampler(RNG::Type{<:AbstractRNG}, d::Normalμσ{T}, n::Repetition) where {T} =

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ using Test
66
# Normal/Exponential
77
@test rand(Normal()) isa Float64
88
@test rand(Normal(0.0, 1.0)) isa Float64
9+
@test rand(Normal(0, 1)) isa Float64
10+
@test rand(Normal(0, 1.0)) isa Float64
911
@test rand(Exponential()) isa Float64
1012
@test rand(Exponential(1.0)) isa Float64
13+
@test rand(Exponential(1)) isa Float64
1114
@test rand(Normal(Float32)) isa Float32
1215
@test rand(Exponential(Float32)) isa Float32
1316
@test rand(Normal(ComplexF64)) isa ComplexF64

0 commit comments

Comments
 (0)