Skip to content

Commit cfc52d8

Browse files
author
ArturoAmorQ
committed
MNT Update notebooks
1 parent 6a8609d commit cfc52d8

4 files changed

+49
-39
lines changed

notebooks/cross_validation_time.ipynb

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
"(as in \"independent and identically distributed random variables\").</p>\n",
1717
"</div>\n",
1818
"\n",
19-
"This assumption is usually violated when dealing with time series. A sample\n",
20-
"depends on past information.\n",
19+
"This assumption is usually violated in time series, where each sample can be\n",
20+
"influenced by previous samples (both their feature and target values) in an\n",
21+
"inherently ordered sequence.\n",
2122
"\n",
22-
"We will take an example to highlight such issues with non-i.i.d. data in the\n",
23-
"previous cross-validation strategies presented. We are going to load financial\n",
23+
"In this notebook we demonstrate the issues that arise when using the\n",
24+
"cross-validation strategies we have presented so far, along with non-i.i.d.\n",
25+
"data. For such purpose we load financial\n",
2426
"quotations from some energy companies."
2527
]
2628
},
@@ -91,15 +93,21 @@
9193
"data, target = quotes.drop(columns=[\"Chevron\"]), quotes[\"Chevron\"]\n",
9294
"data_train, data_test, target_train, target_test = train_test_split(\n",
9395
" data, target, shuffle=True, random_state=0\n",
94-
")"
96+
")\n",
97+
"\n",
98+
"# Shuffling breaks the index order, but we still want it to be time-ordered\n",
99+
"data_train.sort_index(ascending=True, inplace=True)\n",
100+
"data_test.sort_index(ascending=True, inplace=True)\n",
101+
"target_train.sort_index(ascending=True, inplace=True)\n",
102+
"target_test.sort_index(ascending=True, inplace=True)"
95103
]
96104
},
97105
{
98106
"cell_type": "markdown",
99107
"metadata": {},
100108
"source": [
101109
"We will use a decision tree regressor that we expect to overfit and thus not\n",
102-
"generalize to unseen data. We will use a `ShuffleSplit` cross-validation to\n",
110+
"generalize to unseen data. We use a `ShuffleSplit` cross-validation to\n",
103111
"check the generalization performance of our model.\n",
104112
"\n",
105113
"Let's first define our model"
@@ -138,7 +146,7 @@
138146
"cell_type": "markdown",
139147
"metadata": {},
140148
"source": [
141-
"Finally, we perform the evaluation."
149+
"We then perform the evaluation using the `ShuffleSplit` strategy."
142150
]
143151
},
144152
{
@@ -161,8 +169,10 @@
161169
"source": [
162170
"Surprisingly, we get outstanding generalization performance. We will\n",
163171
"investigate and find the reason for such good results with a model that is\n",
164-
"expected to fail. We previously mentioned that `ShuffleSplit` is an iterative\n",
165-
"cross-validation scheme that shuffles data and split. We will simplify this\n",
172+
"expected to fail. We previously mentioned that `ShuffleSplit` is a\n",
173+
"cross-validation method that iteratively shuffles and splits the data.\n",
174+
"\n",
175+
"We can simplify the\n",
166176
"procedure with a single split and plot the prediction. We can use\n",
167177
"`train_test_split` for this purpose."
168178
]
@@ -202,7 +212,7 @@
202212
"cell_type": "markdown",
203213
"metadata": {},
204214
"source": [
205-
"Similarly, we obtain good results in terms of $R^2$. We will plot the\n",
215+
"Similarly, we obtain good results in terms of $R^2$. We now plot the\n",
206216
"training, testing and prediction samples."
207217
]
208218
},
@@ -225,18 +235,19 @@
225235
"cell_type": "markdown",
226236
"metadata": {},
227237
"source": [
228-
"So in this context, it seems that the model predictions are following the\n",
229-
"testing. But we can also see that the testing samples are next to some\n",
230-
"training sample. And with these time-series, we see a relationship between a\n",
231-
"sample at the time `t` and a sample at `t+1`. In this case, we are violating\n",
232-
"the i.i.d. assumption. The insight to get is the following: a model can output\n",
233-
"of its training set at the time `t` for a testing sample at the time `t+1`.\n",
234-
"This prediction would be close to the true value even if our model did not\n",
235-
"learn anything, but just memorized the training dataset.\n",
238+
"From the plot above, we can see that the training and testing samples are\n",
239+
"alternating. This structure effectively evaluates the model\u2019s ability to\n",
240+
"interpolate between neighboring data points, rather than its true\n",
241+
"generalization ability. As a result, the model\u2019s predictions are close to the\n",
242+
"actual values, even if it has not learned anything meaningful from the data.\n",
243+
"This is a form of **data leakage**, where the model gains access to future\n",
244+
"information (testing data) while training, leading to an over-optimistic\n",
245+
"estimate of the generalization performance.\n",
236246
"\n",
237-
"An easy way to verify this hypothesis is to not shuffle the data when doing\n",
247+
"An easy way to verify this is to not shuffle the data during\n",
238248
"the split. In this case, we will use the first 75% of the data to train and\n",
239-
"the remaining data to test."
249+
"the remaining data to test. This way we preserve the time order of the data, and\n",
250+
"ensure training on past data and evaluating on future data."
240251
]
241252
},
242253
{
@@ -343,20 +354,19 @@
343354
"from sklearn.model_selection import TimeSeriesSplit\n",
344355
"\n",
345356
"cv = TimeSeriesSplit(n_splits=groups.nunique())\n",
346-
"test_score = cross_val_score(\n",
347-
" regressor, data, target, cv=cv, groups=groups, n_jobs=2\n",
348-
")\n",
357+
"test_score = cross_val_score(regressor, data, target, cv=cv, n_jobs=2)\n",
349358
"print(f\"The mean R2 is: {test_score.mean():.2f} \u00b1 {test_score.std():.2f}\")"
350359
]
351360
},
352361
{
353362
"cell_type": "markdown",
354363
"metadata": {},
355364
"source": [
356-
"In conclusion, it is really important to not use an out of the shelves\n",
365+
"In conclusion, it is really important not to carelessly use a\n",
357366
"cross-validation strategy which do not respect some assumptions such as having\n",
358-
"i.i.d data. It might lead to absurd results which could make think that a\n",
359-
"predictive model might work."
367+
"i.i.d data. It might lead to misleading outcomes, creating the false\n",
368+
"impression that a predictive model performs well when it may not be the case\n",
369+
"in the intended real-world scenario."
360370
]
361371
}
362372
],

