Skip to content

Commit a304738

Browse files
committed
Output networks updated
1 parent 2677391 commit a304738

File tree

1 file changed

+43
-12
lines changed

1 file changed

+43
-12
lines changed

pvnet/models/multimodal/linear_networks/output_networks.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
class DynamicOutputNetwork(AbstractLinearNetwork):
2222
""" Dynamic output network definition """
2323

24-
# Input ant output dimension specified here
24+
# Input and output dimensions specified here
25+
# Defines feature mapping ℝ^n → ℝ^m
2526
def __init__(
2627
self,
2728
in_features: int,
@@ -36,35 +37,39 @@ def __init__(
3637
):
3738
# Initialisation of dynamic output network
3839
super().__init__(in_features=in_features, out_features=out_features)
40+
self.out_features = out_features
3941

4042
# Default hidden architecture
43+
# h_i ∈ ℝ^{d_i}, where d_i = [2n, n]
4144
if hidden_dims is None:
4245
hidden_dims = [in_features * 2, in_features]
4346

4447
if any(dim <= 0 for dim in hidden_dims):
4548
raise ValueError("hidden_dims must be positive")
4649

4750
# Construction of network layers - config
51+
# Network architecture parameters θ
4852
self.use_layer_norm = use_layer_norm
4953
self.use_residual = use_residual
5054
self.quantile_output = quantile_output
5155
self.num_forecast_steps = num_forecast_steps
5256

5357
# Construction of hidden layers
54-
# Sequence: Linear → LayerNorm → ReLU → Dropout
58+
# H_i: ℝ^{d_i} → ℝ^{d_{i+1}}
59+
# Sequential transformation φ(x) = Dropout(ReLU(LayerNorm(Wx + b)))
5560
self.layers = nn.ModuleList()
5661
prev_dim = in_features
5762

5863
for dim in hidden_dims:
5964

60-
# Linear transformation / normalisatiom
65+
# Affine transformation followed by distribution normalisation
6166
layer_block = []
6267
layer_block.append(nn.Linear(prev_dim, dim))
6368

6469
if use_layer_norm:
6570
layer_block.append(nn.LayerNorm(dim))
6671

67-
# Non linearity / regularisation
72+
# Non-linear activation and stochastic regularisation
6873
layer_block.extend([
6974
nn.ReLU(),
7075
nn.Dropout(dropout)
@@ -73,8 +78,9 @@ def __init__(
7378
self.layers.append(nn.Sequential(*layer_block))
7479
prev_dim = dim
7580

76-
# Output layer definition
77-
# Projection for quantile preds over timesteps or standard
81+
# Output layer transformation definition
82+
# f: ℝ^{d_L} → ℝ^m
83+
# Projection mapping P: ℝ^d → ℝ^{m×t} for temporal quantile predictions
7884
if quantile_output and num_forecast_steps:
7985
final_out_features = out_features * num_forecast_steps
8086
else:
@@ -83,16 +89,23 @@ def __init__(
8389
self.output_layer = nn.Linear(prev_dim, final_out_features)
8490

8591
# Output activation definition
92+
# ψ: ℝ^m → [0,1]^m
8693
if output_activation == "softmax":
8794
self.output_activation = nn.Softmax(dim=-1)
8895
elif output_activation == "sigmoid":
8996
self.output_activation = nn.Sigmoid()
9097
else:
9198
self.output_activation = None
92-
93-
# Optional layer norm for residual connection
99+
100+
# Optional layer norm and residual projection
101+
# g: ℝ^n → ℝ^m
94102
if use_residual:
95-
self.residual_norm = nn.LayerNorm(out_features)
103+
if quantile_output and num_forecast_steps:
104+
self.residual_norm = nn.LayerNorm(out_features)
105+
else:
106+
final_out_features = out_features * num_forecast_steps if quantile_output and num_forecast_steps else out_features
107+
self.residual_norm = nn.LayerNorm(final_out_features)
108+
self.residual_proj = nn.Linear(in_features, out_features)
96109

97110
def reshape_quantile_output(self, x: torch.Tensor) -> torch.Tensor:
98111

@@ -124,11 +137,29 @@ def forward(
124137

125138
# Output transform, reshape and apply residual connection
126139
x = self.output_layer(x)
127-
x = self.reshape_quantile_output(x)
128-
if self.use_residual and x.shape == residual.shape:
129-
x = self.residual_norm(x + residual)
140+
x = self.reshape_quantile_output(x)
141+
142+
if self.use_residual:
143+
# Apply residual projection transformation
144+
projected_residual = self.residual_proj(residual)
145+
if self.quantile_output and self.num_forecast_steps:
146+
147+
# Apply residual mapping followed by normalisation
148+
projected_residual = projected_residual.reshape(x.shape[0], x.shape[2])
149+
150+
# Collapse temporal dimensions for normalisation
151+
# ℝ^{B×T×F} → ℝ^{BT×F}
152+
x = x.reshape(-1, x.shape[2])
153+
x = self.residual_norm(x + projected_residual.repeat(self.num_forecast_steps, 1))
154+
155+
# Restore tensor dimensionality
156+
# ℝ^{BT×F} → ℝ^{B×T×F}
157+
x = x.reshape(-1, self.num_forecast_steps, self.out_features)
158+
else:
159+
x = self.residual_norm(x + projected_residual)
130160

131161
# Apply output activation
162+
# Non-linear transformation ψ
132163
if self.output_activation:
133164
x = self.output_activation(x)
134165

0 commit comments

Comments
 (0)