Skip to content

Commit f7b73d1

Browse files
committed
Fix more test results
1 parent d178d40 commit f7b73d1

File tree

4 files changed

+38
-26
lines changed

4 files changed

+38
-26
lines changed

src/DecisionTree.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ export InfoNode, InfoLeaf, wrap
2727
########## Types ##########
2828

2929
struct Leaf{T, N}
30-
features :: NTuple{N, T}
30+
classes :: NTuple{N, T}
3131
majority :: Int
3232
values :: NTuple{N, Int}
3333
total :: Int
@@ -54,15 +54,20 @@ struct Ensemble{S, T, N}
5454
featim :: Vector{Float64}
5555
end
5656

57-
Leaf(features::NTuple{T, N}) where {T, N} =
57+
Leaf(features::NTuple{N, T}) where {T, N} =
5858
Leaf(features, 0, Tuple(zeros(T, N)), 0)
59+
Leaf(features::NTuple{N, T}, frequencies::NTuple{N, Int}) where {T, N} =
60+
Leaf(features, argmax(frequencies), frequencies, sum(frequencies))
61+
Leaf(features::Union{<:AbstractVector, <:Tuple},
62+
frequencies::Union{<:AbstractVector{Int}, <:Tuple}) =
63+
Leaf(Tuple(features), Tuple(frequencies))
5964

6065
is_leaf(l::Leaf) = true
6166
is_leaf(n::Node) = false
6267

6368
_zero(::Type{String}) = ""
6469
_zero(x::Any) = zero(x)
65-
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.features))
70+
convert(::Type{Node{S, T, N}}, lf::Leaf{T, N}) where {S, T, N} = Node(0, _zero(S), lf, Leaf(lf.classes))
6671
convert(::Type{Root{S, T, N}}, node::LeafOrNode{S, T, N}) where {S, T, N} = Root{S, T, N}(node, 0, Float64[])
6772
convert(::Type{LeafOrNode{S, T, N}}, tree::Root{S, T, N}) where {S, T, N} = tree.node
6873
promote_rule(::Type{Node{S, T, N}}, ::Type{Leaf{T, N}}) where {S, T, N} = Node{S, T, N}
@@ -101,7 +106,7 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
101106
depth(tree::Root) = depth(tree.node)
102107

