Skip to content

Commit db62743

Browse files
committed
Fixed rollout_*_batcher.py
1 parent 58c2c13 commit db62743

4 files changed

+69
-91
lines changed

applications/rollout_ens_batcher.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,8 @@
3838
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
3939
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer
4040
from credit.parser import credit_main_parser, predict_data_check
41-
from credit.datasets.era5_predict_batcher import (
42-
BatchForecastLenDataLoader,
43-
Predict_Dataset_Batcher
44-
)
45-
41+
from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
42+
from credit.datasets.load_dataset_and_dataloader import BatchForecastLenDataLoader
4643

4744
logger = logging.getLogger(__name__)
4845
warnings.filterwarnings("ignore")
@@ -365,9 +362,6 @@ def predict(rank, world_size, conf, backend=None, p=None):
365362
y_pred = None
366363
gc.collect()
367364

368-
if distributed:
369-
torch.distributed.barrier()
370-
371365
forecast_count += batch_size
372366

373367
if distributed:

applications/rollout_ensemble.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@
3636
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
3737
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer
3838
from credit.parser import credit_main_parser, predict_data_check
39-
from credit.datasets.era5_predict_batcher import (
40-
BatchForecastLenDataLoader,
41-
Predict_Dataset_Batcher,
42-
)
39+
from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
40+
from credit.datasets.load_dataset_and_dataloader import BatchForecastLenDataLoader
4341
from credit.ensemble.bred_vector import generate_bred_vectors
4442
from credit.ensemble.crps import calculate_crps_per_channel
4543

@@ -549,8 +547,7 @@ def predict(rank, world_size, conf, backend=None, p=None):
549547
y_pred = None
550548
gc.collect()
551549

552-
if distributed:
553-
torch.distributed.barrier()
550+
554551

555552
forecast_count += batch_size
556553

applications/rollout_metrics_batcher.py

+28-34
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,44 @@
11
# ---------- #
22
# System
3-
import os
43
import gc
5-
import sys
6-
import yaml
74
import logging
8-
import warnings
95
import multiprocessing as mp
10-
from pathlib import Path
6+
import os
7+
import sys
8+
import warnings
119
from argparse import ArgumentParser
1210
from collections import defaultdict
1311

1412
# ---------- #
1513
# Numerics
1614
from datetime import datetime, timedelta
17-
import pandas as pd
18-
import xarray as xr
15+
from pathlib import Path
16+
1917
import numpy as np
18+
import pandas as pd
2019

2120
# ---------- #
2221
import torch
23-
22+
import xarray as xr
23+
import yaml
2424
# ---------- #
2525
# credit
26-
from credit.models import load_model
27-
from credit.seed import seed_everything
2826
from credit.data import concat_and_reshape, reshape_only
2927
from credit.datasets import setup_data_loading
30-
from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing
31-
from credit.pbs import launch_script, launch_script_mpi
32-
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
33-
from credit.metrics import LatWeightedMetrics, LatWeightedMetricsClimatology
28+
from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
29+
from credit.datasets.load_dataset_and_dataloader import BatchForecastLenDataLoader
30+
from credit.distributed import distributed_model_wrapper, get_rank_info, setup
3431
from credit.forecast import load_forecasts
35-
from credit.distributed import distributed_model_wrapper, setup, get_rank_info
32+
from credit.metrics import LatWeightedMetrics, LatWeightedMetricsClimatology
33+
34+
from credit.models import load_model
3635
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
37-
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer
3836
from credit.parser import credit_main_parser, predict_data_check
39-
from credit.datasets.era5_predict_batcher import (
40-
BatchForecastLenDataLoader,
41-
Predict_Dataset_Batcher,
42-
)
43-
37+
from credit.pbs import launch_script, launch_script_mpi
38+
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
39+
from credit.postblock import GlobalEnergyFixer, GlobalMassFixer, GlobalWaterFixer
40+
from credit.seed import seed_everything
41+
from credit.transforms import Normalize_ERA5_and_Forcing, load_transforms
4442

4543
logger = logging.getLogger(__name__)
4644
warnings.filterwarnings("ignore")
@@ -146,6 +144,10 @@ def predict(rank, world_size, conf, backend=None, p=None):
146144

147145
# Load the forecasts we wish to compute
148146
forecasts = load_forecasts(conf)
147+
if len(forecasts) % world_size != 0:
148+
raise ValueError(
149+
f'Number of forecast inits ({len(forecasts)}) given by conf["predict"]["duration"] x len(conf["predict"]["start_hours"]) should be divisible by number of processes/GPUs ({world_size})'
150+
)
149151

