Skip to content

Commit a7f6491

Browse files
committed
fix issue when there is mps fallback enabled during training
1 parent d57b0bb commit a7f6491

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytorch_forecasting/models/base_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,9 @@ def step(
828828
loss = self.loss(prediction, y)
829829
else:
830830
loss = None
831+
# ensure that loss has require_grad
832+
if loss is not None and loss.device.type == "mps":
833+
loss.requires_grad_(True)
831834
self.log(
832835
f"{self.current_stage}_loss",
833836
loss,

0 commit comments

Comments
 (0)