diff --git a/decision_tree/decision_tree.py b/decision_tree/decision_tree.py index d1f837f..d9b650a 100644 --- a/decision_tree/decision_tree.py +++ b/decision_tree/decision_tree.py @@ -72,7 +72,7 @@ def calc_ent(x): calculate shanno ent of x """ - x_value_list = set([x[i] for i in range(x.shape[0])]) + x_value_list = set(x) ent = 0.0 for x_value in x_value_list: p = float(x[x == x_value].shape[0]) / x.shape[0] @@ -87,7 +87,7 @@ def calc_condition_ent(x, y): """ # calc ent(y|x) - x_value_list = set([x[i] for i in range(x.shape[0])]) + x_value_list = set(x) ent = 0.0 for x_value in x_value_list: sub_y = y[x == x_value]