Skip to content

Commit b06b74f

Browse files
committed
Fix tree building with N classes type parameter
1 parent 5dfd8da commit b06b74f

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/classification/main.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,13 @@ function _build_tree(
179179
impurity_importance::Bool
180180
) where {S, T}
181181
node = _convert(tree.root, tree.list, labels[tree.labels])
182+
n_classes = unique(labels) |> length
182183
if !impurity_importance
183-
return Root{S, T}(node, n_features, Float64[])
184+
return Root{S, T, n_classes}(node, n_features, Float64[])
184185
else
185186
fi = zeros(Float64, n_features)
186187
update_using_impurity!(fi, tree.root)
187-
return Root{S, T}(node, n_features, fi ./ n_samples)
188+
return Root{S, T, n_classes}(node, n_features, fi ./ n_samples)
188189
end
189190
end
190191

@@ -241,10 +242,10 @@ function prune_tree(
241242
return tree
242243
end
243244
end
244-
function _prune_run(tree::Root{S, T}, purity_thresh::Real) where {S, T}
245+
function _prune_run(tree::Root{S, T, N}, purity_thresh::Real) where {S, T, N}
245246
fi = deepcopy(tree.featim) ## recalculate feature importances
246247
node = _prune_run(tree.node, purity_thresh, fi)
247-
return Root{S, T}(node, tree.n_feat, fi)
248+
return Root{S, T, N}(node, fi)
248249
end
249250
function _prune_run(
250251
tree::LeafOrNode{S, T, N},

0 commit comments

Comments
 (0)