Skip to content

Commit a3867c1

Browse files
committed
Make plotting configurable
1 parent 693ac46 commit a3867c1

File tree

7 files changed

+59
-66
lines changed

7 files changed

+59
-66
lines changed

pvnet/models/base_model.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def __init__(
241241
output_quantiles: Optional[list[float]] = None,
242242
target_key: str = "gsp",
243243
interval_minutes: int = 30,
244+
timestep_intervals_to_plot: Optional[list[int]] = None
244245
):
245246
"""Abtstract base class for PVNet submodels.
246247
@@ -258,6 +259,10 @@ def __init__(
258259
self._optimizer = optimizer
259260
self._target_key_name = target_key
260261
self._target_key = BatchKey[f"{target_key}"]
262+
if timestep_intervals_to_plot is not None:
263+
for interval in timestep_intervals_to_plot:
264+
assert type(interval) in [list, tuple] and len(interval) == 2, ValueError(f"timestep_intervals_to_plot must be a list of tuples or lists of length 2, but got {timestep_intervals_to_plot=}")
265+
self.time_step_intervals_to_plot = timestep_intervals_to_plot
261266

262267
# Model must have lr to allow tuning
263268
# This setting is only used when lr is tuned with callback
@@ -268,12 +273,12 @@ def __init__(
268273
self.output_quantiles = output_quantiles
269274

270275
# Number of timestemps for 30 minutely data
271-
self.history_len_30 = history_minutes // interval_minutes
272-
self.forecast_len_30 = forecast_minutes // interval_minutes
276+
self.history_len = history_minutes // interval_minutes
277+
self.forecast_len = forecast_minutes // interval_minutes
273278
# self.forecast_len_15 = forecast_minutes // 15
274279
# self.history_len_15 = history_minutes // 15
275280

276-
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len_30)
281+
self.weighted_losses = WeightedLosses(forecast_length=self.forecast_len)
277282

278283
self._accumulated_metrics = MetricAccumulator()
279284
self._accumulated_batches = BatchAccumulator(key_to_keep=self._target_key_name)
@@ -288,9 +293,9 @@ def use_quantile_regression(self):
288293
def num_output_features(self):
289294
"""Number of ouput features he model chould predict for"""
290295
if self.use_quantile_regression:
291-
out_features = self.forecast_len_30 * len(self.output_quantiles)
296+
out_features = self.forecast_len * len(self.output_quantiles)
292297
else:
293-
out_features = self.forecast_len_30
298+
out_features = self.forecast_len
294299
return out_features
295300

296301
def _quantiles_to_prediction(self, y_quantiles):
@@ -448,7 +453,7 @@ def training_step(self, batch, batch_idx):
448453
# Make all -1 values 0.0
449454
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
450455
y_hat = self(batch)
451-
y = batch[self._target_key][:, -self.forecast_len_30 :, 0]
456+
y = batch[self._target_key][:, -self.forecast_len:, 0]
452457

453458
losses = self._calculate_common_losses(y, y_hat)
454459
losses = {f"{k}/train": v for k, v in losses.items()}
@@ -467,7 +472,7 @@ def validation_step(self, batch: dict, batch_idx):
467472
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
468473
y_hat = self(batch)
469474
# Sensor seems to be in batch, station, time order
470-
y = batch[self._target_key][:, -self.forecast_len_30 :, 0]
475+
y = batch[self._target_key][:, -self.forecast_len:, 0]
471476

472477
losses = self._calculate_common_losses(y, y_hat)
473478
losses.update(self._calculate_val_losses(y, y_hat))
@@ -526,39 +531,24 @@ def validation_step(self, batch: dict, batch_idx):
526531
)
527532
plt.close(fig)
528533

529-
# Plot 1:30 to 3 hours ahead
530-
fig = plot_batch_forecasts(
531-
batch,
532-
y_hat,
533-
quantiles=self.output_quantiles,
534-
key_to_plot=self._target_key_name,
535-
timesteps_to_plot=[6, 12], # 1:30 to 3 hours ahead
536-
)
537-
self.logger.experiment.log(
538-
{
539-
f"val_forecast_samples/batch_idx_{accum_batch_num}_1.5_to_3hr": wandb.Image(
540-
fig
541-
),
542-
}
543-
)
544-
plt.close(fig)
534+
if self.time_step_intervals_to_plot is not None:
535+
for interval in self.time_step_intervals_to_plot:
536+
fig = plot_batch_forecasts(
537+
batch,
538+
y_hat,
539+
quantiles=self.output_quantiles,
540+
key_to_plot=self._target_key_name,
541+
timesteps_to_plot=interval,
542+
)
543+
self.logger.experiment.log(
544+
{
545+
f"val_forecast_samples/batch_idx_{accum_batch_num}_timestep_{interval}": wandb.Image(
546+
fig
547+
),
548+
}
549+
)
550+
plt.close(fig)
545551

546-
# Plot 15 to 39 hours ahead
547-
fig = plot_batch_forecasts(
548-
batch,
549-
y_hat,
550-
quantiles=self.output_quantiles,
551-
key_to_plot=self._target_key_name,
552-
timesteps_to_plot=[60, 156], # 15 to 39 hours ahead
553-
)
554-
self.logger.experiment.log(
555-
{
556-
f"val_forecast_samples/batch_idx_{accum_batch_num}_15_to_39hr": wandb.Image(
557-
fig
558-
),
559-
}
560-
)
561-
plt.close(fig)
562552
del self._val_y_hats
563553
del self._val_batches
564554

@@ -569,7 +559,7 @@ def test_step(self, batch, batch_idx):
569559
# Make all -1 values 0.0
570560
batch[self._target_key] = batch[self._target_key].clamp(min=0.0)
571561
y_hat = self(batch)
572-
y = batch[self._target_key][:, -self.forecast_len_30 :, 0]
562+
y = batch[self._target_key][:, -self.forecast_len:, 0]
573563

574564
losses = self._calculate_common_losses(y, y_hat)
575565
losses.update(self._calculate_val_losses(y, y_hat))

pvnet/models/baseline/last_value.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def forward(self, x: dict):
3636

3737
# take the last value non forecaster value and the first in the pv yeild
3838
# (this is the pv site we are preditcting for)
39-
y_hat = gsp_yield[:, -self.forecast_len_30 - 1, 0]
39+
y_hat = gsp_yield[:, -self.forecast_len - 1, 0]
4040

4141
# expand the last valid forward n predict steps
42-
out = y_hat.unsqueeze(1).repeat(1, self.forecast_len_30)
42+
out = y_hat.unsqueeze(1).repeat(1, self.forecast_len)
4343
return out

pvnet/models/baseline/single_value.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ def __init__(
3333
def forward(self, x: dict):
3434
"""Run model forward on dict batch of data"""
3535
# Returns a single value at all steps
36-
y_hat = torch.zeros_like(x[BatchKey.gsp][:, : self.forecast_len_30, 0]) + self._value
36+
y_hat = torch.zeros_like(x[BatchKey.gsp][:, : self.forecast_len, 0]) + self._value
3737
return y_hat

pvnet/models/multimodal/deep_supervision.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
if self.include_sun:
162162
# the minus 12 is bit of hard coded smudge for pvnet
163163
self.sun_fc1 = nn.Linear(
164-
in_features=2 * (self.forecast_len_30 + self.history_len_30 + 1),
164+
in_features=2 * (self.forecast_len + self.history_len + 1),
165165
out_features=16,
166166
)
167167

@@ -170,26 +170,26 @@ def __init__(
170170
num_cat_features += encoder_out_features
171171
self.sat_output_network = output_network(
172172
in_features=encoder_out_features,
173-
out_features=self.forecast_len_30,
173+
out_features=self.forecast_len,
174174
**output_network_kwargs,
175175
)
176176
if include_nwp:
177177
num_cat_features += encoder_out_features
178178
self.nwp_output_network = output_network(
179179
in_features=encoder_out_features,
180-
out_features=self.forecast_len_30,
180+
out_features=self.forecast_len,
181181
**output_network_kwargs,
182182
)
183183
if include_gsp_yield_history:
184-
num_cat_features += self.history_len_30
184+
num_cat_features += self.history_len
185185
if embedding_dim:
186186
num_cat_features += embedding_dim
187187
if include_sun:
188188
num_cat_features += 16
189189

190190
self.output_network = output_network(
191191
in_features=num_cat_features,
192-
out_features=self.forecast_len_30,
192+
out_features=self.forecast_len,
193193
**output_network_kwargs,
194194
)
195195

@@ -226,7 +226,7 @@ def encode(self, x):
226226
# *********************** GSP Data ************************************
227227
# add gsp yield history
228228
if self.include_gsp_yield_history:
229-
gsp_history = x[BatchKey.gsp][:, : self.history_len_30].float()
229+
gsp_history = x[BatchKey.gsp][:, : self.history_len].float()
230230
gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
231231
gsp_history = self.source_dropout_0d(gsp_history)
232232
modes["gsp"] = gsp_history
@@ -266,7 +266,7 @@ def multi_mode_forward(self, x):
266266
def training_step(self, batch, batch_idx):
267267
"""Training step"""
268268
y_hats = self.multi_mode_forward(batch)
269-
y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0]
269+
y = batch[BatchKey.gsp][:, -self.forecast_len:, 0]
270270

271271
losses = self._calculate_common_losses(y, y_hats["all"])
272272
losses = {f"{k}/train": v for k, v in losses.items()}

pvnet/models/multimodal/multimodal.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
pv_interval_minutes: int = 5,
6464
sat_interval_minutes: int = 5,
6565
sensor_interval_minutes: int = 30,
66+
timestep_intervals_to_plot: Optional[list[int]] = None,
6667
):
6768
"""Neural network which combines information from different sources.
6869
@@ -110,6 +111,7 @@ def __init__(
110111
pv_interval_minutes: The interval between each sample of the PV data
111112
sat_interval_minutes: The interval between each sample of the satellite data
112113
sensor_interval_minutes: The interval between each sample of the sensor data
114+
timestep_intervals_to_plot: Intervals, in timesteps, to plot in addition to the full forecast
113115
"""
114116

115117
self.include_gsp_yield_history = include_gsp_yield_history
@@ -131,6 +133,7 @@ def __init__(
131133
output_quantiles=output_quantiles,
132134
target_key=target_key,
133135
interval_minutes=interval_minutes,
136+
timestep_intervals_to_plot=timestep_intervals_to_plot
134137
)
135138

136139
# Number of features expected by the output_network
@@ -234,7 +237,7 @@ def __init__(
234237
if self.include_sun:
235238
# the minus 12 is bit of hard coded smudge for pvnet
236239
self.sun_fc1 = nn.Linear(
237-
in_features=2 * (self.forecast_len_30 + self.history_len_30 + 1),
240+
in_features=2 * (self.forecast_len + self.history_len + 1),
238241
out_features=16,
239242
)
240243

@@ -243,7 +246,7 @@ def __init__(
243246

244247
if include_gsp_yield_history:
245248
# Update num features
246-
fusion_input_features += self.history_len_30
249+
fusion_input_features += self.history_len
247250

248251
self.output_network = output_network(
249252
in_features=fusion_input_features,
@@ -286,13 +289,13 @@ def forward(self, x):
286289
# Target is PV, so only take the history
287290
# Copy batch
288291
x_tmp = x.copy()
289-
x_tmp[BatchKey.pv] = x_tmp[BatchKey.pv][:, : self.history_len_30]
292+
x_tmp[BatchKey.pv] = x_tmp[BatchKey.pv][:, : self.history_len]
290293
modes["pv"] = self.pv_encoder(x_tmp)
291294

292295
# *********************** GSP Data ************************************
293296
# add gsp yield history
294297
if self.include_gsp_yield_history:
295-
gsp_history = x[BatchKey.gsp][:, : self.history_len_30].float()
298+
gsp_history = x[BatchKey.gsp][:, : self.history_len].float()
296299
gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
297300
modes["gsp"] = gsp_history
298301

@@ -314,7 +317,7 @@ def forward(self, x):
314317
else:
315318
# Have to be its own Batch format
316319
x_tmp = x.copy()
317-
x_tmp[BatchKey.wind] = x_tmp[BatchKey.wind][:, : self.history_len_30]
320+
x_tmp[BatchKey.wind] = x_tmp[BatchKey.wind][:, : self.history_len]
318321
# This needs to be a Batch as input
319322
modes["wind"] = self.wind_encoder(x_tmp)
320323

@@ -324,7 +327,7 @@ def forward(self, x):
324327
modes["sensor"] = self.sensor_encoder(x)
325328
else:
326329
x_tmp = x.copy()
327-
x_tmp[BatchKey.sensor] = x_tmp[BatchKey.sensor][:, : self.history_len_30]
330+
x_tmp[BatchKey.sensor] = x_tmp[BatchKey.sensor][:, : self.history_len]
328331
# This needs to be a Batch as input
329332
modes["sensor"] = self.sensor_encoder(x_tmp)
330333

@@ -339,6 +342,6 @@ def forward(self, x):
339342

340343
if self.use_quantile_regression:
341344
# Shape: batch_size, seq_length * num_quantiles
342-
out = out.reshape(out.shape[0], self.forecast_len_30, len(self.output_quantiles))
345+
out = out.reshape(out.shape[0], self.forecast_len, len(self.output_quantiles))
343346

344347
return out

pvnet/models/multimodal/nwp_weighting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
self.interpolate = nn.Sequential(
6464
nn.Linear(
6565
in_features=nwp_sequence_len,
66-
out_features=self.forecast_len_30,
66+
out_features=self.forecast_len,
6767
),
6868
nn.LeakyReLU(negative_slope=0.01),
6969
)

pvnet/models/multimodal/weather_residual.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(
176176
if self.include_sun:
177177
# the minus 12 is bit of hard coded smudge for pvnet
178178
self.sun_fc1 = nn.Linear(
179-
in_features=2 * (self.forecast_len_30 + self.history_len_30 + 1),
179+
in_features=2 * (self.forecast_len + self.history_len + 1),
180180
out_features=16,
181181
)
182182

@@ -187,31 +187,31 @@ def __init__(
187187
if include_nwp:
188188
weather_cat_features += encoder_out_features
189189
if version == 1:
190-
weather_cat_features += self.forecast_len_30
190+
weather_cat_features += self.forecast_len
191191

192192
nonweather_cat_features = 0
193193
if include_gsp_yield_history:
194-
nonweather_cat_features += self.history_len_30
194+
nonweather_cat_features += self.history_len
195195
if embedding_dim:
196196
nonweather_cat_features += embedding_dim
197197
if include_sun:
198198
nonweather_cat_features += 16
199199

200200
self.simple_output_network = output_network(
201201
in_features=nonweather_cat_features,
202-
out_features=self.forecast_len_30,
202+
out_features=self.forecast_len,
203203
**output_network_kwargs,
204204
)
205205

206206
self.weather_residual_network = nn.Sequential(
207207
output_network(
208208
in_features=weather_cat_features,
209-
out_features=self.forecast_len_30,
209+
out_features=self.forecast_len,
210210
**output_network_kwargs,
211211
),
212212
# All output network return LeakyReLU activated outputs
213213
# However, the residual could be positive or negative
214-
nn.Linear(self.forecast_len_30, self.forecast_len_30),
214+
nn.Linear(self.forecast_len, self.forecast_len),
215215
)
216216

217217
self.source_dropout_0d = CompleteDropoutNd(0, p=source_dropout)
@@ -247,7 +247,7 @@ def encode(self, x):
247247
# *********************** GSP Data ************************************
248248
# add gsp yield history
249249
if self.include_gsp_yield_history:
250-
gsp_history = x[BatchKey.gsp][:, : self.history_len_30].float()
250+
gsp_history = x[BatchKey.gsp][:, : self.history_len].float()
251251
gsp_history = gsp_history.reshape(gsp_history.shape[0], -1)
252252
gsp_history = self.source_dropout_0d(gsp_history)
253253
modes["gsp"] = gsp_history
@@ -295,7 +295,7 @@ def multi_mode_forward(self, x):
295295
def training_step(self, batch, batch_idx):
296296
"""Run training step"""
297297
y_hats = self.multi_mode_forward(batch)
298-
y = batch[BatchKey.gsp][:, -self.forecast_len_30 :, 0]
298+
y = batch[BatchKey.gsp][:, -self.forecast_len:, 0]
299299

300300
losses = self._calculate_common_losses(y, y_hats["weather_out"])
301301
losses = {f"{k}/train": v for k, v in losses.items()}

0 commit comments

Comments
 (0)