Skip to content

Commit e4776e3

Browse files
committed
move interpolate tensors to cpu
1 parent f665bd6 commit e4776e3

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

pytorch_forecasting/models/temporal_fusion_transformer/sub_modules.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,16 @@ def __init__(
5151
self.gate = nn.Sigmoid()
5252

5353
def interpolate(self, x):
54-
upsampled = F.interpolate(
55-
x.unsqueeze(1), self.output_size, mode="linear", align_corners=True
56-
).squeeze(1)
54+
if x.device.type == "mps":
55+
x = x.to("cpu")
56+
upsampled = F.interpolate(
57+
x.unsqueeze(1), self.output_size, mode="linear", align_corners=True
58+
).squeeze(1)
59+
upsampled = upsampled.to("mps")
60+
else:
61+
upsampled = F.interpolate(
62+
x.unsqueeze(1), self.output_size, mode="linear", align_corners=True
63+
).squeeze(1)
5764
if self.trainable:
5865
upsampled = upsampled * self.gate(self.mask.unsqueeze(0)) * 2.0
5966
return upsampled
@@ -284,7 +291,8 @@ def __init__(
284291
prescalers: Dict[str, nn.Linear] = None,
285292
):
286293
"""
287-
Calculate weights for ``num_inputs`` variables which are each of size ``input_size``
294+
Calculate weights for ``num_inputs`` variables which are each of size
295+
``input_size``
288296
"""
289297
super().__init__()
290298

0 commit comments

Comments
 (0)