@@ -28,6 +28,7 @@ def test_metnet_creation():
28
28
)
29
29
assert not torch .isnan (out ).any (), "Output included NaNs"
30
30
31
+
31
32
def test_metnet_backwards ():
32
33
model = MetNet (
33
34
hidden_dim = 32 ,
@@ -36,7 +37,7 @@ def test_metnet_backwards():
36
37
output_channels = 12 ,
37
38
sat_channels = 12 ,
38
39
input_size = 32 ,
39
- )
40
+ )
40
41
# MetNet expects original HxW to be 4x the input size
41
42
x = torch .randn ((2 , 12 , 16 , 128 , 128 ))
42
43
out = model (x )
@@ -47,7 +48,7 @@ def test_metnet_backwards():
47
48
12 ,
48
49
8 ,
49
50
8 ,
50
- )
51
+ )
51
52
y = torch .randn ((2 , 24 , 12 , 8 , 8 ))
52
53
F .mse_loss (out , y ).backward ()
53
54
assert not torch .isnan (out ).any (), "Output included NaNs"
@@ -86,6 +87,7 @@ def test_metnet2_creation():
86
87
)
87
88
assert not torch .isnan (out ).any (), "Output included NaNs"
88
89
90
+
89
91
def test_metnet2_backward ():
90
92
model = MetNet2 (
91
93
forecast_steps = 8 ,
@@ -95,7 +97,7 @@ def test_metnet2_backward():
95
97
lstm_channels = 32 ,
96
98
encoder_channels = 64 ,
97
99
center_crop_size = 16 ,
98
- )
100
+ )
99
101
# MetNet expects original HxW to be 4x the input size
100
102
x = torch .randn ((2 , 6 , 12 , 256 , 256 ))
101
103
out = model (x )
@@ -106,7 +108,7 @@ def test_metnet2_backward():
106
108
12 ,
107
109
64 ,
108
110
64 ,
109
- )
110
- y = torch .rand ((2 ,8 , 12 ,64 ,64 ))
111
+ )
112
+ y = torch .rand ((2 , 8 , 12 , 64 , 64 ))
111
113
F .mse_loss (out , y ).backward ()
112
- assert not torch .isnan (out ).any (), "Output included NaNs"
114
+ assert not torch .isnan (out ).any (), "Output included NaNs"
0 commit comments