Skip to content

Commit 18d4b39

Browse files
optimize code
2 parents 6948ac1 + ef3bd88 commit 18d4b39

File tree

2 files changed

+164
-164
lines changed

2 files changed

+164
-164
lines changed

decision_tree_c45.py

+161-161
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,162 @@
1-
import numpy as np
2-
import treelib
3-
import scipy.stats
4-
5-
class C45():
6-
class __data:
7-
def __init__(self):
8-
self.feature_split = None
9-
self.threshold_split = None
10-
self.data_number = None
11-
self.error_number = None
12-
self.result = None
13-
14-
def __init__(self):
15-
self.__tree = treelib.Tree()
16-
17-
def __get_entropy(self, y):
18-
_, counts = np.unique(y, return_counts=True)
19-
prob_classes = counts / np.sum(counts)
20-
return scipy.stats.entropy(prob_classes)
21-
22-
def __get_info_gain(self, X_subs, y_subs, y):
23-
return self.__get_entropy(y) - sum([self.__get_entropy(y_sub) * len(y_sub) for y_sub in y_subs]) / len(y)
24-
25-
def __get_info_gain_ratio(self, X_subs, y_subs, y):
26-
info_gain = self.__get_info_gain(X_subs, y_subs, y)
27-
if info_gain == 0:
28-
return 0
29-
30-
return info_gain / self.__get_entropy(X_subs)
31-
32-
def __process_discrete(self, x, y):
33-
y_subs = [y[np.flatnonzero(x == feature_label)] for feature_label in np.unique(x)]
34-
return self.__get_info_gain_ratio(x, y_subs, y), None
35-
36-
def __process_continuous(self, x, y):
37-
info_gain_max = -np.inf
38-
x_sort = np.unique(np.sort(x))
39-
for j in range(len(x_sort) - 1):
40-
threshold = (x_sort[j] + x_sort[j + 1]) / 2
41-
42-
less_items = np.flatnonzero(x < threshold)
43-
greater_items = np.flatnonzero(x > threshold)
44-
y_subs = [y[less_items], y[greater_items]]
45-
X_subs = np.append(np.zeros(len(less_items)), np.ones(len(greater_items)))
46-
47-
info_gain = self.__get_info_gain(X_subs, y_subs, y)
48-
if info_gain > info_gain_max:
49-
info_gain_max = info_gain
50-
threshold_split = threshold
51-
info_gain_ratio = self.__get_info_gain_ratio(X_subs, y_subs, y)
52-
53-
return info_gain_ratio, threshold_split
54-
55-
def __create_tree(self, parent, X, y):
56-
data_number, feature_number = X.shape
57-
58-
if data_number == 0:
59-
return
60-
61-
data = self.__data()
62-
data.data_number = data_number
63-
data.result = max(set(y), key=y.tolist().count)
64-
data.error_number = sum(y != data.result)
65-
66-
if len(np.unique(y)) == 1 or (X == X[0]).all():
67-
self.__tree.update_node(parent.identifier, data=data)
68-
return
69-
70-
info_gain_ratio_max = -np.inf
71-
for i in range(feature_number):
72-
if len(np.unique(X[:, i])) == 1:
73-
continue
74-
75-
try:
76-
feature = X[:, i].astype(float)
77-
except:
78-
info_gain_ratio, threshold = self.__process_discrete(X[:, i], y)
79-
else:
80-
info_gain_ratio, threshold = self.__process_continuous(feature, y)
81-
82-
if info_gain_ratio > info_gain_ratio_max:
83-
info_gain_ratio_max = info_gain_ratio
84-
data.feature_split = i
85-
data.threshold_split = threshold
86-
87-
self.__tree.update_node(parent.identifier, data=data)
88-
if data.threshold_split:
89-
feature = X[:, data.feature_split].astype(float)
90-
91-
less_items = np.flatnonzero(feature < data.threshold_split)
92-
greater_items = np.flatnonzero(feature > data.threshold_split)
93-
94-
node = self.__tree.create_node('less ' + str(data.threshold_split), parent=parent)
95-
self.__create_tree(node, X[less_items], y[less_items])
96-
97-
node = self.__tree.create_node('greater ' + str(data.threshold_split), parent=parent)
98-
self.__create_tree(node, X[greater_items], y[greater_items])
99-
else:
100-
for feature_label in np.unique(X[:, data.feature_split]):
101-
node = self.__tree.create_node(feature_label, parent=parent)
102-
self.__create_tree(node, X[np.flatnonzero(X[:, data.feature_split] == feature_label)], y[np.flatnonzero(X[:, data.feature_split] == feature_label)])
103-
104-
def fit(self, X, y):
105-
'''
106-
Parameters
107-
----------
108-
X : shape (data_number, feature_number)
109-
Training data
110-
y : shape (data_number)
111-
Target values, discrete value
112-
'''
113-
root = self.__tree.create_node('root')
114-
self.__create_tree(root, X, y)
115-
self.__tree.show()
116-
117-
def prune_pep(self):
118-
for level in reversed(range(self.__tree.depth())):
119-
for node in self.__tree.all_nodes():
120-
if not self.__tree.contains(node.identifier):
121-
continue
122-
123-
if self.__tree.level(node.identifier) == level and not node.is_leaf():
124-
leaves_number = len(self.__tree.leaves(node.identifier))
125-
leaves_error = sum([leaf.data.error_number for leaf in self.__tree.leaves(node.identifier)])
126-
error = (leaves_error + leaves_number * 0.5) / node.data.data_number
127-
std = np.sqrt(error * (1 - error) * node.data.data_number)
128-
if leaves_error + leaves_number * 0.5 + std > node.data.error_number + 0.5:
129-
for child in self.__tree.children(node.identifier):
130-
self.__tree.remove_node(child.identifier)
131-
132-
self.__tree.show()
133-
134-
def __query(self, x, node):
135-
if node.is_leaf():
136-
return node.data.result
137-
138-
for child in self.__tree.children(node.identifier):
139-
try:
140-
feature = x[node.data.feature_split].astype(float)
141-
except:
142-
if x[node.data.feature_split] == child.tag:
143-
return self.__query(x, child)
144-
else:
145-
if feature < node.data.threshold_split and child.tag == 'less ' + str(node.data.threshold_split):
146-
return self.__query(x, child)
147-
elif feature > node.data.threshold_split and child.tag == 'greater ' + str(node.data.threshold_split):
148-
return self.__query(x, child)
149-
150-
def predict(self, X):
151-
'''
152-
Parameters
153-
----------
154-
X : shape (data_number, feature_number)
155-
Predicting data
156-
157-
Returns
158-
-------
159-
y : shape (data_number,)
160-
Predicted class label per sample
161-
'''
1+
import numpy as np
2+
import treelib
3+
import scipy.stats
4+
5+
class C45():
6+
class __data:
7+
def __init__(self):
8+
self.feature_split = None
9+
self.threshold_split = None
10+
self.data_number = None
11+
self.error_number = None
12+
self.result = None
13+
14+
def __init__(self):
15+
self.__tree = treelib.Tree()
16+
17+
def __get_entropy(self, y):
18+
_, counts = np.unique(y, return_counts=True)
19+
prob_classes = counts / np.sum(counts)
20+
return scipy.stats.entropy(prob_classes)
21+
22+
def __get_info_gain(self, X_subs, y_subs, y):
23+
return self.__get_entropy(y) - sum([self.__get_entropy(y_sub) * len(y_sub) for y_sub in y_subs]) / len(y)
24+
25+
def __get_info_gain_ratio(self, X_subs, y_subs, y):
26+
info_gain = self.__get_info_gain(X_subs, y_subs, y)
27+
if info_gain == 0:
28+
return 0
29+
30+
return info_gain / self.__get_entropy(X_subs)
31+
32+
def __process_discrete(self, x, y):
33+
y_subs = [y[np.flatnonzero(x == feature_label)] for feature_label in np.unique(x)]
34+
return self.__get_info_gain_ratio(x, y_subs, y), None
35+
36+
def __process_continuous(self, x, y):
37+
info_gain_max = -np.inf
38+
x_sort = np.unique(np.sort(x))
39+
for j in range(len(x_sort) - 1):
40+
threshold = (x_sort[j] + x_sort[j + 1]) / 2
41+
42+
less_items = np.flatnonzero(x <= threshold)
43+
greater_items = np.flatnonzero(x > threshold)
44+
y_subs = [y[less_items], y[greater_items]]
45+
X_subs = np.append(np.zeros(len(less_items)), np.ones(len(greater_items)))
46+
47+
info_gain = self.__get_info_gain(X_subs, y_subs, y)
48+
if info_gain > info_gain_max:
49+
info_gain_max = info_gain
50+
threshold_split = threshold
51+
info_gain_ratio = self.__get_info_gain_ratio(X_subs, y_subs, y)
52+
53+
return info_gain_ratio, threshold_split
54+
55+
def __create_tree(self, parent, X, y):
56+
data_number, feature_number = X.shape
57+
58+
if data_number == 0:
59+
return
60+
61+
data = self.__data()
62+
data.data_number = data_number
63+
data.result = max(set(y), key=y.tolist().count)
64+
data.error_number = sum(y != data.result)
65+
66+
if len(np.unique(y)) == 1 or (X == X[0]).all():
67+
self.__tree.update_node(parent.identifier, data=data)
68+
return
69+
70+
info_gain_ratio_max = -np.inf
71+
for i in range(feature_number):
72+
if len(np.unique(X[:, i])) == 1:
73+
continue
74+
75+
try:
76+
feature = X[:, i].astype(float)
77+
except:
78+
info_gain_ratio, threshold = self.__process_discrete(X[:, i], y)
79+
else:
80+
info_gain_ratio, threshold = self.__process_continuous(feature, y)
81+
82+
if info_gain_ratio > info_gain_ratio_max:
83+
info_gain_ratio_max = info_gain_ratio
84+
data.feature_split = i
85+
data.threshold_split = threshold
86+
87+
self.__tree.update_node(parent.identifier, data=data)
88+
if data.threshold_split:
89+
feature = X[:, data.feature_split].astype(float)
90+
91+
less_items = np.flatnonzero(feature <= data.threshold_split)
92+
greater_items = np.flatnonzero(feature > data.threshold_split)
93+
94+
node = self.__tree.create_node('less ' + str(data.threshold_split), parent=parent)
95+
self.__create_tree(node, X[less_items], y[less_items])
96+
97+
node = self.__tree.create_node('greater ' + str(data.threshold_split), parent=parent)
98+
self.__create_tree(node, X[greater_items], y[greater_items])
99+
else:
100+
for feature_label in np.unique(X[:, data.feature_split]):
101+
node = self.__tree.create_node(feature_label, parent=parent)
102+
self.__create_tree(node, X[np.flatnonzero(X[:, data.feature_split] == feature_label)], y[np.flatnonzero(X[:, data.feature_split] == feature_label)])
103+
104+
def fit(self, X, y):
105+
'''
106+
Parameters
107+
----------
108+
X : shape (data_number, feature_number)
109+
Training data
110+
y : shape (data_number)
111+
Target values, discrete value
112+
'''
113+
root = self.__tree.create_node('root')
114+
self.__create_tree(root, X, y)
115+
self.__tree.show()
116+
117+
def prune_pep(self):
118+
for level in reversed(range(self.__tree.depth())):
119+
for node in self.__tree.all_nodes():
120+
if not self.__tree.contains(node.identifier):
121+
continue
122+
123+
if self.__tree.level(node.identifier) == level and not node.is_leaf():
124+
leaves_number = len(self.__tree.leaves(node.identifier))
125+
leaves_error = sum([leaf.data.error_number for leaf in self.__tree.leaves(node.identifier)])
126+
error = (leaves_error + leaves_number * 0.5) / node.data.data_number
127+
std = np.sqrt(error * (1 - error) * node.data.data_number)
128+
if leaves_error + leaves_number * 0.5 + std > node.data.error_number + 0.5:
129+
for child in self.__tree.children(node.identifier):
130+
self.__tree.remove_node(child.identifier)
131+
132+
self.__tree.show()
133+
134+
def __query(self, x, node):
135+
if node.is_leaf():
136+
return node.data.result
137+
138+
for child in self.__tree.children(node.identifier):
139+
try:
140+
feature = x[node.data.feature_split].astype(float)
141+
except:
142+
if x[node.data.feature_split] == child.tag:
143+
return self.__query(x, child)
144+
else:
145+
if feature <= node.data.threshold_split and child.tag == 'less ' + str(node.data.threshold_split):
146+
return self.__query(x, child)
147+
elif feature > node.data.threshold_split and child.tag == 'greater ' + str(node.data.threshold_split):
148+
return self.__query(x, child)
149+
150+
def predict(self, X):
151+
'''
152+
Parameters
153+
----------
154+
X : shape (data_number, feature_number)
155+
Predicting data
156+
157+
Returns
158+
-------
159+
y : shape (data_number,)
160+
Predicted class label per sample
161+
'''
162162
return np.apply_along_axis(self.__query, 1, X, self.__tree.get_node(self.__tree.root))

