From fed3fc4506602472a581145e5992d3281fe28b3b Mon Sep 17 00:00:00 2001 From: RUPESH-KUMAR01 <118011558+RUPESH-KUMAR01@users.noreply.github.com> Date: Sun, 23 Feb 2025 23:31:01 +0530 Subject: [PATCH] [BUG] pytorch-forecasting#1752 Fixing --- pytorch_forecasting/data/timeseries.py | 2 +- pytorch_forecasting/models/base/_base_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 336eecd5..afe46587 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -2513,7 +2513,7 @@ def to_dataloader( Parameters ---------- - train : bool, optional, default=Trze + train : bool, optional, default=True whether dataloader is used for training (True) or prediction (False). Will shuffle and drop last batch if True. Defaults to True. batch_size : int, optional, default=64 diff --git a/pytorch_forecasting/models/base/_base_model.py b/pytorch_forecasting/models/base/_base_model.py index 5e6c6839..d3b9a789 100644 --- a/pytorch_forecasting/models/base/_base_model.py +++ b/pytorch_forecasting/models/base/_base_model.py @@ -387,7 +387,7 @@ def on_predict_epoch_end( if self.return_decoder_lengths: output["decoder_lengths"] = torch.cat(self._decode_lengths, dim=0) if self.return_y: - y = concat_sequences([yi[0] for yi in self._y]) + y = _torch_cat_na([yi[0] for yi in self._y]) if self._y[-1][1] is None: weight = None else: