diff --git a/.gitignore b/.gitignore index 27f320a..7db5fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ docs/site/ # committed for packages, but should be committed for applications that require a static # environment. # Manifest.toml + +# visual studio settings folder +.vscode/* diff --git a/Manifest.toml b/Manifest.toml index db1cd32..cd18433 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -24,6 +24,12 @@ git-tree-sha1 = "2b6845cea546604fb4dca4e31414a6a59d39ddcd" uuid = "ec485272-7323-5ecc-a04f-4719b315124d" version = "0.0.4" +[[Arpack]] +deps = ["BinaryProvider", "Libdl", "LinearAlgebra"] +git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6" +uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" +version = "0.3.1" + [[ArrayInterface]] deps = ["LinearAlgebra", "Requires", "SparseArrays"] git-tree-sha1 = "981354dab938901c2b607a213e62d9defa50b698" @@ -50,7 +56,6 @@ git-tree-sha1 = "5b08ed6036d9d3f0ee6369410b830f8873d4024c" uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" version = "0.5.8" - [[CEnum]] git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -80,7 +85,6 @@ git-tree-sha1 = "86a3165cfe6c7944dc9ba5dd3b703a5a1d7bccab" uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" version = "2.5.4" - [[Calculus]] deps = ["Compat"] git-tree-sha1 = "bd8bbd105ba583a42385bd6dc4a20dad8ab3dc11" @@ -123,12 +127,6 @@ git-tree-sha1 = "9a11d428dcdc425072af4aea19ab1e8c3e01c032" uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" version = "1.3.0" -[[Crayons]] -deps = ["Test"] -git-tree-sha1 = "f621b8ef51fd2004c7cf157ea47f027fdeac5523" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.0.0" - [[CuArrays]] deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "DataStructures", "GPUArrays", "Libdl", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] git-tree-sha1 = "639635bcdda71a1fccc4e95e679c517115dbf836" @@ -188,6 +186,12 @@ version = "0.8.2" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[Distributions]] +deps = ["LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] +git-tree-sha1 = "56a158bc0abe4af5d4027af2275fde484261ca6d" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.19.2" + [[DocStringExtensions]] deps = ["LibGit2", "Markdown", "Pkg", "Test"] git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a" @@ -398,6 +402,12 @@ git-tree-sha1 = "d5b15b63a76146da84079aeabaf9d7f692afbcdf" uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" version = "5.19.0" +[[PDMats]] +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] +git-tree-sha1 = "035f8d60ba2a22cb1d2580b1e0e5ce0cb05e4563" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.9.10" + [[Parameters]] deps = ["OrderedCollections"] git-tree-sha1 = "b62b2558efb1eef1fa44e4be5ff58a515c287e38" @@ -422,6 +432,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" deps = ["Printf"] uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +[[QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "1af46bf083b9630a5b27d4fd94f496c5fca642a8" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.1.1" + [[REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -459,6 +475,12 @@ git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "0.5.2" +[[Rmath]] +deps = ["BinaryProvider", "Libdl", "Random", "Statistics"] +git-tree-sha1 = "9825383d3453f4606d77f0a5722495f38001c09e" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.5.1" + [[Roots]] deps = ["Printf"] git-tree-sha1 = "9cc4b586c71f9aea25312b94be8c195f119b0ec3" @@ -522,8 +544,14 @@ git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" version = "0.32.0" +[[StatsFuns]] +deps = ["Rmath", "SpecialFunctions"] +git-tree-sha1 = "67745a79d8e83a83737a7e17a383c54720a97f41" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.9.0" + [[SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +deps = ["Libdl", "LinearAlgebra", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[TableTraits]] @@ -537,10 +565,10 @@ deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimerOutputs]] -deps = ["Crayons", "Printf", "Test", "Unicode"] -git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" +deps = ["Printf"] +git-tree-sha1 = "d9c67bd7ac89aafa75037307331d050998bb5a96" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.0" +version = "0.5.1" [[Tokenize]] git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf" diff --git a/Project.toml b/Project.toml index 9032da1..4efbf3f 100644 --- a/Project.toml +++ b/Project.toml @@ -4,15 +4,18 @@ authors = ["Niklas Heim "] version = "0.1.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/GenerativeModels.jl b/src/GenerativeModels.jl index b36d5f7..7abbdbc 100644 --- a/src/GenerativeModels.jl +++ b/src/GenerativeModels.jl @@ -6,15 +6,20 @@ module GenerativeModels using Zygote: @nograd, @adjoint using DiffEqBase: ODEProblem, solve using OrdinaryDiffEq: Tsit5 + using SpecialFunctions + using Distributions + using Adapt abstract type AbstractGM end abstract type AbstractVAE{T<:Real} <: AbstractGM end abstract type AbstractGAN{T<:Real} <: AbstractGM end + abstract type AbstractSVAE{T<:Real} <: AbstractGM end # functions that are overloaded by this module import Base.length import Random.rand import Statistics.mean + import SpecialFunctions: besselix, logabsgamma # needed to make e.g. sampling work @nograd similar, randn!, fill! @@ -24,15 +29,21 @@ module GenerativeModels include(joinpath("utils", "nogradarray.jl")) include(joinpath("utils", "saveload.jl")) include(joinpath("utils", "utils.jl")) + include(joinpath("utils", "vmf.jl")) include(joinpath("utils", "flux_ode_decoder.jl")) include(joinpath("pdfs", "gaussian.jl")) + include(joinpath("pdfs", "hs_uniform.jl")) + include(joinpath("pdfs", "vonmisesfisher.jl")) include(joinpath("pdfs", "abstract_cgaussian.jl")) include(joinpath("pdfs", "cmean_gaussian.jl")) include(joinpath("pdfs", "cmeanvar_gaussian.jl")) + include(joinpath("pdfs", "abstract_cvmf.jl")) + include(joinpath("pdfs", "cmeanconc_vmf.jl")) include(joinpath("models", "vae.jl")) include(joinpath("models", "rodent.jl")) include(joinpath("models", "gan.jl")) + include(joinpath("models", "svae.jl")) end # module diff --git a/src/models/svae.jl b/src/models/svae.jl new file mode 100644 index 0000000..25f6945 --- /dev/null +++ b/src/models/svae.jl @@ -0,0 +1,87 @@ +export SVAE, SVAE_vmf_prior, SVAE_hsu_prior + +""" + SVAE{T}([prior::Union{HypersphericalUniform{T}, VonMisesFisher{T}}, zlen::Int] encoder::AbstractCVMF, decoder::AbstractCPDF) + +HyperSpherical Variational Auto-Encoder. + +# Example +Create an S-VAE with either HSU prior or VMF prior with μ = [1, 0, ..., 0] and κ = 1 with: +```julia-repl +julia> enc = CMeanConcVMF{Float32}(Dense(5,4), 3) +CMeanConcVMF{Float32}(mapping=Dense(5, 4), μ_from_hidden=Chain(Dense(4, 3), #51), κ_from_hidden=Dense(4, 1, #52)) + +julia> dec = CMeanVarGaussian{Float32,ScalarVar}(Dense(3, 6)) +CMeanVarGaussian{Float32,ScalarVar}(mapping=Dense(3, 6)) + +julia> svae = SVAE(HypersphericalUniform{Float32}(3), enc, dec) +SVAE{Float32}: + prior = HypersphericalUniform{Float32}(3) + encoder = (CMeanConcVMF{Float32}(mapping=Dense(5, 4), μ_from_hidden=Chain(Dens...) + decoder = CMeanVarGaussian{Float32,ScalarVar}(mapping=Dense(3, 6)) + +julia> mean(svae.decoder, mean(svae.encoder, rand(5, 1))) +5×1 Array{Float32,2}: + -0.7267006 + 0.6847478 + -0.032789093 + 0.13542232 + -0.270345421 + +julia> elbo(svae, rand(Float32, 5, 1)) +15.011719478567946 +``` +""" +struct SVAE{T} <: AbstractSVAE{T} + prior::Union{HypersphericalUniform{T}, VonMisesFisher{T}} + encoder::AbstractCVMF{T} + decoder::AbstractCPDF{T} +end + +Flux.@functor SVAE + +# SVAE(p::Union{HypersphericalUniform{T}, VonMisesFisher{T}}, e::AbstractCVMF{T}, d::AbstractCPDF{T}) where T = SVAE{T}(p, e, d) + +function SVAE_vmf_prior(zlength::Int, enc::AbstractCPDF{T}, dec::AbstractCPDF{T}) where T + μp = NoGradArray(zeros(T, zlength)) + μp[1] = T(1) + κp = NoGradArray(ones(T, 1)) + prior = VonMisesFisher(μp, κp) + SVAE{T}(prior, enc, dec) +end + +function SVAE_hsu_prior(zlength::Int, enc::AbstractCPDF{T}, dec::AbstractCPDF{T}) where T + prior = HypersphericalUniform{T}(zlength) + SVAE{T}(prior, enc, dec) +end + +""" + elbo(m::SVAE, x::AbstractArray; β=1) + +Evidence lower boundary of the SVAE model. `β` scales the KLD term. (Assumes hyperspherical uniform prior) +""" +function elbo(m::SVAE{T}, x::AbstractArray{T}; β=T(1)) where {T} + z = rand(m.encoder, x) + llh = mean(-loglikelihood(m.decoder, x, z)) + kl = mean(kld(m.encoder, m.prior, x)) + llh + β*kl +end + +""" + mmd(m::SVAE, x::AbstractArray, k) + +Maximum mean discrepancy of a SVAE model given data `x` and kernel function `k(x,y)`. +""" +mmd(m::SVAE{T}, x::AbstractArray{T}, k) where {T} = mmd(m.encoder, m.prior, x, k) + +function Base.show(io::IO, m::SVAE{T}) where T + p = short_repr(m.prior, 70) + e = short_repr(m.encoder, 70) + d = short_repr(m.decoder, 70) + msg = """$(typeof(m)): + prior = $(p) + encoder = $(e) + decoder = $(d) + """ + print(io, msg) +end \ No newline at end of file diff --git a/src/pdfs/abstract_cvmf.jl b/src/pdfs/abstract_cvmf.jl new file mode 100644 index 0000000..6f22e3b --- /dev/null +++ b/src/pdfs/abstract_cvmf.jl @@ -0,0 +1,54 @@ +export loglikelihood, kld, rand, mean_conc, concentration + +abstract type AbstractCVMF{T} <: AbstractCPDF{T} end + +function rand(p::AbstractCVMF{T}, z::AbstractArray{T}) where {T} + (μ, κ) = mean_conc(p, z) + sample_vmf(μ, κ) +end + +function loglikelihood(p::AbstractCVMF{T}, x::AbstractArray{T}, z::AbstractArray{T}) where {T} + (μ, κ) = mean_conc(p, z) + log_vmf(x, μ, κ) +end + +# This is here because we always compute KLD with VMF and hyperspherical uniform - nothing else as KLD between two VMFs is rather complicated to compute +""" +kld(p::AbstractCVMF, q::HypersphericalUniform, z::AbstractArray) + +Compute Kullback-Leibler divergence between a conditional Von Mises-Fisher distribution `p` given `z` +and a hyperspherical uniform distribution `q` with the same dimensionality. +""" +function kld(p::AbstractCVMF{T}, q::HypersphericalUniform{T}, z::AbstractArray{T}) where {T} + (μ, κ) = mean_conc(p, z) + if size(μ, 1) != q.dims + error("Cannot compute KLD between VMF and HSU with different dimensionality") + end + .- vmfentropy.(q.dims, κ) .+ huentropy(q.dims, T) +end + +""" + mean_conc(p::AbstractCVMF, z::AbstractArray) + +Returns mean and concentration of a conditional VMF distribution. +""" +mean_conc(p::AbstractCVMF, z::AbstractArray) = error("Not implemented!") + + +""" + mean(p::AbstractCVMF, z::AbstractArray) + +Returns mean of a conditional VMF distribution. +""" +mean(p::AbstractCVMF, z::AbstractArray) = mean_conc(p, z)[1] + +""" + concentration(p::AbstractCVMF, z::AbstractArray) + +Returns variance of a conditional VMF distribution. +""" +concentration(p::AbstractCVMF, z::AbstractArray) = mean_conc(p, z)[2] + + + + diff --git a/src/pdfs/cmeanconc_vmf.jl b/src/pdfs/cmeanconc_vmf.jl new file mode 100644 index 0000000..6ad8ecc --- /dev/null +++ b/src/pdfs/cmeanconc_vmf.jl @@ -0,0 +1,67 @@ +export CMeanConcVMF + +""" +CMeanConcVMF(mapping, xlength::Int) + +Conditional Von Mises-Fisher that maps an input z to a mean μx and a concentration κ. +The mapping should end by the last hidden layer because the constructor will add +transformations for μ and κ. + ```julia-repl + μ_from_hidden = Chain(Dense(hidden_dim, xlength), x -> normalizecolumns(x)) + ``` + + ```julia-repl + κ_from_hidden = Dense(hidden_dim, 1, x -> σ.(x) .* 100) + ``` +# Arguments +- `mapping`: maps condition z to mean and concentration (e.g. a Flux Chain) +- `T`: expected eltype. E.g. `rand` will try to sample arrays of this eltype. + If the mapping returns a different eltype the output of `mean`,`concentration`, + and `rand` is not necessarily of eltype T. + +# Example +```julia-repl +julia> p = CMeanConcVMF{Float32}(Dense(2, 3), 3) +CMeanConcVMF{Float32}(mapping=Dense(2, 3), μ_from_hidden=Chain(Dense(3, 3), #45), κ_from_hidden=Dense(3, 1, #46)) + +julia> mean_conc(p, ones(2, 1)) +(Float32[-0.1507113; -0.9488135; 0.27755922], Float32[85.390144]) + +julia> rand(p, ones(2,1)) +3×1 Array{Float32,2}: + 0.22024345 + -0.9597406 + -0.1743287 +``` +""" +struct CMeanConcVMF{T} <: AbstractCVMF{T} + mapping + μ_from_hidden + κ_from_hidden +end + +#! Watch out, kappa is capped between 0 and 100 because it was exploding before. You might want to change this to softmax for kappa but in practice it did not behave well +CMeanConcVMF{T}(mapping, hidden_dim::Int, xlength::Int) where {T} = CMeanConcVMF{T}(mapping, Chain(Dense(hidden_dim, xlength), x -> normalizecolumns(x)), Dense(hidden_dim, 1, x -> σ.(x) .* 100)) +CMeanConcVMF{T}(mapping::Chain{C}, xlength::Int) where {C, T} = CMeanConcVMF{T}(mapping, size(mapping[length(mapping)].W, 1), xlength) +CMeanConcVMF{T}(mapping::Dense{D}, xlength::Int) where {D, T} = CMeanConcVMF{T}(mapping, size(mapping.W, 1), xlength) + +function mean_conc(p::CMeanConcVMF{T}, z::AbstractArray) where {T} + ex = p.mapping(z) + if eltype(ex) != T + error("Mapping should return eltype $T. Found: $(eltype(ex))") + end + + return p.μ_from_hidden(ex), p.κ_from_hidden(ex) +end + +# make sure that parameteric constructor is called... +function Flux.functor(p::CMeanConcVMF{T}) where {T} + fs = fieldnames(typeof(p)) + nt = (; (name=>getfield(p, name) for name in fs)...) + nt, y -> CMeanConcVMF{T}(y...) +end + +function Base.show(io::IO, p::CMeanConcVMF{T}) where {T} + msg = "CMeanConcVMF{$T}(mapping=$(short_repr(p.mapping)), μ_from_hidden=$(short_repr(p.μ_from_hidden)), κ_from_hidden=$(short_repr(p.κ_from_hidden)))" + print(io, msg) +end \ No newline at end of file diff --git a/src/pdfs/hs_uniform.jl b/src/pdfs/hs_uniform.jl new file mode 100644 index 0000000..0689771 --- /dev/null +++ b/src/pdfs/hs_uniform.jl @@ -0,0 +1,20 @@ +export HypersphericalUniform + +""" +HypersphericalUniform{T} + +Hyperspherical uniform distribution in `dims` dimensions. + +""" +struct HypersphericalUniform{T} <: AbstractPDF{T} + dims::Int +end + +HypersphericalUniform(d::Int) = HypersphericalUniform{Float32}(d) + +length(p::HypersphericalUniform) = p.dims + +function rand(p::HypersphericalUniform{T}, batchsize::Int=1) where {T} + v = randn(T, p.dims, batchsize) + normalizecolumns(v) +end diff --git a/src/pdfs/vonmisesfisher.jl b/src/pdfs/vonmisesfisher.jl new file mode 100644 index 0000000..c37826c --- /dev/null +++ b/src/pdfs/vonmisesfisher.jl @@ -0,0 +1,88 @@ +export VonMisesFisher + +""" +VonMisesFisher{T} + +Von Mises-Fisher distribution defined with mean μ and concentration κ that can be any `AbstractArray` and `Real` number respectively + +# Arguments +- `μ::AbstractArray`: mean of VMF +- `κ::AbstractArray`: concentration of VMF + +# Example +```julia-repl +julia> using Flux + +julia> p = VonMisesFisher(zeros(3), 1.0) +VonMisesFisher{Float64}(μ=3-element Array{Float64,1}, κ=[1.0]) + +julia> mean_conc(p) +([0.0, 0.0, 0.0], [1.0]) + +julia> rand(p) +3×1 Array{Float64,2}: + -0.534718473601494 + 0.4131946025140243 + 0.7371203256202924 +``` +""" +struct VonMisesFisher{T} <: AbstractPDF{T} + μ::AbstractArray{T} + κ::AbstractArray{T} + _nograd::Dict{Symbol,Bool} +end + +#! Watch out, there is no check for μ actually being on a sphere, even though all the methods count with that! +VonMisesFisher(μ::AbstractMatrix{T}, κ::Union{T, AbstractArray{T}}) where {T} = VonMisesFisher(vec(μ), κ) +VonMisesFisher(μ::AbstractVector{T}, κ::T) where {T} = VonMisesFisher(μ, NoGradArray([κ])) +function VonMisesFisher(μ::AbstractVector{T}, κ::AbstractArray{T}) where {T} + _nograd = Dict( + :μ => μ isa NoGradArray, + :κ => κ isa NoGradArray) + μ = _nograd[:μ] ? μ.data : μ + κ = _nograd[:κ] ? κ.data : κ + VonMisesFisher{T}(μ, κ, _nograd) +end + +Flux.@functor VonMisesFisher + +function Flux.trainable(p::VonMisesFisher) + ps = (;(k=>getfield(p,k) for k in keys(p._nograd) if !p._nograd[k])...) +end + +length(p::VonMisesFisher) = size(p.μ, 1) +mean_conc(p::VonMisesFisher) = (p.μ, p.κ) +mean(p::VonMisesFisher) = p.μ +concentration(p::VonMisesFisher) = p.κ + +function rand(p::VonMisesFisher{T}, batchsize::Int=1) where {T} + (μ, κ) = mean_conc(p) + μ = μ .* ones(T, size(μ, 1), batchsize) + κ = κ .* ones(T, 1, batchsize) + sample_vmf(μ, κ) +end + +loglikelihood(p::VonMisesFisher{T}, x::AbstractVector{T}) where T = loglikelihood(p, x * ones(T, 1, 1)) +function loglikelihood(p::VonMisesFisher{T}, x::AbstractMatrix{T}) where T + (μ, κ) = mean_conc(p) + μ = μ * ones(T, 1, size(x, 2)) + log_vmf(x, μ, κ[1]) +end + +""" +kld(p::AbstractCVMF, q::HypersphericalUniform, z::AbstractArray) + +Compute Kullback-Leibler divergence between a conditional Von Mises-Fisher distribution `p` given `z` +and a hyperspherical uniform distribution `q` with the same dimensionality. +""" +function kld(p::VonMisesFisher{T}, q::HypersphericalUniform{T}) where {T} + if length(p.μ) != q.dims + error("Cannot compute KLD between VMF and HSU with different dimensionality") + end + .- vmfentropy(q.dims, concentration(p)[1]) .+ huentropy(q.dims, T) +end + +function Base.show(io::IO, p::VonMisesFisher{T}) where T + msg = "VonMisesFisher{$T}(μ=$(summary(mean(p))), κ=$(concentration(p)))" + print(io, msg) +end diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 13a0c53..f517150 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -201,3 +201,13 @@ function restructure(m, xs::AbstractVector) return x end end + +""" + short_repr(x) + +Shortens `repr(x)` to 50 chars... +""" +function short_repr(x, maxchars = 50) + e = repr(x) + e = sizeof(e)>maxchars ? "($(e[1:maxchars - 3])...)" : e +end \ No newline at end of file diff --git a/src/utils/vmf.jl b/src/utils/vmf.jl new file mode 100644 index 0000000..de009d2 --- /dev/null +++ b/src/utils/vmf.jl @@ -0,0 +1,111 @@ +# Utils for Von Mises-Fisher distribution + +∇besselix(ν, x) = @. besselix(ν - 1, x) - besselix(ν, x) * (ν + x) / x +@adjoint SpecialFunctions.besselix(ν, x) = besselix(ν, x), Δ -> (nothing, ∇besselix(ν, x) * Δ) +@adjoint SpecialFunctions.loggamma(x) = loggamma(x), Δ -> (digamma(x) * Δ,) + +""" + vmfentropy(d, κ) + +Entropy of a Von Mises-Fisher distribution with dimensinality `d` and concentration `κ` +""" +vmfentropy(d, κ::T) where {T} = vmfentropy(T(d), κ) +vmfentropy(d::T, κ::T) where {T} = .-κ .* besselix(d / 2, κ) ./ besselix(d / 2 - 1, κ) .- ((d ./ 2 .- 1) .* log.(κ) .- (d ./ 2) .* log(T(2) * π) .- (κ .+ log.(besselix(d / 2 - 1, κ)))) + +""" + huentropy(d) + +Entropy of a Hyperspherical Uniform distribution with dimensinality `d` +""" +huentropy(d, T) = huentropy(T(d)) +huentropy(d::T) where {T <: AbstractFloat} = d / 2 * log(T(1) * π) + log(T(2)) - (loggamma(d / 2))[1] + +# Likelihood estimation of a sample x under VMF with given parameters taken from https://pdfs.semanticscholar.org/2b5b/724fb175f592c1ff919cc61499adb26996b1.pdf + +""" + vmf_norm_const(d, κ) + +Likelihood normalizing constant of a Von Mises-Fisher distribution with dimensinality `d` and concentration `κ` +""" +vmf_norm_const(d, κ::T) where {T} = vmf_norm_const(T(d), κ) +vmf_norm_const(d::T, κ::T) where {T} = κ ^ (d / 2 - 1) / ((T(2) * π) ^ (d / 2) * besseli(d / 2 - 1, κ)) + +# log likelihood of one sample under the VMF dist with given parameters +""" + log_vmf(x, μ, κ) + +Loglikelihood of `x` under the Von Mises-Fisher distribution with mean `μ` and concentration `κ` +""" +log_vmf(x::AbstractVector{T}, μ::AbstractVector{T}, κ::T) where {T} = κ * μ' * x .+ log(vmf_norm_const(length(μ), κ)) + +#? Will we need these as well? Can we actually make this without the for cycle? They can probably be optimised by computing the norm constant just once etc. +#! Not vectorized - will be slow! +log_vmf(x::AbstractMatrix{T}, μ::AbstractMatrix{T}, κ::T) where {T} = [log_vmf(x[:, i], μ[:, i], κ) for i in 1:size(x, 2)]' +log_vmf(x::AbstractMatrix{T}, μ::AbstractMatrix{T}, κ::AbstractVector{T}) where {T} = [log_vmf(x[:, i], μ[:, i], κ[i]) for i in 1:size(x, 2)]' + +""" + log_vmf_wo_c(x, μ, κ) + +Loglikelihood of `x` under the Von Mises-Fisher distribution with mean `μ` and concentration `κ` **without** the normalizing constant. +It can be very useful when it is used just for comparison of likelihoods etc. because it is very expensive to compute and in many applications +it has no effect on the outcome. +""" +log_vmf_wo_c(x, μ, κ) = κ * μ' * x + + +# This sampling procedure is one that can be differentiated through taken from https://arxiv.org/pdf/1804.00891.pdf +normalizecolumns(m::AbstractArray{T, 2}) where {T} = m ./ sqrt.(sum(m .^ 2, dims = 1) .+ eps(T)) + +sample_vmf(μ::AbstractArray{T}, κ::Union{T, AbstractArray{T}}) where {T} = sample_vmf(μ, κ, size(μ, 1)) +function sample_vmf(μ::AbstractArray{T}, κ::Union{T, AbstractArray{T}}, dims) where {T} + ω = sampleω(κ, dims) + v = zeros(T, dims - 1, size(μ, 2)) + randn!(v) + v = normalizecolumns(v) + householderrotation(vcat(ω, sqrt.(1 .- ω .^ 2) .* v), μ) +end + +@nograd function sampleω(κ::Union{T, AbstractArray{T}}, dims) where {T} + c = @. sqrt(4κ ^ 2 + (dims - 1) ^ 2) + b = @. (-2κ + c) / (dims - 1) + a = @. (dims - 1 + 2κ + c) / 4 + d = @. (4 * a * b) / (1 + b) - (dims - 1) * log(dims - 1) + ω = rejectionsampling(dims, a, b, d, κ) +end + +function householderrotation(zprime::AbstractArray{T}, μ::AbstractArray{T}) where {T} + # e1 = similar(μ) .= 0 + e1 = vcat(ones(T, 1, size(μ, 2)), zeros(T, size(μ, 1) - 1, size(μ, 2))) + u = e1 .- μ + normalizedu = normalizecolumns(u) + zprime .- T(2) .* sum(zprime .* normalizedu, dims = 1) .* normalizedu +end + +function rejectionsampling(dims, a, b, d, κ::Union{T, AbstractArray{T}}) where {T} + beta = Beta((dims - 1) / 2, (dims - 1) / 2) + ϵ = Adapt.adapt(T, rand(beta, size(a)...)) #! This is really stupid but even Beta{Float32} samples Float64 values :( - we should switch to our sampler + u = rand(T, size(a)...) + + accepted = isaccepted(ϵ, u, dims, a, b, d) + it = 0 + while (!all(accepted)) & (it < 10000) + mask = .! accepted + ϵ[mask] = Adapt.adapt(T, rand(beta, sum(mask))) #! same issue as above + u[mask] = rand(T, sum(mask)) + accepted[mask] = isaccepted(mask, ϵ, u, dims, a, b, d) + it += 1 + end + if it >= 10000 + println("Warning - sampler was stopped by 10000 iterations - it did not accept the sample!") + # perhaps this can be removed but some networks were causing issues in too high kappas etc so it was better to let + # it continue with a bit imprecise number from time to time than crashing the whole computation - might be an issue though + end + return @. (1 - (1 + b) * ϵ) / (1 - (1 - b) * ϵ) +end + +isaccepted(mask, ϵ, u, dims::Int, a, b, d) = isaccepted(ϵ[mask], u[mask], dims, a[mask], b[mask], d[mask]); +function isaccepted(ϵ, u, dims::Int, a, b, d) + ω = @. (1 - (1 + b) * ϵ) / (1 - (1 - b) * ϵ) + t = @. 2 * a * b / (1 - (1 - b) * ϵ) + @. (dims - 1) * log(t) - t + d >= log(u) +end \ No newline at end of file diff --git a/test/Manifest.toml b/test/Manifest.toml index cd8d0f3..9c8d5e9 100644 --- a/test/Manifest.toml +++ b/test/Manifest.toml @@ -24,6 +24,12 @@ git-tree-sha1 = "2b6845cea546604fb4dca4e31414a6a59d39ddcd" uuid = "ec485272-7323-5ecc-a04f-4719b315124d" version = "0.0.4" +[[Arpack]] +deps = ["BinaryProvider", "Libdl", "LinearAlgebra"] +git-tree-sha1 = "07a2c077bdd4b6d23a40342a8a108e2ee5e58ab6" +uuid = "7d9fca2a-8960-54d3-9f78-7d1dccf2cb97" +version = "0.3.1" + [[ArrayInterface]] deps = ["LinearAlgebra", "Requires", "SparseArrays"] git-tree-sha1 = "981354dab938901c2b607a213e62d9defa50b698" @@ -55,12 +61,6 @@ git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.2.0" -[[CSTParser]] -deps = ["Tokenize"] -git-tree-sha1 = "99dda94f5af21a4565dc2b97edf6a95485f116c3" -uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" -version = "1.0.0" - [[CUDAapi]] deps = ["Libdl", "Logging"] git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b" @@ -172,9 +172,9 @@ version = "6.4.2" [[DiffEqDiffTools]] deps = ["ArrayInterface", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "21b855cb29ec4594f9651e0e9bdc0cdcfdcd52c1" +git-tree-sha1 = "81edfb3a8b55154772bb6080b5db40868e1778ed" uuid = "01453d9d-ee7c-5054-8395-0335cb756afa" -version = "1.3.0" +version = "1.4.0" [[DiffResults]] deps = ["Compat", "StaticArrays"] @@ -198,6 +198,12 @@ version = "0.8.2" deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[Distributions]] +deps = ["LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns"] +git-tree-sha1 = "838a37797ac24c1b4c3353d46ec87ea6598f2308" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.21.7" + [[DocStringExtensions]] deps = ["LibGit2", "Markdown", "Pkg", "Test"] git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a" @@ -269,8 +275,8 @@ uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" version = "2.0.0" [[GenerativeModels]] -deps = ["BSON", "CuArrays", "DiffEqBase", "DrWatson", "Flux", "ForwardDiff", "OrdinaryDiffEq", "Random", "Requires", "Statistics", "ValueHistories", "Zygote"] -path = "/home/niklas/repos/GenerativeModels.jl" +deps = ["Adapt", "BSON", "CuArrays", "DiffEqBase", "Distributions", "DrWatson", "Flux", "ForwardDiff", "OrdinaryDiffEq", "Random", "Requires", "SpecialFunctions", "Statistics", "ValueHistories", "Zygote"] +path = "/home/bimjan/dev/julia/GenerativeModels.jl" uuid = "6ac2c632-c4cd-11e9-0501-33c4b9b2f9c9" version = "0.1.0" @@ -369,10 +375,10 @@ uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" version = "0.3.8" [[MacroTools]] -deps = ["CSTParser", "Compat", "DataStructures", "Test", "Tokenize"] -git-tree-sha1 = "d6e9dedb8c92c3465575442da456aec15a89ff76" +deps = ["Compat", "DataStructures", "Test"] +git-tree-sha1 = "82921f0e3bde6aebb8e524efc20f4042373c0c06" uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.1" +version = "0.5.2" [[Markdown]] deps = ["Base64"] @@ -435,6 +441,12 @@ git-tree-sha1 = "d5b15b63a76146da84079aeabaf9d7f692afbcdf" uuid = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" version = "5.19.0" +[[PDMats]] +deps = ["Arpack", "LinearAlgebra", "SparseArrays", "SuiteSparse", "Test"] +git-tree-sha1 = "035f8d60ba2a22cb1d2580b1e0e5ce0cb05e4563" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.9.10" + [[Parameters]] deps = ["OrderedCollections"] git-tree-sha1 = "b62b2558efb1eef1fa44e4be5ff58a515c287e38" @@ -459,6 +471,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" deps = ["Printf"] uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" +[[QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "1af46bf083b9630a5b27d4fd94f496c5fca642a8" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.1.1" + [[REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" @@ -502,6 +520,12 @@ git-tree-sha1 = "301706196827bdcc045658fc6df3e52fd3d76f83" uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" version = "2.2.2" +[[Rmath]] +deps = ["BinaryProvider", "Libdl", "Random", "Statistics"] +git-tree-sha1 = "9825383d3453f4606d77f0a5722495f38001c09e" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.5.1" + [[Roots]] deps = ["Printf"] git-tree-sha1 = "9cc4b586c71f9aea25312b94be8c195f119b0ec3" @@ -539,9 +563,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SparseDiffTools]] deps = ["Adapt", "ArrayInterface", "DataStructures", "DiffEqDiffTools", "ForwardDiff", "LightGraphs", "LinearAlgebra", "Requires", "SparseArrays", "VertexSafeGraphs"] -git-tree-sha1 = "10537f7c6d3cfda414c7b9fb378bbd165f92735c" +git-tree-sha1 = "77083200046ca5c56a6aca9a9b6f5af240a1b419" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" -version = "0.10.1" +version = "0.10.3" [[SpecialFunctions]] deps = ["BinDeps", "BinaryProvider", "Libdl"] @@ -565,6 +589,12 @@ git-tree-sha1 = "c53e809e63fe5cf5de13632090bc3520649c9950" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" version = "0.32.0" +[[StatsFuns]] +deps = ["Rmath", "SpecialFunctions"] +git-tree-sha1 = "67745a79d8e83a83737a7e17a383c54720a97f41" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "0.9.0" + [[SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -586,15 +616,10 @@ deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[TimerOutputs]] -deps = ["Crayons", "Printf", "Test", "Unicode"] -git-tree-sha1 = "b80671c06f8f8bae08c55d67b5ce292c5ae2660c" +deps = ["Printf"] +git-tree-sha1 = "8f22dc0c23e1cd4ab8070a01ba32285926f104f1" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.0" - -[[Tokenize]] -git-tree-sha1 = "dfcdbbfb2d0370716c815cbd6f8a364efb6f42cf" -uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" -version = "0.5.6" +version = "0.5.2" [[TranscodingStreams]] deps = ["Random", "Test"] diff --git a/test/Project.toml b/test/Project.toml index 1f5acbb..b48205f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GenerativeModels = "6ac2c632-c4cd-11e9-0501-33c4b9b2f9c9" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" diff --git a/test/models/svae.jl b/test/models/svae.jl new file mode 100644 index 0000000..ce55332 --- /dev/null +++ b/test/models/svae.jl @@ -0,0 +1,43 @@ +@testset "models/svae.jl" begin + + Random.seed!(0) + + @testset "Vanilla SVAE" begin + T = Float32 + xlen = 4 + zlen = 3 + batch = 20 + test_data = hcat(ones(T,xlen,Int(batch/2)), -ones(T,xlen,Int(batch/2))) + + enc = GenerativeModels.ae_layer_builder([xlen, 10, 10, zlen], relu, Dense) + enc_dist = CMeanConcVMF{T}(enc, zlen) + + dec = GenerativeModels.ae_layer_builder([zlen, 10, 10, xlen], relu, Dense) + dec_dist = CMeanGaussian{T,DiagVar}(dec, NoGradArray(ones(T,xlen))) + + model = SVAE(HypersphericalUniform{T}(zlen), enc_dist, dec_dist) + + loss = elbo(model, test_data) + ps = params(model) + @test length(ps) > 0 + @test isa(loss, T) + + zs = rand(model.encoder, test_data) + @test size(zs) == (zlen, batch) + xs = rand(model.decoder, zs) + @test size(xs) == (xlen, batch) + + # test training + params_init = get_params(model) + opt = ADAM() + data = [(test_data,) for i in 1:10000] + lossf(x) = elbo(model, x, β=1e-3) + Flux.train!(lossf, params(model), data, opt) + + @test all(param_change(params_init, model)) # did the params change? + zs = rand(model.encoder, test_data) + xs = mean(model.decoder, zs) + @debug maximum(test_data - xs) + # @test all(abs.(test_data - xs) .< 0.2) # is the reconstruction ok? #! it was giving me errors for SVAE it may need a different constant + end +end \ No newline at end of file diff --git a/test/pdfs/abstract_cvmf.jl b/test/pdfs/abstract_cvmf.jl new file mode 100644 index 0000000..681972c --- /dev/null +++ b/test/pdfs/abstract_cvmf.jl @@ -0,0 +1,10 @@ +@testset "pdfs/abstract_cvmf.jl" begin + struct CVMF{T<:Real} <: GenerativeModels.AbstractCVMF{T} end + + cvmf = CVMF{Float32}() + x = ones(1) + + @test_throws ErrorException mean_var(cvmf, x) + @test_throws ErrorException concentration(cvmf, x) + @test_throws ErrorException mean(cvmf, x) +end \ No newline at end of file diff --git a/test/pdfs/hs_uniform.jl b/test/pdfs/hs_uniform.jl new file mode 100644 index 0000000..94e9c24 --- /dev/null +++ b/test/pdfs/hs_uniform.jl @@ -0,0 +1,11 @@ +@testset "src/hs_uniform.jl" begin + + hu = HypersphericalUniform(3) + @test length(hu) == 3 + @test size(rand(hu, 5), 2) == 5 + + r = rand(hu, 5) + for i in 1:size(r, 2) + @test norm(r[:, i]) ≈ 1 + end +end \ No newline at end of file diff --git a/test/pdfs/vonmisesfisher.jl b/test/pdfs/vonmisesfisher.jl new file mode 100644 index 0000000..e282258 --- /dev/null +++ b/test/pdfs/vonmisesfisher.jl @@ -0,0 +1,30 @@ +@testset "src/vonmisesfisher.jl" begin + + T = Float32 + + p = VonMisesFisher(T.([1, 0, 0]), T(1)) + μ = mean(p) + κ = concentration(p) + @test mean_conc(p) == (μ, κ) + @test size(rand(p, 10)) == (3, 10) + @test size(loglikelihood(p, randn(T, 3, 10))) == (1, 10) + @test size(loglikelihood(p, randn(T, 3))) == (1, 1) + @test length(Flux.trainable(p)) == 1 + + @test eltype(loglikelihood(p, randn(T, 3, 10))) == T + @test eltype(rand(p, 10)) == T + + q = VonMisesFisher(zeros(2), ones(1)) + @test length(Flux.trainable(q)) == 2 + @test size(kld(q, HypersphericalUniform{Float64}(2))) == () + + μ = NoGradArray(zeros(2)) + p = VonMisesFisher(μ, ones(1)) + @test length(Flux.trainable(p)) == 1 + + p = VonMisesFisher(NoGradArray([1, 0, 0.]), NoGradArray([1.0])) + @test length(Flux.trainable(p)) == 0 + + # msg = @capture_out show(p) + # @test occursin("VonMisesFisher", msg) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 1dd45b4..e9e5fa8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,11 +2,13 @@ using Test, Suppressor, Logging, Parameters, Random using BSON, DrWatson, ValueHistories using Flux, Zygote, ForwardDiff using DiffEqBase, OrdinaryDiffEq +using LinearAlgebra using Revise using GenerativeModels -if Flux.use_cuda[] using CuArrays end +# if Flux.use_cuda[] using CuArrays end +using CuArrays @warn """Remove `Flux.gpu(x) = identity(x)` from runtests.jl once CUDAdrv does not try to load CUDA anymore even though it @@ -26,8 +28,13 @@ include(joinpath("pdfs", "abstract_pdf.jl")) include(joinpath("pdfs", "gaussian.jl")) include(joinpath("pdfs", "cmean_gaussian.jl")) include(joinpath("pdfs", "cmeanvar_gaussian.jl")) +include(joinpath("pdfs", "abstract_cvmf.jl")) +include(joinpath("pdfs", "vonmisesfisher.jl")) +include(joinpath("pdfs", "hs_uniform.jl")) + include(joinpath("models", "vae.jl")) +include(joinpath("models", "svae.jl")) include(joinpath("models", "gan.jl")) include(joinpath("models", "rodent.jl")) @@ -35,3 +42,4 @@ include(joinpath("utils", "utils.jl")) include(joinpath("utils", "saveload.jl")) include(joinpath("utils", "nogradarray.jl")) include(joinpath("utils", "flux_ode_decoder.jl")) +include(joinpath("utils", "vmf.jl")) diff --git a/test/utils/vmf.jl b/test/utils/vmf.jl new file mode 100644 index 0000000..a93a8e2 --- /dev/null +++ b/test/utils/vmf.jl @@ -0,0 +1,6 @@ +@testset "utils/vmf.jl" begin + +#! Some tests are needed here + + +end