1
- import os
2
1
import gc
3
- import sys
4
- import yaml
5
2
import logging
3
+ import multiprocessing as mp
4
+ import os
5
+ import sys
6
6
import warnings
7
- from pathlib import Path
8
7
from argparse import ArgumentParser
9
- import multiprocessing as mp
10
8
11
9
# ---------- #
12
10
# Numerics
13
11
from datetime import datetime , timedelta
14
- import xarray as xr
12
+ from pathlib import Path
13
+
15
14
import numpy as np
16
15
17
16
# ---------- #
18
17
import torch
19
-
18
+ import xarray as xr
19
+ import yaml
20
20
# ---------- #
21
21
# credit
22
- from credit .models import load_model
23
- from credit .seed import seed_everything
24
- from credit .distributed import get_rank_info
25
- from credit .datasets import setup_data_loading
26
- from credit .datasets .era5_predict_batcher import (
27
- BatchForecastLenDataLoader ,
28
- Predict_Dataset_Batcher ,
29
- )
30
-
31
22
from credit .data import (
32
23
concat_and_reshape ,
33
24
reshape_only ,
34
25
)
35
-
36
- from credit .transforms import load_transforms , Normalize_ERA5_and_Forcing
37
- from credit .pbs import launch_script , launch_script_mpi
38
- from credit .pol_lapdiff_filt import Diffusion_and_Pole_Filter
26
+ from credit . datasets import setup_data_loading
27
+ from credit .datasets . era5_multistep_batcher import Predict_Dataset_Batcher
28
+ from credit .datasets . load_dataset_and_dataloader import BatchForecastLenDataLoader
29
+ from credit .distributed import distributed_model_wrapper , get_rank_info , setup
39
30
from credit .forecast import load_forecasts
40
- from credit .distributed import distributed_model_wrapper , setup
31
+
32
+ from credit .models import load_model
41
33
from credit .models .checkpoint import load_model_state , load_state_dict_error_handler
42
- from credit .parser import credit_main_parser , predict_data_check
43
34
from credit .output import load_metadata , make_xarray , save_netcdf_increment
44
- from credit .postblock import GlobalMassFixer , GlobalWaterFixer , GlobalEnergyFixer
35
+ from credit .parser import credit_main_parser , predict_data_check
36
+ from credit .pbs import launch_script , launch_script_mpi
37
+ from credit .pol_lapdiff_filt import Diffusion_and_Pole_Filter
38
+ from credit .postblock import GlobalEnergyFixer , GlobalMassFixer , GlobalWaterFixer
39
+ from credit .seed import seed_everything
40
+ from credit .transforms import Normalize_ERA5_and_Forcing , load_transforms
45
41
46
42
logger = logging .getLogger (__name__ )
47
43
warnings .filterwarnings ("ignore" )
@@ -134,9 +130,9 @@ def predict(rank, world_size, conf, p):
134
130
135
131
# Load the forecasts we wish to compute
136
132
forecasts = load_forecasts (conf )
137
- if len (forecasts ) < batch_size :
138
- logger . warning (
139
- f"number of forecast init times { len (forecasts )} is less than batch_size { batch_size } , will result in under-utilization"
133
+ if len (forecasts ) % world_size != 0 :
134
+ raise ValueError (
135
+ f'Number of forecast inits ( { len (forecasts )} ) given by conf["predict"]["duration"] x len(conf["predict"]["start_hours"]) should be divisible by number of processes/GPUs ( { world_size } )'
140
136
)
141
137
142
138
dataset = Predict_Dataset_Batcher (
@@ -214,7 +210,8 @@ def predict(rank, world_size, conf, p):
214
210
215
211
# y_pred allocation and results tracking
216
212
results = []
217
- save_datetimes = [0 ] * len (forecasts )
213
+ # save_datetimes = [0] * len(forecasts)
214
+ init_datetimes = []
218
215
219
216
# model inference loop
220
217
for batch in data_loader :
@@ -230,9 +227,8 @@ def predict(rank, world_size, conf, p):
230
227
)
231
228
for i in range (batch_size )
232
229
]
233
- save_datetimes [forecast_count : forecast_count + batch_size ] = (
234
- init_datetimes
235
- )
230
+ # save_datetimes[forecast_count:forecast_count + batch_size] = init_datetimes
231
+ # save_datetimes
236
232
237
233
if "x_surf" in batch :
238
234
x = (
@@ -340,17 +336,15 @@ def predict(rank, world_size, conf, p):
340
336
(
341
337
all_upper_air ,
342
338
all_single_level ,
343
- save_datetimes [
344
- forecast_count + j
345
- ], # Use correct index for current batch item
339
+ init_datetimes [j ],
346
340
lead_time_periods * forecast_step ,
347
341
meta_data ,
348
342
conf ,
349
343
),
350
344
)
351
345
results .append (result )
352
346
353
- print_str = f"Forecast: { forecast_count + 1 + j } "
347
+ print_str = f"{ rank = : } Forecast: { forecast_count + 1 + j } "
354
348
print_str += f"Date: { utc_datetimes [j ].strftime ('%Y-%m-%d %H:%M:%S' )} "
355
349
print_str += f"Hour: { forecast_step * lead_time_periods } "
356
350
print (print_str )
@@ -360,14 +354,16 @@ def predict(rank, world_size, conf, p):
360
354
361
355
# y_diag is not drawn in predict batcher, if diag is specified in config, it will not be in the input to the model
362
356
if history_len == 1 :
363
- x = y_pred [:, :- varnum_diag , ...].detach ()
357
+ # x = y_pred[:, :-varnum_diag, ...].detach()
358
+ x = y_pred .detach ()
364
359
else :
365
360
if static_dim_size == 0 :
366
361
x_detach = x [:, :, 1 :, ...].detach ()
367
362
else :
368
363
x_detach = x [:, :- static_dim_size , 1 :, ...].detach ()
369
364
370
- x = torch .cat ([x_detach , y_pred [:, :- varnum_diag , ...].detach ()], dim = 2 )
365
+ # x = torch.cat([x_detach, y_pred[:, :-varnum_diag, ...].detach()], dim=2)
366
+ x = torch .cat ([x_detach , y_pred .detach ()], dim = 2 )
371
367
372
368
# Memory cleanup
373
369
torch .cuda .empty_cache ()
@@ -381,13 +377,10 @@ def predict(rank, world_size, conf, p):
381
377
y_pred = None
382
378
gc .collect ()
383
379
384
- if distributed :
385
- torch .distributed .barrier ()
386
-
387
380
forecast_count += batch_size
388
381
389
382
if distributed :
390
- torch .distributed .barrier ()
383
+ torch .distributed .destroy_process_group ()
391
384
392
385
return 1
393
386
@@ -491,9 +484,9 @@ def predict(rank, world_size, conf, p):
491
484
predict_data_check (conf , print_summary = False )
492
485
493
486
# create a save location for rollout
494
- assert (
495
- "save_forecast" in conf [" predict" ]
496
- ), "Please specify the output dir through conf['predict']['save_forecast']"
487
+ assert "save_forecast" in conf [ "predict" ], (
488
+ "Please specify the output dir through conf[' predict']['save_forecast']"
489
+ )
497
490
498
491
forecast_save_loc = conf ["predict" ]["save_forecast" ]
499
492
os .makedirs (forecast_save_loc , exist_ok = True )
0 commit comments