Skip to content

Commit 87d1430

Browse files
authored
fix(global-sklearn-tfm): apply inverse transform to each column (#477)
1 parent f7b6326 commit 87d1430

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

Diff for: mlforecast/target_transforms.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,11 @@ def inverse_transform(self, df: DataFrame) -> DataFrame:
330330
cols_to_transform = [
331331
c for c in df.columns if c not in (self.id_col, self.time_col)
332332
]
333-
transformed = self.transformer_.inverse_transform(
334-
df[cols_to_transform].to_numpy()
333+
transformed = np.hstack(
334+
[
335+
self.transformer_.inverse_transform(df[[col]].to_numpy())
336+
for col in cols_to_transform
337+
]
335338
)
336339
return ufp.assign_columns(df, cols_to_transform, transformed)
337340

Diff for: nbs/target_transforms.ipynb

+14-7
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"source": [
6464
"import pandas as pd\n",
6565
"from fastcore.test import test_fail\n",
66+
"from sklearn.ensemble import HistGradientBoostingRegressor\n",
6667
"from sklearn.linear_model import LinearRegression\n",
6768
"from sklearn.preprocessing import PowerTransformer\n",
6869
"from utilsforecast.processing import counts_by_id\n",
@@ -707,8 +708,11 @@
707708
" cols_to_transform = [\n",
708709
" c for c in df.columns if c not in (self.id_col, self.time_col)\n",
709710
" ]\n",
710-
" transformed = self.transformer_.inverse_transform(\n",
711-
" df[cols_to_transform].to_numpy()\n",
711+
" transformed = np.hstack(\n",
712+
" [\n",
713+
" self.transformer_.inverse_transform(df[[col]].to_numpy())\n",
714+
" for col in cols_to_transform\n",
715+
" ]\n",
712716
" )\n",
713717
" return ufp.assign_columns(df, cols_to_transform, transformed)\n",
714718
"\n",
@@ -736,7 +740,7 @@
736740
{
737741
"cell_type": "code",
738742
"execution_count": null,
739-
"id": "af1a2af6-53ed-4dc6-8c79-99d21eb97257",
743+
"id": "15501071-9dce-4c6f-82ca-ff57aaf2d663",
740744
"metadata": {},
741745
"outputs": [],
742746
"source": [
@@ -745,7 +749,7 @@
745749
"single_difference = ExportedDifferences([1])\n",
746750
"series = generate_daily_series(10)\n",
747751
"fcst = MLForecast(\n",
748-
" models=[LinearRegression()],\n",
752+
" models=[LinearRegression(), HistGradientBoostingRegressor()],\n",
749753
" freq='D',\n",
750754
" lags=[1, 2],\n",
751755
" target_transforms=[boxcox_global, single_difference]\n",
@@ -759,7 +763,8 @@
759763
" .dropna()\n",
760764
" .values\n",
761765
")\n",
762-
"np.testing.assert_allclose(prep['y'].values, expected)"
766+
"np.testing.assert_allclose(prep['y'].values, expected)\n",
767+
"preds = fcst.fit(series).predict(5)"
763768
]
764769
},
765770
{
@@ -773,13 +778,15 @@
773778
"#| polars\n",
774779
"series_pl = generate_daily_series(10, engine='polars')\n",
775780
"fcst_pl = MLForecast(\n",
776-
" models=[LinearRegression()],\n",
781+
" models=[LinearRegression(), HistGradientBoostingRegressor()],\n",
777782
" freq='1d',\n",
778783
" lags=[1, 2],\n",
779784
" target_transforms=[boxcox_global, single_difference]\n",
780785
")\n",
781786
"prep_pl = fcst_pl.preprocess(series_pl, dropna=False)\n",
782-
"pd.testing.assert_frame_equal(prep.reset_index(drop=True), prep_pl.to_pandas())"
787+
"pd.testing.assert_frame_equal(prep.reset_index(drop=True), prep_pl.to_pandas())\n",
788+
"pl_preds = fcst_pl.fit(series_pl).predict(5)\n",
789+
"pd.testing.assert_frame_equal(preds, pl_preds.to_pandas())"
783790
]
784791
}
785792
],

0 commit comments

Comments
 (0)