-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathtutorial.jl
139 lines (104 loc) · 3.63 KB
/
tutorial.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
using Pkg # hideall
Pkg.activate("_literate/A-ensembles-2/Project.toml")
Pkg.instantiate()
macro OUTPUT()
return isdefined(Main, :Franklin) ? Franklin.OUT_PATH[] : "/tmp/"
end;
# @@dropdown
# ## Prelims
# @@
# @@dropdown-content
#
# This tutorial builds upon the previous ensemble tutorial with a home-made Random Forest regressor on the "boston" dataset.
#
using MLJ
using PrettyPrinting
using StableRNGs
import DataFrames: DataFrame, describe
MLJ.color_off() # hide
X, y = @load_boston
sch = schema(X)
p = length(sch.names)
describe(y) # From DataFrames
# Let's load the decision tree regressor
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
# Let's first check the performances of just a single Decision Tree Regressor (DTR for short):
tree = machine(DecisionTreeRegressor(), X, y)
e = evaluate!(tree, resampling=Holdout(fraction_train=0.8),
measure=[rms, rmslp1])
e
# Note that multiple measures can be reported simultaneously.
#
# @@
# @@dropdown
# ## Random forest
# @@
# @@dropdown-content
#
# Let's create an ensemble of DTR and fix the number of subfeatures to 3 for now.
forest = EnsembleModel(model=DecisionTreeRegressor())
forest.model.n_subfeatures = 3
# (**NB**: we could have fixed `n_subfeatures` in the DTR constructor too).
#
# To get an idea of how many trees are needed, we can follow the evaluation of the error (say the `rms`) for an increasing number of tree over several sampling round.
rng = StableRNG(5123) # for reproducibility
m = machine(forest, X, y)
r = range(forest, :n, lower=10, upper=1000)
curves = learning_curve(m, resampling=Holdout(fraction_train=0.8, rng=rng),
range=r, measure=rms);
# let's plot the curves
using Plots
Plots.scalefontsizes() #hide
Plots.scalefontsizes(1.2) #hide
plot(curves.parameter_values, curves.measurements,
xticks = [10, 100, 250, 500, 750, 1000],
size=(800,600), linewidth=2, legend=false)
xlabel!("Number of trees")
ylabel!("Root Mean Squared error")
savefig(joinpath(@OUTPUT, "A-ensembles-2-curves.svg")); # hide
# \figalt{RMS vs number of trees}{A-ensembles-2-curves.svg}
#
# Let's go for 150 trees
forest.n = 150;
# @@dropdown
# ### Tuning
# @@
# @@dropdown-content
#
# As `forest` is a composite model, it has nested hyperparameters:
params(forest) |> pprint
# Let's define a range for the number of subfeatures and for the bagging fraction:
r_sf = range(forest, :(model.n_subfeatures), lower=1, upper=12)
r_bf = range(forest, :bagging_fraction, lower=0.4, upper=1.0);
# And build a tuned model as usual that we fit on a 80/20 split.
# We use a low-resolution grid here to make this tutorial faster but you could of course use a finer grid.
tuned_forest = TunedModel(model=forest,
tuning=Grid(resolution=3),
resampling=CV(nfolds=6, rng=StableRNG(32)),
ranges=[r_sf, r_bf],
measure=rms)
m = machine(tuned_forest, X, y)
e = evaluate!(m, resampling=Holdout(fraction_train=0.8),
measure=[rms, rmslp1])
e
#
# @@
# @@dropdown
# ### Reporting
# @@
# @@dropdown-content
# Again, we can visualize the results from the hyperparameter search
plot(m)
savefig(joinpath(@OUTPUT, "A-ensembles-2-heatmap.svg")); # hide
# \fig{A-ensembles-2-heatmap.svg}
#
# Even though we've only done a very rough search, it seems that around 7 sub-features and a bagging fraction of around `0.75` work well.
#
# Now that the machine `m` is trained, you can use use it for predictions (implicitly, this will use the best model).
# For instance we could look at predictions on the whole dataset:
ŷ = predict(m, X)
@show rms(ŷ, y)
#
# @@
#
# @@