Skip to content

Commit 9e77eac

Browse files
authored
Merge pull request #188 from JuliaAI/salbert83-master
Multithreading for random forest predictions
2 parents 33bbec4 + 5811020 commit 9e77eac

File tree

9 files changed

+63
-3
lines changed

9 files changed

+63
-3
lines changed

src/classification/main.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,11 +449,21 @@ function apply_forest(forest::Ensemble{S, T}, features::AbstractVector{S}) where
449449
end
450450
end
451451

452-
function apply_forest(forest::Ensemble{S, T}, features::AbstractMatrix{S}) where {S, T}
452+
function apply_forest(
453+
forest::Ensemble{S, T},
454+
features::AbstractMatrix{S};
455+
use_multithreading = false
456+
) where {S, T}
453457
N = size(features,1)
454458
predictions = Array{T}(undef, N)
455-
for i in 1:N
456-
predictions[i] = apply_forest(forest, features[i, :])
459+
if use_multithreading
460+
Threads.@threads for i in 1:N
461+
predictions[i] = apply_forest(forest, @view(features[i, :]))
462+
end
463+
else
464+
for i in 1:N
465+
predictions[i] = apply_forest(forest, @view(features[i, :]))
466+
end
457467
end
458468
return predictions
459469
end

test/classification/adult.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ cm = confusion_matrix(labels, preds)
2222
f1 = impurity_importance(model)
2323
p1 = permutation_importance(model, labels, features, (model, y, X)->accuracy(y, apply_forest(model, X)), rng=StableRNG(1)).mean
2424

25+
preds_MT = apply_forest(model, features, use_multithreading = true)
26+
cm_MT = confusion_matrix(labels, preds_MT)
27+
@test cm_MT.accuracy > 0.9
28+
2529
n_iterations = 15
2630
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
2731
preds = apply_adaboost_stumps(model, coeffs, features);

test/classification/digits.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ preds = apply_forest(model, X)
8686
cm = confusion_matrix(Y, preds)
8787
@test cm.accuracy > 0.95
8888

