Skip to content

Commit 8cd6165

Browse files
committed
Output networks updated
1 parent 133f2da commit 8cd6165

File tree

1 file changed

+42
-11
lines changed

1 file changed

+42
-11
lines changed

pvnet/models/multimodal/linear_networks/output_networks.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,18 @@
1111

1212
import torch
1313
import torch.nn.functional as F
14+
import logging
1415
from torch import nn
1516
from abc import ABC, abstractmethod
1617
from typing import Optional, List, Dict, Union
1718

1819
from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
1920

2021

22+
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
23+
logger = logging.getLogger('output_networks')
24+
25+
2126
class DynamicOutputNetwork(AbstractLinearNetwork):
2227
""" Dynamic output network definition """
2328

@@ -37,14 +42,20 @@ def __init__(
3742
):
3843
# Initialisation of dynamic output network
3944
super().__init__(in_features=in_features, out_features=out_features)
45+
logger.info(f"Initialising DynamicOutputNetwork with in_features={in_features}, out_features={out_features}")
46+
logger.debug(f"Configuration - dropout: {dropout}, layer_norm: {use_layer_norm}, residual: {use_residual}")
47+
4048
self.out_features = out_features
4149

4250
# Default hidden architecture
4351
# h_i ∈ ℝ^{d_i}, where d_i = [2n, n]
4452
if hidden_dims is None:
4553
hidden_dims = [in_features * 2, in_features]
46-
54+
logger.debug(f"Using default hidden dimensions: {hidden_dims}")
55+
4756
if any(dim <= 0 for dim in hidden_dims):
57+
error_msg = f"Invalid hidden dimensions: {hidden_dims}"
58+
logger.error(error_msg)
4859
raise ValueError("hidden_dims must be positive")
4960

