Skip to content

Commit 9232329

Browse files
authored
Merge pull request #52 from JuliaAI/dev
For a 0.4.0 release
2 parents e6b3f4b + de9a8c8 commit 9232329

File tree

3 files changed

+138
-63
lines changed

3 files changed

+138
-63
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJDecisionTreeInterface"
22
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.3.1"
4+
version = "0.4.0"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/MLJDecisionTreeInterface.jl

Lines changed: 136 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,27 @@ function MMI.fit(
7777
return fitresult, cache, report
7878
end
7979

80+
# returns a dictionary of categorical elements keyed on ref integer:
8081
get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in classes(classes_seen))
8182

82-
MMI.fitted_params(::DecisionTreeClassifier, fitresult) =
83-
(tree=fitresult[1],
84-
encoding=get_encoding(fitresult[2]),
85-
features=fitresult[4])
83+
# given such a dictionary, return printable class labels, ordered by corresponding ref
84+
# integer:
85+
classlabels(encoding) = [string(encoding[i]) for i in sort(keys(encoding) |> collect)]
86+
87+
_node_or_leaf(r::DecisionTree.Root) = r.node
88+
_node_or_leaf(n::Any) = n
89+
90+
function MMI.fitted_params(::DecisionTreeClassifier, fitresult)
91+
raw_tree = fitresult[1]
92+
encoding = get_encoding(fitresult[2])
93+
features = fitresult[4]
94+
classlabels = MLJDecisionTreeInterface.classlabels(encoding)
95+
tree = DecisionTree.wrap(
96+
_node_or_leaf(raw_tree),
97+
(featurenames=features, classlabels),
98+
)
99+
(; tree, raw_tree, encoding, features)
100+
end
86101

87102
function MMI.predict(m::DecisionTreeClassifier, fitresult, Xnew)
88103
tree, classes_seen, integers_seen = fitresult
@@ -102,7 +117,7 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
102117
min_samples_split::Int = 2::(_ ≥ 2)
103118
min_purity_increase::Float64 = 0.0::(_ ≥ 0)
104119
n_subfeatures::Int = (-)(1)::(_ ≥ -1)
105-
n_trees::Int = 10::(_ ≥ 2)
120+
n_trees::Int = 100::(_ ≥ 0)
106121
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
107122
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
108123
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
@@ -285,13 +300,22 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)
285300
cache = nothing
286301

287302
report = (features=features,)
303+
fitresult = (tree, features)
288304

289-
return tree, cache, report
305+
return fitresult, cache, report
290306
end
291307

292-
MMI.fitted_params(::DecisionTreeRegressor, tree) = (tree=tree,)
308+
function MMI.fitted_params(::DecisionTreeRegressor, fitresult)
309+
raw_tree = fitresult[1]
310+
features = fitresult[2]
311+
tree = DecisionTree.wrap(
312+
_node_or_leaf(raw_tree),
313+
(; featurenames=features),
314+
)
315+
(; tree, raw_tree)
316+
end
293317

294-
MMI.predict(::DecisionTreeRegressor, tree, Xnew) = DT.apply_tree(tree, Xnew)
318+
MMI.predict(::DecisionTreeRegressor, fitresult, Xnew) = DT.apply_tree(fitresult[1], Xnew)
295319

296320
MMI.reports_feature_importances(::Type{<:DecisionTreeRegressor}) = true
297321

@@ -304,7 +328,7 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
304328
min_samples_split::Int = 2::(_ ≥ 2)
305329
min_purity_increase::Float64 = 0.0::(_ ≥ 0)
306330
n_subfeatures::Int = (-)(1)::(_ ≥ -1)
307-
n_trees::Int = 10::(_ ≥ 2)
331+
n_trees::Int = 100::(_ ≥ 0)
308332
sampling_fraction::Float64 = 0.7::(0 < _ ≤ 1)
309333
feature_importance::Symbol = :impurity::(_ ∈ (:impurity, :split))
310334
rng::Union{AbstractRNG,Integer} = GLOBAL_RNG
@@ -446,11 +470,11 @@ MMI.selectrows(::TreeModel, I, Xmatrix) = (view(Xmatrix, I, :),)
446470

