Skip to content

more NF examples #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 123 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
4f591af
add support for hasconverged
zuhengxu Jul 11, 2023
4e31807
fix test error
zuhengxu Jul 11, 2023
99ac0c1
rm example/Manifest.toml
zuhengxu Jul 11, 2023
e4f5efa
minor bug fix for trainig loop
zuhengxu Jul 11, 2023
b345e78
test new stopping criterion
zuhengxu Jul 11, 2023
f10512f
test convergent condition/ rm unready examples
zuhengxu Jul 11, 2023
b93dbe8
Merge branch 'TuringLang:main' into hasconverge
zuhengxu Jul 11, 2023
1c1c88a
rm julia test from CI
zuhengxu Jul 11, 2023
5d2844f
Revert "rm julia test from CI"
zuhengxu Jul 11, 2023
3fdcb0e
make autodiff pkgs as extension + require for bwd compat
zuhengxu Jul 12, 2023
ddad59e
debugging Ext
zuhengxu Jul 12, 2023
ef60ee1
keep debugging
zuhengxu Jul 13, 2023
5a5deb0
Fix AD package extension loading issues
sunxd3 Jul 13, 2023
25d4211
Applying @devmotion's comment
sunxd3 Jul 13, 2023
df5eddd
patch last commit
sunxd3 Jul 13, 2023
d44143b
patch for julia 1.6
sunxd3 Jul 13, 2023
b03e922
loading dep pkgs from main pkg instead of functions for explicitness
zuhengxu Jul 13, 2023
e9acf70
fixing test err
zuhengxu Jul 13, 2023
c6bf68b
rm unready examples
zuhengxu Jul 13, 2023
fabf20a
update realnvp
zuhengxu Jul 24, 2023
7964f91
minor ed
zuhengxu Jul 24, 2023
01de9d4
removing unnecessary import
zuhengxu Jul 24, 2023
6993d33
refactor affinecoupling and example/
zuhengxu Jul 31, 2023
810881f
debug affinecoupling flow
zuhengxu Jul 31, 2023
0277c9e
adapt to the updated autoforwarddiff to resolve test err
zuhengxu Jul 31, 2023
4bc02bf
fix test err
zuhengxu Jul 31, 2023
144c668
add new implementation of affcoupling using Bijectors.Coupling
zuhengxu Jul 31, 2023
2678198
implement ham flow
zuhengxu Aug 1, 2023
ab7ac64
finish hamflow implementation
zuhengxu Aug 1, 2023
308ab61
minor update
zuhengxu Aug 1, 2023
ceafbde
rename hamflow.jl to hamiltonian_layer.jl
zuhengxu Aug 3, 2023
5589f36
upadting readme
zuhengxu Aug 3, 2023
de01e0e
rm hamflow.jl
zuhengxu Aug 3, 2023
93a8572
Merge branch 'main' of github.com:zuhengxu/NormalizingFlows.jl into m…
zuhengxu Aug 8, 2023
b8a8f5f
sync with main
zuhengxu Aug 9, 2023
c858b10
fix minor bugs in affine coupling layer
zuhengxu Aug 16, 2023
baef6a9
test affine coupling flow on banana
zuhengxu Aug 16, 2023
0323680
rename simple flow run files
zuhengxu Aug 16, 2023
144b06a
update loglikelihood to fit in optimize interface
zuhengxu Aug 16, 2023
f70a4b7
fix minor bugs in nsf_layer
zuhengxu Aug 17, 2023
30cfc32
rm unused data in nsf lfow
zuhengxu Aug 17, 2023
00387d3
rm @view to avoid zygote mutation error
zuhengxu Aug 17, 2023
2cacfc6
update optimize for direct minibatch training
zuhengxu Sep 28, 2023
24b1564
add BatchNorm for real nvp and make real nvp compatible with batching
zuhengxu Sep 28, 2023
f720eab
testing
zuhengxu Sep 29, 2023
880976c
update resblock
zuhengxu Oct 4, 2023
f6e4189
update resnet arch
zuhengxu Oct 5, 2023
0a78963
minor update
zuhengxu Oct 5, 2023
7c953c2
Merge branch 'stability' of github.com:zuhengxu/NormalizingFlows.jl i…
zuhengxu Oct 5, 2023
e41c3f6
udpate
zuhengxu Oct 5, 2023
91cf585
fix affine bugs
zuhengxu Oct 5, 2023
c523f94
add invertiblemlp
zuhengxu Oct 5, 2023
1f67b10
start testing invertiblenetworks
zuhengxu Oct 5, 2023
6cde9cc
minor update
zuhengxu Oct 5, 2023
eada0e6
minor bug fix
zuhengxu Oct 5, 2023
8255949
update MLP to make it work
zuhengxu Oct 9, 2023
37bd3ab
fix type instability in invertible MLP
zuhengxu Oct 9, 2023
a67a6e2
update deep mlp
zuhengxu Oct 10, 2023
d6c6a2b
update deep hamflow result
zuhengxu Oct 10, 2023
799343a
update flows
zuhengxu Oct 10, 2023
b908a80
add shadowing window computation
zuhengxu Oct 11, 2023
8bde6bf
minor ed
zuhengxu Oct 11, 2023
eaf1775
refactoring training and setup files
zuhengxu Oct 11, 2023
519fd66
minor setup change to allow convenient precision switch
zuhengxu Oct 11, 2023
77816fd
get MLP figs and res'
zuhengxu Oct 12, 2023
8db13c9
obtain figs for ham
zuhengxu Oct 12, 2023
4a656e8
rm duplication
zuhengxu Oct 12, 2023
9f68fdd
add some MLP result
zuhengxu Oct 12, 2023
8c703d5
add shf res
zuhengxu Oct 12, 2023
9e27814
add mom normalization layer implementation
zuhengxu Oct 13, 2023
c63abfc
update shf res
zuhengxu Oct 13, 2023
61862ff
add script for deep shf
zuhengxu Oct 13, 2023
4fe363f
update shf_big res
zuhengxu Oct 13, 2023
a9d46bf
minor update
zuhengxu Oct 13, 2023
8bd27ee
update log reg and new shf res with q0 being std normal
zuhengxu Oct 16, 2023
156faeb
minor bug update
zuhengxu Oct 16, 2023
7e7e782
update sfh banana figs
zuhengxu Oct 17, 2023
8be77ad
add some lip constant figs for shf banana
zuhengxu Oct 17, 2023
0381259
add some lip constant reg
zuhengxu Oct 17, 2023
37835c8
update shf banana/cross
zuhengxu Oct 18, 2023
2f15b1b
update shf/stab
zuhengxu Oct 18, 2023
2ae4f27
update banana plots
zuhengxu Oct 18, 2023
3e453e5
update banana res
zuhengxu Oct 18, 2023
ed088eb
fix conflict
zuhengxu Oct 18, 2023
7db8236
update cross res
zuhengxu Oct 18, 2023
167af8e
update cross res
zuhengxu Oct 18, 2023
d491fee
update
zuhengxu Oct 18, 2023
1aaad46
update cross
zuhengxu Oct 18, 2023
fa80afb
update
zuhengxu Oct 18, 2023
3ffe162
Merge branch 'stability' of github.com:zuhengxu/NormalizingFlows.jl i…
zuhengxu Oct 18, 2023
e25f31b
update cross elbo res
zuhengxu Oct 18, 2023
4e78a5a
update shadowing res
zuhengxu Oct 18, 2023
7704896
udpate banana and cross res
zuhengxu Oct 18, 2023
b4b4aaf
update banana figs
zuhengxu Oct 18, 2023
9d4f6ba
update cross res
zuhengxu Oct 19, 2023
b7d7b67
update cross figs and res
zuhengxu Oct 19, 2023
3b148d7
update cross lp scaling figs
zuhengxu Oct 20, 2023
ae35d63
minor update
zuhengxu Oct 20, 2023
ae9b9f2
new cross figs
zuhengxu Oct 20, 2023
deaf04a
fix conflict
zuhengxu Oct 20, 2023
2dd855a
rm readme
zuhengxu Jan 9, 2025
b30d961
update
zuhengxu Mar 16, 2025
aeda365
Merge branch 'stability' into more_examples
zuhengxu Mar 16, 2025
2c4d793
rm many unrelated code
zuhengxu Mar 16, 2025
6056ef1
merge conflict from upstream
zuhengxu Mar 16, 2025
960a359
keep cleaning
zuhengxu Mar 17, 2025
51bc5b9
merge from turing/main
zuhengxu Apr 9, 2025
f5a2f9a
fix warpgaussian logpdf error/making neal funnel logpdf working with …
zuhengxu Apr 9, 2025
d916b64
add easier model loading code
zuhengxu Apr 9, 2025
d84fce4
restructure example folder and refactor planar and radial examples
zuhengxu Apr 10, 2025
6872760
clean demos for realnvp/planar/radial/ fix a bug in nsf
zuhengxu Apr 10, 2025
e3fb428
rm enzyme dependency
zuhengxu Apr 10, 2025
ddf0ba9
tune nsf flow--enlarge B--to make it work
zuhengxu Apr 10, 2025
f2affe2
rm some useless files from HamVI
zuhengxu Apr 10, 2025
78d5f45
rename MLP_3layer to mlp3 for convenience
zuhengxu Apr 10, 2025
2229577
rename common.jl to utils
zuhengxu Apr 10, 2025
aefc651
minor update file naming, and better nsf implementation
zuhengxu Apr 11, 2025
7f1580c
minor ed
zuhengxu Apr 11, 2025
750c1e1
rm redundant nsf file
zuhengxu Apr 11, 2025
ab48d11
shrink stepsize for nsf
zuhengxu Apr 11, 2025
42bc4a1
cleaned hamiltonian flow
zuhengxu Apr 12, 2025
edf1bac
minor ed
zuhengxu Apr 12, 2025
c05789b
rm redundant pkgs
zuhengxu Apr 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/docs/build/
test/Manifest.toml
example/Manifest.toml
example/LocalPreferences.toml

