Skip to content

Commit 1e7b648

Browse files
committed
Merge remote-tracking branch 'origin/jacob/windnet' into jacob/windnet
2 parents a3867c1 + ee0d98c commit 1e7b648

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

pvnet/models/multimodal/encoders/encoders3d.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Encoder modules for the satellite/NWP data based on 3D concolutions.
22
"""
3+
from typing import List, Union
4+
35
import torch
46
from torch import nn
5-
from typing import Union, List
67
from torchvision.transforms import CenterCrop
78

89
from pvnet.models.multimodal.encoders.basic_blocks import (
@@ -25,7 +26,7 @@ def __init__(
2526
fc_features: int = 128,
2627
spatial_kernel_size: int = 3,
2728
temporal_kernel_size: int = 3,
28-
padding: Union[int, List[int]] = (1,0,0),
29+
padding: Union[int, List[int]] = (1, 0, 0),
2930
):
3031
"""This is the original encoding module used in PVNet, with a few minor tweaks.
3132
@@ -45,8 +46,12 @@ def __init__(
4546
if isinstance(padding, int):
4647
padding = (padding, padding, padding)
4748
# Check that the output shape of the convolutional layers will be at least 1x1
48-
cnn_spatial_output_size = image_size_pixels - (spatial_kernel_size-1) * number_of_conv3d_layers
49-
cnn_sequence_length = ((sequence_length - temporal_kernel_size + 2*padding[0]) + 1) * number_of_conv3d_layers
49+
cnn_spatial_output_size = (
50+
image_size_pixels - (spatial_kernel_size - 1) * number_of_conv3d_layers
51+
)
52+
cnn_sequence_length = (
53+
(sequence_length - temporal_kernel_size + 2 * padding[0]) + 1
54+
) * number_of_conv3d_layers
5055
if not (cnn_spatial_output_size >= 1):
5156
raise ValueError(
5257
f"cannot use this many conv3d layers ({number_of_conv3d_layers}) with this input "

0 commit comments

Comments
 (0)