From df8f22ec7deaf5cc0e88091966d71ed9551fe8c5 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Mon, 6 Apr 2020 21:38:07 +0200 Subject: [PATCH 1/4] add ardnet and simple fluxdecoder --- src/GenerativeModels.jl | 3 +- src/models/ardnet.jl | 47 +++++++++++++++++++ .../{flux_ode_decoder.jl => flux_decoders.jl} | 25 +++++++++- 3 files changed, 73 insertions(+), 2 deletions(-) create mode 100644 src/models/ardnet.jl rename src/utils/{flux_ode_decoder.jl => flux_decoders.jl} (77%) diff --git a/src/GenerativeModels.jl b/src/GenerativeModels.jl index bdce672..e255a7b 100644 --- a/src/GenerativeModels.jl +++ b/src/GenerativeModels.jl @@ -24,7 +24,7 @@ module GenerativeModels abstract type AbstractVAE <: AbstractGM end abstract type AbstractGAN <: AbstractGM end - include(joinpath("utils", "flux_ode_decoder.jl")) + include(joinpath("utils", "flux_decoders.jl")) include(joinpath("utils", "saveload.jl")) include(joinpath("utils", "utils.jl")) @@ -32,5 +32,6 @@ module GenerativeModels include(joinpath("models", "rodent.jl")) include(joinpath("models", "gan.jl")) include(joinpath("models", "vamp.jl")) + include(joinpath("models", "ardnet.jl")) end # module diff --git a/src/models/ardnet.jl b/src/models/ardnet.jl new file mode 100644 index 0000000..8e7a471 --- /dev/null +++ b/src/models/ardnet.jl @@ -0,0 +1,47 @@ +export ARDNet + +const ACG = ConditionalDists.ACGaussian +const FDCMeanGaussian = CMeanGaussian{V,<:FluxDecoder} where V + +""" + ARDNet{P<:Gaussian,E<:Gaussian,D<:ACGaussian,H<:TuringInverseWishart} + +Generative model that emposes the sparsifying ARD (*Automatic Relevance +Determination*) prior on the weights of the decoder mapping: + +p(x|z) = N(x|ϕ(z),σx²) +p(z) = N(z|0,diag(λz²)) +p(λz) = iW(λ|ψ,v) + +where the posterior on z is a multivariate Gaussian +q(z|x) = N(z|μz,σz²) +""" +struct ARDNet{P<:Gaussian,E<:Gaussian,D<:ACG,H<:InverseGamma} <: AbstractGM + prior::P + hyperprior::H + encoder::E + decoder::D +end + +Flux.@functor ARDNet + +function ConditionalDists.logpdf(p::ACG, x::AbstractArray{T}, z::AbstractArray{T}, + ps::AbstractVector{T}) where T + μ = mean(p, z, ps) + σ2 = var(p, z) + d = x - μ + y = d .* d + y = (1 ./ σ2) .* y .+ log.(σ2) .+ T(log(2π)) + -sum(y, dims=1) / 2 +end + +ConditionalDists.mean(p::FDCMeanGaussian, z::AbstractArray, ps::AbstractVector) = + p.mapping(z, ps) + +function elbo(m::ARDNet, x, y; β=1) + ps = reshape(rand(m.encoder),:) + llh = sum(logpdf(m.decoder, y, x, ps)) + kld = sum(kl_divergence(m.encoder, m.prior)) + lpλ = sum(logpdf(m.hyperprior, var(m.prior))) + llh - β*(kld - lpλ) +end diff --git a/src/utils/flux_ode_decoder.jl b/src/utils/flux_decoders.jl similarity index 77% rename from src/utils/flux_ode_decoder.jl rename to src/utils/flux_decoders.jl index 399137a..4e05e35 100644 --- a/src/utils/flux_ode_decoder.jl +++ b/src/utils/flux_decoders.jl @@ -1,4 +1,27 @@ -export FluxODEDecoder +export FluxDecoder, FluxODEDecoder + +""" + FluxDecoder{M}(model) + +Simple decoder that, when called with an additional parameter vector, +restructures it into `model` and calls model(x) + +julia> dec = FluxDecoder(Dense(2,3)) +julia> ps = rand(9) +julia> dec(rand(2,10), ps) +3×10 Array{Float64,2}: + 0.508304 0.620386 0.423422 … 0.595583 0.551536 0.565597 0.255811 + 1.75512 1.32246 1.57151 1.82269 1.2394 1.73934 0.844125 + 1.45708 0.92777 1.28766 1.49607 0.863829 1.4156 0.546213 +""" +struct FluxDecoder{M} + model::M + restructure::Function +end + +FluxDecoder(m) = FluxDecoder(m, Flux.destructure(m)[2]) +(d::FluxDecoder)(x::AbstractMatrix, ps::AbstractVector) = d.restructure(ps)(x) +(d::FluxDecoder)(x::AbstractMatrix) = d.model(x) """ FluxODEDecoder{M}(slength::Int, tlength::Int, dt::Real, From 1fb91ebdacae4fb5371af00241e444cbe2d84c05 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Tue, 7 Apr 2020 22:32:58 +0200 Subject: [PATCH 2/4] add InverseGamma to Rodent --- src/models/rodent.jl | 43 ++++++++----------------------------------- 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/src/models/rodent.jl b/src/models/rodent.jl index 30f976a..c197b42 100644 --- a/src/models/rodent.jl +++ b/src/models/rodent.jl @@ -12,7 +12,8 @@ with ARD prior and an ODE decoder. * `e`: Encoder p(z|x) * `d`: Decoder p(x|z) """ -struct Rodent{P<:Gaussian,E<:CMeanGaussian,D<:CMeanGaussian} <: AbstractVAE +struct Rodent{H<:InverseGamma,P<:Gaussian,E<:CMeanGaussian,D<:CMeanGaussian} <: AbstractVAE + hyperprior::InverseGamma prior::Gaussian encoder::CMeanGaussian decoder::CMeanGaussian @@ -95,50 +96,22 @@ function Rodent(slen::Int, tlen::Int, dt::T, encoder; olen=slen*tlen) where T zlen = length(Flux.destructure(ode)[1]) + slen + # hyperprior + hyperprior = InverseGamma(T(1), T(1), zlen, true) + + # prior μpz = NoGradArray(zeros(T, zlen)) λ2z = ones(T, zlen) / 20 prior = Gaussian(μpz, λ2z) + # encoder σ2z = ones(T, zlen) / 20 enc_dist = CMeanGaussian{DiagVar}(encoder, σ2z) + # decoder σ2x = ones(T, 1) / 20 decoder = FluxODEDecoder(slen, tlen, dt, ode, observe) dec_dist = CMeanGaussian{ScalarVar}(decoder, σ2x, olen) Rodent(prior, enc_dist, dec_dist) end - -struct ConstSpecRodent{CP<:Gaussian,SP<:Gaussian,E<:ConstSpecGaussian,D<:CMeanGaussian} <: AbstractVAE - const_prior::CP - spec_prior::SP - encoder::E - decoder::D -end - -ConstSpecRodent(cp::CP, sp::SP, e::E, d::D) where {CP,SP,E,D} = - ConstSpecRodent{CP,SP,E,D}(cp,sp,e,d) - -Flux.@functor ConstSpecRodent - -function elbo(m::ConstSpecRodent, x::AbstractArray) - cz = rand(m.encoder.cnst) - sz = rand(m.encoder.spec, x) - z = cz .+ sz - - llh = sum(logpdf(m.decoder, x, z)) - ckl = sum(kl_divergence(m.encoder.cnst, m.const_prior)) - skl = sum(kl_divergence(m.encoder.spec, m.spec_prior, sz)) - - llh - ckl - skl -end - -function Base.show(io::IO, m::ConstSpecRodent) - msg = """$(typeof(m)): - const_prior = $(summary(m.const_prior))) - spec_prior = $(summary(m.spec_prior)) - encoder = $(summary(m.encoder)) - decoder = $(summary(m.decoder)) - """ - print(io, msg) -end From fd4f1be038a2aaeab9d3e9b73a6af88ddf414ce9 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Tue, 7 Apr 2020 22:59:01 +0200 Subject: [PATCH 3/4] add Rodent elbo --- src/GenerativeModels.jl | 2 +- src/models/ardnet.jl | 11 ++++++----- src/models/rodent.jl | 10 +++++++++- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/GenerativeModels.jl b/src/GenerativeModels.jl index e255a7b..712d7f8 100644 --- a/src/GenerativeModels.jl +++ b/src/GenerativeModels.jl @@ -29,9 +29,9 @@ module GenerativeModels include(joinpath("utils", "utils.jl")) include(joinpath("models", "vae.jl")) + include(joinpath("models", "ardnet.jl")) include(joinpath("models", "rodent.jl")) include(joinpath("models", "gan.jl")) include(joinpath("models", "vamp.jl")) - include(joinpath("models", "ardnet.jl")) end # module diff --git a/src/models/ardnet.jl b/src/models/ardnet.jl index 8e7a471..80054eb 100644 --- a/src/models/ardnet.jl +++ b/src/models/ardnet.jl @@ -1,24 +1,25 @@ export ARDNet -const ACG = ConditionalDists.ACGaussian +const ACG = ConditionalDists.AbstractConditionalGaussian + const FDCMeanGaussian = CMeanGaussian{V,<:FluxDecoder} where V """ - ARDNet{P<:Gaussian,E<:Gaussian,D<:ACGaussian,H<:TuringInverseWishart} + ARDNet(h::InverseGamma, p::Gaussian, e::Gaussian, d::ACGaussian) Generative model that emposes the sparsifying ARD (*Automatic Relevance Determination*) prior on the weights of the decoder mapping: p(x|z) = N(x|ϕ(z),σx²) p(z) = N(z|0,diag(λz²)) -p(λz) = iW(λ|ψ,v) +p(λz) = iG(λ|α0,β0) where the posterior on z is a multivariate Gaussian q(z|x) = N(z|μz,σz²) """ -struct ARDNet{P<:Gaussian,E<:Gaussian,D<:ACG,H<:InverseGamma} <: AbstractGM - prior::P +struct ARDNet{H<:InverseGamma,P<:Gaussian,E<:Gaussian,D<:ACG} <: AbstractGM hyperprior::H + prior::P encoder::E decoder::D end diff --git a/src/models/rodent.jl b/src/models/rodent.jl index c197b42..025a200 100644 --- a/src/models/rodent.jl +++ b/src/models/rodent.jl @@ -12,7 +12,7 @@ with ARD prior and an ODE decoder. * `e`: Encoder p(z|x) * `d`: Decoder p(x|z) """ -struct Rodent{H<:InverseGamma,P<:Gaussian,E<:CMeanGaussian,D<:CMeanGaussian} <: AbstractVAE +struct Rodent{H,P,E<:CMeanGaussian,D<:CMeanGaussian} <: AbstractGM hyperprior::InverseGamma prior::Gaussian encoder::CMeanGaussian @@ -115,3 +115,11 @@ function Rodent(slen::Int, tlen::Int, dt::T, encoder; Rodent(prior, enc_dist, dec_dist) end + +function elbo(m::Rodent, x::AbstractArray; β=1) + z = rand(m.encoder, x) + llh = sum(logpdf(m.decoder, x, z)) + kld = sum(kl_divergence(m.encoder, m.prior)) + lpλ = sum(logpdf(m.hyperprior, var(m.prior))) + llh - β*(kld - lpλ) +end From 85c94ef87bc30d1f8e2fdff8aca18ce670ff1020 Mon Sep 17 00:00:00 2001 From: Niklas Heim Date: Wed, 8 Apr 2020 15:06:40 +0200 Subject: [PATCH 4/4] fix constructor and elbo --- src/models/rodent.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/models/rodent.jl b/src/models/rodent.jl index 025a200..ac6ad46 100644 --- a/src/models/rodent.jl +++ b/src/models/rodent.jl @@ -21,7 +21,7 @@ end Flux.@functor Rodent -Rodent(p::P, e::E, d::D) where {P,E,D} = Rodent{P,E,D}(p,e,d) +Rodent(h::H, p::P, e::E, d::D) where {H,P,E,D} = Rodent{H,P,E,D}(h,p,e,d) """ Rodent(slen::Int, tlen::Int, dt::T, encoder; @@ -113,13 +113,13 @@ function Rodent(slen::Int, tlen::Int, dt::T, encoder; decoder = FluxODEDecoder(slen, tlen, dt, ode, observe) dec_dist = CMeanGaussian{ScalarVar}(decoder, σ2x, olen) - Rodent(prior, enc_dist, dec_dist) + Rodent(hyperprior, prior, enc_dist, dec_dist) end -function elbo(m::Rodent, x::AbstractArray; β=1) +function elbo(m::Rodent, x::AbstractMatrix; β=1) z = rand(m.encoder, x) llh = sum(logpdf(m.decoder, x, z)) - kld = sum(kl_divergence(m.encoder, m.prior)) + kld = sum(kl_divergence(m.encoder, m.prior, x)) lpλ = sum(logpdf(m.hyperprior, var(m.prior))) llh - β*(kld - lpλ) end