# Files generated by invoking Julia with --code-coverage
*.jl.cov
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for adding these dependencies?

LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also this)

Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
14 changes: 10 additions & 4 deletions example/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might not be required anymore?

Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FunctionChains = "8e6b2b91-af83-483e-ba35-d00930e4cf9b"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
6 changes: 3 additions & 3 deletions example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ normalizing flow to approximate the target distribution using `NormalizingFlows.
Currently, all examples share the same [Julia project](https://pkgdocs.julialang.org/v1/environments/#Using-someone-else's-project). To run the examples, first activate the project environment:

```julia
# pwd() = "NormalizingFlows.jl/"
using Pkg; Pkg.activate("example"); Pkg.instantiate()
# pwd() = "NormalizingFlows.jl/example"
using Pkg; Pkg.activate("."); Pkg.instantiate()
```
This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with include("<example-to-run>.jl"), or by running the example script line-by-line.
This will install all needed packages, at the exact versions when the model was last updated. Then you can run the model code with `include("<example-to-run>.jl")`, or by running the example script line-by-line.
34 changes: 34 additions & 0 deletions example/SyntheticTargets.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using DocStringExtensions
using Distributions, Random, LinearAlgebra
using IrrationalConstants
using Plots


include("targets/banana.jl")
include("targets/cross.jl")
include("targets/neal_funnel.jl")
include("targets/warped_gaussian.jl")


function load_model(name::String)
if name == "Banana"
return Banana(2, 1.0, 10.0)
elseif name == "Cross"
return Cross()
elseif name == "Funnel"
return Funnel(2)
elseif name == "WarpedGaussian"
return WarpedGauss()
else
error("Model not defined")
end
end
Comment on lines +13 to +25
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is slightly ad hoc (it required a jump to know what this function will return at the call site), maybe use model constructor directly?


function visualize(p::ContinuousMultivariateDistribution, samples=rand(p, 1000))
xrange = range(minimum(samples[1, :]) - 1, maximum(samples[1, :]) + 1; length=100)
yrange = range(minimum(samples[2, :]) - 1, maximum(samples[2, :]) + 1; length=100)
z = [exp(Distributions.logpdf(p, [x, y])) for x in xrange, y in yrange]
fig = contour(xrange, yrange, z'; levels=15, color=:viridis, label="PDF", linewidth=2)
scatter!(samples[1, :], samples[2, :]; label="Samples", alpha=0.3, legend=:bottomright)
return fig
end
53 changes: 0 additions & 53 deletions example/common.jl

This file was deleted.

163 changes: 163 additions & 0 deletions example/demo_RealNVP.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
using Flux
using Bijectors
using Bijectors: partition, combine, PartitionMask

using Random, Distributions, LinearAlgebra
using Functors
using Optimisers, ADTypes
using Mooncake
using NormalizingFlows

include("SyntheticTargets.jl")
include("utils.jl")

##################################
# define affine coupling layer using Bijectors.jl interface
#################################
struct AffineCoupling <: Bijectors.Bijector
dim::Int
mask::Bijectors.PartitionMask
s::Flux.Chain
t::Flux.Chain
end

# let params track field s and t
@functor AffineCoupling (s, t)

function AffineCoupling(
dim::Int, # dimension of input
hdims::Int, # dimension of hidden units for s and t
mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
)
cdims = length(mask_idx) # dimension of parts used to construct coupling law
s = mlp3(cdims, hdims, cdims)
t = mlp3(cdims, hdims, cdims)
mask = PartitionMask(dim, mask_idx)
return AffineCoupling(dim, mask, s, t)
end

function Bijectors.transform(af::AffineCoupling, x::AbstractVector)
# partition vector using 'af.mask::PartitionMask`
x₁, x₂, x₃ = partition(af.mask, x)
y₁ = x₁ .* af.s(x₂) .+ af.t(x₂)
return combine(af.mask, y₁, x₂, x₃)
end