447471
# get actual arguments needed for importance calculation from various fitresults.
448472
get_fitresult(
449-
m::Union{DecisionTreeClassifier, RandomForestClassifier},
473+
m::Union{DecisionTreeClassifier, RandomForestClassifier, DecisionTreeRegressor},
450474
fitresult,
451475
) = (fitresult[1],)
452476
get_fitresult(
453-
m::Union{DecisionTreeRegressor, RandomForestRegressor},
477+
m::RandomForestRegressor,
454478
fitresult,
455479
) = (fitresult,)
456480
get_fitresult(m::AdaBoostStumpClassifier, fitresult)= (fitresult[1], fitresult[2])
@@ -561,7 +585,7 @@ where
561585
Train the machine using `fit!(mach, rows=...)`.
562586
563587
564-
# Hyper-parameters
588+
# Hyperparameters
565589
566590
- `max_depth=-1`: max depth of the decision tree (-1=any)
567591
@@ -600,12 +624,14 @@ Train the machine using `fit!(mach, rows=...)`.
600624
601625
The fields of `fitted_params(mach)` are:
602626
603-
- `tree`: the tree or stump object returned by the core DecisionTree.jl algorithm
627+
- `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
628+
algorithm
629+
630+
- `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
631+
interface; see "Examples" below
604632
605633
- `encoding`: dictionary of target classes keyed on integers used
606-
internally by DecisionTree.jl; needed to interpret pretty printing
607-
of tree (obtained by calling `fit!(mach, verbosity=2)` or from
608-
report - see below)
634+
internally by DecisionTree.jl
609635
610636
- `features`: the names of the features encountered in training, in an
611637
order consistent with the output of `print_tree` (see below)
@@ -617,23 +643,28 @@ The fields of `report(mach)` are:
617643
618644
- `classes_seen`: list of target classes actually observed in training
619645
620-
- `print_tree`: method to print a pretty representation of the fitted
646+
- `print_tree`: alternative method to print the fitted
621647
tree, with single argument the tree depth; interpretation requires
622648
internal integer-class encoding (see "Fitted parameters" above).
623649
624650
- `features`: the names of the features encountered in training, in an
625651
order consistent with the output of `print_tree` (see below)
626652
653+
# Accessor functions
654+
655+
- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
656+
the type of importance is determined by the hyperparameter `feature_importance` (see
657+
above)
627658
628659
# Examples
629660
630661
```
631662
using MLJ
632-
Tree = @load DecisionTreeClassifier pkg=DecisionTree
633-
tree = Tree(max_depth=4, min_samples_split=3)
663+
DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
664+
model = DecisionTreeClassifier(max_depth=3, min_samples_split=3)
634665
635666
X, y = @load_iris
636-
mach = machine(tree, X, y) |> fit!
667+
mach = machine(model, X, y) |> fit!
637668
638669
Xnew = (sepal_length = [6.4, 7.2, 7.4],
639670
sepal_width = [2.8, 3.0, 2.8],
@@ -643,33 +674,26 @@ yhat = predict(mach, Xnew) # probabilistic predictions
643674
predict_mode(mach, Xnew) # point predictions
644675
pdf.(yhat, "virginica") # probabilities for the "verginica" class
645676
646-
fitted_params(mach).tree # raw tree or stump object from DecisionTrees.jl
647-
648-
julia> report(mach).print_tree(3)
649-
Feature 4, Threshold 0.8
650-
L-> 1 : 50/50
651-
R-> Feature 4, Threshold 1.75
652-
L-> Feature 3, Threshold 4.95
653-
L->
654-
R->
655-
R-> Feature 3, Threshold 4.85
656-
L->
657-
R-> 3 : 43/43
658-
```
659-
660-
To interpret the internal class labelling:
661-
662-
```
663-
julia> fitted_params(mach).encoding
664-
Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
665-
0x00000003 => "virginica"
666-
0x00000001 => "setosa"
667-
0x00000002 => "versicolor"
677+
julia> tree = fitted_params(mach).tree
678+
petal_length < 2.45
679+
├─ setosa (50/50)
680+
└─ petal_width < 1.75
681+
├─ petal_length < 4.95
682+
│ ├─ versicolor (47/48)
683+
│ └─ virginica (4/6)
684+
└─ petal_length < 4.85
685+
├─ virginica (2/3)
686+
└─ virginica (43/43)
687+
688+
using Plots, TreeRecipe
689+
plot(tree) # for a graphical representation of the tree
690+
691+
feature_importances(mach)
668692
```
669693
670-
See also
671-
[DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and
672-
the unwrapped model type [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
694+
See also [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and the
695+
unwrapped model type
696+
[`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
673697
674698
"""
675699
DecisionTreeClassifier
@@ -699,7 +723,7 @@ where
699723
Train the machine with `fit!(mach, rows=...)`.
700724
701725
702-
# Hyper-parameters
726+
# Hyperparameters
703727
704728
- `max_depth=-1`: max depth of the decision tree (-1=any)
705729
@@ -744,6 +768,13 @@ The fields of `fitted_params(mach)` are:
744768
- `features`: the names of the features encountered in training
745769
746770
771+
# Accessor functions
772+
773+
- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
774+
the type of importance is determined by the hyperparameter `feature_importance` (see
775+
above)
776+
777+
747778
# Examples
748779
749780
```
@@ -800,7 +831,7 @@ where:
800831
Train the machine with `fit!(mach, rows=...)`.
801832
802833
803-
# Hyper-parameters
834+
# Hyperparameters
804835
805836
- `n_iter=10`: number of iterations of AdaBoost
806837
@@ -834,6 +865,15 @@ The fields of `fitted_params(mach)` are:
834865
- `features`: the names of the features encountered in training
835866
836867
868+
# Accessor functions
869+
870+
- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
871+
the type of importance is determined by the hyperparameter `feature_importance` (see
872+
above)
873+
874+
875+
# Examples
876+
837877
```
838878
using MLJ
839879
Booster = @load AdaBoostStumpClassifier pkg=DecisionTree
@@ -852,6 +892,7 @@ pdf.(yhat, "virginica") # probabilities for the "verginica" class
852892
853893
fitted_params(mach).stumps # raw `Ensemble` object from DecisionTree.jl
854894
fitted_params(mach).coefs # coefficient associated with each stump
895+
feature_importances(mach)
855896
```
856897
857898
See also
@@ -886,7 +927,7 @@ where
886927
Train the machine with `fit!(mach, rows=...)`.
887928
888929
889-
# Hyper-parameters
930+
# Hyperparameters
890931
891932
- `max_depth=-1`: max depth of the decision tree (-1=any)
892933
@@ -903,7 +944,8 @@ Train the machine with `fit!(mach, rows=...)`.
903944
- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
904945
combined purity `>= merge_purity_threshold`
905946
906-
- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
947+
- `feature_importance`: method to use for computing feature importances. One of
948+
`(:impurity, :split)`
907949
908950
- `rng=Random.GLOBAL_RNG`: random number generator or seed
909951
@@ -921,26 +963,50 @@ The fields of `fitted_params(mach)` are:
921963
- `tree`: the tree or stump object returned by the core
922964
DecisionTree.jl algorithm
923965
966+
- `features`: the names of the features encountered in training
967+
924968
925969
# Report
926970
927971
- `features`: the names of the features encountered in training
928972
929973
974+
# Accessor functions
975+
976+
- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
977+
the type of importance is determined by the hyperparameter `feature_importance` (see
978+
above)
979+
980+
930981
# Examples
931982
932983
```
933984
using MLJ
934-
Tree = @load DecisionTreeRegressor pkg=DecisionTree
935-
tree = Tree(max_depth=4, min_samples_split=3)
985+
DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
986+
model = DecisionTreeRegressor(max_depth=3, min_samples_split=3)
936987
937-
X, y = make_regression(100, 2) # synthetic data
938-
mach = machine(tree, X, y) |> fit!
988+
X, y = make_regression(100, 4; rng=123) # synthetic data
989+
mach = machine(model, X, y) |> fit!
939990
940-
Xnew, _ = make_regression(3, 2)
991+
Xnew, _ = make_regression(3, 2; rng=123)
941992
yhat = predict(mach, Xnew) # new predictions
942993
943-
fitted_params(mach).tree # raw tree or stump object from DecisionTree.jl
994+
julia> fitted_params(mach).tree
995+
x1 < 0.2758
996+
├─ x2 < 0.9137
997+
│ ├─ x1 < -0.9582
998+
│ │ ├─ 0.9189256882087312 (0/12)
999+
│ │ └─ -0.23180616021065256 (0/38)
1000+
│ └─ -1.6461153800037722 (0/9)
1001+
└─ x1 < 1.062
1002+
├─ x2 < -0.4969
1003+
│ ├─ -0.9330755147107384 (0/5)
1004+
│ └─ -2.3287967825015548 (0/17)
1005+
└─ x2 < 0.4598
1006+
├─ -2.931299926506291 (0/11)
1007+
└─ -4.726518740473489 (0/8)
1008+
1009+
feature_importances(mach) # get feature importances
9441010
```
9451011
9461012
See also
@@ -975,24 +1041,25 @@ where
9751041
Train the machine with `fit!(mach, rows=...)`.
9761042
9771043
978-
# Hyper-parameters
1044+
# Hyperparameters
9791045
980-
- `max_depth=-1`: max depth of the decision tree (-1=any)
1046+
- `max_depth=-1`: max depth of the decision tree (-1=any)
9811047
982-
- `min_samples_leaf=1`: min number of samples each leaf needs to have
1048+
- `min_samples_leaf=1`: min number of samples each leaf needs to have
9831049
984-
- `min_samples_split=2`: min number of samples needed for a split
1050+
- `min_samples_split=2`: min number of samples needed for a split
9851051
9861052
- `min_purity_increase=0`: min purity needed for a split
9871053
9881054
- `n_subfeatures=-1`: number of features to select at random (0 for all,
9891055
-1 for square root of number of features)
9901056
991-
- `n_trees=10`: number of trees to train
1057+
- `n_trees=10`: number of trees to train
9921058
9931059
- `sampling_fraction=0.7` fraction of samples to train each tree on
9941060
995-
- `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
1061+
- `feature_importance`: method to use for computing feature importances. One of
1062+
`(:impurity, :split)`
9961063
9971064
- `rng=Random.GLOBAL_RNG`: random number generator or seed
9981065
@@ -1015,6 +1082,13 @@ The fields of `fitted_params(mach)` are:
10151082
- `features`: the names of the features encountered in training
10161083
10171084
1085+
# Accessor functions
1086+
1087+
- `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
1088+
the type of importance is determined by the hyperparameter `feature_importance` (see
1089+
above)
1090+
1091+
10181092
# Examples
10191093
10201094
```
@@ -1029,6 +1103,7 @@ Xnew, _ = make_regression(3, 2)
10291103
yhat = predict(mach, Xnew) # new predictions
10301104
10311105
fitted_params(mach).forest # raw `Ensemble` object from DecisionTree.jl
1106+
feature_importances(mach)
10321107
```
10331108
10341109
See also

0 commit comments

Comments
 (0)