decision_tree_cart.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __process_continuous(self, x, y):
4949
for j in range(len(x_sort) - 1):
5050
threshold = (x_sort[j] + x_sort[j + 1]) / 2
5151

52-
less_items = np.flatnonzero(x < threshold)
52+
less_items = np.flatnonzero(x <= threshold)
5353
greater_items = np.flatnonzero(x > threshold)
5454
score = self.__get_score(y[less_items], y[greater_items])
5555
if score > score_max:
@@ -98,7 +98,7 @@ def __create_tree(self, parent, X, y):
9898
if data.threshold_split:
9999
feature = X[:, data.feature_split].astype(float)
100100

101-
less_items = np.flatnonzero(feature < data.threshold_split)
101+
less_items = np.flatnonzero(feature <= data.threshold_split)
102102
greater_items = np.flatnonzero(feature > data.threshold_split)
103103

104104
node = self.__tree.create_node('less ' + str(data.threshold_split), parent=parent)
@@ -145,7 +145,7 @@ def __query(self, x, node):
145145
elif x[node.data.feature_split] != node.data.feature_label_split and child.tag == 'not ' + str(node.data.feature_label_split):
146146
return self.__query(x, child)
147147
else:
148-
if feature < node.data.threshold_split and child.tag == 'less ' + str(node.data.threshold_split):
148+
if feature <= node.data.threshold_split and child.tag == 'less ' + str(node.data.threshold_split):
149149
return self.__query(x, child)
150150
elif feature > node.data.threshold_split and child.tag == 'greater ' + str(node.data.threshold_split):
151151
return self.__query(x, child)

0 commit comments

Comments
 (0)