Skip to content

Commit

Permalink
Update criterions.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tdktrang authored May 21, 2024
1 parent 879d294 commit 5590978
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions PyTorch/Forecasting/TFT/criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def forward(self, predictions, targets):
return losses

def qrisk(pred, tgt, quantiles):
if isinstance(pred, torch.Tensor):
pred = pred.detach().cpu().numpy()
if isinstance(tgt, torch.Tensor):
tgt = tgt.detach().cpu().numpy()
diff = pred - tgt
ql = (1-quantiles)*np.clip(diff,0, float('inf')) + quantiles*np.clip(-diff,0, float('inf'))
losses = ql.reshape(-1, ql.shape[-1])
Expand Down

0 comments on commit 5590978

Please sign in to comment.