Skip to content

Commit

Permalink
Add save pretrained test (#318)
Browse files Browse the repository at this point in the history
add save_pretrained test
  • Loading branch information
dfulu authored Feb 12, 2025
1 parent 3f769ea commit 13a24e0
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 35 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include *.txt
recursive-include pvnet/models/model_cards *.md
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def site_encoder_sensor_model_kwargs():
def raw_multimodal_model_kwargs(model_minutes_kwargs):
kwargs = dict(
sat_encoder=dict(
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
_target_="pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet",
_partial_=True,
in_channels=11,
out_features=128,
Expand All @@ -243,7 +243,7 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs):
),
nwp_encoders_dict={
"ukv": dict(
_target_=pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet,
_target_="pvnet.models.multimodal.encoders.encoders3d.DefaultPVNet",
_partial_=True,
in_channels=11,
out_features=128,
Expand All @@ -256,7 +256,7 @@ def raw_multimodal_model_kwargs(model_minutes_kwargs):
# ocf-data-sampler doesn't supprt PV site inputs yet
pv_encoder=None,
output_network=dict(
_target_=pvnet.models.multimodal.linear_networks.networks.ResFCNet2,
_target_="pvnet.models.multimodal.linear_networks.networks.ResFCNet2",
_partial_=True,
fc_hidden_features=128,
n_res_blocks=6,
Expand Down
11 changes: 0 additions & 11 deletions tests/models/multimodal/test_from_pretrained.py

This file was deleted.

38 changes: 38 additions & 0 deletions tests/models/multimodal/test_save_load_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pvnet.models.base_model import BaseModel
from pathlib import Path


def test_from_pretrained():
model_name = "openclimatefix/pvnet_uk_region"
model_version = "92266cd9040c590a9e90ee33eafd0e7b92548be8"

_ = BaseModel.from_pretrained(
model_id=model_name,
revision=model_version,
)


def test_save_pretrained(tmp_path, multimodal_model, raw_multimodal_model_kwargs):
data_config_path = "tests/test_data/presaved_samples_uk_regional/data_configuration.yaml"

# Construct the model config
model_config = {"_target_": "pvnet.models.multimodal.multimodal.Model"}
model_config.update(raw_multimodal_model_kwargs)

# Save the model
model_output_dir = f"{tmp_path}/model"
multimodal_model.save_pretrained(
model_output_dir,
config=model_config,
data_config=data_config_path,
wandb_repo=None,
wandb_ids="excluded-for-text",
push_to_hub=False,
repo_id="openclimatefix/pvnet_uk_region",
)

# Load the model
_ = BaseModel.from_pretrained(
model_id=model_output_dir,
revision=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ general:

input_data:
gsp:
gsp_zarr_path: /mnt/disks/nwp_rechunk/pv_gsp/pvlive_gsp.zarr
zarr_path: /mnt/disks/nwp_rechunk/pv_gsp/pvlive_gsp.zarr
interval_start_minutes: -120
interval_end_minutes: 480
time_resolution_minutes: 30
Expand All @@ -16,16 +16,16 @@ input_data:

nwp:
ukv:
nwp_provider: ukv
nwp_zarr_path:
provider: ukv
zarr_path:
- /mnt/disks/nwp_rechunk/nwp/ukv/UKV_intermediate_version_7.1.zarr
- /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2021_missing.zarr
- /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2022.zarr
- /mnt/disks/nwp_rechunk/nwp/ukv/UKV_2023.zarr
interval_start_minutes: -120
interval_end_minutes: 480
time_resolution_minutes: 60
nwp_channels:
channels:
# These variables exist in the CEDA training set and in the live MetOffice live service
- t # 2-metre temperature
- dswrf # downwards short-wave radiation flux
Expand All @@ -38,20 +38,20 @@ input_data:
- vis # visibility
- si10 # 10-metre wind speed
- prate # precipitation rate
nwp_image_size_pixels_height: 24
nwp_image_size_pixels_width: 24
image_size_pixels_height: 24
image_size_pixels_width: 24
dropout_timedeltas_minutes: [-180]
dropout_fraction: 1.0
max_staleness_minutes: null

ecmwf:
nwp_provider: ecmwf
nwp_zarr_path: /mnt/disks/nwp_rechunk/nwp/ecmwf/UK_v2.zarr
provider: ecmwf
zarr_path: /mnt/disks/nwp_rechunk/nwp/ecmwf/UK_v2.zarr
interval_start_minutes: -120
interval_end_minutes: 480
time_resolution_minutes: 60

nwp_channels:
channels:
- t2m # 2-metre temperature
- dswrf # downwards short-wave radiation flux
- dlwrf # downwards long-wave radiation flux
Expand All @@ -66,25 +66,25 @@ input_data:
- v10 # 10-metre V component of wind speed

# The following channels are accumulated and need to be diffed
nwp_accum_channels:
accum_channels:
- dswrf
- dlwrf
- sr
- duvrs

nwp_image_size_pixels_height: 12 # roughly equivalent to ukv 48
nwp_image_size_pixels_width: 12 # roughly equivalent to ukv 48
image_size_pixels_height: 12 # roughly equivalent to ukv 48
image_size_pixels_width: 12 # roughly equivalent to ukv 48
dropout_timedeltas_minutes: [-360]
dropout_fraction: 1.0
max_staleness_minutes: null

sat_pred:
nwp_provider: sat_pred
nwp_zarr_path: /mnt/disks/sat_preds/simvp_preds/*.zarr
provider: sat_pred
zarr_path: /mnt/disks/sat_preds/simvp_preds/*.zarr
interval_start_minutes: 15
interval_end_minutes: 180
time_resolution_minutes: 15
nwp_channels:
channels:
- IR_016
- IR_039
- IR_087
Expand All @@ -96,14 +96,14 @@ input_data:
- VIS008
- WV_062
- WV_073
nwp_image_size_pixels_height: 24
nwp_image_size_pixels_width: 24
image_size_pixels_height: 24
image_size_pixels_width: 24
dropout_timedeltas_minutes: null
dropout_fraction: 0
max_staleness_minutes: null

satellite:
satellite_zarr_path:
zarr_path:
- /mnt/disks/nwp_rechunk/sat/2019_nonhrv.zarr
- /mnt/disks/nwp_rechunk/sat/2020_nonhrv.zarr
- /mnt/disks/nwp_rechunk/sat/2021_nonhrv.zarr
Expand All @@ -112,7 +112,7 @@ input_data:
interval_start_minutes: -30
interval_end_minutes: 0
time_resolution_minutes: 5
satellite_channels:
channels:
- IR_016
- IR_039
- IR_087
Expand All @@ -124,7 +124,7 @@ input_data:
- VIS008
- WV_062
- WV_073
satellite_image_size_pixels_height: 24
satellite_image_size_pixels_width: 24
image_size_pixels_height: 24
image_size_pixels_width: 24
dropout_timedeltas_minutes: null
dropout_fraction: 0.

0 comments on commit 13a24e0

Please sign in to comment.