@@ -106,40 +106,37 @@ def sample_train_val_datamodule():
106
106
107
107
file_n = 0
108
108
109
- for file in glob .glob ("tests/test_data/sample_batches /train/*.pt" ):
110
- batch = torch .load (file )
109
+ for file_n , file in enumerate ( glob .glob ("tests/test_data/presaved_samples /train/*.pt" ) ):
110
+ sample = torch .load (file )
111
111
112
112
for i in range (n_duplicates ):
113
113
# Save fopr both train and val
114
- torch .save (batch , f"{ tmpdirname } /train/{ file_n :06} .pt" )
115
- torch .save (batch , f"{ tmpdirname } /val/{ file_n :06} .pt" )
116
-
117
- file_n += 1
114
+ torch .save (sample , f"{ tmpdirname } /train/{ file_n :06} .pt" )
115
+ torch .save (sample , f"{ tmpdirname } /val/{ file_n :06} .pt" )
118
116
119
117
dm = DataModule (
120
118
configuration = None ,
119
+ sample_dir = f"{ tmpdirname } " ,
121
120
batch_size = 2 ,
122
121
num_workers = 0 ,
123
122
prefetch_factor = None ,
124
123
train_period = [None , None ],
125
124
val_period = [None , None ],
126
- test_period = [None , None ],
127
- batch_dir = f"{ tmpdirname } " ,
125
+
128
126
)
129
127
yield dm
130
128
131
129
132
130
@pytest .fixture ()
133
131
def sample_datamodule ():
134
132
dm = DataModule (
133
+ sample_dir = "tests/test_data/presaved_samples" ,
135
134
configuration = None ,
136
135
batch_size = 2 ,
137
136
num_workers = 0 ,
138
137
prefetch_factor = None ,
139
138
train_period = [None , None ],
140
139
val_period = [None , None ],
141
- test_period = [None , None ],
142
- batch_dir = "tests/test_data/sample_batches" ,
143
140
)
144
141
return dm
145
142
@@ -157,9 +154,10 @@ def sample_satellite_batch(sample_batch):
157
154
158
155
159
156
@pytest .fixture ()
160
- def sample_pv_batch (sample_batch ):
161
- pv_data = sample_batch [BatchKey .pv ]
162
- return pv_data
157
+ def sample_pv_batch ():
158
+ # TODO: Once PV site inputs are available from ocf-data-sampler UK regional remove these
159
+ # old batches. For now we use the old batches to test the site encoder models
160
+ return torch .load ("tests/test_data/presaved_batches/train/000000.pt" )
163
161
164
162
165
163
@pytest .fixture ()
@@ -191,7 +189,7 @@ def model_minutes_kwargs():
191
189
def encoder_model_kwargs ():
192
190
# Used to test encoder model on satellite data
193
191
kwargs = dict (
194
- sequence_length = ( 90 - 30 ) // 5 + 1 ,
192
+ sequence_length = 7 , # 30 minutes of 5 minutely satellite data = 7 time steps
195
193
image_size_pixels = 24 ,
196
194
in_channels = 11 ,
197
195
out_features = 128 ,
@@ -240,23 +238,16 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs):
240
238
"ukv" : dict (
241
239
_target_ = pvnet .models .multimodal .encoders .encoders3d .DefaultPVNet ,
242
240
_partial_ = True ,
243
- in_channels = 2 ,
241
+ in_channels = 11 ,
244
242
out_features = 128 ,
245
243
number_of_conv3d_layers = 6 ,
246
244
conv3d_channels = 32 ,
247
245
image_size_pixels = 24 ,
248
246
),
249
247
},
250
248
add_image_embedding_channel = True ,
251
- pv_encoder = dict (
252
- _target_ = pvnet .models .multimodal .site_encoders .encoders .SingleAttentionNetwork ,
253
- _partial_ = True ,
254
- num_sites = 349 ,
255
- out_features = 40 ,
256
- num_heads = 4 ,
257
- kdim = 40 ,
258
- id_embed_dim = 20 ,
259
- ),
249
+ # ocf-data-sampler doesn't supprt PV site inputs yet
250
+ pv_encoder = None ,
260
251
output_network = dict (
261
252
_target_ = pvnet .models .multimodal .linear_networks .networks .ResFCNet2 ,
262
253
_partial_ = True ,
@@ -268,11 +259,10 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs):
268
259
embedding_dim = 16 ,
269
260
include_sun = True ,
270
261
include_gsp_yield_history = True ,
271
- sat_history_minutes = 90 ,
262
+ sat_history_minutes = 30 ,
272
263
nwp_history_minutes = {"ukv" : 120 },
273
264
nwp_forecast_minutes = {"ukv" : 480 },
274
- pv_history_minutes = 180 ,
275
- min_sat_delay_minutes = 30 ,
265
+ min_sat_delay_minutes = 0 ,
276
266
)
277
267
278
268
kwargs .update (model_minutes_kwargs )
@@ -297,14 +287,6 @@ def multimodal_quantile_model(multimodal_model_kwargs):
297
287
return model
298
288
299
289
300
- @pytest .fixture ()
301
- def multimodal_weighted_quantile_model (multimodal_model_kwargs ):
302
- model = Model (
303
- output_quantiles = [0.1 , 0.5 , 0.9 ], ** multimodal_model_kwargs , use_weighted_loss = True
304
- )
305
- return model
306
-
307
-
308
290
@pytest .fixture ()
309
291
def multimodal_quantile_model_ignore_minutes (multimodal_model_kwargs ):
310
292
"""Only forecsat second half of the 8 hours"""
0 commit comments