89+
preds_MT = apply_forest(model, X, use_multithreading = true)
90+
cm_MT = confusion_matrix(Y, preds_MT)
91+
@test cm_MT.accuracy > 0.95
92+
8993
n_iterations = 100
9094
model, coeffs = DecisionTree.build_adaboost_stumps(
9195
Y, X,

test/classification/heterogeneous.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ preds = apply_forest(model, features)
2626
cm = confusion_matrix(labels, preds)
2727
@test cm.accuracy > 0.9
2828

29+
preds_MT = apply_forest(model, features, use_multithreading = true)
30+
cm_MT = confusion_matrix(labels, preds_MT)
31+
@test cm_MT.accuracy > 0.9
32+
2933
n_subfeatures = 7
3034
model, coeffs = build_adaboost_stumps(labels, features, n_subfeatures; rng=StableRNG(1))
3135
preds = apply_adaboost_stumps(model, coeffs, features)

test/classification/iris.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ cm = confusion_matrix(labels, preds)
7979
probs = apply_forest_proba(model, features, classes)
8080
@test reshape(sum(probs, dims=2), n) ones(n)
8181

82+
preds_MT = apply_forest(model, features, use_multithreading = true)
83+
cm_MT = confusion_matrix(labels, preds_MT)
84+
@test cm_MT.accuracy > 0.95
85+
@test typeof(preds_MT) == Vector{String}
86+
@test sum(preds .!= preds_MT) == 0
87+
8288
# run n-fold cross validation for forests
8389
println("\n##### nfoldCV Classification Forest #####")
8490
n_subfeatures = 2

test/classification/low_precision.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ cm = confusion_matrix(labels, preds)
4848
@test typeof(preds) == Vector{Int32}
4949
@test cm.accuracy > 0.9
5050

51+
preds_MT = apply_forest(model, features, use_multithreading = true)
52+
cm_MT = confusion_matrix(labels, preds_MT)
53+
@test typeof(preds_MT) == Vector{Int32}
54+
@test cm_MT.accuracy > 0.9
55+
5156
n_iterations = Int32(25)
5257
model, coeffs = build_adaboost_stumps(labels, features, n_iterations; rng=StableRNG(1));
5358
preds = apply_adaboost_stumps(model, coeffs, features);
@@ -116,6 +121,10 @@ model = build_forest(labels, features)
116121
preds = apply_forest(model, features)
117122
@test typeof(preds) == Vector{Int8}
118123

124+
preds_MT = apply_forest(model, features, use_multithreading = true)
125+
@test typeof(preds_MT) == Vector{Int8}
126+
@test sum(abs.(preds .- preds_MT)) == zero(Int8)
127+
119128
model = build_tree(labels, features)
120129
preds = apply_tree(model, features)
121130
@test typeof(preds) == Vector{Int8}

test/classification/random.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ cm = confusion_matrix(labels, preds)
5555
@test cm.accuracy > 0.9
5656
@test typeof(preds) == Vector{Int}
5757

58+
preds_MT = apply_forest(model, features, use_multithreading = true)
59+
cm_MT = confusion_matrix(labels, preds_MT)
60+
@test cm_MT.accuracy > 0.9
61+
@test typeof(preds_MT) == Vector{Int}
62+
@test sum(abs.(preds .- preds_MT)) == zero(Int)
63+
5864
n_subfeatures = 3
5965
n_trees = 9
6066
partial_sampling = 0.7
@@ -77,6 +83,10 @@ cm = confusion_matrix(labels, preds)
7783
@test cm.accuracy > 0.6
7884
@test length(model) == n_trees
7985

86+
preds_MT = apply_forest(model, features, use_multithreading = true)
87+
cm_MT = confusion_matrix(labels, preds_MT)
88+
@test cm_MT.accuracy > 0.9
89+
8090
# test n_subfeatures
8191
n_subfeatures = 0
8292
m_partial = build_forest(labels, features; rng=StableRNG(1)) # default sqrt(n_features)

test/regression/digits.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ model = build_forest(
8080
preds = apply_forest(model, X)
8181
@test R2(Y, preds) > 0.8
8282

83+
preds_MT = apply_forest(model, X, use_multithreading = true)
84+
@test R2(Y, preds_MT) > 0.8
85+
@test sum(abs.(preds .- preds_MT)) < 1e-8
86+
8387
println("\n##### 3 foldCV Regression Tree #####")
8488
n_folds = 5
8589
r2 = nfoldCV_tree(Y, X, n_folds; rng=StableRNG(1), verbose=false);

test/regression/low_precision.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ preds = apply_forest(model, features)
4747
@test R2(labels, preds) > 0.9
4848
@test typeof(preds) <: Vector{Float64}
4949

50+
preds_MT = apply_forest(model, features, use_multithreading=true)
51+
@test R2(labels, preds_MT) > 0.9
52+
@test typeof(preds_MT) <: Vector{Float64}
53+
@test sum(abs.(preds .- preds_MT)) < 1.0e-8
54+
5055
println("\n##### nfoldCV Regression Tree #####")
5156
n_folds = Int32(3)
5257
pruning_purity = 1.0
@@ -102,6 +107,10 @@ model = build_forest(labels, features)
102107
preds = apply_forest(model, features)
103108
@test typeof(preds) == Vector{Float16}
104109

110+
preds_MT = apply_forest(model, features, use_multithreading = true)
111+
@test typeof(preds_MT) == Vector{Float16}
112+
@test sum(abs.(preds .- preds_MT)) < 1.0e-8
113+
105114
model = build_tree(labels, features)
106115
preds = apply_tree(model, features)
107116
@test typeof(preds) == Vector{Float16}

0 commit comments

Comments
 (0)