Skip to content

Commit 5811020

Browse files
salbert83ablaom
authored andcommitted
Unit tests for multithreading
1 parent f358671 commit 5811020

File tree

6 files changed

+37
-0
lines changed

6 files changed

+37
-0
lines changed

test/classification/adult.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ cm = confusion_matrix(labels, preds)
2222
f1 = impurity_importance(model)
2323
p1 = permutation_importance(model, labels, features, (model, y, X)->accuracy(y, apply_forest(model, X)), rng=StableRNG(1)).mean
2424

25+
preds_MT = apply_forest(model, features, use_multithreading = true)
26+
cm_MT = confusion_matrix(labels, preds_MT)
27+
@test cm_MT.accuracy > 0.9
28+
2529
n_iterations = 15
2630
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
2731
preds = apply_adaboost_stumps(model, coeffs, features);

test/classification/digits.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ preds = apply_forest(model, X)
8686
cm = confusion_matrix(Y, preds)
8787
@test cm.accuracy > 0.95
8888

89+
preds_MT = apply_forest(model, X, use_multithreading = true)
90+
cm_MT = confusion_matrix(Y, preds_MT)
91+
@test cm_MT.accuracy > 0.95
92+
8993
n_iterations = 100
9094
model, coeffs = DecisionTree.build_adaboost_stumps(
9195
Y, X,

test/classification/heterogeneous.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ preds = apply_forest(model, features)
2626
cm = confusion_matrix(labels, preds)
2727
@test cm.accuracy > 0.9
2828

29+
preds_MT = apply_forest(model, features, use_multithreading = true)
30+
cm_MT = confusion_matrix(labels, preds_MT)
31+
@test cm_MT.accuracy > 0.9
32+
2933
n_subfeatures = 7
3034
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1))
3135
preds = apply_adaboost_stumps(model, coeffs, features)

test/classification/iris.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ cm = confusion_matrix(labels, preds)
7979
probs = apply_forest_proba(model, features, classes)
8080
@test reshape(sum(probs, dims=2), n) ones(n)
8181

82+
preds_MT = apply_forest(model, features, use_multithreading = true)
83+
cm_MT = confusion_matrix(labels, preds_MT)
84+
@test cm_MT.accuracy > 0.95
85+
@test typeof(preds_MT) == Vector{String}
86+
@test sum(preds .!= preds_MT) == 0
87+
8288
# run n-fold cross validation for forests
8389
println("\n##### nfoldCV Classification Forest #####")
8490
n_subfeatures = 2

test/classification/low_precision.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ cm = confusion_matrix(labels, preds)
4848
@test typeof(preds) == Vector{Int32}
4949
@test cm.accuracy > 0.9
5050

51+
preds_MT = apply_forest(model, features, use_multithreading = true)
52+
cm_MT = confusion_matrix(labels, preds_MT)
53+
@test typeof(preds_MT) == Vector{Int32}
54+
@test cm_MT.accuracy > 0.9
55+
5156
n_iterations = Int32(25)
5257
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
5358
preds = apply_adaboost_stumps(model, coeffs, features);
@@ -116,6 +121,10 @@ model = build_forest(labels, features)
116121
preds = apply_forest(model, features)
117122
@test typeof(preds) == Vector{Int8}
118123

124+
preds_MT = apply_forest(model, features, use_multithreading = true)
125+
@test typeof(preds_MT) == Vector{Int8}
126+
@test sum(abs.(preds .- preds_MT)) == zero(Int8)
127+
119128
model = build_tree(labels, features)
120129
preds = apply_tree(model, features)
121130
@test typeof(preds) == Vector{Int8}

test/classification/random.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ cm = confusion_matrix(labels, preds)
5555
@test cm.accuracy > 0.9
5656
@test typeof(preds) == Vector{Int}
5757

58+
preds_MT = apply_forest(model, features, use_multithreading = true)
59+
cm_MT = confusion_matrix(labels, preds_MT)
60+
@test cm_MT.accuracy > 0.9
61+
@test typeof(preds_MT) == Vector{Int}
62+
@test sum(abs.(preds .- preds_MT)) == zero(Int)
63+
5864
n_subfeatures = 3
5965
n_trees = 9
6066
partial_sampling = 0.7
@@ -77,6 +83,10 @@ cm = confusion_matrix(labels, preds)
7783
@test cm.accuracy > 0.6
7884
@test length(model) == n_trees
7985

86+
preds_MT = apply_forest(model, features, use_multithreading = true)
87+
cm_MT = confusion_matrix(labels, preds_MT)
88+
@test cm_MT.accuracy > 0.9
89+
8090
# test n_subfeatures
8191
n_subfeatures = 0
8292
m_partial = build_forest(labels, features; rng=StableRNG(1)) # default sqrt(n_features)

0 commit comments

Comments
 (0)