@@ -179,12 +179,13 @@ function _build_tree(
179
179
impurity_importance:: Bool
180
180
) where {S, T}
181
181
node = _convert (tree. root, tree. list, labels[tree. labels])
182
+ n_classes = unique (labels) |> length
182
183
if ! impurity_importance
183
- return Root {S, T} (node, n_features, Float64[])
184
+ return Root {S, T, n_classes } (node, n_features, Float64[])
184
185
else
185
186
fi = zeros (Float64, n_features)
186
187
update_using_impurity! (fi, tree. root)
187
- return Root {S, T} (node, n_features, fi ./ n_samples)
188
+ return Root {S, T, n_classes } (node, n_features, fi ./ n_samples)
188
189
end
189
190
end
190
191
@@ -241,10 +242,10 @@ function prune_tree(
241
242
return tree
242
243
end
243
244
end
244
- function _prune_run (tree:: Root{S, T} , purity_thresh:: Real ) where {S, T}
245
+ function _prune_run (tree:: Root{S, T, N } , purity_thresh:: Real ) where {S, T, N }
245
246
fi = deepcopy (tree. featim) # # recalculate feature importances
246
247
node = _prune_run (tree. node, purity_thresh, fi)
247
- return Root {S, T} (node, tree . n_feat , fi)
248
+ return Root {S, T, N } (node, fi)
248
249
end
249
250
function _prune_run (
250
251
tree:: LeafOrNode{S, T, N} ,
0 commit comments