Skip to content

Commit f358671

Browse files
salbert83ablaom
authored andcommitted
Unit tests for multithreading
1 parent 4f19884 commit f358671

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

test/regression/digits.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ model = build_forest(
8080
preds = apply_forest(model, X)
8181
@test R2(Y, preds) > 0.8
8282

83+
preds_MT = apply_forest(model, X, use_multithreading = true)
84+
@test R2(Y, preds_MT) > 0.8
85+
@test sum(abs.(preds .- preds_MT)) < 1e-8
86+
8387
println("\n##### 3 foldCV Regression Tree #####")
8488
n_folds = 5
8589
r2 = nfoldCV_tree(Y, X, n_folds; rng=StableRNG(1), verbose=false);

test/regression/low_precision.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ preds = apply_forest(model, features)
4747
@test R2(labels, preds) > 0.9
4848
@test typeof(preds) <: Vector{Float64}
4949

50+
preds_MT = apply_forest(model, features, use_multithreading=true)
51+
@test R2(labels, preds_MT) > 0.9
52+
@test typeof(preds_MT) <: Vector{Float64}
53+
@test sum(abs.(preds .- preds_MT)) < 1.0e-8
54+
5055
println("\n##### nfoldCV Regression Tree #####")
5156
n_folds = Int32(3)
5257
pruning_purity = 1.0
@@ -102,6 +107,10 @@ model = build_forest(labels, features)
102107
preds = apply_forest(model, features)
103108
@test typeof(preds) == Vector{Float16}
104109

110+
preds_MT = apply_forest(model, features, use_multithreading = true)
111+
@test typeof(preds_MT) == Vector{Float16}
112+
@test sum(abs.(preds .- preds_MT)) < 1.0e-8
113+
105114
model = build_tree(labels, features)
106115
preds = apply_tree(model, features)
107116
@test typeof(preds) == Vector{Float16}

0 commit comments

Comments
 (0)