File tree Expand file tree Collapse file tree 1 file changed +12
-4
lines changed
pytorch_forecasting/models/temporal_fusion_transformer Expand file tree Collapse file tree 1 file changed +12
-4
lines changed Original file line number Diff line number Diff line change @@ -51,9 +51,16 @@ def __init__(
51
51
self .gate = nn .Sigmoid ()
52
52
53
53
def interpolate (self , x ):
54
- upsampled = F .interpolate (
55
- x .unsqueeze (1 ), self .output_size , mode = "linear" , align_corners = True
56
- ).squeeze (1 )
54
+ if x .device .type == "mps" :
55
+ x = x .to ("cpu" )
56
+ upsampled = F .interpolate (
57
+ x .unsqueeze (1 ), self .output_size , mode = "linear" , align_corners = True
58
+ ).squeeze (1 )
59
+ upsampled = upsampled .to ("mps" )
60
+ else :
61
+ upsampled = F .interpolate (
62
+ x .unsqueeze (1 ), self .output_size , mode = "linear" , align_corners = True
63
+ ).squeeze (1 )
57
64
if self .trainable :
58
65
upsampled = upsampled * self .gate (self .mask .unsqueeze (0 )) * 2.0
59
66
return upsampled
@@ -284,7 +291,8 @@ def __init__(
284
291
prescalers : Dict [str , nn .Linear ] = None ,
285
292
):
286
293
"""
287
- Calculate weights for ``num_inputs`` variables which are each of size ``input_size``
294
+ Calculate weights for ``num_inputs`` variables which are each of size
295
+ ``input_size``
288
296
"""
289
297
super ().__init__ ()
290
298
You can’t perform that action at this time.
0 commit comments