1
1
"""Encoder modules for the satellite/NWP data based on 3D concolutions.
2
2
"""
3
+ from typing import List , Union
4
+
3
5
import torch
4
6
from torch import nn
5
- from typing import Union , List
6
7
from torchvision .transforms import CenterCrop
7
8
8
9
from pvnet .models .multimodal .encoders .basic_blocks import (
@@ -25,7 +26,7 @@ def __init__(
25
26
fc_features : int = 128 ,
26
27
spatial_kernel_size : int = 3 ,
27
28
temporal_kernel_size : int = 3 ,
28
- padding : Union [int , List [int ]] = (1 ,0 , 0 ),
29
+ padding : Union [int , List [int ]] = (1 , 0 , 0 ),
29
30
):
30
31
"""This is the original encoding module used in PVNet, with a few minor tweaks.
31
32
@@ -45,8 +46,12 @@ def __init__(
45
46
if isinstance (padding , int ):
46
47
padding = (padding , padding , padding )
47
48
# Check that the output shape of the convolutional layers will be at least 1x1
48
- cnn_spatial_output_size = image_size_pixels - (spatial_kernel_size - 1 ) * number_of_conv3d_layers
49
- cnn_sequence_length = ((sequence_length - temporal_kernel_size + 2 * padding [0 ]) + 1 ) * number_of_conv3d_layers
49
+ cnn_spatial_output_size = (
50
+ image_size_pixels - (spatial_kernel_size - 1 ) * number_of_conv3d_layers
51
+ )
52
+ cnn_sequence_length = (
53
+ (sequence_length - temporal_kernel_size + 2 * padding [0 ]) + 1
54
+ ) * number_of_conv3d_layers
50
55
if not (cnn_spatial_output_size >= 1 ):
51
56
raise ValueError (
52
57
f"cannot use this many conv3d layers ({ number_of_conv3d_layers } ) with this input "
0 commit comments