-
Notifications
You must be signed in to change notification settings - Fork 5
more NF examples #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
more NF examples #11
Changes from all commits
4f591af
4e31807
99ac0c1
e4f5efa
b345e78
f10512f
b93dbe8
1c1c88a
5d2844f
3fdcb0e
ddad59e
ef60ee1
5a5deb0
25d4211
df5eddd
d44143b
b03e922
e9acf70
c6bf68b
fabf20a
7964f91
01de9d4
6993d33
810881f
0277c9e
4bc02bf
144c668
2678198
ab7ac64
308ab61
ceafbde
5589f36
de01e0e
93a8572
b8a8f5f
c858b10
baef6a9
0323680
144b06a
f70a4b7
30cfc32
00387d3
2cacfc6
24b1564
f720eab
880976c
f6e4189
0a78963
7c953c2
e41c3f6
91cf585
c523f94
1f67b10
6cde9cc
eada0e6
8255949
37bd3ab
a67a6e2
d6c6a2b
799343a
b908a80
8bde6bf
eaf1775
519fd66
77816fd
8db13c9
4a656e8
9f68fdd
8c703d5
9e27814
c63abfc
61862ff
4fe363f
a9d46bf
8bd27ee
156faeb
7e7e782
8be77ad
0381259
37835c8
2f15b1b
2ae4f27
3e453e5
ed088eb
7db8236
167af8e
d491fee
1aaad46
fa80afb
3ffe162
e25f31b
4e78a5a
7704896
b4b4aaf
9d4f6ba
b7d7b67
3b148d7
ae35d63
ae9b9f2
deaf04a
2dd855a
b30d961
aeda365
2c4d793
6056ef1
960a359
51bc5b9
f5a2f9a
d916b64
d84fce4
6872760
e3fb428
ddf0ba9
f2affe2
78d5f45
2229577
aefc651
7f1580c
750c1e1
ab48d11
42bc4a1
edf1bac
c05789b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,9 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | |
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (also this) |
||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,16 +2,22 @@ | |
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this might not be required anymore? |
||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b" | ||
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" | ||
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" | ||
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" | ||
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" | ||
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" | ||
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" | ||
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SimpleUnPack can be replaced with destructing syntax (like |
||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
|
||
[extras] | ||
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
using DocStringExtensions | ||
using Distributions, Random, LinearAlgebra | ||
using IrrationalConstants | ||
using Plots | ||
|
||
|
||
include("targets/banana.jl") | ||
include("targets/cross.jl") | ||
include("targets/neal_funnel.jl") | ||
include("targets/warped_gaussian.jl") | ||
|
||
|
||
function load_model(name::String) | ||
if name == "Banana" | ||
return Banana(2, 1.0, 10.0) | ||
elseif name == "Cross" | ||
return Cross() | ||
elseif name == "Funnel" | ||
return Funnel(2) | ||
elseif name == "WarpedGaussian" | ||
return WarpedGauss() | ||
else | ||
error("Model not defined") | ||
end | ||
end | ||
Comment on lines
+13
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is slightly ad hoc (it required a jump to know what this function will return at the call site), maybe use model constructor directly? |
||
|
||
function visualize(p::ContinuousMultivariateDistribution, samples=rand(p, 1000)) | ||
xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100) | ||
yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100) | ||
z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange] | ||
fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2) | ||
scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright) | ||
return fig | ||
end |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
using Flux | ||
using Bijectors | ||
using Bijectors: partition, combine, PartitionMask | ||
|
||
using Random, Distributions, LinearAlgebra | ||
using Functors | ||
using Optimisers, ADTypes | ||
using Mooncake | ||
using NormalizingFlows | ||
|
||
include("SyntheticTargets.jl") | ||
include("utils.jl") | ||
|
||
################################## | ||
# define affine coupling layer using Bijectors.jl interface | ||
################################# | ||
struct AffineCoupling <: Bijectors.Bijector | ||
dim::Int | ||
mask::Bijectors.PartitionMask | ||
s::Flux.Chain | ||
t::Flux.Chain | ||
end | ||
|
||
# let params track field s and t | ||
@functor AffineCoupling (s, t) | ||
|
||
function AffineCoupling( | ||
dim::Int, # dimension of input | ||
hdims::Int, # dimension of hidden units for s and t | ||
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on | ||
) | ||
cdims = length(mask_idx) # dimension of parts used to construct coupling law | ||
s = mlp3(cdims, hdims, cdims) | ||
t = mlp3(cdims, hdims, cdims) | ||
mask = PartitionMask(dim, mask_idx) | ||
return AffineCoupling(dim, mask, s, t) | ||
end | ||
|
||
function Bijectors.transform(af::AffineCoupling, x::AbstractVector) | ||
# partition vector using 'af.mask::PartitionMask` | ||
x₁, x₂, x₃ = partition(af.mask, x) | ||
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂) | ||
return combine(af.mask, y₁, x₂, x₃) | ||
end | ||
|
||
function (af::AffineCoupling)(x::AbstractArray) | ||
return transform(af, x) | ||
end | ||
|
||
function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector) | ||
x_1, x_2, x_3 = Bijectors.partition(af.mask, x) | ||
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2) | ||
logjac = sum(log ∘ abs, af.s(x_2)) | ||
return combine(af.mask, y_1, x_2, x_3), logjac | ||
end | ||
|
||
function Bijectors.with_logabsdet_jacobian( | ||
iaf::Inverse{<:AffineCoupling}, y::AbstractVector | ||
) | ||
af = iaf.orig | ||
# partition vector using `af.mask::PartitionMask` | ||
y_1, y_2, y_3 = partition(af.mask, y) | ||
# inverse transformation | ||
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2) | ||
logjac = -sum(log ∘ abs, af.s(y_2)) | ||
return combine(af.mask, x_1, y_2, y_3), logjac | ||
end | ||
|
||
function Bijectors.logabsdetjac(af::AffineCoupling, x::AbstractVector) | ||
_, x_2, _ = partition(af.mask, x) | ||
logjac = sum(log ∘ abs, af.s(x_2)) | ||
return logjac | ||
end | ||
|
||
################### | ||
# an equivalent definition of AffineCoupling using Bijectors.Coupling | ||
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1) | ||
################### | ||
|
||
# struct AffineCoupling <: Bijectors.Bijector | ||
# dim::Int | ||
# mask::Bijectors.PartitionMask | ||
# s::Flux.Chain | ||
# t::Flux.Chain | ||
# end | ||
|
||
# # let params track field s and t | ||
# @functor AffineCoupling (s, t) | ||
|
||
# function AffineCoupling(dim, mask, s, t) | ||
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask) | ||
# end | ||
|
||
# function AffineCoupling( | ||
# dim::Int, # dimension of input | ||
# hdims::Int, # dimension of hidden units for s and t | ||
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on | ||
# ) | ||
# cdims = length(mask_idx) # dimension of parts used to construct coupling law | ||
# s = mlp3(cdims, hdims, cdims) | ||
# t = mlp3(cdims, hdims, cdims) | ||
# mask = PartitionMask(dim, mask_idx) | ||
# return AffineCoupling(dim, mask, s, t) | ||
# end | ||
|
||
|
||
|
||
################################## | ||
# start demo | ||
################################# | ||
Random.seed!(123) | ||
rng = Random.default_rng() | ||
T = Float32 | ||
|
||
###################################### | ||
# a difficult banana target | ||
###################################### | ||
target = Banana(2, 1.0f0, 100.0f0) | ||
logp = Base.Fix1(logpdf, target) | ||
|
||
###################################### | ||
# learn the target using Affine coupling flow | ||
###################################### | ||
@leaf MvNormal | ||
q0 = MvNormal(zeros(T, 2), ones(T, 2)) | ||
|
||
d = 2 | ||
hdims = 32 | ||
Ls = [AffineCoupling(d, hdims, [1]) ∘ AffineCoupling(d, hdims, [2]) for i in 1:3] | ||
|
||
flow = create_flow(Ls, q0) | ||
flow_untrained = deepcopy(flow) | ||
|
||
|
||
###################################### | ||
# start training | ||
###################################### | ||
sample_per_iter = 64 | ||
|
||
# callback function to log training progress | ||
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype) | ||
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config()) | ||
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000 | ||
flow_trained, stats, _ = train_flow( | ||
elbo, | ||
flow, | ||
logp, | ||
sample_per_iter; | ||
max_iters=50_000, | ||
optimiser=Optimisers.Adam(5e-4), | ||
ADbackend=adtype, | ||
show_progress=true, | ||
callback=cb, | ||
hasconverged=checkconv, | ||
) | ||
θ, re = Optimisers.destructure(flow_trained) | ||
losses = map(x -> x.loss, stats) | ||
|
||
###################################### | ||
# evaluate trained flow | ||
###################################### | ||
plot(losses; label="Loss", linewidth=2) # plot the loss | ||
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for adding these dependencies?