diff --git a/PyTorch/Forecasting/TFT/inference.py b/PyTorch/Forecasting/TFT/inference.py index a4ed0daff..cbbffe4e4 100644 --- a/PyTorch/Forecasting/TFT/inference.py +++ b/PyTorch/Forecasting/TFT/inference.py @@ -139,6 +139,8 @@ def inference(args, config, model, data_loader, scalers, cat_encodings): if args.joint_visualization or args.save_predictions: ids = torch.from_numpy(ids.squeeze()) #ids = torch.cat([x['id'][0] for x in data_loader.dataset]) + unscaled_predictions = torch.tensor(unscaled_predictions) + unscaled_targets = torch.tensor(unscaled_targets) joint_graphs = torch.cat([unscaled_targets, unscaled_predictions], dim=2) graphs = {i:joint_graphs[ids == i, :, :] for i in set(ids.tolist())} for key, g in graphs.items(): #timeseries id, joint targets and predictions