@@ -77,12 +77,27 @@ function MMI.fit(
77
77
return fitresult, cache, report
78
78
end
79
79
80
+ # returns a dictionary of categorical elements keyed on ref integer:
80
81
get_encoding (classes_seen) = Dict (MMI. int (c) => c for c in classes (classes_seen))
81
82
82
- MMI. fitted_params (:: DecisionTreeClassifier , fitresult) =
83
- (tree= fitresult[1 ],
84
- encoding= get_encoding (fitresult[2 ]),
85
- features= fitresult[4 ])
83
+ # given such a dictionary, return printable class labels, ordered by corresponding ref
84
+ # integer:
85
+ classlabels (encoding) = [string (encoding[i]) for i in sort (keys (encoding) |> collect)]
86
+
87
+ _node_or_leaf (r:: DecisionTree.Root ) = r. node
88
+ _node_or_leaf (n:: Any ) = n
89
+
90
+ function MMI. fitted_params (:: DecisionTreeClassifier , fitresult)
91
+ raw_tree = fitresult[1 ]
92
+ encoding = get_encoding (fitresult[2 ])
93
+ features = fitresult[4 ]
94
+ classlabels = MLJDecisionTreeInterface. classlabels (encoding)
95
+ tree = DecisionTree. wrap (
96
+ _node_or_leaf (raw_tree),
97
+ (featurenames= features, classlabels),
98
+ )
99
+ (; tree, raw_tree, encoding, features)
100
+ end
86
101
87
102
function MMI. predict (m:: DecisionTreeClassifier , fitresult, Xnew)
88
103
tree, classes_seen, integers_seen = fitresult
@@ -102,7 +117,7 @@ MMI.@mlj_model mutable struct RandomForestClassifier <: MMI.Probabilistic
102
117
min_samples_split:: Int = 2 :: (_ ≥ 2)
103
118
min_purity_increase:: Float64 = 0.0 :: (_ ≥ 0)
104
119
n_subfeatures:: Int = (- )(1 ):: (_ ≥ -1)
105
- n_trees:: Int = 10 :: (_ ≥ 2 )
120
+ n_trees:: Int = 100 :: (_ ≥ 0 )
106
121
sampling_fraction:: Float64 = 0.7 :: (0 < _ ≤ 1)
107
122
feature_importance:: Symbol = :impurity :: (_ ∈ (:impurity, :split) )
108
123
rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
@@ -285,13 +300,22 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)
285
300
cache = nothing
286
301
287
302
report = (features= features,)
303
+ fitresult = (tree, features)
288
304
289
- return tree , cache, report
305
+ return fitresult , cache, report
290
306
end
291
307
292
- MMI. fitted_params (:: DecisionTreeRegressor , tree) = (tree= tree,)
308
+ function MMI. fitted_params (:: DecisionTreeRegressor , fitresult)
309
+ raw_tree = fitresult[1 ]
310
+ features = fitresult[2 ]
311
+ tree = DecisionTree. wrap (
312
+ _node_or_leaf (raw_tree),
313
+ (; featurenames= features),
314
+ )
315
+ (; tree, raw_tree)
316
+ end
293
317
294
- MMI. predict (:: DecisionTreeRegressor , tree , Xnew) = DT. apply_tree (tree , Xnew)
318
+ MMI. predict (:: DecisionTreeRegressor , fitresult , Xnew) = DT. apply_tree (fitresult[ 1 ] , Xnew)
295
319
296
320
MMI. reports_feature_importances (:: Type{<:DecisionTreeRegressor} ) = true
297
321
@@ -304,7 +328,7 @@ MMI.@mlj_model mutable struct RandomForestRegressor <: MMI.Deterministic
304
328
min_samples_split:: Int = 2 :: (_ ≥ 2)
305
329
min_purity_increase:: Float64 = 0.0 :: (_ ≥ 0)
306
330
n_subfeatures:: Int = (- )(1 ):: (_ ≥ -1)
307
- n_trees:: Int = 10 :: (_ ≥ 2 )
331
+ n_trees:: Int = 100 :: (_ ≥ 0 )
308
332
sampling_fraction:: Float64 = 0.7 :: (0 < _ ≤ 1)
309
333
feature_importance:: Symbol = :impurity :: (_ ∈ (:impurity, :split) )
310
334
rng:: Union{AbstractRNG,Integer} = GLOBAL_RNG
@@ -446,11 +470,11 @@ MMI.selectrows(::TreeModel, I, Xmatrix) = (view(Xmatrix, I, :),)
446
470
447
471
# get actual arguments needed for importance calculation from various fitresults.
448
472
get_fitresult (
449
- m:: Union{DecisionTreeClassifier, RandomForestClassifier} ,
473
+ m:: Union{DecisionTreeClassifier, RandomForestClassifier, DecisionTreeRegressor } ,
450
474
fitresult,
451
475
) = (fitresult[1 ],)
452
476
get_fitresult (
453
- m:: Union{DecisionTreeRegressor, RandomForestRegressor} ,
477
+ m:: RandomForestRegressor ,
454
478
fitresult,
455
479
) = (fitresult,)
456
480
get_fitresult (m:: AdaBoostStumpClassifier , fitresult)= (fitresult[1 ], fitresult[2 ])
@@ -561,7 +585,7 @@ where
561
585
Train the machine using `fit!(mach, rows=...)`.
562
586
563
587
564
- # Hyper-parameters
588
+ # Hyperparameters
565
589
566
590
- `max_depth=-1`: max depth of the decision tree (-1=any)
567
591
@@ -600,12 +624,14 @@ Train the machine using `fit!(mach, rows=...)`.
600
624
601
625
The fields of `fitted_params(mach)` are:
602
626
603
- - `tree`: the tree or stump object returned by the core DecisionTree.jl algorithm
627
+ - `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
628
+ algorithm
629
+
630
+ - `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
631
+ interface; see "Examples" below
604
632
605
633
- `encoding`: dictionary of target classes keyed on integers used
606
- internally by DecisionTree.jl; needed to interpret pretty printing
607
- of tree (obtained by calling `fit!(mach, verbosity=2)` or from
608
- report - see below)
634
+ internally by DecisionTree.jl
609
635
610
636
- `features`: the names of the features encountered in training, in an
611
637
order consistent with the output of `print_tree` (see below)
@@ -617,23 +643,28 @@ The fields of `report(mach)` are:
617
643
618
644
- `classes_seen`: list of target classes actually observed in training
619
645
620
- - `print_tree`: method to print a pretty representation of the fitted
646
+ - `print_tree`: alternative method to print the fitted
621
647
tree, with single argument the tree depth; interpretation requires
622
648
internal integer-class encoding (see "Fitted parameters" above).
623
649
624
650
- `features`: the names of the features encountered in training, in an
625
651
order consistent with the output of `print_tree` (see below)
626
652
653
+ # Accessor functions
654
+
655
+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
656
+ the type of importance is determined by the hyperparameter `feature_importance` (see
657
+ above)
627
658
628
659
# Examples
629
660
630
661
```
631
662
using MLJ
632
- Tree = @load DecisionTreeClassifier pkg=DecisionTree
633
- tree = Tree (max_depth=4 , min_samples_split=3)
663
+ DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
664
+ model = DecisionTreeClassifier (max_depth=3 , min_samples_split=3)
634
665
635
666
X, y = @load_iris
636
- mach = machine(tree , X, y) |> fit!
667
+ mach = machine(model , X, y) |> fit!
637
668
638
669
Xnew = (sepal_length = [6.4, 7.2, 7.4],
639
670
sepal_width = [2.8, 3.0, 2.8],
@@ -643,33 +674,26 @@ yhat = predict(mach, Xnew) # probabilistic predictions
643
674
predict_mode(mach, Xnew) # point predictions
644
675
pdf.(yhat, "virginica") # probabilities for the "verginica" class
645
676
646
- fitted_params(mach).tree # raw tree or stump object from DecisionTrees.jl
647
-
648
- julia> report(mach).print_tree(3)
649
- Feature 4, Threshold 0.8
650
- L-> 1 : 50/50
651
- R-> Feature 4, Threshold 1.75
652
- L-> Feature 3, Threshold 4.95
653
- L->
654
- R->
655
- R-> Feature 3, Threshold 4.85
656
- L->
657
- R-> 3 : 43/43
658
- ```
659
-
660
- To interpret the internal class labelling:
661
-
662
- ```
663
- julia> fitted_params(mach).encoding
664
- Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
665
- 0x00000003 => "virginica"
666
- 0x00000001 => "setosa"
667
- 0x00000002 => "versicolor"
677
+ julia> tree = fitted_params(mach).tree
678
+ petal_length < 2.45
679
+ ├─ setosa (50/50)
680
+ └─ petal_width < 1.75
681
+ ├─ petal_length < 4.95
682
+ │ ├─ versicolor (47/48)
683
+ │ └─ virginica (4/6)
684
+ └─ petal_length < 4.85
685
+ ├─ virginica (2/3)
686
+ └─ virginica (43/43)
687
+
688
+ using Plots, TreeRecipe
689
+ plot(tree) # for a graphical representation of the tree
690
+
691
+ feature_importances(mach)
668
692
```
669
693
670
- See also
671
- [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and
672
- the unwrapped model type [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
694
+ See also [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and the
695
+ unwrapped model type
696
+ [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
673
697
674
698
"""
675
699
DecisionTreeClassifier
@@ -699,7 +723,7 @@ where
699
723
Train the machine with `fit!(mach, rows=...)`.
700
724
701
725
702
- # Hyper-parameters
726
+ # Hyperparameters
703
727
704
728
- `max_depth=-1`: max depth of the decision tree (-1=any)
705
729
@@ -744,6 +768,13 @@ The fields of `fitted_params(mach)` are:
744
768
- `features`: the names of the features encountered in training
745
769
746
770
771
+ # Accessor functions
772
+
773
+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
774
+ the type of importance is determined by the hyperparameter `feature_importance` (see
775
+ above)
776
+
777
+
747
778
# Examples
748
779
749
780
```
@@ -800,7 +831,7 @@ where:
800
831
Train the machine with `fit!(mach, rows=...)`.
801
832
802
833
803
- # Hyper-parameters
834
+ # Hyperparameters
804
835
805
836
- `n_iter=10`: number of iterations of AdaBoost
806
837
@@ -834,6 +865,15 @@ The fields of `fitted_params(mach)` are:
834
865
- `features`: the names of the features encountered in training
835
866
836
867
868
+ # Accessor functions
869
+
870
+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
871
+ the type of importance is determined by the hyperparameter `feature_importance` (see
872
+ above)
873
+
874
+
875
+ # Examples
876
+
837
877
```
838
878
using MLJ
839
879
Booster = @load AdaBoostStumpClassifier pkg=DecisionTree
@@ -852,6 +892,7 @@ pdf.(yhat, "virginica") # probabilities for the "verginica" class
852
892
853
893
fitted_params(mach).stumps # raw `Ensemble` object from DecisionTree.jl
854
894
fitted_params(mach).coefs # coefficient associated with each stump
895
+ feature_importances(mach)
855
896
```
856
897
857
898
See also
@@ -886,7 +927,7 @@ where
886
927
Train the machine with `fit!(mach, rows=...)`.
887
928
888
929
889
- # Hyper-parameters
930
+ # Hyperparameters
890
931
891
932
- `max_depth=-1`: max depth of the decision tree (-1=any)
892
933
@@ -903,7 +944,8 @@ Train the machine with `fit!(mach, rows=...)`.
903
944
- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
904
945
combined purity `>= merge_purity_threshold`
905
946
906
- - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
947
+ - `feature_importance`: method to use for computing feature importances. One of
948
+ `(:impurity, :split)`
907
949
908
950
- `rng=Random.GLOBAL_RNG`: random number generator or seed
909
951
@@ -921,26 +963,50 @@ The fields of `fitted_params(mach)` are:
921
963
- `tree`: the tree or stump object returned by the core
922
964
DecisionTree.jl algorithm
923
965
966
+ - `features`: the names of the features encountered in training
967
+
924
968
925
969
# Report
926
970
927
971
- `features`: the names of the features encountered in training
928
972
929
973
974
+ # Accessor functions
975
+
976
+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
977
+ the type of importance is determined by the hyperparameter `feature_importance` (see
978
+ above)
979
+
980
+
930
981
# Examples
931
982
932
983
```
933
984
using MLJ
934
- Tree = @load DecisionTreeRegressor pkg=DecisionTree
935
- tree = Tree (max_depth=4 , min_samples_split=3)
985
+ DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
986
+ model = DecisionTreeRegressor (max_depth=3 , min_samples_split=3)
936
987
937
- X, y = make_regression(100, 2 ) # synthetic data
938
- mach = machine(tree , X, y) |> fit!
988
+ X, y = make_regression(100, 4; rng=123 ) # synthetic data
989
+ mach = machine(model , X, y) |> fit!
939
990
940
- Xnew, _ = make_regression(3, 2)
991
+ Xnew, _ = make_regression(3, 2; rng=123 )
941
992
yhat = predict(mach, Xnew) # new predictions
942
993
943
- fitted_params(mach).tree # raw tree or stump object from DecisionTree.jl
994
+ julia> fitted_params(mach).tree
995
+ x1 < 0.2758
996
+ ├─ x2 < 0.9137
997
+ │ ├─ x1 < -0.9582
998
+ │ │ ├─ 0.9189256882087312 (0/12)
999
+ │ │ └─ -0.23180616021065256 (0/38)
1000
+ │ └─ -1.6461153800037722 (0/9)
1001
+ └─ x1 < 1.062
1002
+ ├─ x2 < -0.4969
1003
+ │ ├─ -0.9330755147107384 (0/5)
1004
+ │ └─ -2.3287967825015548 (0/17)
1005
+ └─ x2 < 0.4598
1006
+ ├─ -2.931299926506291 (0/11)
1007
+ └─ -4.726518740473489 (0/8)
1008
+
1009
+ feature_importances(mach) # get feature importances
944
1010
```
945
1011
946
1012
See also
@@ -975,24 +1041,25 @@ where
975
1041
Train the machine with `fit!(mach, rows=...)`.
976
1042
977
1043
978
- # Hyper-parameters
1044
+ # Hyperparameters
979
1045
980
- - `max_depth=-1`: max depth of the decision tree (-1=any)
1046
+ - `max_depth=-1`: max depth of the decision tree (-1=any)
981
1047
982
- - `min_samples_leaf=1`: min number of samples each leaf needs to have
1048
+ - `min_samples_leaf=1`: min number of samples each leaf needs to have
983
1049
984
- - `min_samples_split=2`: min number of samples needed for a split
1050
+ - `min_samples_split=2`: min number of samples needed for a split
985
1051
986
1052
- `min_purity_increase=0`: min purity needed for a split
987
1053
988
1054
- `n_subfeatures=-1`: number of features to select at random (0 for all,
989
1055
-1 for square root of number of features)
990
1056
991
- - `n_trees=10`: number of trees to train
1057
+ - `n_trees=10`: number of trees to train
992
1058
993
1059
- `sampling_fraction=0.7` fraction of samples to train each tree on
994
1060
995
- - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
1061
+ - `feature_importance`: method to use for computing feature importances. One of
1062
+ `(:impurity, :split)`
996
1063
997
1064
- `rng=Random.GLOBAL_RNG`: random number generator or seed
998
1065
@@ -1015,6 +1082,13 @@ The fields of `fitted_params(mach)` are:
1015
1082
- `features`: the names of the features encountered in training
1016
1083
1017
1084
1085
+ # Accessor functions
1086
+
1087
+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
1088
+ the type of importance is determined by the hyperparameter `feature_importance` (see
1089
+ above)
1090
+
1091
+
1018
1092
# Examples
1019
1093
1020
1094
```
@@ -1029,6 +1103,7 @@ Xnew, _ = make_regression(3, 2)
1029
1103
yhat = predict(mach, Xnew) # new predictions
1030
1104
1031
1105
fitted_params(mach).forest # raw `Ensemble` object from DecisionTree.jl
1106
+ feature_importances(mach)
1032
1107
```
1033
1108
1034
1109
See also
0 commit comments