Skip to content

Commit 971ea3b

Browse files
author
ZebinYang
committed
skip grid search cv in build leaf if param_dict is empty; version 0.2.4
1 parent f97c846 commit 971ea3b

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

examples/demo.ipynb

+15-15
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
"execution_count": 1,
66
"metadata": {
77
"ExecuteTime": {
8-
"end_time": "2021-12-27T12:22:23.721251Z",
9-
"start_time": "2021-12-27T12:22:22.456359Z"
8+
"end_time": "2021-12-27T12:56:44.118235Z",
9+
"start_time": "2021-12-27T12:56:42.884995Z"
1010
}
1111
},
1212
"outputs": [],
@@ -32,8 +32,8 @@
3232
"execution_count": 2,
3333
"metadata": {
3434
"ExecuteTime": {
35-
"end_time": "2021-12-27T12:22:23.745780Z",
36-
"start_time": "2021-12-27T12:22:23.722805Z"
35+
"end_time": "2021-12-27T12:56:44.142989Z",
36+
"start_time": "2021-12-27T12:56:44.119690Z"
3737
}
3838
},
3939
"outputs": [],
@@ -48,8 +48,8 @@
4848
"execution_count": 3,
4949
"metadata": {
5050
"ExecuteTime": {
51-
"end_time": "2021-12-27T12:00:04.703220Z",
52-
"start_time": "2021-12-27T11:58:43.105225Z"
51+
"end_time": "2021-12-27T12:58:05.418585Z",
52+
"start_time": "2021-12-27T12:56:44.143926Z"
5353
},
5454
"scrolled": true
5555
},
@@ -94,8 +94,8 @@
9494
"execution_count": 4,
9595
"metadata": {
9696
"ExecuteTime": {
97-
"end_time": "2021-12-27T12:00:07.483934Z",
98-
"start_time": "2021-12-27T12:00:04.704162Z"
97+
"end_time": "2021-12-27T12:58:08.216658Z",
98+
"start_time": "2021-12-27T12:58:05.419872Z"
9999
}
100100
},
101101
"outputs": [
@@ -128,11 +128,11 @@
128128
},
129129
{
130130
"cell_type": "code",
131-
"execution_count": 3,
131+
"execution_count": 5,
132132
"metadata": {
133133
"ExecuteTime": {
134-
"end_time": "2021-12-27T12:22:28.522606Z",
135-
"start_time": "2021-12-27T12:22:28.338619Z"
134+
"end_time": "2021-12-27T12:58:08.449903Z",
135+
"start_time": "2021-12-27T12:58:08.217997Z"
136136
}
137137
},
138138
"outputs": [],
@@ -147,8 +147,8 @@
147147
"execution_count": 6,
148148
"metadata": {
149149
"ExecuteTime": {
150-
"end_time": "2021-12-27T12:04:03.092316Z",
151-
"start_time": "2021-12-27T12:00:32.684692Z"
150+
"end_time": "2021-12-27T13:01:41.797837Z",
151+
"start_time": "2021-12-27T12:58:08.450917Z"
152152
}
153153
},
154154
"outputs": [
@@ -192,8 +192,8 @@
192192
"execution_count": 7,
193193
"metadata": {
194194
"ExecuteTime": {
195-
"end_time": "2021-12-27T12:05:37.162885Z",
196-
"start_time": "2021-12-27T12:04:03.093335Z"
195+
"end_time": "2021-12-27T13:03:16.196360Z",
196+
"start_time": "2021-12-27T13:01:41.798966Z"
197197
}
198198
},
199199
"outputs": [

simtree/customtree.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,16 @@ def build_root(self):
4343
def build_leaf(self, sample_indice):
4444

4545
if len(self.param_dict) == 0:
46-
self.base_estimator.fit(self.x[sample_indice], self.y[sample_indice].ravel())
47-
best_estimator = self.base_estimator
46+
best_estimator = clone(self.base_estimator)
47+
best_estimator.fit(self.x[sample_indice], self.y[sample_indice].ravel())
4848
else:
4949
param_size = 1
5050
for key, item in self.param_dict.items():
5151
param_size *= len(item)
5252
if param_size == 1:
53-
self.base_estimator.set_params(**{key: item[0] for key, item in self.param_dict.items()})
54-
self.base_estimator.fit(self.x[sample_indice], self.y[sample_indice].ravel())
55-
best_estimator = self.base_estimator
53+
best_estimator = clone(self.base_estimator)
54+
best_estimator.set_params(**{key: item[0] for key, item in self.param_dict.items()})
55+
best_estimator.fit(self.x[sample_indice], self.y[sample_indice].ravel())
5656
else:
5757
grid = GridSearchCV(self.base_estimator, param_grid=self.param_dict,
5858
scoring={"mse": make_scorer(mean_squared_error, greater_is_better=False)},

0 commit comments

Comments
 (0)