150152
dataset = Predict_Dataset_Batcher(
151153
varname_upper_air=data_config["varname_upper_air"],
@@ -229,8 +231,6 @@ def predict(rank, world_size, conf, backend=None, p=None):
229231

230232
# y_pred allocation and results tracking
231233
results = []
232-
save_datetimes = [0] * batch_size
233-
234234
# model inference loop
235235
for k, batch in enumerate(data_loader):
236236
batch_size = batch["datetime"].shape[0]
@@ -248,9 +248,6 @@ def predict(rank, world_size, conf, backend=None, p=None):
248248
)
249249
for i in range(batch_size)
250250
]
251-
save_datetimes[forecast_count : forecast_count + batch_size] = (
252-
init_datetimes
253-
)
254251
if "x_surf" in batch:
255252
x = (
256253
concat_and_reshape(batch["x"], batch["x_surf"])
@@ -353,7 +350,7 @@ def predict(rank, world_size, conf, backend=None, p=None):
353350
results.append((j, result)) # Store the batch index with the result
354351

355352
# Print to screen
356-
print_str = f"Forecast: {forecast_count + 1 + j} "
353+
print_str = f"{rank=:} Forecast: {forecast_count + 1 + j} "
357354
print_str += f"Date: {utc_datetime[j].strftime('%Y-%m-%d %H:%M:%S')} "
358355
print_str += f"Hour: {forecast_step * lead_time_periods} "
359356
print(print_str)
@@ -401,13 +398,10 @@ def predict(rank, world_size, conf, backend=None, p=None):
401398
y_pred = None
402399
gc.collect()
403400

404-
if distributed:
405-
torch.distributed.barrier()
406-
407401
forecast_count += batch_size
408402

409403
if distributed:
410-
torch.distributed.barrier()
404+
torch.distributed.destroy_process_group()
411405

412406
return 1
413407

@@ -517,9 +511,9 @@ def predict(rank, world_size, conf, backend=None, p=None):
517511
data_config = setup_data_loading(conf)
518512

519513
# create a save location for rollout
520-
assert (
521-
"save_forecast" in conf["predict"]
522-
), "Please specify the output dir through conf['predict']['save_forecast']"
514+
assert "save_forecast" in conf["predict"], (
515+
"Please specify the output dir through conf['predict']['save_forecast']"
516+
)
523517

524518
forecast_save_loc = conf["predict"]["save_forecast"]
525519
os.makedirs(forecast_save_loc, exist_ok=True)

applications/rollout_to_netcdf_batcher.py

+36-43
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,43 @@
1-
import os
21
import gc
3-
import sys
4-
import yaml
52
import logging
3+
import multiprocessing as mp
4+
import os
5+
import sys
66
import warnings
7-
from pathlib import Path
87
from argparse import ArgumentParser
9-
import multiprocessing as mp
108

119
# ---------- #
1210
# Numerics
1311
from datetime import datetime, timedelta
14-
import xarray as xr
12+
from pathlib import Path
13+
1514
import numpy as np
1615

1716
# ---------- #
1817
import torch
19-
18+
import xarray as xr
19+
import yaml
2020
# ---------- #
2121
# credit
22-
from credit.models import load_model
23-
from credit.seed import seed_everything
24-
from credit.distributed import get_rank_info
25-
from credit.datasets import setup_data_loading
26-
from credit.datasets.era5_predict_batcher import (
27-
BatchForecastLenDataLoader,
28-
Predict_Dataset_Batcher,
29-
)
30-
3122
from credit.data import (
3223
concat_and_reshape,
3324
reshape_only,
3425
)
35-
36-
from credit.transforms import load_transforms, Normalize_ERA5_and_Forcing
37-
from credit.pbs import launch_script, launch_script_mpi
38-
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
26+
from credit.datasets import setup_data_loading
27+
from credit.datasets.era5_multistep_batcher import Predict_Dataset_Batcher
28+
from credit.datasets.load_dataset_and_dataloader import BatchForecastLenDataLoader
29+
from credit.distributed import distributed_model_wrapper, get_rank_info, setup
3930
from credit.forecast import load_forecasts
40-
from credit.distributed import distributed_model_wrapper, setup
31+
32+
from credit.models import load_model
4133
from credit.models.checkpoint import load_model_state, load_state_dict_error_handler
42-
from credit.parser import credit_main_parser, predict_data_check
4334
from credit.output import load_metadata, make_xarray, save_netcdf_increment
44-
from credit.postblock import GlobalMassFixer, GlobalWaterFixer, GlobalEnergyFixer
35+
from credit.parser import credit_main_parser, predict_data_check
36+
from credit.pbs import launch_script, launch_script_mpi
37+
from credit.pol_lapdiff_filt import Diffusion_and_Pole_Filter
38+
from credit.postblock import GlobalEnergyFixer, GlobalMassFixer, GlobalWaterFixer
39+
from credit.seed import seed_everything
40+
from credit.transforms import Normalize_ERA5_and_Forcing, load_transforms
4541

4642
logger = logging.getLogger(__name__)
4743
warnings.filterwarnings("ignore")
@@ -134,9 +130,9 @@ def predict(rank, world_size, conf, p):
134130

135131
# Load the forecasts we wish to compute
136132
forecasts = load_forecasts(conf)
137-
if len(forecasts) < batch_size:
138-
logger.warning(
139-
f"number of forecast init times {len(forecasts)} is less than batch_size {batch_size}, will result in under-utilization"
133+
if len(forecasts) % world_size != 0:
134+
raise ValueError(
135+
f'Number of forecast inits ({len(forecasts)}) given by conf["predict"]["duration"] x len(conf["predict"]["start_hours"]) should be divisible by number of processes/GPUs ({world_size})'
140136
)
141137

142138
dataset = Predict_Dataset_Batcher(
@@ -214,7 +210,8 @@ def predict(rank, world_size, conf, p):
214210

215211
# y_pred allocation and results tracking
216212
results = []
217-
save_datetimes = [0] * len(forecasts)
213+
# save_datetimes = [0] * len(forecasts)
214+
init_datetimes = []
218215

219216
# model inference loop
220217
for batch in data_loader:
@@ -230,9 +227,8 @@ def predict(rank, world_size, conf, p):
230227
)
231228
for i in range(batch_size)
232229
]
233-
save_datetimes[forecast_count : forecast_count + batch_size] = (
234-
init_datetimes
235-
)
230+
# save_datetimes[forecast_count:forecast_count + batch_size] = init_datetimes
231+
# save_datetimes
236232

237233
if "x_surf" in batch:
238234
x = (
@@ -340,17 +336,15 @@ def predict(rank, world_size, conf, p):
340336
(
341337
all_upper_air,
342338
all_single_level,
343-
save_datetimes[
344-
forecast_count + j
345-
], # Use correct index for current batch item
339+
init_datetimes[j],
346340
lead_time_periods * forecast_step,
347341
meta_data,
348342
conf,
349343
),
350344
)
351345
results.append(result)
352346

353-
print_str = f"Forecast: {forecast_count + 1 + j} "
347+
print_str = f"{rank=:} Forecast: {forecast_count + 1 + j} "
354348
print_str += f"Date: {utc_datetimes[j].strftime('%Y-%m-%d %H:%M:%S')} "
355349
print_str += f"Hour: {forecast_step * lead_time_periods} "
356350
print(print_str)
@@ -360,14 +354,16 @@ def predict(rank, world_size, conf, p):
360354

361355
# y_diag is not drawn in predict batcher, if diag is specified in config, it will not be in the input to the model
362356
if history_len == 1:
363-
x = y_pred[:, :-varnum_diag, ...].detach()
357+
# x = y_pred[:, :-varnum_diag, ...].detach()
358+
x = y_pred.detach()
364359
else:
365360
if static_dim_size == 0:
366361
x_detach = x[:, :, 1:, ...].detach()
367362
else:
368363
x_detach = x[:, :-static_dim_size, 1:, ...].detach()
369364

370-
x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)
365+
# x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)
366+
x = torch.cat([x_detach, y_pred.detach()], dim=2)
371367

372368
# Memory cleanup
373369
torch.cuda.empty_cache()
@@ -381,13 +377,10 @@ def predict(rank, world_size, conf, p):
381377
y_pred = None
382378
gc.collect()
383379

384-
if distributed:
385-
torch.distributed.barrier()
386-
387380
forecast_count += batch_size
388381

389382
if distributed:
390-
torch.distributed.barrier()
383+
torch.distributed.destroy_process_group()
391384

392385
return 1
393386

@@ -491,9 +484,9 @@ def predict(rank, world_size, conf, p):
491484
predict_data_check(conf, print_summary=False)
492485

493486
# create a save location for rollout
494-
assert (
495-
"save_forecast" in conf["predict"]
496-
), "Please specify the output dir through conf['predict']['save_forecast']"
487+
assert "save_forecast" in conf["predict"], (
488+
"Please specify the output dir through conf['predict']['save_forecast']"
489+
)
497490

498491
forecast_save_loc = conf["predict"]["save_forecast"]
499492
os.makedirs(forecast_save_loc, exist_ok=True)

0 commit comments

Comments
 (0)