-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcommon.jl
53 lines (48 loc) · 1.27 KB
/
common.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
using Random, Distributions, LinearAlgebra, Bijectors
# accessing the trained flow by looking at the first 2 dimensions
function compare_trained_and_untrained_flow(
flow_trained::Bijectors.MultivariateTransformed,
flow_untrained::Bijectors.MultivariateTransformed,
true_dist::ContinuousMultivariateDistribution,
n_samples::Int;
kwargs...,
)
samples_trained = rand(flow_trained, n_samples)
samples_untrained = rand(flow_untrained, n_samples)
samples_true = rand(true_dist, n_samples)
p = scatter(
samples_true[1, :],
samples_true[2, :];
label="True Distribution",
color=:blue,
markersize=2,
alpha=0.5,
)
scatter!(
p,
samples_untrained[1, :],
samples_untrained[2, :];
label="Untrained Flow",
color=:red,
markersize=2,
alpha=0.5,
)
scatter!(
p,
samples_trained[1, :],
samples_trained[2, :];
label="Trained Flow",
color=:green,
markersize=2,
alpha=0.5,
)
plot!(; kwargs...)
xlabel!(p, "X")
ylabel!(p, "Y")
title!(p, "Comparison of Trained and Untrained Flow")
return p
end
function create_flow(Ls, q₀)
ts = fchain(Ls)
return transformed(q₀, ts)
end