diff --git a/pvnet/data/pv_site_datamodule.py b/pvnet/data/pv_site_datamodule.py index 1b52fcf4..06101386 100644 --- a/pvnet/data/pv_site_datamodule.py +++ b/pvnet/data/pv_site_datamodule.py @@ -75,7 +75,6 @@ def __init__( def _get_datapipe(self, start_time, end_time): data_pipeline = pvnet_site_netcdf_datapipe( - self.configuration, keys=["pv", "nwp"], ) @@ -89,7 +88,6 @@ def _get_datapipe(self, start_time, end_time): def _get_premade_batches_datapipe(self, subdir, shuffle=False): filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) data_pipeline = pvnet_site_netcdf_datapipe( - config_filename=self.configuration, keys=["pv", "nwp"], filenames=filenames, ) @@ -103,14 +101,14 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): data_pipeline.shuffle(buffer_size=100) .sharding_filter() # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.sensor) + .split_batches(splitting_key=BatchKey.pv) .shuffle(buffer_size=100 * self.batch_size) ) else: data_pipeline = ( data_pipeline.sharding_filter() # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.sensor) + .split_batches(splitting_key=BatchKey.pv) ) data_pipeline = ( diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py index 330271ef..ba804685 100644 --- a/pvnet/data/wind_datamodule.py +++ b/pvnet/data/wind_datamodule.py @@ -89,7 +89,6 @@ def _get_datapipe(self, start_time, end_time): def _get_premade_batches_datapipe(self, subdir, shuffle=False): filenames = list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")) data_pipeline = windnet_netcdf_datapipe( - config_filename=self.configuration, keys=["wind", "nwp"], filenames=filenames, ) @@ -103,14 +102,14 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): data_pipeline.shuffle(buffer_size=100) .sharding_filter() # Split the batches and reshuffle them to be combined into new batches - .split_batches(splitting_key=BatchKey.sensor) + .split_batches(splitting_key=BatchKey.wind) .shuffle(buffer_size=100 * self.batch_size) ) else: data_pipeline = ( data_pipeline.sharding_filter() # Split the batches so we can use any batch-size - .split_batches(splitting_key=BatchKey.sensor) + .split_batches(splitting_key=BatchKey.wind) ) data_pipeline = ( diff --git a/pvnet/models/base_model.py b/pvnet/models/base_model.py index d0f462ce..ade7895b 100644 --- a/pvnet/models/base_model.py +++ b/pvnet/models/base_model.py @@ -27,6 +27,7 @@ PredAccumulator, WeightedLosses, ) + from pvnet.optimizers import AbstractOptimizer from pvnet.utils import construct_ocf_ml_metrics_batch_df, plot_batch_forecasts diff --git a/pvnet/models/multimodal/encoders/encoders3d.py b/pvnet/models/multimodal/encoders/encoders3d.py index 8d093280..056d284b 100644 --- a/pvnet/models/multimodal/encoders/encoders3d.py +++ b/pvnet/models/multimodal/encoders/encoders3d.py @@ -37,8 +37,8 @@ def __init__( super().__init__(sequence_length, image_size_pixels, in_channels, out_features) # Check that the output shape of the convolutional layers will be at least 1x1 - cnn_spatial_output_size = image_size_pixels - 12 * number_of_conv3d_layers - cnn_sequence_length = sequence_length - 6 * number_of_conv3d_layers + cnn_spatial_output_size = image_size_pixels - 2 * number_of_conv3d_layers + cnn_sequence_length = sequence_length #- 6 * number_of_conv3d_layers if not (cnn_spatial_output_size >= 1): raise ValueError( f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input " @@ -51,8 +51,8 @@ def __init__( nn.Conv3d( in_channels=in_channels, out_channels=conv3d_channels, - kernel_size=(7, 13, 13), - padding=(0, 0, 0), + kernel_size=(3, 3, 3), + padding=(1, 0, 0), ), nn.ELU(), ] @@ -61,8 +61,8 @@ def __init__( nn.Conv3d( in_channels=conv3d_channels, out_channels=conv3d_channels, - kernel_size=(7, 13, 13), - padding=(0, 0, 0), + kernel_size=(3, 3, 3), + padding=(1, 0, 0), ), nn.ELU(), ] diff --git a/pvnet/models/multimodal/multimodal.py b/pvnet/models/multimodal/multimodal.py index f482c21e..a58582c0 100644 --- a/pvnet/models/multimodal/multimodal.py +++ b/pvnet/models/multimodal/multimodal.py @@ -297,7 +297,7 @@ def forward(self, x): if self.include_sun: sun = torch.cat( - (x[BatchKey.pv_solar_azimuth], x[BatchKey.pv_solar_elevation]), dim=1 + (x[BatchKey.gsp_solar_azimuth], x[BatchKey.gsp_solar_elevation]), dim=1 ).float() sun = self.sun_fc1(sun) modes["sun"] = sun