Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: X_df handling in direct approach #468

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions mlforecast/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def _predict_multi(
result = df_constructor({self.id_col: uids, self.time_col: dates})
for name, model in models.items():
with self._backup():
self._predict_setup()
new_x = self._get_features_for_next_step(X_df)
if before_predict_callback is not None:
new_x = before_predict_callback(new_x)
Expand Down Expand Up @@ -789,10 +790,16 @@ def predict(
raise ValueError(
f"The following features were provided through `X_df` but were considered as static during fit: {common}.\n"
"Please re-run the fit step using the `static_features` argument to indicate which features are static. "
"If all your features are dynamic please pass an empty list (static_features=[])."
"If all your features are dynamic please provide an empty list (static_features=[])."
)
starts = ufp.offset_times(self.last_dates, self.freq, 1)
ends = ufp.offset_times(self.last_dates, self.freq, horizon)
if getattr(self, "max_horizon", None) is None:
ends = ufp.offset_times(self.last_dates, self.freq, horizon)
expected_rows_X = len(self.uids) * horizon
else:
# direct approach uses only the immediate next timestamp
ends = starts
expected_rows_X = len(self.uids)
dates_validation = type(X_df)(
{
self.id_col: self.uids,
Expand All @@ -803,7 +810,7 @@ def predict(
X_df = ufp.join(X_df, dates_validation, on=self.id_col)
mask = ufp.between(X_df[self.time_col], X_df["_start"], X_df["_end"])
X_df = ufp.filter_with_mask(X_df, mask)
if X_df.shape[0] != len(self.uids) * horizon:
if X_df.shape[0] != expected_rows_X:
msg = (
"Found missing inputs in X_df. "
"It should have one row per id and time for the complete forecasting horizon.\n"
Expand Down
21 changes: 14 additions & 7 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,7 @@
" result = df_constructor({self.id_col: uids, self.time_col: dates})\n",
" for name, model in models.items():\n",
" with self._backup():\n",
" self._predict_setup()\n",
" new_x = self._get_features_for_next_step(X_df)\n",
" if before_predict_callback is not None:\n",
" new_x = before_predict_callback(new_x)\n",
Expand Down Expand Up @@ -1261,10 +1262,16 @@
" raise ValueError(\n",
" f\"The following features were provided through `X_df` but were considered as static during fit: {common}.\\n\"\n",
" \"Please re-run the fit step using the `static_features` argument to indicate which features are static. \"\n",
" \"If all your features are dynamic please pass an empty list (static_features=[]).\"\n",
" \"If all your features are dynamic please provide an empty list (static_features=[]).\"\n",
" )\n",
" starts = ufp.offset_times(self.last_dates, self.freq, 1)\n",
" ends = ufp.offset_times(self.last_dates, self.freq, horizon)\n",
" if getattr(self, 'max_horizon', None) is None:\n",
" ends = ufp.offset_times(self.last_dates, self.freq, horizon)\n",
" expected_rows_X = len(self.uids) * horizon\n",
" else:\n",
" # direct approach uses only the immediate next timestamp\n",
" ends = starts\n",
" expected_rows_X = len(self.uids)\n",
" dates_validation = type(X_df)(\n",
" {\n",
" self.id_col: self.uids,\n",
Expand All @@ -1275,7 +1282,7 @@
" X_df = ufp.join(X_df, dates_validation, on=self.id_col)\n",
" mask = ufp.between(X_df[self.time_col], X_df['_start'], X_df['_end'])\n",
" X_df = ufp.filter_with_mask(X_df, mask)\n",
" if X_df.shape[0] != len(self.uids) * horizon:\n",
" if X_df.shape[0] != expected_rows_X:\n",
" msg = (\n",
" \"Found missing inputs in X_df. \"\n",
" \"It should have one row per id and time for the complete forecasting horizon.\\n\"\n",
Expand Down Expand Up @@ -2015,7 +2022,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L757){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.predict\n",
"\n",
Expand All @@ -2029,7 +2036,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L757){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L758){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.predict\n",
"\n",
Expand Down Expand Up @@ -2167,7 +2174,7 @@
"text/markdown": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L862){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L869){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.update\n",
"\n",
Expand All @@ -2180,7 +2187,7 @@
"text/plain": [
"---\n",
"\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L862){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"[source](https://github.com/Nixtla/mlforecast/blob/main/mlforecast/core.py#L869){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"## TimeSeries.update\n",
"\n",
Expand Down
28 changes: 28 additions & 0 deletions nbs/forecast.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,7 @@
"import numpy as np\n",
"import xgboost as xgb\n",
"from sklearn.linear_model import LinearRegression\n",
"from utilsforecast.feature_engineering import time_features\n",
"from utilsforecast.plotting import plot_series\n",
"\n",
"from mlforecast.lag_transforms import ExpandingMean, ExponentiallyWeightedMean, RollingMean\n",
Expand Down Expand Up @@ -5439,6 +5440,33 @@
"preds2 = fcst2.predict(10)\n",
"pd.testing.assert_frame_equal(preds, preds2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a39447b-3d6b-4960-83c9-4f927c472ebb",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# direct approach requires only one timestamp and produces same results for two models\n",
"series = generate_daily_series(5)\n",
"h = 5\n",
"freq = 'D'\n",
"train, future = time_features(series, freq=freq, features=['day'], h=h)\n",
"models = [LinearRegression(), lgb.LGBMRegressor(n_estimators=5)]\n",
"\n",
"fcst1 = MLForecast(models=models, freq=freq, date_features=['dayofweek'])\n",
"fcst1.fit(train, max_horizon=h, static_features=[])\n",
"preds1 = fcst1.predict(h=h, X_df=future) # extra timestamps\n",
"\n",
"fcst2 = MLForecast(models=models[::-1], freq=freq, date_features=['dayofweek'])\n",
"fcst2.fit(train, max_horizon=h, static_features=[])\n",
"X_df_one = future.groupby('unique_id', observed=True).head(1)\n",
"preds2 = fcst2.predict(h=h, X_df=X_df_one) # only needed timestamp\n",
"\n",
"pd.testing.assert_frame_equal(preds1, preds2[preds1.columns])"
]
}
],
"metadata": {
Expand Down
Loading