Skip to content

Commit ec0d871

Browse files
author
[zebinyang]
committed
add custom plots for plot_tree function; update version
1 parent 225799f commit ec0d871

File tree

4 files changed

+35
-23
lines changed

4 files changed

+35
-23
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ examples/.ipynb_checkpoints/*
22
history/
33
scripts/
44
dist/
5+
simtree.egg-info/

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22

33
setup(name='simtree',
4-
version='0.1.1',
4+
version='0.1.2',
55
description='Single-index model tree',
66
url='https://github.com/ZebinYang/SIMTree',
77
author='Zebin Yang',

simtree/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
"SIMTreeRegressor", "SIMTreeClassifier",
99
"CustomMobTreeRegressor", "CustomMobTreeClassifier"]
1010

11-
__version__ = '0.1.1'
11+
__version__ = '0.1.2'
1212
__author__ = 'Zebin Yang'

simtree/mobtree.py

+32-21
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def node_split(self, sample_indice):
179179
if sortted_feature[i + 1] <= sortted_feature[i] + self.EPSILON:
180180
continue
181181

182-
if self.min_samples_leaf < n_samples / (self.n_split_grid - 1):
182+
if self.min_samples_leaf < n_samples / max((self.n_split_grid - 1), 2):
183183
if (i + 1) / n_samples < (split_point + 1) / (self.n_split_grid + 1):
184184
continue
185185
elif n_samples > 2 * self.min_samples_leaf:
@@ -312,54 +312,62 @@ def fit(self, x, y):
312312
"is_left": False})
313313
return self
314314

315-
def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=False):
315+
def plot_tree(self, draw_depth=np.inf, start_node_id=1, folder="./results/", name="demo", save_png=False, save_eps=False):
316316

317+
idx = 0
318+
draw_subtree = {}
317319
draw_tree = copy.deepcopy(self.tree)
318-
pending_node_list = [draw_tree[1]]
319-
max_depth = 1 + np.max([item["depth"] for key, item in self.tree.items()])
320+
pending_node_list = [draw_tree[start_node_id]]
321+
start_depth = draw_tree[start_node_id]["depth"]
322+
total_depth = 1 + min(np.max([item["depth"] for key, item in self.tree.items()]) - start_depth, draw_depth)
323+
max_depth = min(np.max([item["depth"] for key, item in self.tree.items()]), start_depth + draw_depth)
320324
while len(pending_node_list) > 0:
321325

322326
item = pending_node_list.pop()
323-
if item["parent_id"] is None:
327+
if item["depth"] > max_depth:
328+
continue
329+
if item["parent_id"] is None or idx == 0:
324330
xy = (0.5, 0)
325331
parent_xy = None
326332
else:
327-
parent_xy = draw_tree[item["parent_id"]]["xy"]
333+
parent_xy = draw_subtree[item["parent_id"]]["xy"]
328334
if item["is_left"]:
329-
xy = (parent_xy[0] - 1 / 2 ** (item["depth"] + 1), 3 * item["depth"] / (3 * max_depth - 2))
335+
xy = (parent_xy[0] - 1 / 2 ** (item["depth"] - start_depth + 1), 3 * (item["depth"] - start_depth) / (3 * total_depth - 2))
330336
else:
331-
xy = (parent_xy[0] + 1 / 2 ** (item["depth"] + 1), 3 * item["depth"] / (3 * max_depth - 2))
337+
xy = (parent_xy[0] + 1 / 2 ** (item["depth"] - start_depth + 1), 3 * (item["depth"] - start_depth) / (3 * total_depth - 2))
338+
idx += 1
332339

340+
draw_subtree[item["node_id"]] = item
333341
if item["is_leaf"]:
334342
if is_regressor(self):
335-
draw_tree[item["node_id"]].update({"xy": xy,
343+
draw_subtree[item["node_id"]].update({"xy": xy,
336344
"parent_xy": parent_xy,
337345
"estimator": item["estimator"],
338346
"label": "____Node " + str(item["node_id"]) + "____" +
339347
"\nMSE: " + str(np.round(item["impurity"], 3))
340348
+ "\nSize: " + str(int(item["n_samples"]))
341349
+ "\nMean: " + str(np.round(item["value"], 3))})
342350
elif is_classifier(self):
343-
draw_tree[item["node_id"]].update({"xy": xy,
351+
draw_subtree[item["node_id"]].update({"xy": xy,
344352
"parent_xy": parent_xy,
345353
"estimator": item["estimator"],
346354
"label": "____Node " + str(item["node_id"]) + "____" +
347355
"\nCEntropy: " + str(np.round(item["impurity"], 3))
348356
+ "\nSize: " + str(int(item["n_samples"]))
349-
+ "\nMean: " + str(np.round(item["value"], 3))})
357+
+ "\nMean: " + str(np.round(item["value"], 3))})
350358
else:
351359
fill_width = len(self.feature_names[item["feature"]] + " <=" + str(np.round(item["threshold"], 3)))
352360
fill_width = int(round((fill_width - 2) / 2))
353361
if is_regressor(self):
354-
draw_tree[item["node_id"]].update({"xy": xy,
362+
draw_subtree[item["node_id"]].update({"xy": xy,
355363
"parent_xy": parent_xy,
356364
"label": "_" * fill_width + "Node " + str(item["node_id"]) + "_" * fill_width
357365
+ "\n" + self.feature_names[item["feature"]] + " <=" + str(np.round(item["threshold"], 3))
358366
+ "\nMSE: " + str(np.round(item["impurity"], 3))
359367
+ "\nSize: " + str(int(item["n_samples"]))
360368
+ "\nMean: " + str(np.round(item["value"], 3))})
361369
elif is_classifier(self):
362-
draw_tree[item["node_id"]].update({"xy": xy,
370+
draw_subtree[item["node_id"]].update({"xy": xy,
363371
"parent_xy": parent_xy,
364372
"label": "_" * fill_width + "Node " + str(item["node_id"]) + "_" * fill_width
365373
+ "\n" + self.feature_names[item["feature"]] + " <=" + str(np.round(item["threshold"], 3))
@@ -370,7 +378,7 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
370378
pending_node_list.append(self.tree[item["left_child_id"]])
371379
pending_node_list.append(self.tree[item["right_child_id"]])
372380

373-
fig = plt.figure(figsize=(2 ** max_depth, (max_depth - 0.8) * 2))
381+
fig = plt.figure(figsize=(2 ** total_depth, (total_depth - 0.8) * 2))
374382
tree = fig.add_axes([0.0, 0.0, 1, 1])
375383
ax_width = tree.get_window_extent().width
376384
ax_height = tree.get_window_extent().height
@@ -380,7 +388,8 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
380388
values = np.array([item["value"] for key, item in self.tree.items()])
381389
min_value, max_value = values.min(), values.max()
382390

383-
for key, item in draw_tree.items():
391+
idx = 0
392+
for key, item in draw_subtree.items():
384393

385394
if max_value == min_value:
386395
if item["is_leaf"]:
@@ -396,22 +405,24 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
396405
color = [int(round(alpha * c + (1 - alpha) * 255, 0)) for c in color_list]
397406

398407
kwargs = dict(bbox={"fc": '#%2x%2x%2x' % tuple(color), "boxstyle": "round"}, arrowprops={"arrowstyle": "<-"},
399-
ha='center', va='center', zorder=100 - 10 * item["depth"], xycoords='axes pixels', fontsize=14)
408+
ha='center', va='center', zorder=100 - 10 * (item["depth"] - start_depth), xycoords='axes pixels', fontsize=14)
400409

401-
if item["parent_id"] is None:
410+
if item["parent_id"] is None or idx == 0:
402411
tree.annotate(item["label"], (item["xy"][0] * ax_width, (1 - item["xy"][1]) * ax_height), **kwargs)
403412
else:
404413
if item["is_left"]:
405-
tree.annotate(item["label"], ((item["parent_xy"][0] - 0.01 / 2 ** (item["depth"] + 1)) * ax_width,
406-
(1 - item["parent_xy"][1] - 0.1 / max_depth) * ax_height),
414+
tree.annotate(item["label"], ((item["parent_xy"][0] - 0.01 / 2 ** (item["depth"] - start_depth + 1)) * ax_width,
415+
(1 - item["parent_xy"][1] - 0.1 / total_depth) * ax_height),
407416
(item["xy"][0] * ax_width, (1 - item["xy"][1]) * ax_height), **kwargs)
408417
else:
409-
tree.annotate(item["label"], ((item["parent_xy"][0] + 0.01 / 2 ** (item["depth"] + 1)) * ax_width,
410-
(1 - item["parent_xy"][1] - 0.1 / max_depth) * ax_height),
418+
tree.annotate(item["label"], ((item["parent_xy"][0] + 0.01 / 2 ** (item["depth"] - start_depth + 1)) * ax_width,
419+
(1 - item["parent_xy"][1] - 0.1 / total_depth) * ax_height),
411420
(item["xy"][0] * ax_width, (1 - item["xy"][1]) * ax_height), **kwargs)
421+
idx += 1
412422

413423
tree.set_axis_off()
414424
plt.show()
425+
415426
if max_depth > 0:
416427
save_path = folder + name
417428
if save_eps:

0 commit comments

Comments
 (0)