Skip to content

Commit fe67dc4

Browse files
committed
add non-autonomous flow
1 parent 742bd53 commit fe67dc4

11 files changed

Lines changed: 234 additions & 134 deletions

File tree

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1212
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1313
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
1414
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
15-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1615
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1716
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1817
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
@@ -42,7 +41,6 @@ DifferentiationInterface = "0.7"
4241
Distributions = "0.25"
4342
DistributionsAD = "0.6"
4443
FillArrays = "1"
45-
ForwardDiff = "1"
4644
LinearAlgebra = "1"
4745
Lux = "1"
4846
LuxCore = "1"

benchmark/Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
55
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
66
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
7-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8-
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
97
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
108
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
119
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -17,8 +15,6 @@ BenchmarkTools = "1"
1715
ComponentArrays = "0.15"
1816
DifferentiationInterface = "0.7"
1917
Distributions = "0.25"
20-
ForwardDiff = "1"
21-
Lux = "1"
2218
LuxCore = "1"
2319
PkgBenchmark = "0.2"
2420
StableRNGs = "1"

benchmark/benchmarks.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ import ADTypes,
33
ComponentArrays,
44
DifferentiationInterface,
55
Distributions,
6-
ForwardDiff,
7-
Lux,
86
LuxCore,
97
PkgBenchmark,
108
StableRNGs,
@@ -19,18 +17,8 @@ r = rand(rng, data_dist, ndimension, ndata)
1917
r = convert.(Float32, r)
2018

2119
nvars = size(r, 1)
22-
naugs = nvars + 1
23-
n_in = nvars + naugs
24-
25-
nn = Lux.Chain(
26-
Lux.Dense(n_in => (2 * n_in + 1), tanh),
27-
Lux.Dense((2 * n_in + 1) => n_in, tanh),
28-
)
29-
30-
icnf = ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng)
31-
32-
icnf2 =
33-
ContinuousNormalizingFlows.ICNF(; nn, nvars, naugmented = naugs, rng, inplace = true)
20+
icnf = ContinuousNormalizingFlows.ICNF(; nvars, rng)
21+
icnf2 = ContinuousNormalizingFlows.ICNF(; nvars, rng, inplace = true)
3422

3523
ps, st = LuxCore.setup(icnf.rng, icnf)
3624
ps = ComponentArrays.ComponentArray(ps)

examples/usage.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using ContinuousNormalizingFlows,
3131
# To use gpu, add related packages
3232
# using LuxCUDA
3333

34-
nn = Chain(Dense(n_in => (2 * n_in + 1), tanh), Dense((2 * n_in + 1) => n_in, tanh))
34+
nn = Chain(Dense(n_in + 1 => n_in, tanh))
3535
icnf = ICNF(;
3636
nn = nn,
3737
nvars = nvars, # number of variables
@@ -45,6 +45,7 @@ icnf = ICNF(;
4545
# device = gpu_device(), # process data by GPU
4646
cond = false, # not conditioning on auxiliary input
4747
inplace = false, # not using the inplace version of functions
48+
autonomous = false, # using non-autonomous flow
4849
compute_mode = LuxVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
4950
sol_kwargs = (;
5051
save_everystep = false,

src/ContinuousNormalizingFlows.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import ADTypes,
88
Distributions,
99
DistributionsAD,
1010
FillArrays,
11-
ForwardDiff,
1211
LinearAlgebra,
1312
Lux,
1413
LuxCore,

0 commit comments

Comments
 (0)