Skip to content

Commit f36ea22

Browse files
author
ZebinYang
committed
skip grid search cv in build leaf if param_dict is empty; version 0.2.4
1 parent f62cdc8 commit f36ea22

File tree

2 files changed

+14
-50
lines changed

2 files changed

+14
-50
lines changed

examples/demo.ipynb

+13-49
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
"execution_count": 1,
66
"metadata": {
77
"ExecuteTime": {
8-
"end_time": "2021-12-27T11:58:43.079438Z",
9-
"start_time": "2021-12-27T11:58:41.821923Z"
8+
"end_time": "2021-12-27T12:22:23.721251Z",
9+
"start_time": "2021-12-27T12:22:22.456359Z"
1010
}
1111
},
1212
"outputs": [],
@@ -29,11 +29,11 @@
2929
},
3030
{
3131
"cell_type": "code",
32-
"execution_count": 9,
32+
"execution_count": 2,
3333
"metadata": {
3434
"ExecuteTime": {
35-
"end_time": "2021-12-27T12:07:47.437564Z",
36-
"start_time": "2021-12-27T12:07:47.428224Z"
35+
"end_time": "2021-12-27T12:22:23.745780Z",
36+
"start_time": "2021-12-27T12:22:23.722805Z"
3737
}
3838
},
3939
"outputs": [],
@@ -128,11 +128,11 @@
128128
},
129129
{
130130
"cell_type": "code",
131-
"execution_count": 16,
131+
"execution_count": 3,
132132
"metadata": {
133133
"ExecuteTime": {
134-
"end_time": "2021-12-27T12:10:02.283742Z",
135-
"start_time": "2021-12-27T12:10:02.097688Z"
134+
"end_time": "2021-12-27T12:22:28.522606Z",
135+
"start_time": "2021-12-27T12:22:28.338619Z"
136136
}
137137
},
138138
"outputs": [],
@@ -209,6 +209,8 @@
209209
}
210210
],
211211
"source": [
212+
"# here we use LogisticRegressor in sklearn\n",
213+
"# reg_lambda corresponds to parameter \"C\", which is inverse of regularization strength.\n",
212214
"clf = GLMTreeClassifier(max_depth=3, min_samples_leaf=50, reg_lambda=np.logspace(-5, 5, 10).tolist(),\n",
213215
" n_split_grid=20, n_screen_grid=5, n_feature_search=10)\n",
214216
"clf.fit(train_x, train_y)\n",
@@ -217,57 +219,19 @@
217219
"roc_auc_score(train_y, pred_train.ravel()), roc_auc_score(test_y, pred_test.ravel())"
218220
]
219221
},
220-
{
221-
"cell_type": "code",
222-
"execution_count": 18,
223-
"metadata": {
224-
"ExecuteTime": {
225-
"end_time": "2021-12-27T12:11:03.435992Z",
226-
"start_time": "2021-12-27T12:10:29.328703Z"
227-
}
228-
},
229-
"outputs": [
230-
{
231-
"ename": "TypeError",
232-
"evalue": "__init__() got an unexpected keyword argument 'alpha'",
233-
"output_type": "error",
234-
"traceback": [
235-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
236-
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
237-
"\u001b[0;32m/tmp/ipykernel_37837/3864539870.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m clf = GLMTreeClassifier(max_depth=1, min_samples_leaf=50, reg_lambda=[0],\n\u001b[1;32m 2\u001b[0m n_split_grid=20, n_screen_grid=5, n_feature_search=10)\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_y\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mpred_train\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mpred_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_proba\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
238-
"\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/simtree/mobtree.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 294\u001b[0m node_id = self.add_node(parent_id, is_left, is_leaf, depth,\n\u001b[0;32m--> 295\u001b[0;31m None, None, impurity, sample_indice)\n\u001b[0m\u001b[1;32m 296\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 297\u001b[0m node_id = self.add_node(parent_id, is_left, is_leaf, depth,\n",
239-
"\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/simtree/mobtree.py\u001b[0m in \u001b[0;36madd_node\u001b[0;34m(self, parent_id, is_left, is_leaf, depth, feature, threshold, impurity, sample_indice)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[0mn_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_indice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 232\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 233\u001b[0;31m \u001b[0mpredict_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbest_impurity\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuild_leaf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msample_indice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 234\u001b[0m node = {\"node_id\": node_id, \"parent_id\": parent_id, \"depth\": depth, \"feature\": feature, \"impurity\": best_impurity,\n\u001b[1;32m 235\u001b[0m \u001b[0;34m\"n_samples\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mn_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"is_left\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mis_left\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"is_leaf\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"value\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0msample_indice\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
240-
"\u001b[0;32m~/anaconda3/envs/py37/lib/python3.7/site-packages/simtree/glmtree.py\u001b[0m in \u001b[0;36mbuild_leaf\u001b[0;34m(self, sample_indice)\u001b[0m\n\u001b[1;32m 96\u001b[0m cv=5, random_state=self.random_state)\n\u001b[1;32m 97\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 98\u001b[0;31m \u001b[0mbest_estimator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLogisticRegression\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malpha\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreg_lambda\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprecompute\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrandom_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandom_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 99\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0mmx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0msample_indice\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
241-
"\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'alpha'"
242-
]
243-
}
244-
],
245-
"source": [
246-
"clf = GLMTreeClassifier(max_depth=1, min_samples_leaf=50, reg_lambda=[0],\n",
247-
" n_split_grid=20, n_screen_grid=5, n_feature_search=10)\n",
248-
"clf.fit(train_x, train_y)\n",
249-
"pred_train = clf.predict_proba(train_x)[:, 1]\n",
250-
"pred_test = clf.predict_proba(test_x)[:, 1]\n",
251-
"roc_auc_score(train_y, pred_train.ravel()), roc_auc_score(test_y, pred_test.ravel())"
252-
]
253-
},
254222
{
255223
"cell_type": "code",
256224
"execution_count": null,
257225
"metadata": {
258226
"ExecuteTime": {
259-
"end_time": "2021-12-27T12:11:03.437503Z",
260-
"start_time": "2021-12-27T12:11:03.437482Z"
227+
"start_time": "2021-12-27T12:27:37.691Z"
261228
}
262229
},
263230
"outputs": [],
264231
"source": [
265-
"clf = SIMTreeClassifier(max_depth=1, min_samples_leaf=50, knot_num=30,\n",
266-
" n_split_grid=20, n_screen_grid=5, n_feature_search=10,\n",
267-
" reg_lambda=[0],\n",
268-
" reg_gamma=[1e-3, 1e-5, 1e-7])\n",
232+
"clf = GLMTreeClassifier(max_depth=1, min_samples_leaf=50, reg_lambda=[1e4],\n",
233+
" n_split_grid=20, n_screen_grid=5, n_feature_search=10)\n",
269234
"clf.fit(train_x, train_y)\n",
270-
"clf.plot_tree()\n",
271235
"pred_train = clf.predict_proba(train_x)[:, 1]\n",
272236
"pred_test = clf.predict_proba(test_x)[:, 1]\n",
273237
"roc_auc_score(train_y, pred_train.ravel()), roc_auc_score(test_y, pred_test.ravel())"

simtree/glmtree.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def build_leaf(self, sample_indice):
9595
best_estimator = LogisticRegressionCV(Cs=self.reg_lambda, penalty="l1", solver="liblinear", scoring="roc_auc",
9696
cv=5, random_state=self.random_state)
9797
else:
98-
best_estimator = LogisticRegression(C=self.reg_lambda[0], random_state=self.random_state)
98+
best_estimator = LogisticRegression(C=self.reg_lambda[0], penalty="l1", solver="liblinear", random_state=self.random_state)
9999

100100
mx = self.x[sample_indice].mean(0)
101101
sx = self.x[sample_indice].std(0) + self.EPSILON

0 commit comments

Comments
 (0)