Skip to content

Commit 8bffee7

Browse files
committed
added unit tests
1 parent 26e0a39 commit 8bffee7

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

Project.toml

+10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ version = "0.1.0"
77
DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
88
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
99

10+
[extras]
11+
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
12+
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
13+
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
14+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
16+
17+
[targets]
18+
test = ["CategoricalArrays", "MLJBase", "MLJModels", "Random", "Test"]
19+
1020
[compat]
1121
julia = "1"
1222
DecisionTree = "0.10"

test/runtests.jl

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
using Test
2+
import CategoricalArrays
3+
import CategoricalArrays.categorical
4+
using MLJBase
5+
using Random
6+
Random.seed!(1234)
7+
8+
# load code to be tested:
9+
import MLJModels
10+
import DecisionTree
11+
using MLJModels.DecisionTree_
12+
13+
# get some test data:
14+
X, y = @load_iris
15+
16+
baretree = DecisionTreeClassifier()
17+
18+
baretree.max_depth = 1
19+
fitresult, cache, report = MLJBase.fit(baretree, 2, X, y);
20+
baretree.max_depth = -1 # no max depth
21+
fitresult, cache, report =
22+
MLJBase.update(baretree, 1, fitresult, cache, X, y);
23+
24+
# in this case decision tree is a perfect predictor:
25+
yhat = MLJBase.predict_mode(baretree, fitresult, X);
26+
@test yhat == y
27+
28+
# but pruning upsets this:
29+
baretree.post_prune = true
30+
baretree.merge_purity_threshold=0.1
31+
fitresult, cache, report =
32+
MLJBase.update(baretree, 2, fitresult, cache, X, y)
33+
yhat = MLJBase.predict_mode(baretree, fitresult, X);
34+
@test yhat != y
35+
yhat = MLJBase.predict(baretree, fitresult, X);
36+
37+
# check preservation of levels:
38+
yyhat = predict_mode(baretree, fitresult, MLJBase.selectrows(X, 1:3))
39+
@test MLJBase.classes(yyhat[1]) == MLJBase.classes(y[1])
40+
41+
info_dict(baretree)
42+
43+
# # testing machine interface:
44+
# tree = machine(baretree, X, y)
45+
# fit!(tree)
46+
# yyhat = predict_mode(tree, MLJBase.selectrows(X, 1:3))
47+
using Random: seed!
48+
seed!(0)
49+
50+
n,m = 10^3, 5;
51+
raw_features = rand(n,m);
52+
weights = rand(-1:1,m);
53+
labels = raw_features * weights;
54+
features = MLJBase.table(raw_features);
55+
56+
R1Tree = DecisionTreeRegressor(min_samples_leaf=5, merge_purity_threshold=0.1)
57+
R2Tree = DecisionTreeRegressor(min_samples_split=5)
58+
model1, = MLJBase.fit(R1Tree,1, features, labels)
59+
60+
vals1 = MLJBase.predict(R1Tree,model1,features)
61+
R1Tree.post_prune = true
62+
model1_prune, = MLJBase.fit(R1Tree,1, features, labels)
63+
vals1_prune = MLJBase.predict(R1Tree,model1_prune,features)
64+
@test vals1 !=vals1_prune
65+
66+
@test DecisionTree.R2(labels, vals1) > 0.8
67+
68+
model2, = MLJBase.fit(R2Tree, 1, features, labels)
69+
vals2 = MLJBase.predict(R2Tree, model2, features)
70+
@test DecisionTree.R2(labels, vals2) > 0.8
71+
72+
73+
## TEST ON ORDINAL FEATURES OTHER THAN CONTINUOUS
74+
75+
N = 20
76+
X = (x1=rand(N), x2=categorical(rand("abc", N), ordered=true), x3=collect(1:N))
77+
yfinite = X.x2
78+
ycont = float.(X.x3)
79+
80+
rgs = DecisionTreeRegressor()
81+
fitresult, _, _ = MLJBase.fit(rgs, 1, X, ycont)
82+
@test rms(predict(rgs, fitresult, X), ycont) < 1.5
83+
84+
clf = DecisionTreeClassifier(pdf_smoothing=0)
85+
fitresult, _, _ = MLJBase.fit(clf, 1, X, yfinite)
86+
@test sum(predict(clf, fitresult, X) .== yfinite) == 0 # perfect prediction
87+
88+
info_dict(R1Tree)
89+
90+
# -- Ensemble
91+
92+
rfc = RandomForestClassifier()
93+
abs = AdaBoostStumpClassifier()
94+
95+
X, y = MLJBase.make_blobs(100, 3; rng=555)
96+
97+
m = machine(rfc, X, y)
98+
fit!(m)
99+
@test accuracy(predict_mode(m, X), y) > 0.95
100+
101+
m = machine(abs, X, y)
102+
fit!(m)
103+
@test accuracy(predict_mode(m, X), y) > 0.95
104+
105+
X, y = MLJBase.make_regression(rng=5124)
106+
rfr = RandomForestRegressor()
107+
m = machine(rfr, X, y)
108+
fit!(m)
109+
@test rms(predict(m, X), y) < 0.4

0 commit comments

Comments
 (0)