Skip to content

Commit 5dfd8da

Browse files
committed
Fix tree pruning with NTuples
1 parent ee97f79 commit 5dfd8da

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

src/DecisionTree.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,16 @@ struct Node{S, T, N}
4040
right :: Union{Leaf{T, N}, Node{S, T, N}}
4141
end
4242

43-
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
43+
const LeafOrNode{S, T, N} = Union{Leaf{T, N}, Node{S, T, N}}
4444

45-
struct Root{S, T}
46-
node :: LeafOrNode{S, T}
45+
struct Root{S, T, N}
46+
node :: LeafOrNode{S, T, N}
4747
n_feat :: Int
4848
featim :: Vector{Float64} # impurity importance
4949
end
5050

51-
struct Ensemble{S, T}
52-
trees :: Vector{LeafOrNode{S, T}}
51+
struct Ensemble{S, T, N}
52+
trees :: Vector{LeafOrNode{S, T, N}}
5353
n_feat :: Int
5454
featim :: Vector{Float64}
5555
end

src/classification/main.jl

+16-18
Original file line numberDiff line numberDiff line change
@@ -224,22 +224,19 @@ function prune_tree(
224224
end
225225
ntt = nsample(tree)
226226
function _prune_run_stump(
227-
tree::LeafOrNode{S, T},
227+
tree::LeafOrNode{S, T, N},
228228
purity_thresh::Real,
229229
fi::Vector{Float64} = Float64[]
230-
) where {S, T}
231-
all_labels = [tree.left.values; tree.right.values]
232-
majority = majority_vote(all_labels)
233-
matches = findall(all_labels .== majority)
234-
purity = length(matches) / length(all_labels)
230+
) where {S, T, N}
231+
combined = tree.left.values .+ tree.right.values
232+
total = tree.left.total + tree.right.total
233+
majority = argmax(combined)
234+
purity = combined[majority] / total
235235
if purity >= purity_thresh
236236
if !isempty(fi)
237237
update_pruned_impurity!(tree, fi, ntt, loss)
238238
end
239-
features = Tuple(unique(all_labels))
240-
featfreq = Tuple(sum(all_labels .== f) for f in features)
241-
return Leaf{T}(features, argmax(featfreq),
242-
featfreq, length(all_labels))
239+
return Leaf{T, N}(tree.left.features, majority, combined, total)
243240
else
244241
return tree
245242
end
@@ -250,19 +247,20 @@ function prune_tree(
250247
return Root{S, T}(node, tree.n_feat, fi)
251248
end
252249
function _prune_run(
253-
tree::LeafOrNode{S, T},
250+
tree::LeafOrNode{S, T, N},
254251
purity_thresh::Real,
255252
fi::Vector{Float64} = Float64[]
256-
) where {S, T}
257-
N = length(tree)
258-
if N == 1 ## a Leaf
253+
) where {S, T, N}
254+
L = length(tree)
255+
if L == 1 ## a Leaf
259256
return tree
260-
elseif N == 2 ## a stump
257+
elseif L == 2 ## a stump
261258
return _prune_run_stump(tree, purity_thresh, fi)
262259
else
263-
left = _prune_run(tree.left, purity_thresh, fi)
264-
right = _prune_run(tree.right, purity_thresh, fi)
265-
return Node{S, T}(tree.featid, tree.featval, left, right)
260+
return Node{S, T, N}(
261+
tree.featid, tree.featval,
262+
_prune_run(tree.left, purity_thresh),
263+
_prune_run(tree.right, purity_thresh))
266264
end
267265
end
268266
pruned = _prune_run(tree, purity_thresh)

0 commit comments

Comments
 (0)