@@ -113,27 +113,25 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
113113 if purity_thresh >= 1.0
114114 return tree
115115 end
116- function _prune_run (tree:: LeafOrNode{S, T} , purity_thresh:: Real ) where {S, T}
117- N = length (tree)
118- if N == 1 # # a Leaf
116+ function _prune_run (tree:: LeafOrNode{S, T, N } , purity_thresh:: Real ) where {S, T, N }
117+ L = length (tree)
118+ if L == 1 # # a Leaf
119119 return tree
120- elseif N == 2 # # a stump
121- all_labels = [ tree. left. values; tree. right. values]
122- majority = majority_vote (all_labels)
123- matches = findall (all_labels .== majority )
124- purity = length (matches) / length (all_labels)
120+ elseif L == 2 # # a stump
121+ combined = tree. left. values .+ tree. right. values
122+ total = tree . left . total + tree . right . total
123+ majority = argmax (combined )
124+ purity = combined[majority] / total
125125 if purity >= purity_thresh
126- features = Tuple (unique (all_labels))
127- featfreq = Tuple (sum (all_labels .== f) for f in features)
128- return Leaf {T} (features, argmax (featfreq),
129- featfreq, length (all_labels))
126+ return Leaf {T, N} (tree. left. features, majority, combined, total)
130127 else
131128 return tree
132129 end
133130 else
134- return Node {S, T} (tree. featid, tree. featval,
135- _prune_run (tree. left, purity_thresh),
136- _prune_run (tree. right, purity_thresh))
131+ return Node {S, T, N} (
132+ tree. featid, tree. featval,
133+ _prune_run (tree. left, purity_thresh),
134+ _prune_run (tree. right, purity_thresh))
137135 end
138136 end
139137 pruned = _prune_run (tree, purity_thresh)
0 commit comments