function (af::AffineCoupling)(x::AbstractArray)
return transform(af, x)
end

function Bijectors.with_logabsdet_jacobian(af::AffineCoupling, x::AbstractVector)
x_1, x_2, x_3 = Bijectors.partition(af.mask, x)
y_1 = af.s(x_2) .* x_1 .+ af.t(x_2)
logjac = sum(log ∘ abs, af.s(x_2))
return combine(af.mask, y_1, x_2, x_3), logjac
end

function Bijectors.with_logabsdet_jacobian(
iaf::Inverse{<:AffineCoupling}, y::AbstractVector
)
af = iaf.orig
# partition vector using `af.mask::PartitionMask`
y_1, y_2, y_3 = partition(af.mask, y)
# inverse transformation
x_1 = (y_1 .- af.t(y_2)) ./ af.s(y_2)
logjac = -sum(log ∘ abs, af.s(y_2))
return combine(af.mask, x_1, y_2, y_3), logjac
end

function Bijectors.logabsdetjac(af::AffineCoupling, x::AbstractVector)
_, x_2, _ = partition(af.mask, x)
logjac = sum(log ∘ abs, af.s(x_2))
return logjac
end

###################
# an equivalent definition of AffineCoupling using Bijectors.Coupling
# (see https://github.com/TuringLang/Bijectors.jl/blob/74d52d4eda72a6149b1a89b72524545525419b3f/src/bijectors/coupling.jl#L188C1-L188C1)
###################