5061
# Construction of network layers - config
@@ -57,19 +68,19 @@ def __init__(
5768
# Construction of hidden layers
5869
# H_i: ℝ^{d_i} → ℝ^{d_{i+1}}
5970
# Sequential transformation φ(x) = Dropout(ReLU(LayerNorm(Wx + b)))
71+
logger.debug("Constructing network layers")
6072
self.layers = nn.ModuleList()
6173
prev_dim = in_features
6274

63-
for dim in hidden_dims:
64-
65-
# Affine transformation followed by distribution normalisation
75+
for i, dim in enumerate(hidden_dims):
76+
logger.debug(f"Building layer {i+1}: {prev_dim}{dim}")
6677
layer_block = []
6778
layer_block.append(nn.Linear(prev_dim, dim))
6879

6980
if use_layer_norm:
81+
logger.debug(f"Adding LayerNorm for dimension {dim}")
7082
layer_block.append(nn.LayerNorm(dim))
7183

72-
# Non-linear activation and stochastic regularisation
7384
layer_block.extend([
7485
nn.ReLU(),
7586
nn.Dropout(dropout)
@@ -83,14 +94,17 @@ def __init__(
8394
# Projection mapping P: ℝ^d → ℝ^{m×t} for temporal quantile predictions
8495
if quantile_output and num_forecast_steps:
8596
final_out_features = out_features * num_forecast_steps
97+
logger.debug(f"Configuring for quantile output with {num_forecast_steps} steps")
8698
else:
8799
final_out_features = out_features
88-
100+
101+
logger.debug(f"Creating output layer: {prev_dim}{final_out_features}")
89102
self.output_layer = nn.Linear(prev_dim, final_out_features)
90103

91104
# Output activation definition
92105
# ψ: ℝ^m → [0,1]^m
93106
if output_activation == "softmax":
107+
logger.debug(f"Setting output activation: {output_activation}")
94108
self.output_activation = nn.Softmax(dim=-1)
95109
elif output_activation == "sigmoid":
96110
self.output_activation = nn.Sigmoid()
@@ -100,6 +114,7 @@ def __init__(
100114
# Optional layer norm and residual projection
101115
# g: ℝ^n → ℝ^m
102116
if use_residual:
117+
logger.debug("Initialising residual connection components")
103118
if quantile_output and num_forecast_steps:
104119
self.residual_norm = nn.LayerNorm(out_features)
105120
else:
@@ -108,10 +123,11 @@ def __init__(
108123
self.residual_proj = nn.Linear(in_features, out_features)
109124

110125
def reshape_quantile_output(self, x: torch.Tensor) -> torch.Tensor:
111-
112-
# Reshape output for quantile predictions
126+
logger.debug(f"Input shape before reshape: {x.shape}")
113127
if self.quantile_output and self.num_forecast_steps:
114-
return x.reshape(x.shape[0], self.num_forecast_steps, -1)
128+
reshaped = x.reshape(x.shape[0], self.num_forecast_steps, -1)
129+
logger.debug(f"Reshaped output shape: {reshaped.shape}")
130+
return reshaped
115131
return x
116132

117133
def forward(
@@ -120,47 +136,59 @@ def forward(
120136
return_intermediates: bool = False
121137
) -> Union[torch.Tensor, tuple]:
122138

139+
logger.info("Starting DynamicOutputNetwork forward pass")
140+
123141
# Forward pass for dynamic output network
124142
# Handle dict input
125143
# Concatenate multimodal inputs if dict provided
126144
if isinstance(x, dict):
145+
logger.debug(f"Processing dictionary input with keys: {list(x.keys())}")
127146
x = torch.cat(list(x.values()), dim=-1)
147+
logger.debug(f"Concatenated input shape: {x.shape}")
128148

129149
intermediates = []
130150
residual = x
131151

132152
# Process through hidden layers
133-
for layer in self.layers:
153+
for i, layer in enumerate(self.layers):
154+
logger.debug(f"Processing layer {i+1}, input shape: {x.shape}")
134155
x = layer(x)
135156
if return_intermediates:
136157
intermediates.append(x)
137158

138159
# Output transform, reshape and apply residual connection
160+
logger.debug(f"Applying output layer to shape: {x.shape}")
139161
x = self.output_layer(x)
140162
x = self.reshape_quantile_output(x)
141163

142164
if self.use_residual:
165+
logger.debug("Applying residual connection")
166+
143167
# Apply residual projection transformation
144168
projected_residual = self.residual_proj(residual)
145169
if self.quantile_output and self.num_forecast_steps:
170+
logger.debug("Processing quantile output with residual")
146171

147172
# Apply residual mapping followed by normalisation
148173
projected_residual = projected_residual.reshape(x.shape[0], x.shape[2])
149174

150175
# Collapse temporal dimensions for normalisation
151176
# ℝ^{B×T×F} → ℝ^{BT×F}
152177
x = x.reshape(-1, x.shape[2])
178+
logger.debug(f"Reshaped for residual: {x.shape}")
153179
x = self.residual_norm(x + projected_residual.repeat(self.num_forecast_steps, 1))
154180

155181
# Restore tensor dimensionality
156182
# ℝ^{BT×F} → ℝ^{B×T×F}
157183
x = x.reshape(-1, self.num_forecast_steps, self.out_features)
184+
logger.debug(f"Final shape after residual: {x.shape}")
158185
else:
159186
x = self.residual_norm(x + projected_residual)
160187

161188
# Apply output activation
162189
# Non-linear transformation ψ
163190
if self.output_activation:
191+
logger.debug(f"Applying output activation: {type(self.output_activation).__name__}")
164192
x = self.output_activation(x)
165193

166194
if return_intermediates:
@@ -180,7 +208,10 @@ def __init__(
180208
hidden_dims: Optional[List[int]] = None,
181209
dropout: float = 0.1
182210
):
183-
211+
212+
logger.info(f"Initialising QuantileOutputNetwork with in_features={in_features}, num_quantiles={num_quantiles}")
213+
logger.debug(f"Forecast steps: {num_forecast_steps}, hidden_dims: {hidden_dims}")
214+
184215
# Initialisation of quantile output network
185216
super().__init__(
186217
in_features=in_features,

0 commit comments

Comments
 (0)