Skip to content

Commit

Permalink
Merge pull request #15 from openclimatefix/jack/python39
Browse files Browse the repository at this point in the history
Must use Python 3.9 or above
  • Loading branch information
jacobbieker authored Feb 23, 2022
2 parents 3f5696c + 55f6830 commit d6c7219
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ pip install -e .

Alternatively, you can also install a usually older version through ```pip install metnet```

Please ensure that you're using Python version 3.9 or above.
## Data
While the exact training data used for both MetNet and MetNet-2 haven't been released, the papers do go into some detail as to the inputs, which were GOES-16 and MRMS precipitation data, as well as the time period covered. We will be making those splits available, as well as a larger dataset that covers a longer time period, with [HuggingFace Datasets](https://huggingface.co/datasets/openclimatefix/goes-mrms)!
Expand Down
14 changes: 8 additions & 6 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_metnet_creation():
)
assert not torch.isnan(out).any(), "Output included NaNs"


def test_metnet_backwards():
model = MetNet(
hidden_dim=32,
Expand All @@ -36,7 +37,7 @@ def test_metnet_backwards():
output_channels=12,
sat_channels=12,
input_size=32,
)
)
# MetNet expects original HxW to be 4x the input size
x = torch.randn((2, 12, 16, 128, 128))
out = model(x)
Expand All @@ -47,7 +48,7 @@ def test_metnet_backwards():
12,
8,
8,
)
)
y = torch.randn((2, 24, 12, 8, 8))
F.mse_loss(out, y).backward()
assert not torch.isnan(out).any(), "Output included NaNs"
Expand Down Expand Up @@ -86,6 +87,7 @@ def test_metnet2_creation():
)
assert not torch.isnan(out).any(), "Output included NaNs"


def test_metnet2_backward():
model = MetNet2(
forecast_steps=8,
Expand All @@ -95,7 +97,7 @@ def test_metnet2_backward():
lstm_channels=32,
encoder_channels=64,
center_crop_size=16,
)
)
# MetNet expects original HxW to be 4x the input size
x = torch.randn((2, 6, 12, 256, 256))
out = model(x)
Expand All @@ -106,7 +108,7 @@ def test_metnet2_backward():
12,
64,
64,
)
y = torch.rand((2,8,12,64,64))
)
y = torch.rand((2, 8, 12, 64, 64))
F.mse_loss(out, y).backward()
assert not torch.isnan(out).any(), "Output included NaNs"
assert not torch.isnan(out).any(), "Output included NaNs"

0 comments on commit d6c7219

Please sign in to comment.