notebooks/datasets_california_housing.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@
185185
"huge difference. It confirms the intuitions that there are a couple of extreme\n",
186186
"values.\n",
187187
"\n",
188-
"Up to know, we discarded the longitude and latitude that carry geographical\n",
189-
"information. In short, the combination of this feature could help us to decide\n",
190-
"if there are locations associated with high-valued houses. Indeed, we could\n",
188+
"Up to now, we discarded the longitude and latitude that carry geographical\n",
189+
"information. In short, the combination of these features could help us decide\n",
190+
"if there are locations associated with high-value houses. Indeed, we could\n",
191191
"make a scatter plot where the x- and y-axis would be the latitude and\n",
192192
"longitude and the circle size and color would be linked with the house value\n",
193193
"in the district."

notebooks/ensemble_hyperparameters.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,15 @@
199199
"residuals are corrected and then less learners are required. Therefore,\n",
200200
"it can be beneficial to increase `max_iter` if `max_depth` is low.\n",
201201
"\n",
202-
"Finally, we have overlooked the impact of the `learning_rate` parameter until\n",
203-
"now. When fitting the residuals, we would like the tree to try to correct all\n",
204-
"possible errors or only a fraction of them. The learning-rate allows you to\n",
205-
"control this behaviour. A small learning-rate value would only correct the\n",
206-
"residuals of very few samples. If a large learning-rate is set (e.g., 1), we\n",
207-
"would fit the residuals of all samples. So, with a very low learning-rate, we\n",
208-
"would need more estimators to correct the overall error. However, a too large\n",
209-
"learning-rate tends to obtain an overfitted ensemble, similar to having very\n",
210-
"deep trees."
202+
"Finally, we have overlooked the impact of the `learning_rate` parameter\n",
203+
"until now. This parameter controls how much each correction contributes to the\n",
204+
"final prediction. A smaller learning-rate means the corrections of a new\n",
205+
"tree result in small adjustments to the model prediction. When the\n",
206+
"learning-rate is small, the model generally needs more trees to achieve good\n",
207+
"performance. A higher learning-rate makes larger adjustments with each tree,\n",
208+
"which requires fewer trees and trains faster, at the risk of overfitting. The\n",
209+
"learning-rate needs to be tuned by hyperparameter tuning to obtain the best\n",
210+
"value that results in a model with good generalization performance."
211211
]
212212
},
213213
{

notebooks/parameter_tuning_grid_search.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@
250250
"<p>Be aware that the evaluation should normally be performed through\n",
251251
"cross-validation by providing <tt class=\"docutils literal\">model_grid_search</tt> as a model to the\n",
252252
"<tt class=\"docutils literal\">cross_validate</tt> function.</p>\n",
253-
"<p class=\"last\">Here, we used a single train-test split to to evaluate <tt class=\"docutils literal\">model_grid_search</tt>. In\n",
253+
"<p class=\"last\">Here, we used a single train-test split to evaluate <tt class=\"docutils literal\">model_grid_search</tt>. In\n",
254254
"a future notebook will go into more detail about nested cross-validation, when\n",
255255
"you use cross-validation both for hyperparameter tuning and model evaluation.</p>\n",
256256
"</div>"

0 commit comments

Comments
 (0)