Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jacob/windnet' into jacob/windnet
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Feb 5, 2024
2 parents 4e2e575 + 7c1b8fc commit 693ac46
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
15 changes: 11 additions & 4 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ocf_datapipes.training.pvnet import pvnet_datapipe
from torch.utils.data.datapipes.iter import FileLister

from pvnet.data.utils import batch_to_tensor
from pvnet.data.base import BaseDataModule
from pvnet.data.utils import batch_to_tensor


class DataModule(BaseDataModule):
Expand Down Expand Up @@ -40,8 +40,16 @@ def __init__(
`configuration` or 'train/val/test_period'.
"""
super().__init__(configuration=configuration, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, train_period=train_period, val_period=val_period, test_period=test_period, batch_dir=batch_dir)

super().__init__(
configuration=configuration,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
train_period=train_period,
val_period=val_period,
test_period=test_period,
batch_dir=batch_dir,
)

def _get_datapipe(self, start_time, end_time):
data_pipeline = pvnet_datapipe(
Expand Down Expand Up @@ -82,4 +90,3 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):
)

return data_pipeline

10 changes: 8 additions & 2 deletions pvnet/data/pv_site_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch
from ocf_datapipes.training.pvnet_site import pvnet_site_netcdf_datapipe
from pvnet.data.base import BaseDataModule

from pvnet.data.base import BaseDataModule
from pvnet.data.utils import batch_to_tensor


Expand Down Expand Up @@ -33,7 +33,13 @@ def __init__(
'train/val/test_period'.
"""
super().__init__(configuration=configuration, batch_size=batch_size, num_workers=num_workers, prefetch_factor=prefetch_factor, batch_dir=batch_dir)
super().__init__(
configuration=configuration,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
batch_dir=batch_dir,
)

def _get_datapipe(self, start_time, end_time):
data_pipeline = pvnet_site_netcdf_datapipe(
Expand Down
11 changes: 8 additions & 3 deletions pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from ocf_datapipes.batch import BatchKey, stack_np_examples_into_batch
from ocf_datapipes.training.windnet import windnet_netcdf_datapipe
from pvnet.data.base import BaseDataModule

from pvnet.data.base import BaseDataModule
from pvnet.data.utils import batch_to_tensor


Expand Down Expand Up @@ -33,8 +33,13 @@ def __init__(
'train/val/test_period'.
"""
super().__init__(configuration=configuration, batch_size=batch_size, num_workers=num_workers,
prefetch_factor=prefetch_factor, batch_dir=batch_dir)
super().__init__(
configuration=configuration,
batch_size=batch_size,
num_workers=num_workers,
prefetch_factor=prefetch_factor,
batch_dir=batch_dir,
)

def _get_datapipe(self, start_time, end_time):
data_pipeline = windnet_netcdf_datapipe(
Expand Down

0 comments on commit 693ac46

Please sign in to comment.