Skip to content

Commit 4857294

Browse files
committed
Fix tree pruning with NTuples
1 parent 1a452f4 commit 4857294

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

src/DecisionTree.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ struct Node{S, T, N}
4242
right :: Union{Leaf{T, N}, Node{S, T, N}}
4343
end
4444

45-
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
45+
const LeafOrNode{S, T, N} = Union{Leaf{T, N}, Node{S, T, N}}
4646

47-
struct Ensemble{S, T}
48-
trees :: Vector{LeafOrNode{S, T}}
47+
struct Ensemble{S, T, N}
48+
trees :: Vector{LeafOrNode{S, T, N}}
4949
end
5050

5151
Leaf(features::NTuple{T, N}) where {T, N} =

src/classification/main.jl

+13-15
Original file line numberDiff line numberDiff line change
@@ -113,27 +113,25 @@ function prune_tree(tree::LeafOrNode{S, T}, purity_thresh=1.0) where {S, T}
113113
if purity_thresh >= 1.0
114114
return tree
115115
end
116-
function _prune_run(tree::LeafOrNode{S, T}, purity_thresh::Real) where {S, T}
117-
N = length(tree)
118-
if N == 1 ## a Leaf
116+
function _prune_run(tree::LeafOrNode{S, T, N}, purity_thresh::Real) where {S, T, N}
117+
L = length(tree)
118+
if L == 1 ## a Leaf
119119
return tree
120-
elseif N == 2 ## a stump
121-
all_labels = [tree.left.values; tree.right.values]
122-
majority = majority_vote(all_labels)
123-
matches = findall(all_labels .== majority)
124-
purity = length(matches) / length(all_labels)
120+
elseif L == 2 ## a stump
121+
combined = tree.left.values .+ tree.right.values
122+
total = tree.left.total + tree.right.total
123+
majority = argmax(combined)
124+
purity = combined[majority] / total
125125
if purity >= purity_thresh
126-
features = Tuple(unique(all_labels))
127-
featfreq = Tuple(sum(all_labels .== f) for f in features)
128-
return Leaf{T}(features, argmax(featfreq),
129-
featfreq, length(all_labels))
126+
return Leaf{T, N}(tree.left.features, majority, combined, total)
130127
else
131128
return tree
132129
end
133130
else
134-
return Node{S, T}(tree.featid, tree.featval,
135-
_prune_run(tree.left, purity_thresh),
136-
_prune_run(tree.right, purity_thresh))
131+
return Node{S, T, N}(
132+
tree.featid, tree.featval,
133+
_prune_run(tree.left, purity_thresh),
134+
_prune_run(tree.right, purity_thresh))
137135
end
138136
end
139137
pruned = _prune_run(tree, purity_thresh)

0 commit comments

Comments
 (0)