-
-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathbasic_blocks.py
218 lines (177 loc) · 6.81 KB
/
basic_blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
"""Basic blocks for image sequence encoders"""
from abc import ABCMeta, abstractmethod
import torch
from torch import nn
class AbstractNWPSatelliteEncoder(nn.Module, metaclass=ABCMeta):
"""Abstract class for NWP/satellite encoder.
The encoder will take an input of shape (batch_size, sequence_length, channels, height, width)
and return an output of shape (batch_size, out_features).
"""
def __init__(
self,
sequence_length: int,
image_size_pixels: int,
in_channels: int,
out_features: int,
):
"""Abstract class for NWP/satellite encoder.
Args:
sequence_length: The time sequence length of the data.
image_size_pixels: The spatial size of the image. Assumed square.
in_channels: Number of input channels.
out_features: Number of output features.
"""
super().__init__()
self.out_features = out_features
self.image_size_pixels = image_size_pixels
self.sequence_length = sequence_length
@abstractmethod
def forward(self):
"""Run model forward"""
pass
class ResidualConv3dBlock(nn.Module):
"""Fully-connected deep network based on ResNet architecture.
Internally, this network uses ELU activations throughout the residual blocks.
"""
def __init__(
self,
in_channels,
n_layers: int = 2,
dropout_frac: float = 0.0,
):
"""Fully-connected deep network based on ResNet architecture.
Args:
in_channels: Number of input channels.
n_layers: Number of layers in residual pathway.
dropout_frac: Probability of an element to be zeroed.
"""
super().__init__()
layers = []
for i in range(n_layers):
layers += [
nn.ELU(),
nn.Conv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=(3, 3, 3),
padding=(1, 1, 1),
),
nn.Dropout3d(p=dropout_frac),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
"""Run residual connection"""
return self.model(x) + x
class ResidualConv3dBlock2(nn.Module):
"""Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
This was the best performing residual block tested in the study. This implementation differs
from that block just by using LeakyReLU activation to avoid dead neurons, and by including
optional dropout in the residual branch. This is also a 3D fully connected layer residual block
rather than a 2D convolutional block.
Sources:
[1] https://arxiv.org/pdf/1603.05027.pdf
"""
def __init__(
self,
in_channels: int,
n_layers: int = 2,
dropout_frac: float = 0.0,
batch_norm: bool = True,
):
"""Residual block of 'full pre-activation' similar to the block in figure 4(e) of [1].
Sources:
[1] https://arxiv.org/pdf/1603.05027.pdf
Args:
in_channels: Number of input channels.
n_layers: Number of layers in residual pathway.
dropout_frac: Probability of an element to be zeroed.
batch_norm: Whether to use batchnorm
"""
super().__init__()
layers = []
for i in range(n_layers):
if batch_norm:
layers.append(nn.BatchNorm3d(in_channels))
layers.extend(
[
nn.Dropout3d(p=dropout_frac),
nn.LeakyReLU(),
nn.Conv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=(3, 3, 3),
padding=(1, 1, 1),
),
]
)
self.model = nn.Sequential(*layers)
def forward(self, x):
"""Run model forward"""
return self.model(x) + x
class ImageSequenceEncoder(nn.Module):
"""Simple network which independently encodes each image in a sequence into 1D features"""
def __init__(
self,
image_size_pixels: int,
in_channels: int,
number_of_conv2d_layers: int = 4,
conv2d_channels: int = 32,
fc_features: int = 128,
):
"""Simple network which independently encodes each image in a sequence into 1D features.
For input image with shape [N, C, L, H, W] the output is of shape [N, L, fc_features] where
N is number of samples in batch, C is the number of input channels, L is the length of the
sequence, and H and W are the height and width.
Args:
image_size_pixels: The spatial size of the image. Assumed square.
in_channels: Number of input channels.
number_of_conv2d_layers: Number of convolution 2D layers that are used.
conv2d_channels: Number of channels used in each conv2d layer.
fc_features: Number of output nodes for each image in each sequence.
"""
super().__init__()
# Check that the output shape of the convolutional layers will be at least 1x1
cnn_spatial_output_size = image_size_pixels - 2 * number_of_conv2d_layers
if not (cnn_spatial_output_size >= 1):
raise ValueError(
f"cannot use this many conv2d layers ({number_of_conv2d_layers}) with this input "
f"spatial size ({image_size_pixels})"
)
conv_layers = []
conv_layers += [
nn.Conv2d(
in_channels=in_channels,
out_channels=conv2d_channels,
kernel_size=3,
padding=0,
),
nn.ELU(),
]
for i in range(0, number_of_conv2d_layers - 1):
conv_layers += [
nn.Conv2d(
in_channels=conv2d_channels,
out_channels=conv2d_channels,
kernel_size=3,
padding=0,
),
nn.ELU(),
]
self.conv_layers = nn.Sequential(*conv_layers)
self.final_block = nn.Sequential(
nn.Linear(
in_features=(cnn_spatial_output_size**2) * conv2d_channels,
out_features=fc_features,
),
nn.ELU(),
)
def forward(self, x):
"""Run model forward"""
batch_size, channel, seq_len, height, width = x.shape
x = torch.swapaxes(x, 1, 2)
x = x.reshape(batch_size * seq_len, channel, height, width)
out = self.conv_layers(x)
out = out.reshape(batch_size * seq_len, -1)
out = self.final_block(out)
out = out.reshape(batch_size, seq_len, -1)
return out