Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Sukh-P committed Jan 21, 2025
1 parent 3f07b32 commit 63f3fc0
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 55 deletions.
4 changes: 2 additions & 2 deletions pvnet/models/multimodal/site_encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def __init__(
)

def _encode_inputs(self, x):
# Shape: [batch size, sequence length, number of sites] -> [8, 197, 1]
# Shape: [batch size, station_id, sequence length, channels] -> [8, 197, 26, 23]
# Shape: [batch size, sequence length, number of sites]
# Shape: [batch size, station_id, sequence length, channels]
input_data = x[f"{self.input_key_to_use}"]
if len(input_data.shape) == 2: # one site per sample
input_data = input_data.unsqueeze(-1) # add dimension of 1 to end to make 3D
Expand Down
7 changes: 3 additions & 4 deletions pvnet/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import torch
from ocf_datapipes.batch import BatchKey

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,7 +83,7 @@ class BatchAccumulator(DictListAccumulator):
"""A class for accumulating batches when using grad accumulation and the batch size is small.
Attributes:
_batches (Dict[BatchKey, list[torch.Tensor]]): Dictionary containing lists of metrics.
_batches (Dict[str, list[torch.Tensor]]): Dictionary containing lists of metrics.
"""

def __init__(self, key_to_keep: str = "gsp"):
Expand All @@ -105,14 +104,14 @@ def _filter_batch_dict(self, d):
]
return {k: v for k, v in d.items() if k in keep_keys}

def append(self, batch: dict[BatchKey, list[torch.Tensor]]):
def append(self, batch: dict[str, list[torch.Tensor]]):
"""Append batch to self"""
if not self:
self._batches = self._dict_init_list(self._filter_batch_dict(batch))
else:
self._dict_list_append(self._batches, self._filter_batch_dict(batch))

def flush(self) -> dict[BatchKey, list[torch.Tensor]]:
def flush(self) -> dict[str, list[torch.Tensor]]:
"""Concatenate all accumulated batches, return, and clear self"""
batch = {}
for k, v in self._batches.items():
Expand Down
2 changes: 1 addition & 1 deletion pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def plot_batch_forecasts(
def _get_numpy(key):
return batch[key].cpu().numpy().squeeze()

y_key = f"{key_to_plot}"
y_key = key_to_plot
y_id_key = f"{key_to_plot}_id"
time_utc_key = f"{key_to_plot}_time_utc"
y = batch[y_key].cpu().numpy() # Select the one it is trained on
Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import hydra

from ocf_datapipes.batch import BatchKey
from datetime import timedelta

import pvnet
Expand Down Expand Up @@ -164,7 +163,7 @@ def sample_pv_batch():
def sample_site_batch():
dm = SiteDataModule(
configuration=None,
batch_size=8,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
Expand Down
5 changes: 2 additions & 3 deletions tests/models/multimodal/site_encoders/test_encoders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from ocf_datapipes.batch import BatchKey
from torch import nn

from pvnet.models.multimodal.site_encoders.encoders import (
Expand Down Expand Up @@ -38,11 +37,11 @@ def test_singleattentionnetwork_forward(sample_site_batch, site_encoder_model_kw
sample_site_batch,
SingleAttentionNetwork,
site_encoder_model_kwargs_dsampler,
batch_size=8,
batch_size=2,
)


# TODO once we have updated the sample batches for sites include this test
# TODO once we have test data which inludes sensor data with sites include this test
# def test_singleattentionnetwork_forward_4d(sample_wind_batch, site_encoder_sensor_model_kwargs):
# _test_model_forward(
# sample_wind_batch,
Expand Down
2 changes: 0 additions & 2 deletions tests/models/multimodal/test_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from torch.optim import SGD
import pytest
from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey


def test_model_forward(multimodal_model, sample_batch):
y = multimodal_model(sample_batch)
Expand Down
42 changes: 1 addition & 41 deletions tests/test_data/presaved_samples_site/data_configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,6 @@ input_data:
dropout_timedeltas_minutes: null
dropout_fraction: 0 # Fraction of samples with dropout

# nwp:
# mo_global:
# provider: mo_global
# zarr_path: /mnt/storage_b/nwp/ceda/global/20*.zarr
# interval_start_minutes: -60
# interval_end_minutes: 1860
# time_resolution_minutes: 60
# channels:
# - temperature_sl
# - wind_u_component_10m
# - wind_v_component_10m
# image_size_pixels_height: 4
# image_size_pixels_width: 4
# # A random value from the list below will be chosen as the delay when dropout is used
# # If set to null no dropout is applied. Values must be negative.
# dropout_timedeltas_minutes: [-180]
# # Dropout applied with this probability
# dropout_fraction: 1.0
# # How long after the NWP init-time are we still willing to use this forecast
# # If null we use each init-time for all steps it covers
# max_staleness_minutes: null

nwp:
ecmwf:
provider: ecmwf
Expand All @@ -56,22 +34,4 @@ input_data:
image_size_pixels_width: 24
dropout_timedeltas_minutes: [-360]
dropout_fraction: 1.0
max_staleness_minutes: null

# gfs:
# provider: gfs
# dropout_fraction: 1.0
# dropout_timedeltas_minutes: [-300]
# interval_start_minutes: 0
# interval_end_minutes: 2160
# channels:
# - t
# - prate
# - u10
# - v10
# - u100
# - v100
# image_size_pixels_height: 2
# image_size_pixels_width: 2
# zarr_path: /mnt/storage_b/nwp/gfs/nw_india/gfs_nw_india_*.zarr
# time_resolution_minutes: 180
max_staleness_minutes: null

0 comments on commit 63f3fc0

Please sign in to comment.