Skip to content

Commit 9e2b405

Browse files
committed
Fix some more tests
1 parent 31e5fa4 commit 9e2b405

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

tests/models/conv3d/test_conv3d_model_sat_nwp.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,11 @@ def test_model_forward_no_satellite(configuration_conv3d):
5252
# start model
5353
model = Model(**config)
5454

55-
dataset_configuration = configuration_conv3d
55+
dataset_configuration: Configuration = configuration_conv3d
5656
dataset_configuration.input_data.nwp.nwp_image_size_pixels_height = 16
57+
dataset_configuration.input_data.nwp.nwp_image_size_pixels_width = 16
58+
dataset_configuration.input_data.satellite.satellite_image_size_pixels_height = 16
59+
dataset_configuration.input_data.satellite.satellite_image_size_pixels_width = 16
5760

5861
# create fake data loader
5962
train_dataset = FakeDataset(configuration=dataset_configuration)
@@ -74,8 +77,12 @@ def test_train(configuration_conv3d):
7477
config_file = "tests/configs/model/conv3d_sat_nwp.yaml"
7578
config = load_config(config_file)
7679

77-
dataset_configuration = configuration_conv3d
80+
dataset_configuration: Configuration = configuration_conv3d
7881
dataset_configuration.input_data.nwp.nwp_image_size_pixels_height = 16
82+
dataset_configuration.input_data.nwp.nwp_image_size_pixels_width = 16
83+
84+
dataset_configuration.input_data.satellite.satellite_image_size_pixels_height = 16
85+
dataset_configuration.input_data.satellite.satellite_image_size_pixels_width = 16
7986

8087
# start model
8188
model = Model(**config)
@@ -85,7 +92,7 @@ def test_train(configuration_conv3d):
8592
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=None)
8693

8794
# fit model
88-
trainer = pl.Trainer(gpus=0, max_epochs=1)
95+
trainer = pl.Trainer(max_epochs=1)
8996
trainer.fit(model, train_dataloader)
9097

9198
# predict over training set

0 commit comments

Comments
 (0)