Skip to content

Commit 13a24e0

Browse files
authored
Add save pretrained test (#318)
add save_pretrained test
1 parent 3f769ea commit 13a24e0

File tree

5 files changed

+63
-35
lines changed

5 files changed

+63
-35
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
include *.txt
2+
recursive-include pvnet/models/model_cards *.md

tests/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def site_encoder_sensor_model_kwargs():
233233
def raw_multimodal_model_kwargs(model_minutes_kwargs):
234234
kwargs = dict(
235235
sat_encoder=dict(
236-
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
236+
_target_="pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet",
237237
_partial_=True,
238238
in_channels=11,
239239
out_features=128,
@@ -243,7 +243,7 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs):
243243
),
244244
nwp_encoders_dict={
245245
"ukv": dict(
246-
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
246+
_target_="pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet",
247247
_partial_=True,
248248
in_channels=11,
249249
out_features=128,
@@ -256,7 +256,7 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs):
256256
# ocf-data-sampler doesn't supprt PV site inputs yet
257257
pv_encoder=None,
258258
output_network=dict(
259-
_target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2,
259+
_target_="pvnet.models.multimodal.linear_networks.networks.ResFCNet2",
260260
_partial_=True,
261261
fc_hidden_features=128,
262262
n_res_blocks=6,

tests/models/multimodal/test_from_pretrained.py

Lines changed: 0 additions & 11 deletions
This file was deleted.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from pvnet.models.base_model import BaseModel
2+
from pathlib import Path
3+
4+
5+
def test_from_pretrained():
6+
model_name = "openclimatefix/pvnet_uk_region"
7+
model_version = "92266cd9040c590a9e90ee33eafd0e7b92548be8"
8+
9+
_ = BaseModel.from_pretrained(
10+
model_id=model_name,
11+
revision=model_version,
12+
)
13+
14+
15+
def test_save_pretrained(tmp_path, multimodal_model, raw_multimodal_model_kwargs):
16+
data_config_path = "tests/test_data/presaved_samples_uk_regional/data_configuration.yaml"
17+
18+
# Construct the model config
19+
model_config = {"_target_": "pvnet.models.multimodal.multimodal.Model"}
20+
model_config.update(raw_multimodal_model_kwargs)
21+
22+
# Save the model
23+
model_output_dir = f"{tmp_path}/model"
24+
multimodal_model.save_pretrained(
25+
model_output_dir,
26+
config=model_config,
27+
data_config=data_config_path,
28+
wandb_repo=None,
29+
wandb_ids="excluded-for-text",
30+
push_to_hub=False,
31+
repo_id="openclimatefix/pvnet_uk_region",
32+
)
33+
34+
# Load the model
35+
_ = BaseModel.from_pretrained(
36+
model_id=model_output_dir,
37+
revision=None,
38+
)

tests/test_data/presaved_samples_uk_regional/data_configuration.yaml

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ general:
44

55
input_data:
66
gsp:
7-
gsp_zarr_path: /mnt/disks/nwp_rechunk/pv_gsp/pvlive_gsp.zarr
7+
zarr_path: /mnt/disks/nwp_rechunk/pv_gsp/pvlive_gsp.zarr
88
interval_start_minutes: -120
99
interval_end_minutes: 480
1010
time_resolution_minutes: 30
@@ -16,16 +16,16 @@ input_data:
1616

1717
nwp:
1818
ukv:
19-
nwp_provider: ukv
20-
nwp_zarr_path:
19+
provider: ukv
20+
zarr_path:
2121
- /mnt/disks/nwp_rechunk/nwp/ukv/UKV_intermediate_version_7.1.zarr
2222
- /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2021_missing.zarr
2323
- /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2022.zarr
2424
- /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2023.zarr
2525
interval_start_minutes: -120
2626
interval_end_minutes: 480
2727
time_resolution_minutes: 60
28-
nwp_channels:
28+
channels:
2929
# These variables exist in the CEDA training set and in the live MetOffice live service
3030
- t # 2-metre temperature
3131
- dswrf # downwards short-wave radiation flux
@@ -38,20 +38,20 @@ input_data:
3838
- vis # visibility
3939
- si10 # 10-metre wind speed
4040
- prate # precipitation rate
41-
nwp_image_size_pixels_height: 24
42-
nwp_image_size_pixels_width: 24
41+
image_size_pixels_height: 24
42+
image_size_pixels_width: 24
4343
dropout_timedeltas_minutes: [-180]
4444
dropout_fraction: 1.0
4545
max_staleness_minutes: null
4646

4747
ecmwf:
48-
nwp_provider: ecmwf
49-
nwp_zarr_path: /mnt/disks/nwp_rechunk/nwp/ecmwf/UK_v2.zarr
48+
provider: ecmwf
49+
zarr_path: /mnt/disks/nwp_rechunk/nwp/ecmwf/UK_v2.zarr
5050
interval_start_minutes: -120
5151
interval_end_minutes: 480
5252
time_resolution_minutes: 60
5353

54-
nwp_channels:
54+
channels:
5555
- t2m # 2-metre temperature
5656
- dswrf # downwards short-wave radiation flux
5757
- dlwrf # downwards long-wave radiation flux
@@ -66,25 +66,25 @@ input_data:
6666
- v10 # 10-metre V component of wind speed
6767

6868
# The following channels are accumulated and need to be diffed
69-
nwp_accum_channels:
69+
accum_channels:
7070
- dswrf
7171
- dlwrf
7272
- sr
7373
- duvrs
7474

75-
nwp_image_size_pixels_height: 12 # roughly equivalent to ukv 48
76-
nwp_image_size_pixels_width: 12 # roughly equivalent to ukv 48
75+
image_size_pixels_height: 12 # roughly equivalent to ukv 48
76+
image_size_pixels_width: 12 # roughly equivalent to ukv 48
7777
dropout_timedeltas_minutes: [-360]
7878
dropout_fraction: 1.0
7979
max_staleness_minutes: null
8080

8181
sat_pred:
82-
nwp_provider: sat_pred
83-
nwp_zarr_path: /mnt/disks/sat_preds/simvp_preds/*.zarr
82+
provider: sat_pred
83+
zarr_path: /mnt/disks/sat_preds/simvp_preds/*.zarr
8484
interval_start_minutes: 15
8585
interval_end_minutes: 180
8686
time_resolution_minutes: 15
87-
nwp_channels:
87+
channels:
8888
- IR_016
8989
- IR_039
9090
- IR_087
@@ -96,14 +96,14 @@ input_data:
9696
- VIS008
9797
- WV_062
9898
- WV_073
99-
nwp_image_size_pixels_height: 24
100-
nwp_image_size_pixels_width: 24
99+
image_size_pixels_height: 24
100+
image_size_pixels_width: 24
101101
dropout_timedeltas_minutes: null
102102
dropout_fraction: 0
103103
max_staleness_minutes: null
104104

105105
satellite:
106-
satellite_zarr_path:
106+
zarr_path:
107107
- /mnt/disks/nwp_rechunk/sat/2019_nonhrv.zarr
108108
- /mnt/disks/nwp_rechunk/sat/2020_nonhrv.zarr
109109
- /mnt/disks/nwp_rechunk/sat/2021_nonhrv.zarr
@@ -112,7 +112,7 @@ input_data:
112112
interval_start_minutes: -30
113113
interval_end_minutes: 0
114114
time_resolution_minutes: 5
115-
satellite_channels:
115+
channels:
116116
- IR_016
117117
- IR_039
118118
- IR_087
@@ -124,7 +124,7 @@ input_data:
124124
- VIS008
125125
- WV_062
126126
- WV_073
127-
satellite_image_size_pixels_height: 24
128-
satellite_image_size_pixels_width: 24
127+
image_size_pixels_height: 24
128+
image_size_pixels_width: 24
129129
dropout_timedeltas_minutes: null
130130
dropout_fraction: 0.

0 commit comments

Comments
 (0)