diff --git a/PyTorch/Forecasting/TFT/criterions.py b/PyTorch/Forecasting/TFT/criterions.py index 12de5be76..566467a23 100644 --- a/PyTorch/Forecasting/TFT/criterions.py +++ b/PyTorch/Forecasting/TFT/criterions.py @@ -29,6 +29,10 @@ def forward(self, predictions, targets): return losses def qrisk(pred, tgt, quantiles): + if isinstance(pred, torch.Tensor): + pred = pred.detach().cpu().numpy() + if isinstance(tgt, torch.Tensor): + tgt = tgt.detach().cpu().numpy() diff = pred - tgt ql = (1-quantiles)*np.clip(diff,0, float('inf')) + quantiles*np.clip(-diff,0, float('inf')) losses = ql.reshape(-1, ql.shape[-1])