@@ -52,8 +52,11 @@ def test_model_forward_no_satellite(configuration_conv3d):
52
52
# start model
53
53
model = Model (** config )
54
54
55
- dataset_configuration = configuration_conv3d
55
+ dataset_configuration : Configuration = configuration_conv3d
56
56
dataset_configuration .input_data .nwp .nwp_image_size_pixels_height = 16
57
+ dataset_configuration .input_data .nwp .nwp_image_size_pixels_width = 16
58
+ dataset_configuration .input_data .satellite .satellite_image_size_pixels_height = 16
59
+ dataset_configuration .input_data .satellite .satellite_image_size_pixels_width = 16
57
60
58
61
# create fake data loader
59
62
train_dataset = FakeDataset (configuration = dataset_configuration )
@@ -74,8 +77,12 @@ def test_train(configuration_conv3d):
74
77
config_file = "tests/configs/model/conv3d_sat_nwp.yaml"
75
78
config = load_config (config_file )
76
79
77
- dataset_configuration = configuration_conv3d
80
+ dataset_configuration : Configuration = configuration_conv3d
78
81
dataset_configuration .input_data .nwp .nwp_image_size_pixels_height = 16
82
+ dataset_configuration .input_data .nwp .nwp_image_size_pixels_width = 16
83
+
84
+ dataset_configuration .input_data .satellite .satellite_image_size_pixels_height = 16
85
+ dataset_configuration .input_data .satellite .satellite_image_size_pixels_width = 16
79
86
80
87
# start model
81
88
model = Model (** config )
@@ -85,7 +92,7 @@ def test_train(configuration_conv3d):
85
92
train_dataloader = torch .utils .data .DataLoader (train_dataset , batch_size = None )
86
93
87
94
# fit model
88
- trainer = pl .Trainer (gpus = 0 , max_epochs = 1 )
95
+ trainer = pl .Trainer (max_epochs = 1 )
89
96
trainer .fit (model , train_dataloader )
90
97
91
98
# predict over training set
0 commit comments