Skip to content
Open
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
# Manifest.toml

# visual studio settings folder
.vscode/*
52 changes: 40 additions & 12 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ git-tree-sha1 = "2b6845cea546604fb4dca4e31414a6a59d39ddcd"
uuid = "ec485272-7323-5ecc-a04f-4719b315124d"
version = "0.0.4"

[[Arpack]]
deps = ["BinaryProvider", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6"
uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97"
version = "0.3.1"

[[ArrayInterface]]
deps = ["LinearAlgebra", "Requires", "SparseArrays"]
git-tree-sha1 = "981354dab938901c2b607a213e62d9defa50b698"
Expand All @@ -50,7 +56,6 @@ git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.8"


[[CEnum]]
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down Expand Up @@ -80,7 +85,6 @@ git-tree-sha1 = "86a3165cfe6c7944dc9ba5dd3b703a5a1d7bccab"
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
version = "2.5.4"


[[Calculus]]
deps = ["Compat"]
git-tree-sha1 = "bd8bbd105ba583a42385bd6dc4a20dad8ab3dc11"
Expand Down Expand Up @@ -123,12 +127,6 @@ git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032"
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
version = "1.3.0"

[[Crayons]]
deps = ["Test"]
git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523"
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
version = "4.0.0"

[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
git-tree-sha1 = "639635bcdda71a1fccc4e95e679c517115dbf836"
Expand Down Expand Up @@ -188,6 +186,12 @@ version = "0.8.2"
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[Distributions]]
deps = ["LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"]
git-tree-sha1 = "56a158bc0abe4af5d4027af2275fde484261ca6d"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
version = "0.19.2"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a"
Expand Down Expand Up @@ -398,6 +402,12 @@ git-tree-sha1 = "d5b15b63a76146da84079aeabaf9d7f692afbcdf"
uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
version = "5.19.0"

[[PDMats]]
deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"]
git-tree-sha1 = "035f8d60ba2a22cb1d2580b1e0e5ce0cb05e4563"
uuid = "90014a1f-27ba-587c-ab20-58faa44d9150"
version = "0.9.10"

[[Parameters]]
deps = ["OrderedCollections"]
git-tree-sha1 = "b62b2558efb1eef1fa44e4be5ff58a515c287e38"
Expand All @@ -422,6 +432,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

[[QuadGK]]
deps = ["DataStructures", "LinearAlgebra"]
git-tree-sha1 = "1af46bf083b9630a5b27d4fd94f496c5fca642a8"
uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
version = "2.1.1"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Expand Down Expand Up @@ -459,6 +475,12 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
uuid = "ae029012-a4dd-5104-9daa-d747884805df"
version = "0.5.2"

[[Rmath]]
deps = ["BinaryProvider", "Libdl", "Random", "Statistics"]
git-tree-sha1 = "9825383d3453f4606d77f0a5722495f38001c09e"
uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa"
version = "0.5.1"

[[Roots]]
deps = ["Printf"]
git-tree-sha1 = "9cc4b586c71f9aea25312b94be8c195f119b0ec3"
Expand Down Expand Up @@ -522,8 +544,14 @@ git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.32.0"

[[StatsFuns]]
deps = ["Rmath", "SpecialFunctions"]
git-tree-sha1 = "67745a79d8e83a83737a7e17a383c54720a97f41"
uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
version = "0.9.0"

[[SuiteSparse]]
deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"]
deps = ["Libdl", "LinearAlgebra", "SparseArrays"]
uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[[TableTraits]]
Expand All @@ -537,10 +565,10 @@ deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["Crayons", "Printf", "Test", "Unicode"]
git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c"
deps = ["Printf"]
git-tree-sha1 = "d9c67bd7ac89aafa75037307331d050998bb5a96"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.0"
version = "0.5.1"

[[Tokenize]]
git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf"
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ authors = ["Niklas Heim <heim.niklas@gmail.com>"]
version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
11 changes: 11 additions & 0 deletions src/GenerativeModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,20 @@ module GenerativeModels
using Zygote: @nograd, @adjoint
using DiffEqBase: ODEProblem, solve
using OrdinaryDiffEq: Tsit5
using SpecialFunctions
using Distributions
using Adapt

abstract type AbstractGM end
abstract type AbstractVAE{T<:Real} <: AbstractGM end
abstract type AbstractGAN{T<:Real} <: AbstractGM end
abstract type AbstractSVAE{T<:Real} <: AbstractGM end

# functions that are overloaded by this module
import Base.length
import Random.rand
import Statistics.mean
import SpecialFunctions: besselix, logabsgamma

# needed to make e.g. sampling work
@nograd similar, randn!, fill!
Expand All @@ -24,15 +29,21 @@ module GenerativeModels
include(joinpath("utils", "nogradarray.jl"))
include(joinpath("utils", "saveload.jl"))
include(joinpath("utils", "utils.jl"))
include(joinpath("utils", "vmf.jl"))
include(joinpath("utils", "flux_ode_decoder.jl"))

include(joinpath("pdfs", "gaussian.jl"))
include(joinpath("pdfs", "hs_uniform.jl"))
include(joinpath("pdfs", "vonmisesfisher.jl"))
include(joinpath("pdfs", "abstract_cgaussian.jl"))
include(joinpath("pdfs", "cmean_gaussian.jl"))
include(joinpath("pdfs", "cmeanvar_gaussian.jl"))
include(joinpath("pdfs", "abstract_cvmf.jl"))
include(joinpath("pdfs", "cmeanconc_vmf.jl"))

include(joinpath("models", "vae.jl"))
include(joinpath("models", "rodent.jl"))
include(joinpath("models", "gan.jl"))
include(joinpath("models", "svae.jl"))

end # module
87 changes: 87 additions & 0 deletions src/models/svae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
export SVAE, SVAE_vmf_prior, SVAE_hsu_prior

"""
SVAE{T}([prior::Union{HypersphericalUniform{T}, VonMisesFisher{T}}, zlen::Int] encoder::AbstractCVMF, decoder::AbstractCPDF)

HyperSpherical Variational Auto-Encoder.

# Example
Create an S-VAE with either HSU prior or VMF prior with μ = [1, 0, ..., 0] and κ = 1 with:
```julia-repl
julia> enc = CMeanConcVMF{Float32}(Dense(5,4), 3)
CMeanConcVMF{Float32}(mapping=Dense(5, 4), μ_from_hidden=Chain(Dense(4, 3), #51), κ_from_hidden=Dense(4, 1, #52))

julia> dec = CMeanVarGaussian{Float32,ScalarVar}(Dense(3, 6))
CMeanVarGaussian{Float32,ScalarVar}(mapping=Dense(3, 6))

julia> svae = SVAE(HypersphericalUniform{Float32}(3), enc, dec)
SVAE{Float32}:
prior = HypersphericalUniform{Float32}(3)
encoder = (CMeanConcVMF{Float32}(mapping=Dense(5, 4), μ_from_hidden=Chain(Dens...)
decoder = CMeanVarGaussian{Float32,ScalarVar}(mapping=Dense(3, 6))

julia> mean(svae.decoder, mean(svae.encoder, rand(5, 1)))
5×1 Array{Float32,2}:
-0.7267006
0.6847478
-0.032789093
0.13542232
-0.270345421

julia> elbo(svae, rand(Float32, 5, 1))
15.011719478567946
```
"""
struct SVAE{T} <: AbstractSVAE{T}
prior::Union{HypersphericalUniform{T}, VonMisesFisher{T}}
encoder::AbstractCVMF{T}
decoder::AbstractCPDF{T}
end

Flux.@functor SVAE

# SVAE(p::Union{HypersphericalUniform{T}, VonMisesFisher{T}}, e::AbstractCVMF{T}, d::AbstractCPDF{T}) where T = SVAE{T}(p, e, d)

function SVAE_vmf_prior(zlength::Int, enc::AbstractCPDF{T}, dec::AbstractCPDF{T}) where T
μp = NoGradArray(zeros(T, zlength))
μp[1] = T(1)
κp = NoGradArray(ones(T, 1))
prior = VonMisesFisher(μp, κp)
SVAE{T}(prior, enc, dec)
end

function SVAE_hsu_prior(zlength::Int, enc::AbstractCPDF{T}, dec::AbstractCPDF{T}) where T
prior = HypersphericalUniform{T}(zlength)
SVAE{T}(prior, enc, dec)
end

"""
elbo(m::SVAE, x::AbstractArray; β=1)

Evidence lower boundary of the SVAE model. `β` scales the KLD term. (Assumes hyperspherical uniform prior)
"""
function elbo(m::SVAE{T}, x::AbstractArray{T}; β=T(1)) where {T}
z = rand(m.encoder, x)
llh = mean(-loglikelihood(m.decoder, x, z))
kl = mean(kld(m.encoder, m.prior, x))
llh + β*kl
end

"""
mmd(m::SVAE, x::AbstractArray, k)

Maximum mean discrepancy of a SVAE model given data `x` and kernel function `k(x,y)`.
"""
mmd(m::SVAE{T}, x::AbstractArray{T}, k) where {T} = mmd(m.encoder, m.prior, x, k)

function Base.show(io::IO, m::SVAE{T}) where T
p = short_repr(m.prior, 70)
e = short_repr(m.encoder, 70)
d = short_repr(m.decoder, 70)
msg = """$(typeof(m)):
prior = $(p)
encoder = $(e)
decoder = $(d)
"""
print(io, msg)
end
54 changes: 54 additions & 0 deletions src/pdfs/abstract_cvmf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
export loglikelihood, kld, rand, mean_conc, concentration

abstract type AbstractCVMF{T} <: AbstractCPDF{T} end

function rand(p::AbstractCVMF{T}, z::AbstractArray{T}) where {T}
(μ, κ) = mean_conc(p, z)
sample_vmf(μ, κ)
end

function loglikelihood(p::AbstractCVMF{T}, x::AbstractArray{T}, z::AbstractArray{T}) where {T}
(μ, κ) = mean_conc(p, z)
log_vmf(x, μ, κ)
end

# This is here because we always compute KLD with VMF and hyperspherical uniform - nothing else as KLD between two VMFs is rather complicated to compute
"""
kld(p::AbstractCVMF, q::HypersphericalUniform, z::AbstractArray)

Compute Kullback-Leibler divergence between a conditional Von Mises-Fisher distribution `p` given `z`
and a hyperspherical uniform distribution `q` with the same dimensionality.
"""
function kld(p::AbstractCVMF{T}, q::HypersphericalUniform{T}, z::AbstractArray{T}) where {T}
(μ, κ) = mean_conc(p, z)
if size(μ, 1) != q.dims
error("Cannot compute KLD between VMF and HSU with different dimensionality")
end
.- vmfentropy.(q.dims, κ) .+ huentropy(q.dims, T)
end

"""
mean_conc(p::AbstractCVMF, z::AbstractArray)

Returns mean and concentration of a conditional VMF distribution.
"""
mean_conc(p::AbstractCVMF, z::AbstractArray) = error("Not implemented!")


"""
mean(p::AbstractCVMF, z::AbstractArray)

Returns mean of a conditional VMF distribution.
"""
mean(p::AbstractCVMF, z::AbstractArray) = mean_conc(p, z)[1]

"""
concentration(p::AbstractCVMF, z::AbstractArray)

Returns variance of a conditional VMF distribution.
"""
concentration(p::AbstractCVMF, z::AbstractArray) = mean_conc(p, z)[2]




Loading