Skip to content

Commit ee97f79

Browse files
committed
Directly operate on leaf tuples
1 parent d92e5aa commit ee97f79

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/classification/main.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ of the output matrix.
320320
apply_tree_proba(tree::Root{S, T}, features::AbstractVector{S}, labels) where {S, T} =
321321
apply_tree_proba(tree.node, features, labels)
322322
apply_tree_proba(leaf::Leaf{T}, features::AbstractVector{S}, labels) where {S, T} =
323-
collect(leaf.values ./ leaf.total)
323+
leaf.values ./ leaf.total
324324

325325
function apply_tree_proba(
326326
tree::Node{S, T},
@@ -335,10 +335,13 @@ function apply_tree_proba(
335335
return apply_tree_proba(tree.right, features, labels)
336336
end
337337
end
338-
apply_tree_proba(tree::Root{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
339-
apply_tree_proba(tree.node, features, labels)
340-
apply_tree_proba(tree::LeafOrNode{S, T}, features::AbstractMatrix{S}, labels) where {S, T} =
341-
stack_function_results(row->apply_tree_proba(tree, row, labels), features)
338+
function apply_tree_proba(tree::Root{S, T}, features::AbstractMatrix{S}, labels) where {S, T}
339+
predictions = Vector{NTuple{length(labels), Float64}}(undef, size(features, 1))
340+
for i in 1:size(features, 1)
341+
predictions[i] = apply_tree_proba(tree, view(features, i, :), labels)
342+
end
343+
reinterpret(reshape, Float64, predictions) |> transpose |> Matrix
344+
end
342345

343346
function build_forest(
344347
labels :: AbstractVector{T},

0 commit comments

Comments
 (0)