Skip to content

Commit

Permalink
Fix and update for newer datapipes
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Jan 31, 2024
1 parent ef30539 commit afff584
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 14 deletions.
6 changes: 2 additions & 4 deletions pvnet/data/pv_site_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(

def _get_datapipe(self, start_time, end_time):
data_pipeline = pvnet_site_netcdf_datapipe(
self.configuration,
keys=["pv", "nwp"],
)

Expand All @@ -89,7 +88,6 @@ def _get_datapipe(self, start_time, end_time):
def _get_premade_batches_datapipe(self, subdir, shuffle=False):
filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc"))
data_pipeline = pvnet_site_netcdf_datapipe(
config_filename=self.configuration,
keys=["pv", "nwp"],
filenames=filenames,
)
Expand All @@ -103,14 +101,14 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline.shuffle(buffer_size=100)
.sharding_filter()
# Split the batches and reshuffle them to be combined into new batches
.split_batches(splitting_key=BatchKey.sensor)
.split_batches(splitting_key=BatchKey.pv)
.shuffle(buffer_size=100 * self.batch_size)
)
else:
data_pipeline = (
data_pipeline.sharding_filter()
# Split the batches so we can use any batch-size
.split_batches(splitting_key=BatchKey.sensor)
.split_batches(splitting_key=BatchKey.pv)
)

data_pipeline = (
Expand Down
5 changes: 2 additions & 3 deletions pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def _get_datapipe(self, start_time, end_time):
def _get_premade_batches_datapipe(self, subdir, shuffle=False):
filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc"))
data_pipeline = windnet_netcdf_datapipe(
config_filename=self.configuration,
keys=["wind", "nwp"],
filenames=filenames,
)
Expand All @@ -103,14 +102,14 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline.shuffle(buffer_size=100)
.sharding_filter()
# Split the batches and reshuffle them to be combined into new batches
.split_batches(splitting_key=BatchKey.sensor)
.split_batches(splitting_key=BatchKey.wind)
.shuffle(buffer_size=100 * self.batch_size)
)
else:
data_pipeline = (
data_pipeline.sharding_filter()
# Split the batches so we can use any batch-size
.split_batches(splitting_key=BatchKey.sensor)
.split_batches(splitting_key=BatchKey.wind)
)

data_pipeline = (
Expand Down
1 change: 1 addition & 0 deletions pvnet/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
PredAccumulator,
WeightedLosses,
)

from pvnet.optimizers import AbstractOptimizer
from pvnet.utils import construct_ocf_ml_metrics_batch_df, plot_batch_forecasts

Expand Down
12 changes: 6 additions & 6 deletions pvnet/models/multimodal/encoders/encoders3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def __init__(
super().__init__(sequence_length, image_size_pixels, in_channels, out_features)

# Check that the output shape of the convolutional layers will be at least 1x1
cnn_spatial_output_size = image_size_pixels - 12 * number_of_conv3d_layers
cnn_sequence_length = sequence_length - 6 * number_of_conv3d_layers
cnn_spatial_output_size = image_size_pixels - 2 * number_of_conv3d_layers
cnn_sequence_length = sequence_length #- 6 * number_of_conv3d_layers
if not (cnn_spatial_output_size >= 1):
raise ValueError(
f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input "
Expand All @@ -51,8 +51,8 @@ def __init__(
nn.Conv3d(
in_channels=in_channels,
out_channels=conv3d_channels,
kernel_size=(7, 13, 13),
padding=(0, 0, 0),
kernel_size=(3, 3, 3),
padding=(1, 0, 0),
),
nn.ELU(),
]
Expand All @@ -61,8 +61,8 @@ def __init__(
nn.Conv3d(
in_channels=conv3d_channels,
out_channels=conv3d_channels,
kernel_size=(7, 13, 13),
padding=(0, 0, 0),
kernel_size=(3, 3, 3),
padding=(1, 0, 0),
),
nn.ELU(),
]
Expand Down
2 changes: 1 addition & 1 deletion pvnet/models/multimodal/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def forward(self, x):

if self.include_sun:
sun = torch.cat(
(x[BatchKey.pv_solar_azimuth], x[BatchKey.pv_solar_elevation]), dim=1
(x[BatchKey.gsp_solar_azimuth], x[BatchKey.gsp_solar_elevation]), dim=1
).float()
sun = self.sun_fc1(sun)
modes["sun"] = sun
Expand Down

0 comments on commit afff584

Please sign in to comment.