@@ -26,16 +26,18 @@ export InfoNode, InfoLeaf, wrap
26
26
# ##########################
27
27
# ######### Types ##########
28
28
29
- struct Leaf{T}
30
- majority :: T
31
- values :: Vector{T}
29
+ struct Leaf{T, N}
30
+ features :: NTuple{N, T}
31
+ majority :: Int
32
+ values :: NTuple{N, Int}
33
+ total :: Int
32
34
end
33
35
34
- struct Node{S, T}
36
+ struct Node{S, T, N }
35
37
featid :: Int
36
38
featval :: S
37
- left :: Union{Leaf{T}, Node{S, T}}
38
- right :: Union{Leaf{T}, Node{S, T}}
39
+ left :: Union{Leaf{T, N }, Node{S, T, N }}
40
+ right :: Union{Leaf{T, N }, Node{S, T, N }}
39
41
end
40
42
41
43
const LeafOrNode{S, T} = Union{Leaf{T}, Node{S, T}}
@@ -52,13 +54,15 @@ struct Ensemble{S, T}
52
54
featim :: Vector{Float64}
53
55
end
54
56
57
+ Leaf (features:: NTuple{T, N} ) where {T, N} =
58
+ Leaf (features, 0 , Tuple (zeros (T, N)), 0 )
55
59
56
60
is_leaf (l:: Leaf ) = true
57
61
is_leaf (n:: Node ) = false
58
62
59
63
_zero (:: Type{String} ) = " "
60
64
_zero (x:: Any ) = zero (x)
61
- convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , _zero (S), lf, Leaf (_zero (T), [ _zero (T)] ))
65
+ convert (:: Type{Node{S, T}} , lf:: Leaf{T} ) where {S, T} = Node (0 , _zero (S), lf, Leaf (lf . features ))
62
66
convert (:: Type{Root{S, T}} , node:: LeafOrNode{S, T} ) where {S, T} = Root {S, T} (node, 0 , Float64[])
63
67
convert (:: Type{LeafOrNode{S, T}} , tree:: Root{S, T} ) where {S, T} = tree. node
64
68
promote_rule (:: Type{Node{S, T}} , :: Type{Leaf{T}} ) where {S, T} = Node{S, T}
@@ -97,9 +101,8 @@ depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
97
101
depth (tree:: Root ) = depth (tree. node)
98
102
99
103
function print_tree (io:: IO , leaf:: Leaf , depth= - 1 , indent= 0 ; sigdigits= 4 , feature_names= nothing )
100
- n_matches = count (leaf. values .== leaf. majority)
101
- ratio = string (n_matches, " /" , length (leaf. values))
102
- println (io, " $(leaf. majority) : $(ratio) " )
104
+ println (io, leaf. features[leaf. majority], " : " ,
105
+ leaf. values[leaf. majority], ' /' , leaf. total)
103
106
end
104
107
function print_tree (leaf:: Leaf , depth= - 1 , indent= 0 ; sigdigits= 4 , feature_names= nothing )
105
108
return print_tree (stdout , leaf, depth, indent; sigdigits, feature_names)
162
165
163
166
function show (io:: IO , leaf:: Leaf )
164
167
println (io, " Decision Leaf" )
165
- println (io, " Majority: $( leaf. majority) " )
166
- print (io, " Samples: $( length ( leaf. values)) " )
168
+ println (io, " Majority: " , leaf. features[leaf . majority] )
169
+ print (io, " Samples: " , leaf. total )
167
170
end
168
171
169
172
function show (io:: IO , tree:: Node )
0 commit comments