21
21
class DynamicOutputNetwork (AbstractLinearNetwork ):
22
22
""" Dynamic output network definition """
23
23
24
- # Input ant output dimension specified here
24
+ # Input and output dimensions specified here
25
+ # Defines feature mapping ℝ^n → ℝ^m
25
26
def __init__ (
26
27
self ,
27
28
in_features : int ,
@@ -36,35 +37,39 @@ def __init__(
36
37
):
37
38
# Initialisation of dynamic output network
38
39
super ().__init__ (in_features = in_features , out_features = out_features )
40
+ self .out_features = out_features
39
41
40
42
# Default hidden architecture
43
+ # h_i ∈ ℝ^{d_i}, where d_i = [2n, n]
41
44
if hidden_dims is None :
42
45
hidden_dims = [in_features * 2 , in_features ]
43
46
44
47
if any (dim <= 0 for dim in hidden_dims ):
45
48
raise ValueError ("hidden_dims must be positive" )
46
49
47
50
# Construction of network layers - config
51
+ # Network architecture parameters θ
48
52
self .use_layer_norm = use_layer_norm
49
53
self .use_residual = use_residual
50
54
self .quantile_output = quantile_output
51
55
self .num_forecast_steps = num_forecast_steps
52
56
53
57
# Construction of hidden layers
54
- # Sequence: Linear → LayerNorm → ReLU → Dropout
58
+ # H_i: ℝ^{d_i} → ℝ^{d_{i+1}}
59
+ # Sequential transformation φ(x) = Dropout(ReLU(LayerNorm(Wx + b)))
55
60
self .layers = nn .ModuleList ()
56
61
prev_dim = in_features
57
62
58
63
for dim in hidden_dims :
59
64
60
- # Linear transformation / normalisatiom
65
+ # Affine transformation followed by distribution normalisation
61
66
layer_block = []
62
67
layer_block .append (nn .Linear (prev_dim , dim ))
63
68
64
69
if use_layer_norm :
65
70
layer_block .append (nn .LayerNorm (dim ))
66
71
67
- # Non linearity / regularisation
72
+ # Non-linear activation and stochastic regularisation
68
73
layer_block .extend ([
69
74
nn .ReLU (),
70
75
nn .Dropout (dropout )
@@ -73,8 +78,9 @@ def __init__(
73
78
self .layers .append (nn .Sequential (* layer_block ))
74
79
prev_dim = dim
75
80
76
- # Output layer definition
77
- # Projection for quantile preds over timesteps or standard
81
+ # Output layer transformation definition
82
+ # f: ℝ^{d_L} → ℝ^m
83
+ # Projection mapping P: ℝ^d → ℝ^{m×t} for temporal quantile predictions
78
84
if quantile_output and num_forecast_steps :
79
85
final_out_features = out_features * num_forecast_steps
80
86
else :
@@ -83,16 +89,23 @@ def __init__(
83
89
self .output_layer = nn .Linear (prev_dim , final_out_features )
84
90
85
91
# Output activation definition
92
+ # ψ: ℝ^m → [0,1]^m
86
93
if output_activation == "softmax" :
87
94
self .output_activation = nn .Softmax (dim = - 1 )
88
95
elif output_activation == "sigmoid" :
89
96
self .output_activation = nn .Sigmoid ()
90
97
else :
91
98
self .output_activation = None
92
-
93
- # Optional layer norm for residual connection
99
+
100
+ # Optional layer norm and residual projection
101
+ # g: ℝ^n → ℝ^m
94
102
if use_residual :
95
- self .residual_norm = nn .LayerNorm (out_features )
103
+ if quantile_output and num_forecast_steps :
104
+ self .residual_norm = nn .LayerNorm (out_features )
105
+ else :
106
+ final_out_features = out_features * num_forecast_steps if quantile_output and num_forecast_steps else out_features
107
+ self .residual_norm = nn .LayerNorm (final_out_features )
108
+ self .residual_proj = nn .Linear (in_features , out_features )
96
109
97
110
def reshape_quantile_output (self , x : torch .Tensor ) -> torch .Tensor :
98
111
@@ -124,11 +137,29 @@ def forward(
124
137
125
138
# Output transform, reshape and apply residual connection
126
139
x = self .output_layer (x )
127
- x = self .reshape_quantile_output (x )
128
- if self .use_residual and x .shape == residual .shape :
129
- x = self .residual_norm (x + residual )
140
+ x = self .reshape_quantile_output (x )
141
+
142
+ if self .use_residual :
143
+ # Apply residual projection transformation
144
+ projected_residual = self .residual_proj (residual )
145
+ if self .quantile_output and self .num_forecast_steps :
146
+
147
+ # Apply residual mapping followed by normalisation
148
+ projected_residual = projected_residual .reshape (x .shape [0 ], x .shape [2 ])
149
+
150
+ # Collapse temporal dimensions for normalisation
151
+ # ℝ^{B×T×F} → ℝ^{BT×F}
152
+ x = x .reshape (- 1 , x .shape [2 ])
153
+ x = self .residual_norm (x + projected_residual .repeat (self .num_forecast_steps , 1 ))
154
+
155
+ # Restore tensor dimensionality
156
+ # ℝ^{BT×F} → ℝ^{B×T×F}
157
+ x = x .reshape (- 1 , self .num_forecast_steps , self .out_features )
158
+ else :
159
+ x = self .residual_norm (x + projected_residual )
130
160
131
161
# Apply output activation
162
+ # Non-linear transformation ψ
132
163
if self .output_activation :
133
164
x = self .output_activation (x )
134
165
0 commit comments