Skip to content

Commit

Permalink
Output networks updated
Browse files Browse the repository at this point in the history
  • Loading branch information
felix-e-h-p committed Jan 18, 2025
1 parent 2677391 commit a304738
Showing 1 changed file with 43 additions and 12 deletions.
55 changes: 43 additions & 12 deletions pvnet/models/multimodal/linear_networks/output_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
class DynamicOutputNetwork(AbstractLinearNetwork):
""" Dynamic output network definition """

# Input ant output dimension specified here
# Input and output dimensions specified here
# Defines feature mapping ℝ^n → ℝ^m
def __init__(
self,
in_features: int,
Expand All @@ -36,35 +37,39 @@ def __init__(
):
# Initialisation of dynamic output network
super().__init__(in_features=in_features, out_features=out_features)
self.out_features = out_features

# Default hidden architecture
# h_i ∈ ℝ^{d_i}, where d_i = [2n, n]
if hidden_dims is None:
hidden_dims = [in_features * 2, in_features]

if any(dim <= 0 for dim in hidden_dims):
raise ValueError("hidden_dims must be positive")

# Construction of network layers - config
# Network architecture parameters θ
self.use_layer_norm = use_layer_norm
self.use_residual = use_residual
self.quantile_output = quantile_output
self.num_forecast_steps = num_forecast_steps

# Construction of hidden layers
# Sequence: Linear → LayerNorm → ReLU → Dropout
# H_i: ℝ^{d_i} → ℝ^{d_{i+1}}
# Sequential transformation φ(x) = Dropout(ReLU(LayerNorm(Wx + b)))
self.layers = nn.ModuleList()
prev_dim = in_features

for dim in hidden_dims:

# Linear transformation / normalisatiom
# Affine transformation followed by distribution normalisation
layer_block = []
layer_block.append(nn.Linear(prev_dim, dim))

if use_layer_norm:
layer_block.append(nn.LayerNorm(dim))

# Non linearity / regularisation
# Non-linear activation and stochastic regularisation
layer_block.extend([
nn.ReLU(),
nn.Dropout(dropout)
Expand All @@ -73,8 +78,9 @@ def __init__(
self.layers.append(nn.Sequential(*layer_block))
prev_dim = dim

# Output layer definition
# Projection for quantile preds over timesteps or standard
# Output layer transformation definition
# f: ℝ^{d_L} → ℝ^m
# Projection mapping P: ℝ^d → ℝ^{m×t} for temporal quantile predictions
if quantile_output and num_forecast_steps:
final_out_features = out_features * num_forecast_steps
else:
Expand All @@ -83,16 +89,23 @@ def __init__(
self.output_layer = nn.Linear(prev_dim, final_out_features)

# Output activation definition
# ψ: ℝ^m → [0,1]^m
if output_activation == "softmax":
self.output_activation = nn.Softmax(dim=-1)
elif output_activation == "sigmoid":
self.output_activation = nn.Sigmoid()
else:
self.output_activation = None

# Optional layer norm for residual connection

# Optional layer norm and residual projection
# g: ℝ^n → ℝ^m
if use_residual:
self.residual_norm = nn.LayerNorm(out_features)
if quantile_output and num_forecast_steps:
self.residual_norm = nn.LayerNorm(out_features)
else:
final_out_features = out_features * num_forecast_steps if quantile_output and num_forecast_steps else out_features
self.residual_norm = nn.LayerNorm(final_out_features)
self.residual_proj = nn.Linear(in_features, out_features)

def reshape_quantile_output(self, x: torch.Tensor) -> torch.Tensor:

Expand Down Expand Up @@ -124,11 +137,29 @@ def forward(

# Output transform, reshape and apply residual connection
x = self.output_layer(x)
x = self.reshape_quantile_output(x)
if self.use_residual and x.shape == residual.shape:
x = self.residual_norm(x + residual)
x = self.reshape_quantile_output(x)

if self.use_residual:
# Apply residual projection transformation
projected_residual = self.residual_proj(residual)
if self.quantile_output and self.num_forecast_steps:

# Apply residual mapping followed by normalisation
projected_residual = projected_residual.reshape(x.shape[0], x.shape[2])

# Collapse temporal dimensions for normalisation
# ℝ^{B×T×F} → ℝ^{BT×F}
x = x.reshape(-1, x.shape[2])
x = self.residual_norm(x + projected_residual.repeat(self.num_forecast_steps, 1))

# Restore tensor dimensionality
# ℝ^{BT×F} → ℝ^{B×T×F}
x = x.reshape(-1, self.num_forecast_steps, self.out_features)
else:
x = self.residual_norm(x + projected_residual)

# Apply output activation
# Non-linear transformation ψ
if self.output_activation:
x = self.output_activation(x)

Expand Down

0 comments on commit a304738

Please sign in to comment.