Skip to content

Commit a8d448c

Browse files
authored
Merge pull request #54 from adarshpalaskar1/print-feature-names-23
Print the feature names in report.print_tree()
2 parents 51dc062 + b1a41d5 commit a8d448c

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/MLJDecisionTreeInterface.jl

+4-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ const PKG = "MLJDecisionTreeInterface"
1515

1616
struct TreePrinter{T}
1717
tree::T
18+
features::Vector{Symbol}
1819
end
19-
(c::TreePrinter)(depth) = DT.print_tree(c.tree, depth)
20-
(c::TreePrinter)() = DT.print_tree(c.tree, 5)
20+
(c::TreePrinter)(depth) = DT.print_tree(c.tree, depth, feature_names = c.features)
21+
(c::TreePrinter)() = DT.print_tree(c.tree, 5, feature_names = c.features)
2122

2223
Base.show(stream::IO, c::TreePrinter) =
2324
print(stream, "TreePrinter object (call with display depth)")
@@ -71,7 +72,7 @@ function MMI.fit(
7172
cache = nothing
7273
report = (
7374
classes_seen=classes_seen,
74-
print_tree=TreePrinter(tree),
75+
print_tree=TreePrinter(tree, features),
7576
features=features,
7677
)
7778
return fitresult, cache, report

0 commit comments

Comments
 (0)