@@ -224,22 +224,19 @@ function prune_tree(
224
224
end
225
225
ntt = nsample (tree)
226
226
function _prune_run_stump (
227
- tree:: LeafOrNode{S, T} ,
227
+ tree:: LeafOrNode{S, T, N } ,
228
228
purity_thresh:: Real ,
229
229
fi:: Vector{Float64} = Float64[]
230
- ) where {S, T}
231
- all_labels = [ tree. left. values; tree. right. values]
232
- majority = majority_vote (all_labels)
233
- matches = findall (all_labels .== majority )
234
- purity = length (matches) / length (all_labels)
230
+ ) where {S, T, N }
231
+ combined = tree. left. values .+ tree. right. values
232
+ total = tree . left . total + tree . right . total
233
+ majority = argmax (combined )
234
+ purity = combined[majority] / total
235
235
if purity >= purity_thresh
236
236
if ! isempty (fi)
237
237
update_pruned_impurity! (tree, fi, ntt, loss)
238
238
end
239
- features = Tuple (unique (all_labels))
240
- featfreq = Tuple (sum (all_labels .== f) for f in features)
241
- return Leaf {T} (features, argmax (featfreq),
242
- featfreq, length (all_labels))
239
+ return Leaf {T, N} (tree. left. features, majority, combined, total)
243
240
else
244
241
return tree
245
242
end
@@ -250,19 +247,20 @@ function prune_tree(
250
247
return Root {S, T} (node, tree. n_feat, fi)
251
248
end
252
249
function _prune_run (
253
- tree:: LeafOrNode{S, T} ,
250
+ tree:: LeafOrNode{S, T, N } ,
254
251
purity_thresh:: Real ,
255
252
fi:: Vector{Float64} = Float64[]
256
- ) where {S, T}
257
- N = length (tree)
258
- if N == 1 # # a Leaf
253
+ ) where {S, T, N }
254
+ L = length (tree)
255
+ if L == 1 # # a Leaf
259
256
return tree
260
- elseif N == 2 # # a stump
257
+ elseif L == 2 # # a stump
261
258
return _prune_run_stump (tree, purity_thresh, fi)
262
259
else
263
- left = _prune_run (tree. left, purity_thresh, fi)
264
- right = _prune_run (tree. right, purity_thresh, fi)
265
- return Node {S, T} (tree. featid, tree. featval, left, right)
260
+ return Node {S, T, N} (
261
+ tree. featid, tree. featval,
262
+ _prune_run (tree. left, purity_thresh),
263
+ _prune_run (tree. right, purity_thresh))
266
264
end
267
265
end
268
266
pruned = _prune_run (tree, purity_thresh)
0 commit comments