@@ -113,27 +113,25 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
113
113
if purity_thresh >= 1.0
114
114
return tree
115
115
end
116
- function _prune_run (tree:: LeafOrNode{S, T} , purity_thresh:: Real ) where {S, T}
117
- N = length (tree)
118
- if N == 1 # # a Leaf
116
+ function _prune_run (tree:: LeafOrNode{S, T, N } , purity_thresh:: Real ) where {S, T, N }
117
+ L = length (tree)
118
+ if L == 1 # # a Leaf
119
119
return tree
120
- elseif N == 2 # # a stump
121
- all_labels = [ tree. left. values; tree. right. values]
122
- majority = majority_vote (all_labels)
123
- matches = findall (all_labels .== majority )
124
- purity = length (matches) / length (all_labels)
120
+ elseif L == 2 # # a stump
121
+ combined = tree. left. values .+ tree. right. values
122
+ total = tree . left . total + tree . right . total
123
+ majority = argmax (combined )
124
+ purity = combined[majority] / total
125
125
if purity >= purity_thresh
126
- features = Tuple (unique (all_labels))
127
- featfreq = Tuple (sum (all_labels .== f) for f in features)
128
- return Leaf {T} (features, argmax (featfreq),
129
- featfreq, length (all_labels))
126
+ return Leaf {T, N} (tree. left. features, majority, combined, total)
130
127
else
131
128
return tree
132
129
end
133
130
else
134
- return Node {S, T} (tree. featid, tree. featval,
135
- _prune_run (tree. left, purity_thresh),
136
- _prune_run (tree. right, purity_thresh))
131
+ return Node {S, T, N} (
132
+ tree. featid, tree. featval,
133
+ _prune_run (tree. left, purity_thresh),
134
+ _prune_run (tree. right, purity_thresh))
137
135
end
138
136
end
139
137
pruned = _prune_run (tree, purity_thresh)
0 commit comments