Skip to content

Commit d92e5aa

Browse files
committed
Turn Leaf struct into a frequency map
1 parent 9dab9c1 commit d92e5aa

File tree

3 files changed

+30
-18
lines changed

3 files changed

+30
-18
lines changed

src/DecisionTree.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,18 @@ export InfoNode, InfoLeaf, wrap
2626
###########################
2727
########## Types ##########
2828

29-
struct Leaf{T}
30-
majority :: T
31-
values :: Vector{T}
29+
struct Leaf{T, N}
30+
features :: NTuple{N, T}
31+
majority :: Int
32+
values :: NTuple{N, Int}
33+
total :: Int
3234
end
3335

34-
struct Node{S, T}
36+
struct Node{S, T, N}
3537
featid :: Int
3638
featval :: S
37-
left :: Union{Leaf{T}, Node{S, T}}
38-
right :: Union{Leaf{T}, Node{S, T}}
39+
left :: Union{Leaf{T, N}, Node{S, T, N}}
40+
right :: Union{Leaf{T, N}, Node{S, T, N}}
3941
end
4042

4143
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
@@ -52,13 +54,15 @@ struct Ensemble{S, T}
5254
featim :: Vector{Float64}
5355
end
5456

57+
Leaf(features::NTuple{T, N}) where {T, N} =
58+
Leaf(features, 0, Tuple(zeros(T, N)), 0)
5559

5660
is_leaf(l::Leaf) = true
5761
is_leaf(n::Node) = false
5862

5963
_zero(::Type{String}) = ""
6064
_zero(x::Any) = zero(x)
61-
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(_zero(T), [_zero(T)]))
65+
convert(::Type{Node{S, T}}, lf::Leaf{T}) where {S, T} = Node(0, _zero(S), lf, Leaf(lf.features))
6266
convert(::Type{Root{S, T}}, node::LeafOrNode{S, T}) where {S, T} = Root{S, T}(node, 0, Float64[])
6367
convert(::Type{LeafOrNode{S, T}}, tree::Root{S, T}) where {S, T} = tree.node
6468
promote_rule(::Type{Node{S, T}}, ::Type{Leaf{T}}) where {S, T} = Node{S, T}
@@ -97,9 +101,8 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
97101
depth(tree::Root) = depth(tree.node)
98102

99103
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
100-
n_matches = count(leaf.values .== leaf.majority)
101-
ratio = string(n_matches, "/", length(leaf.values))
102-
println(io, "$(leaf.majority) : $(ratio)")
104+
println(io, leaf.features[leaf.majority], " : ",
105+
leaf.values[leaf.majority], '/', leaf.total)
103106
end
104107
function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
105108
return print_tree(stdout, leaf, depth, indent; sigdigits, feature_names)
@@ -162,8 +165,8 @@ end
162165

163166
function show(io::IO, leaf::Leaf)
164167
println(io, "Decision Leaf")
165-
println(io, "Majority: $(leaf.majority)")
166-
print(io, "Samples: $(length(leaf.values))")
168+
println(io, "Majority: ", leaf.features[leaf.majority])
169+
print(io, "Samples: ", leaf.total)
167170
end
168171

169172
function show(io::IO, tree::Node)

src/classification/main.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ function _convert(
4141
) where {S, T}
4242

4343
if node.is_leaf
44-
return Leaf{T}(list[node.label], labels[node.region])
44+
featfreq = Tuple(sum(labels[node.region] .== l) for l in list)
45+
return Leaf{T, length(list)}(
46+
Tuple(list), argmax(featfreq), featfreq, length(node.region))
4547
else
4648
left = _convert(node.l, list, labels)
4749
right = _convert(node.r, list, labels)
48-
return Node{S, T}(node.feature, node.threshold, left, right)
50+
return Node{S, T, length(list)}(
51+
node.feature, node.threshold, left, right)
4952
end
5053
end
5154

@@ -233,7 +236,10 @@ function prune_tree(
233236
if !isempty(fi)
234237
update_pruned_impurity!(tree, fi, ntt, loss)
235238
end
236-
return Leaf{T}(majority, all_labels)
239+
features = Tuple(unique(all_labels))
240+
featfreq = Tuple(sum(all_labels .== f) for f in features)
241+
return Leaf{T}(features, argmax(featfreq),
242+
featfreq, length(all_labels))
237243
else
238244
return tree
239245
end
@@ -268,7 +274,7 @@ function prune_tree(
268274
end
269275

270276

271-
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.majority
277+
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.features[leaf.majority]
272278
apply_tree(
273279
tree::Root{S, T},
274280
features::AbstractVector{S}
@@ -314,7 +320,7 @@ of the output matrix.
314320
apply_tree_proba(tree::Root{S, T}, features::AbstractVector{S}, labels) where {S, T} =
315321
apply_tree_proba(tree.node, features, labels)
316322
apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} =
317-
compute_probabilities(labels, leaf.values)
323+
collect(leaf.values ./ leaf.total)
318324

319325
function apply_tree_proba(
320326
tree::Node{S, T},

src/regression/main.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@ include("tree.jl")
22

33
function _convert(node::treeregressor.NodeMeta{S}, labels::Array{T}) where {S, T <: Float64}
44
if node.is_leaf
5-
return Leaf{T}(node.label, labels[node.region])
5+
features = Tuple(unique(labels))
6+
featfreq = Tuple(sum(labels[node.region] .== f) for f in features)
7+
return Leaf{T, length(features)}(
8+
features, argmax(featfreq), featfreq, length(node.region))
69
else
710
left = _convert(node.l, labels)
811
right = _convert(node.r, labels)

0 commit comments

Comments
 (0)