Skip to content

Commit 55f6830

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1a750d6 commit 55f6830

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

tests/test_model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def test_metnet_creation():
2828
)
2929
assert not torch.isnan(out).any(), "Output included NaNs"
3030

31+
3132
def test_metnet_backwards():
3233
model = MetNet(
3334
hidden_dim=32,
@@ -36,7 +37,7 @@ def test_metnet_backwards():
3637
output_channels=12,
3738
sat_channels=12,
3839
input_size=32,
39-
)
40+
)
4041
# MetNet expects original HxW to be 4x the input size
4142
x = torch.randn((2, 12, 16, 128, 128))
4243
out = model(x)
@@ -47,7 +48,7 @@ def test_metnet_backwards():
4748
12,
4849
8,
4950
8,
50-
)
51+
)
5152
y = torch.randn((2, 24, 12, 8, 8))
5253
F.mse_loss(out, y).backward()
5354
assert not torch.isnan(out).any(), "Output included NaNs"
@@ -86,6 +87,7 @@ def test_metnet2_creation():
8687
)
8788
assert not torch.isnan(out).any(), "Output included NaNs"
8889

90+
8991
def test_metnet2_backward():
9092
model = MetNet2(
9193
forecast_steps=8,
@@ -95,7 +97,7 @@ def test_metnet2_backward():
9597
lstm_channels=32,
9698
encoder_channels=64,
9799
center_crop_size=16,
98-
)
100+
)
99101
# MetNet expects original HxW to be 4x the input size
100102
x = torch.randn((2, 6, 12, 256, 256))
101103
out = model(x)
@@ -106,7 +108,7 @@ def test_metnet2_backward():
106108
12,
107109
64,
108110
64,
109-
)
110-
y = torch.rand((2,8,12,64,64))
111+
)
112+
y = torch.rand((2, 8, 12, 64, 64))
111113
F.mse_loss(out, y).backward()
112-
assert not torch.isnan(out).any(), "Output included NaNs"
114+
assert not torch.isnan(out).any(), "Output included NaNs"

0 commit comments

Comments
 (0)