|
63 | 63 | "source": [
|
64 | 64 | "import pandas as pd\n",
|
65 | 65 | "from fastcore.test import test_fail\n",
|
| 66 | + "from sklearn.ensemble import HistGradientBoostingRegressor\n", |
66 | 67 | "from sklearn.linear_model import LinearRegression\n",
|
67 | 68 | "from sklearn.preprocessing import PowerTransformer\n",
|
68 | 69 | "from utilsforecast.processing import counts_by_id\n",
|
|
707 | 708 | " cols_to_transform = [\n",
|
708 | 709 | " c for c in df.columns if c not in (self.id_col, self.time_col)\n",
|
709 | 710 | " ]\n",
|
710 |
| - " transformed = self.transformer_.inverse_transform(\n", |
711 |
| - " df[cols_to_transform].to_numpy()\n", |
| 711 | + " transformed = np.hstack(\n", |
| 712 | + " [\n", |
| 713 | + " self.transformer_.inverse_transform(df[[col]].to_numpy())\n", |
| 714 | + " for col in cols_to_transform\n", |
| 715 | + " ]\n", |
712 | 716 | " )\n",
|
713 | 717 | " return ufp.assign_columns(df, cols_to_transform, transformed)\n",
|
714 | 718 | "\n",
|
|
736 | 740 | {
|
737 | 741 | "cell_type": "code",
|
738 | 742 | "execution_count": null,
|
739 |
| - "id": "af1a2af6-53ed-4dc6-8c79-99d21eb97257", |
| 743 | + "id": "15501071-9dce-4c6f-82ca-ff57aaf2d663", |
740 | 744 | "metadata": {},
|
741 | 745 | "outputs": [],
|
742 | 746 | "source": [
|
|
745 | 749 | "single_difference = ExportedDifferences([1])\n",
|
746 | 750 | "series = generate_daily_series(10)\n",
|
747 | 751 | "fcst = MLForecast(\n",
|
748 |
| - " models=[LinearRegression()],\n", |
| 752 | + " models=[LinearRegression(), HistGradientBoostingRegressor()],\n", |
749 | 753 | " freq='D',\n",
|
750 | 754 | " lags=[1, 2],\n",
|
751 | 755 | " target_transforms=[boxcox_global, single_difference]\n",
|
|
759 | 763 | " .dropna()\n",
|
760 | 764 | " .values\n",
|
761 | 765 | ")\n",
|
762 |
| - "np.testing.assert_allclose(prep['y'].values, expected)" |
| 766 | + "np.testing.assert_allclose(prep['y'].values, expected)\n", |
| 767 | + "preds = fcst.fit(series).predict(5)" |
763 | 768 | ]
|
764 | 769 | },
|
765 | 770 | {
|
|
773 | 778 | "#| polars\n",
|
774 | 779 | "series_pl = generate_daily_series(10, engine='polars')\n",
|
775 | 780 | "fcst_pl = MLForecast(\n",
|
776 |
| - " models=[LinearRegression()],\n", |
| 781 | + " models=[LinearRegression(), HistGradientBoostingRegressor()],\n", |
777 | 782 | " freq='1d',\n",
|
778 | 783 | " lags=[1, 2],\n",
|
779 | 784 | " target_transforms=[boxcox_global, single_difference]\n",
|
780 | 785 | ")\n",
|
781 | 786 | "prep_pl = fcst_pl.preprocess(series_pl, dropna=False)\n",
|
782 |
| - "pd.testing.assert_frame_equal(prep.reset_index(drop=True), prep_pl.to_pandas())" |
| 787 | + "pd.testing.assert_frame_equal(prep.reset_index(drop=True), prep_pl.to_pandas())\n", |
| 788 | + "pl_preds = fcst_pl.fit(series_pl).predict(5)\n", |
| 789 | + "pd.testing.assert_frame_equal(preds, pl_preds.to_pandas())" |
783 | 790 | ]
|
784 | 791 | }
|
785 | 792 | ],
|
|
0 commit comments