1
- import pytest
2
1
from pvnet .data .datamodule import DataModule
3
- from pvnet .data .wind_datamodule import WindDataModule
4
- from pvnet .data .pv_site_datamodule import PVSiteDataModule
5
- import os
6
- from ocf_datapipes .batch .batches import BatchKey , NWPBatchKey
7
2
8
3
9
4
def test_init ():
@@ -17,58 +12,6 @@ def test_init():
17
12
val_period = [None , None ],
18
13
)
19
14
20
-
21
- @pytest .mark .skip (reason = "Has not been updated for ocf-data-sampler yet" )
22
- def test_wind_init ():
23
- dm = WindDataModule (
24
- configuration = None ,
25
- batch_size = 2 ,
26
- num_workers = 0 ,
27
- prefetch_factor = None ,
28
- train_period = [None , None ],
29
- val_period = [None , None ],
30
- test_period = [None , None ],
31
- batch_dir = "tests/data/sample_batches" ,
32
- )
33
-
34
-
35
- @pytest .mark .skip (reason = "Has not been updated for ocf-data-sampler yet" )
36
- def test_wind_init_with_nwp_filter ():
37
- dm = WindDataModule (
38
- configuration = None ,
39
- batch_size = 2 ,
40
- num_workers = 0 ,
41
- prefetch_factor = None ,
42
- train_period = [None , None ],
43
- val_period = [None , None ],
44
- test_period = [None , None ],
45
- batch_dir = "tests/test_data/sample_wind_batches" ,
46
- nwp_channels = {"ecmwf" : ["t2m" , "v200" ]},
47
- )
48
- dataloader = iter (dm .train_dataloader ())
49
-
50
- batch = next (dataloader )
51
- batch_channels = batch [BatchKey .nwp ]["ecmwf" ][NWPBatchKey .nwp_channel_names ]
52
- print (batch_channels )
53
- for v in ["t2m" , "v200" ]:
54
- assert v in batch_channels
55
- assert batch [BatchKey .nwp ]["ecmwf" ][NWPBatchKey .nwp ].shape [2 ] == 2
56
-
57
-
58
- @pytest .mark .skip (reason = "Has not been updated for ocf-data-sampler yet" )
59
- def test_pv_site_init ():
60
- dm = PVSiteDataModule (
61
- configuration = f"{ os .path .dirname (os .path .abspath (__file__ ))} /test_data/sample_batches/data_configuration.yaml" ,
62
- batch_size = 2 ,
63
- num_workers = 0 ,
64
- prefetch_factor = None ,
65
- train_period = [None , None ],
66
- val_period = [None , None ],
67
- test_period = [None , None ],
68
- batch_dir = None ,
69
- )
70
-
71
-
72
15
def test_iter ():
73
16
dm = DataModule (
74
17
configuration = None ,
@@ -104,3 +47,5 @@ def test_iter_multiprocessing():
104
47
105
48
# Make sure we've served 2 batches
106
49
assert served_batches == 2
50
+
51
+ # TODO add test cases with some netcdfs premade samples
0 commit comments