@@ -179,7 +179,7 @@ def node_split(self, sample_indice):
179
179
if sortted_feature [i + 1 ] <= sortted_feature [i ] + self .EPSILON :
180
180
continue
181
181
182
- if self .min_samples_leaf < n_samples / ( self .n_split_grid - 1 ):
182
+ if self .min_samples_leaf < n_samples / max (( self .n_split_grid - 1 ), 2 ):
183
183
if (i + 1 ) / n_samples < (split_point + 1 ) / (self .n_split_grid + 1 ):
184
184
continue
185
185
elif n_samples > 2 * self .min_samples_leaf :
@@ -312,54 +312,62 @@ def fit(self, x, y):
312
312
"is_left" : False })
313
313
return self
314
314
315
- def plot_tree (self , folder = "./results/" , name = "demo" , save_png = False , save_eps = False ):
315
+ def plot_tree (self , draw_depth = np . inf , start_node_id = 1 , folder = "./results/" , name = "demo" , save_png = False , save_eps = False ):
316
316
317
+ idx = 0
318
+ draw_subtree = {}
317
319
draw_tree = copy .deepcopy (self .tree )
318
- pending_node_list = [draw_tree [1 ]]
319
- max_depth = 1 + np .max ([item ["depth" ] for key , item in self .tree .items ()])
320
+ pending_node_list = [draw_tree [start_node_id ]]
321
+ start_depth = draw_tree [start_node_id ]["depth" ]
322
+ total_depth = 1 + min (np .max ([item ["depth" ] for key , item in self .tree .items ()]) - start_depth , draw_depth )
323
+ max_depth = min (np .max ([item ["depth" ] for key , item in self .tree .items ()]), start_depth + draw_depth )
320
324
while len (pending_node_list ) > 0 :
321
325
322
326
item = pending_node_list .pop ()
323
- if item ["parent_id" ] is None :
327
+ if item ["depth" ] > max_depth :
328
+ continue
329
+ if item ["parent_id" ] is None or idx == 0 :
324
330
xy = (0.5 , 0 )
325
331
parent_xy = None
326
332
else :
327
- parent_xy = draw_tree [item ["parent_id" ]]["xy" ]
333
+ parent_xy = draw_subtree [item ["parent_id" ]]["xy" ]
328
334
if item ["is_left" ]:
329
- xy = (parent_xy [0 ] - 1 / 2 ** (item ["depth" ] + 1 ), 3 * item ["depth" ] / (3 * max_depth - 2 ))
335
+ xy = (parent_xy [0 ] - 1 / 2 ** (item ["depth" ] - start_depth + 1 ), 3 * ( item ["depth" ] - start_depth ) / (3 * total_depth - 2 ))
330
336
else :
331
- xy = (parent_xy [0 ] + 1 / 2 ** (item ["depth" ] + 1 ), 3 * item ["depth" ] / (3 * max_depth - 2 ))
337
+ xy = (parent_xy [0 ] + 1 / 2 ** (item ["depth" ] - start_depth + 1 ), 3 * (item ["depth" ] - start_depth ) / (3 * total_depth - 2 ))
338
+ idx += 1
332
339
340
+ draw_subtree [item ["node_id" ]] = item
333
341
if item ["is_leaf" ]:
334
342
if is_regressor (self ):
335
- draw_tree [item ["node_id" ]].update ({"xy" : xy ,
343
+ draw_subtree [item ["node_id" ]].update ({"xy" : xy ,
336
344
"parent_xy" : parent_xy ,
337
345
"estimator" : item ["estimator" ],
338
346
"label" : "____Node " + str (item ["node_id" ]) + "____" +
339
347
"\n MSE: " + str (np .round (item ["impurity" ], 3 ))
340
348
+ "\n Size: " + str (int (item ["n_samples" ]))
341
349
+ "\n Mean: " + str (np .round (item ["value" ], 3 ))})
342
350
elif is_classifier (self ):
343
- draw_tree [item ["node_id" ]].update ({"xy" : xy ,
351
+ draw_subtree [item ["node_id" ]].update ({"xy" : xy ,
344
352
"parent_xy" : parent_xy ,
345
353
"estimator" : item ["estimator" ],
346
354
"label" : "____Node " + str (item ["node_id" ]) + "____" +
347
355
"\n CEntropy: " + str (np .round (item ["impurity" ], 3 ))
348
356
+ "\n Size: " + str (int (item ["n_samples" ]))
349
- + "\n Mean: " + str (np .round (item ["value" ], 3 ))})
357
+ + "\n Mean: " + str (np .round (item ["value" ], 3 ))})
350
358
else :
351
359
fill_width = len (self .feature_names [item ["feature" ]] + " <=" + str (np .round (item ["threshold" ], 3 )))
352
360
fill_width = int (round ((fill_width - 2 ) / 2 ))
353
361
if is_regressor (self ):
354
- draw_tree [item ["node_id" ]].update ({"xy" : xy ,
362
+ draw_subtree [item ["node_id" ]].update ({"xy" : xy ,
355
363
"parent_xy" : parent_xy ,
356
364
"label" : "_" * fill_width + "Node " + str (item ["node_id" ]) + "_" * fill_width
357
365
+ "\n " + self .feature_names [item ["feature" ]] + " <=" + str (np .round (item ["threshold" ], 3 ))
358
366
+ "\n MSE: " + str (np .round (item ["impurity" ], 3 ))
359
367
+ "\n Size: " + str (int (item ["n_samples" ]))
360
368
+ "\n Mean: " + str (np .round (item ["value" ], 3 ))})
361
369
elif is_classifier (self ):
362
- draw_tree [item ["node_id" ]].update ({"xy" : xy ,
370
+ draw_subtree [item ["node_id" ]].update ({"xy" : xy ,
363
371
"parent_xy" : parent_xy ,
364
372
"label" : "_" * fill_width + "Node " + str (item ["node_id" ]) + "_" * fill_width
365
373
+ "\n " + self .feature_names [item ["feature" ]] + " <=" + str (np .round (item ["threshold" ], 3 ))
@@ -370,7 +378,7 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
370
378
pending_node_list .append (self .tree [item ["left_child_id" ]])
371
379
pending_node_list .append (self .tree [item ["right_child_id" ]])
372
380
373
- fig = plt .figure (figsize = (2 ** max_depth , (max_depth - 0.8 ) * 2 ))
381
+ fig = plt .figure (figsize = (2 ** total_depth , (total_depth - 0.8 ) * 2 ))
374
382
tree = fig .add_axes ([0.0 , 0.0 , 1 , 1 ])
375
383
ax_width = tree .get_window_extent ().width
376
384
ax_height = tree .get_window_extent ().height
@@ -380,7 +388,8 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
380
388
values = np .array ([item ["value" ] for key , item in self .tree .items ()])
381
389
min_value , max_value = values .min (), values .max ()
382
390
383
- for key , item in draw_tree .items ():
391
+ idx = 0
392
+ for key , item in draw_subtree .items ():
384
393
385
394
if max_value == min_value :
386
395
if item ["is_leaf" ]:
@@ -396,22 +405,24 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
396
405
color = [int (round (alpha * c + (1 - alpha ) * 255 , 0 )) for c in color_list ]
397
406
398
407
kwargs = dict (bbox = {"fc" : '#%2x%2x%2x' % tuple (color ), "boxstyle" : "round" }, arrowprops = {"arrowstyle" : "<-" },
399
- ha = 'center' , va = 'center' , zorder = 100 - 10 * item ["depth" ], xycoords = 'axes pixels' , fontsize = 14 )
408
+ ha = 'center' , va = 'center' , zorder = 100 - 10 * ( item ["depth" ] - start_depth ) , xycoords = 'axes pixels' , fontsize = 14 )
400
409
401
- if item ["parent_id" ] is None :
410
+ if item ["parent_id" ] is None or idx == 0 :
402
411
tree .annotate (item ["label" ], (item ["xy" ][0 ] * ax_width , (1 - item ["xy" ][1 ]) * ax_height ), ** kwargs )
403
412
else :
404
413
if item ["is_left" ]:
405
- tree .annotate (item ["label" ], ((item ["parent_xy" ][0 ] - 0.01 / 2 ** (item ["depth" ] + 1 )) * ax_width ,
406
- (1 - item ["parent_xy" ][1 ] - 0.1 / max_depth ) * ax_height ),
414
+ tree .annotate (item ["label" ], ((item ["parent_xy" ][0 ] - 0.01 / 2 ** (item ["depth" ] - start_depth + 1 )) * ax_width ,
415
+ (1 - item ["parent_xy" ][1 ] - 0.1 / total_depth ) * ax_height ),
407
416
(item ["xy" ][0 ] * ax_width , (1 - item ["xy" ][1 ]) * ax_height ), ** kwargs )
408
417
else :
409
- tree .annotate (item ["label" ], ((item ["parent_xy" ][0 ] + 0.01 / 2 ** (item ["depth" ] + 1 )) * ax_width ,
410
- (1 - item ["parent_xy" ][1 ] - 0.1 / max_depth ) * ax_height ),
418
+ tree .annotate (item ["label" ], ((item ["parent_xy" ][0 ] + 0.01 / 2 ** (item ["depth" ] - start_depth + 1 )) * ax_width ,
419
+ (1 - item ["parent_xy" ][1 ] - 0.1 / total_depth ) * ax_height ),
411
420
(item ["xy" ][0 ] * ax_width , (1 - item ["xy" ][1 ]) * ax_height ), ** kwargs )
421
+ idx += 1
412
422
413
423
tree .set_axis_off ()
414
424
plt .show ()
425
+
415
426
if max_depth > 0 :
416
427
save_path = folder + name
417
428
if save_eps :
0 commit comments