# struct AffineCoupling <: Bijectors.Bijector
# dim::Int
# mask::Bijectors.PartitionMask
# s::Flux.Chain
# t::Flux.Chain
# end

# # let params track field s and t
# @functor AffineCoupling (s, t)

# function AffineCoupling(dim, mask, s, t)
# return Bijectors.Coupling(θ -> Bijectors.Shift(t(θ)) ∘ Bijectors.Scale(s(θ)), mask)
# end

# function AffineCoupling(
# dim::Int, # dimension of input
# hdims::Int, # dimension of hidden units for s and t
# mask_idx::AbstractVector, # index of dimensione that one wants to apply transformations on
# )
# cdims = length(mask_idx) # dimension of parts used to construct coupling law
# s = mlp3(cdims, hdims, cdims)
# t = mlp3(cdims, hdims, cdims)
# mask = PartitionMask(dim, mask_idx)
# return AffineCoupling(dim, mask, s, t)
# end



##################################
# start demo
#################################
Random.seed!(123)
rng = Random.default_rng()
T = Float32

######################################
# a difficult banana target
######################################
target = Banana(2, 1.0f0, 100.0f0)
logp = Base.Fix1(logpdf, target)

######################################
# learn the target using Affine coupling flow
######################################
@leaf MvNormal
q0 = MvNormal(zeros(T, 2), ones(T, 2))

d = 2
hdims = 32
Ls = [AffineCoupling(d, hdims, [1]) ∘ AffineCoupling(d, hdims, [2]) for i in 1:3]

flow = create_flow(Ls, q0)
flow_untrained = deepcopy(flow)


######################################
# start training
######################################
sample_per_iter = 64

# callback function to log training progress
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
flow_trained, stats, _ = train_flow(
elbo,
flow,
logp,
sample_per_iter;
max_iters=50_000,
optimiser=Optimisers.Adam(5e-4),
ADbackend=adtype,
show_progress=true,
callback=cb,
hasconverged=checkconv,
)
θ, re = Optimisers.destructure(flow_trained)
losses = map(x -> x.loss, stats)

######################################
# evaluate trained flow
######################################
plot(losses; label="Loss", linewidth=2) # plot the loss
compare_trained_and_untrained_flow(flow_trained, flow_untrained, target, 1000)
Loading