11"""Encoder modules for the satellite/NWP data based on 3D concolutions.
22"""
3+ from typing import List , Union
4+
35import torch
46from torch import nn
5- from typing import Union , List
67from torchvision .transforms import CenterCrop
78
89from pvnet .models .multimodal .encoders .basic_blocks import (
@@ -25,7 +26,7 @@ def __init__(
2526 fc_features : int = 128 ,
2627 spatial_kernel_size : int = 3 ,
2728 temporal_kernel_size : int = 3 ,
28- padding : Union [int , List [int ]] = (1 ,0 , 0 ),
29+ padding : Union [int , List [int ]] = (1 , 0 , 0 ),
2930 ):
3031 """This is the original encoding module used in PVNet, with a few minor tweaks.
3132
@@ -45,8 +46,12 @@ def __init__(
4546 if isinstance (padding , int ):
4647 padding = (padding , padding , padding )
4748 # Check that the output shape of the convolutional layers will be at least 1x1
48- cnn_spatial_output_size = image_size_pixels - (spatial_kernel_size - 1 ) * number_of_conv3d_layers
49- cnn_sequence_length = ((sequence_length - temporal_kernel_size + 2 * padding [0 ]) + 1 ) * number_of_conv3d_layers
49+ cnn_spatial_output_size = (
50+ image_size_pixels - (spatial_kernel_size - 1 ) * number_of_conv3d_layers
51+ )
52+ cnn_sequence_length = (
53+ (sequence_length - temporal_kernel_size + 2 * padding [0 ]) + 1
54+ ) * number_of_conv3d_layers
5055 if not (cnn_spatial_output_size >= 1 ):
5156 raise ValueError (
5257 f"cannot use this many conv3d layers ({ number_of_conv3d_layers } ) with this input "
0 commit comments