103108
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
104-
println(io, leaf.features[leaf.majority], " : ",
109+
println(io, leaf.classes[leaf.majority], " : ",
105110
leaf.values[leaf.majority], '/', leaf.total)
106111
end
107112
function print_tree(leaf::Leaf, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
@@ -165,7 +170,7 @@ end
165170

166171
function show(io::IO, leaf::Leaf)
167172
println(io, "Decision Leaf")
168-
println(io, "Majority: ", leaf.features[leaf.majority])
173+
println(io, "Majority: ", leaf.classes[leaf.majority])
169174
print(io, "Samples: ", leaf.total)
170175
end
171176

src/classification/main.jl

+22-15
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ function build_stump(
117117
labels :: AbstractVector{T},
118118
features :: AbstractMatrix{S},
119119
weights = nothing;
120+
n_classes :: Int = length(unique(labels)),
120121
rng = Random.GLOBAL_RNG,
121122
impurity_importance :: Bool = true) where {S, T}
122123

@@ -133,7 +134,7 @@ function build_stump(
133134
min_purity_increase = 0.0;
134135
rng = rng)
135136

136-
return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
137+
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
137138
end
138139

139140
function build_tree(
@@ -144,6 +145,7 @@ function build_tree(
144145
min_samples_leaf = 1,
145146
min_samples_split = 2,
146147
min_purity_increase = 0.0;
148+
n_classes :: Int = length(unique(labels)),
147149
loss = util.entropy :: Function,
148150
rng = Random.GLOBAL_RNG,
149151
impurity_importance :: Bool = true) where {S, T}
@@ -168,18 +170,18 @@ function build_tree(
168170
min_purity_increase = Float64(min_purity_increase),
169171
rng = rng)
170172

171-
return _build_tree(t, labels, size(features, 2), size(features, 1), impurity_importance)
173+
return _build_tree(t, labels, n_classes, size(features, 2), size(features, 1), impurity_importance)
172174
end
173175

174176
function _build_tree(
175177
tree::treeclassifier.Tree{S, T},
176178
labels::AbstractVector{T},
179+
n_classes::Int,
177180
n_features,
178181
n_samples,
179182
impurity_importance::Bool
180183
) where {S, T}
181184
node = _convert(tree.root, tree.list, labels[tree.labels])
182-
n_classes = unique(labels) |> length
183185
if !impurity_importance
184186
return Root{S, T, n_classes}(node, n_features, Float64[])
185187
else
@@ -237,15 +239,15 @@ function prune_tree(
237239
if !isempty(fi)
238240
update_pruned_impurity!(tree, fi, ntt, loss)
239241
end
240-
return Leaf{T, N}(tree.left.features, majority, combined, total)
242+
return Leaf{T, N}(tree.left.classes, majority, combined, total)
241243
else
242244
return tree
243245
end
244246
end
245247
function _prune_run(tree::Root{S, T, N}, purity_thresh::Real) where {S, T, N}
246248
fi = deepcopy(tree.featim) ## recalculate feature importances
247249
node = _prune_run(tree.node, purity_thresh, fi)
248-
return Root{S, T, N}(node, fi)
250+
return Root{S, T, N}(node, tree.n_feat, fi)
249251
end
250252
function _prune_run(
251253
tree::LeafOrNode{S, T, N},
@@ -273,7 +275,7 @@ function prune_tree(
273275
end
274276

275277

276-
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.features[leaf.majority]
278+
apply_tree(leaf::Leaf, feature::AbstractVector) = leaf.classes[leaf.majority]
277279
apply_tree(
278280
tree::Root{S, T},
279281
features::AbstractVector{S}
@@ -369,10 +371,11 @@ function build_forest(
369371

370372
t_samples = length(labels)
371373
n_samples = floor(Int, partial_sampling * t_samples)
374+
n_classes = length(unique(labels))
372375

373376
forest = impurity_importance ?
374-
Vector{Root{S, T}}(undef, n_trees) :
375-
Vector{LeafOrNode{S, T}}(undef, n_trees)
377+
Vector{Root{S, T, n_classes}}(undef, n_trees) :
378+
Vector{LeafOrNode{S, T, n_classes}}(undef, n_trees)
376379

377380
entropy_terms = util.compute_entropy_terms(n_samples)
378381
loss = (ns, n) -> util.entropy(ns, n, entropy_terms)
@@ -392,7 +395,8 @@ function build_forest(
392395
max_depth,
393396
min_samples_leaf,
394397
min_samples_split,
395-
min_purity_increase,
398+
min_purity_increase;
399+
n_classes,
396400
loss = loss,
397401
rng = _rng,
398402
impurity_importance = impurity_importance)
@@ -408,7 +412,8 @@ function build_forest(
408412
max_depth,
409413
min_samples_leaf,
410414
min_samples_split,
411-
min_purity_increase,
415+
min_purity_increase;
416+
n_classes,
412417
loss = loss,
413418
impurity_importance = impurity_importance)
414419
end
@@ -418,13 +423,13 @@ function build_forest(
418423
end
419424

420425
function _build_forest(
421-
forest :: Vector{<: Union{Root{S, T}, LeafOrNode{S, T}}},
426+
forest :: Vector{<: Union{Root{S, T, N}, LeafOrNode{S, T, N}}},
422427
n_features ,
423428
n_trees ,
424-
impurity_importance :: Bool) where {S, T}
429+
impurity_importance :: Bool) where {S, T, N}
425430

426431
if !impurity_importance
427-
return Ensemble{S, T}(forest, n_features, Float64[])
432+
return Ensemble{S, T, N}(forest, n_features, Float64[])
428433
else
429434
fi = zeros(Float64, n_features)
430435
for tree in forest
@@ -434,12 +439,12 @@ function _build_forest(
434439
end
435440
end
436441

437-
forest_new = Vector{LeafOrNode{S, T}}(undef, n_trees)
442+
forest_new = Vector{LeafOrNode{S, T, N}}(undef, n_trees)
438443
Threads.@threads for i in 1:n_trees
439444
forest_new[i] = forest[i].node
440445
end
441446

442-
return Ensemble{S, T}(forest_new, n_features, fi ./ n_trees)
447+
return Ensemble{S, T, N}(forest_new, n_features, fi ./ n_trees)
443448
end
444449
end
445450

@@ -516,11 +521,13 @@ function build_adaboost_stumps(
516521
stumps = Node{S, T}[]
517522
coeffs = Float64[]
518523
n_features = size(features, 2)
524+
n_classes = length(unique(labels))
519525
for i in 1:n_iterations
520526
new_stump = build_stump(
521527
labels,
522528
features,
523529
weights;
530+
n_classes,
524531
rng=mk_rng(rng),
525532
impurity_importance=false
526533
)

test/miscellaneous/abstract_trees_test.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ clabel_pattern(clabel) = "─ " * clabel * " (" # class labels are embedde
1717
check_occurence(str_tree, pool, pattern) = count(map(elem -> occursin(pattern(elem), str_tree), pool)) == length(pool)
1818

1919
@info("Test base functionality")
20-
l1 = Leaf(1, [1,1,2])
21-
l2 = Leaf(2, [1,2,2])
22-
l3 = Leaf(3, [3,3,1])
20+
l1 = Leaf((1,2,3), 1, (2, 1, 0), 3)
21+
l2 = Leaf((1,2,3), 2, (1, 2, 0), 3)
22+
l3 = Leaf((1,2,3), 3, (1, 0, 2), 3)
2323
n2 = Node(2, 0.5, l2, l3)
2424
n1 = Node(1, 0.7, l1, n2)
2525
feature_names = ["firstFt", "secondFt"]
@@ -81,4 +81,4 @@ end
8181
traverse_tree(leaf::InfoLeaf) = nothing
8282

8383
traverse_tree(wrapped_tree)
84-
end
84+
end

test/miscellaneous/convert.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
@testset "convert.jl" begin
44

5-
lf = Leaf(1, [1])
5+
lf = Leaf([1], [1])
66
nv = Node{Int, Int}[]
77
rv = Root{Int, Int}[]
88
push!(nv, lf)
@@ -22,7 +22,7 @@ push!(rv, nv[1])
2222
@test apply_tree(rv[1], [0]) == 1.0
2323
@test apply_tree(rv[2], [0]) == 1.0
2424

25-
lf = Leaf("A", ["B", "A"])
25+
lf = Leaf(["A", "B"], [2, 1])
2626
nv = Node{Int, String}[]
2727
rv = Root{Int, String}[]
2828
push!(nv, lf)

0 commit comments

Comments
 (0)