11
11
12
12
import torch
13
13
import torch .nn .functional as F
14
+ import logging
14
15
from torch import nn
15
16
from abc import ABC , abstractmethod
16
17
from typing import Optional , List , Dict , Union
17
18
18
19
from pvnet .models .multimodal .linear_networks .basic_blocks import AbstractLinearNetwork
19
20
20
21
22
+ logging .basicConfig (level = logging .DEBUG , format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
23
+ logger = logging .getLogger ('output_networks' )
24
+
25
+
21
26
class DynamicOutputNetwork (AbstractLinearNetwork ):
22
27
""" Dynamic output network definition """
23
28
@@ -37,14 +42,20 @@ def __init__(
37
42
):
38
43
# Initialisation of dynamic output network
39
44
super ().__init__ (in_features = in_features , out_features = out_features )
45
+ logger .info (f"Initialising DynamicOutputNetwork with in_features={ in_features } , out_features={ out_features } " )
46
+ logger .debug (f"Configuration - dropout: { dropout } , layer_norm: { use_layer_norm } , residual: { use_residual } " )
47
+
40
48
self .out_features = out_features
41
49
42
50
# Default hidden architecture
43
51
# h_i ∈ ℝ^{d_i}, where d_i = [2n, n]
44
52
if hidden_dims is None :
45
53
hidden_dims = [in_features * 2 , in_features ]
46
-
54
+ logger .debug (f"Using default hidden dimensions: { hidden_dims } " )
55
+
47
56
if any (dim <= 0 for dim in hidden_dims ):
57
+ error_msg = f"Invalid hidden dimensions: { hidden_dims } "
58
+ logger .error (error_msg )
48
59
raise ValueError ("hidden_dims must be positive" )
49
60
50
61
# Construction of network layers - config
@@ -57,19 +68,19 @@ def __init__(
57
68
# Construction of hidden layers
58
69
# H_i: ℝ^{d_i} → ℝ^{d_{i+1}}
59
70
# Sequential transformation φ(x) = Dropout(ReLU(LayerNorm(Wx + b)))
71
+ logger .debug ("Constructing network layers" )
60
72
self .layers = nn .ModuleList ()
61
73
prev_dim = in_features
62
74
63
- for dim in hidden_dims :
64
-
65
- # Affine transformation followed by distribution normalisation
75
+ for i , dim in enumerate (hidden_dims ):
76
+ logger .debug (f"Building layer { i + 1 } : { prev_dim } → { dim } " )
66
77
layer_block = []
67
78
layer_block .append (nn .Linear (prev_dim , dim ))
68
79
69
80
if use_layer_norm :
81
+ logger .debug (f"Adding LayerNorm for dimension { dim } " )
70
82
layer_block .append (nn .LayerNorm (dim ))
71
83
72
- # Non-linear activation and stochastic regularisation
73
84
layer_block .extend ([
74
85
nn .ReLU (),
75
86
nn .Dropout (dropout )
@@ -83,14 +94,17 @@ def __init__(
83
94
# Projection mapping P: ℝ^d → ℝ^{m×t} for temporal quantile predictions
84
95
if quantile_output and num_forecast_steps :
85
96
final_out_features = out_features * num_forecast_steps
97
+ logger .debug (f"Configuring for quantile output with { num_forecast_steps } steps" )
86
98
else :
87
99
final_out_features = out_features
88
-
100
+
101
+ logger .debug (f"Creating output layer: { prev_dim } → { final_out_features } " )
89
102
self .output_layer = nn .Linear (prev_dim , final_out_features )
90
103
91
104
# Output activation definition
92
105
# ψ: ℝ^m → [0,1]^m
93
106
if output_activation == "softmax" :
107
+ logger .debug (f"Setting output activation: { output_activation } " )
94
108
self .output_activation = nn .Softmax (dim = - 1 )
95
109
elif output_activation == "sigmoid" :
96
110
self .output_activation = nn .Sigmoid ()
@@ -100,6 +114,7 @@ def __init__(
100
114
# Optional layer norm and residual projection
101
115
# g: ℝ^n → ℝ^m
102
116
if use_residual :
117
+ logger .debug ("Initialising residual connection components" )
103
118
if quantile_output and num_forecast_steps :
104
119
self .residual_norm = nn .LayerNorm (out_features )
105
120
else :
@@ -108,10 +123,11 @@ def __init__(
108
123
self .residual_proj = nn .Linear (in_features , out_features )
109
124
110
125
def reshape_quantile_output (self , x : torch .Tensor ) -> torch .Tensor :
111
-
112
- # Reshape output for quantile predictions
126
+ logger .debug (f"Input shape before reshape: { x .shape } " )
113
127
if self .quantile_output and self .num_forecast_steps :
114
- return x .reshape (x .shape [0 ], self .num_forecast_steps , - 1 )
128
+ reshaped = x .reshape (x .shape [0 ], self .num_forecast_steps , - 1 )
129
+ logger .debug (f"Reshaped output shape: { reshaped .shape } " )
130
+ return reshaped
115
131
return x
116
132
117
133
def forward (
@@ -120,47 +136,59 @@ def forward(
120
136
return_intermediates : bool = False
121
137
) -> Union [torch .Tensor , tuple ]:
122
138
139
+ logger .info ("Starting DynamicOutputNetwork forward pass" )
140
+
123
141
# Forward pass for dynamic output network
124
142
# Handle dict input
125
143
# Concatenate multimodal inputs if dict provided
126
144
if isinstance (x , dict ):
145
+ logger .debug (f"Processing dictionary input with keys: { list (x .keys ())} " )
127
146
x = torch .cat (list (x .values ()), dim = - 1 )
147
+ logger .debug (f"Concatenated input shape: { x .shape } " )
128
148
129
149
intermediates = []
130
150
residual = x
131
151
132
152
# Process through hidden layers
133
- for layer in self .layers :
153
+ for i , layer in enumerate (self .layers ):
154
+ logger .debug (f"Processing layer { i + 1 } , input shape: { x .shape } " )
134
155
x = layer (x )
135
156
if return_intermediates :
136
157
intermediates .append (x )
137
158
138
159
# Output transform, reshape and apply residual connection
160
+ logger .debug (f"Applying output layer to shape: { x .shape } " )
139
161
x = self .output_layer (x )
140
162
x = self .reshape_quantile_output (x )
141
163
142
164
if self .use_residual :
165
+ logger .debug ("Applying residual connection" )
166
+
143
167
# Apply residual projection transformation
144
168
projected_residual = self .residual_proj (residual )
145
169
if self .quantile_output and self .num_forecast_steps :
170
+ logger .debug ("Processing quantile output with residual" )
146
171
147
172
# Apply residual mapping followed by normalisation
148
173
projected_residual = projected_residual .reshape (x .shape [0 ], x .shape [2 ])
149
174
150
175
# Collapse temporal dimensions for normalisation
151
176
# ℝ^{B×T×F} → ℝ^{BT×F}
152
177
x = x .reshape (- 1 , x .shape [2 ])
178
+ logger .debug (f"Reshaped for residual: { x .shape } " )
153
179
x = self .residual_norm (x + projected_residual .repeat (self .num_forecast_steps , 1 ))
154
180
155
181
# Restore tensor dimensionality
156
182
# ℝ^{BT×F} → ℝ^{B×T×F}
157
183
x = x .reshape (- 1 , self .num_forecast_steps , self .out_features )
184
+ logger .debug (f"Final shape after residual: { x .shape } " )
158
185
else :
159
186
x = self .residual_norm (x + projected_residual )
160
187
161
188
# Apply output activation
162
189
# Non-linear transformation ψ
163
190
if self .output_activation :
191
+ logger .debug (f"Applying output activation: { type (self .output_activation ).__name__ } " )
164
192
x = self .output_activation (x )
165
193
166
194
if return_intermediates :
@@ -180,7 +208,10 @@ def __init__(
180
208
hidden_dims : Optional [List [int ]] = None ,
181
209
dropout : float = 0.1
182
210
):
183
-
211
+
212
+ logger .info (f"Initialising QuantileOutputNetwork with in_features={ in_features } , num_quantiles={ num_quantiles } " )
213
+ logger .debug (f"Forecast steps: { num_forecast_steps } , hidden_dims: { hidden_dims } " )
214
+
184
215
# Initialisation of quantile output network
185
216
super ().__init__ (
186
217
in_features = in_features ,
0 commit comments